diff --git a/.clang-tidy b/.clang-tidy index 5b36ac93d48d..02b565b23585 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,20 +1,26 @@ Checks: > *, -abseil-*, + -altera-struct-pack-align, + -altera-unroll-loops, -android-*, - -cert-err58-cpp, + -boost-use-ranges, + -bugprone-easily-swappable-parameters, -cert-err58-cpp, -clang-analyzer-osx-*, -cppcoreguidelines-avoid-c-arrays, + -cppcoreguidelines-avoid-const-or-ref-data-members, + -cppcoreguidelines-avoid-do-while, -cppcoreguidelines-avoid-goto, -cppcoreguidelines-avoid-magic-numbers, -cppcoreguidelines-avoid-non-const-global-variables, + -cppcoreguidelines-non-private-member-variables-in-classes, -cppcoreguidelines-owning-memory, -cppcoreguidelines-pro-bounds-array-to-pointer-decay, + -cppcoreguidelines-pro-bounds-constant-array-index, -cppcoreguidelines-pro-bounds-pointer-arithmetic, -cppcoreguidelines-pro-type-reinterpret-cast, -cppcoreguidelines-pro-type-vararg, - -cppcoreguidelines-pro-type-vararg, -cppcoreguidelines-special-member-functions, -fuchsia-*, -google-*, @@ -25,25 +31,31 @@ Checks: > -hicpp-special-member-functions, -hicpp-use-equals-default, -hicpp-vararg, - -hicpp-vararg, -llvm-header-guard, -llvm-include-order, + -llvm-qualified-auto, -llvmlibc-*, -misc-no-recursion, - -misc-no-recursion, -misc-non-private-member-variables-in-classes, -misc-unused-parameters, -modernize-avoid-c-arrays, -modernize-deprecated-headers, + -modernize-use-designated-initializers, -modernize-use-nodiscard, -modernize-use-trailing-return-type, -mpi-*, -objc-*, -openmp-*, + -performance-avoid-endl, + -performance-enum-size, -readability-avoid-const-params-in-decls, -readability-convert-member-functions-to-static, + -readability-function-cognitive-complexity, + -readability-identifier-length, -readability-implicit-bool-conversion, -readability-magic-numbers, + -readability-math-missing-parentheses, + -readability-qualified-auto, -zircon-*, HeaderFilterRegex: '.*' diff --git a/.cmake-format.yaml b/.cmake-format.yaml deleted file mode 100644 index bbbd89f433a5..000000000000 --- a/.cmake-format.yaml +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Currently this config mostly mirrors the default with the addition of custom functions -format: - line_width: 80 - tab_size: 2 - use_tabchars: false - max_pargs_hwrap: 4 - max_subgroups_hwrap: 2 - min_prefix_chars: 4 - max_prefix_chars: 6 - separate_ctrl_name_with_space: false - separate_fn_name_with_space: false - dangle_parens: false - command_case: canonical - keyword_case: unchanged - always_wrap: - - set_target_properties - - target_sources - - target_link_libraries - -parse: - # We define these for our custom - # functions so they get formatted correctly - additional_commands: - velox_add_library: - pargs: - nargs: 1+ - flags: - - OBJECT - - STATIC - - SHARED - - INTERFACE - kwargs: {} - - velox_base_add_library: - pargs: - nargs: 1+ - flags: - - OBJECT - - STATIC - - SHARED - - INTERFACE - kwargs: {} - - velox_compile_definitions: - pargs: 1 - kwargs: - PRIVATE: '*' - PUBLIC: '*' - INTERFACE: '*' - - velox_include_directories: - pargs: 1+ - flags: - - SYSTEM - - BEFORE - - AFTER - kwargs: - PRIVATE: '*' - PUBLIC: '*' - INTERFACE: '*' - - velox_link_libraries: - pargs: 1+ - kwargs: - PRIVATE: '*' - PUBLIC: '*' - INTERFACE: '*' -markup: - first_comment_is_literal: true diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000000..6506b8ae624c --- /dev/null +++ b/.dockerignore @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +<<<<<<<< HEAD:.yamlfmt.yml +match_type: doublestar +exclude: + - '**/.clang-format' + - '**/.clang-tidy' +formatter: + type: basic + retain_line_breaks_single: true + scan_folded_as_literal: true + indent: 2 +======== + +.git* +docs/ +static/ +website/ + +_build/ + +# Python binary dirs +.venv/ +dist/ +wheelhouse/ +>>>>>>>> 59b492a9bce45f487f24b5cbae7dc845ea3d0827:.dockerignore diff --git a/.gersemirc b/.gersemirc new file mode 100644 index 000000000000..6d6ca521020c --- /dev/null +++ b/.gersemirc @@ -0,0 +1,13 @@ +# vim: set filetype=yaml : + +line_length: 100 +indent: 2 +definitions: + - CMake/third-party/FBCMakeParseArgs.cmake + - CMake/third-party/FBThriftCppLibrary.cmake + - CMake/FindThrift.cmake + - velox/experimental/breeze/cmake + - velox/experimental/breeze/test + - velox/experimental/breeze/perftest +extensions: + - scripts/ci/gersemi_cmd_definitions.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 03e20d010ca9..79cac5c9a768 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -24,14 +24,18 @@ # Build & CI CMake/ @assignUser @majetideepak -*.cmake @assignUser @majetideepak -**/CMakeLists.txt @assignUser @majetideepak +/CMakeLists.txt @assignUser +*.cmake @majetideepak +**/CMakeLists.txt @majetideepak scripts/ @assignUser @majetideepak .github/ @assignUser @majetideepak # Breeze velox/experimental/breeze @dreveman +# cuDF +velox/experimental/cudf/ @bdice @karthikeyann @devavret @mhaseeb123 + # Parquet velox/dwio/parquet/ @majetideepak @@ -43,3 +47,6 @@ velox/connectors/ @majetideepak # Caching velox/common/caching/ @majetideepak + +# Spark Functions +velox/functions/sparksql/ @jinchengchenghh @rui-mo @zhli1142015 diff --git a/.github/workflows/linux-build-base.yml b/.github/workflows/linux-build-base.yml index 044c9757fedd..de2a187dcbea 100644 --- a/.github/workflows/linux-build-base.yml +++ b/.github/workflows/linux-build-base.yml @@ -84,7 +84,7 @@ jobs: export CC=/usr/bin/clang-15 export CXX=/usr/bin/clang++-15 fi - make release + make release TREAT_WARNINGS_AS_ERRORS=0 - name: Show CCache stats after build run: ccache -vs diff --git a/.github/zizmor.yml b/.github/zizmor.yml index 82e325d5047c..0919ec9a9d26 100644 --- a/.github/zizmor.yml +++ b/.github/zizmor.yml @@ -15,9 +15,6 @@ rules: use-trusted-publishing: ignore: - build_pyvelox.yml - unpinned-images: ignore: - # linux-build-base.yml builds and publishes a container image which we then use immediately; - # so we need to use the images dynamically generated (i.e., unpinned) tag. - linux-build-base.yml diff --git a/.gitignore b/.gitignore index 3ebb2942c5c4..9ce01803eaf9 100644 --- a/.gitignore +++ b/.gitignore @@ -293,8 +293,7 @@ distribute/* *.bin cmake_build .cmake_build -cmake-build-debug -cmake-build-release +cmake-build-* # tests test/test.sql @@ -325,5 +324,6 @@ src/amalgamation/ velox/docs/sphinx/source/README_generated_* velox/docs/bindings/python/_generate/* scripts/bm-report/report.html +scripts/bm-report/result dist/ wheelhouse/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c902b08ed1ff..2a6563fdc5df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,22 +35,16 @@ repos: - repo: local hooks: - - id: cmake-format - name: cmake-format - description: Format CMake files. - entry: cmake-format - language: python - files: (CMakeLists.*|.*\.cmake|.*\.cmake.in)$ - args: [--in-place] - require_serial: false - additional_dependencies: [cmake-format==0.6.13, pyyaml] - - id: clang-tidy name: clang-tidy - description: Run clang-tidy on C/C++ files + description: >- + Run clang-tidy on C/C++ files, requires 'compile_commands.json' + to be available in the repo root (e.g. symlinked) or BUILD_PATH + set to it's path e.g. _build/release. stages: - - manual # Needs compile_commands.json - entry: clang-tidy + - pre-commit + entry: ./scripts/checks/run-clang-tidy.py + args: [--commit, HEAD] language: python types_or: [c++, c] additional_dependencies: [clang-tidy==18.1.8] @@ -76,8 +70,14 @@ repos: NOTICE.txt )$ + - repo: https://github.com/BlankSpruce/gersemi + rev: 0.21.0 + hooks: + - id: gersemi + name: CMake formatter + - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.3 + rev: v21.1.2 hooks: - id: clang-format # types_or: [c++, c, cuda, metal, objective-c] @@ -91,6 +91,19 @@ repos: args: [--fix] - id: ruff-format + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.10.0.1 + hooks: + - id: shellcheck + args: [-x, --severity=warning] + + - repo: https://github.com/scop/pre-commit-shfmt + rev: v3.11.0-1 + hooks: + - id: shfmt + # w: write changes, s: simplify, i set indent to 2 spaces + args: [-w, -s, -i, '2'] + # The following checks mostly target GitHub Actions workflows. - repo: https://github.com/adrienverge/yamllint.git rev: v1.37.0 diff --git a/CMake/FindArrow.cmake b/CMake/FindArrow.cmake index 4e09b3971dcd..99bd6412f6fb 100644 --- a/CMake/FindArrow.cmake +++ b/CMake/FindArrow.cmake @@ -25,7 +25,8 @@ find_package_handle_standard_args( ARROW_LIB ARROW_TESTING_LIB ARROW_INCLUDE_PATH - Thrift_FOUND) + Thrift_FOUND +) # Only add the libraries once. if(Arrow_FOUND AND NOT TARGET arrow) @@ -34,10 +35,13 @@ if(Arrow_FOUND AND NOT TARGET arrow) add_library(thrift ALIAS thrift::thrift) set_target_properties( - arrow arrow_testing PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - ${ARROW_INCLUDE_PATH}) - set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ${ARROW_LIB} - INTERFACE_LINK_LIBRARIES thrift) - set_target_properties(arrow_testing PROPERTIES IMPORTED_LOCATION - ${ARROW_TESTING_LIB}) + arrow + arrow_testing + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ARROW_INCLUDE_PATH} + ) + set_target_properties( + arrow + PROPERTIES IMPORTED_LOCATION ${ARROW_LIB} INTERFACE_LINK_LIBRARIES thrift + ) + set_target_properties(arrow_testing PROPERTIES IMPORTED_LOCATION ${ARROW_TESTING_LIB}) endif() diff --git a/CMake/FindSnappy.cmake b/CMake/FindSnappy.cmake deleted file mode 100644 index e2a0359f8f7c..000000000000 --- a/CMake/FindSnappy.cmake +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# - Try to find snappy -# Once done, this will define -# -# SNAPPY_FOUND - system has Glog -# SNAPPY_INCLUDE_DIRS - deprecated -# SNAPPY_LIBRARIES - deprecated -# Snappy::snappy will be defined based on CMAKE_FIND_LIBRARY_SUFFIXES priority - -include(FindPackageHandleStandardArgs) -include(SelectLibraryConfigurations) - -find_library(SNAPPY_LIBRARY_RELEASE snappy PATHS $SNAPPY_LIBRARYDIR}) -find_library(SNAPPY_LIBRARY_DEBUG snappyd PATHS ${SNAPPY_LIBRARYDIR}) - -find_path(SNAPPY_INCLUDE_DIR snappy.h PATHS ${SNAPPY_INCLUDEDIR}) - -select_library_configurations(SNAPPY) - -find_package_handle_standard_args(Snappy DEFAULT_MSG SNAPPY_LIBRARY - SNAPPY_INCLUDE_DIR) - -mark_as_advanced(SNAPPY_LIBRARY SNAPPY_INCLUDE_DIR) - -get_filename_component(libsnappy_ext ${SNAPPY_LIBRARY} EXT) -if(libsnappy_ext STREQUAL ".a") - set(libsnappy_type STATIC) -else() - set(libsnappy_type SHARED) -endif() - -if(NOT TARGET Snappy::snappy) - add_library(Snappy::snappy ${libsnappy_type} IMPORTED) - set_target_properties(Snappy::snappy PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "${SNAPPY_INCLUDE_DIR}") - set_target_properties( - Snappy::snappy PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" - IMPORTED_LOCATION "${SNAPPY_LIBRARIES}") -endif() diff --git a/CMake/FindSodium.cmake b/CMake/FindSodium.cmake index 68ea1f96550c..2cd1c4e3de2e 100644 --- a/CMake/FindSodium.cmake +++ b/CMake/FindSodium.cmake @@ -40,9 +40,12 @@ if(NOT (sodium_USE_STATIC_LIBS EQUAL sodium_USE_STATIC_LIBS_LAST)) unset(sodium_LIBRARY_RELEASE CACHE) unset(sodium_DLL_DEBUG CACHE) unset(sodium_DLL_RELEASE CACHE) - set(sodium_USE_STATIC_LIBS_LAST - ${sodium_USE_STATIC_LIBS} - CACHE INTERNAL "internal change tracking variable") + set( + sodium_USE_STATIC_LIBS_LAST + ${sodium_USE_STATIC_LIBS} + CACHE INTERNAL + "internal change tracking variable" + ) endif() # ############################################################################## @@ -57,7 +60,7 @@ if(UNIX) if(sodium_USE_STATIC_LIBS) foreach(_libname ${sodium_PKG_STATIC_LIBRARIES}) if(NOT _libname MATCHES "^lib.*\\.a$") # ignore strings already ending - # with .a + # with .a list(INSERT sodium_PKG_STATIC_LIBRARIES 0 "lib${_libname}.a") endif() endforeach() @@ -79,32 +82,26 @@ if(UNIX) endif() find_path(sodium_INCLUDE_DIR sodium.h HINTS ${${XPREFIX}_INCLUDE_DIRS}) - find_library( - sodium_LIBRARY_DEBUG - NAMES ${${XPREFIX}_LIBRARIES} - HINTS ${${XPREFIX}_LIBRARY_DIRS}) + find_library(sodium_LIBRARY_DEBUG NAMES ${${XPREFIX}_LIBRARIES} HINTS ${${XPREFIX}_LIBRARY_DIRS}) find_library( sodium_LIBRARY_RELEASE NAMES ${${XPREFIX}_LIBRARIES} - HINTS ${${XPREFIX}_LIBRARY_DIRS}) + HINTS ${${XPREFIX}_LIBRARY_DIRS} + ) # ############################################################################ # Windows elseif(WIN32) - set(sodium_DIR - "$ENV{sodium_DIR}" - CACHE FILEPATH "sodium install directory") + set(sodium_DIR "$ENV{sodium_DIR}" CACHE FILEPATH "sodium install directory") mark_as_advanced(sodium_DIR) - find_path( - sodium_INCLUDE_DIR sodium.h - HINTS ${sodium_DIR} - PATH_SUFFIXES include) + find_path(sodium_INCLUDE_DIR sodium.h HINTS ${sodium_DIR} PATH_SUFFIXES include) if(MSVC) # detect target architecture file( - WRITE "${CMAKE_CURRENT_BINARY_DIR}/arch.cpp" + WRITE + "${CMAKE_CURRENT_BINARY_DIR}/arch.cpp" [=[ #if defined _M_IX86 #error ARCH_VALUE x86_32 @@ -112,13 +109,15 @@ elseif(WIN32) #error ARCH_VALUE x86_64 #endif #error ARCH_VALUE unknown - ]=]) + ]=] + ) try_compile( - _UNUSED_VAR "${CMAKE_CURRENT_BINARY_DIR}" + _UNUSED_VAR + "${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_BINARY_DIR}/arch.cpp" - OUTPUT_VARIABLE _COMPILATION_LOG) - string(REGEX REPLACE ".*ARCH_VALUE ([a-zA-Z0-9_]+).*" "\\1" _TARGET_ARCH - "${_COMPILATION_LOG}") + OUTPUT_VARIABLE _COMPILATION_LOG + ) + string(REGEX REPLACE ".*ARCH_VALUE ([a-zA-Z0-9_]+).*" "\\1" _TARGET_ARCH "${_COMPILATION_LOG}") # construct library path if(_TARGET_ARCH STREQUAL "x86_32") @@ -126,10 +125,7 @@ elseif(WIN32) elseif(_TARGET_ARCH STREQUAL "x86_64") string(APPEND _PLATFORM_PATH "x64") else() - message( - FATAL_ERROR - "the ${_TARGET_ARCH} architecture is not supported by Findsodium.cmake." - ) + message(FATAL_ERROR "the ${_TARGET_ARCH} architecture is not supported by Findsodium.cmake.") endif() string(APPEND _PLATFORM_PATH "/$$CONFIG$$") @@ -147,64 +143,53 @@ elseif(WIN32) endif() string(REPLACE "$$CONFIG$$" "Debug" _DEBUG_PATH_SUFFIX "${_PLATFORM_PATH}") - string(REPLACE "$$CONFIG$$" "Release" _RELEASE_PATH_SUFFIX - "${_PLATFORM_PATH}") + string(REPLACE "$$CONFIG$$" "Release" _RELEASE_PATH_SUFFIX "${_PLATFORM_PATH}") find_library( - sodium_LIBRARY_DEBUG libsodium.lib + sodium_LIBRARY_DEBUG + libsodium.lib HINTS ${sodium_DIR} - PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX}) + PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX} + ) find_library( - sodium_LIBRARY_RELEASE libsodium.lib + sodium_LIBRARY_RELEASE + libsodium.lib HINTS ${sodium_DIR} - PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX}) + PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX} + ) if(NOT sodium_USE_STATIC_LIBS) set(CMAKE_FIND_LIBRARY_SUFFIXES_BCK ${CMAKE_FIND_LIBRARY_SUFFIXES}) set(CMAKE_FIND_LIBRARY_SUFFIXES ".dll") find_library( - sodium_DLL_DEBUG libsodium + sodium_DLL_DEBUG + libsodium HINTS ${sodium_DIR} - PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX}) + PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX} + ) find_library( - sodium_DLL_RELEASE libsodium + sodium_DLL_RELEASE + libsodium HINTS ${sodium_DIR} - PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX}) + PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX} + ) set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES_BCK}) endif() - elseif(_GCC_COMPATIBLE) if(sodium_USE_STATIC_LIBS) - find_library( - sodium_LIBRARY_DEBUG libsodium.a - HINTS ${sodium_DIR} - PATH_SUFFIXES lib) - find_library( - sodium_LIBRARY_RELEASE libsodium.a - HINTS ${sodium_DIR} - PATH_SUFFIXES lib) + find_library(sodium_LIBRARY_DEBUG libsodium.a HINTS ${sodium_DIR} PATH_SUFFIXES lib) + find_library(sodium_LIBRARY_RELEASE libsodium.a HINTS ${sodium_DIR} PATH_SUFFIXES lib) else() - find_library( - sodium_LIBRARY_DEBUG libsodium.dll.a - HINTS ${sodium_DIR} - PATH_SUFFIXES lib) - find_library( - sodium_LIBRARY_RELEASE libsodium.dll.a - HINTS ${sodium_DIR} - PATH_SUFFIXES lib) + find_library(sodium_LIBRARY_DEBUG libsodium.dll.a HINTS ${sodium_DIR} PATH_SUFFIXES lib) + find_library(sodium_LIBRARY_RELEASE libsodium.dll.a HINTS ${sodium_DIR} PATH_SUFFIXES lib) file( GLOB _DLL LIST_DIRECTORIES false RELATIVE "${sodium_DIR}/bin" - "${sodium_DIR}/bin/libsodium*.dll") - find_library( - sodium_DLL_DEBUG ${_DLL} libsodium - HINTS ${sodium_DIR} - PATH_SUFFIXES bin) - find_library( - sodium_DLL_RELEASE ${_DLL} libsodium - HINTS ${sodium_DIR} - PATH_SUFFIXES bin) + "${sodium_DIR}/bin/libsodium*.dll" + ) + find_library(sodium_DLL_DEBUG ${_DLL} libsodium HINTS ${sodium_DIR} PATH_SUFFIXES bin) + find_library(sodium_DLL_RELEASE ${_DLL} libsodium HINTS ${sodium_DIR} PATH_SUFFIXES bin) endif() else() message(FATAL_ERROR "this platform is not supported by FindSodium.cmake") @@ -225,12 +210,13 @@ if(sodium_INCLUDE_DIR) if(EXISTS _VERSION_HEADER) file(READ "${_VERSION_HEADER}" _VERSION_HEADER_CONTENT) string( - REGEX - REPLACE ".*#[ \t]*define[ \t]*SODIUM_VERSION_STRING[ \t]*\"([^\n]*)\".*" - "\\1" sodium_VERSION "${_VERSION_HEADER_CONTENT}") - set(sodium_VERSION - "${sodium_VERSION}" - PARENT_SCOPE) + REGEX REPLACE + ".*#[ \t]*define[ \t]*SODIUM_VERSION_STRING[ \t]*\"([^\n]*)\".*" + "\\1" + sodium_VERSION + "${_VERSION_HEADER_CONTENT}" + ) + set(sodium_VERSION "${sodium_VERSION}" PARENT_SCOPE) endif() endif() @@ -239,11 +225,11 @@ include(FindPackageHandleStandardArgs) find_package_handle_standard_args( Sodium # The name must be either uppercase or match the filename case. REQUIRED_VARS sodium_LIBRARY_RELEASE sodium_LIBRARY_DEBUG sodium_INCLUDE_DIR - VERSION_VAR sodium_VERSION) + VERSION_VAR sodium_VERSION +) if(Sodium_FOUND) - set(sodium_LIBRARIES optimized ${sodium_LIBRARY_RELEASE} debug - ${sodium_LIBRARY_DEBUG}) + set(sodium_LIBRARIES optimized ${sodium_LIBRARY_RELEASE} debug ${sodium_LIBRARY_DEBUG}) endif() # mark file paths as advanced @@ -268,8 +254,10 @@ endif() set_target_properties( sodium - PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${sodium_INCLUDE_DIR}" - IMPORTED_LINK_INTERFACE_LANGUAGES "C") + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${sodium_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" +) if(sodium_USE_STATIC_LIBS) set_target_properties( @@ -277,19 +265,25 @@ if(sodium_USE_STATIC_LIBS) PROPERTIES INTERFACE_COMPILE_DEFINITIONS "SODIUM_STATIC" IMPORTED_LOCATION "${sodium_LIBRARY_RELEASE}" - IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}") + IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}" + ) else() if(UNIX) set_target_properties( - sodium PROPERTIES IMPORTED_LOCATION "${sodium_LIBRARY_RELEASE}" - IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}") + sodium + PROPERTIES + IMPORTED_LOCATION "${sodium_LIBRARY_RELEASE}" + IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}" + ) elseif(WIN32) set_target_properties( - sodium PROPERTIES IMPORTED_IMPLIB "${sodium_LIBRARY_RELEASE}" - IMPORTED_IMPLIB_DEBUG "${sodium_LIBRARY_DEBUG}") + sodium + PROPERTIES + IMPORTED_IMPLIB "${sodium_LIBRARY_RELEASE}" + IMPORTED_IMPLIB_DEBUG "${sodium_LIBRARY_DEBUG}" + ) if(NOT (sodium_DLL_DEBUG MATCHES ".*-NOTFOUND")) - set_target_properties(sodium PROPERTIES IMPORTED_LOCATION_DEBUG - "${sodium_DLL_DEBUG}") + set_target_properties(sodium PROPERTIES IMPORTED_LOCATION_DEBUG "${sodium_DLL_DEBUG}") endif() if(NOT (sodium_DLL_RELEASE MATCHES ".*-NOTFOUND")) set_target_properties( @@ -297,7 +291,8 @@ else() PROPERTIES IMPORTED_LOCATION_RELWITHDEBINFO "${sodium_DLL_RELEASE}" IMPORTED_LOCATION_MINSIZEREL "${sodium_DLL_RELEASE}" - IMPORTED_LOCATION_RELEASE "${sodium_DLL_RELEASE}") + IMPORTED_LOCATION_RELEASE "${sodium_DLL_RELEASE}" + ) endif() endif() endif() diff --git a/CMake/FindThrift.cmake b/CMake/FindThrift.cmake index 273500a6ae36..48c1a84b1d1a 100644 --- a/CMake/FindThrift.cmake +++ b/CMake/FindThrift.cmake @@ -30,16 +30,16 @@ function(EXTRACT_THRIFT_VERSION) if(THRIFT_INCLUDE_DIR) file(READ "${THRIFT_INCLUDE_DIR}/thrift/config.h" THRIFT_CONFIG_H_CONTENT) - string(REGEX MATCH "#define PACKAGE_VERSION \"[0-9.]+\"" - THRIFT_VERSION_DEFINITION "${THRIFT_CONFIG_H_CONTENT}") + string( + REGEX MATCH + "#define PACKAGE_VERSION \"[0-9.]+\"" + THRIFT_VERSION_DEFINITION + "${THRIFT_CONFIG_H_CONTENT}" + ) string(REGEX MATCH "[0-9.]+" Thrift_VERSION "${THRIFT_VERSION_DEFINITION}") - set(Thrift_VERSION - "${Thrift_VERSION}" - PARENT_SCOPE) + set(Thrift_VERSION "${Thrift_VERSION}" PARENT_SCOPE) else() - set(Thrift_VERSION - "" - PARENT_SCOPE) + set(Thrift_VERSION "" PARENT_SCOPE) endif() endfunction(EXTRACT_THRIFT_VERSION) @@ -77,8 +77,9 @@ if(ARROW_THRIFT_USE_SHARED) "${CMAKE_SHARED_LIBRARY_PREFIX}${THRIFT_LIB_NAME_BASE}${CMAKE_SHARED_LIBRARY_SUFFIX}" ) else() - set(THRIFT_LIB_NAMES - "${CMAKE_STATIC_LIBRARY_PREFIX}${THRIFT_LIB_NAME_BASE}${CMAKE_STATIC_LIBRARY_SUFFIX}" + set( + THRIFT_LIB_NAMES + "${CMAKE_STATIC_LIBRARY_PREFIX}${THRIFT_LIB_NAME_BASE}${CMAKE_STATIC_LIBRARY_SUFFIX}" ) endif() @@ -87,16 +88,11 @@ if(Thrift_ROOT) THRIFT_LIB NAMES ${THRIFT_LIB_NAMES} PATHS ${Thrift_ROOT} - PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}" "lib") - find_path( - THRIFT_INCLUDE_DIR thrift/Thrift.h - PATHS ${Thrift_ROOT} - PATH_SUFFIXES "include") - find_program( - THRIFT_COMPILER thrift - PATHS ${Thrift_ROOT} - PATH_SUFFIXES "bin") - extract_thrift_version() + PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}" "lib" + ) + find_path(THRIFT_INCLUDE_DIR thrift/Thrift.h PATHS ${Thrift_ROOT} PATH_SUFFIXES "include") + find_program(THRIFT_COMPILER thrift PATHS ${Thrift_ROOT} PATH_SUFFIXES "bin") + EXTRACT_THRIFT_VERSION() else() # THRIFT-4760: The pkgconfig files are currently only installed when using # autotools. Starting with 0.13, they are also installed for the CMake-based @@ -112,21 +108,25 @@ else() THRIFT_LIB NAMES ${THRIFT_LIB_NAMES} PATHS ${THRIFT_PC_LIBRARY_DIRS} - NO_DEFAULT_PATH) + NO_DEFAULT_PATH + ) find_program( - THRIFT_COMPILER thrift + THRIFT_COMPILER + thrift HINTS ${THRIFT_PC_PREFIX} NO_DEFAULT_PATH - PATH_SUFFIXES "bin") + PATH_SUFFIXES "bin" + ) set(Thrift_VERSION ${THRIFT_PC_VERSION}) else() find_library( THRIFT_LIB NAMES ${THRIFT_LIB_NAMES} - PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}" "lib") + PATH_SUFFIXES "lib/${CMAKE_LIBRARY_ARCHITECTURE}" "lib" + ) find_path(THRIFT_INCLUDE_DIR thrift/Thrift.h PATH_SUFFIXES "include") find_program(THRIFT_COMPILER thrift PATH_SUFFIXES "bin") - extract_thrift_version() + EXTRACT_THRIFT_VERSION() endif() endif() @@ -140,7 +140,8 @@ find_package_handle_standard_args( Thrift REQUIRED_VARS THRIFT_LIB THRIFT_INCLUDE_DIR VERSION_VAR Thrift_VERSION - HANDLE_COMPONENTS) + HANDLE_COMPONENTS +) if(Thrift_FOUND) if(ARROW_THRIFT_USE_SHARED) @@ -150,18 +151,18 @@ if(Thrift_FOUND) endif() set_target_properties( thrift::thrift - PROPERTIES IMPORTED_LOCATION "${THRIFT_LIB}" INTERFACE_INCLUDE_DIRECTORIES - "${THRIFT_INCLUDE_DIR}") + PROPERTIES + IMPORTED_LOCATION "${THRIFT_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${THRIFT_INCLUDE_DIR}" + ) if(WIN32 AND NOT MSVC_TOOLCHAIN) # We don't need this for Visual C++ because Thrift uses "#pragma # comment(lib, "Ws2_32.lib")" in thrift/windows/config.h for Visual C++. - set_target_properties(thrift::thrift PROPERTIES INTERFACE_LINK_LIBRARIES - "ws2_32") + set_target_properties(thrift::thrift PROPERTIES INTERFACE_LINK_LIBRARIES "ws2_32") endif() if(Thrift_COMPILER_FOUND) add_executable(thrift::compiler IMPORTED) - set_target_properties(thrift::compiler PROPERTIES IMPORTED_LOCATION - "${THRIFT_COMPILER}") + set_target_properties(thrift::compiler PROPERTIES IMPORTED_LOCATION "${THRIFT_COMPILER}") endif() endif() diff --git a/CMake/FindXxhash.cmake b/CMake/FindXxhash.cmake index 04419b7c238f..31b828243274 100644 --- a/CMake/FindXxhash.cmake +++ b/CMake/FindXxhash.cmake @@ -25,8 +25,7 @@ include(SelectLibraryConfigurations) select_library_configurations(Xxhash) include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(Xxhash DEFAULT_MSG Xxhash_LIBRARY - Xxhash_INCLUDE_DIR) +find_package_handle_standard_args(Xxhash DEFAULT_MSG Xxhash_LIBRARY Xxhash_INCLUDE_DIR) if(Xxhash_FOUND) message(STATUS "Found xxhash: ${Xxhash_LIBRARY}") diff --git a/CMake/Findc-ares.cmake b/CMake/Findc-ares.cmake index 0e3f1fe0938a..58087c018fa9 100644 --- a/CMake/Findc-ares.cmake +++ b/CMake/Findc-ares.cmake @@ -19,22 +19,20 @@ if(c-ares_FOUND) endif() endif() -find_path( - C_ARES_INCLUDE_DIR - NAMES ares.h - PATH_SUFFIXES c-ares) +find_path(C_ARES_INCLUDE_DIR NAMES ares.h PATH_SUFFIXES c-ares) find_library(C_ARES_LIBRARY NAMES c-ares) include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(c-ares DEFAULT_MSG C_ARES_LIBRARY - C_ARES_INCLUDE_DIR) +find_package_handle_standard_args(c-ares DEFAULT_MSG C_ARES_LIBRARY C_ARES_INCLUDE_DIR) if(c-ares_FOUND AND NOT TARGET c-ares::cares) add_library(c-ares::cares UNKNOWN IMPORTED) set_target_properties( c-ares::cares - PROPERTIES IMPORTED_LOCATION "${C_ARES_LIBRARY}" - INTERFACE_INCLUDE_DIRECTORIES "${C_ARES_INCLUDE_DIR}") + PROPERTIES + IMPORTED_LOCATION "${C_ARES_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${C_ARES_INCLUDE_DIR}" + ) endif() mark_as_advanced(C_ARES_INCLUDE_DIR C_ARES_LIBRARY) diff --git a/CMake/Finddouble-conversion.cmake b/CMake/Finddouble-conversion.cmake index 457bd88e26c1..ec1059067dfe 100644 --- a/CMake/Finddouble-conversion.cmake +++ b/CMake/Finddouble-conversion.cmake @@ -12,23 +12,25 @@ if(double-conversion_FOUND) endif() endif() -find_path( - DOUBLE_CONVERSION_INCLUDE_DIR - NAMES double-conversion.h - PATH_SUFFIXES double-conversion) +find_path(DOUBLE_CONVERSION_INCLUDE_DIR NAMES double-conversion.h PATH_SUFFIXES double-conversion) find_library(DOUBLE_CONVERSION_LIBRARY NAMES double-conversion) include(FindPackageHandleStandardArgs) find_package_handle_standard_args( - double-conversion DEFAULT_MSG DOUBLE_CONVERSION_LIBRARY - DOUBLE_CONVERSION_INCLUDE_DIR) + double-conversion + DEFAULT_MSG + DOUBLE_CONVERSION_LIBRARY + DOUBLE_CONVERSION_INCLUDE_DIR +) if(double-conversion_FOUND AND NOT TARGET double-conversion::double-conversion) add_library(double-conversion::double-conversion UNKNOWN IMPORTED) set_target_properties( double-conversion::double-conversion - PROPERTIES IMPORTED_LOCATION "${DOUBLE_CONVERSION_LIBRARY}" - INTERFACE_INCLUDE_DIRECTORIES "${DOUBLE_CONVERSION_INCLUDE_DIR}") + PROPERTIES + IMPORTED_LOCATION "${DOUBLE_CONVERSION_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${DOUBLE_CONVERSION_INCLUDE_DIR}" + ) endif() mark_as_advanced(DOUBLE_CONVERSION_INCLUDE_DIR DOUBLE_CONVERSION_LIBRARY) diff --git a/CMake/Findglog.cmake b/CMake/Findglog.cmake index 81deadb36442..067669eece03 100644 --- a/CMake/Findglog.cmake +++ b/CMake/Findglog.cmake @@ -29,8 +29,7 @@ find_path(GLOG_INCLUDE_DIR glog/logging.h PATHS ${GLOG_INCLUDEDIR}) select_library_configurations(GLOG) -find_package_handle_standard_args(glog DEFAULT_MSG GLOG_LIBRARY - GLOG_INCLUDE_DIR) +find_package_handle_standard_args(glog DEFAULT_MSG GLOG_LIBRARY GLOG_INCLUDE_DIR) mark_as_advanced(GLOG_LIBRARY GLOG_INCLUDE_DIR) @@ -39,9 +38,9 @@ set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) if(NOT TARGET glog::glog) add_library(glog::glog UNKNOWN IMPORTED) - set_target_properties(glog::glog PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "${GLOG_INCLUDE_DIRS}") + set_target_properties(glog::glog PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${GLOG_INCLUDE_DIRS}") set_target_properties( - glog::glog PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" - IMPORTED_LOCATION "${GLOG_LIBRARIES}") + glog::glog + PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${GLOG_LIBRARIES}" + ) endif() diff --git a/CMake/Findlz4.cmake b/CMake/Findlz4.cmake index d13c951b8898..2840d0174795 100644 --- a/CMake/Findlz4.cmake +++ b/CMake/Findlz4.cmake @@ -43,9 +43,9 @@ endif() if(NOT TARGET lz4::lz4) add_library(lz4::lz4 ${liblz4_type} IMPORTED) - set_target_properties(lz4::lz4 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "${LZ4_INCLUDE_DIR}") + set_target_properties(lz4::lz4 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LZ4_INCLUDE_DIR}") set_target_properties( - lz4::lz4 PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" - IMPORTED_LOCATION "${LZ4_LIBRARIES}") + lz4::lz4 + PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${LZ4_LIBRARIES}" + ) endif() diff --git a/CMake/Findlzo2.cmake b/CMake/Findlzo2.cmake index 9f9fbbbe11ca..f052fd8fc0dc 100644 --- a/CMake/Findlzo2.cmake +++ b/CMake/Findlzo2.cmake @@ -30,8 +30,7 @@ find_path(LZO2_INCLUDE_DIR lzo/lzo1a.h PATHS ${LZO2_INCLUDEDIR}) select_library_configurations(LZO2) -find_package_handle_standard_args(lzo2 DEFAULT_MSG LZO2_LIBRARY - LZO2_INCLUDE_DIR) +find_package_handle_standard_args(lzo2 DEFAULT_MSG LZO2_LIBRARY LZO2_INCLUDE_DIR) mark_as_advanced(LZO2_LIBRARY LZO2_INCLUDE_DIR) @@ -44,9 +43,9 @@ endif() if(NOT TARGET lzo2::lzo2) add_library(lzo2::lzo2 ${liblzo2_type} IMPORTED) - set_target_properties(lzo2::lzo2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "${LZO2_INCLUDE_DIR}") + set_target_properties(lzo2::lzo2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LZO2_INCLUDE_DIR}") set_target_properties( - lzo2::lzo2 PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" - IMPORTED_LOCATION "${LZO2_LIBRARIES}") + lzo2::lzo2 + PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${LZO2_LIBRARIES}" + ) endif() diff --git a/CMake/Findpyarrow.cmake b/CMake/Findpyarrow.cmake index 593ef06fc170..f7df31083301 100644 --- a/CMake/Findpyarrow.cmake +++ b/CMake/Findpyarrow.cmake @@ -16,32 +16,30 @@ find_package(Python REQUIRED COMPONENTS Interpreter) execute_process( COMMAND - "${Python_EXECUTABLE}" -c "\ + "${Python_EXECUTABLE}" -c + "\ import pyarrow print(pyarrow.get_include()) " OUTPUT_VARIABLE _pyarrow_include_dir - OUTPUT_STRIP_TRAILING_WHITESPACE) + OUTPUT_STRIP_TRAILING_WHITESPACE +) execute_process( COMMAND - "${Python_EXECUTABLE}" -c "\ + "${Python_EXECUTABLE}" -c + "\ import pyarrow pyarrow.create_library_symlinks() print(pyarrow.get_library_dirs()[0]) " OUTPUT_VARIABLE _pyarrow_lib_dir - OUTPUT_STRIP_TRAILING_WHITESPACE) + OUTPUT_STRIP_TRAILING_WHITESPACE +) -find_library( - _libarrow arrow - PATHS ${_pyarrow_lib_dir} - NO_DEFAULT_PATH) +find_library(_libarrow arrow PATHS ${_pyarrow_lib_dir} NO_DEFAULT_PATH) -find_library( - _libarrow_py arrow_python - PATHS ${_pyarrow_lib_dir} - NO_DEFAULT_PATH) +find_library(_libarrow_py arrow_python PATHS ${_pyarrow_lib_dir} NO_DEFAULT_PATH) set(pyarrow_LIBARROW ${_libarrow}) set(pyarrow_LIBARROW_PYTHON ${_libarrow_py}) @@ -51,8 +49,9 @@ mark_as_advanced(_libarrow _libarrow_py _pyarrow_include_dir _pyarrow_lib_dir) include(FindPackageHandleStandardArgs) find_package_handle_standard_args( - pyarrow REQUIRED_VARS pyarrow_LIBARROW pyarrow_LIBARROW_PYTHON - pyarrow_INCLUDE_DIR) + pyarrow + REQUIRED_VARS pyarrow_LIBARROW pyarrow_LIBARROW_PYTHON pyarrow_INCLUDE_DIR +) if(pyarrow_FOUND) if(NOT TARGET pyarrow::dev) @@ -63,7 +62,7 @@ if(pyarrow_FOUND) IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" IMPORTED_LOCATION "${pyarrow_LIBARROW_PYTHON}" INTERFACE_INCLUDE_DIRECTORIES "${pyarrow_INCLUDE_DIR}" - INTERFACE_LINK_LIBRARIES "${pyarrow_LIBARROW}") - + INTERFACE_LINK_LIBRARIES "${pyarrow_LIBARROW}" + ) endif() endif() diff --git a/CMake/Findre2.cmake b/CMake/Findre2.cmake index 1a438cce0e0a..ff8b484972c8 100644 --- a/CMake/Findre2.cmake +++ b/CMake/Findre2.cmake @@ -35,8 +35,7 @@ if(RE2_FOUND) set(re2_FOUND "${RE2_FOUND}") add_library(re2::re2 INTERFACE IMPORTED) if(RE2_INCLUDE_DIRS) - set_property(TARGET re2::re2 PROPERTY INTERFACE_INCLUDE_DIRECTORIES - "${RE2_INCLUDE_DIRS}") + set_property(TARGET re2::re2 PROPERTY INTERFACE_INCLUDE_DIRECTORIES "${RE2_INCLUDE_DIRS}") endif() if(RE2_CFLAGS_OTHER) # Filter out the -std flag, which is handled by CMAKE_CXX_STANDARD. @@ -47,12 +46,10 @@ if(RE2_FOUND) list(REMOVE_ITEM RE2_CFLAGS_OTHER "${flag}") endif() endforeach() - set_property(TARGET re2::re2 PROPERTY INTERFACE_COMPILE_OPTIONS - "${RE2_CFLAGS_OTHER}") + set_property(TARGET re2::re2 PROPERTY INTERFACE_COMPILE_OPTIONS "${RE2_CFLAGS_OTHER}") endif() if(RE2_LDFLAGS) - set_property(TARGET re2::re2 PROPERTY INTERFACE_LINK_LIBRARIES - "${RE2_LDFLAGS}") + set_property(TARGET re2::re2 PROPERTY INTERFACE_LINK_LIBRARIES "${RE2_LDFLAGS}") endif() message(STATUS "Found RE2 via pkg-config.") return() diff --git a/CMake/Findstemmer.cmake b/CMake/Findstemmer.cmake index 8796ce461ac5..7b22996d8943 100644 --- a/CMake/Findstemmer.cmake +++ b/CMake/Findstemmer.cmake @@ -25,6 +25,8 @@ if(NOT TARGET stemmer::stemmer) find_path(STEMMER_INCLUDE_PATH libstemmer.h) set_target_properties( stemmer::stemmer - PROPERTIES IMPORTED_LOCATION ${STEMMER_LIB} INTERFACE_INCLUDE_DIRECTORIES - ${STEMMER_INCLUDE_PATH}) + PROPERTIES + IMPORTED_LOCATION ${STEMMER_LIB} + INTERFACE_INCLUDE_DIRECTORIES ${STEMMER_INCLUDE_PATH} + ) endif() diff --git a/CMake/Findzstd.cmake b/CMake/Findzstd.cmake index 3285c5744d5b..26f7917e98bb 100644 --- a/CMake/Findzstd.cmake +++ b/CMake/Findzstd.cmake @@ -30,8 +30,7 @@ find_path(ZSTD_INCLUDE_DIR zstd.h PATHS ${ZSTD_INCLUDEDIR}) select_library_configurations(ZSTD) -find_package_handle_standard_args(zstd DEFAULT_MSG ZSTD_LIBRARY - ZSTD_INCLUDE_DIR) +find_package_handle_standard_args(zstd DEFAULT_MSG ZSTD_LIBRARY ZSTD_INCLUDE_DIR) mark_as_advanced(ZSTD_LIBRARY ZSTD_INCLUDE_DIR) @@ -44,11 +43,11 @@ endif() if(NOT TARGET zstd::zstd) add_library(zstd::zstd ${libzstd_type} IMPORTED) - set_target_properties(zstd::zstd PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - "${ZSTD_INCLUDE_DIR}") + set_target_properties(zstd::zstd PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${ZSTD_INCLUDE_DIR}") set_target_properties( - zstd::zstd PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" - IMPORTED_LOCATION "${ZSTD_LIBRARIES}") + zstd::zstd + PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${ZSTD_LIBRARIES}" + ) endif() if(NOT TARGET zstd::libzstd_shared) diff --git a/CMake/ResolveDependency.cmake b/CMake/ResolveDependency.cmake index 0290c154d316..3c3517a18cf3 100644 --- a/CMake/ResolveDependency.cmake +++ b/CMake/ResolveDependency.cmake @@ -32,8 +32,7 @@ include(FetchContent) include(ExternalProject) include(ProcessorCount) include(CheckCXXCompilerFlag) -list(APPEND CMAKE_MODULE_PATH - ${CMAKE_CURRENT_LIST_DIR}/resolve_dependency_modules) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/resolve_dependency_modules) # Enable SSL certificate verification for file downloads set(CMAKE_TLS_VERIFY true) @@ -83,9 +82,7 @@ macro(velox_resolve_dependency dependency_name) elseif(${dependency_name}_SOURCE STREQUAL "BUNDLED") velox_build_dependency(${dependency_name}) else() - message( - FATAL_ERROR - "Invalid source for ${dependency_name}: ${${dependency_name}_SOURCE}") + message(FATAL_ERROR "Invalid source for ${dependency_name}: ${${dependency_name}_SOURCE}") endif() list(POP_BACK CMAKE_MESSAGE_INDENT) @@ -93,10 +90,12 @@ endmacro() # By using a macro we don't need to propagate the value into the parent scope. macro(velox_set_source dependency_name) - velox_set_with_default(${dependency_name}_SOURCE ${dependency_name}_SOURCE - ${VELOX_DEPENDENCY_SOURCE}) - message( - STATUS "Setting ${dependency_name} source to ${${dependency_name}_SOURCE}") + velox_set_with_default( + ${dependency_name}_SOURCE + ${dependency_name}_SOURCE + ${VELOX_DEPENDENCY_SOURCE} + ) + message(STATUS "Setting ${dependency_name} source to ${${dependency_name}_SOURCE}") endmacro() # Set var_name to the value of $ENV{envvar_name} if ENV is defined. If neither @@ -105,13 +104,9 @@ endmacro() # automatically! Use PARENT_SCOPE. function(velox_set_with_default var_name envvar_name default) if(DEFINED ENV{${envvar_name}}) - set(${var_name} - $ENV{${envvar_name}} - PARENT_SCOPE) + set(${var_name} $ENV{${envvar_name}} PARENT_SCOPE) elseif(NOT DEFINED ${var_name}) - set(${var_name} - ${default} - PARENT_SCOPE) + set(${var_name} ${default} PARENT_SCOPE) endif() endfunction() @@ -121,13 +116,21 @@ macro(velox_resolve_dependency_url dependency_name) string(PREPEND VELOX_${dependency_name}_BUILD_SHA256_CHECKSUM "SHA256=") velox_set_with_default( - VELOX_${dependency_name}_SOURCE_URL VELOX_${dependency_name}_URL - ${VELOX_${dependency_name}_SOURCE_URL}) - message(VERBOSE "Set VELOX_${dependency_name}_SOURCE_URL to " - "${VELOX_${dependency_name}_SOURCE_URL}") + VELOX_${dependency_name}_SOURCE_URL + VELOX_${dependency_name}_URL + ${VELOX_${dependency_name}_SOURCE_URL} + ) + message( + VERBOSE + "Set VELOX_${dependency_name}_SOURCE_URL to " + "${VELOX_${dependency_name}_SOURCE_URL}" + ) if(DEFINED ENV{VELOX_${dependency_name}_URL}) - velox_set_with_default(VELOX_${dependency_name}_BUILD_SHA256_CHECKSUM - VELOX_${dependency_name}_SHA256 "") + velox_set_with_default( + VELOX_${dependency_name}_BUILD_SHA256_CHECKSUM + VELOX_${dependency_name}_SHA256 + "" + ) if(DEFINED ENV{VELOX_${dependency_name}_SHA256}) string(PREPEND VELOX_${dependency_name}_BUILD_SHA256_CHECKSUM "SHA256=") endif() diff --git a/CMake/VeloxConfig.cmake.in b/CMake/VeloxConfig.cmake.in new file mode 100644 index 000000000000..158f47277b4a --- /dev/null +++ b/CMake/VeloxConfig.cmake.in @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) + +block() + list(PREPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}") + + if("@Arrow_SOURCE@" STREQUAL "SYSTEM") + find_dependency(Arrow) + endif() + if("@Boost_SOURCE@" STREQUAL "SYSTEM") + find_dependency(Boost COMPONENTS "@BOOST_INCLUDE_LIBRARIES@") + endif() + find_dependency(double-conversion) + if("@folly_SOURCE@" STREQUAL "SYSTEM") + find_dependency(folly) + endif() + if("@fmt_SOURCE@" STREQUAL "SYSTEM") + find_dependency(fmt) + endif() + if("@gflags_SOURCE@" STREQUAL "SYSTEM") + find_dependency(gflags) + endif() + if("@glog_SOURCE@" STREQUAL "SYSTEM") + find_dependency(glog) + endif() + if("@VELOX_ENABLE_COMPRESSION_LZ4@") + find_dependency(lz4) + endif() + if("@Protobuf_SOURCE@" STREQUAL "SYSTEM") + find_dependency(Protobuf) + endif() + if("@re2_SOURCE@" STREQUAL "SYSTEM") + find_dependency(re2) + endif() + if("@stemmer_SOURCE@" STREQUAL "SYSTEM") + find_dependency(stemmer) + endif() + if("@VELOX_BUILD_MINIMAL_WITH_DWIO@" OR "@VELOX_ENABLE_HIVE_CONNECTOR@") + find_dependency(Snappy) + find_dependency(ZLIB) + find_dependency(zstd) + endif() + if("@simdjson_SOURCE@" STREQUAL "SYSTEM") + find_dependency(simdjson) + endif() + if("@Thrift_FOUND@") + find_dependency(Thrift) + endif() + if("@xsimd_SOURCE@" STREQUAL "SYSTEM") + find_dependency(xsimd) + endif() +endblock() + +include("${CMAKE_CURRENT_LIST_DIR}/VeloxTargets.cmake") + +check_required_components(Velox) diff --git a/CMake/VeloxUtils.cmake b/CMake/VeloxUtils.cmake index 5db4624555a8..8a63332fd335 100644 --- a/CMake/VeloxUtils.cmake +++ b/CMake/VeloxUtils.cmake @@ -12,34 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. include_guard(GLOBAL) -function(get_rpath_origin VAR) + +include(CMakePackageConfigHelpers) + +function(velox_get_rpath_origin VAR) if(APPLE) set(_origin @loader_path) else() set(_origin "\$ORIGIN") endif() - set(${VAR} - ${_origin} - PARENT_SCOPE) + set(${VAR} ${_origin} PARENT_SCOPE) endfunction() function(pyvelox_add_module TARGET) pybind11_add_module(${TARGET} ${ARGN}) if(DEFINED SKBUILD_PROJECT_VERSION_FULL) - target_compile_definitions( - ${TARGET} PRIVATE PYVELOX_VERSION=${SKBUILD_PROJECT_VERSION_FULL}) + target_compile_definitions(${TARGET} PRIVATE PYVELOX_VERSION=${SKBUILD_PROJECT_VERSION_FULL}) else() target_compile_definitions(${TARGET} PRIVATE PYVELOX_VERSION=dev) endif() # Set the rpath so linker looks within pyvelox package for libs - get_rpath_origin(_origin) + velox_get_rpath_origin(_origin) set_target_properties( - ${TARGET} PROPERTIES INSTALL_RPATH "${_origin}/;${CMAKE_BINARY_DIR}/lib" - INSTALL_RPATH_USE_LINK_PATH TRUE) - install(TARGETS ${TARGET} LIBRARY DESTINATION pyvelox - COMPONENT pyvelox_libraries) + ${TARGET} + PROPERTIES INSTALL_RPATH "${_origin}/;${CMAKE_BINARY_DIR}/lib" INSTALL_RPATH_USE_LINK_PATH TRUE + ) + install(TARGETS ${TARGET} LIBRARY DESTINATION pyvelox COMPONENT pyvelox_libraries) endfunction() # TODO use file sets @@ -50,10 +50,9 @@ function(velox_install_library_headers) cmake_path( RELATIVE_PATH CMAKE_CURRENT_SOURCE_DIR - BASE_DIRECTORY - "${CMAKE_SOURCE_DIR}" - OUTPUT_VARIABLE - _hdr_dir) + BASE_DIRECTORY "${CMAKE_SOURCE_DIR}" + OUTPUT_VARIABLE _hdr_dir + ) install(FILES ${_hdrs} DESTINATION include/${_hdr_dir}) endif() endfunction() @@ -70,12 +69,7 @@ function(velox_add_library TARGET) set(options OBJECT STATIC SHARED INTERFACE) set(oneValueArgs) set(multiValueArgs) - cmake_parse_arguments( - VELOX - "${options}" - "${oneValueArgs}" - "${multiValueArgs}" - ${ARGN}) + cmake_parse_arguments(VELOX "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) # Remove library type specifiers from ARGN set(library_type) @@ -98,8 +92,9 @@ function(velox_add_library TARGET) if(TARGET velox) # Target already exists, append sources to it. target_sources(velox PRIVATE ${ARGN}) - install(TARGETS velox LIBRARY DESTINATION pyvelox - COMPONENT pyvelox_libraries) + if(VELOX_BUILD_PYTHON_PACKAGE) + install(TARGETS velox LIBRARY DESTINATION pyvelox COMPONENT pyvelox_libraries) + endif() else() set(_type STATIC) if(VELOX_BUILD_SHARED) @@ -107,11 +102,56 @@ function(velox_add_library TARGET) endif() # Create the target if this is the first invocation. add_library(velox ${_type} ${ARGN}) - set_target_properties(velox PROPERTIES LIBRARY_OUTPUT_DIRECTORY - ${PROJECT_BINARY_DIR}/lib) - set_target_properties(velox PROPERTIES ARCHIVE_OUTPUT_DIRECTORY - ${PROJECT_BINARY_DIR}/lib) - install(TARGETS velox DESTINATION lib/velox) + set_target_properties(velox PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) + set_target_properties(velox PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib) + install(TARGETS velox DESTINATION lib/velox EXPORT velox_targets) + if(VELOX_BUILD_CMAKE_PACKAGE) + set(package_cmake_dir "lib/cmake/Velox") + set(config_cmake_in "${PROJECT_SOURCE_DIR}/CMake/VeloxConfig.cmake.in") + set(config_cmake "${PROJECT_BINARY_DIR}/CMake/VeloxConfig.cmake") + configure_package_config_file( + "${config_cmake_in}" + "${config_cmake}" + INSTALL_DESTINATION "${package_cmake_dir}" + ) + install(FILES "${config_cmake}" DESTINATION "${package_cmake_dir}") + set(system_dependencies) + if(Arrow_SOURCE STREQUAL "SYSTEM") + list(APPEND system_dependencies Arrow) + endif() + if(glog_SOURCE STREQUAL "SYSTEM") + list(APPEND system_dependencies glog) + endif() + if(VELOX_ENABLE_COMPRESSION_LZ4) + list(APPEND system_dependencies lz4) + endif() + if(re2_SOURCE STREQUAL "SYSTEM") + list(APPEND system_dependencies re2) + endif() + if(stemmer_SOURCE STREQUAL "SYSTEM") + list(APPEND system_dependencies stemmer) + endif() + if(VELOX_BUILD_MINIMAL_WITH_DWIO OR VELOX_ENABLE_HIVE_CONNECTOR) + list(APPEND system_dependencies Snappy zstd) + endif() + foreach(system_dependency ${system_dependencies}) + set(velox_find_module "${PROJECT_SOURCE_DIR}/CMake/Find${system_dependency}.cmake") + if(EXISTS "${velox_find_module}") + install(FILES "${velox_find_module}" DESTINATION "${package_cmake_dir}") + endif() + endforeach() + # TODO: We can enable this once we add version to Velox. + # set(version_cmake "${PROJECT_BINARY_DIR}/CMake/VeloxConfigVersion.cmake") + # write_basic_package_version_file("${version_cmake}" + # COMPATIBILITY SameMajorVersion) + # install(FILES "${version_cmake}" DESTINATION "${package_cmake_dir}") + install( + EXPORT velox_targets + DESTINATION "${package_cmake_dir}" + NAMESPACE "Velox::" + FILE "VeloxTargets.cmake" + ) + endif() endif() # create alias for compatability if(NOT TARGET ${TARGET}) @@ -131,15 +171,17 @@ function(velox_link_libraries TARGET) # These targets follow the velox_* name for consistency but are NOT actually # aliases to velox when building the mono lib and need to be linked # explicitly (this is a hack) - set(explicit_targets - velox_exec_test_lib - # see velox/experimental/wave/README.md - velox_wave_common - velox_wave_decode - velox_wave_dwio - velox_wave_exec - velox_wave_stream - velox_wave_vector) + set( + explicit_targets + velox_exec_test_lib + # see velox/experimental/wave/README.md + velox_wave_common + velox_wave_decode + velox_wave_dwio + velox_wave_exec + velox_wave_stream + velox_wave_vector + ) foreach(_arg ${ARGN}) list(FIND explicit_targets ${_arg} _explicit) diff --git a/CMake/resolve_dependency_modules/README.md b/CMake/resolve_dependency_modules/README.md index 268617cbabe7..a1314eec783c 100644 --- a/CMake/resolve_dependency_modules/README.md +++ b/CMake/resolve_dependency_modules/README.md @@ -6,45 +6,44 @@ The versions of certain libraries is the default provided by the platform's package manager. Some libraries can be bundled by Velox. See details on bundling below. -| Library Name | Minimum Version | Bundled? | -|-------------------|-----------------|----------| -| ninja | default | No | -| ccache | default | No | -| icu4c | default | Yes | -| gflags | default | Yes | -| glog | default | Yes | -| gtest (testing) | default | Yes | -| libevent | default | No | -| libsodium | default | No | -| lz4 | default | No | -| snappy | default | No | -| lzo | default | No | -| xz | default | No | -| zstd | default | No | -| openssl | default | No | -| protobuf | 21.7 >= x < 22 | Yes | -| boost | 1.77.0 | Yes | -| flex | 2.5.13 | No | -| bison | 3.0.4 | No | -| cmake | 3.28 | No | -| double-conversion | 3.1.5 | No | -| xsimd | 10.0.0 | Yes | -| re2 | 2021-04-01 | Yes | -| fmt | 10.1.1 | Yes | -| simdjson | 3.9.3 | Yes | -| folly | v2025.04.28.00 | Yes | -| fizz | v2025.04.28.00 | No | -| wangle | v2025.04.28.00 | No | -| mvfst | v2025.04.28.00 | No | -| fbthrift | v2025.04.28.00 | No | -| libstemmer | 2.2.0 | Yes | -| DuckDB (testing) | 0.8.1 | Yes | -| cpr (testing) | 1.10.15 | Yes | -| arrow | 15.0.0 | Yes | -| geos | 3.10.2 | Yes | -| fast_float | v8.0.2 | Yes | -| xxhash | default | No | -| thrift | 0.16 | No | +| Library Name | Minimum Version | Bundled? | Comment | +|-------------------|-----------------|----------|---------| +| ninja | default | No || +| ccache | default | No || +| icu4c | default | Yes || +| gflags | default | Yes || +| glog | default | Yes || +| gtest (testing) | default | Yes || +| libevent | default | No || +| libsodium | default | No || +| lz4 | default | No || +| snappy | default | No || +| xz | default | No || +| zstd | default | No || +| openssl | default | No || +| protobuf | 21.7 >= x < 22 | Yes || +| boost | 1.77.0 | Yes || +| flex | 2.5.13 | No || +| bison | 3.0.4 | No || +| cmake | 3.28 | No || +| double-conversion | 3.1.5 | No || +| xsimd | 10.0.0 | Yes || +| re2 | 2024-07-02 | Yes || +| fmt | 11.2.0 | Yes | Used API must be fmt 9 compatible | +| simdjson | 4.1.0 | Yes || +| faiss | 1.11.0 | Yes || +| folly | v2025.04.28.00 | Yes || +| fizz | v2025.04.28.00 | No || +| wangle | v2025.04.28.00 | No || +| mvfst | v2025.04.28.00 | No || +| fbthrift | v2025.04.28.00 | No || +| libstemmer | 2.2.0 | Yes || +| DuckDB (testing) | 0.8.1 | Yes || +| arrow | 15.0.0 | Yes || +| geos | 3.10.7 | Yes || +| fast_float | v8.0.2 | Yes || +| xxhash | default | No || +| thrift | 0.16 | No || # Bundled Dependency Management This module provides a dependency management system that allows us to automatically fetch and build dependencies from source if needed. diff --git a/CMake/resolve_dependency_modules/absl.cmake b/CMake/resolve_dependency_modules/absl.cmake index c382e128eabf..505e39df6349 100644 --- a/CMake/resolve_dependency_modules/absl.cmake +++ b/CMake/resolve_dependency_modules/absl.cmake @@ -14,12 +14,16 @@ include_guard(GLOBAL) set(VELOX_ABSL_BUILD_VERSION 20240116.2) -set(VELOX_ABSL_BUILD_SHA256_CHECKSUM - 733726b8c3a6d39a4120d7e45ea8b41a434cdacde401cba500f14236c49b39dc) +set( + VELOX_ABSL_BUILD_SHA256_CHECKSUM + 733726b8c3a6d39a4120d7e45ea8b41a434cdacde401cba500f14236c49b39dc +) string( - CONCAT VELOX_ABSL_SOURCE_URL - "https://github.com/abseil/abseil-cpp/archive/refs/tags/" - "${VELOX_ABSL_BUILD_VERSION}.tar.gz") + CONCAT + VELOX_ABSL_SOURCE_URL + "https://github.com/abseil/abseil-cpp/archive/refs/tags/" + "${VELOX_ABSL_BUILD_VERSION}.tar.gz" +) velox_resolve_dependency_url(ABSL) @@ -29,8 +33,11 @@ FetchContent_Declare( absl URL ${VELOX_ABSL_SOURCE_URL} URL_HASH ${VELOX_ABSL_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM - PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/absl/absl-macos.patch) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM + PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/absl/absl-macos.patch +) set(ABSL_BUILD_TESTING OFF) set(ABSL_PROPAGATE_CXX_STD ON) diff --git a/CMake/resolve_dependency_modules/antlr4-runtime.cmake b/CMake/resolve_dependency_modules/antlr4-runtime.cmake index 86bc799ce715..cf8d77325462 100644 --- a/CMake/resolve_dependency_modules/antlr4-runtime.cmake +++ b/CMake/resolve_dependency_modules/antlr4-runtime.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_ANTLR4_RUNTIME_VERSION 4.13.2) -set(VELOX_ANTLR4_RUNTIME_BUILD_SHA256_CHECKSUM - 9f18272a9b32b622835a3365f850dd1063d60f5045fb1e12ce475ae6e18a35bb) -set(VELOX_ANTLR4_RUNTIME_SOURCE_URL - "https://github.com/antlr/antlr4/archive/refs/tags/${VELOX_ANTLR4_RUNTIME_VERSION}.tar.gz" +set( + VELOX_ANTLR4_RUNTIME_BUILD_SHA256_CHECKSUM + 9f18272a9b32b622835a3365f850dd1063d60f5045fb1e12ce475ae6e18a35bb +) +set( + VELOX_ANTLR4_RUNTIME_SOURCE_URL + "https://github.com/antlr/antlr4/archive/refs/tags/${VELOX_ANTLR4_RUNTIME_VERSION}.tar.gz" ) velox_resolve_dependency_url(ANTLR4_RUNTIME) @@ -28,9 +31,10 @@ FetchContent_Declare( antlr4-runtime URL ${VELOX_ANTLR4_RUNTIME_SOURCE_URL} URL_HASH ${VELOX_ANTLR4_RUNTIME_BUILD_SHA256_CHECKSUM} - SOURCE_SUBDIR runtime/Cpp OVERRIDE_FIND_PACKAGE) + SOURCE_SUBDIR + runtime/Cpp + OVERRIDE_FIND_PACKAGE +) -set(ANTLR4_INSTALL - ON - CACHE BOOL "" FORCE) +set(ANTLR4_INSTALL ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(antlr4-runtime) diff --git a/CMake/resolve_dependency_modules/arrow/CMakeLists.txt b/CMake/resolve_dependency_modules/arrow/CMakeLists.txt index 5144d646e788..401644d11f85 100644 --- a/CMake/resolve_dependency_modules/arrow/CMakeLists.txt +++ b/CMake/resolve_dependency_modules/arrow/CMakeLists.txt @@ -21,28 +21,31 @@ if(VELOX_ENABLE_ARROW) endif() set(ARROW_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/arrow_ep") - set(ARROW_CMAKE_ARGS - -DARROW_PARQUET=OFF - -DARROW_DEPENDENCY_SOURCE=AUTO - -DARROW_WITH_THRIFT=ON - -DARROW_WITH_LZ4=ON - -DARROW_WITH_SNAPPY=ON - -DARROW_WITH_ZLIB=ON - -DARROW_WITH_ZSTD=ON - -DARROW_JEMALLOC=OFF - -DARROW_SIMD_LEVEL=NONE - -DARROW_RUNTIME_SIMD_LEVEL=NONE - -DARROW_WITH_UTF8PROC=OFF - -DARROW_TESTING=ON - -DCMAKE_INSTALL_PREFIX=${ARROW_PREFIX}/install - -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} - -DARROW_BUILD_STATIC=ON - -DThrift_SOURCE=${THRIFT_SOURCE} - -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} - -DCMAKE_POLICY_VERSION_MINIMUM=3.5 - # Remove with Arrow upgrade to Arrow 20. - -DARROW_CXXFLAGS=-Wno-documentation) - set(ARROW_LIBDIR ${ARROW_PREFIX}/install/${CMAKE_INSTALL_LIBDIR}) + set( + ARROW_CMAKE_ARGS + -DARROW_PARQUET=OFF + -DARROW_DEPENDENCY_SOURCE=AUTO + -DARROW_WITH_THRIFT=ON + -DARROW_WITH_LZ4=ON + -DARROW_WITH_SNAPPY=ON + -DARROW_WITH_ZLIB=ON + -DARROW_WITH_ZSTD=ON + -DARROW_JEMALLOC=OFF + -DARROW_SIMD_LEVEL=NONE + -DARROW_RUNTIME_SIMD_LEVEL=NONE + -DARROW_WITH_UTF8PROC=OFF + -DARROW_TESTING=ON + -DCMAKE_INSTALL_LIBDIR=lib + -DCMAKE_INSTALL_PREFIX=${ARROW_PREFIX}/install + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DARROW_BUILD_STATIC=ON + -DThrift_SOURCE=${THRIFT_SOURCE} + -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 + # Remove with Arrow upgrade to Arrow 20. + -DARROW_CXXFLAGS=-Wno-documentation + ) + set(ARROW_LIBDIR ${ARROW_PREFIX}/install/lib) add_library(thrift STATIC IMPORTED GLOBAL) if(NOT Thrift_FOUND) @@ -53,15 +56,17 @@ if(VELOX_ENABLE_ARROW) set(THRIFT_INCLUDE_DIR ${THRIFT_ROOT}/include) endif() - set_property(TARGET thrift PROPERTY INTERFACE_INCLUDE_DIRECTORIES - ${THRIFT_INCLUDE_DIR}) + set_property(TARGET thrift PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${THRIFT_INCLUDE_DIR}) set_property(TARGET thrift PROPERTY IMPORTED_LOCATION ${THRIFT_LIB}) set(VELOX_ARROW_BUILD_VERSION 15.0.0) - set(VELOX_ARROW_BUILD_SHA256_CHECKSUM - ab74c60c46938505c8cd7599b1d2826c68450645d5860d0ff40f67e371a5d0b5) - set(VELOX_ARROW_SOURCE_URL - "https://github.com/apache/arrow/archive/refs/tags/apache-arrow-${VELOX_ARROW_BUILD_VERSION}.tar.gz" + set( + VELOX_ARROW_BUILD_SHA256_CHECKSUM + ab74c60c46938505c8cd7599b1d2826c68450645d5860d0ff40f67e371a5d0b5 + ) + set( + VELOX_ARROW_SOURCE_URL + "https://github.com/apache/arrow/archive/refs/tags/apache-arrow-${VELOX_ARROW_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(ARROW) @@ -73,9 +78,11 @@ if(VELOX_ENABLE_ARROW) URL_HASH ${VELOX_ARROW_BUILD_SHA256_CHECKSUM} SOURCE_SUBDIR cpp CMAKE_ARGS ${ARROW_CMAKE_ARGS} - BUILD_BYPRODUCTS ${ARROW_LIBDIR}/libarrow.a - ${ARROW_LIBDIR}/libarrow_testing.a ${THRIFT_LIB} - PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/thrift-download.patch) + BUILD_BYPRODUCTS ${ARROW_LIBDIR}/libarrow.a ${ARROW_LIBDIR}/libarrow_testing.a ${THRIFT_LIB} + PATCH_COMMAND + git apply ${CMAKE_CURRENT_LIST_DIR}/thrift-download.patch && git apply + ${CMAKE_CURRENT_LIST_DIR}/cmake-compatibility.patch + ) add_library(arrow STATIC IMPORTED GLOBAL) add_library(arrow_testing STATIC IMPORTED GLOBAL) @@ -83,13 +90,14 @@ if(VELOX_ENABLE_ARROW) add_dependencies(arrow_testing arrow) file(MAKE_DIRECTORY ${ARROW_PREFIX}/install/include) set_target_properties( - arrow arrow_testing PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - ${ARROW_PREFIX}/install/include) - set_target_properties(arrow PROPERTIES IMPORTED_LOCATION - ${ARROW_LIBDIR}/libarrow.a) + arrow + arrow_testing + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ARROW_PREFIX}/install/include + ) + set_target_properties(arrow PROPERTIES IMPORTED_LOCATION ${ARROW_LIBDIR}/libarrow.a) set_property(TARGET arrow PROPERTY INTERFACE_LINK_LIBRARIES ${RE2} thrift) set_target_properties( - arrow_testing PROPERTIES IMPORTED_LOCATION - ${ARROW_LIBDIR}/libarrow_testing.a) - + arrow_testing + PROPERTIES IMPORTED_LOCATION ${ARROW_LIBDIR}/libarrow_testing.a + ) endif() diff --git a/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch b/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch new file mode 100644 index 000000000000..249ea6090483 --- /dev/null +++ b/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--- a/cpp/cmake_modules/ThirdpartyToolchain.cmake ++++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake +@@ -971,7 +971,8 @@ + -DCMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY=${CMAKE_FIND_PACKAGE_NO_PACKAGE_REGISTRY} + -DCMAKE_INSTALL_LIBDIR=lib + -DCMAKE_OSX_SYSROOT=${CMAKE_OSX_SYSROOT} +- -DCMAKE_VERBOSE_MAKEFILE=${CMAKE_VERBOSE_MAKEFILE}) ++ -DCMAKE_VERBOSE_MAKEFILE=${CMAKE_VERBOSE_MAKEFILE} ++ -DCMAKE_POLICY_VERSION_MINIMUM=3.5) + + # Enable s/ccache if set by parent. + if(CMAKE_C_COMPILER_LAUNCHER AND CMAKE_CXX_COMPILER_LAUNCHER) +@@ -1026,6 +1027,7 @@ + set(CMAKE_COMPILE_WARNING_AS_ERROR FALSE) + set(CMAKE_EXPORT_NO_PACKAGE_REGISTRY TRUE) + set(CMAKE_MACOSX_RPATH ${ARROW_INSTALL_NAME_RPATH}) ++ set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + if(MSVC) + string(REPLACE "/WX" "" CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}") + string(REPLACE "/WX" "" CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") diff --git a/CMake/resolve_dependency_modules/boost.cmake b/CMake/resolve_dependency_modules/boost.cmake index 842cba6f1334..dfbc61698477 100644 --- a/CMake/resolve_dependency_modules/boost.cmake +++ b/CMake/resolve_dependency_modules/boost.cmake @@ -15,13 +15,6 @@ include_guard(GLOBAL) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/boost) -if(ICU_SOURCE) - if(${ICU_SOURCE} STREQUAL "BUNDLED") - # ensure ICU is built before Boost - add_dependencies(boost_regex ICU ICU::i18n) - endif() -endif() - # This prevents system boost from leaking in set(Boost_NO_SYSTEM_PATHS ON) # We have to keep the FindBoost.cmake in an subfolder to prevent it from diff --git a/CMake/resolve_dependency_modules/boost/CMakeLists.txt b/CMake/resolve_dependency_modules/boost/CMakeLists.txt index 69018a0f3a68..c653eba61f3f 100644 --- a/CMake/resolve_dependency_modules/boost/CMakeLists.txt +++ b/CMake/resolve_dependency_modules/boost/CMakeLists.txt @@ -23,12 +23,16 @@ add_compile_options(-w) # officale releases for some reason) set(VELOX_BOOST_BUILD_VERSION 1.84.0) string( - CONCAT VELOX_BOOST_SOURCE_URL - "https://github.com/boostorg/boost/releases/download/" - "boost-${VELOX_BOOST_BUILD_VERSION}/" - "boost-${VELOX_BOOST_BUILD_VERSION}.tar.gz") -set(VELOX_BOOST_BUILD_SHA256_CHECKSUM - 4d27e9efed0f6f152dc28db6430b9d3dfb40c0345da7342eaa5a987dde57bd95) + CONCAT + VELOX_BOOST_SOURCE_URL + "https://github.com/boostorg/boost/releases/download/" + "boost-${VELOX_BOOST_BUILD_VERSION}/" + "boost-${VELOX_BOOST_BUILD_VERSION}.tar.gz" +) +set( + VELOX_BOOST_BUILD_SHA256_CHECKSUM + 4d27e9efed0f6f152dc28db6430b9d3dfb40c0345da7342eaa5a987dde57bd95 +) velox_resolve_dependency_url(BOOST) message(STATUS "Building boost from source") @@ -42,22 +46,28 @@ endif() FetchContent_Declare( Boost URL ${VELOX_BOOST_SOURCE_URL} - URL_HASH ${VELOX_BOOST_BUILD_SHA256_CHECKSUM}) + URL_HASH ${VELOX_BOOST_BUILD_SHA256_CHECKSUM} +) # Configure the file before adding the header only libs -configure_file(${CMAKE_CURRENT_LIST_DIR}/FindBoost.cmake.in - ${CMAKE_CURRENT_LIST_DIR}/FindBoost.cmake @ONLY) +configure_file( + ${CMAKE_CURRENT_LIST_DIR}/FindBoost.cmake.in + ${CMAKE_CURRENT_LIST_DIR}/FindBoost.cmake + @ONLY +) -set(BOOST_HEADER_ONLY - crc - circular_buffer - math - multi_index - multiprecision - numeric_conversion - random - uuid - variant) +set( + BOOST_HEADER_ONLY + crc + circular_buffer + math + multi_index + multiprecision + numeric_conversion + random + uuid + variant +) list(APPEND BOOST_INCLUDE_LIBRARIES ${BOOST_HEADER_ONLY}) # The `headers` target is not created by Boost cmake and leads to a warning @@ -66,7 +76,5 @@ set(BUILD_SHARED_LIBS OFF) FetchContent_MakeAvailable(Boost) list(TRANSFORM BOOST_HEADER_ONLY PREPEND Boost::) -target_link_libraries( - boost_headers - INTERFACE ${BOOST_HEADER_ONLY}) +target_link_libraries(boost_headers INTERFACE ${BOOST_HEADER_ONLY}) add_library(Boost::headers ALIAS boost_headers) diff --git a/CMake/resolve_dependency_modules/c-ares.cmake b/CMake/resolve_dependency_modules/c-ares.cmake index ccc48b632728..267ee27758c9 100644 --- a/CMake/resolve_dependency_modules/c-ares.cmake +++ b/CMake/resolve_dependency_modules/c-ares.cmake @@ -14,12 +14,16 @@ include_guard(GLOBAL) set(VELOX_CARES_BUILD_VERSION cares-1_13_0) -set(VELOX_CARES_BUILD_SHA256_CHECKSUM - 7c48c57706a38691041920e705d2a04426ad9c68d40edd600685323f214b2d57) +set( + VELOX_CARES_BUILD_SHA256_CHECKSUM + 7c48c57706a38691041920e705d2a04426ad9c68d40edd600685323f214b2d57 +) string( - CONCAT VELOX_CARES_SOURCE_URL - "https://github.com/c-ares/c-ares/archive/refs/tags/" - "${VELOX_CARES_BUILD_VERSION}.tar.gz") + CONCAT + VELOX_CARES_SOURCE_URL + "https://github.com/c-ares/c-ares/archive/refs/tags/" + "${VELOX_CARES_BUILD_VERSION}.tar.gz" +) velox_resolve_dependency_url(CARES) @@ -29,10 +33,11 @@ FetchContent_Declare( c-ares URL ${VELOX_CARES_SOURCE_URL} URL_HASH ${VELOX_CARES_BUILD_SHA256_CHECKSUM} - PATCH_COMMAND - git init && git apply - ${CMAKE_CURRENT_LIST_DIR}/c-ares/c-ares-random-file.patch - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + PATCH_COMMAND git init && git apply ${CMAKE_CURRENT_LIST_DIR}/c-ares/c-ares-random-file.patch + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) set(CARES_STATIC ON) set(CARES_INSTALL ON) diff --git a/CMake/resolve_dependency_modules/clp.cmake b/CMake/resolve_dependency_modules/clp.cmake index 72e5e65b9f41..c9579fbc314b 100644 --- a/CMake/resolve_dependency_modules/clp.cmake +++ b/CMake/resolve_dependency_modules/clp.cmake @@ -16,28 +16,16 @@ include_guard(GLOBAL) FetchContent_Declare( clp GIT_REPOSITORY https://github.com/y-scope/clp.git - GIT_TAG v0.8.0) + GIT_TAG f82e6114160a6addd4727259906bcf621ac9912c +) -set(CLP_BUILD_CLP_REGEX_UTILS - OFF - CACHE BOOL "Build CLP regex utils") -set(CLP_BUILD_CLP_S_JSONCONSTRUCTOR - OFF - CACHE BOOL "Build CLP-S JSON constructor") -set(CLP_BUILD_CLP_S_REDUCER_DEPENDENCIES - OFF - CACHE BOOL "Build CLP-S reducer dependencies") -set(CLP_BUILD_CLP_S_SEARCH_SQL - OFF - CACHE BOOL "Build CLP-S search SQL") -set(CLP_BUILD_EXECUTABLES - OFF - CACHE BOOL "Build CLP executables") -set(CLP_BUILD_TESTING - OFF - CACHE BOOL "Build CLP tests") +set(CLP_BUILD_CLP_REGEX_UTILS OFF CACHE BOOL "Build CLP regex utils") +set(CLP_BUILD_CLP_S_JSONCONSTRUCTOR OFF CACHE BOOL "Build CLP-S JSON constructor") +set(CLP_BUILD_CLP_S_REDUCER_DEPENDENCIES OFF CACHE BOOL "Build CLP-S reducer dependencies") +set(CLP_BUILD_CLP_S_SEARCH_SQL OFF CACHE BOOL "Build CLP-S search SQL") +set(CLP_BUILD_EXECUTABLES OFF CACHE BOOL "Build CLP executables") +set(CLP_BUILD_TESTING OFF CACHE BOOL "Build CLP tests") FetchContent_Populate(clp) -add_subdirectory(${clp_SOURCE_DIR}/components/core - ${clp_BINARY_DIR}/components/core) +add_subdirectory(${clp_SOURCE_DIR}/components/core ${clp_BINARY_DIR}/components/core) diff --git a/CMake/resolve_dependency_modules/cpr.cmake b/CMake/resolve_dependency_modules/cpr.cmake deleted file mode 100644 index 488595e08512..000000000000 --- a/CMake/resolve_dependency_modules/cpr.cmake +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -include_guard(GLOBAL) - -set(VELOX_CPR_VERSION 1.10.5) -set(VELOX_CPR_BUILD_SHA256_CHECKSUM - c8590568996cea918d7cf7ec6845d954b9b95ab2c4980b365f582a665dea08d8) -set(VELOX_CPR_SOURCE_URL - "https://github.com/libcpr/cpr/archive/refs/tags/${VELOX_CPR_VERSION}.tar.gz" -) - -# Add the dependency for curl, so that we can define the source URL for curl in -# curl.cmake. This will override the curl version declared by cpr. -set(curl_SOURCE BUNDLED) -velox_resolve_dependency(curl) - -velox_resolve_dependency_url(CPR) - -message(STATUS "Building cpr from source") -FetchContent_Declare( - cpr - URL ${VELOX_CPR_SOURCE_URL} - URL_HASH ${VELOX_CPR_BUILD_SHA256_CHECKSUM} - PATCH_COMMAND - git apply ${CMAKE_CURRENT_LIST_DIR}/cpr/cpr-libcurl-compatible.patch && git - apply ${CMAKE_CURRENT_LIST_DIR}/cpr/cpr-remove-sancheck.patch) -set(BUILD_SHARED_LIBS ${VELOX_BUILD_SHARED}) -set(CPR_USE_SYSTEM_CURL OFF) -# ZLIB has already been found by find_package(ZLIB, REQUIRED), set CURL_ZLIB=OFF -# to save compile time. -set(CURL_ZLIB OFF) -FetchContent_MakeAvailable(cpr) -# libcpr in its CMakeLists.txt file disables the BUILD_TESTING globally when -# CPR_USE_SYSTEM_CURL=OFF. unset BUILD_TESTING here. -unset(BUILD_TESTING) -unset(BUILD_SHARED_LIBS) diff --git a/CMake/resolve_dependency_modules/cpr/cpr-libcurl-compatible.patch b/CMake/resolve_dependency_modules/cpr/cpr-libcurl-compatible.patch deleted file mode 100644 index 49821889f2bf..000000000000 --- a/CMake/resolve_dependency_modules/cpr/cpr-libcurl-compatible.patch +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This can be removed once we upgrade to curl >= 7.68.0 ---- a/cpr/multiperform.cpp -+++ b/cpr/multiperform.cpp -@@ -97,9 +97,9 @@ void MultiPerform::DoMultiPerform() { - - if (still_running) { - const int timeout_ms{250}; -- error_code = curl_multi_poll(multicurl_->handle, nullptr, 0, timeout_ms, nullptr); -+ error_code = curl_multi_wait(multicurl_->handle, nullptr, 0, timeout_ms, nullptr); - if (error_code) { -- std::cerr << "curl_multi_poll() failed, code " << static_cast(error_code) << std::endl; -+ std::cerr << "curl_multi_wait() failed, code " << static_cast(error_code) << std::endl; - break; - } - } - ---- a/include/cpr/util.h -+++ b/include/cpr/util.h -@@ -23,7 +23,7 @@ size_t writeUserFunction(char* ptr, size_t size, size_t nmemb, const WriteCallba - template - int progressUserFunction(const T* progress, cpr_pf_arg_t dltotal, cpr_pf_arg_t dlnow, cpr_pf_arg_t ultotal, cpr_pf_arg_t ulnow) { - const int cancel_retval{1}; -- static_assert(cancel_retval != CURL_PROGRESSFUNC_CONTINUE); -+ static_assert(cancel_retval != 0x10000001); - return (*progress)(dltotal, dlnow, ultotal, ulnow) ? 0 : cancel_retval; - } - int debugUserFunction(CURL* handle, curl_infotype type, char* data, size_t size, const DebugCallback* debug); diff --git a/CMake/resolve_dependency_modules/cudf.cmake b/CMake/resolve_dependency_modules/cudf.cmake index 56f7b9bf1d18..e21406411865 100644 --- a/CMake/resolve_dependency_modules/cudf.cmake +++ b/CMake/resolve_dependency_modules/cudf.cmake @@ -17,79 +17,48 @@ include_guard(GLOBAL) # 3.30.4 is the minimum version required by cudf cmake_minimum_required(VERSION 3.30.4) -set(VELOX_rapids_cmake_VERSION 25.04) -set(VELOX_rapids_cmake_BUILD_SHA256_CHECKSUM - 458c14eaff9000067b32d65c8c914f4521090ede7690e16eb57035ce731386db) -set(VELOX_rapids_cmake_SOURCE_URL - "https://github.com/rapidsai/rapids-cmake/archive/7828fc8ff2e9f4fa86099f3c844505c2f47ac672.tar.gz" -) -velox_resolve_dependency_url(rapids_cmake) +# Add velox_resolve_dependency_url here for rapids-cmake, rmm, and kvikio if a specific version or commit is needed. -set(VELOX_rmm_VERSION 25.04) -set(VELOX_rmm_BUILD_SHA256_CHECKSUM - 294905094213a2d1fd8e024500359ff871bc52f913a3fbaca3514727c49f62de) -set(VELOX_rmm_SOURCE_URL - "https://github.com/rapidsai/rmm/archive/d8b7dacdeda302d2e37313c02d14ef5e1d1e98ea.tar.gz" -) -velox_resolve_dependency_url(rmm) +set(VELOX_cudf_VERSION 25.12 CACHE STRING "cudf version") -set(VELOX_kvikio_VERSION 25.04) -set(VELOX_kvikio_BUILD_SHA256_CHECKSUM - 4a0b15295d0a397433930bf9a309e4ad2361b25dc7a7b3e6a35d0c9419d0cb62) -set(VELOX_kvikio_SOURCE_URL - "https://github.com/rapidsai/kvikio/archive/5c710f37236bda76e447e929e17b1efbc6c632c3.tar.gz" +set( + VELOX_cudf_BUILD_SHA256_CHECKSUM + 4ec101a368e1423a1a3831a121f0d24e2d91ed044aed68faee0ba18fda38b450 ) -velox_resolve_dependency_url(kvikio) - -set(VELOX_cudf_VERSION 25.04) -set(VELOX_cudf_BUILD_SHA256_CHECKSUM - e5a1900dfaf23dab2c5808afa17a2d04fa867d2892ecec1cb37908f3b73715c2) -set(VELOX_cudf_SOURCE_URL - "https://github.com/rapidsai/cudf/archive/4c1c99011da2c23856244e05adda78ba66697105.tar.gz" +set( + VELOX_cudf_SOURCE_URL + "https://github.com/rapidsai/cudf/archive/181bd7b85c614e9c8f755a62f07f4b9c9334b615.tar.gz" ) velox_resolve_dependency_url(cudf) # Use block so we don't leak variables block(SCOPE_FOR VARIABLES) -# Setup libcudf build to not have testing components -set(BUILD_TESTS OFF) -set(CUDF_BUILD_TESTUTIL OFF) -set(BUILD_SHARED_LIBS ON) - -FetchContent_Declare( - rapids-cmake - URL ${VELOX_rapids_cmake_SOURCE_URL} - URL_HASH ${VELOX_rapids_cmake_BUILD_SHA256_CHECKSUM} - UPDATE_DISCONNECTED 1) - -FetchContent_Declare( - rmm - URL ${VELOX_rmm_SOURCE_URL} - URL_HASH ${VELOX_rmm_BUILD_SHA256_CHECKSUM} - UPDATE_DISCONNECTED 1) + # Setup libcudf build to not have testing components + set(BUILD_TESTS OFF) + set(CUDF_BUILD_TESTUTIL OFF) + set(BUILD_SHARED_LIBS ON) -FetchContent_Declare( - kvikio - URL ${VELOX_kvikio_SOURCE_URL} - URL_HASH ${VELOX_kvikio_BUILD_SHA256_CHECKSUM} - SOURCE_SUBDIR cpp - UPDATE_DISCONNECTED 1) + # Add FetchContent_Declare here for rapids-cmake, rmm, and kvikio if a specific version or commit is needed. -FetchContent_Declare( - cudf - URL ${VELOX_cudf_SOURCE_URL} - URL_HASH ${VELOX_cudf_BUILD_SHA256_CHECKSUM} - SOURCE_SUBDIR cpp - UPDATE_DISCONNECTED 1) + FetchContent_Declare( + cudf + URL ${VELOX_cudf_SOURCE_URL} + URL_HASH ${VELOX_cudf_BUILD_SHA256_CHECKSUM} + SOURCE_SUBDIR + cpp + UPDATE_DISCONNECTED 1 + ) -FetchContent_MakeAvailable(cudf) + FetchContent_MakeAvailable(cudf) -# cudf sets all warnings as errors, and therefore fails to compile with velox -# expanded set of warnings. We selectively disable problematic warnings just for -# cudf -target_compile_options( - cudf PRIVATE -Wno-non-virtual-dtor -Wno-missing-field-initializers - -Wno-deprecated-copy) + # cudf sets all warnings as errors, and therefore fails to compile with velox + # expanded set of warnings. We selectively disable problematic warnings just for + # cudf + target_compile_options( + cudf + PRIVATE -Wno-non-virtual-dtor -Wno-missing-field-initializers -Wno-deprecated-copy -Wno-restrict + ) -unset(BUILD_SHARED_LIBS) + unset(BUILD_SHARED_LIBS) + unset(BUILD_TESTING CACHE) endblock() diff --git a/CMake/resolve_dependency_modules/curl.cmake b/CMake/resolve_dependency_modules/curl.cmake index 80be3e09b975..fd4b51adf2b4 100644 --- a/CMake/resolve_dependency_modules/curl.cmake +++ b/CMake/resolve_dependency_modules/curl.cmake @@ -13,18 +13,4 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_CURL_VERSION 8.4.0) -string(REPLACE "." "_" VELOX_CURL_VERSION_UNDERSCORES ${VELOX_CURL_VERSION}) -set(VELOX_CURL_BUILD_SHA256_CHECKSUM - 16c62a9c4af0f703d28bda6d7bbf37ba47055ad3414d70dec63e2e6336f2a82d) -string( - CONCAT - VELOX_CURL_SOURCE_URL "https://github.com/curl/curl/releases/download/" - "curl-${VELOX_CURL_VERSION_UNDERSCORES}/curl-${VELOX_CURL_VERSION}.tar.xz") - -velox_resolve_dependency_url(CURL) - -FetchContent_Declare( - curl - URL ${VELOX_CURL_SOURCE_URL} - URL_HASH ${VELOX_CURL_BUILD_SHA256_CHECKSUM}) +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/curl) diff --git a/CMake/resolve_dependency_modules/curl/CMakeLists.txt b/CMake/resolve_dependency_modules/curl/CMakeLists.txt new file mode 100644 index 000000000000..867c19cf242f --- /dev/null +++ b/CMake/resolve_dependency_modules/curl/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(VELOX_CURL_VERSION 8.4.0) +string(REPLACE "." "_" VELOX_CURL_VERSION_UNDERSCORES ${VELOX_CURL_VERSION}) +set( + VELOX_CURL_BUILD_SHA256_CHECKSUM + 16c62a9c4af0f703d28bda6d7bbf37ba47055ad3414d70dec63e2e6336f2a82d +) +string( + CONCAT + VELOX_CURL_SOURCE_URL + "https://github.com/curl/curl/releases/download/" + "curl-${VELOX_CURL_VERSION_UNDERSCORES}/curl-${VELOX_CURL_VERSION}.tar.xz" +) + +set(BUILD_TESTING OFF) +set(BUILD_SHARED_LIBS ON) + +velox_resolve_dependency_url(CURL) + +message(STATUS "Building CURL from source") + +FetchContent_Declare(curl URL ${VELOX_CURL_SOURCE_URL} URL_HASH ${VELOX_CURL_BUILD_SHA256_CHECKSUM}) + +FetchContent_MakeAvailable(curl) + +# Curl uses CMake option for BUILD_TESTING and BUILD_SHARED_LIBS +# See CMake option semantics. +unset(BUILD_TESTING CACHE) +unset(BUILD_SHARED_LIBS CACHE) diff --git a/CMake/resolve_dependency_modules/date.cmake b/CMake/resolve_dependency_modules/date.cmake index 275aa6e0a080..4daaee85f067 100644 --- a/CMake/resolve_dependency_modules/date.cmake +++ b/CMake/resolve_dependency_modules/date.cmake @@ -14,18 +14,19 @@ include_guard(GLOBAL) set(VELOX_DATE_BUILD_VERSION 3.0.1) -set(VELOX_DATE_BUILD_SHA256_CHECKSUM - 7a390f200f0ccd207e8cff6757e04817c1a0aec3e327b006b7eb451c57ee3538) -set(VELOX_DATE_SOURCE_URL - "https://github.com/HowardHinnant/date/archive/refs/tags/v${VELOX_DATE_BUILD_VERSION}.tar.gz" +set( + VELOX_DATE_BUILD_SHA256_CHECKSUM + 7a390f200f0ccd207e8cff6757e04817c1a0aec3e327b006b7eb451c57ee3538 +) +set( + VELOX_DATE_SOURCE_URL + "https://github.com/HowardHinnant/date/archive/refs/tags/v${VELOX_DATE_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(DATE) # Optionally set CMake variables *before* make-available -set(CMAKE_INSTALL_MESSAGE - LAZY - CACHE STRING "" FORCE) +set(CMAKE_INSTALL_MESSAGE LAZY CACHE STRING "" FORCE) message(STATUS "Building date from source") @@ -33,6 +34,7 @@ FetchContent_Declare( date URL ${VELOX_DATE_SOURCE_URL} URL_HASH ${VELOX_DATE_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE) + OVERRIDE_FIND_PACKAGE +) FetchContent_MakeAvailable(date) diff --git a/CMake/resolve_dependency_modules/duckdb.cmake b/CMake/resolve_dependency_modules/duckdb.cmake index 75dde1f8eebd..bd3fbd7e14f8 100644 --- a/CMake/resolve_dependency_modules/duckdb.cmake +++ b/CMake/resolve_dependency_modules/duckdb.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_DUCKDB_VERSION 0.8.1) -set(VELOX_DUCKDB_BUILD_SHA256_CHECKSUM - a0674f7e320dc7ebcf51990d7fc1c0e7f7b2c335c08f5953702b5285e6c30694) -set(VELOX_DUCKDB_SOURCE_URL - "https://github.com/duckdb/duckdb/archive/refs/tags/v${VELOX_DUCKDB_VERSION}.tar.gz" +set( + VELOX_DUCKDB_BUILD_SHA256_CHECKSUM + a0674f7e320dc7ebcf51990d7fc1c0e7f7b2c335c08f5953702b5285e6c30694 +) +set( + VELOX_DUCKDB_SOURCE_URL + "https://github.com/duckdb/duckdb/archive/refs/tags/v${VELOX_DUCKDB_VERSION}.tar.gz" ) set(CMAKE_POLICY_VERSION_MINIMUM 3.5) @@ -25,19 +28,23 @@ velox_resolve_dependency_url(DUCKDB) message(STATUS "Building DuckDB from source") # We need remove-ccache.patch to remove adding ccache to the build command -# twice. Velox already does this. We need fix-duckdbversion.patch as DuckDB -# tries to infer the version via a git commit hash or git tag. This inference -# can lead to errors when building in another git project such as Prestissimo. +# twice. Velox already does this. FetchContent_Declare( duckdb URL ${VELOX_DUCKDB_SOURCE_URL} URL_HASH ${VELOX_DUCKDB_BUILD_SHA256_CHECKSUM} PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/duckdb/remove-ccache.patch && git apply - ${CMAKE_CURRENT_LIST_DIR}/duckdb/fix-duckdbversion.patch && git apply - ${CMAKE_CURRENT_LIST_DIR}/duckdb/re2.patch) + ${CMAKE_CURRENT_LIST_DIR}/duckdb/re2.patch +) +# DuckDB uses git commands to retrieve version information during the build, +# which works with git clone. To prevent incorrectly using the parent project's +# git version when building from a tarball, we define GIT_COMMIT_HASH to skip +# that. +set(GIT_COMMIT_HASH "6536a77") set(BUILD_UNITTESTS OFF) +set(BUILD_TESTING OFF) set(ENABLE_SANITIZER OFF) set(ENABLE_UBSAN OFF) set(BUILD_SHELL OFF) @@ -46,6 +53,10 @@ set(PREVIOUS_BUILD_TYPE ${CMAKE_BUILD_TYPE}) set(CMAKE_BUILD_TYPE Release) set(PREVIOUS_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-non-virtual-dtor") +# Clang17 requires this. See issue #13215. +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 17.0.0) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-template-arg-list-after-template-kw") +endif() FetchContent_MakeAvailable(duckdb) @@ -55,3 +66,5 @@ endif() set(CMAKE_CXX_FLAGS ${PREVIOUS_CMAKE_CXX_FLAGS}) set(CMAKE_BUILD_TYPE ${PREVIOUS_BUILD_TYPE}) +# Some DuckDB third-party package sets this flags. We cannot control that. +unset(BUILD_TESTING) diff --git a/CMake/resolve_dependency_modules/duckdb/fix-duckdbversion.patch b/CMake/resolve_dependency_modules/duckdb/fix-duckdbversion.patch deleted file mode 100644 index d990646800f5..000000000000 --- a/CMake/resolve_dependency_modules/duckdb/fix-duckdbversion.patch +++ /dev/null @@ -1,59 +0,0 @@ ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -210,56 +210,8 @@ - set(CXX_EXTRA "${CXX_EXTRA} -mimpure-text") - add_definitions(-DSUN=1) - set(SUN TRUE) --endif() -- --find_package(Git) --if(Git_FOUND) -- if (NOT DEFINED GIT_COMMIT_HASH) -- execute_process( -- COMMAND ${GIT_EXECUTABLE} log -1 --format=%h -- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} -- RESULT_VARIABLE GIT_RESULT -- OUTPUT_VARIABLE GIT_COMMIT_HASH -- OUTPUT_STRIP_TRAILING_WHITESPACE) -- endif() -- execute_process( -- COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 -- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} -- OUTPUT_VARIABLE GIT_LAST_TAG -- OUTPUT_STRIP_TRAILING_WHITESPACE) -- execute_process( -- COMMAND ${GIT_EXECUTABLE} describe --tags --long -- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} -- OUTPUT_VARIABLE GIT_ITERATION -- OUTPUT_STRIP_TRAILING_WHITESPACE) --else() -- message("Git NOT FOUND") --endif() -- --if(GIT_RESULT EQUAL "0") -- string(REGEX REPLACE "v([0-9]+).[0-9]+.[0-9]+" "\\1" DUCKDB_MAJOR_VERSION "${GIT_LAST_TAG}") -- string(REGEX REPLACE "v[0-9]+.([0-9]+).[0-9]+" "\\1" DUCKDB_MINOR_VERSION "${GIT_LAST_TAG}") -- string(REGEX REPLACE "v[0-9]+.[0-9]+.([0-9]+)" "\\1" DUCKDB_PATCH_VERSION "${GIT_LAST_TAG}") -- string(REGEX REPLACE ".*-([0-9]+)-.*" "\\1" DUCKDB_DEV_ITERATION "${GIT_ITERATION}") -- -- if(DUCKDB_DEV_ITERATION EQUAL 0) -- # on a tag; directly use the version -- set(DUCKDB_VERSION "${GIT_LAST_TAG}") -- else() -- # not on a tag, increment the patch version by one and add a -devX suffix -- math(EXPR DUCKDB_PATCH_VERSION "${DUCKDB_PATCH_VERSION}+1") -- set(DUCKDB_VERSION "v${DUCKDB_MAJOR_VERSION}.${DUCKDB_MINOR_VERSION}.${DUCKDB_PATCH_VERSION}-dev${DUCKDB_DEV_ITERATION}") -- endif() --else() -- # fallback for when building from tarball -- set(DUCKDB_MAJOR_VERSION 0) -- set(DUCKDB_MINOR_VERSION 0) -- set(DUCKDB_PATCH_VERSION 1) -- set(DUCKDB_DEV_ITERATION 0) -- set(DUCKDB_VERSION "v${DUCKDB_MAJOR_VERSION}.${DUCKDB_MINOR_VERSION}.${DUCKDB_PATCH_VERSION}-dev${DUCKDB_DEV_ITERATION}") - endif() - --message(STATUS "git hash ${GIT_COMMIT_HASH}, version ${DUCKDB_VERSION}") - - option(AMALGAMATION_BUILD - "Build from the amalgamation files, rather than from the normal sources." diff --git a/CMake/resolve_dependency_modules/faiss.cmake b/CMake/resolve_dependency_modules/faiss.cmake new file mode 100644 index 000000000000..db158dbe65d3 --- /dev/null +++ b/CMake/resolve_dependency_modules/faiss.cmake @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +include_guard(GLOBAL) + +set(VELOX_FAISS_BUILD_VERSION 1.11.0) +set( + VELOX_FAISS_BUILD_SHA256_CHECKSUM + c5d517da6deb6a6d74290d7145331fc7474426025e2d826fa4a6d40670f4493c +) +set( + VELOX_FAISS_SOURCE_URL + "https://github.com/facebookresearch/faiss/archive/refs/tags/v${VELOX_FAISS_BUILD_VERSION}.tar.gz" +) + +velox_resolve_dependency_url(FAISS) + +# We need these hints for macos to build. +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + message(STATUS "Detected Apple platform") + execute_process( + COMMAND brew --prefix libomp + RESULT_VARIABLE BREW_LIBOMP_RESULT + OUTPUT_VARIABLE BREW_LIBOMP_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(BREW_LIBOMP_RESULT EQUAL 0 AND EXISTS "${BREW_LIBOMP_PREFIX}") + list(APPEND CMAKE_PREFIX_PATH "${BREW_LIBOMP_PREFIX}") + endif() + + execute_process( + COMMAND brew --prefix openblas + RESULT_VARIABLE BREW_OPENBLAS_RESULT + OUTPUT_VARIABLE BREW_OPENBLAS_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(BREW_OPENBLAS_RESULT EQUAL 0 AND EXISTS "${BREW_OPENBLAS_PREFIX}") + list(APPEND CMAKE_PREFIX_PATH "${BREW_OPENBLAS_PREFIX}") + endif() +endif() + +FetchContent_Declare( + faiss + URL ${VELOX_FAISS_SOURCE_URL} + URL_HASH ${VELOX_FAISS_BUILD_SHA256_CHECKSUM} + SYSTEM + EXCLUDE_FROM_ALL +) + +# Set build options +block() + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + set(CMAKE_BUILD_TYPE Release) + set(FAISS_ENABLE_GPU OFF) + set(FAISS_ENABLE_PYTHON OFF) + set(FAISS_ENABLE_GPU_TESTS OFF) + # Make FAISS available + FetchContent_MakeAvailable(faiss) + add_library(FAISS::faiss ALIAS faiss) + unset(BUILD_TESTING CACHE) + unset(BUILD_SHARED_LIBS CACHE) +endblock() diff --git a/CMake/resolve_dependency_modules/fastfloat.cmake b/CMake/resolve_dependency_modules/fastfloat.cmake index 2800cdf566fc..9f31b08c6c92 100644 --- a/CMake/resolve_dependency_modules/fastfloat.cmake +++ b/CMake/resolve_dependency_modules/fastfloat.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_FASTFLOAT_VERSION v8.0.2) -set(VELOX_FASTFLOAT_BUILD_SHA256_CHECKSUM - e14a33089712b681d74d94e2a11362643bd7d769ae8f7e7caefe955f57f7eacd) -set(VELOX_FASTFLOAT_SOURCE_URL - "https://github.com/fastfloat/fast_float/archive/refs/tags/${VELOX_FASTFLOAT_VERSION}.tar.gz" +set( + VELOX_FASTFLOAT_BUILD_SHA256_CHECKSUM + e14a33089712b681d74d94e2a11362643bd7d769ae8f7e7caefe955f57f7eacd +) +set( + VELOX_FASTFLOAT_SOURCE_URL + "https://github.com/fastfloat/fast_float/archive/refs/tags/${VELOX_FASTFLOAT_VERSION}.tar.gz" ) velox_resolve_dependency_url(FASTFLOAT) @@ -26,7 +29,8 @@ message(STATUS "Building fast_float from source") FetchContent_Declare( fastfloat URL ${VELOX_FASTFLOAT_SOURCE_URL} - URL_HASH ${VELOX_FASTFLOAT_BUILD_SHA256_CHECKSUM}) + URL_HASH ${VELOX_FASTFLOAT_BUILD_SHA256_CHECKSUM} +) set(fastfloat_BUILD_TESTS OFF) FetchContent_MakeAvailable(fastfloat) # If folly is bundled it uses find_path fast_float/fast_float.h to locate the @@ -39,4 +43,5 @@ find_path( FASTFLOAT_INCLUDE_DIR NAMES fast_float/fast_float.h PATHS ${FASTFLOAT_SOURCE_DIR} - PATH_SUFFIXES include) + PATH_SUFFIXES include +) diff --git a/CMake/resolve_dependency_modules/fmt.cmake b/CMake/resolve_dependency_modules/fmt.cmake index cb57fd28fa4c..9e8087f1da00 100644 --- a/CMake/resolve_dependency_modules/fmt.cmake +++ b/CMake/resolve_dependency_modules/fmt.cmake @@ -13,19 +13,17 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_FMT_VERSION 10.1.1) -set(VELOX_FMT_BUILD_SHA256_CHECKSUM - 78b8c0a72b1c35e4443a7e308df52498252d1cefc2b08c9a97bc9ee6cfe61f8b) -set(VELOX_FMT_SOURCE_URL - "https://github.com/fmtlib/fmt/archive/${VELOX_FMT_VERSION}.tar.gz") +set(VELOX_FMT_VERSION 11.2.0) +set( + VELOX_FMT_BUILD_SHA256_CHECKSUM + bc23066d87ab3168f27cef3e97d545fa63314f5c79df5ea444d41d56f962c6af +) +set(VELOX_FMT_SOURCE_URL "https://github.com/fmtlib/fmt/archive/${VELOX_FMT_VERSION}.tar.gz") velox_resolve_dependency_url(FMT) message(STATUS "Building fmt from source") -FetchContent_Declare( - fmt - URL ${VELOX_FMT_SOURCE_URL} - URL_HASH ${VELOX_FMT_BUILD_SHA256_CHECKSUM}) +FetchContent_Declare(fmt URL ${VELOX_FMT_SOURCE_URL} URL_HASH ${VELOX_FMT_BUILD_SHA256_CHECKSUM}) # Force fmt to create fmt-config.cmake which can be found by other dependecies # (e.g. folly) set(FMT_INSTALL ON) diff --git a/CMake/resolve_dependency_modules/folly/CMakeLists.txt b/CMake/resolve_dependency_modules/folly/CMakeLists.txt index 57156ee0f3fd..884193118de8 100644 --- a/CMake/resolve_dependency_modules/folly/CMakeLists.txt +++ b/CMake/resolve_dependency_modules/folly/CMakeLists.txt @@ -14,14 +14,17 @@ project(Folly) cmake_minimum_required(VERSION 3.28) -velox_set_source(fastfloat) -velox_resolve_dependency(fastfloat CONFIG REQUIRED) +velox_set_source(FastFloat) +velox_resolve_dependency(FastFloat CONFIG REQUIRED) set(VELOX_FOLLY_BUILD_VERSION v2025.04.28.00) -set(VELOX_FOLLY_BUILD_SHA256_CHECKSUM - ccbb7eac662023f9f5beba94e51350d527f33d8a7a036eb5e3d8a5cf1b49d3bc) -set(VELOX_FOLLY_SOURCE_URL - "https://github.com/facebook/folly/releases/download/${VELOX_FOLLY_BUILD_VERSION}/folly-${VELOX_FOLLY_BUILD_VERSION}.tar.gz" +set( + VELOX_FOLLY_BUILD_SHA256_CHECKSUM + ccbb7eac662023f9f5beba94e51350d527f33d8a7a036eb5e3d8a5cf1b49d3bc +) +set( + VELOX_FOLLY_SOURCE_URL + "https://github.com/facebook/folly/releases/download/${VELOX_FOLLY_BUILD_VERSION}/folly-${VELOX_FOLLY_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(FOLLY) @@ -41,8 +44,11 @@ FetchContent_Declare( folly URL ${VELOX_FOLLY_SOURCE_URL} URL_HASH ${VELOX_FOLLY_BUILD_SHA256_CHECKSUM} - PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/folly-no-export.patch - ${glog_patch} OVERRIDE_FIND_PACKAGE SYSTEM EXCLUDE_FROM_ALL) + PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/folly-no-export.patch ${glog_patch} + OVERRIDE_FIND_PACKAGE + SYSTEM + EXCLUDE_FROM_ALL +) set(BUILD_SHARED_LIBS ${VELOX_BUILD_SHARED}) diff --git a/CMake/resolve_dependency_modules/geos.cmake b/CMake/resolve_dependency_modules/geos.cmake index 2cc129deda8b..f092ba823b1f 100644 --- a/CMake/resolve_dependency_modules/geos.cmake +++ b/CMake/resolve_dependency_modules/geos.cmake @@ -12,40 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. include_guard(GLOBAL) -# GEOS Configuration -set(VELOX_GEOS_BUILD_VERSION 3.10.2) -set(VELOX_GEOS_BUILD_SHA256_CHECKSUM - 50bbc599ac386b4c2b3962dcc411f0040a61f204aaef4eba7225ecdd0cf45715) -string(CONCAT VELOX_GEOS_SOURCE_URL "https://download.osgeo.org/geos/" - "geos-${VELOX_GEOS_BUILD_VERSION}.tar.bz2") -velox_resolve_dependency_url(GEOS) - -FetchContent_Declare( - geos - URL ${VELOX_GEOS_SOURCE_URL} - URL_HASH ${VELOX_GEOS_BUILD_SHA256_CHECKSUM} - PATCH_COMMAND - git apply "${CMAKE_CURRENT_LIST_DIR}/geos/geos-cmakelists.patch" && git - apply "${CMAKE_CURRENT_LIST_DIR}/geos/geos-build.patch") - -list(APPEND CMAKE_MODULE_PATH "${geos_SOURCE_DIR}/cmake") -set(BUILD_SHARED_LIBS ${VELOX_BUILD_SHARED}) -set(CMAKE_BUILD_TYPE Release) -set(PREVIOUS_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-nonnull ") - -# This option defaults to on and adds warning flags that fail the build. -set(GEOS_BUILD_DEVELOPER OFF) - -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-dangling-pointer") -endif() - -FetchContent_MakeAvailable(geos) - -add_library(GEOS::geos ALIAS geos) - -unset(BUILD_SHARED_LIBS) -set(CMAKE_CXX_FLAGS ${PREVIOUS_CMAKE_CXX_FLAGS}) -set(CMAKE_BUILD_TYPE ${PREVIOUS_BUILD_TYPE}) +# This creates a separate scope so any changed variables don't affect +# the rest of the build. +block() + set(VELOX_GEOS_BUILD_VERSION 3.10.7) + set( + VELOX_GEOS_BUILD_SHA256_CHECKSUM + 8b2ab4d04d660e27f2006550798f49dd11748c3767455cae9f71967dc437da1f + ) + string( + CONCAT + VELOX_GEOS_SOURCE_URL + "https://download.osgeo.org/geos/" + "geos-${VELOX_GEOS_BUILD_VERSION}.tar.bz2" + ) + + velox_resolve_dependency_url(GEOS) + + FetchContent_Declare( + geos + URL ${VELOX_GEOS_SOURCE_URL} + URL_HASH ${VELOX_GEOS_BUILD_SHA256_CHECKSUM} + PATCH_COMMAND git apply "${CMAKE_CURRENT_LIST_DIR}/geos/geos-cmakelists.patch" + OVERRIDE_FIND_PACKAGE + SYSTEM + EXCLUDE_FROM_ALL + ) + + list(APPEND CMAKE_MODULE_PATH "${geos_SOURCE_DIR}/cmake") + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + set(CMAKE_BUILD_TYPE Release) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-nonnull ") + # This option defaults to on and adds warning flags that fail the build. + set(GEOS_BUILD_DEVELOPER OFF) + + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-dangling-pointer") + endif() + + FetchContent_MakeAvailable(geos) + + add_library(GEOS::geos ALIAS geos) +endblock() diff --git a/CMake/resolve_dependency_modules/geos/geos-build.patch b/CMake/resolve_dependency_modules/geos/geos-build.patch deleted file mode 100644 index 4194a5f44c15..000000000000 --- a/CMake/resolve_dependency_modules/geos/geos-build.patch +++ /dev/null @@ -1,46 +0,0 @@ ---- a/capi/geos_ts_c.cpp -+++ b/capi/geos_ts_c.cpp -@@ -2168,7 +2168,7 @@ extern "C" { - const char* GEOSversion() - { - static char version[256]; -- sprintf(version, "%s", GEOS_CAPI_VERSION); -+ snprintf(version, sizeof(version), "%s", GEOS_CAPI_VERSION); - return version; - } - ---- a/include/geos/io/ByteOrderDataInStream.h -+++ b/include/geos/io/ByteOrderDataInStream.h -@@ -22,6 +22,7 @@ - - #include - #include -+#include - - //#include - //#include ---- a/tests/unit/math/DDTest.cpp -+++ b/tests/unit/math/DDTest.cpp -@@ -160,11 +160,9 @@ struct test_dd_data { - DD t2 = t*t; - DD at(0.0); - DD two(2.0); -- int k = 0; - DD d(1.0); - int sign = 1; - while (t.doubleValue() > eps) { -- k++; - if (sign < 0) - at = at - (t / d); - else -@@ -187,10 +185,8 @@ struct test_dd_data { - DD s(2.0); - DD t(1.0); - double n = 1.0; -- int i = 0; - while(t.doubleValue() > eps) - { -- i++; - n += 1.0; - t = t / DD(n); - s = s + t; diff --git a/CMake/resolve_dependency_modules/gflags.cmake b/CMake/resolve_dependency_modules/gflags.cmake index 7c7aa2854cba..f937a7f7087d 100644 --- a/CMake/resolve_dependency_modules/gflags.cmake +++ b/CMake/resolve_dependency_modules/gflags.cmake @@ -14,12 +14,16 @@ include_guard(GLOBAL) set(VELOX_GFLAGS_VERSION 2.2.2) -set(VELOX_GFLAGS_BUILD_SHA256_CHECKSUM - 34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf) +set( + VELOX_GFLAGS_BUILD_SHA256_CHECKSUM + 34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf +) string( - CONCAT VELOX_GFLAGS_SOURCE_URL - "https://github.com/gflags/gflags/archive/refs/tags/" - "v${VELOX_GFLAGS_VERSION}.tar.gz") + CONCAT + VELOX_GFLAGS_SOURCE_URL + "https://github.com/gflags/gflags/archive/refs/tags/" + "v${VELOX_GFLAGS_VERSION}.tar.gz" +) velox_resolve_dependency_url(GFLAGS) @@ -29,7 +33,10 @@ FetchContent_Declare( URL ${VELOX_GFLAGS_SOURCE_URL} URL_HASH ${VELOX_GFLAGS_BUILD_SHA256_CHECKSUM} PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/gflags/gflags-config.patch - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) # glog relies on the old `google` namespace set(GFLAGS_NAMESPACE "google;gflags") @@ -52,9 +59,13 @@ FetchContent_MakeAvailable(gflags) # Workaround for https://github.com/gflags/gflags/issues/277 if(DEFINED CACHED_BUILD_SHARED_LIBS) - set(BUILD_SHARED_LIBS - ${CACHED_BUILD_SHARED_LIBS} - CACHE BOOL "Restored after setting up gflags" FORCE) + set( + BUILD_SHARED_LIBS + ${CACHED_BUILD_SHARED_LIBS} + CACHE BOOL + "Restored after setting up gflags" + FORCE + ) endif() # This causes find_package(gflags) in other dependencies to search in the build diff --git a/CMake/resolve_dependency_modules/glog.cmake b/CMake/resolve_dependency_modules/glog.cmake index 61dfe479e71b..28256bfa0838 100644 --- a/CMake/resolve_dependency_modules/glog.cmake +++ b/CMake/resolve_dependency_modules/glog.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_GLOG_VERSION 0.6.0) -set(VELOX_GLOG_BUILD_SHA256_CHECKSUM - 8a83bf982f37bb70825df71a9709fa90ea9f4447fb3c099e1d720a439d88bad6) -set(VELOX_GLOG_SOURCE_URL - "https://github.com/google/glog/archive/refs/tags/v${VELOX_GLOG_VERSION}.tar.gz" +set( + VELOX_GLOG_BUILD_SHA256_CHECKSUM + 8a83bf982f37bb70825df71a9709fa90ea9f4447fb3c099e1d720a439d88bad6 +) +set( + VELOX_GLOG_SOURCE_URL + "https://github.com/google/glog/archive/refs/tags/v${VELOX_GLOG_VERSION}.tar.gz" ) velox_resolve_dependency_url(GLOG) @@ -29,8 +32,11 @@ FetchContent_Declare( URL_HASH ${VELOX_GLOG_BUILD_SHA256_CHECKSUM} PATCH_COMMAND git apply ${CMAKE_CURRENT_LIST_DIR}/glog/glog-no-export.patch && git apply - ${CMAKE_CURRENT_LIST_DIR}/glog/glog-config.patch SYSTEM - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL) + ${CMAKE_CURRENT_LIST_DIR}/glog/glog-config.patch + SYSTEM + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL +) set(BUILD_SHARED_LIBS ${VELOX_BUILD_SHARED}) set(WITH_UNWIND OFF) @@ -47,17 +53,13 @@ add_dependencies(glog gflags::gflags) # The default target has the glog-src as an include dir but this causes issues # with folly due to an internal glog 'demangle.h' being mistaken for a system -# header so we remove glog_SOURCE_DIR by overwriting -# INTERFACE_INCLUDE_DIRECTORIES -get_target_property( - _glog_target glog::glog ALIASED_TARGET) # Can't set properties on ALIAS - # targets -set_target_properties( - ${_glog_target} - PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${glog_BINARY_DIR}) +# header so we remove glog_SOURCE_DIR by overwriting INTERFACE_INCLUDE_DIRECTORIES + +# Can't set properties on ALIAS targets +get_target_property(_glog_target glog::glog ALIASED_TARGET) + +set_target_properties(${_glog_target} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${glog_BINARY_DIR}) # These headers are missing from glog_BINARY_DIR -file(COPY ${glog_SOURCE_DIR}/src/glog/platform.h - DESTINATION ${glog_BINARY_DIR}/glog) -file(COPY ${glog_SOURCE_DIR}/src/glog/log_severity.h - DESTINATION ${glog_BINARY_DIR}/glog) +file(COPY ${glog_SOURCE_DIR}/src/glog/platform.h DESTINATION ${glog_BINARY_DIR}/glog) +file(COPY ${glog_SOURCE_DIR}/src/glog/log_severity.h DESTINATION ${glog_BINARY_DIR}/glog) diff --git a/CMake/resolve_dependency_modules/google_cloud_cpp_storage.cmake b/CMake/resolve_dependency_modules/google_cloud_cpp_storage.cmake index 894cb7a216ed..403b252fa58c 100644 --- a/CMake/resolve_dependency_modules/google_cloud_cpp_storage.cmake +++ b/CMake/resolve_dependency_modules/google_cloud_cpp_storage.cmake @@ -17,12 +17,16 @@ velox_set_source(gRPC) velox_resolve_dependency(gRPC CONFIG 1.48.1 REQUIRED) set(VELOX_GOOGLE_CLOUD_CPP_BUILD_VERSION 2.22.0) -set(VELOX_GOOGLE_CLOUD_CPP_BUILD_SHA256_CHECKSUM - 0c68782e57959c82e0c81def805c01460a042c1aae0c2feee905acaa2a2dc9bf) +set( + VELOX_GOOGLE_CLOUD_CPP_BUILD_SHA256_CHECKSUM + 0c68782e57959c82e0c81def805c01460a042c1aae0c2feee905acaa2a2dc9bf +) string( - CONCAT VELOX_GOOGLE_CLOUD_CPP_SOURCE_URL - "https://github.com/googleapis/google-cloud-cpp/archive/refs/tags/" - "v${VELOX_GOOGLE_CLOUD_CPP_BUILD_VERSION}.tar.gz") + CONCAT + VELOX_GOOGLE_CLOUD_CPP_SOURCE_URL + "https://github.com/googleapis/google-cloud-cpp/archive/refs/tags/" + "v${VELOX_GOOGLE_CLOUD_CPP_BUILD_VERSION}.tar.gz" +) velox_resolve_dependency_url(GOOGLE_CLOUD_CPP) @@ -32,10 +36,11 @@ FetchContent_Declare( google_cloud_cpp URL ${VELOX_GOOGLE_CLOUD_CPP_SOURCE_URL} URL_HASH ${VELOX_GOOGLE_CLOUD_CPP_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) set(GOOGLE_CLOUD_CPP_ENABLE_EXAMPLES OFF) -set(GOOGLE_CLOUD_CPP_ENABLE - "storage" - CACHE STRING "The list of libraries to build.") +set(GOOGLE_CLOUD_CPP_ENABLE "storage" CACHE STRING "The list of libraries to build.") FetchContent_MakeAvailable(google_cloud_cpp) diff --git a/CMake/resolve_dependency_modules/grpc.cmake b/CMake/resolve_dependency_modules/grpc.cmake index 7ee7dd2d9775..5b762077368a 100644 --- a/CMake/resolve_dependency_modules/grpc.cmake +++ b/CMake/resolve_dependency_modules/grpc.cmake @@ -17,12 +17,16 @@ velox_set_source(absl) velox_resolve_dependency(absl CONFIG REQUIRED) set(VELOX_GRPC_BUILD_VERSION 1.48.1) -set(VELOX_GRPC_BUILD_SHA256_CHECKSUM - 320366665d19027cda87b2368c03939006a37e0388bfd1091c8d2a96fbc93bd8) +set( + VELOX_GRPC_BUILD_SHA256_CHECKSUM + 320366665d19027cda87b2368c03939006a37e0388bfd1091c8d2a96fbc93bd8 +) string( - CONCAT VELOX_GRPC_SOURCE_URL - "https://github.com/grpc/grpc/archive/refs/tags/" - "v${VELOX_GRPC_BUILD_VERSION}.tar.gz") + CONCAT + VELOX_GRPC_SOURCE_URL + "https://github.com/grpc/grpc/archive/refs/tags/" + "v${VELOX_GRPC_BUILD_VERSION}.tar.gz" +) velox_resolve_dependency_url(GRPC) @@ -32,32 +36,20 @@ FetchContent_Declare( gRPC URL ${VELOX_GRPC_SOURCE_URL} URL_HASH ${VELOX_GRPC_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL +) # We need to specify CACHE explicitly even when we have # set(CMAKE_POLICY_DEFAULT_CMP0077 NEW). Because gRPC doesn't use option(). gRPC # uses set(... CACHE). So CMP0077 isn't affected. -set(gRPC_ABSL_PROVIDER - "package" - CACHE STRING "Provider of absl library") -set(gRPC_ZLIB_PROVIDER - "package" - CACHE STRING "Provider of zlib library") -set(gRPC_CARES_PROVIDER - "package" - CACHE STRING "Provider of c-ares library") -set(gRPC_RE2_PROVIDER - "package" - CACHE STRING "Provider of re2 library") -set(gRPC_SSL_PROVIDER - "package" - CACHE STRING "Provider of ssl library") -set(gRPC_PROTOBUF_PROVIDER - "package" - CACHE STRING "Provider of protobuf library") -set(gRPC_INSTALL - ON - CACHE BOOL "Generate installation target") +set(gRPC_ABSL_PROVIDER "package" CACHE STRING "Provider of absl library") +set(gRPC_ZLIB_PROVIDER "package" CACHE STRING "Provider of zlib library") +set(gRPC_CARES_PROVIDER "package" CACHE STRING "Provider of c-ares library") +set(gRPC_RE2_PROVIDER "package" CACHE STRING "Provider of re2 library") +set(gRPC_SSL_PROVIDER "package" CACHE STRING "Provider of ssl library") +set(gRPC_PROTOBUF_PROVIDER "package" CACHE STRING "Provider of protobuf library") +set(gRPC_INSTALL ON CACHE BOOL "Generate installation target") FetchContent_MakeAvailable(gRPC) add_library(gRPC::grpc ALIAS grpc) add_library(gRPC::grpc++ ALIAS grpc++) diff --git a/CMake/resolve_dependency_modules/gtest.cmake b/CMake/resolve_dependency_modules/gtest.cmake index f1c892bb4c43..09ea9611aa05 100644 --- a/CMake/resolve_dependency_modules/gtest.cmake +++ b/CMake/resolve_dependency_modules/gtest.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_GTEST_VERSION 1.13.0) -set(VELOX_GTEST_BUILD_SHA256_CHECKSUM - ad7fdba11ea011c1d925b3289cf4af2c66a352e18d4c7264392fead75e919363) -set(VELOX_GTEST_SOURCE_URL - "https://github.com/google/googletest/archive/refs/tags/v${VELOX_GTEST_VERSION}.tar.gz" +set( + VELOX_GTEST_BUILD_SHA256_CHECKSUM + ad7fdba11ea011c1d925b3289cf4af2c66a352e18d4c7264392fead75e919363 +) +set( + VELOX_GTEST_SOURCE_URL + "https://github.com/google/googletest/archive/refs/tags/v${VELOX_GTEST_VERSION}.tar.gz" ) velox_resolve_dependency_url(GTEST) @@ -27,7 +30,10 @@ FetchContent_Declare( googletest URL ${VELOX_GTEST_SOURCE_URL} URL_HASH ${VELOX_GTEST_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE SYSTEM EXCLUDE_FROM_ALL) + OVERRIDE_FIND_PACKAGE + SYSTEM + EXCLUDE_FROM_ALL +) FetchContent_MakeAvailable(googletest) diff --git a/CMake/resolve_dependency_modules/icu.cmake b/CMake/resolve_dependency_modules/icu.cmake index 59f7f4eb0835..8f83c681e37e 100644 --- a/CMake/resolve_dependency_modules/icu.cmake +++ b/CMake/resolve_dependency_modules/icu.cmake @@ -14,13 +14,17 @@ include_guard(GLOBAL) set(VELOX_ICU4C_BUILD_VERSION 72) -set(VELOX_ICU4C_BUILD_SHA256_CHECKSUM - a2d2d38217092a7ed56635e34467f92f976b370e20182ad325edea6681a71d68) +set( + VELOX_ICU4C_BUILD_SHA256_CHECKSUM + a2d2d38217092a7ed56635e34467f92f976b370e20182ad325edea6681a71d68 +) string( - CONCAT VELOX_ICU4C_SOURCE_URL - "https://github.com/unicode-org/icu/releases/download/" - "release-${VELOX_ICU4C_BUILD_VERSION}-1/" - "icu4c-${VELOX_ICU4C_BUILD_VERSION}_1-src.tgz") + CONCAT + VELOX_ICU4C_SOURCE_URL + "https://github.com/unicode-org/icu/releases/download/" + "release-${VELOX_ICU4C_BUILD_VERSION}-1/" + "icu4c-${VELOX_ICU4C_BUILD_VERSION}_1-src.tgz" +) velox_resolve_dependency_url(ICU4C) @@ -31,15 +35,17 @@ velox_set_with_default(NUM_JOBS NUM_THREADS ${NUM_JOBS}) find_program(MAKE_PROGRAM make REQUIRED) set(ICU_CFG --disable-tests --disable-samples) -set(HOST_ENV_CMAKE - ${CMAKE_COMMAND} - -E - env - CC="${CMAKE_C_COMPILER}" - CXX="${CMAKE_CXX_COMPILER}" - CFLAGS="${CMAKE_C_FLAGS}" - CXXFLAGS="${CMAKE_CXX_FLAGS} -w" - LDFLAGS="${CMAKE_MODULE_LINKER_FLAGS}") +set( + HOST_ENV_CMAKE + ${CMAKE_COMMAND} + -E + env + CC="${CMAKE_C_COMPILER}" + CXX="${CMAKE_CXX_COMPILER}" + CFLAGS="${CMAKE_C_FLAGS}" + CXXFLAGS="${CMAKE_CXX_FLAGS} -w" + LDFLAGS="${CMAKE_MODULE_LINKER_FLAGS}" +) set(ICU_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/icu) set(ICU_INCLUDE_DIRS ${ICU_DIR}/include) set(ICU_LIBRARIES ${ICU_DIR}/lib) @@ -51,10 +57,11 @@ ExternalProject_Add( URL_HASH ${VELOX_ICU4C_BUILD_SHA256_CHECKSUM} SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/icu-src BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/icu-bld - CONFIGURE_COMMAND /source/configure --prefix=${ICU_DIR} - --libdir=${ICU_LIBRARIES} ${ICU_CFG} + CONFIGURE_COMMAND + /source/configure --prefix=${ICU_DIR} --libdir=${ICU_LIBRARIES} ${ICU_CFG} BUILD_COMMAND ${MAKE_PROGRAM} -j ${NUM_JOBS} - INSTALL_COMMAND ${HOST_ENV_CMAKE} ${MAKE_PROGRAM} install) + INSTALL_COMMAND ${HOST_ENV_CMAKE} ${MAKE_PROGRAM} install +) add_library(ICU::ICU UNKNOWN IMPORTED) add_dependencies(ICU::ICU ICU-build) @@ -65,26 +72,25 @@ file(MAKE_DIRECTORY ${ICU_INCLUDE_DIRS}) file(MAKE_DIRECTORY ${ICU_LIBRARIES}) # Create a target for each component -set(icu_components - data - i18n - io - uc - tu) +set( + icu_components + data + i18n + io + uc + tu +) foreach(component ${icu_components}) add_library(ICU::${component} SHARED IMPORTED) - string( - CONCAT ICU_${component}_LIBRARY - ${ICU_LIBRARIES} - "/libicu" - ${component} - ".so") + string(CONCAT ICU_${component}_LIBRARY ${ICU_LIBRARIES} "/libicu" ${component} ".so") file(TOUCH ${ICU_${component}_LIBRARY}) set_target_properties( ICU::${component} - PROPERTIES IMPORTED_LOCATION ${ICU_${component}_LIBRARY} - INTERFACE_SYSTEM_INCLUDE_DIRECTORIES ${ICU_INCLUDE_DIRS}) + PROPERTIES + IMPORTED_LOCATION ${ICU_${component}_LIBRARY} + INTERFACE_SYSTEM_INCLUDE_DIRECTORIES ${ICU_INCLUDE_DIRS} + ) target_link_libraries(ICU::ICU INTERFACE ICU::${component}) endforeach() diff --git a/CMake/resolve_dependency_modules/icu/FindICU.cmake b/CMake/resolve_dependency_modules/icu/FindICU.cmake index 15413ecc415d..bfcd265b2dc7 100644 --- a/CMake/resolve_dependency_modules/icu/FindICU.cmake +++ b/CMake/resolve_dependency_modules/icu/FindICU.cmake @@ -13,12 +13,14 @@ # limitations under the License. message("Using ICU - Bundled") set(ICU_FOUND TRUE) -set(icu_components - data - i18n - io - uc - tu) +set( + icu_components + data + i18n + io + uc + tu +) foreach(component icu_components) set(ICU_${component}_FOUND TRUE) endforeach() diff --git a/CMake/resolve_dependency_modules/log_surgeon.cmake b/CMake/resolve_dependency_modules/log_surgeon.cmake index f6d206666a70..2f549adeff22 100644 --- a/CMake/resolve_dependency_modules/log_surgeon.cmake +++ b/CMake/resolve_dependency_modules/log_surgeon.cmake @@ -16,8 +16,9 @@ include_guard(GLOBAL) FetchContent_Declare( log_surgeon GIT_REPOSITORY https://github.com/y-scope/log-surgeon.git - GIT_TAG 85d4f2c09c0e55f1fb87cdc8b0f4d13fb1a733e1 - OVERRIDE_FIND_PACKAGE) + GIT_TAG 193e1f91eb137bb935a7f44b13cc8dd945a8d742 + OVERRIDE_FIND_PACKAGE +) set(log_surgeon_BUILD_TESTING OFF) FetchContent_MakeAvailable(log_surgeon) diff --git a/CMake/resolve_dependency_modules/microsoft_gsl.cmake b/CMake/resolve_dependency_modules/microsoft_gsl.cmake index 5fc1731231e3..808b4ad2d456 100644 --- a/CMake/resolve_dependency_modules/microsoft_gsl.cmake +++ b/CMake/resolve_dependency_modules/microsoft_gsl.cmake @@ -15,10 +15,13 @@ include_guard(GLOBAL) # Version you want to build set(VELOX_GSL_BUILD_VERSION 4.0.0) -set(VELOX_GSL_BUILD_SHA256_CHECKSUM - f0e32cb10654fea91ad56bde89170d78cfbf4363ee0b01d8f097de2ba49f6ce9) -set(VELOX_GSL_SOURCE_URL - "https://github.com/microsoft/GSL/archive/refs/tags/v${VELOX_GSL_BUILD_VERSION}.tar.gz" +set( + VELOX_GSL_BUILD_SHA256_CHECKSUM + f0e32cb10654fea91ad56bde89170d78cfbf4363ee0b01d8f097de2ba49f6ce9 +) +set( + VELOX_GSL_SOURCE_URL + "https://github.com/microsoft/GSL/archive/refs/tags/v${VELOX_GSL_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(GSL) @@ -29,6 +32,9 @@ FetchContent_Declare( Microsoft.GSL URL ${VELOX_GSL_SOURCE_URL} URL_HASH ${VELOX_GSL_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) FetchContent_MakeAvailable(Microsoft.GSL) diff --git a/CMake/resolve_dependency_modules/msgpack-cxx.cmake b/CMake/resolve_dependency_modules/msgpack-cxx.cmake index 9f20800b74e3..3e1215d761db 100644 --- a/CMake/resolve_dependency_modules/msgpack-cxx.cmake +++ b/CMake/resolve_dependency_modules/msgpack-cxx.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_MSGPACK_BUILD_VERSION cpp-7.0.0) -set(VELOX_MSGPACK_BUILD_SHA256_CHECKSUM - 070881ebea9208cf7e731fd5a46a11404025b2f260ab9527e32dfcb7c689fbfc) -set(VELOX_MSGPACK_SOURCE_URL - "https://github.com/msgpack/msgpack-c/archive/refs/tags/${VELOX_MSGPACK_BUILD_VERSION}.tar.gz" +set( + VELOX_MSGPACK_BUILD_SHA256_CHECKSUM + 070881ebea9208cf7e731fd5a46a11404025b2f260ab9527e32dfcb7c689fbfc +) +set( + VELOX_MSGPACK_SOURCE_URL + "https://github.com/msgpack/msgpack-c/archive/refs/tags/${VELOX_MSGPACK_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(MSGPACK) @@ -28,7 +31,10 @@ FetchContent_Declare( msgpack-cxx URL ${VELOX_MSGPACK_SOURCE_URL} URL_HASH ${VELOX_MSGPACK_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) set(MSGPACK_USE_BOOST OFF) diff --git a/CMake/resolve_dependency_modules/nlohmann_json.cmake b/CMake/resolve_dependency_modules/nlohmann_json.cmake index 1b0298f6618e..0405de452a6e 100644 --- a/CMake/resolve_dependency_modules/nlohmann_json.cmake +++ b/CMake/resolve_dependency_modules/nlohmann_json.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_NLOHMANN_JSON_BUILD_VERSION 3.11.3) -set(VELOX_NLOHMANN_JSON_BUILD_SHA256_CHECKSUM - 0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406) -set(VELOX_NLOHMANN_JSON_SOURCE_URL - "https://github.com/nlohmann/json/archive/refs/tags/v${VELOX_NLOHMANN_JSON_BUILD_VERSION}.tar.gz" +set( + VELOX_NLOHMANN_JSON_BUILD_SHA256_CHECKSUM + 0d8ef5af7f9794e3263480193c491549b2ba6cc74bb018906202ada498a79406 +) +set( + VELOX_NLOHMANN_JSON_SOURCE_URL + "https://github.com/nlohmann/json/archive/refs/tags/v${VELOX_NLOHMANN_JSON_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(NLOHMANN_JSON) @@ -28,10 +31,9 @@ FetchContent_Declare( nlohmann_json URL ${VELOX_NLOHMANN_JSON_SOURCE_URL} URL_HASH ${VELOX_NLOHMANN_JSON_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE) + OVERRIDE_FIND_PACKAGE +) -set(JSON_BuildTests - OFF - CACHE INTERNAL "") +set(JSON_BuildTests OFF CACHE INTERNAL "") FetchContent_MakeAvailable(nlohmann_json) diff --git a/CMake/resolve_dependency_modules/protobuf.cmake b/CMake/resolve_dependency_modules/protobuf.cmake index 3193eedac934..a7d54da6554b 100644 --- a/CMake/resolve_dependency_modules/protobuf.cmake +++ b/CMake/resolve_dependency_modules/protobuf.cmake @@ -14,22 +14,26 @@ include_guard(GLOBAL) set(VELOX_PROTOBUF_BUILD_VERSION 21.8) -set(VELOX_PROTOBUF_BUILD_SHA256_CHECKSUM - 83ad4faf95ff9cbece7cb9c56eb3ca9e42c3497b77001840ab616982c6269fb6) +set( + VELOX_PROTOBUF_BUILD_SHA256_CHECKSUM + 83ad4faf95ff9cbece7cb9c56eb3ca9e42c3497b77001840ab616982c6269fb6 +) if(${VELOX_PROTOBUF_BUILD_VERSION} LESS 22.0) string( CONCAT - VELOX_PROTOBUF_SOURCE_URL - "https://github.com/protocolbuffers/protobuf/releases/download/" - "v${VELOX_PROTOBUF_BUILD_VERSION}/protobuf-all-${VELOX_PROTOBUF_BUILD_VERSION}.tar.gz" + VELOX_PROTOBUF_SOURCE_URL + "https://github.com/protocolbuffers/protobuf/releases/download/" + "v${VELOX_PROTOBUF_BUILD_VERSION}/protobuf-all-${VELOX_PROTOBUF_BUILD_VERSION}.tar.gz" ) else() velox_set_source(absl) velox_resolve_dependency(absl CONFIG REQUIRED) string( - CONCAT VELOX_PROTOBUF_SOURCE_URL - "https://github.com/protocolbuffers/protobuf/archive/" - "v${VELOX_PROTOBUF_BUILD_VERSION}.tar.gz") + CONCAT + VELOX_PROTOBUF_SOURCE_URL + "https://github.com/protocolbuffers/protobuf/archive/" + "v${VELOX_PROTOBUF_BUILD_VERSION}.tar.gz" + ) endif() velox_resolve_dependency_url(PROTOBUF) @@ -40,7 +44,10 @@ FetchContent_Declare( protobuf URL ${VELOX_PROTOBUF_SOURCE_URL} URL_HASH ${VELOX_PROTOBUF_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) set(protobuf_BUILD_TESTS OFF) set(protobuf_ABSL_PROVIDER "package") diff --git a/CMake/resolve_dependency_modules/pybind11.cmake b/CMake/resolve_dependency_modules/pybind11.cmake index d9577c9d9ee3..16c1b10ea9cc 100644 --- a/CMake/resolve_dependency_modules/pybind11.cmake +++ b/CMake/resolve_dependency_modules/pybind11.cmake @@ -14,12 +14,16 @@ include_guard(GLOBAL) set(VELOX_PYBIND11_BUILD_VERSION 2.10.0) -set(VELOX_PYBIND11_BUILD_SHA256_CHECKSUM - eacf582fa8f696227988d08cfc46121770823839fe9e301a20fbce67e7cd70ec) +set( + VELOX_PYBIND11_BUILD_SHA256_CHECKSUM + eacf582fa8f696227988d08cfc46121770823839fe9e301a20fbce67e7cd70ec +) string( - CONCAT VELOX_PYBIND11_SOURCE_URL - "https://github.com/pybind/pybind11/archive/refs/tags/" - "v${VELOX_PYBIND11_BUILD_VERSION}.tar.gz") + CONCAT + VELOX_PYBIND11_SOURCE_URL + "https://github.com/pybind/pybind11/archive/refs/tags/" + "v${VELOX_PYBIND11_BUILD_VERSION}.tar.gz" +) velox_resolve_dependency_url(PYBIND11) @@ -28,6 +32,7 @@ message(STATUS "Building Pybind11 from source") FetchContent_Declare( pybind11 URL ${VELOX_PYBIND11_SOURCE_URL} - URL_HASH ${VELOX_PYBIND11_BUILD_SHA256_CHECKSUM}) + URL_HASH ${VELOX_PYBIND11_BUILD_SHA256_CHECKSUM} +) FetchContent_MakeAvailable(pybind11) diff --git a/CMake/resolve_dependency_modules/re2.cmake b/CMake/resolve_dependency_modules/re2.cmake index 2caaf20d3162..738e0564b7cb 100644 --- a/CMake/resolve_dependency_modules/re2.cmake +++ b/CMake/resolve_dependency_modules/re2.cmake @@ -17,18 +17,22 @@ if(DEFINED ENV{VELOX_RE2_URL}) set(VELOX_RE2_SOURCE_URL "$ENV{VELOX_RE2_URL}") else() set(VELOX_RE2_VERSION 2024-07-02) - set(VELOX_RE2_SOURCE_URL - "https://github.com/google/re2/archive/refs/tags/${VELOX_RE2_VERSION}.tar.gz" + set( + VELOX_RE2_SOURCE_URL + "https://github.com/google/re2/archive/refs/tags/${VELOX_RE2_VERSION}.tar.gz" + ) + set( + VELOX_RE2_BUILD_SHA256_CHECKSUM + eb2df807c781601c14a260a507a5bb4509be1ee626024cb45acbd57cb9d4032b ) - set(VELOX_RE2_BUILD_SHA256_CHECKSUM - eb2df807c781601c14a260a507a5bb4509be1ee626024cb45acbd57cb9d4032b) endif() message(STATUS "Building re2 from source") FetchContent_Declare( re2 URL ${VELOX_RE2_SOURCE_URL} - URL_HASH SHA256=${VELOX_RE2_BUILD_SHA256_CHECKSUM}) + URL_HASH SHA256=${VELOX_RE2_BUILD_SHA256_CHECKSUM} +) set(RE2_USE_ICU ON) set(RE2_BUILD_TESTING OFF) @@ -40,9 +44,7 @@ velox_resolve_dependency(absl) FetchContent_MakeAvailable(re2) if("${absl_SOURCE}" STREQUAL "SYSTEM") if(DEFINED absl_VERSION AND "${absl_VERSION}" VERSION_LESS "20240116") - message( - FATAL_ERROR - "Abseil 20240116 or later is required for bundled RE2: ${absl_VERSION}") + message(FATAL_ERROR "Abseil 20240116 or later is required for bundled RE2: ${absl_VERSION}") endif() elseif("${absl_SOURCE}" STREQUAL "BUNDLED") # Build RE2 after Abseil so the files are available @@ -58,3 +60,5 @@ set(re2_INCLUDE_DIRS ${re2_SOURCE_DIR}) set(RE2_ROOT ${re2_BINARY_DIR}) set(re2_ROOT ${re2_BINARY_DIR}) + +unset(BUILD_TESTING CACHE) diff --git a/CMake/resolve_dependency_modules/simdjson.cmake b/CMake/resolve_dependency_modules/simdjson.cmake index 5d28789747ee..59894896c8bf 100644 --- a/CMake/resolve_dependency_modules/simdjson.cmake +++ b/CMake/resolve_dependency_modules/simdjson.cmake @@ -13,11 +13,14 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_SIMDJSON_VERSION 3.9.3) -set(VELOX_SIMDJSON_BUILD_SHA256_CHECKSUM - 2e3d10abcde543d3dd8eba9297522cafdcebdd1db4f51b28f3bc95bf1d6ad23c) -set(VELOX_SIMDJSON_SOURCE_URL - "https://github.com/simdjson/simdjson/archive/refs/tags/v${VELOX_SIMDJSON_VERSION}.tar.gz" +set(VELOX_SIMDJSON_VERSION 4.1.0) +set( + VELOX_SIMDJSON_BUILD_SHA256_CHECKSUM + 78115e37b2e88ec63e6ae20bb148063a9112c55bcd71404c8572078fd8a6ac3e +) +set( + VELOX_SIMDJSON_SOURCE_URL + "https://github.com/simdjson/simdjson/archive/refs/tags/v${VELOX_SIMDJSON_VERSION}.tar.gz" ) velox_resolve_dependency_url(SIMDJSON) @@ -28,7 +31,7 @@ FetchContent_Declare( simdjson URL ${VELOX_SIMDJSON_SOURCE_URL} URL_HASH ${VELOX_SIMDJSON_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE) +) if(${VELOX_SIMDJSON_SKIPUTF8VALIDATION}) set(SIMDJSON_SKIPUTF8VALIDATION ON) diff --git a/CMake/resolve_dependency_modules/spdlog.cmake b/CMake/resolve_dependency_modules/spdlog.cmake index a7a1ddfe6988..68c54dc41a0e 100644 --- a/CMake/resolve_dependency_modules/spdlog.cmake +++ b/CMake/resolve_dependency_modules/spdlog.cmake @@ -13,11 +13,14 @@ # limitations under the License. include_guard(GLOBAL) -set(VELOX_SPDLOG_BUILD_VERSION 1.12.0) -set(VELOX_SPDLOG_BUILD_SHA256_CHECKSUM - 4dccf2d10f410c1e2feaff89966bfc49a1abb29ef6f08246335b110e001e09a9) -set(VELOX_SPDLOG_SOURCE_URL - "https://github.com/gabime/spdlog/archive/refs/tags/v${VELOX_SPDLOG_BUILD_VERSION}.tar.gz" +set(VELOX_SPDLOG_BUILD_VERSION 1.15.3) +set( + VELOX_SPDLOG_BUILD_SHA256_CHECKSUM + 15a04e69c222eb6c01094b5c7ff8a249b36bb22788d72519646fb85feb267e67 +) +set( + VELOX_SPDLOG_SOURCE_URL + "https://github.com/gabime/spdlog/archive/refs/tags/v${VELOX_SPDLOG_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(SPDLOG) @@ -28,7 +31,10 @@ FetchContent_Declare( spdlog URL ${VELOX_SPDLOG_SOURCE_URL} URL_HASH ${VELOX_SPDLOG_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) set(SPDLOG_FMT_EXTERNAL ON) FetchContent_MakeAvailable(spdlog) diff --git a/CMake/resolve_dependency_modules/stemmer.cmake b/CMake/resolve_dependency_modules/stemmer.cmake index a3ab123f1659..17686ea2832a 100644 --- a/CMake/resolve_dependency_modules/stemmer.cmake +++ b/CMake/resolve_dependency_modules/stemmer.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_STEMMER_VERSION 2.2.0) -set(VELOX_STEMMER_BUILD_SHA256_CHECKSUM - b941d9fe9cf36b4e2f8d3873cd4d8b8775bd94867a1df8d8c001bb8b688377c3) -set(VELOX_STEMMER_SOURCE_URL - "https://snowballstem.org/dist/libstemmer_c-${VELOX_STEMMER_VERSION}.tar.gz" +set( + VELOX_STEMMER_BUILD_SHA256_CHECKSUM + b941d9fe9cf36b4e2f8d3873cd4d8b8775bd94867a1df8d8c001bb8b688377c3 +) +set( + VELOX_STEMMER_SOURCE_URL + "https://snowballstem.org/dist/libstemmer_c-${VELOX_STEMMER_VERSION}.tar.gz" ) velox_resolve_dependency_url(STEMMER) @@ -51,7 +54,8 @@ set_target_properties( stemmer PROPERTIES IMPORTED_LOCATION - ${STEMMER_PREFIX}/src/libstemmer/${CMAKE_STATIC_LIBRARY_PREFIX}stemmer${CMAKE_STATIC_LIBRARY_SUFFIX} - INTERFACE_INCLUDE_DIRECTORIES ${STEMMER_INCLUDE_PATH}) + ${STEMMER_PREFIX}/src/libstemmer/${CMAKE_STATIC_LIBRARY_PREFIX}stemmer${CMAKE_STATIC_LIBRARY_SUFFIX} + INTERFACE_INCLUDE_DIRECTORIES ${STEMMER_INCLUDE_PATH} +) add_dependencies(stemmer libstemmer) diff --git a/CMake/resolve_dependency_modules/template.cmake b/CMake/resolve_dependency_modules/template.cmake index 1db752bac7f0..5455efb8781b 100644 --- a/CMake/resolve_dependency_modules/template.cmake +++ b/CMake/resolve_dependency_modules/template.cmake @@ -16,8 +16,8 @@ include_guard(GLOBAL) set(VELOX__VERSION x.y.z) # release artifacts are tough (except the auto generated ones) set(VELOX__BUILD_SHA256_CHECKSUM 123) -set(VELOX__SOURCE_URL "") # ideally don't use github archive links as - # they are not guranteed to be hash stable +# ideally don't use github archive links as they are not guaranteed to be hash stable +set(VELOX__SOURCE_URL "") velox_resolve_dependency_url() @@ -25,6 +25,7 @@ message(STATUS "Building from source") FetchContent_Declare( URL ${VELOX__SOURCE_URL} - URL_HASH ${VELOX__BUILD_SHA256_CHECKSUM}) + URL_HASH ${VELOX__BUILD_SHA256_CHECKSUM} +) FetchContent_MakeAvailable() diff --git a/CMake/resolve_dependency_modules/xsimd.cmake b/CMake/resolve_dependency_modules/xsimd.cmake index 56f1bc2aef97..7cbd4258ab25 100644 --- a/CMake/resolve_dependency_modules/xsimd.cmake +++ b/CMake/resolve_dependency_modules/xsimd.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_XSIMD_VERSION 10.0.0) -set(VELOX_XSIMD_BUILD_SHA256_CHECKSUM - 73f818368b3a4dad92fab1b2933d93694241bd2365a6181747b2df1768f6afdd) -set(VELOX_XSIMD_SOURCE_URL - "https://github.com/xtensor-stack/xsimd/archive/refs/tags/${VELOX_XSIMD_VERSION}.tar.gz" +set( + VELOX_XSIMD_BUILD_SHA256_CHECKSUM + 73f818368b3a4dad92fab1b2933d93694241bd2365a6181747b2df1768f6afdd +) +set( + VELOX_XSIMD_SOURCE_URL + "https://github.com/xtensor-stack/xsimd/archive/refs/tags/${VELOX_XSIMD_VERSION}.tar.gz" ) velox_resolve_dependency_url(XSIMD) @@ -26,6 +29,7 @@ message(STATUS "Building xsimd from source") FetchContent_Declare( xsimd URL ${VELOX_XSIMD_SOURCE_URL} - URL_HASH ${VELOX_XSIMD_BUILD_SHA256_CHECKSUM}) + URL_HASH ${VELOX_XSIMD_BUILD_SHA256_CHECKSUM} +) FetchContent_MakeAvailable(xsimd) diff --git a/CMake/resolve_dependency_modules/ystdlib.cmake b/CMake/resolve_dependency_modules/ystdlib.cmake index eca92f1d59e7..a6cb7b29d3f0 100644 --- a/CMake/resolve_dependency_modules/ystdlib.cmake +++ b/CMake/resolve_dependency_modules/ystdlib.cmake @@ -14,10 +14,13 @@ include_guard(GLOBAL) set(VELOX_YSTDLIB_BUILD_VERSION 9ed78cd) -set(VELOX_YSTDLIB_BUILD_SHA256_CHECKSUM - 65990dc2bcc4a355c2181bfe31a7800f492309d1bcd340f52a34e85047e61bc8) -set(VELOX_YSTDLIB_SOURCE_URL - "https://github.com/y-scope/ystdlib-cpp/archive/${VELOX_YSTDLIB_BUILD_VERSION}.tar.gz" +set( + VELOX_YSTDLIB_BUILD_SHA256_CHECKSUM + 65990dc2bcc4a355c2181bfe31a7800f492309d1bcd340f52a34e85047e61bc8 +) +set( + VELOX_YSTDLIB_SOURCE_URL + "https://github.com/y-scope/ystdlib-cpp/archive/${VELOX_YSTDLIB_BUILD_VERSION}.tar.gz" ) velox_resolve_dependency_url(YSTDLIB) @@ -28,7 +31,10 @@ FetchContent_Declare( ystdlib URL ${VELOX_YSTDLIB_SOURCE_URL} URL_HASH ${VELOX_YSTDLIB_BUILD_SHA256_CHECKSUM} - OVERRIDE_FIND_PACKAGE EXCLUDE_FROM_ALL SYSTEM) + OVERRIDE_FIND_PACKAGE + EXCLUDE_FROM_ALL + SYSTEM +) set(ystdlib_BUILD_TESTING OFF) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8dbe5950d1b0..b2f0e2fc0101 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License. cmake_minimum_required(VERSION 3.28) message(STATUS "Building using CMake version: ${CMAKE_VERSION}") -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED TRUE) set(CMAKE_CXX_EXTENSIONS ON) # Big Int is an extension message(STATUS "CXX standard: ${CMAKE_CXX_STANDARD}") @@ -36,15 +36,16 @@ project(velox) # If we are in an active conda env disable search in system paths and add env to # prefix path if(DEFINED ENV{CONDA_PREFIX}) - if(NOT DEFINED ENV{VELOX_DEPENDENCY_SOURCE} OR "$ENV{VELOX_DEPENDENCY_SOURCE}" - STREQUAL "CONDA") + if(NOT DEFINED ENV{VELOX_DEPENDENCY_SOURCE} OR "$ENV{VELOX_DEPENDENCY_SOURCE}" STREQUAL "CONDA") message(STATUS "Using Conda environment: $ENV{CONDA_PREFIX}") set(CMAKE_FIND_USE_SYSTEM_ENVIRONMENT_PATH FALSE) list(APPEND CMAKE_PREFIX_PATH "$ENV{CONDA_PREFIX}") # Override in case it was set to CONDA set(ENV{VELOX_DEPENDENCY_SOURCE} AUTO) - elseif(DEFINED ENV{VELOX_DEPENDENCY_SOURCE} - AND NOT "$ENV{VELOX_DEPENDENCY_SOURCE}" STREQUAL "CONDA") + elseif( + DEFINED ENV{VELOX_DEPENDENCY_SOURCE} + AND NOT "$ENV{VELOX_DEPENDENCY_SOURCE}" STREQUAL "CONDA" + ) message(STATUS "Overriding Conda environment: $ENV{CONDA_PREFIX}") endif() endif() @@ -54,37 +55,43 @@ if(DEFINED ENV{INSTALL_PREFIX}) list(APPEND CMAKE_PREFIX_PATH "$ENV{INSTALL_PREFIX}") endif() -list(PREPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/CMake" - "${PROJECT_SOURCE_DIR}/CMake/third-party") +list( + PREPEND + CMAKE_MODULE_PATH + "${PROJECT_SOURCE_DIR}/CMake" + "${PROJECT_SOURCE_DIR}/CMake/third-party" +) # Include our ThirdPartyToolchain dependencies macros include(ResolveDependency) include(VeloxUtils) include(CMakeDependentOption) -velox_set_with_default(VELOX_DEPENDENCY_SOURCE_DEFAULT VELOX_DEPENDENCY_SOURCE - AUTO) +velox_set_with_default(VELOX_DEPENDENCY_SOURCE_DEFAULT VELOX_DEPENDENCY_SOURCE AUTO) message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") # Add all options below option( VELOX_BUILD_TESTING "Enable Velox tests. This will enable all other build options automatically." - ON) + ON +) option( VELOX_BUILD_MINIMAL "Build a minimal set of components only. This will override other build options." - OFF) + OFF +) option( VELOX_BUILD_MINIMAL_WITH_DWIO "Build a minimal set of components, including DWIO (file format readers/writers). This will override other build options." - OFF) + OFF +) option(VELOX_MONO_LIBRARY "Build single unified library." OFF) option(ENABLE_ALL_WARNINGS "Enable -Wall and -Wextra compiler warnings." ON) option(VELOX_BUILD_SHARED "Build Velox as shared libraries." OFF) -option(VELOX_SKIP_WAVE_BRANCH_KERNEL_TEST "Disable Wave branch kernel test." - OFF) +option(VELOX_BUILD_CMAKE_PACKAGE "Build CMake package for Velox." OFF) +option(VELOX_SKIP_WAVE_BRANCH_KERNEL_TEST "Disable Wave branch kernel test." OFF) # While it's possible to build both in one go we currently want to build either # static or shared. cmake_dependent_option( @@ -92,7 +99,8 @@ cmake_dependent_option( "Build Velox as static libraries." ON "NOT VELOX_BUILD_SHARED" - OFF) + OFF +) if(VELOX_BUILD_SHARED AND NOT VELOX_MONO_LIBRARY) # The large number of targets currently in use within Velox make a shared @@ -103,8 +111,9 @@ endif() if(VELOX_BUILD_SHARED) message( WARNING - "When building Velox as a shared library it's recommended to build against a shared build of folly to avoid issues with linking of gflags." - "This is currently NOT being enforced so user discretion is advised.") + "When building Velox as a shared library it's recommended to build against a shared build of folly to avoid issues with linking of gflags. " + "This is currently NOT being enforced so user discretion is advised." + ) endif() # option() always creates a BOOL variable so we have to use a normal cache @@ -113,30 +122,33 @@ endif() # * AUTO: Try SYSTEM first fall back to BUNDLED. # * SYSTEM: Use installed dependencies via find_package. # * BUNDLED: Build dependencies from source. -set(VELOX_DEPENDENCY_SOURCE - ${VELOX_DEPENDENCY_SOURCE_DEFAULT} - CACHE - STRING - "Default source for all dependencies with source builds enabled: AUTO SYSTEM BUNDLED." +set( + VELOX_DEPENDENCY_SOURCE + ${VELOX_DEPENDENCY_SOURCE_DEFAULT} + CACHE STRING + "Default source for all dependencies with source builds enabled: AUTO SYSTEM BUNDLED." ) -set(VELOX_GFLAGS_TYPE - "shared" - CACHE - STRING - "Specify whether to find the gflags package as a shared or static package" +set( + VELOX_GFLAGS_TYPE + "shared" + CACHE STRING + "Specify whether to find the gflags package as a shared or static package" ) option(VELOX_ENABLE_EXEC "Build exec." ON) option(VELOX_ENABLE_AGGREGATES "Build aggregates." ON) option(VELOX_ENABLE_CLP_CONNECTOR "Build CLP connector." ON) option(VELOX_ENABLE_HIVE_CONNECTOR "Build Hive connector." ON) option(VELOX_ENABLE_TPCH_CONNECTOR "Build TPC-H connector." ON) +option(VELOX_ENABLE_TPCDS_CONNECTOR "Build TPC-DS connector." ON) option(VELOX_ENABLE_PRESTO_FUNCTIONS "Build Presto SQL functions." ON) option(VELOX_ENABLE_SPARK_FUNCTIONS "Build Spark SQL functions." ON) +option(VELOX_ENABLE_ICEBERG_FUNCTIONS "Build Iceberg functions." ON) option(VELOX_ENABLE_EXPRESSION "Build expression." ON) option( VELOX_ENABLE_EXAMPLES - "Build examples. This will enable VELOX_ENABLE_EXPRESSION automatically." OFF) -option(VELOX_ENABLE_SUBSTRAIT "Build Substrait-to-Velox converter." OFF) + "Build examples. This will enable VELOX_ENABLE_EXPRESSION automatically." + OFF +) option(VELOX_ENABLE_BENCHMARKS "Enable Velox top level benchmarks." OFF) option(VELOX_ENABLE_BENCHMARKS_BASIC "Enable Velox basic benchmarks." OFF) option(VELOX_ENABLE_S3 "Build S3 Connector" OFF) @@ -158,15 +170,15 @@ option(VELOX_BUILD_RUNNER "Builds velox runner" ON) option( VELOX_ENABLE_INT64_BUILD_PARTITION_BOUND "make buildPartitionBounds_ a vector int64 instead of int32 to avoid integer overflow when the hashtable has billions of records" - OFF) -option(VELOX_SIMDJSON_SKIPUTF8VALIDATION - "Skip simdjson utf8 validation in JSON parsing" OFF) + OFF +) +option(VELOX_SIMDJSON_SKIPUTF8VALIDATION "Skip simdjson utf8 validation in JSON parsing" OFF) +option(VELOX_ENABLE_FAISS "Build faiss vector search support" OFF) # Explicitly force compilers to generate colored output. Compilers usually do # this by default if they detect the output is a terminal, but this assumption # is broken if you use ninja. -option(VELOX_FORCE_COLORED_OUTPUT - "Always produce ANSI-colored output (GNU/Clang only)." OFF) +option(VELOX_FORCE_COLORED_OUTPUT "Always produce ANSI-colored output (GNU/Clang only)." OFF) if(${VELOX_BUILD_MINIMAL} OR ${VELOX_BUILD_MINIMAL_WITH_DWIO}) # Enable and disable components for velox base build @@ -178,12 +190,20 @@ if(${VELOX_BUILD_MINIMAL} OR ${VELOX_BUILD_MINIMAL_WITH_DWIO}) set(VELOX_ENABLE_CLP_CONNECTOR OFF) set(VELOX_ENABLE_HIVE_CONNECTOR OFF) set(VELOX_ENABLE_TPCH_CONNECTOR OFF) + set(VELOX_ENABLE_TPCDS_CONNECTOR OFF) set(VELOX_ENABLE_SPARK_FUNCTIONS OFF) set(VELOX_ENABLE_EXAMPLES OFF) set(VELOX_ENABLE_S3 OFF) set(VELOX_ENABLE_GCS OFF) set(VELOX_ENABLE_ABFS OFF) - set(VELOX_ENABLE_SUBSTRAIT OFF) +else() + if(VELOX_BUILD_CMAKE_PACKAGE) + message( + FATAL_ERROR + "VELOX_BUILD_CMAKE_PACKAGE is only available with " + "VELOX_BUILD_MINIMAL=ON or VELOX_BUILD_MINIMAL_WITH_DWIO=ON for now." + ) + endif() endif() if(${VELOX_ENABLE_BENCHMARKS}) @@ -194,16 +214,24 @@ if(VELOX_ENABLE_BENCHMARKS_BASIC) set(VELOX_BUILD_TEST_UTILS ON) endif() +if(VELOX_ENABLE_CUDF) + message(STATUS "Building curl from source to satisfy cuDF curl version requirement") + set(CURL_SOURCE BUNDLED) + velox_resolve_dependency(CURL) +endif() + if(VELOX_BUILD_TESTING OR VELOX_BUILD_TEST_UTILS) - set(cpr_SOURCE BUNDLED) - velox_resolve_dependency(cpr) + # cuDF bundles curl since it needs a specific version. + # Use bundled or system curl otherwise. + if(NOT VELOX_ENABLE_CUDF) + velox_set_source(CURL) + velox_resolve_dependency(CURL) + endif() set(VELOX_ENABLE_DUCKDB ON) set(VELOX_ENABLE_PARSE ON) endif() -if(${VELOX_BUILD_TESTING} - OR ${VELOX_BUILD_MINIMAL_WITH_DWIO} - OR ${VELOX_ENABLE_HIVE_CONNECTOR}) +if(${VELOX_BUILD_TESTING} OR ${VELOX_BUILD_MINIMAL_WITH_DWIO} OR ${VELOX_ENABLE_HIVE_CONNECTOR}) set(VELOX_ENABLE_COMPRESSION_LZ4 ON) endif() @@ -231,6 +259,14 @@ if(${VELOX_ENABLE_DUCKDB}) velox_resolve_dependency(DuckDB) endif() +if(VELOX_ENABLE_FAISS) + velox_set_source(faiss) + velox_resolve_dependency(faiss) + if(NOT TARGET FAISS::faiss) + add_library(FAISS::faiss ALIAS faiss) + endif() +endif() + if(DEFINED ENV{INSTALL_PREFIX}) # Allow installed package headers to be picked up before brew/system package # headers. We set this after DuckDB bundling since DuckDB uses its own @@ -242,11 +278,12 @@ endif() # dependencies. find_package(OpenSSL REQUIRED) -if(VELOX_ENABLE_CCACHE - AND NOT CMAKE_C_COMPILER_LAUNCHER - AND NOT CMAKE_CXX_COMPILER_LAUNCHER - AND NOT CMAKE_CUDA_COMPILER_LAUNCHER) - +if( + VELOX_ENABLE_CCACHE + AND NOT CMAKE_C_COMPILER_LAUNCHER + AND NOT CMAKE_CXX_COMPILER_LAUNCHER + AND NOT CMAKE_CUDA_COMPILER_LAUNCHER +) find_program(CCACHE_FOUND ccache) if(CCACHE_FOUND) @@ -262,8 +299,10 @@ endif() if(${VELOX_FORCE_COLORED_OUTPUT}) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") add_compile_options(-fdiagnostics-color=always) - elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" - OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") + elseif( + "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" + OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" + ) add_compile_options(-fcolor-diagnostics) endif() endif() @@ -273,7 +312,7 @@ if(VELOX_ENABLE_S3) if(AWSSDK_ROOT_DIR) list(APPEND CMAKE_PREFIX_PATH ${AWSSDK_ROOT_DIR}) endif() - find_package(AWSSDK REQUIRED COMPONENTS s3;identity-management) + find_package(AWSSDK 1.11.654 REQUIRED COMPONENTS s3;identity-management) add_definitions(-DVELOX_ENABLE_S3) endif() @@ -323,12 +362,38 @@ endif() # Required so velox code can be used in a dynamic library set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) +# For C++20 support we need GNU GCC11 (or later versions) or Clang/AppleClang 15 +# (or later versions) to get support for the used features. +if( + NOT + ( + ( + "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" + AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER_EQUAL 11 + ) + OR + ( + ( + "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" + OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang" + ) + AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 15 + ) + ) +) + message( + FATAL_ERROR + "Unsupported compiler ${CMAKE_CXX_COMPILER_ID} with version ${CMAKE_CXX_COMPILER_VERSION} found." + ) +endif() + execute_process( COMMAND bash -c "( source ${CMAKE_CURRENT_SOURCE_DIR}/scripts/setup-helper-functions.sh && echo -n $(get_cxx_flags $ENV{CPU_TARGET}))" OUTPUT_VARIABLE SCRIPT_CXX_FLAGS - RESULT_VARIABLE COMMAND_STATUS) + RESULT_VARIABLE COMMAND_STATUS +) if(COMMAND_STATUS EQUAL "1") message(FATAL_ERROR "Unable to determine compiler flags!") @@ -353,28 +418,28 @@ string(APPEND CMAKE_CXX_FLAGS " -DFOLLY_CFG_NO_COROUTINES") # Under Ninja, we are able to designate certain targets large enough to require # restricted parallelism. if("${MAX_HIGH_MEM_JOBS}") - set_property(GLOBAL PROPERTY JOB_POOLS - "high_memory_pool=${MAX_HIGH_MEM_JOBS}") + set_property(GLOBAL PROPERTY JOB_POOLS "high_memory_pool=${MAX_HIGH_MEM_JOBS}") else() set_property(GLOBAL PROPERTY JOB_POOLS high_memory_pool=1000) endif() if("${MAX_LINK_JOBS}") - set_property(GLOBAL APPEND PROPERTY JOB_POOLS - "link_job_pool=${MAX_LINK_JOBS}") + set_property(GLOBAL APPEND PROPERTY JOB_POOLS "link_job_pool=${MAX_LINK_JOBS}") set(CMAKE_JOB_POOL_LINK link_job_pool) endif() if(ENABLE_ALL_WARNINGS) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - set(KNOWN_COMPILER_SPECIFIC_WARNINGS - "-Wno-implicit-const-int-float-conversion \ - -Wno-range-loop-analysis \ + set( + KNOWN_COMPILER_SPECIFIC_WARNINGS + "-Wno-range-loop-analysis \ -Wno-mismatched-tags \ - -Wno-nullability-completeness") + -Wno-nullability-completeness" + ) elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(KNOWN_COMPILER_SPECIFIC_WARNINGS - "-Wno-implicit-fallthrough \ + set( + KNOWN_COMPILER_SPECIFIC_WARNINGS + "-Wno-implicit-fallthrough \ -Wno-class-memaccess \ -Wno-comment \ -Wno-int-in-bool-context \ @@ -383,15 +448,28 @@ if(ENABLE_ALL_WARNINGS) -Wno-maybe-uninitialized \ -Wno-unused-result \ -Wno-format-overflow \ - -Wno-strict-aliasing") + -Wno-strict-aliasing" + ) + # Avoid compiler bug for GCC 12.2.1 + # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105329 + if(CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL "12.2.1") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-restrict") + endif() + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "14.0.0") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=template-id-cdtor") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-overloaded-virtual") + string(APPEND KNOWN_COMPILER_SPECIFIC_WARNINGS " -Wno-error=tautological-compare") + endif() endif() - set(KNOWN_WARNINGS - "-Wno-unused \ + set( + KNOWN_WARNINGS + "-Wno-unused \ -Wno-unused-parameter \ -Wno-sign-compare \ -Wno-ignored-qualifiers \ - ${KNOWN_COMPILER_SPECIFIC_WARNINGS}") + ${KNOWN_COMPILER_SPECIFIC_WARNINGS}" + ) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra ${KNOWN_WARNINGS}") endif() @@ -413,20 +491,32 @@ if(VELOX_ENABLE_WAVE OR VELOX_ENABLE_CUDF) if(arch LESS 70) message( FATAL_ERROR - "CUDA architecture ${arch} is below 70. CUDF requires Volta (SM 70) or newer GPUs." + "CUDA architecture ${arch} is below 70. CUDF requires Volta (SM 70) or newer GPUs." ) endif() endforeach() set(VELOX_ENABLE_ARROW ON) velox_set_source(cudf) velox_resolve_dependency(cudf) + if(TARGET aws-cpp-sdk-core) + # Fix for AWS SDK CPP using hardcoded system curl instead of soft link to curl + get_target_property(override_curl_lib aws-cpp-sdk-core INTERFACE_LINK_LIBRARIES) + list(REMOVE_ITEM override_curl_lib "/usr/lib64/libcurl.so") + list(APPEND override_curl_lib "\$") + set_target_properties( + aws-cpp-sdk-core + PROPERTIES INTERFACE_LINK_LIBRARIES "${override_curl_lib}" + ) + endif() endif() endif() # Set after the test of the CUDA compiler. Otherwise, the test fails with # -latomic not found because it is added right after the compiler exe. -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" - AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_GREATER_EQUAL 15) +if( + "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" + AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_GREATER_EQUAL 15 +) set(CMAKE_EXE_LINKER_FLAGS "-latomic") endif() @@ -437,7 +527,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin") COMMAND brew --prefix icu4c RESULT_VARIABLE BREW_ICU4C OUTPUT_VARIABLE BREW_ICU4C_PREFIX - OUTPUT_STRIP_TRAILING_WHITESPACE) + OUTPUT_STRIP_TRAILING_WHITESPACE + ) if(BREW_ICU4C EQUAL 0 AND EXISTS "${BREW_ICU4C_PREFIX}") message(STATUS "Found icu4c installed by Homebrew at ${BREW_ICU4C_PREFIX}") list(APPEND CMAKE_PREFIX_PATH "${BREW_ICU4C_PREFIX}") @@ -449,24 +540,22 @@ endif() velox_set_source(ICU) velox_resolve_dependency( ICU - COMPONENTS - data - i18n - io - uc - tu) - -set(BOOST_INCLUDE_LIBRARIES - atomic - context - date_time - filesystem - iostreams - program_options - regex - system - url - thread) + COMPONENTS data i18n io uc tu +) + +set( + BOOST_INCLUDE_LIBRARIES + atomic + context + date_time + filesystem + iostreams + program_options + regex + system + url + thread +) velox_set_source(Boost) velox_resolve_dependency(Boost 1.77.0 COMPONENTS ${BOOST_INCLUDE_LIBRARIES}) @@ -480,9 +569,7 @@ if(NOT TARGET gflags::gflags) # target even when velox is built as a subproject which uses # `find_package(gflags)` which does not create a globally imported target that # we can ALIAS. - add_library(gflags_gflags INTERFACE) - target_link_libraries(gflags_gflags INTERFACE gflags) - add_library(gflags::gflags ALIAS gflags_gflags) + add_library(gflags::gflags ALIAS gflags) endif() if(${gflags_SOURCE} STREQUAL "BUNDLED") @@ -495,7 +582,7 @@ endif() velox_resolve_dependency(glog) velox_set_source(fmt) -velox_resolve_dependency(fmt 9.0.0) +velox_resolve_dependency(fmt 11.2.0) if(VELOX_ENABLE_COMPRESSION_LZ4) find_package(lz4 REQUIRED) @@ -506,9 +593,19 @@ if(${VELOX_BUILD_MINIMAL_WITH_DWIO} OR ${VELOX_ENABLE_HIVE_CONNECTOR}) # # TODO: make these optional and pluggable. find_package(ZLIB REQUIRED) - find_package(lzo2 REQUIRED) find_package(zstd REQUIRED) find_package(Snappy REQUIRED) + # Ensure zstd::zstd target exists - handle different zstd package configurations + if(NOT TARGET zstd::zstd) + if(TARGET zstd::libzstd_static) + add_library(zstd::zstd ALIAS zstd::libzstd_static) + elseif(TARGET zstd::libzstd_shared) + add_library(zstd::zstd ALIAS zstd::libzstd_shared) + else() + # Fallback: use Findzstd.cmake to create the target + include(Findzstd) + endif() + endif() endif() velox_set_source(re2) @@ -540,32 +637,25 @@ if(${VELOX_ENABLE_CLP_CONNECTOR}) velox_set_source(spdlog) velox_resolve_dependency(spdlog) + velox_set_source(ystdlib) + velox_resolve_dependency(ystdlib) + # Dependencies that depend on other dependencies velox_set_source(log_surgeon) velox_resolve_dependency(log_surgeon) - velox_set_source(ystdlib) - velox_resolve_dependency(ystdlib) - set(clp_SOURCE BUNDLED) velox_resolve_dependency(clp) endif() if(${VELOX_BUILD_PYTHON_PACKAGE}) - find_package( - Python 3.9 - COMPONENTS Interpreter Development.Module - REQUIRED) + find_package(Python 3.9 COMPONENTS Interpreter Development.Module REQUIRED) velox_set_source(pybind11) velox_resolve_dependency(pybind11 2.10.0) endif() -# DWIO (ORC/DWRF) and Substrait depend on protobuf. -if(${VELOX_BUILD_MINIMAL_WITH_DWIO} - OR ${VELOX_ENABLE_HIVE_CONNECTOR} - OR ${VELOX_ENABLE_SUBSTRAIT} - OR VELOX_ENABLE_GCS) - +# DWIO (ORC/DWRF) depends on protobuf. +if(${VELOX_BUILD_MINIMAL_WITH_DWIO} OR ${VELOX_ENABLE_HIVE_CONNECTOR} OR VELOX_ENABLE_GCS) # Locate or build protobuf. velox_set_source(Protobuf) velox_resolve_dependency(Protobuf 3.21.7 REQUIRED) @@ -573,7 +663,10 @@ if(${VELOX_BUILD_MINIMAL_WITH_DWIO} endif() velox_set_source(simdjson) -velox_resolve_dependency(simdjson 3.9.3) +velox_resolve_dependency(simdjson 4.1.0) + +velox_set_source(FastFloat) +velox_resolve_dependency(FastFloat) velox_set_source(folly) velox_resolve_dependency(folly) @@ -608,13 +701,6 @@ endif() # GCC needs to link a library to enable std::filesystem. if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") - - # Ensure we have gcc at least 9+. - if(CMAKE_CXX_COMPILER_VERSION LESS 9.0) - message( - FATAL_ERROR "VELOX requires gcc > 9. Found ${CMAKE_CXX_COMPILER_VERSION}") - endif() - # Find Threads library find_package(Threads REQUIRED) endif() @@ -622,7 +708,7 @@ endif() if(VELOX_BUILD_TESTING AND NOT VELOX_ENABLE_DUCKDB) message( FATAL_ERROR - "Unit tests require duckDB to be enabled (VELOX_ENABLE_DUCKDB=ON or VELOX_BUILD_TESTING=OFF)" + "Unit tests require duckDB to be enabled (VELOX_ENABLE_DUCKDB=ON or VELOX_BUILD_TESTING=OFF)" ) endif() @@ -639,10 +725,10 @@ if(CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin") COMMAND brew --prefix bison RESULT_VARIABLE BREW_BISON OUTPUT_VARIABLE BREW_BISON_PREFIX - OUTPUT_STRIP_TRAILING_WHITESPACE) + OUTPUT_STRIP_TRAILING_WHITESPACE + ) if(BREW_BISON EQUAL 0 AND EXISTS "${BREW_BISON_PREFIX}") - message( - STATUS "Found Bison keg installed by Homebrew at ${BREW_BISON_PREFIX}") + message(STATUS "Found Bison keg installed by Homebrew at ${BREW_BISON_PREFIX}") set(BISON_EXECUTABLE "${BREW_BISON_PREFIX}/bin/bison") endif() @@ -650,10 +736,10 @@ if(CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin") COMMAND brew --prefix flex RESULT_VARIABLE BREW_FLEX OUTPUT_VARIABLE BREW_FLEX_PREFIX - OUTPUT_STRIP_TRAILING_WHITESPACE) + OUTPUT_STRIP_TRAILING_WHITESPACE + ) if(BREW_FLEX EQUAL 0 AND EXISTS "${BREW_FLEX_PREFIX}") - message( - STATUS "Found Flex keg installed by Homebrew at ${BREW_FLEX_PREFIX}") + message(STATUS "Found Flex keg installed by Homebrew at ${BREW_FLEX_PREFIX}") set(FLEX_EXECUTABLE "${BREW_FLEX_PREFIX}/bin/flex") set(FLEX_INCLUDE_DIR "${BREW_FLEX_PREFIX}/include") endif() @@ -668,9 +754,7 @@ include_directories(SYSTEM velox/external) if(NOT VELOX_DISABLE_GOOGLETEST) velox_set_source(GTest) velox_resolve_dependency(GTest) - set(VELOX_GTEST_INCUDE_DIR - "${gtest_SOURCE_DIR}/include" - PARENT_SCOPE) + set(VELOX_GTEST_INCUDE_DIR "${gtest_SOURCE_DIR}/include" PARENT_SCOPE) endif() velox_set_source(xsimd) @@ -686,10 +770,6 @@ endif() include_directories(.) -# TODO: Include all other installation files. For now just making sure this -# generates an installable makefile. -install(FILES velox/type/Type.h DESTINATION "include/velox") - # Adding this down here prevents warnings in dependencies from stopping the # build if("${TREAT_WARNINGS_AS_ERRORS}") diff --git a/CODING_STYLE.md b/CODING_STYLE.md index 43947192cd74..97a60199941e 100644 --- a/CODING_STYLE.md +++ b/CODING_STYLE.md @@ -246,16 +246,17 @@ About comment style: * As a general rule, do not use string literals without declaring a named constant for them. * The best way to make a constant string literal is to use constexpr - `std::string_view`/`folly::StringPiece` + `std::string_view` * **NEVER** use `std::string` - this makes your code more prone to SIOF bugs. * Avoid `const char* const` and `const char*` - these are less efficient to convert to `std::string` later on in your program if you ever need to - because `std::string_view`/ `folly::StringPiece` knows its size and can use - a more efficient constructor. `std::string_view`/ `folly::StringPiece` also - has richer interfaces and often works as a drop-in replacement to - `std::string`. + because `std::string_view` knows its size and can use a more efficient + constructor. `std::string_view` also has richer interfaces and often + works as a drop-in replacement to `std::string`. * Need compile-time string concatenation? You can use `folly::FixedString` for that. + * Do not use `folly::StringPiece` in new code, use `std::string_view` + instead. ## Macros diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a76dd9e334cc..8db9559bddd3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -137,6 +137,7 @@ where: * *Type* can be any of the following keywords: * **feat** when new features are being added. * **fix** for bug fixes. + * **perf** for performance improvements. * **build** for build or CI-related improvements. * **test** for adding tests (only). * **docs** for enhancements to documentation (only). diff --git a/Makefile b/Makefile index be290e806ea8..b9548ea66f0a 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,8 @@ endif ifdef CUDA_ARCHITECTURES CMAKE_FLAGS += -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)" +else +EXTRA_CMAKE_CUDA_FLAGS = -DCMAKE_CUDA_ARCHITECTURES="native" endif ifdef CUDA_COMPILER @@ -101,10 +103,10 @@ cmake: #: Use CMake to create a Makefile build system ${EXTRA_CMAKE_FLAGS} cmake-wave: - $(MAKE) EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_ENABLE_WAVE=ON" cmake + $(MAKE) EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_ENABLE_WAVE=ON ${EXTRA_CMAKE_CUDA_FLAGS}" cmake cmake-cudf: - $(MAKE) EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_ENABLE_CUDF=ON" cmake + $(MAKE) EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_ENABLE_CUDF=ON ${EXTRA_CMAKE_CUDA_FLAGS}" cmake build: #: Build the software based in BUILD_DIR and BUILD_TYPE variables cmake --build $(BUILD_BASE_DIR)/$(BUILD_DIR) -j $(NUM_THREADS) @@ -118,13 +120,17 @@ release: #: Build the release version $(MAKE) build BUILD_DIR=release minimal_debug: #: Minimal build with debugging symbols - $(MAKE) cmake BUILD_DIR=debug BUILD_TYPE=debug EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_BUILD_MINIMAL=ON" + $(MAKE) cmake BUILD_DIR=debug BUILD_TYPE=debug \ + EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ + -DVELOX_BUILD_MINIMAL=ON" $(MAKE) build BUILD_DIR=debug min_debug: minimal_debug minimal: #: Minimal build - $(MAKE) cmake BUILD_DIR=release BUILD_TYPE=release EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DVELOX_BUILD_MINIMAL=ON" + $(MAKE) cmake BUILD_DIR=release BUILD_TYPE=release \ + EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ + -DVELOX_BUILD_MINIMAL=ON" $(MAKE) build BUILD_DIR=release wave: #: Build with Wave GPU support @@ -144,26 +150,39 @@ cudf-debug: #: Build with debugging symbols and cuDF GPU support $(MAKE) build BUILD_DIR=debug dwio: #: Minimal build with dwio enabled. - $(MAKE) cmake BUILD_DIR=release BUILD_TYPE=release EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -DVELOX_BUILD_MINIMAL_WITH_DWIO=ON" + $(MAKE) cmake BUILD_DIR=release BUILD_TYPE=release \ + EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ + -DVELOX_BUILD_MINIMAL_WITH_DWIO=ON" $(MAKE) build BUILD_DIR=release dwio_debug: #: Minimal build with dwio debugging symbols. - $(MAKE) cmake BUILD_DIR=debug BUILD_TYPE=debug EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ - -DVELOX_BUILD_MINIMAL_WITH_DWIO=ON" + $(MAKE) cmake BUILD_DIR=debug BUILD_TYPE=debug \ + EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} \ + -DVELOX_BUILD_MINIMAL_WITH_DWIO=ON" $(MAKE) build BUILD_DIR=debug +sve_build: #: Build with SVE-specific configuration +# Check for SVE support and set appropriate compilers + @SVE_CC=$$(if [ "$$(lscpu | grep -q "sve" && grep -qi "ubuntu" /etc/os-release && echo 1)" = "1" ]; then echo "/usr/bin/gcc-12"; else echo "gcc"; fi); \ + SVE_CXX=$$(if [ "$$(lscpu | grep -q "sve" && grep -qi "ubuntu" /etc/os-release && echo 1)" = "1" ]; then echo "/usr/bin/g++-12"; else echo "g++"; fi); \ + echo "Using CC=$$SVE_CC, CXX=$$SVE_CXX"; \ + export CC=$$SVE_CC; export CXX=$$SVE_CXX; \ + $(MAKE) cmake BUILD_DIR=release BUILD_TYPE=release EXTRA_CMAKE_FLAGS="-DCMAKE_C_COMPILER=$$SVE_CC -DCMAKE_CXX_COMPILER=$$SVE_CXX -DCMAKE_CXX_FLAGS='$(COMPILER_FLAGS) -Wno-error=stringop-overflow $(shell ./scripts/setup-helper-functions.sh detect_sve_flags)'" && \ + $(MAKE) build BUILD_DIR=release + benchmarks-basic-build: - $(MAKE) release EXTRA_CMAKE_FLAGS=" ${EXTRA_CMAKE_FLAGS} \ - -DVELOX_BUILD_TESTING=OFF \ - -DVELOX_ENABLE_BENCHMARKS_BASIC=ON \ - -DVELOX_BUILD_RUNNER=OFF" + $(MAKE) release \ + EXTRA_CMAKE_FLAGS=" ${EXTRA_CMAKE_FLAGS} \ + -DVELOX_BUILD_TESTING=OFF \ + -DVELOX_ENABLE_BENCHMARKS_BASIC=ON \ + -DVELOX_BUILD_RUNNER=OFF" benchmarks-build: - $(MAKE) release EXTRA_CMAKE_FLAGS=" ${EXTRA_CMAKE_FLAGS} \ - -DVELOX_BUILD_TESTING=OFF \ - -DVELOX_ENABLE_BENCHMARKS=ON \ - -DVELOX_BUILD_RUNNER=OFF" + $(MAKE) release \ + EXTRA_CMAKE_FLAGS=" ${EXTRA_CMAKE_FLAGS} \ + -DVELOX_BUILD_TESTING=OFF \ + -DVELOX_ENABLE_BENCHMARKS=ON \ + -DVELOX_BUILD_RUNNER=OFF" benchmarks-basic-run: scripts/ci/benchmark-runner.py run \ diff --git a/NOTICE.txt b/NOTICE.txt index 4fb5849fba09..ca4cf2db6b4d 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -15,97 +15,142 @@ This product includes software from the Arrow project. * https://github.com/apache/arrow/blob/apache-arrow-15.0.0/cpp/src/arrow/io/hdfs_internal.cc Which contain the following NOTICE file: ------- - Apache Arrow - Copyright 2016-2024 The Apache Software Foundation - This product includes software developed at - The Apache Software Foundation (http://www.apache.org/). - This product includes software from the SFrame project (BSD, 3-clause). - * Copyright (C) 2015 Dato, Inc. - * Copyright (c) 2009 Carnegie Mellon University. - This product includes software from the Feather project (Apache 2.0) - https://github.com/wesm/feather - This product includes software from the DyND project (BSD 2-clause) - https://github.com/libdynd - This product includes software from the LLVM project - * distributed under the University of Illinois Open Source - This product includes software from the google-lint project - * Copyright (c) 2009 Google Inc. All rights reserved. - This product includes software from the mman-win32 project - * Copyright https://code.google.com/p/mman-win32/ - * Licensed under the MIT License; - This product includes software from the LevelDB project - * Copyright (c) 2011 The LevelDB Authors. All rights reserved. - * Use of this source code is governed by a BSD-style license that can be - * Moved from Kudu http://github.com/cloudera/kudu - This product includes software from the CMake project - * Copyright 2001-2009 Kitware, Inc. - * Copyright 2012-2014 Continuum Analytics, Inc. - * All rights reserved. - This product includes software from https://github.com/matthew-brett/multibuild (BSD 2-clause) - * Copyright (c) 2013-2016, Matt Terry and Matthew Brett; all rights reserved. - This product includes software from the Ibis project (Apache 2.0) - * Copyright (c) 2015 Cloudera, Inc. - * https://github.com/cloudera/ibis - This product includes software from Dremio (Apache 2.0) - * Copyright (C) 2017-2018 Dremio Corporation - * https://github.com/dremio/dremio-oss - This product includes software from Google Guava (Apache 2.0) - * Copyright (C) 2007 The Guava Authors - * https://github.com/google/guava - This product include software from CMake (BSD 3-Clause) - * CMake - Cross Platform Makefile Generator - * Copyright 2000-2019 Kitware, Inc. and Contributors - The web site includes files generated by Jekyll. - -------------------------------------------------------------------------------- - This product includes code from Apache Kudu, which includes the following in - its NOTICE file: - Apache Kudu - Copyright 2016 The Apache Software Foundation - This product includes software developed at - The Apache Software Foundation (http://www.apache.org/). - Portions of this software were developed at - Cloudera, Inc (http://www.cloudera.com/). - -------------------------------------------------------------------------------- - This product includes code from Apache ORC, which includes the following in - its NOTICE file: - Apache ORC - Copyright 2013-2019 The Apache Software Foundation - This product includes software developed by The Apache Software - Foundation (http://www.apache.org/). - This product includes software developed by Hewlett-Packard: - (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P + Apache Arrow + Copyright 2016-2024 The Apache Software Foundation + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + This product includes software from the SFrame project (BSD, 3-clause). + * Copyright (C) 2015 Dato, Inc. + * Copyright (c) 2009 Carnegie Mellon University. + This product includes software from the Feather project (Apache 2.0) + https://github.com/wesm/feather + This product includes software from the DyND project (BSD 2-clause) + https://github.com/libdynd + This product includes software from the LLVM project + * distributed under the University of Illinois Open Source + This product includes software from the google-lint project + * Copyright (c) 2009 Google Inc. All rights reserved. + This product includes software from the mman-win32 project + * Copyright https://code.google.com/p/mman-win32/ + * Licensed under the MIT License; + This product includes software from the LevelDB project + * Copyright (c) 2011 The LevelDB Authors. All rights reserved. + * Use of this source code is governed by a BSD-style license that can be + * Moved from Kudu http://github.com/cloudera/kudu + This product includes software from the CMake project + * Copyright 2001-2009 Kitware, Inc. + * Copyright 2012-2014 Continuum Analytics, Inc. + * All rights reserved. + This product includes software from https://github.com/matthew-brett/multibuild (BSD 2-clause) + * Copyright (c) 2013-2016, Matt Terry and Matthew Brett; all rights reserved. + This product includes software from the Ibis project (Apache 2.0) + * Copyright (c) 2015 Cloudera, Inc. + * https://github.com/cloudera/ibis + This product includes software from Dremio (Apache 2.0) + * Copyright (C) 2017-2018 Dremio Corporation + * https://github.com/dremio/dremio-oss + This product includes software from Google Guava (Apache 2.0) + * Copyright (C) 2007 The Guava Authors + * https://github.com/google/guava + This product include software from CMake (BSD 3-Clause) + * CMake - Cross Platform Makefile Generator + * Copyright 2000-2019 Kitware, Inc. and Contributors + The web site includes files generated by Jekyll. + -------------------------------------------------------------------------------- + This product includes code from Apache Kudu, which includes the following in + its NOTICE file: + Apache Kudu + Copyright 2016 The Apache Software Foundation + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Portions of this software were developed at + Cloudera, Inc (http://www.cloudera.com/). + -------------------------------------------------------------------------------- + This product includes code from Apache ORC, which includes the following in + its NOTICE file: + Apache ORC + Copyright 2013-2019 The Apache Software Foundation + This product includes software developed by The Apache Software + Foundation (http://www.apache.org/). + This product includes software developed by Hewlett-Packard: + (c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P ------- This product includes software from the The Hadoop project. * https://github.com/apache/hadoop/blob/release-3.3.0-RC0/hadoop-hdfs-project/hadoop-hdfs-native-client/src/main/native/libhdfs/include/hdfs/hdfs.h Which contains the following NOTICE file: ---- - Apache Hadoop - Copyright 2006 and onwards The Apache Software Foundation. - This product includes software developed at - The Apache Software Foundation (http://www.apache.org/). - Export Control Notice - --------------------- - This distribution includes cryptographic software. The country in - which you currently reside may have restrictions on the import, - possession, use, and/or re-export to another country, of - encryption software. BEFORE using any encryption software, please - check your country's laws, regulations and policies concerning the - import, possession, or use, and re-export of encryption software, to - see if this is permitted. See for more - information. - The U.S. Government Department of Commerce, Bureau of Industry and - Security (BIS), has classified this software as Export Commodity - Control Number (ECCN) 5D002.C.1, which includes information security - software using or performing cryptographic functions with asymmetric - algorithms. The form and manner of this Apache Software Foundation - distribution makes it eligible for export under the License Exception - ENC Technology Software Unrestricted (TSU) exception (see the BIS - Export Administration Regulations, Section 740.13) for both object - code and source code. - The following provides more details on the included cryptographic software: - This software uses the SSL libraries from the Jetty project written - by mortbay.org. - Hadoop Yarn Server Web Proxy uses the BouncyCastle Java - cryptography APIs written by the Legion of the Bouncy Castle Inc. + Apache Hadoop + Copyright 2006 and onwards The Apache Software Foundation. + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Export Control Notice + --------------------- + This distribution includes cryptographic software. The country in + which you currently reside may have restrictions on the import, + possession, use, and/or re-export to another country, of + encryption software. BEFORE using any encryption software, please + check your country's laws, regulations and policies concerning the + import, possession, or use, and re-export of encryption software, to + see if this is permitted. See for more + information. + The U.S. Government Department of Commerce, Bureau of Industry and + Security (BIS), has classified this software as Export Commodity + Control Number (ECCN) 5D002.C.1, which includes information security + software using or performing cryptographic functions with asymmetric + algorithms. The form and manner of this Apache Software Foundation + distribution makes it eligible for export under the License Exception + ENC Technology Software Unrestricted (TSU) exception (see the BIS + Export Administration Regulations, Section 740.13) for both object + code and source code. + The following provides more details on the included cryptographic software: + This software uses the SSL libraries from the Jetty project written + by mortbay.org. + Hadoop Yarn Server Web Proxy uses the BouncyCastle Java + cryptography APIs written by the Legion of the Bouncy Castle Inc. +---- + +This product includes software from the The Spark project. +* https://github.com/apache/spark/tree/v3.5.1/connector/connect/common/src/main/protobuf/spark/connect +Which contain the following NOTICE file: +---- + Apache Spark + Copyright 2014 and onwards The Apache Software Foundation. + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Export Control Notice + --------------------- + + This distribution includes cryptographic software. The country in which you currently reside may have + restrictions on the import, possession, use, and/or re-export to another country, of encryption software. + BEFORE using any encryption software, please check your country's laws, regulations and policies concerning + the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + + The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this + software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software + using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache + Software Foundation distribution makes it eligible for export under the License Exception ENC Technology + Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for + both object code and source code. + + The following provides more details on the included cryptographic software: + + This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to + support authentication, and encryption and decryption of data sent across the network between + services. + + + Metrics + Copyright 2010-2013 Coda Hale and Yammer, Inc. + + This product includes software developed by Coda Hale and Yammer, Inc. + + This product includes code derived from the JSR-166 project (ThreadLocalRandom, Striped64, + LongAdder), which was released with the following comments: + + Written by Doug Lea with assistance from members of JCP JSR-166 + Expert Group and released to the public domain, as explained at + http://creativecommons.org/publicdomain/zero/1.0/ ---- diff --git a/README.md b/README.md index 16fd3cd265cf..3fb83181fd9c 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,32 @@ Details on the dependencies and how Velox manages some of them for you Velox also provides the following scripts to help developers setup and install Velox dependencies for a given platform. +### Supported OS and compiler matrix + +The minimum versions of supported compilers: + +| OS | Compiler | Version | +|----|----------|---------| +| Linux | gcc | 11 | +| Linux | clang | 15 | +| macOS | clang | 15 | + +The recommended OS versions and compilers: + +| OS | Compiler | Version | +|----|----------|---------| +| CentOS 9/RHEL 9 | gcc | 12 | +| Ubuntu 22.04 | gcc | 11 | +| macOS | clang | 16 | + +Alternative combinations: + +| OS | Compiler | Version | +|----|----------|---------| +| CentOS 9/RHEL 9 | gcc | 11 | +| Ubuntu 20.04 | gcc | 11 | +| Ubuntu 24.04 | clang | 15 | + ### Setting up dependencies The following setup scripts use the `DEPENDENCY_DIR` environment variable to set the @@ -125,7 +151,7 @@ Using the default install location `/usr/local` on macOS is discouraged since th location is used by certain Homebrew versions. Manually add the `INSTALL_PREFIX` value in the IDE or bash environment, -say `export INSTALL_PREFIX=/Users/$USERNAME/velox/deps-install` to `~/.zshrc` so that +say `export INSTALL_PREFIX=/Users/$USER/velox/deps-install` to `~/.zshrc` so that subsequent Velox builds can use the installed packages. *You can reuse `DEPENDENCY_INSTALL` and `INSTALL_PREFIX` for Velox clients such as Prestissimo diff --git a/docker-bake.hcl b/docker-bake.hcl new file mode 100644 index 000000000000..3a63980522a6 --- /dev/null +++ b/docker-bake.hcl @@ -0,0 +1,148 @@ +variable "tag" { + default = "ghcr.io/facebookincubator/velox-dev" +} + +# Variable to control cache pushing +variable "DOCKER_UPLOAD_CACHE" { + default = false +} + +function "cache-to-arch" { + params = [name, arch] + # We don't want to push local build to the cache by accident + # so we remove cache-to unless we are running in CI + result = DOCKER_UPLOAD_CACHE ? [{ + type = "registry" + ref = "${tag}:build-cache-${name}-${arch}" + mode = "max" + compression = "zstd" + }] : [] +} + +function "cache-from-arch" { + params = [name, arch] + result = [{ + type = "registry" + ref = "${tag}:build-cache-${name}-${arch}" + }] +} + +function "ci_images_by_arch" { + params = [arch] + result = ["centos9-${arch}", "adapters-${arch}"] +} + +group "ci-amd64" { + targets = ci_images_by_arch("amd64") +} + +group "ci-arm64" { + targets = ci_images_by_arch("arm64") +} + +group "default" { + targets = [] +} + +target "base" { + output = [ + DOCKER_UPLOAD_CACHE ? { + type = "registry" + compression = "zstd" + oci-mediatypes = true + } : { + # For local builds with the docker driver the image will be loaded + # even without explicitly adding the `docker` exporter + type = "cacheonly" + } + + ] +} + +target "pyvelox" { + inherits = ["base"] + context = "." + name = "pyvelox-${arch}" + dockerfile = "scripts/docker/centos-multi.dockerfile" + target = "pyvelox" + args = { + image = "quay.io/pypa/manylinux_2_28:latest" + VELOX_BUILD_SHARED = "OFF" + } + matrix = { + arch = ["amd64", "arm64"] + } + platforms = ["linux/${arch}"] + tags = ["${tag}:pyvelox-${arch}"] + cache-to = cache-to-arch("pyvelox", "${arch}") + cache-from = cache-from-arch("pyvelox", "${arch}") +} + +target "adapters" { + inherits = ["base","adapters-cpp"] + name = "adapters-${arch}" + matrix = { + arch = ["amd64", "arm64"] + } + platforms = ["linux/${arch}"] + tags = ["${tag}:adapters-${arch}"] + cache-to = cache-to-arch("adapters", "${arch}") + cache-from = cache-from-arch("adapters", "${arch}") +} + +target "centos9" { + inherits = ["base","centos-cpp"] + name = "centos9-${arch}" + matrix = { + arch = ["amd64", "arm64"] + } + platforms = ["linux/${arch}"] + tags = ["${tag}:centos9-${arch}"] + cache-to = cache-to-arch("centos9", "${arch}") + cache-from = cache-from-arch("centos9", "${arch}") +} + +target "ubuntu-amd64" { + inherits = ["base","ubuntu-cpp"] + cache-to = cache-to-arch("ubuntu", "amd64") + cache-from = cache-from-arch("ubuntu", "amd64") +} + +group "ubuntu-arm64" { + # We don't actually want to build the ubuntu arm image, this is a trick to simplify CI + # Empty targets don't fail the build. + targets = [] +} + +group "fedora-arm64" { + # We don't actually want to build the fedora arm image, this is a trick to simplify CI + # Empty targets don't fail the build. + targets = [] +} + +target "fedora-amd64" { + inherits = ["base", "fedora"] + dockerfile = "scripts/docker/fedora.dockerfile" + cache-to = cache-to-arch("fedora", "amd64") + cache-from = cache-from-arch("fedora", "amd64") +} + +group "java" { + # The main work is in the well cached download steps and the shared base stage, + # so these can easily be run on the same node in ci + targets = ["spark-server", "presto-java"] +} + +target "spark-server" { + inherits = ["base"] + target = "spark-server" + cache-to = cache-to-arch("spark-server", "amd64") + cache-from = cache-from-arch("spark-server", "amd64") +} + +target "presto-java" { + inherits = ["base"] + target = "presto-java" + cache-to = cache-to-arch("presto-java", "amd64") + cache-from = cache-from-arch("presto-java", "amd64") +} diff --git a/docker-compose.yml b/docker-compose.yml index d7cc50d1a8f0..65d99a93bf7d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,77 +11,63 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -version: '3.5' +# The docker-compose file is used only for running the services: +# `docker compose run adapters-cpp` +# +# If you want to build any of the images you will need to use docker bake: +# `docker bake adapters-amd64` +# +# The target names are defined in docker-bake.hcl, you can use the names +# from `docker-compose.yml`, like `adapters-cpp` but those will not make use +# of layer caching and will tag images with the wrong tag for the multi-arch images. +# Do this for local use only! +# +# When testing changes you want to push as a PR use docker bake. services: - ubuntu-cpp: - # Usage: - # docker-compose pull ubuntu-cpp or docker-compose build ubuntu-cpp - # docker-compose run --rm ubuntu-cpp - # or - # docker-compose run -e NUM_THREADS= --rm ubuntu-cpp - # to set the number of threads used during compilation - image: ghcr.io/facebookincubator/velox-dev:ubuntu-22.04 + # This is not an actual target and just used to inherit the env and volume + # defaults + base-block: build: context: . - dockerfile: scripts/docker/ubuntu-22.04-cpp.dockerfile environment: - NUM_THREADS: 8 # default value for NUM_THREADS - VELOX_DEPENDENCY_SOURCE: BUNDLED # Build dependencies from source + NUM_THREADS: ${NUM_THREADS:-8} # default value for NUM_THREADS CCACHE_DIR: /velox/.ccache volumes: - .:/velox:delegated - command: scripts/docker/docker-command.sh + working_dir: /velox + command: /velox/scripts/docker/docker-command.sh - adapters-cpp: - # Usage: - # docker-compose pull adapters-cpp or docker-compose build adapters-cpp - # or - # docker-compose run --rm adapters-cpp - # or - # docker-compose run -e NUM_THREADS= --rm adapters-cpp - # to set the number of threads used during compilation - # scripts/docker/adapters.dockerfile uses SHELL which is not supported for OCI image format. - # podman users must specify "--podman-build-args='--format docker'" argument. - image: ghcr.io/facebookincubator/velox-dev:adapters + ubuntu-cpp: + extends: base-block + image: ghcr.io/facebookincubator/velox-dev:ubuntu-22.04 build: - context: . - dockerfile: scripts/docker/adapters.dockerfile - args: - image: ghcr.io/facebookincubator/velox-dev:centos9 + dockerfile: scripts/docker/ubuntu-22.04-cpp.dockerfile environment: - NUM_THREADS: 8 # default value for NUM_THREADS - CCACHE_DIR: /velox/.ccache - EXTRA_CMAKE_FLAGS: > - -DVELOX_ENABLE_PARQUET=ON - -DVELOX_ENABLE_S3=ON - volumes: - - .:/velox:delegated - working_dir: /velox - command: /velox/scripts/docker/docker-command.sh + VELOX_DEPENDENCY_SOURCE: BUNDLED # Build dependencies from source - adapters-cuda: - # Usage: - # docker-compose pull adapters-cuda or docker-compose build adapters-cuda - # or - # docker-compose run --rm adapters-cuda - # or - # docker-compose run -e NUM_THREADS= --rm adapters-cuda - # to set the number of threads used during compilation - # scripts/docker/adapters.dockerfile uses SHELL which is not supported for OCI image format. - # podman users must specify "--podman-build-args='--format docker'" argument. - image: ghcr.io/facebookincubator/velox-dev:adapters + centos-cpp: + extends: base-block + image: ghcr.io/facebookincubator/velox-dev:centos9 build: - context: . - dockerfile: scripts/docker/adapters.dockerfile + dockerfile: scripts/docker/centos-multi.dockerfile + target: centos9 args: - image: ghcr.io/facebookincubator/velox-dev:centos9 + image: quay.io/centos/centos:stream9 + VELOX_BUILD_SHARED: "ON" + + adapters-cpp: + extends: centos-cpp + image: ghcr.io/facebookincubator/velox-dev:adapters environment: - NUM_THREADS: 8 # default value for NUM_THREADS - CCACHE_DIR: /velox/.ccache - EXTRA_CMAKE_FLAGS: > + EXTRA_CMAKE_FLAGS: >- -DVELOX_ENABLE_PARQUET=ON -DVELOX_ENABLE_S3=ON + build: + target: adapters + + adapters-cuda: + extends: adapters-cpp privileged: true deploy: resources: @@ -90,70 +76,28 @@ services: - driver: nvidia count: 1 capabilities: [gpu] - volumes: - - .:/velox:delegated - working_dir: /velox - command: /velox/scripts/docker/docker-command.sh - - centos-cpp: - # Usage: - # docker-compose pull centos-cpp or docker-compose build centos-cpp - # docker-compose run --rm centos-cpp - # or - # docker-compose run -e NUM_THREADS= --rm centos-cpp - # to set the number of threads used during compilation - image: ghcr.io/facebookincubator/velox-dev:centos9 - build: - context: . - dockerfile: scripts/docker/centos.dockerfile - args: - image: quay.io/centos/centos:stream9 - environment: - NUM_THREADS: 8 # default value for NUM_THREADS - CCACHE_DIR: /velox/.ccache - volumes: - - .:/velox:delegated - working_dir: /velox - command: /velox/scripts/docker/docker-command.sh presto-java: - # Usage: - # docker-compose pull presto-java or docker-compose build presto-java - # docker-compose run --rm presto-java - # or - # docker-compose run -e NUM_THREADS= --rm presto-java - # to set the number of threads used during compilation + extends: base-block image: ghcr.io/facebookincubator/velox-dev:presto-java build: args: - - PRESTO_VERSION=0.290 - context: . - dockerfile: scripts/docker/prestojava-container.dockerfile - environment: - NUM_THREADS: 8 # default value for NUM_THREADS - CCACHE_DIR: /velox/.ccache - volumes: - - .:/velox:delegated - working_dir: /velox - command: /velox/scripts/docker/docker-command.sh + - PRESTO_VERSION=0.295 + dockerfile: scripts/docker/java.dockerfile spark-server: - # Usage: - # docker-compose pull spark-server or docker-compose build spark-server - # docker-compose run --rm spark-server - # or - # docker-compose run -e NUM_THREADS= --rm spark-server - # to set the number of threads used during compilation + extends: base-block image: ghcr.io/facebookincubator/velox-dev:spark-server build: args: - SPARK_VERSION=3.5.1 - context: . - dockerfile: scripts/docker/spark-container.dockerfile + dockerfile: scripts/docker/java.dockerfile + + fedora: + extends: base-block + image: ghcr.io/facebookincubator/velox-dev:fedora + build: + dockerfile: scripts/docker/fedora.dockerfile + target: fedora environment: - NUM_THREADS: 8 # default value for NUM_THREADS - CCACHE_DIR: /velox/.ccache - volumes: - - .:/velox:delegated - working_dir: /velox - command: /velox/scripts/docker/docker-command.sh + VELOX_BUILD_SHARED: "ON" diff --git a/pyproject.toml b/pyproject.toml index 439d447e583d..d22044ea01b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,9 @@ classifiers = [ [project.urls] homepage = "https://github.com/facebookincubator/velox" +[project.optional-dependencies] +docs = ["sphinx", "sphinx-tabs", "breathe", "sphinx_rtd_theme", "chardet"] + [tool.scikit-build] build-dir = "_build/python" metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" @@ -110,7 +113,7 @@ cmake.define.TREAT_WARNINGS_AS_ERRORS = false [tool.setuptools_scm] root = '.' version_scheme = 'guess-next-dev' -git_describe_command = 'git describe --dirty --tags --long --match "pyvelox-v[0-9]*.*"' +git_describe_command = ['git', 'describe', '--dirty', '--tags', '--long', '--match', 'pyvelox-v[0-9]*'] fallback_version = '0.2.0a' [tool.cibuildwheel] diff --git a/python/test/test_runner.py b/python/test/test_runner.py index bafde6730e10..59a5baf4a430 100644 --- a/python/test/test_runner.py +++ b/python/test/test_runner.py @@ -40,7 +40,7 @@ def setUp(self) -> None: def tearDown(self) -> None: unregister_all() - def test_runner_empty(self): + def test_empty(self): plan_builder = PlanBuilder().values() runner = LocalRunner(plan_builder.get_plan_node()) total_size = 0 @@ -49,19 +49,19 @@ def test_runner_empty(self): total_size += vector.size() self.assertEqual(total_size, 0) - def test_runner_not_executed(self): + def test_not_executed(self): # Ensure it won't hang on destruction when it was not executed. plan_builder = PlanBuilder().values() LocalRunner(plan_builder.get_plan_node()) - def test_runner_executed_twice(self): + def test_executed_twice(self): # Ensure the runner fails if it is executed twice. plan_builder = PlanBuilder().values() runner = LocalRunner(plan_builder.get_plan_node()) runner.execute() self.assertRaises(RuntimeError, runner.execute) - def test_runner_with_values(self): + def test_values(self): vectors = [] batch_size = 10 num_batches = 10 @@ -79,7 +79,7 @@ def test_runner_with_values(self): total_size += vector.size() self.assertEqual(total_size, 100) - def test_runner_with_values_order_limit(self): + def test_values_order_limit(self): vectors = [] batch_size = 10 num_batches = 10 @@ -103,7 +103,7 @@ def test_runner_with_values_order_limit(self): ) self.assertEqual(output, expected_result) - def test_runner_with_hash_join(self): + def test_hash_join(self): batch_size = 100 probe = list(range(batch_size)) build = [i for i in probe if i % 2 == 0] @@ -138,7 +138,7 @@ def test_runner_with_hash_join(self): result = int(vector.child_at(0)[0]) self.assertEqual(result, sum(build)) - def test_runner_with_merge_join(self): + def test_merge_join(self): batch_size = 10 array = pyarrow.array([42] * batch_size) batch = to_velox(pyarrow.record_batch([array], names=["c0"])) @@ -159,7 +159,7 @@ def test_runner_with_merge_join(self): total_size += vector.size() self.assertEqual(total_size, batch_size * batch_size) - def test_runner_with_merge_sort(self): + def test_merge_sort(self): array = pyarrow.array([0, 1, 2, 3, 4]) batch = to_velox(pyarrow.record_batch([array], names=["c0"])) @@ -186,6 +186,43 @@ def test_runner_with_merge_sort(self): ) self.assertEqual(output, expected) + def test_unnest_and_streaming_aggregate(self): + batch_size = 100 + base = list(range(batch_size)) + + input_vector = to_velox( + pyarrow.record_batch([pyarrow.array([base])], names=["c0"]) + ) + # Single row containing an array column with `batch_size` elements. + self.assertEqual(input_vector.size(), 1) + + # Unnest it and check we get batch_size rows. + plan_builder = PlanBuilder() + plan_builder.values([input_vector]).unnest(unnest_columns=["c0"]) + + runner = LocalRunner(plan_builder.get_plan_node()) + iterator = runner.execute() + vector = next(iterator) + + self.assertRaises(StopIteration, next, iterator) + self.assertEqual(vector.size(), batch_size) + + # Unnest then stream aggregate it back to ensure we get the input + # vector back. + plan_builder = PlanBuilder() + plan_builder.values([input_vector]).unnest( + unnest_columns=["c0"], + ).streaming_aggregate( + grouping_keys=[], + aggregations=["array_agg(c0_e)"], + ) + + runner = LocalRunner(plan_builder.get_plan_node()) + iterator = runner.execute() + vector = next(iterator) + + self.assertEqual(vector, input_vector) + def test_register_connectors(self): register_hive("conn1") self.assertRaises(RuntimeError, register_hive, "conn1") @@ -209,11 +246,12 @@ def test_write_read_file(self): input_batch = to_velox(pyarrow.record_batch([array], names=["c0"])) with tempfile.TemporaryDirectory() as temp_dir: - output_file = f"{temp_dir}/output_file" + output_file_name = "output_file" + output_file_path = f"{temp_dir}/{output_file_name}" plan_builder = PlanBuilder() plan_builder.values([input_batch]).table_write( - output_file=DWRF(output_file), + output_file=DWRF(output_file_path), connector_id="hive", ) @@ -225,14 +263,14 @@ def test_write_read_file(self): self.assertNotEqual(runner.print_plan_with_stats(), "") output_file_from_table_writer = self.extract_file(output) - self.assertEqual(output_file, output_file_from_table_writer) + self.assertEqual(output_file_path, output_file_from_table_writer) # Now scan it back. scan_plan_builder = PlanBuilder() scan_plan_builder.table_scan( output_schema=ROW(["c0"], [BIGINT()]), connector_id="hive", - input_files=[DWRF(output_file)], + input_files=[DWRF(output_file_path)], ) runner = LocalRunner(scan_plan_builder.get_plan_node()) @@ -243,6 +281,28 @@ def test_write_read_file(self): # Ensure the read batch is the same as the one written. self.assertEqual(input_batch, result) + # Test special columns ($row_group_id and $row_number). + special_column_plan_builder = PlanBuilder().table_scan( + output_schema=ROW( + ["$row_group_id", "row_number"], + [VARCHAR(), BIGINT()], + ), + row_index="row_number", + connector_id="hive", + input_files=[DWRF(output_file_path)], + ) + + runner = LocalRunner(special_column_plan_builder.get_plan_node()) + iterator = runner.execute() + result = next(iterator) + self.assertRaises(StopIteration, next, iterator) + + # First column is always the output file name; the second is a + # monotonically increasing row id. + for i in range(batch_size): + self.assertEqual(result.child_at(0)[i], output_file_name) + self.assertEqual(result.child_at(1)[i], str(i)) + def test_tpch_gen(self): register_tpch("tpch") register_hive("hive") diff --git a/python/test/test_type.py b/python/test/test_type.py index ff079cd1cb17..b059ee587236 100644 --- a/python/test/test_type.py +++ b/python/test/test_type.py @@ -26,6 +26,7 @@ DOUBLE, VARCHAR, VARBINARY, + JSON, ARRAY, MAP, ROW, @@ -43,6 +44,7 @@ def test_simple_types(self): self.assertTrue(isinstance(DOUBLE(), Type)) self.assertTrue(isinstance(VARCHAR(), Type)) self.assertTrue(isinstance(VARBINARY(), Type)) + self.assertTrue(isinstance(JSON(), Type)) def test_complex_types(self): self.assertTrue(isinstance(ARRAY(VARCHAR()), Type)) @@ -86,3 +88,5 @@ def test_equality(self): self.assertNotEqual(BIGINT(), INTEGER()) self.assertNotEqual(ARRAY(BIGINT()), REAL()) + self.assertNotEqual(VARBINARY(), VARCHAR()) + self.assertNotEqual(JSON(), VARCHAR()) diff --git a/scripts/checks/license-header.py b/scripts/checks/license-header.py index 6f2ebb52e809..03dcdaff48f6 100755 --- a/scripts/checks/license-header.py +++ b/scripts/checks/license-header.py @@ -100,7 +100,7 @@ def wrapper_hash(header, args): "*.cmake": attrdict({"wrapper": wrapper_hash, "hashbang": False}), "*.cpp": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), "*.hpp": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), - "*.dockfile": attrdict({"wrapper": wrapper_hash, "hashbang": False}), + "*.dockerfile": attrdict({"wrapper": wrapper_hash, "hashbang": False}), "*.h": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), "*.inc": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), "*.java": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), @@ -110,6 +110,7 @@ def wrapper_hash(header, args): "*.thrift": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), "*.txt": attrdict({"wrapper": wrapper_hash, "hashbang": True}), "*.yml": attrdict({"wrapper": wrapper_hash, "hashbang": False}), + "*.yaml": attrdict({"wrapper": wrapper_hash, "hashbang": False}), "*.cu": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), "*.cuh": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), "*.clcpp": attrdict({"wrapper": wrapper_chpp, "hashbang": False}), @@ -228,7 +229,13 @@ def main(): # If removing the header text, zero it out there. header_comment = "" - message(log_to, "Fix : " + filepath) + if os.environ.get("GITHUB_ACTIONS"): + print( + f"::error file={filepath},line=1,endLine=1,title=Missing license header::" + + "Please add the properly commented license header." + ) + else: + message(log_to, "Fix : " + filepath) if args.check: fail = True diff --git a/scripts/checks/run-clang-tidy.py b/scripts/checks/run-clang-tidy.py index 5d7016c3d9ed..c694ebacbd2a 100755 --- a/scripts/checks/run-clang-tidy.py +++ b/scripts/checks/run-clang-tidy.py @@ -15,75 +15,12 @@ import argparse import json -import regex +import re import sys +import os import util -CODE_CHECKS = """* - -abseil-* - -android-* - -cert-err58-cpp - -clang-analyzer-osx-* - -cppcoreguidelines-avoid-c-arrays - -cppcoreguidelines-avoid-magic-numbers - -cppcoreguidelines-pro-bounds-array-to-pointer-decay - -cppcoreguidelines-pro-bounds-pointer-arithmetic - -cppcoreguidelines-pro-type-reinterpret-cast - -cppcoreguidelines-pro-type-vararg - -fuchsia-* - -google-* - -hicpp-avoid-c-arrays - -hicpp-deprecated-headers - -hicpp-no-array-decay - -hicpp-use-equals-default - -hicpp-vararg - -llvmlibc-* - -llvm-header-guard - -llvm-include-order - -mpi-* - -misc-non-private-member-variables-in-classes - -misc-no-recursion - -misc-unused-parameters - -modernize-avoid-c-arrays - -modernize-deprecated-headers - -modernize-use-nodiscard - -modernize-use-trailing-return-type - -objc-* - -openmp-* - -readability-avoid-const-params-in-decls - -readability-convert-member-functions-to-static - -readability-magic-numbers - -zircon-* -""" - -# Additional opt-outs because googletest macros trip too many things. -# -TEST_CHECKS = ( - CODE_CHECKS - + """ - -cert-err58-cpp - -cppcoreguidelines-avoid-goto - -cppcoreguidelines-avoid-non-const-global-variables - -cppcoreguidelines-owning-memory - -cppcoreguidelines-pro-type-vararg - -cppcoreguidelines-special-member-functions - -hicpp-avoid-goto - -hicpp-special-member-functions - -hicpp-vararg - -misc-no-recursion - -readability-implicit-bool-conversion -""" -) - - -def check_list(check_string): - return ",".join([check.strip() for check in check_string.strip().splitlines()]) - - -CODE_CHECKS = check_list(CODE_CHECKS) -TEST_CHECKS = check_list(TEST_CHECKS) - class Multimap(dict): def __setitem__(self, key, value): @@ -101,15 +38,15 @@ def git_changed_lines(commit): line = line.rstrip("\n") fields = line.split() - match = regex.match(r"^\+\+\+ b/.*", line) + match = re.match(r"^\+\+\+ b/.*", line) if match: file = "" - match = regex.match(r"^\+\+\+ b/(.*(\.cpp|\.h))$", line) + match = re.match(r"^\+\+\+ b/(.*(\.cpp|\.h|\.hpp))$", line) if match: file = match.group(1) - match = regex.match(r"^@@", line) + match = re.match(r"^@@", line) if match and file != "": lspan = fields[2].split(",") if len(lspan) <= 1: @@ -117,60 +54,71 @@ def git_changed_lines(commit): changed_lines[file] = [int(lspan[0]), int(lspan[0]) + int(lspan[1])] - return json.dumps( - [{"name": key, "lines": value} for key, value in changed_lines.items()] - ) - - -def checks(args): - status, stdout, stderr = util.run( - f"clang-tidy -checks='{CODE_CHECKS}' --list-checks" - ) - print(stdout) + return changed_lines def check_output(output): - return regex.match(r"^/.* warning: ", output) + return re.match(r"(^/.* warning: |^$)", output) def tidy(args): files = util.input_files(args.files) + files = [file for file in files if re.match(r".*(\.cpp|\.h|\.hpp)$", file)] - groups = Multimap() + in_gha = os.environ.get("GITHUB_ACTIONS") is not None - for file in files: - groups["test" if "/tests/" in file else "main"] = file + changed_lines = git_changed_lines(args.commit) - fix = "--fix" if args.fix == "fix" else "" - lines = ( - ("'--line-filter=" + git_changed_lines(args.commit)) + "'" - if args.commit is not None - else "" + line_filter = json.dumps( + [{"name": key, "lines": value} for key, value in changed_lines.items()] ) + filtered_files = [*changed_lines.keys()] + if len(filtered_files) == 0: + return 0 + + fix = "--fix" if args.fix == "fix" else "" + lines = f"'--line-filter={line_filter}'" if args.commit is not None else "" ok = True - if groups.get("main", None): - status, stdout, stderr = util.run( - f"xargs clang-tidy -p=build/release/ --format-style=file -header-filter='.*' --checks='{CODE_CHECKS}' --quiet {fix} {lines}", - input=groups["main"], - ) - ok = check_output(stdout) and ok + build_path = args.p or os.getenv("BUILD_PATH") + build_path_str = f"-p {build_path}" if build_path else "" + + if build_path_str == "" and not os.path.isfile( + os.getcwd().join("compile_commands.json") + ): + print("compile_commands.json not found, skipping clang-tidy") + return 0 - if groups.get("test", None): - status, stdout, stderr = util.run( - f"xargs clang-tidy -p=build/release/ --format-style=file -header-filter='.*' --checks='{TEST_CHECKS}' --quiet {fix} {lines}", - input=groups["test"], + status, stdout, stderr = util.run( + f"xargs clang-tidy --format-style=file -header-filter='.*' --quiet {build_path_str} {fix} {lines}", + input=filtered_files, + ) + + if in_gha: + clang_tidy_pattern = ( + r"^(.*):(\d+):(\d+):\s+(error|warning):\s+(.*) \[([a-z0-9,\-]+)\]\s*$" ) - ok = check_output(stdout) and ok + + for stdout_line in stdout.split("\n"): + m = re.match(clang_tidy_pattern, stdout_line) + if m is not None: + file, line, col, severity, message, rule = m.groups() + file = file.removeprefix("/__w/velox/velox/") + print( + f"::{severity} file={file},line={line},col={col},title={rule}::{message}" + ) + + ok = check_output(stdout) return 0 if ok else 1 def parse_args(): global parser - parser = argparse.ArgumentParser(description="CircliCi Utility") + parser = argparse.ArgumentParser(description="Clang Tidy Utility") parser.add_argument("--commit") parser.add_argument("--fix") + parser.add_argument("-p", help="Path containing 'compile_commands.json'") parser.add_argument("files", metavar="FILES", nargs="+", help="files to process") diff --git a/scripts/checks/util.py b/scripts/checks/util.py index c0b34532a945..fe7f05c86dd7 100644 --- a/scripts/checks/util.py +++ b/scripts/checks/util.py @@ -15,7 +15,7 @@ import gzip import json import os -import regex +import re import subprocess import sys @@ -27,7 +27,7 @@ class attrdict(dict): class string(str): def extract(self, rexp): - return regex.match(rexp, self).group(1) + return re.match(rexp, self).group(1) def json(self): return json.loads(self, object_hook=attrdict) diff --git a/scripts/ci/bm-report/report.qmd b/scripts/ci/bm-report/report.qmd index a91a116a6b93..b9ad14c79a44 100644 --- a/scripts/ci/bm-report/report.qmd +++ b/scripts/ci/bm-report/report.qmd @@ -50,7 +50,7 @@ run_shas <- runs |> jsonlite::fromJSON() run_ids <- mruns(run_shas) |> - filter(commit.branch == "facebookincubator:main", substr(id, 1, 2) == "BM") |> + filter(substr(id, 1, 2) == "BM") |> pull(id) # Speed up local dev by saving 'results' as conbench requests can't be memoised diff --git a/scripts/ci/gersemi_cmd_definitions.py b/scripts/ci/gersemi_cmd_definitions.py new file mode 100644 index 000000000000..78aca426e6ae --- /dev/null +++ b/scripts/ci/gersemi_cmd_definitions.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy + +from gersemi.builtin_commands import builtin_commands + +# Gersemi throws a runtime error if two commands use the same builtin definition +# but changing the canonical name in a deepcopy prevents this +velox_add_library = deepcopy(builtin_commands["add_library"]) +velox_base_add_library = deepcopy(builtin_commands["add_library"]) +velox_add_library["_canonical_name"] = "velox_add_library" +velox_base_add_library["_canonical_name"] = "velox_base_add_library" + +pybind11_add_module = { + "front_positional_arguments": ["target_name"], + "options": [ + "MODULE", + "SHARED", + "EXCLUDE_FROM_ALL", + "THIN_LTO", + "NO_EXTRAS", + "OPT_SIZE", + ], + "back_positional_arguments": ["sources"], + "_canonical_name": "pybind11_add_module", +} +pyvelox_add_module = deepcopy(pybind11_add_module) +pyvelox_add_module["_canonical_name"] = "pyvelox_add_module" + +# Define the argument structure of our custom functions, this influences how they are formatted +command_definitions = { + "pybind11_add_module": pybind11_add_module, + "pyvelox_add_module": pyvelox_add_module, + "velox_add_library": velox_add_library, + "velox_base_add_library": velox_base_add_library, + "velox_build_dependency": { + "front_positional_arguments": ["dependency_name"], + }, + "velox_compile_definitions": builtin_commands["target_compile_definitions"], + "velox_get_rpath_origin": {"front_positional_arguments": ["output_variable"]}, + "velox_include_directories": builtin_commands["target_include_directories"], + "velox_install_library_headers": {"options": ["nothing"]}, + "velox_link_libraries": builtin_commands["target_link_libraries"], + "velox_resolve_dependency": builtin_commands["find_package"], + "velox_resolve_dependency_url": {"front_positional_arguments": ["dependency_name"]}, + "velox_set_source": {"front_positional_arguments": ["dependency_name"]}, + "velox_set_with_default": { + "front_positional_arguments": ["var_name", "envvar_name", "default"] + }, + "velox_sources": builtin_commands["target_sources"], +} diff --git a/scripts/ci/presto/start-prestojava.sh b/scripts/ci/presto/start-prestojava.sh index 4a02636aa820..30ef1ecfc965 100755 --- a/scripts/ci/presto/start-prestojava.sh +++ b/scripts/ci/presto/start-prestojava.sh @@ -15,4 +15,4 @@ set -e -"$PRESTO_HOME"/bin/launcher --pid-file=/tmp/pidfile run +"$PRESTO_HOME"/bin/launcher --pid-file=/tmp/pidfile run diff --git a/scripts/ci/signature.py b/scripts/ci/signature.py index daa876942f1a..f9534e88d0cc 100644 --- a/scripts/ci/signature.py +++ b/scripts/ci/signature.py @@ -337,7 +337,7 @@ def parse_args(args): bias_command_parser.add_argument("contender", type=str) bias_command_parser.add_argument("output_path", type=str) bias_command_parser.add_argument( - "ticket_value", type=get_tickets, default=10, nargs="?" + "ticket_value", type=get_tickets, default=20, nargs="?" ) bias_command_parser.add_argument("error_path", type=str, default="") diff --git a/scripts/ci/spark/start-spark.sh b/scripts/ci/spark/start-spark.sh index 7aca5999f8f9..7f081386452d 100755 --- a/scripts/ci/spark/start-spark.sh +++ b/scripts/ci/spark/start-spark.sh @@ -15,4 +15,4 @@ set -e -"$SPARK_HOME"/sbin/start-connect-server.sh --jars "$SPARK_HOME"/misc/spark-connect_2.12-3.5.1.jar +"$SPARK_HOME"/sbin/start-connect-server.sh --jars "$SPARK_HOME"/misc/spark-connect.jar diff --git a/scripts/docker/adapters.dockerfile b/scripts/docker/adapters.dockerfile deleted file mode 100644 index 47ce766b95bf..000000000000 --- a/scripts/docker/adapters.dockerfile +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Build the test and build container for presto_cpp -ARG image=ghcr.io/facebookincubator/velox-dev:centos9 -FROM $image - -COPY scripts/setup-centos9.sh / -COPY scripts/setup-common.sh / -COPY scripts/setup-versions.sh / -COPY scripts/setup-helper-functions.sh / -RUN mkdir build && bash -c "{ \ - cd build && \ - source /opt/rh/gcc-toolset-12/enable && \ - source /setup-centos9.sh && \ - install_adapters && \ - install_cuda 12.8; \ - }" && \ - rm -rf build && dnf remove -y conda && dnf clean all - -# put CUDA binaries on the PATH -ENV PATH=/usr/local/cuda/bin:${PATH} - -# configuration for nvidia-container-toolkit -ENV NVIDIA_VISIBLE_DEVICES=all -ENV NVIDIA_DRIVER_CAPABILITIES="compute,utility" - -# install miniforge -RUN curl -L -o /tmp/miniforge.sh https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh && \ - bash /tmp/miniforge.sh -b -p /opt/miniforge && \ - rm /tmp/miniforge.sh -ENV PATH=/opt/miniforge/condabin:${PATH} - -# install test dependencies -RUN mamba create -y --name adapters python=3.8 -SHELL ["mamba", "run", "-n", "adapters", "/bin/bash", "-c"] - -RUN pip install https://github.com/googleapis/storage-testbench/archive/refs/tags/v0.36.0.tar.gz -RUN mamba install -y nodejs -RUN npm install -g azurite - -ENV HADOOP_HOME=/usr/local/hadoop \ - HADOOP_ROOT_LOGGER="WARN,DRFA" \ - LC_ALL=C \ - PATH=/usr/local/hadoop/bin:${PATH} \ - JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk \ - PATH=/usr/lib/jvm/java-1.8.0-openjdk/bin:${PATH} - -COPY scripts/setup-classpath.sh / -ENTRYPOINT ["/bin/bash", "-c", "source /setup-classpath.sh && source /opt/rh/gcc-toolset-12/enable && exec \"$@\"", "--"] -CMD ["/bin/bash"] diff --git a/scripts/docker/centos-multi.dockerfile b/scripts/docker/centos-multi.dockerfile new file mode 100644 index 000000000000..d957e70d511d --- /dev/null +++ b/scripts/docker/centos-multi.dockerfile @@ -0,0 +1,166 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Build the test and build container for presto_cpp + +# This multi-stage Dockerfile contains the following relevant targets: +# - centos9: Our base CI build +# - adapters: Based on centos9 with all optional dependencies installed +# - pyvelox: Image used by cibuildwheel to build pyvelox + +######################## +# Stage 1: Base Build # +######################## +ARG image=quay.io/centos/centos:stream9 +FROM $image AS base-build + +COPY scripts/setup-helper-functions.sh / +COPY scripts/setup-versions.sh / +COPY scripts/setup-common.sh / +COPY scripts/setup-centos9.sh / +COPY CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch / + +ARG VELOX_BUILD_SHARED=ON +# Building libvelox.so requires folly and gflags to be built shared as well for now +ENV VELOX_BUILD_SHARED=${VELOX_BUILD_SHARED} + +RUN mkdir build +WORKDIR /build + +# We don't want UV symlinks to be copied into the following stages +ENV UV_TOOL_BIN_DIR=/usr/local/bin \ + UV_INSTALL_DIR=/usr/local/bin \ + INSTALL_PREFIX=/deps + +# CMake 4.0 removed support for cmake minimums of <=3.5 and will fail builds, this overrides it +ENV CMAKE_POLICY_VERSION_MINIMUM="3.5" \ + VELOX_ARROW_CMAKE_PATCH=/cmake-compatibility.patch + +# Some CMake configs contain the hard coded prefix '/deps', we need to replace that with +# the future location to avoid build errors in the base-image +RUN bash /setup-centos9.sh && \ + find $INSTALL_PREFIX/lib/cmake -type f -name '*.cmake' -exec sed -i 's|/deps/|/usr/local/|g' {} \; + +######################## +# Stage 2: Base Image # +######################## +FROM $image AS base-image + +COPY scripts/setup-helper-functions.sh / +COPY scripts/setup-versions.sh / +COPY scripts/setup-common.sh / +COPY scripts/setup-centos9.sh / + +# This way it's on the PATH and doesn't clash with the version installed in manylinux +ENV UV_TOOL_BIN_DIR=/usr/local/bin \ + UV_INSTALL_DIR=/usr/local/bin + +RUN /bin/bash -c 'source /setup-centos9.sh && \ + install_build_prerequisites && \ + install_velox_deps_from_dnf && \ + dnf clean all' + +RUN ln -s $(which python3) /usr/bin/python + +COPY --from=base-build /deps /usr/local + +######################## +# Stage: Centos 9 # +######################## +FROM base-image AS centos9 + +# Install tools needed for CI +RUN /bin/bash -c "source /setup-centos9.sh && \ + dnf_install 'dnf-command(config-manager)' && \ + dnf config-manager --add-repo 'https://cli.github.com/packages/rpm/gh-cli.repo' && \ + dnf_install gh jq && \ + dnf clean all" + +ENV CC=/opt/rh/gcc-toolset-12/root/bin/gcc \ + CXX=/opt/rh/gcc-toolset-12/root/bin/g++ + +ENTRYPOINT ["/bin/bash", "-c", "source /opt/rh/gcc-toolset-12/enable && exec \"$@\"", "--"] +CMD ["/bin/bash"] + +######################## +# Stage: PyVelox # +######################## +FROM base-image AS pyvelox + +ENV LD_LIBRARY_PATH="/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH" + +######################## +# Stage: Adapters Build# +######################## +FROM $image AS adapters-build + +COPY scripts/setup-helper-functions.sh / +COPY scripts/setup-versions.sh / +COPY scripts/setup-common.sh / +COPY scripts/setup-centos9.sh / +COPY scripts/setup-centos-adapters.sh / + +RUN mkdir build +WORKDIR /build + +ENV UV_TOOL_BIN_DIR=/usr/local/bin \ + UV_INSTALL_DIR=/usr/local/bin \ + INSTALL_PREFIX=/deps + +RUN bash -c 'source /setup-centos9.sh && install_build_prerequisites' + +RUN bash /setup-centos-adapters.sh install_adapters && \ + find $INSTALL_PREFIX/lib/cmake -type f -name '*.cmake' -exec sed -i 's|/deps/|/usr/local/|g' {} \; + +######################## +# Stage: Adapters # +######################## +FROM centos9 AS adapters + +COPY scripts/setup-centos-adapters.sh / + +RUN bash /setup-centos-adapters.sh install_cuda && \ + dnf clean all + +RUN bash /setup-centos-adapters.sh install_adapters_deps_from_dnf && \ + dnf clean all + +# put CUDA binaries on the PATH +ENV PATH=/usr/local/cuda/bin:${PATH} + +# configuration for nvidia-container-toolkit +ENV NVIDIA_VISIBLE_DEVICES=all +ENV NVIDIA_DRIVER_CAPABILITIES="compute,utility" + +# Install test dependencies +RUN uv pip install --system https://github.com/googleapis/storage-testbench/archive/refs/tags/v0.36.0.tar.gz + +RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.3/install.sh | bash +RUN . /root/.nvm/nvm.sh && nvm install 22 + +# Install azurite and symlink nvm managed binaries into /usr/local/bin which is on the path +RUN . /root/.nvm/nvm.sh && npm install -g azurite && \ + ln -s $(dirname $(which node))/* /usr/local/bin + +ENV HADOOP_HOME=/usr/local/hadoop \ + HADOOP_ROOT_LOGGER="WARN,DRFA" \ + LC_ALL=C \ + PATH=/usr/local/hadoop/bin:${PATH} \ + JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk \ + PATH=/usr/lib/jvm/java-1.8.0-openjdk/bin:${PATH} + +COPY --from=adapters-build /deps /usr/local + +COPY scripts/setup-classpath.sh / +ENTRYPOINT ["/bin/bash", "-c", "source /setup-classpath.sh && source /opt/rh/gcc-toolset-12/enable && exec \"$@\"", "--"] +CMD ["/bin/bash"] diff --git a/scripts/docker/centos.dockerfile b/scripts/docker/centos.dockerfile deleted file mode 100644 index 8d9ffb24563e..000000000000 --- a/scripts/docker/centos.dockerfile +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Build the test and build container for presto_cpp -ARG image=quay.io/centos/centos:stream9 -FROM $image - -COPY scripts/setup-helper-functions.sh / -COPY scripts/setup-versions.sh / -COPY scripts/setup-common.sh / -COPY scripts/setup-centos9.sh / - -# Building libvelox.so requires folly and gflags to be built shared as well for now -ENV VELOX_BUILD_SHARED=ON -# The removal of the build dir has to happen in the same layer as the build -# to minimize the image size. gh & jq are required for CI -RUN mkdir build && ( cd build && bash /setup-centos9.sh ) && rm -rf build && \ - dnf install -y -q 'dnf-command(config-manager)' && \ - dnf config-manager --add-repo 'https://cli.github.com/packages/rpm/gh-cli.repo' && \ - dnf install -y -q gh jq && \ - dnf clean all - -ENV CC=/opt/rh/gcc-toolset-12/root/bin/gcc \ - CXX=/opt/rh/gcc-toolset-12/root/bin/g++ - -ENTRYPOINT ["/bin/bash", "-c", "source /opt/rh/gcc-toolset-12/enable && exec \"$@\"", "--"] -CMD ["/bin/bash"] diff --git a/scripts/docker/docker-command.sh b/scripts/docker/docker-command.sh index 8abc2af30188..88faf22c997d 100755 --- a/scripts/docker/docker-command.sh +++ b/scripts/docker/docker-command.sh @@ -16,4 +16,4 @@ set -eu # Compilation and testing make -cd _build/release && ctest -j${NUM_THREADS} --output-on-failure --no-tests=error +cd _build/release && ctest -j"${NUM_THREADS}" --output-on-failure --no-tests=error diff --git a/scripts/docker/fedora.dockerfile b/scripts/docker/fedora.dockerfile new file mode 100644 index 000000000000..28478c5857bd --- /dev/null +++ b/scripts/docker/fedora.dockerfile @@ -0,0 +1,72 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################## +# Stage 1: Base Build # +######################## +ARG base=quay.io/fedora/fedora:42-x86_64 +FROM $base AS base-build + +COPY scripts/setup-helper-functions.sh / +COPY scripts/setup-versions.sh / +COPY scripts/setup-common.sh / +COPY scripts/setup-centos9.sh / +COPY scripts/setup-fedora.sh / +COPY CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch / + +ARG VELOX_BUILD_SHARED=ON +# Building libvelox.so requires folly and gflags to be built shared as well for now +ENV VELOX_BUILD_SHARED=${VELOX_BUILD_SHARED} + +RUN mkdir build +WORKDIR /build + +# We don't want UV symlinks to be copied into the following stages +ENV UV_TOOL_BIN_DIR=/usr/local/bin \ + UV_INSTALL_DIR=/usr/local/bin \ + INSTALL_PREFIX=/deps + +# CMake 4.0 removed support for cmake minimums of <=3.5 and will fail builds, this overrides it +ENV CMAKE_POLICY_VERSION_MINIMUM="3.5" \ + VELOX_ARROW_CMAKE_PATCH=/cmake-compatibility.patch + +# Some CMake configs contain the hard coded prefix '/deps', we need to replace that with +# the future location to avoid build errors in the base-image +RUN bash /setup-fedora.sh && \ + find $INSTALL_PREFIX/lib/cmake -type f -name '*.cmake' -exec sed -i 's|/deps/|/usr/local/|g' {} \; + +######################## +# Stage 2: Base Image # +######################## +FROM $base AS fedora + +COPY scripts/setup-helper-functions.sh / +COPY scripts/setup-versions.sh / +COPY scripts/setup-common.sh / +COPY scripts/setup-centos9.sh / +COPY scripts/setup-fedora.sh / + +# This way it's on the PATH and doesn't clash with the version installed in manylinux +ENV UV_TOOL_BIN_DIR=/usr/local/bin \ + UV_INSTALL_DIR=/usr/local/bin + +RUN /bin/bash -c 'source /setup-fedora.sh && \ + install_build_prerequisites && \ + install_velox_deps_from_dnf && \ + dnf_install jq gh &&\ + dnf clean all' + +RUN ln -s $(which python3) /usr/bin/python + +COPY --from=base-build /deps /usr/local diff --git a/scripts/docker/java.dockerfile b/scripts/docker/java.dockerfile new file mode 100644 index 000000000000..0d75fb216d6c --- /dev/null +++ b/scripts/docker/java.dockerfile @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Build the test and build container for presto_cpp + +# Global arg default to share across stages +ARG SPARK_VERSION=3.5.1 +ARG PRESTO_VERSION=0.295 + +######################### +# Stage: Spark Download # +######################### +# This allows us to cache the (slow) download until we change Spark version +FROM alpine:3.22 AS spark-download +ARG SPARK_VERSION + +RUN wget -O spark.tgz https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop3.tgz +RUN wget -O spark-connect.jar \ + https://repo1.maven.org/maven2/org/apache/spark/spark-connect_2.12/${SPARK_VERSION}/spark-connect_2.12-${SPARK_VERSION}.jar + +RUN tar -zxf spark.tgz + +########################## +# Stage: Presto Download # +########################## +# This allows us to cache the (slow) download until we change version +FROM alpine:3.22 AS presto-download +ARG PRESTO_VERSION + +RUN wget -O presto-server.tar.gz \ + https://repo1.maven.org/maven2/com/facebook/presto/presto-server/${PRESTO_VERSION}/presto-server-${PRESTO_VERSION}.tar.gz +RUN wget -O presto-cli \ + https://github.com/prestodb/presto/releases/download/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar + +RUN tar -xzf presto-server.tar.gz + +######################### +# Stage: Java Base # +######################### +FROM ghcr.io/facebookincubator/velox-dev:centos9 AS java-base + +RUN dnf install -y -q --setopt=install_weak_deps=False java-17-openjdk less procps tzdata + +# We set the timezone to America/Los_Angeles due to issue +# detailed here : https://github.com/facebookincubator/velox/issues/8127 +ENV TZ=America/Los_Angeles + +######################### +# Stage: Spark Server # +######################### +FROM java-base AS spark-server +ARG SPARK_VERSION + +ENV SPARK_HOME="/opt/spark-server" + +COPY scripts/ci/spark/conf/spark-defaults.conf.example $SPARK_HOME/conf/spark-defaults.conf +COPY scripts/ci/spark/conf/spark-env.sh.example $SPARK_HOME/conf/spark-env.sh +COPY scripts/ci/spark/conf/workers.example $SPARK_HOME/conf/workers +COPY scripts/ci/spark/start-spark.sh /opt/ + +COPY --from=spark-download /spark-${SPARK_VERSION}-bin-hadoop3 $SPARK_HOME/ +COPY --from=spark-download /spark-connect.jar $SPARK_HOME/misc/ + +WORKDIR /velox + +######################### +# Stage: Presto Java # +######################### +FROM java-base AS presto-java +ARG PRESTO_VERSION + +ENV PRESTO_HOME="/opt/presto-server" + +COPY scripts/ci/presto/etc/config.properties.example $PRESTO_HOME/etc/config.properties +COPY scripts/ci/presto/etc/jvm.config.example $PRESTO_HOME/etc/jvm.config +COPY scripts/ci/presto/etc/node.properties $PRESTO_HOME/etc/node.properties +COPY scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog/ +COPY scripts/ci/presto/start-prestojava.sh /opt/ + +COPY --from=presto-download /presto-server-$PRESTO_VERSION $PRESTO_HOME/ +COPY --from=presto-download --chmod=755 /presto-cli /opt/ + +RUN ln -s /opt/presto-cli /usr/local/bin/ && \ + mkdir -p $PRESTO_HOME/etc/data && \ + mkdir -p /usr/lib/presto/utils + + +WORKDIR /velox diff --git a/scripts/docker/prestojava-container.dockerfile b/scripts/docker/prestojava-container.dockerfile deleted file mode 100644 index 373e37db6286..000000000000 --- a/scripts/docker/prestojava-container.dockerfile +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Build the test and build container for presto_cpp -# -FROM ghcr.io/facebookincubator/velox-dev:centos9 - -ARG PRESTO_VERSION=0.290 - -COPY scripts /velox/scripts/ -RUN wget https://repo1.maven.org/maven2/com/facebook/presto/presto-server/${PRESTO_VERSION}/presto-server-${PRESTO_VERSION}.tar.gz -RUN wget https://repo1.maven.org/maven2/com/facebook/presto/presto-cli/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar - -ARG PRESTO_PKG=presto-server-$PRESTO_VERSION.tar.gz -ARG PRESTO_CLI_JAR=presto-cli-$PRESTO_VERSION-executable.jar - -ENV PRESTO_HOME="/opt/presto-server" -RUN cp $PRESTO_CLI_JAR /opt/presto-cli - -RUN dnf install -y java-11-openjdk less procps python3 tzdata \ - && ln -s $(which python3) /usr/bin/python \ - && tar -zxf $PRESTO_PKG \ - && mv ./presto-server-$PRESTO_VERSION $PRESTO_HOME \ - && chmod +x /opt/presto-cli \ - && ln -s /opt/presto-cli /usr/local/bin/ \ - && mkdir -p $PRESTO_HOME/etc \ - && mkdir -p $PRESTO_HOME/etc/catalog \ - && mkdir -p $PRESTO_HOME/etc/data \ - && mkdir -p /usr/lib/presto/utils - -# We set the timezone to America/Los_Angeles due to issue -# detailed here : https://github.com/facebookincubator/velox/issues/8127 -ENV TZ=America/Los_Angeles - -COPY scripts/ci/presto/etc/config.properties.example $PRESTO_HOME/etc/config.properties -COPY scripts/ci/presto/etc/jvm.config.example $PRESTO_HOME/etc/jvm.config -COPY scripts/ci/presto/etc/node.properties $PRESTO_HOME/etc/node.properties -COPY scripts/ci/presto/etc/hive.properties $PRESTO_HOME/etc/catalog -COPY scripts/ci/presto/start-prestojava.sh /opt - -WORKDIR /velox diff --git a/scripts/docker/pyvelox.dockerfile b/scripts/docker/pyvelox.dockerfile deleted file mode 100644 index 5757345ef9b8..000000000000 --- a/scripts/docker/pyvelox.dockerfile +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Build the test and build container for presto_cpp -FROM quay.io/pypa/manylinux_2_28_x86_64:latest - -COPY scripts/setup-helper-functions.sh / -COPY scripts/setup-manylinux.sh / - -# Build static folly to reduce wheel size (folly.so is ~120M) -ENV VELOX_BUILD_SHARED=OFF -# The removal of the build dir has to happen in the same layer as the build -# to minimize the image size. gh & jq are required for CI -RUN mkdir build && ( cd build && bash /setup-manylinux.sh ) && rm -rf build && \ - dnf install -y -q 'dnf-command(config-manager)' && \ - dnf config-manager --add-repo 'https://cli.github.com/packages/rpm/gh-cli.repo' && \ - dnf install -y -q gh jq && \ - dnf clean all - -ENV LD_LIBRARY_PATH="/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH" diff --git a/scripts/docker/spark-container.dockerfile b/scripts/docker/spark-container.dockerfile deleted file mode 100644 index a9825628088a..000000000000 --- a/scripts/docker/spark-container.dockerfile +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Build the test and build container for presto_cpp -# -FROM ghcr.io/facebookincubator/velox-dev:centos9 - -ARG SPARK_VERSION=3.5.1 - -COPY scripts /velox/scripts/ -RUN wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop3.tgz -RUN wget https://repo1.maven.org/maven2/org/apache/spark/spark-connect_2.12/${SPARK_VERSION}/spark-connect_2.12-${SPARK_VERSION}.jar - -ARG SPARK_PKG=spark-${SPARK_VERSION}-bin-hadoop3.tgz -ARG SPARK_CONNECT_JAR=spark-connect_2.12-${SPARK_VERSION}.jar - -ENV SPARK_HOME="/opt/spark-server" - -RUN dnf install -y java-11-openjdk less procps python3 tzdata \ - && ln -s $(which python3) /usr/bin/python \ - && tar -zxf $SPARK_PKG \ - && mv ./spark-${SPARK_VERSION}-bin-hadoop3 $SPARK_HOME \ - && mkdir ${SPARK_HOME}/misc/ \ - && mv ./$SPARK_CONNECT_JAR ${SPARK_HOME}/misc/ - -# We set the timezone to America/Los_Angeles due to issue -# detailed here : https://github.com/facebookincubator/velox/issues/8127 -ENV TZ=America/Los_Angeles - -COPY scripts/ci/spark/conf/spark-defaults.conf.example $SPARK_HOME/conf/spark-defaults.conf -COPY scripts/ci/spark/conf/spark-env.sh.example $SPARK_HOME/conf/spark-env.sh -COPY scripts/ci/spark/conf/workers.example $SPARK_HOME/conf/workers -COPY scripts/ci/spark/start-spark.sh /opt - -WORKDIR /velox diff --git a/scripts/docker/ubuntu-22.04-cpp.dockerfile b/scripts/docker/ubuntu-22.04-cpp.dockerfile index 5cd1f76099ac..8175ecd6823b 100644 --- a/scripts/docker/ubuntu-22.04-cpp.dockerfile +++ b/scripts/docker/ubuntu-22.04-cpp.dockerfile @@ -14,8 +14,6 @@ ARG base=ubuntu:22.04 FROM ${base} -SHELL ["/bin/bash", "-o", "pipefail", "-c"] - RUN apt update && \ apt install -y sudo \ lsb-release \ @@ -24,6 +22,11 @@ RUN apt update && \ COPY scripts /velox/scripts/ +COPY CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch / + +ENV VELOX_ARROW_CMAKE_PATCH=/cmake-compatibility.patch \ + UV_TOOL_BIN_DIR=/usr/local/bin \ + UV_INSTALL_DIR=/usr/local/bin # TZ and DEBIAN_FRONTEND="noninteractive" # are required to avoid tzdata installation @@ -32,6 +35,6 @@ ARG DEBIAN_FRONTEND="noninteractive" # Set a default timezone, can be overriden via ARG ARG tz="Etc/UTC" ENV TZ=${tz} -RUN /velox/scripts/setup-ubuntu.sh +RUN /bin/bash -o pipefail /velox/scripts/setup-ubuntu.sh WORKDIR /velox diff --git a/scripts/docker/yscope-velox-builder.dockerfile b/scripts/docker/yscope-velox-builder.dockerfile index 38f04afb6ca2..c70c13419210 100644 --- a/scripts/docker/yscope-velox-builder.dockerfile +++ b/scripts/docker/yscope-velox-builder.dockerfile @@ -21,32 +21,18 @@ SHELL ["/bin/bash", "-o", "pipefail", "-c"] ENV TZ=Etc/UTC ENV DEBIAN_FRONTEND=noninteractive -# Install CMake 3.28.3 using its install script -# NOTE: `scripts/setup-ubuntu.sh` installs CMake via pip, but sometimes the pip-installed CMake -# doesn't show up on the path in container environments (causing, for example, FastFloat library -# build failures). Using CMake's install script avoids this issue. -RUN curl --fail --location --show-error --silent --remote-name \ - https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-linux-x86_64.sh \ - && chmod +x cmake-3.28.3-linux-x86_64.sh \ - && ./cmake-3.28.3-linux-x86_64.sh --skip-license --prefix=/usr/local \ - && rm cmake-3.28.3-linux-x86_64.sh +# Copy dependency installation scripts and CMake modules together so that setup-common.sh's +# relative path resolution (SCRIPT_DIR/../CMake/...) works correctly (after rarely-changing +# layers for better caching) +COPY scripts /tmp/velox-deps/scripts/ +COPY CMake/resolve_dependency_modules /tmp/velox-deps/CMake/resolve_dependency_modules/ -# Copy dependency installation scripts (after rarely-changing layers for better caching) -COPY scripts /tmp/velox-deps/ +ENV UV_TOOL_BIN_DIR=/usr/local/bin +ENV UV_INSTALL_DIR=/usr/local/bin -RUN /tmp/velox-deps/setup-ubuntu.sh \ - && mv /tmp/.venv /opt/velox-venv \ +RUN /tmp/velox-deps/scripts/setup-ubuntu.sh \ && rm -rf /tmp/velox-deps -# Activate the virtual environment. -# -# NOTE: We set `ENV` variables directly rather than using `source /opt/velox-venv/bin/activate` in -# a `RUN` command since the latter only persists for that single instruction (each `RUN` starts a -# fresh shell), whereas the former persists across all subsequent `RUN` commands and in containers -# that use the image. -ENV VIRTUAL_ENV="/opt/velox-venv" -ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" - ENV CCACHE_DIR=/var/cache/ccache # Disable compression to trade disk space for speed in CI builds @@ -63,7 +49,7 @@ ENV CCACHE_NOHASHDIR=true # - We clear the stats after warmup so that CI builds only show their own cache hits. COPY . /tmp/velox-src/ WORKDIR /tmp/velox-src -RUN CCACHE_BASEDIR=/tmp/velox-src make release \ +RUN CCACHE_BASEDIR=/tmp/velox-src make release TREAT_WARNINGS_AS_ERRORS=0 \ && echo "CCache statistics after warmup build:" \ && ccache --verbose --show-stats \ && ccache --zero-stats diff --git a/scripts/info.sh b/scripts/info.sh index effbd94277c5..132a0fc3ca61 100755 --- a/scripts/info.sh +++ b/scripts/info.sh @@ -25,16 +25,16 @@ fi info=$(cmake --system-information) ext() { - grep -oE '".+"$' $1 | tr -d '"' + grep -oE '".+"$' | tr -d '"' } print_info() { -echo "$info" | grep -e "$1" | ext + echo "$info" | grep -e "$1" | ext } result=" Velox System Info v${version} -Commit: $(git rev-parse HEAD 2> /dev/null || echo "Not in a git repo.") +Commit: $(git rev-parse HEAD 2>/dev/null || echo "Not in a git repo.") CMake Version: $(cmake --version | grep -oE '[[:digit:]]+\.[[:digit:]]+\.[[:digit:]]+') System: $(print_info 'CMAKE_SYSTEM "') Arch: $(print_info 'CMAKE_SYSTEM_PROCESSOR') @@ -61,14 +61,14 @@ all="$result $conda" echo "$all" if [ -x "$(command -v xclip)" ]; then - clip="xclip -selection c" + clip="xclip -selection c" elif [ -x "$(command -v pbcopy)" ]; then clip="pbcopy" else - echo "\nThe results will be copied to your clipboard if xclip is installed." + printf "\nThe results will be copied to your clipboard if xclip is installed." fi -if [ ! -z "$clip" ]; then +if [ -n "$clip" ]; then echo "$all" | $clip echo "Result copied to clipboard!" fi diff --git a/scripts/setup-centos-adapters.sh b/scripts/setup-centos-adapters.sh new file mode 100755 index 000000000000..1167c95f69db --- /dev/null +++ b/scripts/setup-centos-adapters.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# shellcheck source-path=SCRIPT_DIR + +# This script installs addition dependencies for the adapters build +# of Velox. The scrip expects base dependencies to already be installed. +# +# This script is split of from setup-centos9.sh to improve docker caching +# +# Environment variables: +# * INSTALL_PREREQUISITES="N": Skip installation of packages for build. +# * PROMPT_ALWAYS_RESPOND="n": Automatically respond to interactive prompts. +# Use "n" to never wipe directories. +# * VELOX_CUDA_VERSION="12.8": Which version of CUDA to install, will pick up +# CUDA_VERSION from the env + +set -efx -o pipefail + +VELOX_CUDA_VERSION=${CUDA_VERSION:-"12.8"} +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") +source "$SCRIPT_DIR"/setup-centos9.sh + +function install_cuda { + # See https://developer.nvidia.com/cuda-downloads + local arch + arch="$(uname -m)" + local repo_url + version="${1:-$VELOX_CUDA_VERSION}" + + if [[ $arch == "x86_64" ]]; then + repo_url="https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo" + elif [[ $arch == "aarch64" ]]; then + # Using SBSA (Server Base System Architecture) repository for ARM64 servers + repo_url="https://developer.download.nvidia.com/compute/cuda/repos/rhel9/sbsa/cuda-rhel9.repo" + else + echo "Unsupported architecture: $arch" >&2 + return 1 + fi + + dnf config-manager --add-repo "$repo_url" + local dashed + dashed="$(echo "$version" | tr '.' '-')" + dnf_install \ + cuda-compat-"$dashed" \ + cuda-driver-devel-"$dashed" \ + cuda-minimal-build-"$dashed" \ + cuda-nvrtc-devel-"$dashed" \ + libcufile-devel-"$dashed" \ + libnvjitlink-devel-"$dashed" \ + numactl-libs +} + +function install_adapters_deps_from_dnf { + local gcs_deps=(curl-devel c-ares-devel re2-devel) + local azure_deps=(perl-IPC-Cmd openssl-devel libxml2-devel) + local hdfs_deps=(libxml2-devel libgsasl-devel libuuid-devel krb5-devel java-1.8.0-openjdk-devel) + + dnf_install "${azure_deps[@]}" "${gcs_deps[@]}" "${hdfs_deps[@]}" +} + +function install_s3 { + install_aws_deps + local MINIO_OS="linux" + install_minio ${MINIO_OS} +} + +function install_adapters { + run_and_time install_adapters_deps_from_dnf + run_and_time install_s3 + run_and_time install_gcs_sdk_cpp + run_and_time install_azure_storage_sdk_cpp + run_and_time install_hdfs_deps +} + +(return 2>/dev/null) && return # If script was sourced, don't run commands. + +( + if [[ $# -ne 0 ]]; then + # Activate gcc12; enable errors on unset variables afterwards. + source /opt/rh/gcc-toolset-12/enable || exit 1 + set -u + + for cmd in "$@"; do + run_and_time "${cmd}" + done + echo "All specified dependencies installed!" + else + # Activate gcc12; enable errors on unset variables afterwards. + source /opt/rh/gcc-toolset-12/enable || exit 1 + set -u + install_cuda "$VELOX_CUDA_VERSION" + install_adapters + echo "All dependencies for the Velox Adapters installed!" + if [[ ${USE_CLANG} != "false" ]]; then + echo "To use clang for the Velox build set the CC and CXX environment variables in your session." + echo " export CC=/usr/bin/clang-15" + echo " export CXX=/usr/bin/clang++-15" + fi + dnf clean all + fi +) diff --git a/scripts/setup-centos9.sh b/scripts/setup-centos9.sh index c0cd353fb667..54f6f94d697c 100755 --- a/scripts/setup-centos9.sh +++ b/scripts/setup-centos9.sh @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# shellcheck source-path=SCRIPT_DIR # This script documents setting up a Centos9 host for Velox # development. Running it should make you ready to compile. @@ -28,15 +29,16 @@ set -efx -o pipefail # Some of the packages must be build with the same compiler flags # so that some low level types are the same size. Also, disable warnings. -SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") -source $SCRIPTDIR/setup-common.sh -export CXXFLAGS=$(get_cxx_flags) # Used by boost. -export CFLAGS=${CXXFLAGS//"-std=c++17"/} # Used by LZO. +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") +source "$SCRIPT_DIR"/setup-common.sh +CXXFLAGS=$(get_cxx_flags) # Used by boost. +export CXXFLAGS export COMPILER_FLAGS=${CXXFLAGS} SUDO="${SUDO:-""}" USE_CLANG="${USE_CLANG:-false}" export INSTALL_PREFIX=${INSTALL_PREFIX:-"/usr/local"} DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)/deps-download} +export UV_TOOL_BIN_DIR="${UV_TOOL_BIN_DIR:-"$INSTALL_PREFIX"/bin}" function dnf_install { dnf install -y -q --setopt=install_weak_deps=False "$@" @@ -50,12 +52,15 @@ function install_clang15 { function install_build_prerequisites { dnf update -y dnf_install epel-release dnf-plugins-core # For ccache, ninja - dnf config-manager --set-enabled crb - dnf update -y - dnf_install ninja-build cmake ccache gcc-toolset-12 git wget which - dnf_install autoconf automake python3-devel pip libtool + if grep -q CentOS /etc/os-release; then + dnf config-manager --set-enabled crb + dnf update -y + fi + dnf_install autoconf automake ccache clang gcc-toolset-12 gcc-toolset-14 git libtool \ + llvm ninja-build python3-pip python3-devel wget which - pip install cmake==3.30.4 + install_uv + uv_install cmake@3.31.1 if [[ ${USE_CLANG} != "false" ]]; then install_clang15 @@ -69,8 +74,7 @@ function install_velox_deps_from_dnf { libdwarf-devel elfutils-libelf-devel curl-devel libicu-devel bison flex \ libsodium-devel zlib-devel gtest-devel gmock-devel xxhash-devel - # install sphinx for doc gen - pip install sphinx sphinx-tabs breathe sphinx_rtd_theme + install_faiss_deps } function install_conda { @@ -80,51 +84,12 @@ function install_conda { function install_gflags { # Remove an older version if present. dnf remove -y gflags - wget_and_untar https://github.com/gflags/gflags/archive/${GFLAGS_VERSION}.tar.gz gflags + wget_and_untar https://github.com/gflags/gflags/archive/"${GFLAGS_VERSION}".tar.gz gflags cmake_install_dir gflags -DBUILD_SHARED_LIBS=ON -DBUILD_STATIC_LIBS=ON -DBUILD_gflags_LIB=ON -DLIB_SUFFIX=64 } -function install_cuda { - # See https://developer.nvidia.com/cuda-downloads - dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo - local dashed="$(echo $1 | tr '.' '-')" - dnf install -y \ - cuda-compat-$dashed \ - cuda-driver-devel-$dashed \ - cuda-minimal-build-$dashed \ - cuda-nvrtc-devel-$dashed -} - -function install_s3 { - install_aws_deps - - local MINIO_OS="linux" - install_minio ${MINIO_OS} -} - -function install_gcs { - # Dependencies of GCS, probably a workaround until the docker image is rebuilt - dnf -y install npm curl-devel c-ares-devel - install_gcs-sdk-cpp -} - -function install_abfs { - # Dependencies of Azure Storage Blob cpp - dnf -y install perl-IPC-Cmd openssl libxml2-devel - install_azure-storage-sdk-cpp -} - -function install_hdfs { - dnf -y install libxml2-devel libgsasl-devel libuuid-devel krb5-devel - install_hdfs_deps - yum install -y java-1.8.0-openjdk-devel -} - -function install_adapters { - run_and_time install_s3 - run_and_time install_gcs - run_and_time install_abfs - run_and_time install_hdfs +function install_faiss_deps { + dnf_install openblas-devel libomp } function install_velox_deps { @@ -132,7 +97,6 @@ function install_velox_deps { run_and_time install_conda run_and_time install_gflags run_and_time install_glog - run_and_time install_lzo run_and_time install_snappy run_and_time install_boost run_and_time install_protobuf @@ -150,9 +114,10 @@ function install_velox_deps { run_and_time install_xsimd run_and_time install_simdjson run_and_time install_geos + run_and_time install_faiss } -(return 2> /dev/null) && return # If script was sourced, don't run commands. +(return 2>/dev/null) && return # If script was sourced, don't run commands. ( if [[ $# -ne 0 ]]; then diff --git a/scripts/setup-classpath.sh b/scripts/setup-classpath.sh index bfd7066dc63c..fbb99f6751a8 100644 --- a/scripts/setup-classpath.sh +++ b/scripts/setup-classpath.sh @@ -11,5 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# shellcheck shell=bash -export CLASSPATH=`/usr/local/hadoop/bin/hdfs classpath --glob` +CLASSPATH=$(/usr/local/hadoop/bin/hdfs classpath --glob) +export CLASSPATH diff --git a/scripts/setup-common.sh b/scripts/setup-common.sh index d9a02c47f0c5..b9ae39ed89ae 100755 --- a/scripts/setup-common.sh +++ b/scripts/setup-common.sh @@ -12,153 +12,205 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# trigger reinstall +# shellcheck source-path=SCRIPT_DIR -SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") -source $SCRIPTDIR/setup-helper-functions.sh -source $SCRIPTDIR/setup-versions.sh +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") +source "$SCRIPT_DIR"/setup-helper-functions.sh +source "$SCRIPT_DIR"/setup-versions.sh -VELOX_BUILD_SHARED=${VELOX_BUILD_SHARED:-"OFF"} #Build folly and gflags shared for use in libvelox.so. +VELOX_BUILD_SHARED=${VELOX_BUILD_SHARED:-"OFF"} #Build folly and gflags shared for use in libvelox.so. +VELOX_ARROW_CMAKE_PATCH=${VELOX_ARROW_CMAKE_PATCH:-""} # avoid error due to +u CMAKE_BUILD_TYPE="${BUILD_TYPE:-Release}" DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} BUILD_GEOS="${BUILD_GEOS:-true}" +BUILD_FAISS="${BUILD_FAISS:-true}" BUILD_DUCKDB="${BUILD_DUCKDB:-true}" EXTRA_ARROW_OPTIONS=${EXTRA_ARROW_OPTIONS:-""} +SIMDJSON_SKIPUTF8VALIDATION=${SIMDJSON_SKIPUTF8VALIDATION:-"OFF"} USE_CLANG="${USE_CLANG:-false}" MACHINE=$(uname -m) -WGET_OPTIONS=${WGET_OPTIONS:-""} +# Read WGET_OPTIONS into an array which can be expanded to nothing +# using a normal variable expands into an empty string causing wget to exit 1 +read -r -a WGET_OPTS <<<"${WGET_OPTIONS:-}" mkdir -p "${DEPENDENCY_DIR}" function install_fmt { - wget_and_untar https://github.com/fmtlib/fmt/archive/${FMT_VERSION}.tar.gz fmt + wget_and_untar https://github.com/fmtlib/fmt/archive/"${FMT_VERSION}".tar.gz fmt cmake_install_dir fmt -DFMT_TEST=OFF } function install_folly { # Folly Portability.h being used to decide whether or not support coroutines # causes issues (build, link) if the selection is not consistent across users of folly. + # shellcheck disable=SC2034 EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/folly/archive/refs/tags/${FB_OS_VERSION}.tar.gz folly + wget_and_untar https://github.com/facebook/folly/archive/refs/tags/"${FB_OS_VERSION}".tar.gz folly cmake_install_dir folly -DBUILD_SHARED_LIBS="$VELOX_BUILD_SHARED" -DBUILD_TESTS=OFF -DFOLLY_HAVE_INT128_T=ON } function install_fizz { # Folly Portability.h being used to decide whether or not support coroutines # causes issues (build, link) if the selection is not consistent across users of folly. + # shellcheck disable=SC2034 EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebookincubator/fizz/archive/refs/tags/${FB_OS_VERSION}.tar.gz fizz + wget_and_untar https://github.com/facebookincubator/fizz/archive/refs/tags/"${FB_OS_VERSION}".tar.gz fizz cmake_install_dir fizz/fizz -DBUILD_TESTS=OFF } function install_fast_float { - wget_and_untar https://github.com/fastfloat/fast_float/archive/refs/tags/${FAST_FLOAT_VERSION}.tar.gz fast_float + wget_and_untar https://github.com/fastfloat/fast_float/archive/refs/tags/"${FAST_FLOAT_VERSION}".tar.gz fast_float cmake_install_dir fast_float -DBUILD_TESTS=OFF } function install_wangle { # Folly Portability.h being used to decide whether or not support coroutines # causes issues (build, link) if the selection is not consistent across users of folly. + # shellcheck disable=SC2034 EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/wangle/archive/refs/tags/${FB_OS_VERSION}.tar.gz wangle + wget_and_untar https://github.com/facebook/wangle/archive/refs/tags/"${FB_OS_VERSION}".tar.gz wangle cmake_install_dir wangle/wangle -DBUILD_TESTS=OFF } function install_mvfst { # Folly Portability.h being used to decide whether or not support coroutines # causes issues (build, link) if the selection is not consistent across users of folly. + # shellcheck disable=SC2034 EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/mvfst/archive/refs/tags/${FB_OS_VERSION}.tar.gz mvfst + wget_and_untar https://github.com/facebook/mvfst/archive/refs/tags/"${FB_OS_VERSION}".tar.gz mvfst cmake_install_dir mvfst -DBUILD_TESTS=OFF } function install_fbthrift { # Folly Portability.h being used to decide whether or not support coroutines # causes issues (build, link) if the selection is not consistent across users of folly. + # shellcheck disable=SC2034 EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/fbthrift/archive/refs/tags/${FB_OS_VERSION}.tar.gz fbthrift + wget_and_untar https://github.com/facebook/fbthrift/archive/refs/tags/"${FB_OS_VERSION}".tar.gz fbthrift cmake_install_dir fbthrift -Denable_tests=OFF -DBUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF } function install_duckdb { - if $BUILD_DUCKDB ; then - wget_and_untar https://github.com/duckdb/duckdb/archive/refs/tags/${DUCKDB_VERSION}.tar.gz duckdb - cmake_install_dir duckdb -DBUILD_UNITTESTS=OFF -DENABLE_SANITIZER=OFF -DENABLE_UBSAN=OFF -DBUILD_SHELL=OFF -DEXPORT_DLL_SYMBOLS=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + if $BUILD_DUCKDB; then + wget_and_untar https://github.com/duckdb/duckdb/archive/refs/tags/"${DUCKDB_VERSION}".tar.gz duckdb + # DuckDB uses git commands to retrieve version information during the build, + # which works with git clone. To prevent incorrectly using the parent project's + # git version when building from a tarball, we define GIT_COMMIT_HASH to skip + # that. + cmake_install_dir duckdb \ + -DGIT_COMMIT_HASH="6536a77" -DBUILD_UNITTESTS=OFF -DENABLE_SANITIZER=OFF -DENABLE_UBSAN=OFF \ + -DBUILD_SHELL=OFF -DEXPORT_DLL_SYMBOLS=OFF -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" fi } function install_boost { - wget_and_untar https://github.com/boostorg/boost/releases/download/${BOOST_VERSION}/${BOOST_VERSION}.tar.gz boost + wget_and_untar https://github.com/boostorg/boost/releases/download/"${BOOST_VERSION}"/"${BOOST_VERSION}".tar.gz boost ( - cd ${DEPENDENCY_DIR}/boost + cd "${DEPENDENCY_DIR}"/boost || exit if [[ "$(uname)" == "Linux" && ${USE_CLANG} != "false" ]]; then - ./bootstrap.sh --prefix=${INSTALL_PREFIX} --with-toolset="clang-15" + ./bootstrap.sh --prefix="${INSTALL_PREFIX}" --with-toolset="clang-15" # Switch the compiler from the clang-15 toolset which doesn't exist (clang-15.jam) to # clang of version 15 when toolset clang-15 is used. # This reconciles the project-config.jam generation with what the b2 build system allows for customization. sed -i 's/using clang-15/using clang : 15/g' project-config.jam ${SUDO} ./b2 "-j${NPROC}" -d0 install threading=multi toolset=clang-15 --without-python else - ./bootstrap.sh --prefix=${INSTALL_PREFIX} + ./bootstrap.sh --prefix="${INSTALL_PREFIX}" ${SUDO} ./b2 "-j${NPROC}" -d0 install threading=multi --without-python fi ) } function install_protobuf { - wget_and_untar https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protobuf-all-${PROTOBUF_VERSION}.tar.gz protobuf + install_abseil + + wget_and_untar https://github.com/protocolbuffers/protobuf/releases/download/v"${PROTOBUF_VERSION}"/protobuf-all-"${PROTOBUF_VERSION}".tar.gz protobuf cmake_install_dir protobuf -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_ABSL_PROVIDER=package } function install_double_conversion { - wget_and_untar https://github.com/google/double-conversion/archive/refs/tags/${DOUBLE_CONVERSION_VERSION}.tar.gz double-conversion + wget_and_untar https://github.com/google/double-conversion/archive/refs/tags/"${DOUBLE_CONVERSION_VERSION}".tar.gz double-conversion cmake_install_dir double-conversion -DBUILD_TESTING=OFF } function install_ranges_v3 { - wget_and_untar https://github.com/ericniebler/range-v3/archive/refs/tags/${RANGE_V3_VERSION}.tar.gz ranges_v3 + wget_and_untar https://github.com/ericniebler/range-v3/archive/refs/tags/"${RANGE_V3_VERSION}".tar.gz ranges_v3 cmake_install_dir ranges_v3 -DRANGES_ENABLE_WERROR=OFF -DRANGE_V3_TESTS=OFF -DRANGE_V3_EXAMPLES=OFF } +function install_abseil { + wget_and_untar https://github.com/abseil/abseil-cpp/archive/refs/tags/"${ABSEIL_VERSION}".tar.gz abseil-cpp + local OS + OS=$(uname) + if [[ $OS == "Darwin" ]]; then + ABSOLUTE_SCRIPTDIR=$(realpath "$SCRIPT_DIR") + ( + cd "${DEPENDENCY_DIR}/abseil-cpp" || exit 1 + git apply $ABSOLUTE_SCRIPTDIR/../CMake/resolve_dependency_modules/absl/absl-macos.patch + ) + fi + cmake_install_dir abseil-cpp \ + -DABSL_BUILD_TESTING=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DABSL_ENABLE_INSTALL=ON +} + function install_re2 { - wget_and_untar https://github.com/google/re2/archive/refs/tags/${RE2_VERSION}.tar.gz re2 - cmake_install_dir re2 -DRE2_BUILD_TESTING=OFF + install_abseil + + wget_and_untar https://github.com/google/re2/archive/refs/tags/"${RE2_VERSION}".tar.gz re2 + cmake_install_dir re2 -DRE2_BUILD_TESTING=OFF -Dabsl_DIR="${INSTALL_PREFIX}/lib/cmake/absl" } function install_glog { - wget_and_untar https://github.com/google/glog/archive/${GLOG_VERSION}.tar.gz glog + wget_and_untar https://github.com/google/glog/archive/"${GLOG_VERSION}".tar.gz glog cmake_install_dir glog -DBUILD_SHARED_LIBS=ON } function install_lzo { - wget_and_untar http://www.oberhumer.com/opensource/lzo/download/lzo-${LZO_VERSION}.tar.gz lzo + wget_and_untar http://www.oberhumer.com/opensource/lzo/download/lzo-"${LZO_VERSION}".tar.gz lzo ( - cd ${DEPENDENCY_DIR}/lzo - ./configure --prefix=${INSTALL_PREFIX} --enable-shared --disable-static --docdir=/usr/share/doc/lzo-${LZO_VERSION} + cd "${DEPENDENCY_DIR}"/lzo || exit + ./configure --prefix="${INSTALL_PREFIX}" --enable-shared --disable-static --docdir=/usr/share/doc/lzo-"${LZO_VERSION}" make "-j${NPROC}" ${SUDO} make install ) } function install_snappy { - wget_and_untar https://github.com/google/snappy/archive/${SNAPPY_VERSION}.tar.gz snappy + wget_and_untar https://github.com/google/snappy/archive/"${SNAPPY_VERSION}".tar.gz snappy cmake_install_dir snappy -DSNAPPY_BUILD_TESTS=OFF } function install_xsimd { - wget_and_untar https://github.com/xtensor-stack/xsimd/archive/refs/tags/${XSIMD_VERSION}.tar.gz xsimd + wget_and_untar https://github.com/xtensor-stack/xsimd/archive/refs/tags/"${XSIMD_VERSION}".tar.gz xsimd cmake_install_dir xsimd } function install_simdjson { - wget_and_untar https://github.com/simdjson/simdjson/archive/refs/tags/v${SIMDJSON_VERSION}.tar.gz simdjson - cmake_install_dir simdjson + wget_and_untar https://github.com/simdjson/simdjson/archive/refs/tags/v"${SIMDJSON_VERSION}".tar.gz simdjson + cmake_install_dir simdjson -DSIMDJSON_SKIPUTF8VALIDATION=${SIMDJSON_SKIPUTF8VALIDATION} } function install_arrow { - wget_and_untar https://github.com/apache/arrow/archive/apache-arrow-${ARROW_VERSION}.tar.gz arrow + wget_and_untar https://github.com/apache/arrow/archive/apache-arrow-"${ARROW_VERSION}".tar.gz arrow + ( + # Can be removed after an upgrade to Arrow 20.0.0 + if [ -z "$VELOX_ARROW_CMAKE_PATCH" ]; then + # We need to set a different path when building the Dockerfile. + ABSOLUTE_SCRIPTDIR=$(realpath "$SCRIPT_DIR") + VELOX_ARROW_CMAKE_PATCH="$ABSOLUTE_SCRIPTDIR/../CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch" + fi + + cd "$DEPENDENCY_DIR"/arrow || exit 1 + git apply "$VELOX_ARROW_CMAKE_PATCH" + ) || exit 1 + cmake_install_dir arrow/cpp \ -DARROW_PARQUET=OFF \ -DARROW_WITH_THRIFT=ON \ @@ -171,74 +223,84 @@ function install_arrow { -DARROW_RUNTIME_SIMD_LEVEL=NONE \ -DARROW_WITH_UTF8PROC=OFF \ -DARROW_TESTING=ON \ - -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ + -DCMAKE_INSTALL_PREFIX="$INSTALL_PREFIX" \ -DCMAKE_BUILD_TYPE=Release \ -DARROW_BUILD_STATIC=ON \ - -DBOOST_ROOT=${INSTALL_PREFIX} + -DBOOST_ROOT="$INSTALL_PREFIX" \ + $EXTRA_ARROW_OPTIONS } function install_thrift { - wget_and_untar https://github.com/apache/thrift/archive/${THRIFT_VERSION}.tar.gz thrift - - EXTRA_CXXFLAGS="-O3 -fPIC" - # Clang will generate warnings and they need to be suppressed, otherwise the build will fail. - if [[ ${USE_CLANG} != "false" ]]; then - EXTRA_CXXFLAGS="-O3 -fPIC -Wno-inconsistent-missing-override -Wno-unused-but-set-variable" - fi - - CXX_FLAGS="$EXTRA_CXXFLAGS" cmake_install_dir thrift \ - -DBUILD_SHARED_LIBS=OFF \ - -DBUILD_COMPILER=ON \ - -DBUILD_EXAMPLES=OFF \ - -DBUILD_TUTORIALS=OFF \ - -DCMAKE_DEBUG_POSTFIX= \ - -DWITH_AS3=OFF \ - -DWITH_CPP=ON \ - -DWITH_C_GLIB=OFF \ - -DWITH_JAVA=OFF \ - -DWITH_JAVASCRIPT=OFF \ - -DWITH_LIBEVENT=OFF \ - -DWITH_NODEJS=OFF \ - -DWITH_PYTHON=OFF \ - -DWITH_QT5=OFF \ - -DWITH_ZLIB=OFF \ - ${EXTRA_ARROW_OPTIONS} + wget_and_untar https://github.com/apache/thrift/archive/"${THRIFT_VERSION}".tar.gz thrift + + EXTRA_CXXFLAGS="-O3 -fPIC" + # Clang will generate warnings and they need to be suppressed, otherwise the build will fail. + if [[ ${USE_CLANG} != "false" ]]; then + EXTRA_CXXFLAGS="-O3 -fPIC -Wno-inconsistent-missing-override -Wno-unused-but-set-variable" + fi + + CXX_FLAGS="$EXTRA_CXXFLAGS" cmake_install_dir thrift \ + -DBUILD_SHARED_LIBS=OFF \ + -DBUILD_COMPILER=ON \ + -DBUILD_EXAMPLES=OFF \ + -DBUILD_TUTORIALS=OFF \ + -DCMAKE_DEBUG_POSTFIX= \ + -DWITH_AS3=OFF \ + -DWITH_CPP=ON \ + -DWITH_C_GLIB=OFF \ + -DWITH_JAVA=OFF \ + -DWITH_JAVASCRIPT=OFF \ + -DWITH_LIBEVENT=OFF \ + -DWITH_NODEJS=OFF \ + -DWITH_PYTHON=OFF \ + -DWITH_QT5=OFF \ + -DWITH_ZLIB=OFF } function install_stemmer { - wget_and_untar https://snowballstem.org/dist/libstemmer_c-${STEMMER_VERSION}.tar.gz stemmer + wget_and_untar https://snowballstem.org/dist/libstemmer_c-"${STEMMER_VERSION}".tar.gz stemmer ( - cd ${DEPENDENCY_DIR}/stemmer - sed -i '/CPPFLAGS=-Iinclude/ s/$/ -fPIC/' Makefile + cd "${DEPENDENCY_DIR}"/stemmer || exit + sed -i='' '/CPPFLAGS=-Iinclude/ s/$/ -fPIC/' Makefile make clean && make "-j${NPROC}" - ${SUDO} cp libstemmer.a ${INSTALL_PREFIX}/lib/ - ${SUDO} cp include/libstemmer.h ${INSTALL_PREFIX}/include/ + ${SUDO} cp libstemmer.a "${INSTALL_PREFIX}"/lib/ + ${SUDO} cp include/libstemmer.h "${INSTALL_PREFIX}"/include/ ) } function install_geos { - if [[ "$BUILD_GEOS" == "true" ]]; then - wget_and_untar https://github.com/libgeos/geos/archive/${GEOS_VERSION}.tar.gz geos - if [[ "$(uname)" == "Darwin" ]]; then - ABSOLUTE_SCRIPTDIR=$(realpath ${SCRIPTDIR}) - ( - # Adopted from the bundled patching needed for macOS. - cd "${DEPENDENCY_DIR}/geos" || exit 1 - git apply "${ABSOLUTE_SCRIPTDIR}/../CMake/resolve_dependency_modules/geos/geos-cmakelists.patch" - git apply "${ABSOLUTE_SCRIPTDIR}/../CMake/resolve_dependency_modules/geos/geos-build.patch" - ) - fi + if [[ $BUILD_GEOS == "true" ]]; then + wget_and_untar https://github.com/libgeos/geos/archive/"${GEOS_VERSION}".tar.gz geos cmake_install_dir geos -DBUILD_TESTING=OFF fi } +function install_faiss_deps { + echo "Unsupported platform for faiss" +} + +function install_faiss { + if [[ $BUILD_FAISS == "true" ]]; then + # Install OpenBLAS and libomp if not already installed + install_faiss_deps + + wget_and_untar "https://github.com/facebookresearch/faiss/archive/refs/tags/v${FAISS_VERSION}.tar.gz" faiss + cmake_install_dir faiss \ + -DFAISS_ENABLE_GPU=OFF \ + -DFAISS_ENABLE_PYTHON=OFF \ + -DFAISS_ENABLE_REMOTE=OFF \ + -DFAISS_ENABLE_GPU_TESTS=OFF \ + -DFAISS_ENABLE_BENCHMARKS=OFF + fi +} + # Adapters that can be installed. function install_aws_deps { local AWS_REPO_NAME="aws/aws-sdk-cpp" - github_checkout $AWS_REPO_NAME $AWS_SDK_VERSION --depth 1 --recurse-submodules - cmake_install -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS:BOOL=OFF -DMINIMIZE_SIZE:BOOL=ON -DENABLE_TESTING:BOOL=OFF -DBUILD_ONLY:STRING="s3;identity-management" + github_checkout $AWS_REPO_NAME "$AWS_SDK_VERSION" --depth 1 --recurse-submodules + cmake_install_dir aws-sdk-cpp -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS:BOOL=OFF -DMINIMIZE_SIZE:BOOL=ON -DENABLE_TESTING:BOOL=OFF -DBUILD_ONLY:STRING="s3;identity-management" } function install_minio { @@ -253,32 +315,28 @@ function install_minio { echo "Unsupported Minio platform" fi - wget ${WGET_OPTIONS} https://dl.min.io/server/minio/release/${MINIO_OS}-${MINIO_ARCH}/archive/minio.RELEASE.${MINIO_VERSION} -O ${MINIO_BINARY_NAME} - chmod +x ./${MINIO_BINARY_NAME} - ${SUDO} mv ./${MINIO_BINARY_NAME} /usr/local/bin/ + wget "${WGET_OPTS[@]}" https://dl.min.io/server/minio/release/"${MINIO_OS}"-${MINIO_ARCH}/archive/minio.RELEASE."${MINIO_VERSION}" -O "${MINIO_BINARY_NAME}" + chmod +x ./"${MINIO_BINARY_NAME}" + mkdir -p "$INSTALL_PREFIX"/bin/ + ${SUDO} mv ./"${MINIO_BINARY_NAME}" "$INSTALL_PREFIX"/bin/ } -function install_gcs-sdk-cpp { +function install_gcs_sdk_cpp { # Install gcs dependencies # https://github.com/googleapis/google-cloud-cpp/blob/main/doc/packaging.md#required-libraries # abseil-cpp - github_checkout abseil/abseil-cpp ${ABSEIL_VERSION} --depth 1 - cmake_install \ - -DABSL_BUILD_TESTING=OFF \ - -DCMAKE_CXX_STANDARD=17 \ - -DABSL_PROPAGATE_CXX_STD=ON \ - -DABSL_ENABLE_INSTALL=ON + install_abseil # protobuf - github_checkout protocolbuffers/protobuf v${PROTOBUF_VERSION} --depth 1 - cmake_install \ + github_checkout protocolbuffers/protobuf v"${PROTOBUF_VERSION}" --depth 1 + cmake_install_dir protobuf \ -Dprotobuf_BUILD_TESTS=OFF \ -Dprotobuf_ABSL_PROVIDER=package # grpc - github_checkout grpc/grpc ${GRPC_VERSION} --depth 1 - cmake_install \ + github_checkout grpc/grpc "${GRPC_VERSION}" --depth 1 + cmake_install_dir grpc \ -DgRPC_BUILD_TESTS=OFF \ -DgRPC_ABSL_PROVIDER=package \ -DgRPC_ZLIB_PROVIDER=package \ @@ -289,71 +347,95 @@ function install_gcs-sdk-cpp { -DgRPC_INSTALL=ON # crc32 - github_checkout google/crc32c ${CRC32_VERSION} --depth 1 - cmake_install \ + github_checkout google/crc32c "${CRC32_VERSION}" --depth 1 + cmake_install_dir crc32c \ -DCRC32C_BUILD_TESTS=OFF \ -DCRC32C_BUILD_BENCHMARKS=OFF \ -DCRC32C_USE_GLOG=OFF # nlohmann json - github_checkout nlohmann/json ${NLOHMAN_JSON_VERSION} --depth 1 - cmake_install \ + github_checkout nlohmann/json "${NLOHMAN_JSON_VERSION}" --depth 1 + cmake_install_dir json \ -DJSON_BuildTests=OFF # google-cloud-cpp - github_checkout googleapis/google-cloud-cpp ${GOOGLE_CLOUD_CPP_VERSION} --depth 1 - cmake_install \ + github_checkout googleapis/google-cloud-cpp "${GOOGLE_CLOUD_CPP_VERSION}" --depth 1 + cmake_install_dir google-cloud-cpp \ -DGOOGLE_CLOUD_CPP_ENABLE_EXAMPLES=OFF \ -DGOOGLE_CLOUD_CPP_ENABLE=storage } -function install_azure-storage-sdk-cpp { +function install_azure_storage_sdk_cpp { # Disable VCPKG to install additional static dependencies under the VCPKG installed path # instead of using system pre-installed dependencies. export AZURE_SDK_DISABLE_AUTO_VCPKG=ON vcpkg_commit_id=7a6f366cefd27210f6a8309aed10c31104436509 - github_checkout azure/azure-sdk-for-cpp azure-storage-files-datalake_${AZURE_SDK_VERSION} - sed -i "s/set(VCPKG_COMMIT_STRING .*)/set(VCPKG_COMMIT_STRING $vcpkg_commit_id)/" cmake-modules/AzureVcpkg.cmake + github_checkout azure/azure-sdk-for-cpp azure-storage-files-datalake_"${AZURE_SDK_VERSION}" + pushd azure-sdk-for-cpp || exit + sed -i='' "s/set(VCPKG_COMMIT_STRING .*)/set(VCPKG_COMMIT_STRING $vcpkg_commit_id)/" cmake-modules/AzureVcpkg.cmake azure_core_dir="sdk/core/azure-core" if ! grep -q "baseline" $azure_core_dir/vcpkg.json; then # build and install azure-core with the version compatible with system pre-installed openssl openssl_version=$(openssl version -v | awk '{print $2}') - if [[ "$openssl_version" == 1.1.1* ]]; then + if [[ $openssl_version == 1.1.1* ]]; then openssl_version="1.1.1n" fi - sed -i "s/\"version-string\"/\"builtin-baseline\": \"$vcpkg_commit_id\",\"version-string\"/" $azure_core_dir/vcpkg.json - sed -i "s/\"version-string\"/\"overrides\": [{ \"name\": \"openssl\", \"version-string\": \"$openssl_version\" }],\"version-string\"/" $azure_core_dir/vcpkg.json + sed -i='' "s/\"version-string\"/\"builtin-baseline\": \"$vcpkg_commit_id\",\"version-string\"/" $azure_core_dir/vcpkg.json + sed -i='' "s/\"version-string\"/\"overrides\": [{ \"name\": \"openssl\", \"version-string\": \"$openssl_version\" }],\"version-string\"/" $azure_core_dir/vcpkg.json fi ( - cd $azure_core_dir - cmake_install -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS=OFF + cd $azure_core_dir || exit + cmake_install -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS=OFF ) # install azure-identity ( - cd sdk/identity/azure-identity - cmake_install -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS=OFF + cd sdk/identity/azure-identity || exit + cmake_install -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS=OFF ) # install azure-storage-common ( - cd sdk/storage/azure-storage-common - cmake_install -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS=OFF + cd sdk/storage/azure-storage-common || exit + cmake_install -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS=OFF ) # install azure-storage-blobs ( - cd sdk/storage/azure-storage-blobs - cmake_install -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS=OFF + cd sdk/storage/azure-storage-blobs || exit + cmake_install -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS=OFF ) # install azure-storage-files-datalake ( - cd sdk/storage/azure-storage-files-datalake - cmake_install -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DBUILD_SHARED_LIBS=OFF + cd sdk/storage/azure-storage-files-datalake || exit + cmake_install -DCMAKE_BUILD_TYPE="${CMAKE_BUILD_TYPE}" -DBUILD_SHARED_LIBS=OFF ) + popd || exit } function install_hdfs_deps { # Dependencies for Hadoop testing - wget_and_untar https://archive.apache.org/dist/hadoop/common/hadoop-${HADOOP_VERSION}/hadoop-${HADOOP_VERSION}.tar.gz hadoop - cp -a ${DEPENDENCY_DIR}/hadoop /usr/local/ - wget ${WGET_OPTIONS} -P /usr/local/hadoop/share/hadoop/common/lib/ https://repo1.maven.org/maven2/junit/junit/4.11/junit-4.11.jar + wget_and_untar https://dlcdn.apache.org/hadoop/common/hadoop-"${HADOOP_VERSION}"/hadoop-"${HADOOP_VERSION}".tar.gz hadoop + cp -a "${DEPENDENCY_DIR}"/hadoop "$INSTALL_PREFIX" + wget "${WGET_OPTS[@]}" -P "$INSTALL_PREFIX"/hadoop/share/hadoop/common/lib/ https://repo1.maven.org/maven2/junit/junit/4.11/junit-4.11.jar +} + +function install_uv { + if command -v uv >/dev/null 2>&1; then + echo "uv is already installed." + else + echo "Installing uv..." + + export UV_TOOL_BIN_DIR="${UV_TOOL_BIN_DIR:-$INSTALL_PREFIX/bin}" + export UV_INSTALL_DIR=${UV_INSTALL_DIR:-"$UV_TOOL_BIN_DIR"} + + curl -LsSf https://astral.sh/uv/install.sh | sh + uv tool update-shell + fi +} + +function uv_install { + uv tool install "$@" || { + ret=$? + # exit code 2 means the binary already exists, so we can ignore that + [ "$ret" -eq 2 ] || exit "$ret" + } } diff --git a/scripts/setup-fedora.sh b/scripts/setup-fedora.sh new file mode 100755 index 000000000000..80d958f19a4b --- /dev/null +++ b/scripts/setup-fedora.sh @@ -0,0 +1,132 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# shellcheck source-path=SCRIPT_DIR + +# This script documents setting up a Fedora host for Velox +# development. Running it should make you ready to compile. +# +# Environment variables: +# * INSTALL_PREREQUISITES="N": Skip installation of packages for build. +# * PROMPT_ALWAYS_RESPOND="n": Automatically respond to interactive prompts. +# Use "n" to never wipe directories. +# +# You can also run individual functions below by specifying them as arguments: +# $ scripts/setup-fedora.sh install_googletest install_fmt +# + +set -efx -o pipefail +# Some of the packages must be build with the same compiler flags +# so that some low level types are the same size. Also, disable warnings. +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") +source "$SCRIPT_DIR"/setup-centos9.sh + +# Install packages required for build. +function install_build_prerequisites { + dnf update -y + dnf_install dnf-plugins-core # For ccache, ninja + dnf_install autoconf automake ccache git g++-14 libtool \ + ninja-build python3-pip python3-devel wget which + + # Set up gcc alternatives + sudo alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 140 \ + --slave /usr/bin/g++ g++ /usr/bin/g++-14 \ + --slave /usr/bin/gcov gcov /usr/bin/gcov-14 \ + --slave /usr/bin/gcc-ar gcc-ar /usr/bin/gcc-ar-14 \ + --slave /usr/bin/gcc-ranlib gcc-ranlib /usr/bin/gcc-ranlib-14 \ + --slave /usr/bin/gcc-nm gcc-nm /usr/bin/gcc-nm-14 + + install_uv + uv_install cmake + + if [[ ${USE_CLANG} != "false" ]]; then + install_clang15 + fi +} + +function install_velox_deps_from_dnf { + dnf_install \ + bison boost-devel c-ares-devel curl-devel double-conversion-devel \ + elfutils-libelf-devel flex fmt-devel gflags-devel glog-devel gmock-devel \ + gtest-devel libdwarf-devel libevent-devel libicu-devel \ + libsodium-devel libzstd-devel lz4-devel openssl-devel-engine \ + re2-devel snappy-devel thrift-devel xxhash-devel zlib-devel grpc-devel grpc-plugins + + install_faiss_deps +} + +function install_velox_deps { + run_and_time install_velox_deps_from_dnf + run_and_time install_gcs_sdk_cpp #grpc, abseil, protobuf + run_and_time install_fast_float + run_and_time install_folly + run_and_time install_fizz + run_and_time install_wangle + run_and_time install_mvfst + run_and_time install_fbthrift + run_and_time install_duckdb + run_and_time install_stemmer + run_and_time install_arrow + run_and_time install_xsimd # to new in fedora repos + run_and_time install_simdjson # to new in fedora repos + run_and_time install_geos # to new in fedora repos + run_and_time install_faiss +} + +(return 2>/dev/null) && return # If script was sourced, don't run commands. + +( + if [[ $# -ne 0 ]]; then + if [[ ${USE_CLANG} != "false" ]]; then + export CC=/usr/bin/clang-15 + export CXX=/usr/bin/clang++-15 + else + # Activate gcc 14 + export CXX=/usr/bin/gcc-14 + export CXX=/usr/bin/g++-14 + set -u + fi + + for cmd in "$@"; do + run_and_time "${cmd}" + done + echo "All specified dependencies installed!" + else + if [ "${INSTALL_PREREQUISITES:-Y}" == "Y" ]; then + echo "Installing build dependencies" + run_and_time install_build_prerequisites + else + echo "Skipping installation of build dependencies since INSTALL_PREREQUISITES is not set" + fi + if [[ ${USE_CLANG} != "false" ]]; then + export CC=/usr/bin/clang-15 + export CXX=/usr/bin/clang++-15 + else + # Activate gcc 14 + export CXX=/usr/bin/gcc-14 + export CXX=/usr/bin/g++-14 + set -u + fi + install_velox_deps + # BUILD_TESTING requires grpc + dnf_install grpc + echo "All dependencies for Velox installed!" + if [[ ${USE_CLANG} != "false" ]]; then + echo "To use clang for the Velox build set the CC and CXX environment variables in your session." + echo " export CC=/usr/bin/clang-15" + echo " export CXX=/usr/bin/clang++-15" + fi + dnf clean all + fi +) diff --git a/scripts/setup-helper-functions.sh b/scripts/setup-helper-functions.sh index d5867e736783..a50fb02ae0ec 100755 --- a/scripts/setup-helper-functions.sh +++ b/scripts/setup-helper-functions.sh @@ -23,10 +23,16 @@ NPROC=${BUILD_THREADS:-$(getconf _NPROCESSORS_ONLN)} CURL_OPTIONS=${CURL_OPTIONS:-""} CMAKE_OPTIONS=${CMAKE_OPTIONS:-""} +ARM_BUILD_TARGET=${ARM_BUILD_TARGET:-"local"} function run_and_time { - time "$@" || (echo "Failed to run $* ." ; exit 1 ) - { echo "+ Finished running $*"; } 2> /dev/null + time "$@" + if [ $? -ne 0 ]; then + echo "Failed to run $* ." + exit 1 + else + echo "+ Finished running $*" 2>/dev/null + fi } function prompt { @@ -34,14 +40,14 @@ function prompt { while true; do local input="${PROMPT_ALWAYS_RESPOND:-}" echo -n "$(tput bold)$* [Y, n]$(tput sgr0) " - [[ -z "${input}" ]] && read input - if [[ "${input}" == "Y" || "${input}" == "y" || "${input}" == "" ]]; then + [[ -z ${input} ]] && read input + if [[ ${input} == "Y" || ${input} == "y" || ${input} == "" ]]; then return 0 - elif [[ "${input}" == "N" || "${input}" == "n" ]]; then + elif [[ ${input} == "N" || ${input} == "n" ]]; then return 1 fi done - ) 2> /dev/null + ) 2>/dev/null } function github_checkout { @@ -49,10 +55,11 @@ function github_checkout { shift local VERSION=$1 shift - local GIT_CLONE_PARAMS=$@ - local DIRNAME=$(basename $REPO) + local GIT_CLONE_PARAMS=("$@") + local DIRNAME + DIRNAME=$(basename "$REPO") SUDO="${SUDO:-""}" - cd "${DEPENDENCY_DIR}" + cd "${DEPENDENCY_DIR}" || exit if [ -z "${DIRNAME}" ]; then echo "Failed to get repo name from ${REPO}" exit 1 @@ -61,9 +68,8 @@ function github_checkout { ${SUDO} rm -rf "${DIRNAME}" fi if [ ! -d "${DIRNAME}" ]; then - git clone -q -b $VERSION $GIT_CLONE_PARAMS "https://github.com/${REPO}.git" + git clone -q -b "$VERSION" "${GIT_CLONE_PARAMS[@]}" "https://github.com/${REPO}.git" fi - cd "${DIRNAME}" } # get_cxx_flags [$CPU_ARCH] @@ -81,130 +87,167 @@ function github_checkout { # CXX_FLAGS=$(get_cxx_flags) or # CXX_FLAGS=$(get_cxx_flags "avx") +# shellcheck disable=SC2120 function get_cxx_flags { local CPU_ARCH=${1:-""} - local OS=$(uname) - local MACHINE=$(uname -m) - - if [[ -z "$CPU_ARCH" ]]; then - if [ "$OS" = "Darwin" ]; then - if [ "$MACHINE" = "arm64" ]; then - CPU_ARCH="arm64" - else # x86_64 - local CPU_CAPABILITIES=$(sysctl -a | grep machdep.cpu.features | awk '{print tolower($0)}') - if [[ $CPU_CAPABILITIES =~ "avx" ]]; then - CPU_ARCH="avx" - else - CPU_ARCH="sse" - fi - fi - elif [ "$OS" = "Linux" ]; then - if [ "$MACHINE" = "aarch64" ]; then - CPU_ARCH="aarch64" - else # x86_64 - local CPU_CAPABILITIES=$(cat /proc/cpuinfo | grep flags | head -n 1| awk '{print tolower($0)}') - if [[ $CPU_CAPABILITIES =~ "avx" ]]; then - CPU_ARCH="avx" - elif [[ $CPU_CAPABILITIES =~ "sse" ]]; then - CPU_ARCH="sse" - fi - fi - else - echo "Unsupported platform $OS"; exit 1; - fi + local OS + OS=$(uname) + local MACHINE + MACHINE=$(uname -m) + + if [[ -z $CPU_ARCH ]]; then + if [ "$OS" = "Darwin" ]; then + if [ "$MACHINE" = "arm64" ]; then + CPU_ARCH="arm64" + else # x86_64 + local CPU_CAPABILITIES + CPU_CAPABILITIES=$(sysctl -a | grep machdep.cpu.features | awk '{print tolower($0)}') + if [[ $CPU_CAPABILITIES =~ "avx" ]]; then + CPU_ARCH="avx" + else + CPU_ARCH="sse" + fi + fi + elif [ "$OS" = "Linux" ]; then + if [ "$MACHINE" = "aarch64" ]; then + CPU_ARCH="aarch64" + else # x86_64 + local CPU_CAPABILITIES + CPU_CAPABILITIES=$(cat /proc/cpuinfo | grep flags | head -n 1 | awk '{print tolower($0)}') + if [[ $CPU_CAPABILITIES =~ "avx" ]]; then + CPU_ARCH="avx" + elif [[ $CPU_CAPABILITIES =~ "sse" ]]; then + CPU_ARCH="sse" + fi + fi + else + echo "Unsupported platform $OS" + exit 1 + fi fi case $CPU_ARCH in - "arm64") - echo -n "-mcpu=apple-m1+crc" + "arm64") + echo -n "-mcpu=apple-m1+crc" ;; - "avx") - echo -n "-mavx2 -mfma -mavx -mf16c -mlzcnt -mbmi2" + "avx") + echo -n "-mavx2 -mfma -mavx -mf16c -mlzcnt -mbmi2" ;; - "sse") - echo -n "-msse4.2 " + "sse") + echo -n "-msse4.2 " ;; - "aarch64") - # Read Arm MIDR_EL1 register to detect Arm cpu. - # https://developer.arm.com/documentation/100616/0301/register-descriptions/aarch64-system-registers/midr-el1--main-id-register--el1 - ARM_CPU_FILE="/sys/devices/system/cpu/cpu0/regs/identification/midr_el1" - - # https://gitlab.arm.com/telemetry-solution/telemetry-solution/-/blob/main/data/pmu/cpu/neoverse/neoverse-n1.json#L13 - # N1:d0c; N2:d49; V1:d40; - Neoverse_N1="d0c" - Neoverse_N2="d49" - Neoverse_V1="d40" - Neoverse_V2="d4f" - if [ -f "$ARM_CPU_FILE" ]; then - hex_ARM_CPU_DETECT=`cat $ARM_CPU_FILE` - # PartNum, [15:4]: The primary part number such as Neoverse N1/N2 core. - ARM_CPU_PRODUCT=${hex_ARM_CPU_DETECT: -4:3} - - if [ "$ARM_CPU_PRODUCT" = "$Neoverse_N1" ]; then - echo -n "-mcpu=neoverse-n1 " - elif [ "$ARM_CPU_PRODUCT" = "$Neoverse_N2" ]; then - echo -n "-mcpu=neoverse-n2 " - elif [ "$ARM_CPU_PRODUCT" = "$Neoverse_V1" ]; then - echo -n "-mcpu=neoverse-v1 " - elif [ "$ARM_CPU_PRODUCT" = "$Neoverse_V2" ]; then - # Read the JEDEC JEP-106 manufacturer ID to distinguish different Neoverse V2 cores - # https://developer.arm.com/documentation/ka001301/latest/ - SOC_ID_FILE="/sys/devices/soc0/soc_id" - GRACE_SOC_ID="jep106:036b:0241" - # Check for NVIDIA Grace which has various extensions - if [ -f "$SOC_ID_FILE" ] && [ "$(cat $SOC_ID_FILE)" = "$GRACE_SOC_ID" ]; then - echo -n "-mcpu=neoverse-v2+crypto+sha3+sm4+sve2-aes+sve2-sha3+sve2-sm4" - else - echo -n "-mcpu=neoverse-v2 " - fi + "aarch64") + # Read Arm MIDR_EL1 register to detect Arm cpu. + # https://developer.arm.com/documentation/100616/0301/register-descriptions/aarch64-system-registers/midr-el1--main-id-register--el1 + ARM_CPU_FILE="/sys/devices/system/cpu/cpu0/regs/identification/midr_el1" + + # https://gitlab.arm.com/telemetry-solution/telemetry-solution/-/blob/main/data/pmu/cpu/neoverse/neoverse-n1.json#L13 + # N1:d0c; N2:d49; V1:d40; + Neoverse_N1="d0c" + Neoverse_N2="d49" + Neoverse_V1="d40" + Neoverse_V2="d4f" + if [ -f "$ARM_CPU_FILE" ] && [ "$ARM_BUILD_TARGET" = "local" ]; then + hex_ARM_CPU_DETECT=$(cat $ARM_CPU_FILE) + # PartNum, [15:4]: The primary part number such as Neoverse N1/N2 core. + ARM_CPU_PRODUCT=${hex_ARM_CPU_DETECT: -4:3} + + if [ "$ARM_CPU_PRODUCT" = "$Neoverse_N1" ]; then + echo -n "-mcpu=neoverse-n1 " + elif [ "$ARM_CPU_PRODUCT" = "$Neoverse_N2" ]; then + echo -n "-mcpu=neoverse-n2 " + elif [ "$ARM_CPU_PRODUCT" = "$Neoverse_V1" ]; then + echo -n "-mcpu=neoverse-v1 " + elif [ "$ARM_CPU_PRODUCT" = "$Neoverse_V2" ]; then + # Read the JEDEC JEP-106 manufacturer ID to distinguish different Neoverse V2 cores + # https://developer.arm.com/documentation/ka001301/latest/ + SOC_ID_FILE="/sys/devices/soc0/soc_id" + GRACE_SOC_ID="jep106:036b:0241" + # Check for NVIDIA Grace which has various extensions + if [ -f "$SOC_ID_FILE" ] && [ "$(cat $SOC_ID_FILE)" = "$GRACE_SOC_ID" ]; then + echo -n "-mcpu=neoverse-v2+crypto+sha3+sm4+sve2-aes+sve2-sha3+sve2-sm4" else - echo -n "-march=armv8-a+crc+crypto " + echo -n "-mcpu=neoverse-v2 " fi else - echo -n "" + echo -n "-march=armv8-a+crc+crypto " fi + else + echo -n "-march=armv8-a+crc+crypto " + fi ;; *) echo -n "Architecture not supported!" + ;; esac } +detect_sve_flags() { + if grep -q "sve" /proc/cpuinfo; then + ARCH_FLAGS="-march=armv8-a+sve" + SVE_VECTOR_BITS=$( + gcc $ARCH_FLAGS -o detect_sve_vector -xc++ - -lstdc++ < + #include + int main() { + std::cout << svcntb() * 8 << std::endl; + return 0; + } +EOF + ./detect_sve_vector 2>/dev/null + ) + + if [ "$SVE_VECTOR_BITS" ]; then + echo "-msve-vector-bits=$SVE_VECTOR_BITS -DSVE_BITS=$SVE_VECTOR_BITS" + fi + + rm -f detect_sve_vector + fi +} + +if [[ ${BASH_SOURCE[0]} == "${0}" && $1 == "detect_sve_flags" ]]; then + detect_sve_flags +fi + function wget_and_untar { local URL=$1 local DIR=$2 mkdir -p "${DEPENDENCY_DIR}" - pushd "${DEPENDENCY_DIR}" + pushd "${DEPENDENCY_DIR}" || exit SUDO="${SUDO:-""}" if [ -d "${DIR}" ]; then if prompt "${DIR} already exists. Delete?"; then ${SUDO} rm -rf "${DIR}" else - popd + popd || exit return fi fi mkdir -p "${DIR}" - pushd "${DIR}" - curl ${CURL_OPTIONS} -L "${URL}" > $2.tar.gz - tar -xz --strip-components=1 -f $2.tar.gz - popd - popd + pushd "${DIR}" || exit + # Use ${VAR:+"$VAR"} pattern to only include CURL_OPTIONS if it's not empty + # as curl >=8.6.0 rejects empty arguments + curl ${CURL_OPTIONS:+${CURL_OPTIONS}} -L "${URL}" -o "$2".tar.gz + tar -xz --strip-components=1 --no-same-owner -f "$2".tar.gz + popd || exit + popd || exit } function cmake_install_dir { - pushd "${DEPENDENCY_DIR}/$1" + pushd "${DEPENDENCY_DIR}/$1" || exit # remove the directory argument shift - cmake_install $@ - popd + cmake_install "$@" + popd || exit } function cmake_install { - local NAME=$(basename "$(pwd)") + local NAME + NAME=$(basename "$(pwd)") local BINARY_DIR=_build SUDO="${SUDO:-""}" if [ -d "${BINARY_DIR}" ]; then @@ -222,8 +265,9 @@ function cmake_install { COMPILER_FLAGS+=${EXTRA_PKG_CXXFLAGS} # CMAKE_POSITION_INDEPENDENT_CODE is required so that Velox can be built into dynamic libraries \ - cmake -Wno-dev ${CMAKE_OPTIONS} -B"${BINARY_DIR}" \ + cmake -Wno-dev "${CMAKE_OPTIONS}" -B"${BINARY_DIR}" \ -GNinja \ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 \ -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ "${INSTALL_PREFIX+-DCMAKE_PREFIX_PATH=}${INSTALL_PREFIX-}" \ "${INSTALL_PREFIX+-DCMAKE_INSTALL_PREFIX=}${INSTALL_PREFIX-}" \ @@ -231,6 +275,9 @@ function cmake_install { -DBUILD_TESTING=OFF \ "$@" # Exit if the build fails. - cmake --build "${BINARY_DIR}" "-j ${NPROC}" || { echo 'build failed' ; exit 1; } + cmake --build "${BINARY_DIR}" "-j ${NPROC}" || { + echo 'build failed' + exit 1 + } ${SUDO} cmake --install "${BINARY_DIR}" } diff --git a/scripts/setup-macos.sh b/scripts/setup-macos.sh index ceae1e5eed22..a10295a372c7 100755 --- a/scripts/setup-macos.sh +++ b/scripts/setup-macos.sh @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# shellcheck source-path=SCRIPT_DIR # This script documents setting up a macOS host for Velox # development. Running it should make you ready to compile. @@ -28,26 +29,28 @@ set -e # Exit on error. set -x # Print commands that are executed. -SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") export INSTALL_PREFIX=${INSTALL_PREFIX:-"$(pwd)/deps-install"} -source $SCRIPTDIR/setup-common.sh -PYTHON_VENV=${PYTHON_VENV:-"${SCRIPTDIR}/../.venv"} +source "$SCRIPT_DIR"/setup-common.sh +PYTHON_VENV=${PYTHON_VENV:-"${SCRIPT_DIR}/../.venv"} # Allow installed package headers to be picked up before brew package headers # by tagging the brew packages to be system packages. # This is used during package builds. -export OS_CXXFLAGS=" -isystem $(brew --prefix)/include " +OS_CXXFLAGS=" -isystem $(brew --prefix)/include " +export OS_CXXFLAGS +export CMAKE_POLICY_VERSION_MINIMUM="3.5" DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} -MACOS_VELOX_DEPS="bison flex gflags glog googletest icu4c libevent libsodium lz4 lzo openssl protobuf@21 simdjson snappy xz zstd" -MACOS_BUILD_DEPS="ninja cmake ccache" +MACOS_VELOX_DEPS="bison flex gflags glog googletest icu4c libevent libsodium lz4 openssl protobuf@21 simdjson snappy xz xxhash zstd" + +MACOS_BUILD_DEPS="ninja cmake" SUDO="${SUDO:-""}" function update_brew { DEFAULT_BREW_PATH=/usr/local/bin/brew - if [ `arch` == "arm64" ] ; - then - DEFAULT_BREW_PATH=$(which brew) ; + if [ "$(arch)" == "arm64" ]; then + DEFAULT_BREW_PATH=$(which brew) fi BREW_PATH=${BREW_PATH:-$DEFAULT_BREW_PATH} $BREW_PATH update --auto-update --verbose @@ -56,42 +59,47 @@ function update_brew { function install_from_brew { pkg=$1 - if [[ "${pkg}" =~ ^([0-9a-z-]*):([0-9](\.[0-9\])*)$ ]]; - then + if [[ ${pkg} =~ ^([0-9a-z-]*):([0-9](\.[0-9\])*)$ ]]; then pkg=${BASH_REMATCH[1]} ver=${BASH_REMATCH[2]} echo "Installing '${pkg}' at '${ver}'" tap="velox/local-${pkg}" brew tap-new "${tap}" brew extract "--version=${ver}" "${pkg}" "${tap}" - brew install "${tap}/${pkg}@${ver}" || ( echo "Failed to install ${tap}/${pkg}@${ver}" ; exit 1 ) + brew install "${tap}/${pkg}@${ver}" || ( + echo "Failed to install ${tap}/${pkg}@${ver}" + exit 1 + ) else - ( brew install --formula "${pkg}" && echo "Installation of ${pkg} is successful" || brew upgrade --formula "$pkg" ) || ( echo "Failed to install ${pkg}" ; exit 1 ) + (brew install --formula "${pkg}" && echo "Installation of ${pkg} is successful" || brew upgrade --formula "$pkg") || ( + echo "Failed to install ${pkg}" + exit 1 + ) fi } function install_build_prerequisites { - for pkg in ${MACOS_BUILD_DEPS} - do - install_from_brew ${pkg} + for pkg in ${MACOS_BUILD_DEPS}; do + install_from_brew "${pkg}" done - if [ ! -f ${PYTHON_VENV}/pyvenv.cfg ]; then + if [ ! -f "${PYTHON_VENV}"/pyvenv.cfg ]; then echo "Creating Python Virtual Environment at ${PYTHON_VENV}" - python3 -m venv ${PYTHON_VENV} - fi - source ${PYTHON_VENV}/bin/activate; pip3 install cmake-format regex pyyaml - if [ ! -f /usr/local/bin/ccache ]; then - curl -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-darwin.tar.gz > ccache.tar.gz - tar -xf ccache.tar.gz - mv ccache-4.10.2-darwin/ccache /usr/local/bin/ - rm -rf ccache-4.10.2-darwin ccache.tar.gz + python3 -m venv "${PYTHON_VENV}" fi + source "${PYTHON_VENV}"/bin/activate + pip3 install regex pyyaml + + # Install ccache + curl -L https://github.com/ccache/ccache/releases/download/v"${CCACHE_VERSION}"/ccache-"${CCACHE_VERSION}"-darwin.tar.gz -o ccache.tar.gz + tar -xf ccache.tar.gz + $SUDO mkdir -p "$INSTALL_PREFIX"/bin + $SUDO mv ccache-"${CCACHE_VERSION}"-darwin/ccache "$INSTALL_PREFIX"/bin + rm -rf ccache-"${CCACHE_VERSION}"-darwin ccache.tar.gz } function install_velox_deps_from_brew { - for pkg in ${MACOS_VELOX_DEPS} - do - install_from_brew ${pkg} + for pkg in ${MACOS_VELOX_DEPS}; do + install_from_brew "${pkg}" done } @@ -103,11 +111,11 @@ function install_s3 { } function install_gcs { - install_gcs-sdk-cpp + install_gcs_sdk_cpp } function install_abfs { - install_azure-storage-sdk-cpp + install_azure_storage_sdk_cpp } function install_hdfs { @@ -122,6 +130,48 @@ function install_adapters { run_and_time install_hdfs } +function install_duckdb_clang { + clang_major_version=$(echo | clang -dM -E - | grep __clang_major__ | awk '{print $3}') + # Clang17 requires this. See issue #13215. + if [ "${clang_major_version}" -ge 17 ]; then + EXTRA_PKG_CXXFLAGS=" -Wno-missing-template-arg-list-after-template-kw" install_duckdb + else + install_duckdb + fi +} + +function install_faiss_deps { + brew install openblas + brew install libomp +} + +function install_faiss { + if [[ $BUILD_FAISS == "true" ]]; then + # Install OpenBLAS and libomp if not already installed + install_faiss_deps + + wget_and_untar "https://github.com/facebookresearch/faiss/archive/refs/tags/v${FAISS_VERSION}.tar.gz" faiss + + local cmake_args + cmake_args=( + -DFAISS_ENABLE_GPU=OFF + -DFAISS_ENABLE_PYTHON=OFF + -DFAISS_ENABLE_REMOTE=OFF + -DFAISS_ENABLE_GPU_TESTS=OFF + -DFAISS_ENABLE_BENCHMARKS=OFF + -DFAISS_ENABLE_GPU=OFF + -DFAISS_ENABLE_MKL=OFF + ) + + local libomp_prefix + libomp_prefix=$(brew --prefix libomp) + cmake_args+=( + "-DCMAKE_PREFIX_PATH=${libomp_prefix}" + ) + cmake_install_dir faiss "${cmake_args[@]}" + fi +} + function install_velox_deps { run_and_time install_velox_deps_from_brew run_and_time install_ranges_v3 @@ -136,14 +186,17 @@ function install_velox_deps { run_and_time install_mvfst run_and_time install_fbthrift run_and_time install_xsimd - run_and_time install_duckdb run_and_time install_stemmer - run_and_time install_thrift + # We allow arrow to bundle thrift on MacOS due to issues with bison and flex. + # See https://github.com/facebook/fbthrift/pull/317 for an explanation. + # run_and_time install_thrift run_and_time install_arrow + run_and_time install_duckdb_clang run_and_time install_geos + run_and_time install_faiss } -(return 2> /dev/null) && return # If script was sourced, don't run commands. +(return 2>/dev/null) && return # If script was sourced, don't run commands. ( update_brew @@ -160,7 +213,7 @@ function install_velox_deps { echo "Skipping installation of build dependencies since INSTALL_PREREQUISITES is not set" fi install_velox_deps - echo "All deps for Velox installed! Now try \"make\"" + echo 'All deps for Velox installed! Now try "make"' fi ) diff --git a/scripts/setup-manylinux.sh b/scripts/setup-manylinux.sh index b82db9263fbb..424db964cce3 100755 --- a/scripts/setup-manylinux.sh +++ b/scripts/setup-manylinux.sh @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# shellcheck source-path=SCRIPT_DIR # This script documents setting up a Centos9 host for Velox # development. Running it should make you ready to compile. @@ -28,27 +29,16 @@ set -efx -o pipefail # Some of the packages must be build with the same compiler flags # so that some low level types are the same size. Also, disable warnings. -SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") -source $SCRIPTDIR/setup-helper-functions.sh -NPROC=${BUILD_THREADS:-$(getconf _NPROCESSORS_ONLN)} -export CXXFLAGS=$(get_cxx_flags) # Used by boost. +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") +source "$SCRIPT_DIR"/setup-common.sh +CXXFLAGS=$(get_cxx_flags) # Used by boost. +export CXXFLAGS export CFLAGS=${CXXFLAGS//"-std=c++17"/} # Used by LZO. -CMAKE_BUILD_TYPE="${BUILD_TYPE:-Release}" -VELOX_BUILD_SHARED=${VELOX_BUILD_SHARED:-"OFF"} #Build folly and gflags shared for use in libvelox.so. -BUILD_DUCKDB="${BUILD_DUCKDB:-true}" USE_CLANG="${USE_CLANG:-false}" export INSTALL_PREFIX=${INSTALL_PREFIX:-"/usr/local"} DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)/deps-download} -FB_OS_VERSION="v2025.04.28.00" -FMT_VERSION="10.1.1" -BOOST_VERSION="boost-1.84.0" -THRIFT_VERSION="v0.21.0" -# Note: when updating arrow check if thrift needs an update as well. -ARROW_VERSION="15.0.0" -STEMMER_VERSION="2.2.0" -DUCKDB_VERSION="v0.8.1" -FAST_FLOAT_VERSION="v8.0.2" +export THRIFT_VERSION="v0.21.0" # CMake 4.0 removed support for cmake minimums of <=3.5 and will fail builds, this overrides it export CMAKE_POLICY_VERSION_MINIMUM="3.5" @@ -70,7 +60,6 @@ function install_build_prerequisites { dnf_install ninja-build cmake ccache gcc-toolset-12 git wget which dnf_install autoconf automake python3-devel pip libtool - if [[ ${USE_CLANG} != "false" ]]; then install_clang15 fi @@ -81,7 +70,7 @@ function install_velox_deps_from_dnf { dnf_install libevent-devel \ openssl-devel re2-devel libzstd-devel lz4-devel double-conversion-devel \ libdwarf-devel elfutils-libelf-devel curl-devel libicu-devel bison flex \ - libsodium-devel zlib-devel xxhash-devel + libsodium-devel zlib-devel gtest-devel gmock-devel xxhash-devel } function install_conda { @@ -91,187 +80,36 @@ function install_conda { function install_gflags { # Remove an older version if present. dnf remove -y gflags - wget_and_untar https://github.com/gflags/gflags/archive/v2.2.2.tar.gz gflags + wget_and_untar https://github.com/gflags/gflags/archive/"${GFLAGS_VERSION}".tar.gz gflags cmake_install_dir gflags -DBUILD_SHARED_LIBS=ON -DBUILD_STATIC_LIBS=ON -DBUILD_gflags_LIB=ON -DLIB_SUFFIX=64 } -function install_glog { - wget_and_untar https://github.com/google/glog/archive/v0.6.0.tar.gz glog - cmake_install_dir glog -DBUILD_SHARED_LIBS=ON -} - -function install_lzo { - wget_and_untar http://www.oberhumer.com/opensource/lzo/download/lzo-2.10.tar.gz lzo - ( - cd ${DEPENDENCY_DIR}/lzo - ./configure --prefix=${INSTALL_PREFIX} --enable-shared --disable-static --docdir=/usr/share/doc/lzo-2.10 - make "-j${NPROC}" - make install - ) -} - -function install_boost { - wget_and_untar https://github.com/boostorg/boost/releases/download/${BOOST_VERSION}/${BOOST_VERSION}.tar.gz boost - ( - cd ${DEPENDENCY_DIR}/boost - if [[ ${USE_CLANG} != "false" ]]; then - ./bootstrap.sh --prefix=${INSTALL_PREFIX} --with-toolset="clang-15" - # Switch the compiler from the clang-15 toolset which doesn't exist (clang-15.jam) to - # clang of version 15 when toolset clang-15 is used. - # This reconciles the project-config.jam generation with what the b2 build system allows for customization. - sed -i 's/using clang-15/using clang : 15/g' project-config.jam - ${SUDO} ./b2 "-j${NPROC}" -d0 install threading=multi toolset=clang-15 --without-python - else - ./bootstrap.sh --prefix=${INSTALL_PREFIX} - ${SUDO} ./b2 "-j${NPROC}" -d0 install threading=multi --without-python - fi - ) -} - -function install_snappy { - wget_and_untar https://github.com/google/snappy/archive/1.1.8.tar.gz snappy - cmake_install_dir snappy -DSNAPPY_BUILD_TESTS=OFF -} - -function install_fmt { - wget_and_untar https://github.com/fmtlib/fmt/archive/${FMT_VERSION}.tar.gz fmt - cmake_install_dir fmt -DFMT_TEST=OFF -} - -function install_protobuf { - wget_and_untar https://github.com/protocolbuffers/protobuf/releases/download/v21.8/protobuf-all-21.8.tar.gz protobuf - ( - cd ${DEPENDENCY_DIR}/protobuf - ./configure CXXFLAGS="-fPIC" --prefix=${INSTALL_PREFIX} - make "-j${NPROC}" - make install - ldconfig - ) -} - -function install_fizz { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, lin) if the selection is not consistent across users of folly. - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebookincubator/fizz/archive/refs/tags/${FB_OS_VERSION}.tar.gz fizz - cmake_install_dir fizz/fizz -DBUILD_TESTS=OFF -} - -function install_fast_float { - wget_and_untar https://github.com/fastfloat/fast_float/archive/refs/tags/${FAST_FLOAT_VERSION}.tar.gz fast_float - cmake_install_dir fast_float -DBUILD_TESTS=OFF -} - -function install_folly { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, lin) if the selection is not consistent across users of folly. - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/folly/archive/refs/tags/${FB_OS_VERSION}.tar.gz folly - cmake_install_dir folly -DBUILD_SHARED_LIBS="$VELOX_BUILD_SHARED" -DBUILD_TESTS=OFF -DFOLLY_HAVE_INT128_T=ON -} - -function install_wangle { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, lin) if the selection is not consistent across users of folly. - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/wangle/archive/refs/tags/${FB_OS_VERSION}.tar.gz wangle - cmake_install_dir wangle/wangle -DBUILD_TESTS=OFF -} - -function install_fbthrift { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, lin) if the selection is not consistent across users of folly. - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/fbthrift/archive/refs/tags/${FB_OS_VERSION}.tar.gz fbthrift - cmake_install_dir fbthrift -Denable_tests=OFF -DBUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -} - -function install_mvfst { - # Folly Portability.h being used to decide whether or not support coroutines - # causes issues (build, lin) if the selection is not consistent across users of folly. - EXTRA_PKG_CXXFLAGS=" -DFOLLY_CFG_NO_COROUTINES" - wget_and_untar https://github.com/facebook/mvfst/archive/refs/tags/${FB_OS_VERSION}.tar.gz mvfst - cmake_install_dir mvfst -DBUILD_TESTS=OFF -} - -function install_duckdb { - if $BUILD_DUCKDB ; then - echo 'Building DuckDB' - wget_and_untar https://github.com/duckdb/duckdb/archive/refs/tags/${DUCKDB_VERSION}.tar.gz duckdb - cmake_install_dir duckdb -DBUILD_UNITTESTS=OFF -DENABLE_SANITIZER=OFF -DENABLE_UBSAN=OFF -DBUILD_SHELL=OFF -DEXPORT_DLL_SYMBOLS=OFF -DCMAKE_BUILD_TYPE=Release - fi -} - -function install_stemmer { - wget_and_untar https://snowballstem.org/dist/libstemmer_c-${STEMMER_VERSION}.tar.gz stemmer - ( - cd ${DEPENDENCY_DIR}/stemmer - sed -i '/CPPFLAGS=-Iinclude/ s/$/ -fPIC/' Makefile - make clean && make "-j${NPROC}" - ${SUDO} cp libstemmer.a ${INSTALL_PREFIX}/lib/ - ${SUDO} cp include/libstemmer.h ${INSTALL_PREFIX}/include/ - ) -} - -function install_thrift { - wget_and_untar https://github.com/apache/thrift/archive/${THRIFT_VERSION}.tar.gz thrift - - EXTRA_CXXFLAGS="-O3 -fPIC" - # Clang will generate warnings and they need to be suppressed, otherwise the build will fail. - if [[ ${USE_CLANG} != "false" ]]; then - EXTRA_CXXFLAGS="-O3 -fPIC -Wno-inconsistent-missing-override -Wno-unused-but-set-variable" - fi - - CXX_FLAGS="$EXTRA_CXXFLAGS" cmake_install_dir thrift \ - -DBUILD_SHARED_LIBS=OFF \ - -DBUILD_COMPILER=OFF \ - -DBUILD_EXAMPLES=OFF \ - -DBUILD_TUTORIALS=OFF \ - -DCMAKE_DEBUG_POSTFIX= \ - -DWITH_AS3=OFF \ - -DWITH_CPP=ON \ - -DWITH_C_GLIB=OFF \ - -DWITH_JAVA=OFF \ - -DWITH_JAVASCRIPT=OFF \ - -DWITH_LIBEVENT=OFF \ - -DWITH_NODEJS=OFF \ - -DWITH_PYTHON=OFF \ - -DWITH_QT5=OFF \ - -DWITH_ZLIB=OFF -} - -function install_arrow { - wget_and_untar https://github.com/apache/arrow/archive/apache-arrow-${ARROW_VERSION}.tar.gz arrow - cmake_install_dir arrow/cpp \ - -DARROW_PARQUET=OFF \ - -DARROW_WITH_THRIFT=ON \ - -DARROW_WITH_LZ4=ON \ - -DARROW_WITH_SNAPPY=ON \ - -DARROW_WITH_ZLIB=ON \ - -DARROW_WITH_ZSTD=ON \ - -DARROW_JEMALLOC=OFF \ - -DARROW_SIMD_LEVEL=NONE \ - -DARROW_RUNTIME_SIMD_LEVEL=NONE \ - -DARROW_WITH_UTF8PROC=OFF \ - -DARROW_TESTING=ON \ - -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ - -DCMAKE_BUILD_TYPE=Release \ - -DARROW_BUILD_STATIC=ON \ - -DBOOST_ROOT=${INSTALL_PREFIX} -} - function install_cuda { # See https://developer.nvidia.com/cuda-downloads - dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel9/x86_64/cuda-rhel9.repo - local dashed="$(echo $1 | tr '.' '-')" - dnf install -y cuda-nvcc-$dashed cuda-cudart-devel-$dashed cuda-nvrtc-devel-$dashed cuda-driver-devel-$dashed + local arch + arch=$(uname -m) + local repo_url + + if [[ $arch == "x86_64" ]]; then + repo_url="https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo" + elif [[ $arch == "aarch64" ]]; then + # Using SBSA (Server Base System Architecture) repository for ARM64 servers + repo_url="https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo" + else + echo "Unsupported architecture: $arch" >&2 + return 1 + fi + + dnf config-manager --add-repo "$repo_url" + local dashed + dashed="$(echo "$1" | tr '.' '-')" + dnf install -y cuda-nvcc-"$dashed" cuda-cudart-devel-"$dashed" cuda-nvrtc-devel-"$dashed" cuda-driver-devel-"$dashed" libnvjitlink-devel-"$dashed" } function install_velox_deps { run_and_time install_velox_deps_from_dnf run_and_time install_gflags run_and_time install_glog - run_and_time install_lzo run_and_time install_snappy run_and_time install_boost run_and_time install_protobuf @@ -288,7 +126,7 @@ function install_velox_deps { run_and_time install_arrow } -(return 2> /dev/null) && return # If script was sourced, don't run commands. +(return 2>/dev/null) && return # If script was sourced, don't run commands. ( if [[ $# -ne 0 ]]; then diff --git a/scripts/setup-ubuntu.sh b/scripts/setup-ubuntu.sh index 33711884bc8e..221fe26853d1 100755 --- a/scripts/setup-ubuntu.sh +++ b/scripts/setup-ubuntu.sh @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# shellcheck source-path=SCRIPT_DIR +# shellcheck disable=SC2076 # This script documents setting up a Ubuntu host for Velox # development. Running it should make you ready to compile. @@ -27,15 +29,15 @@ # Minimal setup for Ubuntu 22.04. set -eufx -o pipefail -SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") -source $SCRIPTDIR/setup-common.sh +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") +source "$SCRIPT_DIR"/setup-common.sh SUDO="${SUDO:-"sudo --preserve-env"}" USE_CLANG="${USE_CLANG:-false}" export INSTALL_PREFIX=${INSTALL_PREFIX:-"/usr/local"} DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)/deps-download} VERSION=$(cat /etc/os-release | grep VERSION_ID) -PYTHON_VENV=${PYTHON_VENV:-"${SCRIPTDIR}/../.venv"} +PYTHON_VENV=${PYTHON_VENV:-"${SCRIPT_DIR}/../.venv"} # On Ubuntu 20.04 dependencies need to be built using gcc11. # On Ubuntu 22.04 gcc11 is already the system gcc installed. @@ -44,6 +46,10 @@ if [[ ${VERSION} =~ "20.04" ]]; then export CXX=/usr/bin/g++-11 fi +if lscpu | grep -q "sve"; then + $SUDO apt install -y gcc-12 g++-12 +fi + function install_clang15 { if [[ ! ${VERSION} =~ "22.04" && ! ${VERSION} =~ "24.04" ]]; then echo "Warning: using the Clang configuration is for Ubuntu 22.04 and 24.04. Errors might occur." @@ -82,13 +88,8 @@ function install_build_prerequisites { libtool \ wget - if [ ! -f ${PYTHON_VENV}/pyvenv.cfg ]; then - echo "Creating Python Virtual Environment at ${PYTHON_VENV}" - python3 -m venv ${PYTHON_VENV} - fi - source ${PYTHON_VENV}/bin/activate; - # Install to /usr/local to make it available to all users. - ${SUDO} pip3 install cmake==3.28.3 + install_uv + uv_install cmake==3.30.4 install_gcc11_if_needed @@ -98,14 +99,6 @@ function install_build_prerequisites { } -# Install packages required to fix format -function install_format_prerequisites { - pip3 install regex - ${SUDO} apt install -y \ - clang-format \ - cmake-format -} - # Install packages required for build. function install_velox_deps_from_apt { ${SUDO} apt update @@ -126,7 +119,6 @@ function install_velox_deps_from_apt { libre2-dev \ libsnappy-dev \ libsodium-dev \ - liblzo2-dev \ libelf-dev \ libdwarf-dev \ bison \ @@ -138,7 +130,7 @@ function install_velox_deps_from_apt { function install_conda { MINICONDA_PATH="${HOME:-/opt}/miniconda-for-velox" - if [ -e ${MINICONDA_PATH} ]; then + if [ -e "${MINICONDA_PATH}" ]; then echo "File or directory already exists: ${MINICONDA_PATH}" return fi @@ -149,25 +141,56 @@ function install_conda { fi ( mkdir -p conda && cd conda - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-$ARCH.sh -O Miniconda3-latest-Linux-$ARCH.sh - bash Miniconda3-latest-Linux-$ARCH.sh -b -p $MINICONDA_PATH + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-"$ARCH".sh -O Miniconda3-latest-Linux-"$ARCH".sh + bash Miniconda3-latest-Linux-"$ARCH".sh -b -p "$MINICONDA_PATH" ) } function install_cuda { # See https://developer.nvidia.com/cuda-downloads + local arch + arch=$(uname -m) + local os_ver + + if [[ ${VERSION} =~ "24.04" ]]; then + os_ver="ubuntu2404" + elif [[ ${VERSION} =~ "22.04" ]]; then + os_ver="ubuntu2204" + elif [[ ${VERSION} =~ "20.04" ]]; then + os_ver="ubuntu2004" + else + echo "Unsupported Ubuntu version: ${VERSION}" >&2 + return 1 + fi + + local cuda_repo + if [[ $arch == "x86_64" ]]; then + cuda_repo="${os_ver}/x86_64" + elif [[ $arch == "aarch64" ]]; then + cuda_repo="${os_ver}/sbsa" + else + echo "Unsupported architecture: $arch" >&2 + return 1 + fi + if ! dpkg -l cuda-keyring 1>/dev/null; then - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb + wget https://developer.download.nvidia.com/compute/cuda/repos/${cuda_repo}/cuda-keyring_1.1-1_all.deb $SUDO dpkg -i cuda-keyring_1.1-1_all.deb rm cuda-keyring_1.1-1_all.deb $SUDO apt update fi - local dashed="$(echo $1 | tr '.' '-')" + + local dashed + dashed="$(echo "$1" | tr '.' '-')" + $SUDO apt install -y \ - cuda-compat-$dashed \ - cuda-driver-dev-$dashed \ - cuda-minimal-build-$dashed \ - cuda-nvrtc-dev-$dashed + cuda-compat-"$dashed" \ + cuda-driver-dev-"$dashed" \ + cuda-minimal-build-"$dashed" \ + cuda-nvrtc-dev-"$dashed" \ + libcufile-dev-"$dashed" \ + libnvjitlink-dev-"$dashed" \ + libnuma1 } function install_s3 { @@ -179,18 +202,18 @@ function install_s3 { function install_gcs { # Dependencies of GCS, probably a workaround until the docker image is rebuilt - apt install -y --no-install-recommends libc-ares-dev libcurl4-openssl-dev - install_gcs-sdk-cpp + ${SUDO} apt install -y --no-install-recommends libc-ares-dev libcurl4-openssl-dev + install_gcs_sdk_cpp } function install_abfs { # Dependencies of Azure Storage Blob cpp - apt install -y openssl libxml2-dev - install_azure-storage-sdk-cpp + ${SUDO} apt install -y openssl libxml2-dev + install_azure_storage_sdk_cpp } function install_hdfs { - apt install -y --no-install-recommends libxml2-dev libgsasl7-dev uuid-dev openjdk-8-jdk + ${SUDO} apt install -y --no-install-recommends libxml2-dev libgsasl7-dev uuid-dev openjdk-8-jdk install_hdfs_deps } @@ -201,6 +224,10 @@ function install_adapters { run_and_time install_hdfs } +function install_faiss_deps { + ${SUDO} apt-get install -y libopenblas-dev libomp-dev +} + function install_velox_deps { run_and_time install_velox_deps_from_apt run_and_time install_fmt @@ -220,15 +247,15 @@ function install_velox_deps { run_and_time install_xsimd run_and_time install_simdjson run_and_time install_geos + run_and_time install_faiss } function install_apt_deps { install_build_prerequisites - install_format_prerequisites install_velox_deps_from_apt } -(return 2> /dev/null) && return # If script was sourced, don't run commands. +(return 2>/dev/null) && return # If script was sourced, don't run commands. ( if [[ ${USE_CLANG} != "false" ]]; then diff --git a/scripts/setup-versions.sh b/scripts/setup-versions.sh index 836a97c3c7c7..685794a88160 100755 --- a/scripts/setup-versions.sh +++ b/scripts/setup-versions.sh @@ -1,4 +1,5 @@ #!/bin/bash +# shellcheck disable=SC2034 # Copyright (c) Facebook, Inc. and its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,28 +20,28 @@ # Build dependencies versions. FB_OS_VERSION="v2025.04.28.00" -FMT_VERSION="10.1.1" +FMT_VERSION="11.2.0" BOOST_VERSION="boost-1.84.0" ARROW_VERSION="15.0.0" DUCKDB_VERSION="v0.8.1" PROTOBUF_VERSION="21.8" -MINIO_VERSION="2022-05-26T05-48-41Z" -MINIO_BINARY_NAME="minio-2022-05-26" -AWS_SDK_VERSION="1.11.321" XSIMD_VERSION="10.0.0" -SIMDJSON_VERSION="3.9.3" +SIMDJSON_VERSION="4.1.0" CPR_VERSION="1.10.5" DOUBLE_CONVERSION_VERSION="v3.1.5" RANGE_V3_VERSION="0.12.0" -RE2_VERSION="2022-02-01" +RE2_VERSION="2024-07-02" GFLAGS_VERSION="v2.2.2" GLOG_VERSION="v0.6.0" LZO_VERSION="2.10" SNAPPY_VERSION="1.1.8" -THRIFT_VERSION="v0.16.0" +THRIFT_VERSION="${THRIFT_VERSION:-v0.16.0}" STEMMER_VERSION="2.2.0" -GEOS_VERSION="3.10.2" +GEOS_VERSION="3.10.7" +# shellcheck disable=SC2034 +FAISS_VERSION="1.11.0" FAST_FLOAT_VERSION="v8.0.2" +CCACHE_VERSION="4.11.3" # Adapter related versions. ABSEIL_VERSION="20240116.2" @@ -48,5 +49,8 @@ GRPC_VERSION="v1.48.1" CRC32_VERSION="1.1.2" NLOHMAN_JSON_VERSION="v3.11.3" GOOGLE_CLOUD_CPP_VERSION="v2.22.0" -HADOOP_VERSION="3.3.0" +HADOOP_VERSION="3.3.6" AZURE_SDK_VERSION="12.8.0" +MINIO_VERSION="2022-05-26T05-48-41Z" +MINIO_BINARY_NAME="minio-2022-05-26" +AWS_SDK_VERSION="1.11.654" diff --git a/scripts/tests/TestFramework.sh b/scripts/tests/TestFramework.sh index 03b73d3c4e71..8874bad39607 100644 --- a/scripts/tests/TestFramework.sh +++ b/scripts/tests/TestFramework.sh @@ -46,15 +46,15 @@ printf -- "------ %-52s ------ ------\n" ---- Test() { Expect=Pass - if [ x"$1" = "x!" ] ; then Expect=Fail; shift + if [ "$1" = "!" ] ; then Expect=Fail; shift fi Title=$1; shift - TestsRun=`expr $TestsRun + 1` + TestsRun=$(expr $TestsRun + 1) - printf "%6d %-52s" $TestsRun "$Title" 1>&2 + printf "%6d %-52s" "$TestsRun" "$Title" 1>&2 - Start=`Clock` + Start=$(Clock) if [ "$SetMinusX" = 1 ] ; then set -x fi @@ -70,12 +70,12 @@ Calc() { TestCleanUp() { local result="$1" - Now=`Clock`; Elapse=$(Calc $Now - $Start) + Now=$(Clock); Elapse=$(Calc "$Now" - "$Start") - if [ "$result" = Pass ] ; then TestsPassed=`expr $TestsPassed + 1` - else TestsFailed=`expr $TestsFailed + 1` + if [ "$result" = Pass ] ; then TestsPassed=$(expr $TestsPassed + 1) + else TestsFailed=$(expr $TestsFailed + 1) fi - printf " $result %7.3f\n" $Elapse 1>&2 + printf " $result %7.3f\n" "$Elapse" 1>&2 if [ "$result" = Fail -a $ExitOnFail = 1 ] ; then exit 1 @@ -132,7 +132,7 @@ CompareArgs() { CompareEval() { while [ $# -ge 2 ] ; do - x=`$1` + x=$($1) y=$2 shift; shift diff --git a/scripts/tests/test_LicenseHeader.sh b/scripts/tests/test_LicenseHeader.sh index fe6fc5e98169..6cadc727b9a0 100755 --- a/scripts/tests/test_LicenseHeader.sh +++ b/scripts/tests/test_LicenseHeader.sh @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -SRC=$(dirname $0) +SRC=$(dirname "$0") DATA=$SRC/data -source $SRC/TestFramework.sh +source "$SRC"/TestFramework.sh LICENSE_HEADER=$SRC/../license-header.py LICENSE_HEADER_FILE=$SRC/../../license.header license_header() { - $SRC/../license-header.py --header $LICENSE_HEADER_FILE "$@" + "$SRC"/../license-header.py --header "$LICENSE_HEADER_FILE" "$@" } TestFile() { @@ -30,10 +30,10 @@ TestFile() { local expected="$1" ; shift Test "license-header.py $title file $file" - cp $DATA/$file . - license_header "$@" -i $file + cp "$DATA"/"$file" . + license_header "$@" -i "$file" - DiffFiles $file $DATA/$expected + DiffFiles "$file" "$DATA"/"$expected" } # Test header insertion @@ -69,15 +69,15 @@ TestFile Almost foo.almost.sh foo.expected.sh TestFile Almost hashbang.almost.sh hashbang.expected.sh Test "List of Files in stdin" - cp $DATA/foo.sh . + cp "$DATA"/foo.sh . echo foo.sh | license_header -i - - DiffFiles foo.sh $DATA/foo.expected.sh + DiffFiles foo.sh "$DATA"/foo.expected.sh Test "Header Check Only - OK" - cp $DATA/foo.expected.sh . - cp $DATA/foo.expected.cpp . + cp "$DATA"/foo.expected.sh . + cp "$DATA"/foo.expected.cpp . if license_header "$@" -k foo.expected.sh foo.expected.cpp ; then Pass @@ -86,8 +86,8 @@ Test "Header Check Only - OK" fi Test "Header Check Only - Fix" - cp $DATA/foo.sh . - cp $DATA/foo.cpp . + cp "$DATA"/foo.sh . + cp "$DATA"/foo.cpp . if license_header "$@" -k foo.sh foo.cpp; then Fail @@ -96,10 +96,10 @@ Test "Header Check Only - Fix" fi Test "Header Check Verbose" - cp $DATA/foo.sh . - cp $DATA/foo.expected.sh . - cp $DATA/foo.cpp . - cp $DATA/foo.expected.cpp . + cp "$DATA"/foo.sh . + cp "$DATA"/foo.expected.sh . + cp "$DATA"/foo.cpp . + cp "$DATA"/foo.expected.cpp . result=$(license_header "$@" -vk foo.sh foo.expected.sh foo.cpp foo.expected.cpp) expected="\ diff --git a/scripts/velox_env_linux.yml b/scripts/velox_env_linux.yml index 58e722cee5ac..2560bb3b29af 100644 --- a/scripts/velox_env_linux.yml +++ b/scripts/velox_env_linux.yml @@ -26,7 +26,7 @@ dependencies: - binutils - bison - clangxx=14 - - cmake=3.28.3 + - cmake=3.30.4 - ccache - flex - gxx=12 # has to be installed to get clang to work... @@ -58,7 +58,6 @@ dependencies: - libtool - libunwind - lz4-c - - lzo - openssl=1.1 - re2 - simdjson diff --git a/scripts/velox_env_mac.yml b/scripts/velox_env_mac.yml index 8c24af8e31a5..34c58679d734 100644 --- a/scripts/velox_env_mac.yml +++ b/scripts/velox_env_mac.yml @@ -56,7 +56,6 @@ dependencies: - libsodium - libtool - lz4-c - - lzo - openssl=1.1.* - re2 - snappy diff --git a/velox/CMakeLists.txt b/velox/CMakeLists.txt index 45d6a4540545..1bf6976e8e50 100644 --- a/velox/CMakeLists.txt +++ b/velox/CMakeLists.txt @@ -33,9 +33,7 @@ if(${VELOX_ENABLE_EXAMPLES} AND ${VELOX_ENABLE_EXPRESSION}) add_subdirectory(examples) endif() -if(${VELOX_ENABLE_BENCHMARKS} OR ${VELOX_ENABLE_BENCHMARKS_BASIC}) - add_subdirectory(benchmarks) -endif() +add_subdirectory(benchmarks) if(${VELOX_ENABLE_EXPRESSION}) add_subdirectory(expression) @@ -54,15 +52,15 @@ if(${VELOX_ENABLE_TPCH_CONNECTOR}) add_subdirectory(tpch/gen) endif() +if(${VELOX_ENABLE_TPCDS_CONNECTOR}) + add_subdirectory(tpcds/gen) +endif() + add_subdirectory(functions) # depends on md5 (postgresql) add_subdirectory(connectors) if(${VELOX_ENABLE_EXEC}) add_subdirectory(exec) - # Disable runner from pyvelox builds - if(${VELOX_BUILD_RUNNER}) - add_subdirectory(runner) - endif() endif() if(${VELOX_ENABLE_DUCKDB}) @@ -80,11 +78,6 @@ if(VELOX_ENABLE_WAVE OR VELOX_ENABLE_CUDF) endif() endif() -# substrait converter -if(${VELOX_ENABLE_SUBSTRAIT}) - add_subdirectory(substrait) -endif() - if(${VELOX_BUILD_TESTING}) add_subdirectory(tool) endif() diff --git a/velox/benchmarks/CMakeLists.txt b/velox/benchmarks/CMakeLists.txt index 45466b1ed3fd..48dc1fde93fc 100644 --- a/velox/benchmarks/CMakeLists.txt +++ b/velox/benchmarks/CMakeLists.txt @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_subdirectory(basic) -set(velox_benchmark_deps +if(${VELOX_BUILD_TEST_UTILS}) + set( + velox_benchmark_deps velox_type velox_vector velox_vector_fuzzer @@ -26,40 +27,47 @@ set(velox_benchmark_deps Folly::follybenchmark ${DOUBLE_CONVERSION} gflags::gflags - glog::glog) + glog::glog + ) -add_library(velox_benchmark_builder ExpressionBenchmarkBuilder.cpp) -target_link_libraries( - velox_benchmark_builder ${velox_benchmark_deps}) -# This is a workaround for the use of VectorTestBase.h which includes gtest.h -target_link_libraries( - velox_benchmark_builder GTest::gtest) + add_library(velox_benchmark_builder ExpressionBenchmarkBuilder.cpp) + target_link_libraries(velox_benchmark_builder ${velox_benchmark_deps}) + + # This is a workaround for the use of VectorTestBase.h which includes gtest.h + target_link_libraries(velox_benchmark_builder GTest::gtest) + + add_library(velox_query_benchmark QueryBenchmarkBase.cpp) + + target_link_libraries( + velox_query_benchmark + velox_aggregates + velox_connector + velox_exec + velox_exec_test_lib + velox_dwio_common + velox_dwio_common_exception + velox_dwio_parquet_reader + velox_dwio_common_test_utils + velox_exception + velox_memory + velox_process + velox_serialization + velox_encode + velox_type + velox_type_fbhive + velox_caching + velox_vector_test_lib + Folly::folly + Folly::follybenchmark + fmt::fmt + ) +endif() + +if(${VELOX_ENABLE_BENCHMARKS_BASIC}) + add_subdirectory(basic) +endif() if(${VELOX_ENABLE_BENCHMARKS}) add_subdirectory(tpch) add_subdirectory(filesystem) endif() - -add_library(velox_query_benchmark QueryBenchmarkBase.cpp) -target_link_libraries( - velox_query_benchmark - velox_aggregates - velox_exec - velox_exec_test_lib - velox_dwio_common - velox_dwio_common_exception - velox_dwio_parquet_reader - velox_dwio_common_test_utils - velox_hive_connector - velox_exception - velox_memory - velox_process - velox_serialization - velox_encode - velox_type - velox_type_fbhive - velox_caching - velox_vector_test_lib - Folly::folly - Folly::follybenchmark - fmt::fmt) diff --git a/velox/benchmarks/QueryBenchmarkBase.cpp b/velox/benchmarks/QueryBenchmarkBase.cpp index 469b6c00542d..852e55d34173 100644 --- a/velox/benchmarks/QueryBenchmarkBase.cpp +++ b/velox/benchmarks/QueryBenchmarkBase.cpp @@ -15,12 +15,37 @@ */ #include "velox/benchmarks/QueryBenchmarkBase.h" +#include +#include "velox/common/base/SuccinctPrinter.h" +#include "velox/common/file/FileSystems.h" +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/dwio/dwrf/RegisterDwrfReader.h" +#include "velox/dwio/parquet/RegisterParquetReader.h" +#include "velox/exec/Split.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" -DEFINE_string(data_format, "parquet", "Data format"); +namespace { -DEFINE_validator( - data_format, - &facebook::velox::QueryBenchmarkBase::validateDataFormat); +bool validateDataFormat(const char* flagname, const std::string& value) { + if ((value.compare("parquet") == 0) || (value.compare("dwrf") == 0)) { + return true; + } + std::cout + << fmt::format( + "Invalid value for --{}: {}. Allowed values are [\"parquet\", \"dwrf\"]", + flagname, + value) + << std::endl; + return false; +} +} // namespace + +DEFINE_string(data_format, "parquet", "Data format: parquet or dwrf."); + +DEFINE_validator(data_format, &validateDataFormat); DEFINE_bool( include_custom_stats, @@ -97,20 +122,24 @@ using namespace facebook::velox::dwio::common; namespace facebook::velox { -// static -bool QueryBenchmarkBase::validateDataFormat( - const char* flagname, - const std::string& value) { - if ((value.compare("parquet") == 0) || (value.compare("dwrf") == 0)) { - return true; +std::string RunStats::toString(bool detail) const { + std::stringstream out; + out << succinctNanos(micros * 1000) << " " + << succinctBytes(rawInputBytes / (micros / 1000000.0)) << "/s raw, " + << succinctNanos(userNanos) << " user " << succinctNanos(systemNanos) + << " system (" << (100 * (userNanos + systemNanos) / (micros * 1000)) + << "%)"; + if (!flags.empty()) { + out << " flags: "; + for (auto& pair : flags) { + out << pair.first << "=" << pair.second << " "; + } } - std::cout - << fmt::format( - "Invalid value for --{}: {}. Allowed values are [\"parquet\", \"dwrf\"]", - flagname, - value) - << std::endl; - return false; + out << std::endl << "======" << std::endl; + if (detail) { + out << std::endl << output << std::endl; + } + return out.str(); } // static @@ -175,6 +204,20 @@ void QueryBenchmarkBase::initialize() { std::make_unique(FLAGS_num_io_threads); // Add new values into the hive configuration... + auto properties = makeConnectorProperties(); + + // Create hive connector with config... + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = + factory.newConnector(kHiveConnectorId, properties, ioExecutor_.get()); + connector::registerConnector(hiveConnector); + parquet::registerParquetReaderFactory(); + dwrf::registerDwrfReaderFactory(); +} + +std::shared_ptr +QueryBenchmarkBase::makeConnectorProperties() { + // Default behaviour identical to the original hard-coded version. auto configurationValues = std::unordered_map(); configurationValues[connector::hive::HiveConfig::kMaxCoalescedBytes] = std::to_string(FLAGS_max_coalesced_bytes); @@ -182,19 +225,8 @@ void QueryBenchmarkBase::initialize() { FLAGS_max_coalesced_distance_bytes; configurationValues[connector::hive::HiveConfig::kPrefetchRowGroups] = std::to_string(FLAGS_parquet_prefetch_rowgroups); - auto properties = std::make_shared( - std::move(configurationValues)); - - // Create hive connector with config... - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector(kHiveConnectorId, properties, ioExecutor_.get()); - connector::registerConnector(hiveConnector); - parquet::registerParquetReaderFactory(); - dwrf::registerDwrfReaderFactory(); + return std::make_shared( + std::move(configurationValues), true); } std::vector> @@ -218,13 +250,16 @@ void QueryBenchmarkBase::shutdown() { } std::pair, std::vector> -QueryBenchmarkBase::run(const TpchPlan& tpchPlan) { +QueryBenchmarkBase::run( + const TpchPlan& tpchPlan, + const std::unordered_map& queryConfigs) { int32_t repeat = 0; try { for (;;) { CursorParameters params; params.maxDrivers = FLAGS_num_drivers; params.planNode = tpchPlan.plan; + params.queryConfigs = queryConfigs; params.queryConfigs[core::QueryConfig::kMaxSplitPreloadPerDriver] = std::to_string(FLAGS_split_preload_per_driver); const int numSplitsPerFile = FLAGS_num_splits_per_file; @@ -326,10 +361,14 @@ void QueryBenchmarkBase::runCombinations(int32_t level) { auto tvNanos = [](struct timeval tv) { return tv.tv_sec * 1000000000 + tv.tv_usec * 1000; }; - stats.userNanos = tvNanos(final.ru_utime) - tvNanos(start.ru_utime); - stats.systemNanos = tvNanos(final.ru_stime) - tvNanos(start.ru_stime); + if (!stats.userNanos) { + stats.userNanos = tvNanos(final.ru_utime) - tvNanos(start.ru_utime); + stats.systemNanos = tvNanos(final.ru_stime) - tvNanos(start.ru_stime); + } + } + if (!stats.micros) { + stats.micros = micros; } - stats.micros = micros; stats.output = result.str(); for (auto i = 0; i < parameters_.size(); ++i) { std::string name; @@ -340,12 +379,19 @@ void QueryBenchmarkBase::runCombinations(int32_t level) { } else { auto& flag = parameters_[level].flag; for (auto& value : parameters_[level].values) { - std::string result = - gflags::SetCommandLineOption(flag.c_str(), value.c_str()); - if (result.empty()) { - LOG(ERROR) << "Failed to set " << flag << "=" << value; + if (flag.substr(0, 2) == "s-") { + auto config = flag.substr(2, flag.size() - 2); + config_[config] = value; + std::cout << "Set session config " << config << " = " << value + << std::endl; + } else { + std::string result = + gflags::SetCommandLineOption(flag.c_str(), value.c_str()); + if (result.empty()) { + LOG(ERROR) << "Failed to set " << flag << "=" << value; + } + std::cout << result << std::endl; } - std::cout << result << std::endl; runCombinations(level + 1); } } diff --git a/velox/benchmarks/QueryBenchmarkBase.h b/velox/benchmarks/QueryBenchmarkBase.h index a5a172a8216c..755b67d36a9c 100644 --- a/velox/benchmarks/QueryBenchmarkBase.h +++ b/velox/benchmarks/QueryBenchmarkBase.h @@ -17,31 +17,16 @@ #pragma once #include -#include -#include - #include +#include #include #include +#include #include +#include #include -#include - -#include "velox/common/base/SuccinctPrinter.h" -#include "velox/common/file/FileSystems.h" -#include "velox/common/memory/MmapAllocator.h" -#include "velox/connectors/hive/HiveConfig.h" -#include "velox/connectors/hive/HiveConnector.h" -#include "velox/dwio/common/Options.h" -#include "velox/dwio/dwrf/RegisterDwrfReader.h" -#include "velox/dwio/parquet/RegisterParquetReader.h" -#include "velox/exec/PlanNodeStats.h" -#include "velox/exec/Split.h" -#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/Cursor.h" #include "velox/exec/tests/utils/TpchQueryBuilder.h" -#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" -#include "velox/functions/prestosql/registration/RegistrationFunctions.h" -#include "velox/parse/TypeResolver.h" DECLARE_string(test_flags_file); DECLARE_bool(include_results); @@ -58,22 +43,7 @@ struct RunStats { int64_t systemNanos{0}; std::string output; - std::string toString(bool detail) { - std::stringstream out; - out << succinctNanos(micros * 1000) << " " - << succinctBytes(rawInputBytes / (micros / 1000000.0)) << "/s raw, " - << succinctNanos(userNanos) << " user " << succinctNanos(systemNanos) - << " system (" << (100 * (userNanos + systemNanos) / (micros * 1000)) - << "%), flags: "; - for (auto& pair : flags) { - out << pair.first << "=" << pair.second << " "; - } - out << std::endl << "======" << std::endl; - if (detail) { - out << std::endl << output << std::endl; - } - return out.str(); - } + std::string toString(bool detail) const; }; struct ParameterDim { @@ -85,9 +55,10 @@ class QueryBenchmarkBase { public: virtual ~QueryBenchmarkBase() = default; virtual void initialize(); - void shutdown(); + virtual void shutdown(); std::pair, std::vector> run( - const exec::test::TpchPlan& tpchPlan); + const exec::test::TpchPlan& tpchPlan, + const std::unordered_map& queryConfigs = {}); virtual std::vector> listSplits( const std::string& path, @@ -96,10 +67,6 @@ class QueryBenchmarkBase { static void ensureTaskCompletion(exec::Task* task); - static bool validateDataFormat( - const char* flagname, - const std::string& value); - static void printResults( const std::vector& results, std::ostream& out); @@ -115,11 +82,17 @@ class QueryBenchmarkBase { void runAllCombinations(); + virtual std::shared_ptr makeConnectorProperties(); + protected: std::unique_ptr ioExecutor_; std::unique_ptr cacheExecutor_; std::shared_ptr allocator_; std::shared_ptr cache_; + + // QueryConfig properties. May be part of parameter sweep. + std::unordered_map config_; + // Parameter combinations to try. Each element specifies a flag and possible // values. All permutations are tried. std::vector parameters_; diff --git a/velox/benchmarks/basic/CMakeLists.txt b/velox/benchmarks/basic/CMakeLists.txt index b15bf4a07f11..7ba8c8fee604 100644 --- a/velox/benchmarks/basic/CMakeLists.txt +++ b/velox/benchmarks/basic/CMakeLists.txt @@ -11,62 +11,72 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -set(velox_benchmark_deps - velox_type - velox_vector - velox_vector_fuzzer - velox_expression - velox_parse_parser - velox_parse_utils - velox_parse_expression - velox_serialization - velox_benchmark_builder - velox_vector_test_lib - Folly::folly - Folly::follybenchmark - ${DOUBLE_CONVERSION} - gflags::gflags - glog::glog) +set( + velox_benchmark_deps + velox_type + velox_vector + velox_vector_fuzzer + velox_expression + velox_parse_parser + velox_parse_utils + velox_parse_expression + velox_serialization + velox_benchmark_builder + velox_vector_test_lib + Folly::folly + Folly::follybenchmark + ${DOUBLE_CONVERSION} + gflags::gflags + glog::glog +) add_executable(velox_benchmark_basic_simple_arithmetic SimpleArithmetic.cpp) -target_link_libraries( - velox_benchmark_basic_simple_arithmetic ${velox_benchmark_deps}) +target_link_libraries(velox_benchmark_basic_simple_arithmetic ${velox_benchmark_deps}) add_executable(velox_benchmark_basic_comparison_conjunct ComparisonConjunct.cpp) -target_link_libraries( - velox_benchmark_basic_comparison_conjunct ${velox_benchmark_deps}) +target_link_libraries(velox_benchmark_basic_comparison_conjunct ${velox_benchmark_deps}) add_executable(velox_benchmark_basic_simple_cast SimpleCastExpr.cpp) -target_link_libraries( - velox_benchmark_basic_simple_cast ${velox_benchmark_deps}) +target_link_libraries(velox_benchmark_basic_simple_cast ${velox_benchmark_deps}) add_executable(velox_benchmark_basic_decoded_vector DecodedVector.cpp) -target_link_libraries( - velox_benchmark_basic_decoded_vector ${velox_benchmark_deps}) +target_link_libraries(velox_benchmark_basic_decoded_vector ${velox_benchmark_deps}) + +add_executable(velox_benchmark_estimate_flat_size EstimateFlatSizeBenchmark.cpp) +target_link_libraries(velox_benchmark_estimate_flat_size ${velox_benchmark_deps}) add_executable(velox_benchmark_basic_selectivity_vector SelectivityVector.cpp) -target_link_libraries( - velox_benchmark_basic_selectivity_vector ${velox_benchmark_deps}) +target_link_libraries(velox_benchmark_basic_selectivity_vector ${velox_benchmark_deps}) add_executable(velox_benchmark_basic_vector_compare VectorCompare.cpp) target_link_libraries( - velox_benchmark_basic_vector_compare ${velox_benchmark_deps} - velox_vector_test_lib) + velox_benchmark_basic_vector_compare + ${velox_benchmark_deps} + velox_vector_test_lib +) add_executable(velox_benchmark_basic_vector_slice VectorSlice.cpp) target_link_libraries( - velox_benchmark_basic_vector_slice ${velox_benchmark_deps} - velox_vector_test_lib) + velox_benchmark_basic_vector_slice + ${velox_benchmark_deps} + velox_vector_test_lib +) add_executable(velox_benchmark_feature_normalization FeatureNormalization.cpp) target_link_libraries( - velox_benchmark_feature_normalization ${velox_benchmark_deps} velox_row_fast - velox_functions_prestosql) + velox_benchmark_feature_normalization + ${velox_benchmark_deps} + velox_row_fast + velox_functions_prestosql +) add_executable(velox_benchmark_basic_preproc Preproc.cpp) target_link_libraries( - velox_benchmark_basic_preproc ${velox_benchmark_deps} - velox_functions_prestosql velox_vector_test_lib) + velox_benchmark_basic_preproc + ${velox_benchmark_deps} + velox_functions_prestosql + velox_vector_test_lib +) add_executable(velox_like_tpch_benchmark LikeTpchBenchmark.cpp) target_link_libraries( @@ -74,7 +84,8 @@ target_link_libraries( ${velox_benchmark_deps} velox_functions_lib velox_tpch_gen - velox_vector_test_lib) + velox_vector_test_lib +) add_executable(velox_like_benchmark LikeBenchmark.cpp) target_link_libraries( @@ -82,16 +93,19 @@ target_link_libraries( ${velox_benchmark_deps} velox_functions_lib velox_functions_prestosql - velox_vector_test_lib) + velox_vector_test_lib +) add_executable(velox_benchmark_basic_vector_fuzzer VectorFuzzer.cpp) target_link_libraries( - velox_benchmark_basic_vector_fuzzer ${velox_benchmark_deps} - velox_vector_test_lib velox_common_fuzzer_util) + velox_benchmark_basic_vector_fuzzer + ${velox_benchmark_deps} + velox_vector_test_lib + velox_common_fuzzer_util +) add_executable(velox_cast_benchmark CastBenchmark.cpp) -target_link_libraries( - velox_cast_benchmark ${velox_benchmark_deps} velox_vector_test_lib) +target_link_libraries(velox_cast_benchmark ${velox_benchmark_deps} velox_vector_test_lib) add_executable(velox_format_datetime_benchmark FormatDateTimeBenchmark.cpp) target_link_libraries( @@ -100,4 +114,5 @@ target_link_libraries( velox_vector_test_lib velox_functions_spark velox_functions_prestosql - velox_row_fast) + velox_row_fast +) diff --git a/velox/benchmarks/basic/EstimateFlatSizeBenchmark.cpp b/velox/benchmarks/basic/EstimateFlatSizeBenchmark.cpp new file mode 100644 index 000000000000..a0ad3a8e4613 --- /dev/null +++ b/velox/benchmarks/basic/EstimateFlatSizeBenchmark.cpp @@ -0,0 +1,230 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/functions/lib/benchmarks/FunctionBenchmarkBase.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" + +DEFINE_int64(fuzzer_seed, 99887766, "Seed for random input dataset generator"); +DEFINE_int64(vector_size, 10000, "Size of vectors to benchmark"); +DEFINE_int64(row_children, 1000, "Number of children in row vector"); +DEFINE_int64(dict_nesting, 5, "Number of dictionary nesting levels"); + +using namespace facebook::velox; +using namespace facebook::velox::test; + +namespace { + +class EstimateFlatSizeBenchmark + : public functions::test::FunctionBenchmarkBase { + public: + EstimateFlatSizeBenchmark(size_t vectorSize, size_t rowChildren) + : FunctionBenchmarkBase(), + vectorSize_(vectorSize), + numRowChildren_(rowChildren) { + VectorFuzzer::Options opts; + opts.vectorSize = vectorSize_; + opts.nullRatio = 0; + opts.stringLength = 10; + VectorFuzzer fuzzer(opts, pool(), FLAGS_fuzzer_seed); + + // Create flat vectors of different types + flatBigintVector_ = fuzzer.fuzzFlat(BIGINT()); + flatVarcharVector_ = fuzzer.fuzzFlat(VARCHAR()); + + // Create constant vectors + constantBigintVector_ = fuzzer.fuzzConstant(BIGINT()); + constantVarcharVector_ = fuzzer.fuzzConstant(VARCHAR()); + + // Create dictionary vectors + dictionaryBigintVector_ = fuzzer.fuzzDictionary(fuzzer.fuzzFlat(BIGINT())); + dictionaryVarcharVector_ = + fuzzer.fuzzDictionary(fuzzer.fuzzFlat(VARCHAR())); + + // Create nested dictionary vector + nestedDictionaryVector_ = fuzzer.fuzzFlat(BIGINT()); + for (size_t i = 0; i < 5; ++i) { + nestedDictionaryVector_ = fuzzer.fuzzDictionary(nestedDictionaryVector_); + } + + // Create a nested row vector with complex children + std::vector names; + std::vector types; + std::vector children; + + // Create a mix of different types of children + for (size_t i = 0; i < numRowChildren_; ++i) { + names.push_back(fmt::format("field{}", i)); + + // Create different types of children based on the index + switch (i % 5) { + case 0: { + // Flat vector + types.push_back(BIGINT()); + children.push_back(fuzzer.fuzzFlat(BIGINT())); + break; + } + case 1: { + // Dictionary vector + types.push_back(VARCHAR()); + children.push_back(fuzzer.fuzzDictionary(fuzzer.fuzzFlat(VARCHAR()))); + break; + } + case 2: { + // Nested row vector + std::vector nestedNames = { + "nested1", "nested2", "nested3"}; + std::vector nestedTypes = {BIGINT(), VARCHAR(), DOUBLE()}; + auto nestedRowType = + ROW(std::move(nestedNames), std::move(nestedTypes)); + types.push_back(nestedRowType); + children.push_back(fuzzer.fuzzRow(nestedRowType, vectorSize_)); + break; + } + case 3: { + // Array vector + auto arrayType = ARRAY(BIGINT()); + types.push_back(arrayType); + children.push_back(fuzzer.fuzzArray(BIGINT(), vectorSize_)); + break; + } + case 4: { + // Map vector + auto mapType = MAP(VARCHAR(), BIGINT()); + types.push_back(mapType); + children.push_back(fuzzer.fuzzMap(VARCHAR(), BIGINT(), vectorSize_)); + break; + } + } + } + + auto rowType = ROW(std::move(names), std::move(types)); + rowVector_ = std::make_shared( + pool(), rowType, nullptr, vectorSize_, std::move(children)); + } + + // Benchmark methods for estimateFlatSize + void estimateFlatSizeFlatBigint() { + auto size = flatBigintVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + void estimateFlatSizeFlatVarchar() { + auto size = flatVarcharVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + void estimateFlatSizeConstantBigint() { + auto size = constantBigintVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + void estimateFlatSizeConstantVarchar() { + auto size = constantVarcharVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + void estimateFlatSizeDictionaryBigint() { + auto size = dictionaryBigintVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + void estimateFlatSizeDictionaryVarchar() { + auto size = dictionaryVarcharVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + void estimateFlatSizeNestedDictionary() { + auto size = nestedDictionaryVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + void estimateFlatSizeRowVector() { + auto size = rowVector_->estimateFlatSize(); + folly::doNotOptimizeAway(size); + } + + private: + const size_t vectorSize_; + const size_t numRowChildren_; + + VectorPtr flatBigintVector_; + VectorPtr flatVarcharVector_; + VectorPtr constantBigintVector_; + VectorPtr constantVarcharVector_; + VectorPtr dictionaryBigintVector_; + VectorPtr dictionaryVarcharVector_; + VectorPtr nestedDictionaryVector_; + RowVectorPtr rowVector_; +}; + +std::unique_ptr benchmark; + +template +void run(Func&& func, size_t iterations = 100) { + for (auto i = 0; i < iterations; i++) { + func(); + } +} + +BENCHMARK(estimateFlatSizeFlatBigint) { + run([&] { benchmark->estimateFlatSizeFlatBigint(); }); +} + +BENCHMARK(estimateFlatSizeFlatVarchar) { + run([&] { benchmark->estimateFlatSizeFlatVarchar(); }); +} + +BENCHMARK(estimateFlatSizeConstantBigint) { + run([&] { benchmark->estimateFlatSizeConstantBigint(); }); +} + +BENCHMARK(estimateFlatSizeConstantVarchar) { + run([&] { benchmark->estimateFlatSizeConstantVarchar(); }); +} + +BENCHMARK(estimateFlatSizeDictionaryBigint) { + run([&] { benchmark->estimateFlatSizeDictionaryBigint(); }); +} + +BENCHMARK(estimateFlatSizeDictionaryVarchar) { + run([&] { benchmark->estimateFlatSizeDictionaryVarchar(); }); +} + +BENCHMARK(estimateFlatSizeNestedDictionary) { + run([&] { benchmark->estimateFlatSizeNestedDictionary(); }); +} + +BENCHMARK(estimateFlatSizeRowVector) { + run([&] { benchmark->estimateFlatSizeRowVector(); }); +} + +} // namespace + +int main(int argc, char* argv[]) { + folly::Init init{&argc, &argv}; + ::gflags::ParseCommandLineFlags(&argc, &argv, true); + memory::MemoryManager::initialize(memory::MemoryManager::Options{}); + + benchmark = std::make_unique( + FLAGS_vector_size, FLAGS_row_children); + folly::runBenchmarks(); + benchmark.reset(); + return 0; +} diff --git a/velox/benchmarks/basic/LikeTpchBenchmark.cpp b/velox/benchmarks/basic/LikeTpchBenchmark.cpp index 1f4dab648f94..d904ae5118b2 100644 --- a/velox/benchmarks/basic/LikeTpchBenchmark.cpp +++ b/velox/benchmarks/basic/LikeTpchBenchmark.cpp @@ -127,8 +127,9 @@ class LikeFunctionsBenchmark : public FunctionBaseTest, return tpchSupplier->childAt(6); } default: - VELOX_FAIL(fmt::format( - "Tpch data generation for case {} is not supported", tpchCase)); + VELOX_FAIL( + fmt::format( + "Tpch data generation for case {} is not supported", tpchCase)); } } diff --git a/velox/benchmarks/basic/Preproc.cpp b/velox/benchmarks/basic/Preproc.cpp index f3fbca54b4f7..51e57dea9d94 100644 --- a/velox/benchmarks/basic/Preproc.cpp +++ b/velox/benchmarks/basic/Preproc.cpp @@ -186,11 +186,12 @@ std::vector> signatures() { std::vector> signatures; for (auto type : {"tinyint", "smallint", "integer", "bigint", "real", "double"}) { - signatures.push_back(exec::FunctionSignatureBuilder() - .returnType(type) - .argumentType(type) - .argumentType(type) - .build()); + signatures.push_back( + exec::FunctionSignatureBuilder() + .returnType(type) + .argumentType(type) + .argumentType(type) + .build()); } return signatures; } diff --git a/velox/benchmarks/basic/SimpleCastExpr.cpp b/velox/benchmarks/basic/SimpleCastExpr.cpp index 1ae05f03dd88..e2708e7f77d0 100644 --- a/velox/benchmarks/basic/SimpleCastExpr.cpp +++ b/velox/benchmarks/basic/SimpleCastExpr.cpp @@ -19,6 +19,7 @@ #include +#include "velox/core/Expressions.h" #include "velox/functions/lib/benchmarks/FunctionBenchmarkBase.h" #include "velox/vector/fuzzer/VectorFuzzer.h" diff --git a/velox/benchmarks/basic/VectorSlice.cpp b/velox/benchmarks/basic/VectorSlice.cpp index 44ad91f8200a..54b706265e71 100644 --- a/velox/benchmarks/basic/VectorSlice.cpp +++ b/velox/benchmarks/basic/VectorSlice.cpp @@ -37,9 +37,10 @@ constexpr int kVectorSize = 16 << 10; struct BenchmarkData { BenchmarkData() - : pool_(memory::memoryManager()->addLeafPool( - "BenchmarkData", - FLAGS_use_thread_safe_memory_usage_track)) { + : pool_( + memory::memoryManager()->addLeafPool( + "BenchmarkData", + FLAGS_use_thread_safe_memory_usage_track)) { VectorFuzzer::Options opts; opts.nullRatio = 0.01; opts.vectorSize = kVectorSize; diff --git a/velox/benchmarks/filesystem/CMakeLists.txt b/velox/benchmarks/filesystem/CMakeLists.txt index 10a5ebdf710d..5f5b411b94ab 100644 --- a/velox/benchmarks/filesystem/CMakeLists.txt +++ b/velox/benchmarks/filesystem/CMakeLists.txt @@ -16,16 +16,12 @@ add_library(velox_read_benchmark_lib ReadBenchmark.cpp) target_link_libraries( velox_read_benchmark_lib - PUBLIC velox_file velox_time Folly::folly gflags::gflags) + PUBLIC velox_file velox_time Folly::folly gflags::gflags +) add_executable(velox_read_benchmark ReadBenchmarkMain.cpp) target_link_libraries( velox_read_benchmark - PRIVATE - velox_read_benchmark_lib - velox_hive_config - velox_s3fs - velox_hdfs - velox_abfs - velox_gcs) + PRIVATE velox_read_benchmark_lib velox_hive_config velox_s3fs velox_hdfs velox_abfs velox_gcs +) diff --git a/velox/benchmarks/filesystem/ReadBenchmark.h b/velox/benchmarks/filesystem/ReadBenchmark.h index e033bc953aca..69b26082d647 100644 --- a/velox/benchmarks/filesystem/ReadBenchmark.h +++ b/velox/benchmarks/filesystem/ReadBenchmark.h @@ -180,8 +180,9 @@ class ReadBenchmark { } else { std::vector> ranges; for (auto start = 0; start < rangeSize; start += size + gap) { - ranges.push_back(folly::Range( - globalScratch.buffer.data() + start, size)); + ranges.push_back( + folly::Range( + globalScratch.buffer.data() + start, size)); if (gap && start + gap < rangeSize) { ranges.push_back(folly::Range(nullptr, gap)); } diff --git a/velox/benchmarks/tpch/CMakeLists.txt b/velox/benchmarks/tpch/CMakeLists.txt index 1ac7c3f1aee7..1dbe29ea7d2c 100644 --- a/velox/benchmarks/tpch/CMakeLists.txt +++ b/velox/benchmarks/tpch/CMakeLists.txt @@ -36,9 +36,9 @@ target_link_libraries( velox_vector_test_lib Folly::follybenchmark Folly::folly - fmt::fmt) + fmt::fmt +) add_executable(velox_tpch_benchmark TpchBenchmarkMain.cpp) -target_link_libraries( - velox_tpch_benchmark velox_tpch_benchmark_lib) +target_link_libraries(velox_tpch_benchmark velox_tpch_benchmark_lib) diff --git a/velox/benchmarks/tpch/TpchBenchmark.cpp b/velox/benchmarks/tpch/TpchBenchmark.cpp index 25d02224fc58..585e5c6eba98 100644 --- a/velox/benchmarks/tpch/TpchBenchmark.cpp +++ b/velox/benchmarks/tpch/TpchBenchmark.cpp @@ -14,7 +14,9 @@ * limitations under the License. */ -#include "velox/benchmarks/QueryBenchmarkBase.h" +#include "velox/benchmarks/tpch/TpchBenchmark.h" +#include +#include "velox/exec/PlanNodeStats.h" using namespace facebook::velox; using namespace facebook::velox::exec; @@ -57,178 +59,165 @@ DEFINE_int32( "include in IO meter query. The columns are sorted by name and the n% first " "are scanned"); -std::shared_ptr queryBuilder; - -class TpchBenchmark : public QueryBenchmarkBase { - public: - void runMain(std::ostream& out, RunStats& runStats) override { - if (FLAGS_run_query_verbose == -1 && FLAGS_io_meter_column_pct == 0) { - folly::runBenchmarks(); - } else { - const auto queryPlan = FLAGS_io_meter_column_pct > 0 - ? queryBuilder->getIoMeterPlan(FLAGS_io_meter_column_pct) - : queryBuilder->getQueryPlan(FLAGS_run_query_verbose); - auto [cursor, actualResults] = run(queryPlan); - if (!cursor) { - LOG(ERROR) << "Query terminated with error. Exiting"; - exit(1); - } - auto task = cursor->task(); - ensureTaskCompletion(task.get()); - if (FLAGS_include_results) { - printResults(actualResults, out); - out << std::endl; - } - const auto stats = task->taskStats(); - int64_t rawInputBytes = 0; - for (auto& pipeline : stats.pipelineStats) { - auto& first = pipeline.operatorStats[0]; - if (first.operatorType == "TableScan") { - rawInputBytes += first.rawInputBytes; - } +void TpchBenchmark::initQueryBuilder() { + queryBuilder_ = + std::make_shared(toFileFormat(FLAGS_data_format)); + queryBuilder_->initialize(FLAGS_data_path); +} + +void TpchBenchmark::initialize() { + QueryBenchmarkBase::initialize(); + initQueryBuilder(); +} + +void TpchBenchmark::shutdown() { + QueryBenchmarkBase::shutdown(); + queryBuilder_.reset(); +} + +void TpchBenchmark::runMain( + std::ostream& out, + facebook::velox::RunStats& runStats) { + if (FLAGS_run_query_verbose == -1 && FLAGS_io_meter_column_pct == 0) { + folly::runBenchmarks(); + } else { + auto queryPlan = FLAGS_io_meter_column_pct > 0 + ? queryBuilder_->getIoMeterPlan(FLAGS_io_meter_column_pct) + : queryBuilder_->getQueryPlan(FLAGS_run_query_verbose); + auto [cursor, actualResults] = run(queryPlan, queryConfigs_); + if (!cursor) { + LOG(ERROR) << "Query terminated with error. Exiting"; + exit(1); + } + auto task = cursor->task(); + ensureTaskCompletion(task.get()); + if (FLAGS_include_results) { + printResults(actualResults, out); + out << std::endl; + } + const auto stats = task->taskStats(); + int64_t rawInputBytes = 0; + for (auto& pipeline : stats.pipelineStats) { + auto& first = pipeline.operatorStats[0]; + if (first.operatorType == "TableScan") { + rawInputBytes += first.rawInputBytes; } - runStats.rawInputBytes = rawInputBytes; - out << fmt::format( - "Execution time: {}", - succinctMillis( - stats.executionEndTimeMs - stats.executionStartTimeMs)) - << std::endl; - out << fmt::format( - "Splits total: {}, finished: {}", - stats.numTotalSplits, - stats.numFinishedSplits) - << std::endl; - out << printPlanWithStats( - *queryPlan.plan, stats, FLAGS_include_custom_stats) - << std::endl; } + runStats.rawInputBytes = rawInputBytes; + out << fmt::format( + "Execution time: {}", + facebook::velox::succinctMillis( + stats.executionEndTimeMs - stats.executionStartTimeMs)) + << std::endl; + out << fmt::format( + "Splits total: {}, finished: {}", + stats.numTotalSplits, + stats.numFinishedSplits) + << std::endl; + out << printPlanWithStats( + *queryPlan.plan, stats, FLAGS_include_custom_stats) + << std::endl; } -}; +} -TpchBenchmark benchmark; +std::unique_ptr benchmark; BENCHMARK(q1) { - const auto planContext = queryBuilder->getQueryPlan(1); - benchmark.run(planContext); + benchmark->runQuery(1); } BENCHMARK(q2) { - const auto planContext = queryBuilder->getQueryPlan(2); - benchmark.run(planContext); + benchmark->runQuery(2); } BENCHMARK(q3) { - const auto planContext = queryBuilder->getQueryPlan(3); - benchmark.run(planContext); + benchmark->runQuery(3); } BENCHMARK(q4) { - const auto planContext = queryBuilder->getQueryPlan(4); - benchmark.run(planContext); + benchmark->runQuery(4); } BENCHMARK(q5) { - const auto planContext = queryBuilder->getQueryPlan(5); - benchmark.run(planContext); + benchmark->runQuery(5); } BENCHMARK(q6) { - const auto planContext = queryBuilder->getQueryPlan(6); - benchmark.run(planContext); + benchmark->runQuery(6); } BENCHMARK(q7) { - const auto planContext = queryBuilder->getQueryPlan(7); - benchmark.run(planContext); + benchmark->runQuery(7); } BENCHMARK(q8) { - const auto planContext = queryBuilder->getQueryPlan(8); - benchmark.run(planContext); + benchmark->runQuery(8); } BENCHMARK(q9) { - const auto planContext = queryBuilder->getQueryPlan(9); - benchmark.run(planContext); + benchmark->runQuery(9); } BENCHMARK(q10) { - const auto planContext = queryBuilder->getQueryPlan(10); - benchmark.run(planContext); + benchmark->runQuery(10); } BENCHMARK(q11) { - const auto planContext = queryBuilder->getQueryPlan(11); - benchmark.run(planContext); + benchmark->runQuery(11); } BENCHMARK(q12) { - const auto planContext = queryBuilder->getQueryPlan(12); - benchmark.run(planContext); + benchmark->runQuery(12); } BENCHMARK(q13) { - const auto planContext = queryBuilder->getQueryPlan(13); - benchmark.run(planContext); + benchmark->runQuery(13); } BENCHMARK(q14) { - const auto planContext = queryBuilder->getQueryPlan(14); - benchmark.run(planContext); + benchmark->runQuery(14); } BENCHMARK(q15) { - const auto planContext = queryBuilder->getQueryPlan(15); - benchmark.run(planContext); + benchmark->runQuery(15); } BENCHMARK(q16) { - const auto planContext = queryBuilder->getQueryPlan(16); - benchmark.run(planContext); + benchmark->runQuery(16); } BENCHMARK(q17) { - const auto planContext = queryBuilder->getQueryPlan(17); - benchmark.run(planContext); + benchmark->runQuery(17); } BENCHMARK(q18) { - const auto planContext = queryBuilder->getQueryPlan(18); - benchmark.run(planContext); + benchmark->runQuery(18); } BENCHMARK(q19) { - const auto planContext = queryBuilder->getQueryPlan(19); - benchmark.run(planContext); + benchmark->runQuery(19); } BENCHMARK(q20) { - const auto planContext = queryBuilder->getQueryPlan(20); - benchmark.run(planContext); + benchmark->runQuery(20); } BENCHMARK(q21) { - const auto planContext = queryBuilder->getQueryPlan(21); - benchmark.run(planContext); + benchmark->runQuery(21); } BENCHMARK(q22) { - const auto planContext = queryBuilder->getQueryPlan(22); - benchmark.run(planContext); + benchmark->runQuery(22); } -int tpchBenchmarkMain() { - benchmark.initialize(); - queryBuilder = - std::make_shared(toFileFormat(FLAGS_data_format)); - queryBuilder->initialize(FLAGS_data_path); +void tpchBenchmarkMain() { + VELOX_CHECK_NOT_NULL(benchmark); + benchmark->initialize(); if (FLAGS_test_flags_file.empty()) { RunStats ignore; - benchmark.runMain(std::cout, ignore); + benchmark->runMain(std::cout, ignore); } else { - benchmark.runAllCombinations(); + benchmark->runAllCombinations(); } - benchmark.shutdown(); - queryBuilder.reset(); - return 0; + benchmark->shutdown(); } diff --git a/velox/benchmarks/tpch/TpchBenchmark.h b/velox/benchmarks/tpch/TpchBenchmark.h index e66e7c53cbc5..7297f7db17ad 100644 --- a/velox/benchmarks/tpch/TpchBenchmark.h +++ b/velox/benchmarks/tpch/TpchBenchmark.h @@ -15,4 +15,31 @@ */ #pragma once +#include "velox/benchmarks/QueryBenchmarkBase.h" +#include "velox/exec/tests/utils/TpchQueryBuilder.h" + +class TpchBenchmark : public facebook::velox::QueryBenchmarkBase { + public: + void initialize() override; + + void shutdown() override; + + void runMain(std::ostream& out, facebook::velox::RunStats& runStats) override; + + void runQuery(int32_t queryId) { + const auto planContext = queryBuilder_->getQueryPlan(queryId); + run(planContext, queryConfigs_); + } + + protected: + std::unordered_map queryConfigs_; + + private: + void initQueryBuilder(); + + std::shared_ptr queryBuilder_; +}; + +extern std::unique_ptr benchmark; + void tpchBenchmarkMain(); diff --git a/velox/benchmarks/tpch/TpchBenchmarkMain.cpp b/velox/benchmarks/tpch/TpchBenchmarkMain.cpp index 4477455d8f3c..0fa9718ca7ba 100644 --- a/velox/benchmarks/tpch/TpchBenchmarkMain.cpp +++ b/velox/benchmarks/tpch/TpchBenchmarkMain.cpp @@ -24,5 +24,6 @@ int main(int argc, char** argv) { "This program benchmarks TPC-H queries. Run 'velox_tpch_benchmark -helpon=TpchBenchmark' for available options.\n"); gflags::SetUsageMessage(kUsage); folly::Init init{&argc, &argv, false}; + benchmark = std::make_unique(); tpchBenchmarkMain(); } diff --git a/velox/benchmarks/unstable/CMakeLists.txt b/velox/benchmarks/unstable/CMakeLists.txt index cbd8e1cf6dca..3ca3339fca9c 100644 --- a/velox/benchmarks/unstable/CMakeLists.txt +++ b/velox/benchmarks/unstable/CMakeLists.txt @@ -11,21 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -set(velox_benchmark_deps - velox_type - velox_vector - velox_vector_fuzzer - velox_expression - velox_parse_parser - velox_parse_utils - velox_parse_expression - velox_serialization - Folly::folly - Folly::follybenchmark - ${DOUBLE_CONVERSION} - gflags::gflags - glog::glog) +set( + velox_benchmark_deps + velox_type + velox_vector + velox_vector_fuzzer + velox_expression + velox_parse_parser + velox_parse_utils + velox_parse_expression + velox_serialization + Folly::folly + Folly::follybenchmark + ${DOUBLE_CONVERSION} + gflags::gflags + glog::glog +) add_executable(velox_memory_alloc_benchmark MemoryAllocationBenchmark.cpp) -target_link_libraries( - velox_memory_alloc_benchmark ${velox_benchmark_deps} velox_memory pthread) +target_link_libraries(velox_memory_alloc_benchmark ${velox_benchmark_deps} velox_memory pthread) diff --git a/velox/buffer/Buffer.cpp b/velox/buffer/Buffer.cpp index 806959295401..abb7a3f09d63 100644 --- a/velox/buffer/Buffer.cpp +++ b/velox/buffer/Buffer.cpp @@ -18,9 +18,24 @@ namespace facebook::velox { +std::string Buffer::typeString(Type type) { + switch (type) { + case Type::kPOD: + return "kPOD"; + case Type::kNonPOD: + return "kNonPOD"; + case Type::kPODView: + return "kPODView"; + case Type::kNonPODView: + return "kNonPODView"; + default: + return fmt::format("Unknown({})", static_cast(type)); + } +} + namespace { struct BufferReleaser { - explicit BufferReleaser(const BufferPtr& parent) : parent_(parent) {} + explicit BufferReleaser(BufferPtr parent) : parent_{std::move(parent)} {} void addRef() const {} void release() const {} diff --git a/velox/buffer/Buffer.h b/velox/buffer/Buffer.h index 4e6dba9ebf2f..85ceab567123 100644 --- a/velox/buffer/Buffer.h +++ b/velox/buffer/Buffer.h @@ -19,6 +19,8 @@ #include #include +#include +#include #include "velox/common/base/BitUtil.h" #include "velox/common/base/CheckedArithmetic.h" #include "velox/common/base/Exceptions.h" @@ -26,8 +28,7 @@ #include "velox/common/base/SimdUtil.h" #include "velox/common/memory/Memory.h" -namespace facebook { -namespace velox { +namespace facebook::velox { class Buffer; class AlignedBuffer; @@ -58,21 +59,40 @@ class Buffer { // type. Thus the conditions are: trivial destructor (no resources to release) // and trivially copyable (so memcpy works) template - static inline constexpr bool is_pod_like_v = + static constexpr bool is_pod_like_v = std::is_trivially_destructible_v && std::is_trivially_copyable_v; - virtual ~Buffer() {} + virtual ~Buffer() = default; - void addRef() { - referenceCount_.fetch_add(1); + static constexpr uint8_t kPODBit = 0; + static constexpr uint8_t kPODMask = 1 << kPODBit; + static constexpr uint8_t kViewBit = 1; + static constexpr uint8_t kViewMask = 1 << kViewBit; + static_assert(kPODBit != kViewBit); + + enum class Type : uint8_t { + kNonPOD = 0 << kPODBit | 0 << kViewBit, + kPOD = 1 << kPODBit | 0 << kViewBit, + kNonPODView = 0 << kPODBit | 1 << kViewBit, + kPODView = 1 << kPODBit | 1 << kViewBit, + }; + + static std::string typeString(Type type); + + Type type() const { + return type_; + } + + void addRef() noexcept { + referenceCount_.fetch_add(1, std::memory_order_acq_rel); } - int refCount() const { - return referenceCount_; + int refCount() const noexcept { + return referenceCount_.load(std::memory_order_acquire); } void release() { - if (referenceCount_.fetch_sub(1) == 1) { + if (referenceCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { releaseResources(); if (pool_) { freeToPool(); @@ -86,13 +106,13 @@ class Buffer { const T* as() const { // We can't check actual types, but we can sanity-check POD/non-POD // conversion. `void` is special as it's used in type-erased contexts - VELOX_DCHECK((std::is_same_v) || podType_ == is_pod_like_v); + VELOX_DCHECK(std::is_void_v || isPOD() == is_pod_like_v); return reinterpret_cast(data_); } template Range asRange() { - return Range(as(), 0, size() / sizeof(T)); + return {as(), 0, static_cast(size() / sizeof(T))}; } template @@ -102,16 +122,16 @@ class Buffer { VELOX_CHECK(!isView()); // We can't check actual types, but we can sanity-check POD/non-POD // conversion. `void` is special as it's used in type-erased contexts - VELOX_DCHECK((std::is_same_v) || podType_ == is_pod_like_v); + VELOX_DCHECK(std::is_void_v || isPOD() == is_pod_like_v); return reinterpret_cast(data_); } template MutableRange asMutableRange() { - return MutableRange(asMutable(), 0, size() / sizeof(T)); + return {asMutable(), 0, static_cast(size() / sizeof(T))}; } - size_t size() const { + size_t size() const noexcept { return size_; } @@ -126,24 +146,28 @@ class Buffer { checkEndGuard(); } - uint64_t capacity() const { + uint64_t capacity() const noexcept { return capacity_; } - bool unique() const { - return referenceCount_ == 1; + bool unique() const noexcept { + return refCount() == 1; } - velox::memory::MemoryPool* pool() const { + velox::memory::MemoryPool* pool() const noexcept { return pool_; } - bool isMutable() const { + bool isMutable() const noexcept { return !isView() && unique(); } - virtual bool isView() const { - return false; + bool isView() const { + return (static_cast(type_) & kViewMask) != 0; + } + + bool isPOD() const { + return (static_cast(type_) & kPODMask) != 0; } friend std::ostream& operator<<(std::ostream& os, const Buffer& buffer) { @@ -199,6 +223,14 @@ class Buffer { sizeof(T), is_pod_like_v, buffer, offset, length); } + /// Transfers this buffer to 'pool'. Returns true if the transfer succeeds, or + /// false if the transfer fails. A buffer can be transferred to 'pool' if its + /// original pool and 'pool' are from the same MemoryAllocator and the buffer + /// is not a BufferView. + virtual bool transferTo(velox::memory::MemoryPool* /*pool*/) { + VELOX_NYI("{} unsupported", __FUNCTION__); + } + protected: // Writes a magic word at 'capacity_'. No-op for a BufferView. The actual // logic is inside a separate virtual function, allowing override by derived @@ -241,7 +273,7 @@ class Buffer { virtual void copyFrom(const Buffer* other, size_t bytes) { VELOX_CHECK(!isView()); VELOX_CHECK_GE(capacity_, bytes); - VELOX_CHECK(podType_); + VELOX_CHECK_EQ(type_, Type::kPOD); memcpy(data_, other->data_, bytes); } @@ -252,27 +284,24 @@ class Buffer { } Buffer( - velox::memory::MemoryPool* pool, + Type type, uint8_t* data, size_t capacity, - bool podType) - : pool_(pool), - data_(data), - capacity_(capacity), - referenceCount_(0), - podType_(podType) {} + velox::memory::MemoryPool* pool) + : pool_{pool}, data_{data}, capacity_{capacity}, type_{type} {} velox::memory::MemoryPool* const pool_; uint8_t* const data_; - uint64_t size_ = 0; - uint64_t capacity_ = 0; - std::atomic referenceCount_; - bool podType_ = true; - // Pad to 64 bytes. If using as int32_t[], guarantee that value at index -1 == - // -1. - uint64_t padding_[2] = {static_cast(-1), static_cast(-1)}; - // Needs to use setCapacity() from static method reallocate(). - friend class AlignedBuffer; + + uint64_t size_{0}; + uint64_t capacity_; + std::atomic_int32_t referenceCount_{0}; + + const Type type_; + + // Pad to 64 bytes. + // If using as int32_t[], guarantee that value at index -1 == -1. + uint64_t padding_[2]{static_cast(-1), static_cast(-1)}; private: static BufferPtr sliceBufferZeroCopy( @@ -281,6 +310,9 @@ class Buffer { const BufferPtr& buffer, size_t offset, size_t length); + + // Needs to use setCapacity() from static method reallocate(). + friend class AlignedBuffer; }; static_assert( @@ -289,12 +321,12 @@ static_assert( template <> inline Range Buffer::asRange() { - return Range(as(), 0, size() * 8); + return {as(), 0, static_cast(size() * 8)}; } template <> inline MutableRange Buffer::asMutableRange() { - return MutableRange(asMutable(), 0, size() * 8); + return {asMutable(), 0, static_cast(size() * 8)}; } template <> @@ -304,11 +336,11 @@ BufferPtr Buffer::slice( size_t length, memory::MemoryPool* pool); -static inline void intrusive_ptr_add_ref(Buffer* buffer) { +FOLLY_ALWAYS_INLINE void intrusive_ptr_add_ref(Buffer* buffer) noexcept { buffer->addRef(); } -static inline void intrusive_ptr_release(Buffer* buffer) { +FOLLY_ALWAYS_INLINE void intrusive_ptr_release(Buffer* buffer) noexcept { buffer->release(); } @@ -325,7 +357,7 @@ class AlignedBuffer : public Buffer { static constexpr int32_t kSizeofAlignedBuffer = 64; static constexpr int32_t kPaddedSize = kSizeofAlignedBuffer + simd::kPadding; - ~AlignedBuffer() { + ~AlignedBuffer() override { // This may throw, which is expected to signal an error to the // user. This is better for distributed debugging than killing the // process. In concept this indicates the possibility of memory @@ -337,10 +369,8 @@ class AlignedBuffer : public Buffer { // It's almost like partial specialization, but we redirect all POD types to // the same non-templated class template - using ImplClass = typename std::conditional< - is_pod_like_v, - AlignedBuffer, - NonPODAlignedBuffer>::type; + using ImplClass = std:: + conditional_t, AlignedBuffer, NonPODAlignedBuffer>; /** * Allocates enough memory to store numElements of type T. May @@ -367,7 +397,8 @@ class AlignedBuffer : public Buffer { } void* memory = pool->allocate(preferredSize); - auto* buffer = new (memory) ImplClass(pool, preferredSize - kPaddedSize); + VELOX_CHECK_NOT_NULL(memory); + auto* buffer = new (memory) ImplClass{pool, preferredSize - kPaddedSize}; // set size explicitly instead of setSize because `fillNewMemory` already // called the constructors buffer->size_ = size; @@ -376,6 +407,18 @@ class AlignedBuffer : public Buffer { return result; } + /// A verbose version of the allocate() with the exact size. + /// May allocate slightly more memory than strictly necessary. Guarantees that + /// simd::kPadding bytes past capacity() are addressable and asserts that + /// these do not get overrun. + template + static BufferPtr allocateExact( + size_t numElements, + velox::memory::MemoryPool* pool, + const std::optional& initValue = std::nullopt) { + return allocate(numElements, pool, initValue, true); + } + // Changes the capacity of '*buffer'. The buffer may grow/shrink in // place or may change addresses. The content is copied up to the // old size() or the new size, whichever is smaller. If the buffer grows, the @@ -417,34 +460,32 @@ class AlignedBuffer : public Buffer { // called the constructors newBuffer->size_ = size; *buffer = std::move(newBuffer); - return; - } - if (!old->unique()) { + } else if (!old->unique()) { auto newBuffer = allocate(numElements, pool); newBuffer->copyFrom(old, std::min(size, old->size())); reinterpret_cast(newBuffer.get()) ->template fillNewMemory(old->size(), size, initValue); newBuffer->size_ = size; *buffer = std::move(newBuffer); - return; - } - auto oldCapacity = checkedPlus(old->capacity(), kPaddedSize); - auto preferredSize = - pool->preferredSize(checkedPlus(size, kPaddedSize)); + } else { + auto oldCapacity = checkedPlus(old->capacity(), kPaddedSize); + auto preferredSize = + pool->preferredSize(checkedPlus(size, kPaddedSize)); - void* newPtr = pool->reallocate(old, oldCapacity, preferredSize); + void* newPtr = pool->reallocate(old, oldCapacity, preferredSize); - // Make the old buffer no longer owned by '*buffer' because reallocate - // freed the old buffer. Reassigning the new buffer to - // '*buffer' would be a double free if we didn't do this. - buffer->detach(); + // Make the old buffer no longer owned by '*buffer' because reallocate + // freed the old buffer. Reassigning the new buffer to + // '*buffer' would be a double free if we didn't do this. + buffer->detach(); - auto newBuffer = - new (newPtr) AlignedBuffer(pool, preferredSize - kPaddedSize); - newBuffer->setSize(size); - newBuffer->fillNewMemory(oldSize, size, initValue); + auto newBuffer = + new (newPtr) AlignedBuffer{pool, preferredSize - kPaddedSize}; + newBuffer->setSize(size); + newBuffer->fillNewMemory(oldSize, size, initValue); - *buffer = newBuffer; + *buffer = newBuffer; + } } // Appends bytes starting at 'items' for a length of 'sizeof(T) * @@ -479,7 +520,7 @@ class AlignedBuffer : public Buffer { } VELOX_CHECK( - bufferPtr->podType_, "Support for non POD types not implemented yet"); + bufferPtr->isPOD(), "Support for non POD types not implemented yet"); // The reason we use uint8_t is because mutableNulls()->size() will return // in byte count. We also don't bother initializing since copyFrom will be @@ -491,13 +532,49 @@ class AlignedBuffer : public Buffer { return newBuffer; } + template + static BufferPtr copy( + const BufferPtr& buffer, + velox::memory::MemoryPool* pool) { + if (buffer == nullptr) { + return nullptr; + } + + // The reason we use uint8_t is because mutableNulls()->size() will return + // in byte count. We also don't bother initializing since copyFrom will be + // overwriting anyway. + BufferPtr newBuffer; + if constexpr (std::is_same_v) { + newBuffer = AlignedBuffer::allocate(buffer->size(), pool); + } else { + const auto numElements = checkedDivide(buffer->size(), sizeof(T)); + newBuffer = AlignedBuffer::allocate(numElements, pool); + } + + newBuffer->copyFrom(buffer.get(), newBuffer->size()); + + return newBuffer; + } + + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + if (pool_->transferTo( + pool, this, checkedPlus(kPaddedSize, capacity_))) { + setPool(pool); + return true; + } + return false; + } + protected: AlignedBuffer(velox::memory::MemoryPool* pool, size_t capacity) - : Buffer( - pool, + : Buffer{ + Type::kPOD, reinterpret_cast(this) + sizeof(*this), capacity, - true /*podType*/) { + pool} { static_assert(sizeof(*this) == kAlignment); static_assert(sizeof(*this) == kSizeofAlignedBuffer); setEndGuard(); @@ -531,7 +608,12 @@ class AlignedBuffer : public Buffer { } } - protected: + void setPool(velox::memory::MemoryPool* pool) { + velox::memory::MemoryPool** poolPtr = + const_cast(&pool_); + *poolPtr = pool; + } + void setEndGuardImpl() override { *reinterpret_cast(data_ + capacity_) = kEndGuard; } @@ -596,13 +678,30 @@ class NonPODAlignedBuffer : public Buffer { } } + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + + if (pool_->transferTo( + pool, + this, + checkedPlus(AlignedBuffer::kPaddedSize, capacity_))) { + velox::memory::MemoryPool** poolPtr = + const_cast(&pool_); + *poolPtr = pool; + return true; + } + return false; + } + protected: NonPODAlignedBuffer(velox::memory::MemoryPool* pool, size_t capacity) - : Buffer( - pool, + : Buffer{ + Type::kNonPOD, reinterpret_cast(this) + sizeof(*this), capacity, - false /*podType*/) { + pool} { static_assert(sizeof(*this) == AlignedBuffer::kAlignment); static_assert(sizeof(*this) == sizeof(AlignedBuffer)); } @@ -610,8 +709,8 @@ class NonPODAlignedBuffer : public Buffer { void releaseResources() override { VELOX_CHECK_EQ(size_ % sizeof(T), 0); size_t numValues = size_ / sizeof(T); - // we can't use asMutable because it checks isMutable and we wan't to - // destroy regardless + // we can't use asMutable because it checks isMutable and we wan't + // to destroy regardless T* ptr = reinterpret_cast(data_); for (int i = 0; i < numValues; ++i) { ptr[i].~T(); @@ -619,6 +718,8 @@ class NonPODAlignedBuffer : public Buffer { } void copyFrom(const Buffer* other, size_t bytes) override { + // TODO: change this to isMutable(). See + // https://github.com/facebookincubator/velox/issues/6562. VELOX_CHECK(!isView()); VELOX_CHECK_GE(size_, bytes); VELOX_DCHECK( @@ -650,10 +751,12 @@ class NonPODAlignedBuffer : public Buffer { int oldNum = oldBytes / sizeof(T); int newNum = newBytes / sizeof(T); auto data = asMutable(); - for (int i = oldNum; i < newNum; ++i) { - if (initValue) { + if (initValue) { + for (int i = oldNum; i < newNum; ++i) { new (data + i) T(*initValue); - } else { + } + } else { + for (int i = oldNum; i < newNum; ++i) { new (data + i) T(); } } @@ -673,40 +776,60 @@ class NonPODAlignedBuffer : public Buffer { template class BufferView : public Buffer { public: - static BufferPtr create( - const uint8_t* data, - size_t size, - Releaser releaser, - bool podType = true) { - BufferView* view = new BufferView(data, size, releaser, podType); - BufferPtr result(view); + template + static BufferPtr + create(const uint8_t* data, size_t size, R&& releaser, bool podType = true) { + auto* view = new BufferView{data, size, std::forward(releaser), podType}; + BufferPtr result{view}; return result; } + // Helper method to create a buffer view referencing another existing Buffer. + template + static BufferPtr + create(const BufferPtr& innerBuffer, R&& releaser, bool podType = true) { + return create( + innerBuffer->as(), + innerBuffer->size(), + std::forward(releaser), + podType); + } + ~BufferView() override { releaser_.release(); } - bool isView() const override { - return true; + bool transferTo(velox::memory::MemoryPool* pool) override { + if (pool_ == pool) { + return true; + } + return false; } private: - BufferView(const uint8_t* data, size_t size, Releaser releaser, bool podType) + template + BufferView(const uint8_t* data, size_t size, R&& releaser, bool podType) // A BufferView must be created over the data held by a cache // pin, which is typically const. The Buffer enforces const-ness // when returning the pointer. We cast away the const here to // avoid a separate code path for const and non-const Buffer // payloads. - : Buffer(nullptr, const_cast(data), size, podType), - releaser_(releaser) { + : Buffer{podType ? Type::kPODView : Type::kNonPODView, const_cast(data), size, nullptr}, + releaser_{std::forward(releaser)} { size_ = size; - capacity_ = size; releaser_.addRef(); } - Releaser const releaser_; + [[no_unique_address]] const Releaser releaser_; }; -} // namespace velox -} // namespace facebook +} // namespace facebook::velox + +// fmt formatter specialization for Buffer::Type +template <> +struct fmt::formatter : formatter { + auto format(facebook::velox::Buffer::Type s, format_context& ctx) const { + return formatter::format( + facebook::velox::Buffer::typeString(s), ctx); + } +}; diff --git a/velox/buffer/StringViewBufferHolder.cpp b/velox/buffer/StringViewBufferHolder.cpp index 83479bba0a7c..750b011b2c10 100644 --- a/velox/buffer/StringViewBufferHolder.cpp +++ b/velox/buffer/StringViewBufferHolder.cpp @@ -26,8 +26,9 @@ StringView StringViewBufferHolder::getOwnedStringView( if (stringBuffers_.empty() || stringBuffers_.back()->size() + size > stringBuffers_.back()->capacity()) { - stringBuffers_.push_back(AlignedBuffer::allocate( - std::max(size, kInitialStringReservation), pool_)); + stringBuffers_.push_back( + AlignedBuffer::allocate( + std::max(size, kInitialStringReservation), pool_)); stringBuffers_.back()->setSize(0); } auto stringBuffer = stringBuffers_.back().get(); diff --git a/velox/buffer/StringViewBufferHolder.h b/velox/buffer/StringViewBufferHolder.h index c066d0dbc0a1..962b484fd9e4 100644 --- a/velox/buffer/StringViewBufferHolder.h +++ b/velox/buffer/StringViewBufferHolder.h @@ -33,7 +33,7 @@ class StringViewBufferHolder { /// Return a copy of the StringView where the StringView is copied to this /// StringViewBufferHolder if the StringView is not inlined. std::string and - /// folly::StringPiece are also copied to the internal buffers (see the + /// std::string_view are also copied to the internal buffers (see the /// specializations below). /// /// NOTE: Out of convenience, we allow different types to be passed in, but @@ -52,8 +52,8 @@ class StringViewBufferHolder { return getOwnedStringView(value.data(), value.size()); } - /// Specialization for folly::StringPiece type. - StringView getOwnedValue(folly::StringPiece value) { + /// Specialization for std::string_view type. + StringView getOwnedValue(std::string_view value) { return getOwnedStringView(value.data(), value.size()); } @@ -72,6 +72,17 @@ class StringViewBufferHolder { return stringBuffers_; } + /// Add a buffer to the list of buffers. This is used to allow bulk addition + /// of values with fewer overall buffers vs adding a value at a time via + /// getOwnedValue. The buffer must be allocated on the same underlying memory + /// pool. + void addOwnedBuffer(BufferPtr&& inBuffer) { + VELOX_CHECK( + inBuffer->pool() == pool_, + "Buffer must be allocated on same memory pool"); + stringBuffers_.push_back(std::move(inBuffer)); + } + private: StringView getOwnedStringView(StringView stringView); StringView getOwnedStringView(const char* data, int32_t size); diff --git a/velox/buffer/tests/BufferTest.cpp b/velox/buffer/tests/BufferTest.cpp index 9db4963221da..a45eeff24724 100644 --- a/velox/buffer/tests/BufferTest.cpp +++ b/velox/buffer/tests/BufferTest.cpp @@ -143,6 +143,20 @@ TEST_F(BufferTest, testAlignedBufferExact) { EXPECT_GE(buffer4->capacity(), oneMBMinusPad + 1); } +TEST_F(BufferTest, testAllocateExact) { + const int32_t oneMBMinusPad = 1024 * 1024 - AlignedBuffer::kPaddedSize; + + BufferPtr buffer1 = AlignedBuffer::allocateExact( + oneMBMinusPad + 1, pool_.get(), std::nullopt); + EXPECT_EQ(buffer1->size(), oneMBMinusPad + 1); + EXPECT_GE(buffer1->capacity(), oneMBMinusPad + 1); + + BufferPtr buffer2 = AlignedBuffer::allocateExact(3, pool_.get(), 'i'); + for (size_t i = 0; i < buffer2->size(); i++) { + EXPECT_EQ(buffer2->as()[i], 'i'); + } +} + TEST_F(BufferTest, testAsRange) { // Simple 2 element vector. std::vector testData({5, 255}); @@ -484,7 +498,9 @@ TEST_F(BufferTest, testNonPOD) { TEST_F(BufferTest, testNonPODMemoryUsage) { using T = std::shared_ptr; const int64_t currentBytes = pool_->usedBytes(); - { auto buffer = AlignedBuffer::allocate(0, pool_.get()); } + { + auto buffer = AlignedBuffer::allocate(0, pool_.get()); + } EXPECT_EQ(pool_->usedBytes(), currentBytes); } @@ -535,5 +551,34 @@ TEST_F(BufferTest, sliceBooleanBuffer) { Buffer::slice(bufferPtr, 5, 6, nullptr), "Pool must not be null."); } +TEST_F(BufferTest, testType) { + // Test AlignedBuffer type + auto alignedBuffer = AlignedBuffer::allocate(100, pool_.get()); + EXPECT_EQ(alignedBuffer->type(), Buffer::Type::kPOD); + EXPECT_TRUE(alignedBuffer->isPOD()); + EXPECT_FALSE(alignedBuffer->isView()); + + // Test NonPODAlignedBuffer type + auto nonPODBuffer = AlignedBuffer::allocate(10, pool_.get()); + EXPECT_EQ(nonPODBuffer->type(), Buffer::Type::kNonPOD); + EXPECT_FALSE(nonPODBuffer->isPOD()); + EXPECT_FALSE(nonPODBuffer->isView()); + + // Test BufferView type + MockCachePin pin; + const char* data = "test data"; + auto podBufferView = BufferView::create( + reinterpret_cast(data), 9, pin); + EXPECT_EQ(podBufferView->type(), Buffer::Type::kPODView); + EXPECT_TRUE(podBufferView->isPOD()); + EXPECT_TRUE(podBufferView->isView()); + + auto nonPodBufferView = BufferView::create( + reinterpret_cast(data), 9, pin, false); + EXPECT_EQ(nonPodBufferView->type(), Buffer::Type::kNonPODView); + EXPECT_FALSE(nonPodBufferView->isPOD()); + EXPECT_TRUE(nonPodBufferView->isView()); +} + } // namespace velox } // namespace facebook diff --git a/velox/buffer/tests/CMakeLists.txt b/velox/buffer/tests/CMakeLists.txt index 097f5ee3bb3f..9ca02cf2545b 100644 --- a/velox/buffer/tests/CMakeLists.txt +++ b/velox/buffer/tests/CMakeLists.txt @@ -25,4 +25,5 @@ target_link_libraries( GTest::gmock glog::glog gflags::gflags - pthread) + pthread +) diff --git a/velox/buffer/tests/StringViewBufferHolderTest.cpp b/velox/buffer/tests/StringViewBufferHolderTest.cpp index 4049be7503fa..4ae7ece1a381 100644 --- a/velox/buffer/tests/StringViewBufferHolderTest.cpp +++ b/velox/buffer/tests/StringViewBufferHolderTest.cpp @@ -161,23 +161,21 @@ TEST_F(StringViewBufferHolderTest, getOwnedValueCanBeCalledWithStringType) { ASSERT_EQ(1, holder.buffers().size()); } -TEST_F( - StringViewBufferHolderTest, - getOwnedValueCanBeCalledWithStringPieceType) { +TEST_F(StringViewBufferHolderTest, getOwnedValueCanBeCalledWithStringViewType) { const char* buf = "abcdefghijklmnopqrstuvxz"; StringView result; - folly::StringPiece piece; + std::string_view view; auto holder = makeHolder(); ASSERT_EQ(0, holder.buffers().size()); { std::string str = buf; - piece = str; - result = holder.getOwnedValue(piece); + view = str; + result = holder.getOwnedValue(view); } - // `str` is already destructed and piece is invalid. + // `str` is already destructed and `view` is invalid. ASSERT_EQ(StringView(buf), result); ASSERT_EQ(1, holder.buffers().size()); } @@ -208,4 +206,25 @@ TEST_F(StringViewBufferHolderTest, buffersCopy) { EXPECT_NE(&buffers, &firstMoved); } +TEST_F(StringViewBufferHolderTest, addOwnedBuffer) { + auto holder = makeHolder(); + ASSERT_EQ(0, holder.buffers().size()); + auto buffer = AlignedBuffer::allocate(10, pool_.get()); + holder.addOwnedBuffer(std::move(buffer)); + ASSERT_EQ(1, holder.buffers().size()); + ASSERT_EQ(1, holder.moveBuffers().size()); +} + +TEST_F(StringViewBufferHolderTest, addOwnedBufferThrowsForWrongPool) { + auto holder = makeHolder(); + ASSERT_EQ(0, holder.buffers().size()); + auto buffer = AlignedBuffer::allocate(10, pool_.get()); + holder.addOwnedBuffer(std::move(buffer)); + + auto newPool = memory::memoryManager()->addLeafPool(); + ASSERT_THROW( + holder.addOwnedBuffer(AlignedBuffer::allocate(10, newPool.get())), + VeloxException); +} + } // namespace facebook::velox diff --git a/velox/common/CMakeLists.txt b/velox/common/CMakeLists.txt index 6661427654be..f6457da34760 100644 --- a/velox/common/CMakeLists.txt +++ b/velox/common/CMakeLists.txt @@ -18,6 +18,8 @@ add_subdirectory(config) add_subdirectory(dynamic_registry) add_subdirectory(encode) add_subdirectory(file) +add_subdirectory(future) +add_subdirectory(geospatial) add_subdirectory(hyperloglog) add_subdirectory(io) add_subdirectory(memory) @@ -26,3 +28,5 @@ add_subdirectory(serialization) add_subdirectory(time) add_subdirectory(testutil) add_subdirectory(fuzzer) + +velox_install_library_headers() diff --git a/velox/common/Casts.h b/velox/common/Casts.h new file mode 100644 index 000000000000..0f03271d6495 --- /dev/null +++ b/velox/common/Casts.h @@ -0,0 +1,170 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/Exceptions.h" + +#include + +namespace facebook::velox { + +namespace detail { + +template +void ensureCastSucceeded(To* casted, From* original) { + // Either casted or original will be nullptr. Otherwise it's a bad usage. + if (casted == nullptr) { + VELOX_CHECK_NOT_NULL( + original, "If casted is nullptr, original must not be."); + VELOX_FAIL( + "Failed to cast from '{}' to '{}'. Object is of type '{}'.", + folly::demangle(typeid(From).name()), + folly::demangle(typeid(To).name()), + folly::demangle(typeid(*original).name())); + } +} + +} // namespace detail + +// `checkedPointerCast` is a dynamic casting tool to throw a Velox exception +// when the casting failed. Use this instead of `std::dynamic_pointer_cast` +// when: +// 1) Casting must happen +// 2) We want a stack trace if it failed. +template +std::shared_ptr checkedPointerCast(const std::shared_ptr& input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto casted = std::dynamic_pointer_cast(input); + detail::ensureCastSucceeded(casted.get(), input.get()); + return casted; +} + +template +std::unique_ptr checkedPointerCast(std::unique_ptr input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* released = input.release(); + To* casted{nullptr}; + try { + casted = dynamic_cast(released); + detail::ensureCastSucceeded(casted, released); + } catch (...) { + input.reset(released); + throw; + } + return std::unique_ptr(casted); +} + +template +To* checkedPointerCast(From* input) { + VELOX_CHECK_NOT_NULL(input); + auto* casted = dynamic_cast(input); + detail::ensureCastSucceeded(casted, input); + return casted; +} + +template +std::unique_ptr staticUniquePointerCast(std::unique_ptr input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* released = input.release(); + auto* casted = static_cast(released); + return std::unique_ptr(casted); +} + +template +bool isInstanceOf(const std::shared_ptr& input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* casted = dynamic_cast(input.get()); + return casted != nullptr; +} + +template +bool isInstanceOf(const std::unique_ptr& input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* casted = dynamic_cast(input.get()); + return casted != nullptr; +} + +template +bool isInstanceOf(const From* input) { + VELOX_CHECK_NOT_NULL(input); + auto* casted = dynamic_cast(input); + return casted != nullptr; +} + +#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY +template +std::shared_ptr checked_pointer_cast(const std::shared_ptr& input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto casted = std::dynamic_pointer_cast(input); + detail::ensureCastSucceeded(casted.get(), input.get()); + return casted; +} + +template +std::unique_ptr checked_pointer_cast(std::unique_ptr input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* released = input.release(); + To* casted{nullptr}; + try { + casted = dynamic_cast(released); + detail::ensureCastSucceeded(casted, released); + } catch (...) { + input.reset(released); + throw; + } + return std::unique_ptr(casted); +} + +template +To* checked_pointer_cast(From* input) { + VELOX_CHECK_NOT_NULL(input); + auto* casted = dynamic_cast(input); + detail::ensureCastSucceeded(casted, input); + return casted; +} + +template +std::unique_ptr static_unique_pointer_cast(std::unique_ptr input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* released = input.release(); + auto* casted = static_cast(released); + return std::unique_ptr(casted); +} + +template +bool is_instance_of(const std::shared_ptr& input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* casted = dynamic_cast(input.get()); + return casted != nullptr; +} + +template +bool is_instance_of(const std::unique_ptr& input) { + VELOX_CHECK_NOT_NULL(input.get()); + auto* casted = dynamic_cast(input.get()); + return casted != nullptr; +} + +template +bool is_instance_of(const From* input) { + VELOX_CHECK_NOT_NULL(input); + auto* casted = dynamic_cast(input); + return casted != nullptr; +} +#endif + +} // namespace facebook::velox diff --git a/velox/common/Enums.h b/velox/common/Enums.h new file mode 100644 index 000000000000..8f844ea1b050 --- /dev/null +++ b/velox/common/Enums.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "folly/container/F14Map.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox { + +struct Enums { + /// Helper function to invert a mapping from enum type to name. + template + static auto invertMap( + const folly::F14FastMap& mapping) { + folly::F14FastMap inverted; + for (const auto& [key, value] : mapping) { + const bool emplaced = inverted.emplace(value, key).second; + VELOX_USER_CHECK( + emplaced, "Cannot invert a map with duplicate values: {}", value); + } + return inverted; + } +}; + +} // namespace facebook::velox + +/// Helper macros to implement bi-direction mappings between enum values and +/// names. +/// +/// Usage: +/// +/// In the header file, define the enum: +/// +/// #include "velox/common/Enums.h" +/// +/// enum class Foo {...}; +/// +/// VELOX_DECLARE_ENUM_NAME(Foo); +/// +/// In the cpp file, define the mapping: +/// +/// namespace { +/// const auto& fooNames() { +/// static const folly::F14FastMap kNames = { +/// {Foo::kFirst, "FIRST"}, +/// {Foo::kSecond, "SECOND"}, +/// ... +/// }; +/// return kNames; +/// } +/// } // namespace +/// +/// VELOX_DEFINE_ENUM_NAME(Foo, fooNames); +/// +/// In the client code, use FooName::toName(Foo::kFirst) to get the name of the +/// enum and FooName::toFoo("FIRST") or FooName::tryToFoo("FIRST") to get the +/// enum value. toFoo throws an exception if input is not a valid name, while +/// tryToFoo returns a std::nullopt. +/// +/// Use _EMBEDDED_ versions of the macros to define enums embedded in other +/// classes. + +#define VELOX_DECLARE_ENUM_NAME(EnumType) \ + struct EnumType##Name { \ + static std::string_view toName(EnumType value); \ + static EnumType to##EnumType(std::string_view name); \ + static std::optional tryTo##EnumType(std::string_view name); \ + }; \ + std::ostream& operator<<(std::ostream& os, const EnumType& value); + +#define VELOX_DEFINE_ENUM_NAME(EnumType, Names) \ + std::string_view EnumType##Name::toName(EnumType value) { \ + const auto& names = Names(); \ + auto it = names.find(value); \ + VELOX_CHECK( \ + it != names.end(), \ + "Invalid enum value: {}", \ + static_cast>(value)); \ + return it->second; \ + } \ + \ + std::optional EnumType##Name::tryTo##EnumType( \ + std::string_view name) { \ + static const auto kValues = facebook::velox::Enums::invertMap(Names()); \ + \ + auto it = kValues.find(name); \ + if (it == kValues.end()) { \ + return std::nullopt; \ + } \ + return it->second; \ + } \ + std::ostream& operator<<(std::ostream& os, const EnumType& value) { \ + os << EnumType##Name::toName(value); \ + return os; \ + } \ + \ + EnumType EnumType##Name::to##EnumType(std::string_view name) { \ + const auto maybeType = EnumType##Name::tryTo##EnumType(name); \ + VELOX_CHECK(maybeType, "Invalid enum name: {}", name); \ + return *maybeType; \ + } + +#define VELOX_DECLARE_EMBEDDED_ENUM_NAME(EnumType) \ + static std::string_view toName(EnumType value); \ + static EnumType to##EnumType(std::string_view name); \ + static std::optional tryTo##EnumType(std::string_view name); + +#define VELOX_DEFINE_EMBEDDED_ENUM_NAME(Class, EnumType, Names) \ + std::string_view Class::toName(Class::EnumType value) { \ + const auto& names = Names(); \ + auto it = names.find(value); \ + VELOX_CHECK( \ + it != names.end(), \ + "Invalid enum value: {}", \ + static_cast>(value)); \ + return it->second; \ + } \ + \ + std::optional Class::tryTo##EnumType( \ + std::string_view name) { \ + static const auto kValues = facebook::velox::Enums::invertMap(Names()); \ + \ + auto it = kValues.find(name); \ + if (it == kValues.end()) { \ + return std::nullopt; \ + } \ + return it->second; \ + } \ + \ + Class::EnumType Class::to##EnumType(std::string_view name) { \ + const auto maybeType = Class::tryTo##EnumType(name); \ + VELOX_CHECK(maybeType, "Invalid enum name: {}", name); \ + return *maybeType; \ + } diff --git a/velox/common/base/AdmissionController.cpp b/velox/common/base/AdmissionController.cpp index c7e1b71ea5ac..e1dae70a402e 100644 --- a/velox/common/base/AdmissionController.cpp +++ b/velox/common/base/AdmissionController.cpp @@ -32,12 +32,10 @@ void AdmissionController::accept(uint64_t resourceUnits) { { std::lock_guard l(mu_); if (unitsUsed_ + resourceUnits > config_.maxLimit) { - auto [unblockPromise, unblockFuture] = makeVeloxContinuePromiseContract(); Request req; req.unitsRequested = resourceUnits; - req.promise = std::move(unblockPromise); + future = req.promise.getSemiFuture(); queue_.push_back(std::move(req)); - future = std::move(unblockFuture); } else { updatedValue = unitsUsed_ += resourceUnits; } diff --git a/velox/common/base/AsyncSource.h b/velox/common/base/AsyncSource.h index 76740dd8681a..02901068c1da 100644 --- a/velox/common/base/AsyncSource.h +++ b/velox/common/base/AsyncSource.h @@ -56,8 +56,8 @@ class AsyncSource { ~AsyncSource() { VELOX_CHECK( - moved_ || closed_, - "AsyncSource should be properly consumed or closed."); + moved_ || closed_ || (cancelled_ && !making_), + "AsyncSource should be properly consumed, closed, or cancelled."); } // Makes an item if it is not already made. To be called on a background @@ -121,7 +121,7 @@ class AsyncSource { return nullptr; } if (making_) { - promise_ = std::make_unique(); + promise_ = std::make_unique("AsyncSource::move"); wait = promise_->getSemiFuture(); } else { if (!make_) { @@ -163,6 +163,18 @@ class AsyncSource { return timing_; } + /// Cancels the task if it hasn't started yet. + /// If the task has already started, the task will continue but AsyncSource + /// is marked as cancelled to allow proper cleanup in destructor. + void cancel() { + std::lock_guard l(mutex_); + cancelled_ = true; + if (make_ == nullptr) { + return; + } + make_ = nullptr; + } + /// This function assists the caller in ensuring that resources allocated in /// AsyncSource are promptly released: /// 1. Waits for the completion of the 'make_' function if it is executing @@ -178,7 +190,7 @@ class AsyncSource { { std::lock_guard l(mutex_); if (making_) { - promise_ = std::make_unique(); + promise_ = std::make_unique("AsyncSource::close"); wait = promise_->getSemiFuture(); } else if (make_) { make_ = nullptr; @@ -217,5 +229,6 @@ class AsyncSource { CpuWallTiming timing_; bool closed_{false}; bool moved_{false}; + bool cancelled_{false}; }; } // namespace facebook::velox diff --git a/velox/common/base/BigintIdMap.h b/velox/common/base/BigintIdMap.h index e63445f4d127..97d3abfdb514 100644 --- a/velox/common/base/BigintIdMap.h +++ b/velox/common/base/BigintIdMap.h @@ -31,8 +31,10 @@ class BigintIdMap { static constexpr int64_t kMaxCapacity = 1 << 30; // 1G entries, 12GB BigintIdMap(int32_t capacity, memory::MemoryPool& pool) : pool_(pool) { - makeTable(std::max( - 2 * sizeof(xsimd::batch), bits::nextPowerOfTwo(capacity))); + makeTable( + std::max( + 2 * sizeof(xsimd::batch), + bits::nextPowerOfTwo(capacity))); } BigintIdMap(const BigintIdMap& other) = delete; diff --git a/velox/common/base/BitUtil.h b/velox/common/base/BitUtil.h index be159ad2ea51..3257cbe4b0f0 100644 --- a/velox/common/base/BitUtil.h +++ b/velox/common/base/BitUtil.h @@ -48,6 +48,8 @@ namespace facebook { namespace velox { namespace bits { +inline constexpr uint64_t kNullHash = 1; + template inline bool isBitSet(const T* bits, uint64_t idx) { return bits[idx / (sizeof(bits[0]) * 8)] & @@ -92,6 +94,11 @@ inline void setBit(T* bits, uint64_t idx, bool value) { value ? setBit(bits, idx) : clearBit(bits, idx); } +inline void negateBit(void* bits, uint64_t idx) { + auto* bitsAs8Bit = reinterpret_cast(bits); + bitsAs8Bit[idx / 8] ^= (1 << (idx % 8)); +} + inline void negate(uint64_t* bits, int32_t size) { int32_t i = 0; for (; i + 64 <= size; i += 64) { diff --git a/velox/common/base/BloomFilter.h b/velox/common/base/BloomFilter.h index 1d23e382834b..6fab5a4376cf 100644 --- a/velox/common/base/BloomFilter.h +++ b/velox/common/base/BloomFilter.h @@ -61,6 +61,20 @@ class BloomFilter { return test(bits_.data(), bits_.size(), value); } + /// Tests an input value directly on a serialized bloom filter. + /// For the implementation V1, this API involves no copy of the bloom + /// filter data. + static bool mayContain(const char* serializedBloom, uint64_t value) { + common::InputByteStream stream(serializedBloom); + const auto version = stream.read(); + VELOX_USER_CHECK_EQ(kBloomFilterV1, version); + const auto size = stream.read(); + VELOX_USER_CHECK_GT(size, 0); + const uint64_t* bloomBits = + reinterpret_cast(serializedBloom + stream.offset()); + return test(bloomBits, size, value); + } + void merge(const char* serialized) { common::InputByteStream stream(serialized); auto version = stream.read(); @@ -127,7 +141,7 @@ class BloomFilter { return mask == (bloom[index] & mask); } - const int8_t kBloomFilterV1 = 1; + static constexpr int8_t kBloomFilterV1 = 1; std::vector bits_; }; diff --git a/velox/common/base/CMakeLists.txt b/velox/common/base/CMakeLists.txt index 58cd1f38ed2e..3774a4215429 100644 --- a/velox/common/base/CMakeLists.txt +++ b/velox/common/base/CMakeLists.txt @@ -12,16 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_exception Exceptions.cpp VeloxException.cpp - Exceptions.h) +velox_add_library(velox_exception Exceptions.cpp VeloxException.cpp Exceptions.h) velox_link_libraries( velox_exception - PUBLIC velox_flag_definitions - velox_process - Folly::folly - fmt::fmt - gflags::gflags - glog::glog) + PUBLIC velox_flag_definitions velox_process Folly::folly fmt::fmt gflags::gflags glog::glog +) velox_add_library( velox_common_base @@ -36,14 +31,17 @@ velox_add_library( SkewedPartitionBalancer.cpp SpillConfig.cpp SpillStats.cpp + SplitBlockBloomFilter.cpp StatsReporter.cpp SuccinctPrinter.cpp - TraceConfig.cpp) + TraceConfig.cpp +) velox_link_libraries( velox_common_base PUBLIC velox_exception Folly::folly fmt::fmt xsimd - PRIVATE velox_common_compression velox_process velox_test_util glog::glog) + PRIVATE velox_caching velox_common_compression velox_process velox_test_util glog::glog +) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) @@ -61,10 +59,8 @@ velox_link_libraries( velox_process glog::glog Folly::folly - fmt::fmt) + fmt::fmt +) velox_add_library(velox_status Status.cpp) -velox_link_libraries( - velox_status - PUBLIC fmt::fmt Folly::folly - PRIVATE glog::glog) +velox_link_libraries(velox_status PUBLIC fmt::fmt Folly::folly PRIVATE glog::glog) diff --git a/velox/common/base/CoalesceIo.h b/velox/common/base/CoalesceIo.h index f5f58ba1092b..7ad405cf2908 100644 --- a/velox/common/base/CoalesceIo.h +++ b/velox/common/base/CoalesceIo.h @@ -65,14 +65,13 @@ CoalesceIoStats coalesceIo( AddRanges addRanges, SkipRange skipRange, IoFunc ioFunc) { - std::vector buffers; int32_t startItem = 0; auto startOffset = offsetFunc(startItem); auto lastEndOffset = startOffset; std::vector ranges; CoalesceIoStats result; for (int32_t i = 0; i < items.size(); ++i) { - auto& item = items[i]; + const auto& item = items[i]; const auto itemOffset = offsetFunc(i); const auto itemSize = sizeFunc(i); result.payloadBytes += itemSize; diff --git a/velox/common/base/CompareFlags.h b/velox/common/base/CompareFlags.h index 265f84652f81..f7873ab330df 100644 --- a/velox/common/base/CompareFlags.h +++ b/velox/common/base/CompareFlags.h @@ -84,7 +84,7 @@ struct CompareFlags { /// ex: (null, 1) = (null, 1) is indeterminate. /// /// - If all fields compare results are true, then result is true. - /// ex: (1, 1) = (1, 1) is indeterminate. + /// ex: (1, 1) = (1, 1) is true. /// /// 4. Maps: /// - Keys are compared first, if keys are not equal values are not @@ -171,6 +171,13 @@ struct CompareFlags { } } + /// Returns a copy of the flags with the ascending flag flipped. + static CompareFlags reverseDirection(const CompareFlags& flags) { + CompareFlags result = flags; + result.ascending = !result.ascending; + return result; + } + std::string toString() const { return fmt::format( "[NullFirst[{}] Ascending[{}] EqualsOnly[{}] NullHandleMode[{}]]", diff --git a/velox/common/base/CountBits.h b/velox/common/base/CountBits.h index b267d2f636ef..f40fb95355e5 100644 --- a/velox/common/base/CountBits.h +++ b/velox/common/base/CountBits.h @@ -16,6 +16,8 @@ #pragma once +#include + namespace facebook::velox { // Copied from format.h of fmt. diff --git a/velox/common/base/Counters.cpp b/velox/common/base/Counters.cpp index ed7de5f18959..6e88aa653e86 100644 --- a/velox/common/base/Counters.cpp +++ b/velox/common/base/Counters.cpp @@ -45,6 +45,9 @@ void registerVeloxMetrics() { DEFINE_HISTOGRAM_METRIC( kMetricTaskBarrierProcessTimeMs, 1'000, 0, 30'000, 50, 90, 99, 100); + // Tracks the total number of splits received by all tasks. + DEFINE_METRIC(kMetricTaskSplitsCount, facebook::velox::StatType::COUNT); + /// ================== Cache Counters ================= // Tracks hive handle generation latency in range of [0, 100s] and reports @@ -75,6 +78,14 @@ void registerVeloxMetrics() { DEFINE_METRIC( kMetricMemoryAllocatorMappedBytes, facebook::velox::StatType::AVG); + // Number of bytes allocated and explicitly mmap'd by the application via + // allocateContiguous, outside of 'sizeClasses'. These pages are counted in + // 'kMetricMemoryAllocatorAllocatedBytes' and + // 'kMetricMemoryAllocatorMappedBytes'. + DEFINE_METRIC( + kMetricMemoryAllocatorExternalMappedBytes, + facebook::velox::StatType::AVG); + // Number of bytes currently allocated (used) from MemoryAllocator in the form // of 'Allocation' or 'ContiguousAllocation'. DEFINE_METRIC( @@ -84,13 +95,6 @@ void registerVeloxMetrics() { DEFINE_METRIC( kMetricMemoryAllocatorTotalUsedBytes, facebook::velox::StatType::AVG); - // Number of bytes currently mapped in MmapAllocator, in the form of - // 'ContiguousAllocation'. - // - // NOTE: This applies only to MmapAllocator - DEFINE_METRIC( - kMetricMmapAllocatorExternalMappedBytes, facebook::velox::StatType::AVG); - // Number of bytes currently allocated from MmapAllocator directly from raw // allocateBytes() interface, and internally allocated by malloc. Only small // chunks of memory are delegated to malloc. @@ -106,8 +110,13 @@ void registerVeloxMetrics() { // was opened to load the cache. DEFINE_METRIC(kMetricCacheMaxAgeSecs, facebook::velox::StatType::AVG); - // Total number of cache entries. - DEFINE_METRIC(kMetricMemoryCacheNumEntries, facebook::velox::StatType::AVG); + // Total number of tiny cache entries. + DEFINE_METRIC( + kMetricMemoryCacheNumTinyEntries, facebook::velox::StatType::AVG); + + // Total number of large cache entries. + DEFINE_METRIC( + kMetricMemoryCacheNumLargeEntries, facebook::velox::StatType::AVG); // Total number of cache entries that do not cache anything. DEFINE_METRIC( @@ -248,6 +257,10 @@ void registerVeloxMetrics() { // Total number of error while writing to SSD cache files. DEFINE_METRIC(kMetricSsdCacheWriteSsdErrors, facebook::velox::StatType::SUM); + // Total number of errors due to SSD no space for writes. + DEFINE_METRIC( + kMetricSsdCacheWriteNoSpaceErrors, facebook::velox::StatType::SUM); + // Total number of errors while writing SSD checkpoint file. DEFINE_METRIC( kMetricSsdCacheWriteCheckpointErrors, facebook::velox::StatType::SUM); @@ -255,6 +268,10 @@ void registerVeloxMetrics() { // Total number of writes dropped due to no cache space. DEFINE_METRIC(kMetricSsdCacheWriteSsdDropped, facebook::velox::StatType::SUM); + // Total number of writes dropped due to entry limit exceeded. + DEFINE_METRIC( + kMetricSsdCacheWriteExceedEntryLimit, facebook::velox::StatType::SUM); + // Total number of errors while reading from SSD cache files. DEFINE_METRIC(kMetricSsdCacheReadSsdErrors, facebook::velox::StatType::SUM); @@ -364,9 +381,6 @@ void registerVeloxMetrics() { kMetricTaskMemoryReclaimWaitTimeoutCount, facebook::velox::StatType::COUNT); - // Tracks the total number of splits received by all tasks. - DEFINE_METRIC(kMetricTaskSplitsCount, facebook::velox::StatType::COUNT); - // The number of times that the memory reclaim fails because the operator is // executing a non-reclaimable section where it is expected to have reserved // enough memory to execute without asking for more. Therefore, it is an @@ -643,16 +657,17 @@ void registerVeloxMetrics() { DEFINE_HISTOGRAM_METRIC( kMetricIndexLookupBlockedWaitTimeMs, 32, 0, 16L << 10, 50, 90, 99, 100); + // The number of index lookup results with error. + DEFINE_METRIC( + kMetricIndexLookupErrorResultCount, facebook::velox::StatType::COUNT); + /// ================== Table Scan Counters ================= - // The time distribution of table scan batch processing time in range of [0, - // 16s] with 512 buckets and reports P50, P90, P99, and P100. - DEFINE_HISTOGRAM_METRIC( - kMetricTableScanBatchProcessTimeMs, 32, 0, 16L << 10, 50, 90, 99, 100); + // Tracks the averaged table scan batch processing time in milliseconds. + DEFINE_METRIC( + kMetricTableScanBatchProcessTimeMs, facebook::velox::StatType::AVG); - // The size distribution of table scan output batch in range of [0, 512MB] - // with 512 buckets and reports P50, P90, P99, and P100 - DEFINE_HISTOGRAM_METRIC( - kMetricTableScanBatchBytes, 1L << 20, 0, 512L << 20, 50, 90, 99, 100); + // Tracks the averaged table scan output batch size in bytes. + DEFINE_METRIC(kMetricTableScanBatchBytes, facebook::velox::StatType::AVG); /// ================== Storage Counters ================= diff --git a/velox/common/base/Counters.h b/velox/common/base/Counters.h index 78703336eb68..3a144f1bc881 100644 --- a/velox/common/base/Counters.h +++ b/velox/common/base/Counters.h @@ -23,373 +23,374 @@ namespace facebook::velox { /// Velox metrics Registration. void registerVeloxMetrics(); -constexpr folly::StringPiece kMetricHiveFileHandleGenerateLatencyMs{ +constexpr std::string_view kMetricHiveFileHandleGenerateLatencyMs{ "velox.hive_file_handle_generate_latency_ms"}; -constexpr folly::StringPiece kMetricCacheShrinkCount{ - "velox.cache_shrink_count"}; +constexpr std::string_view kMetricCacheShrinkCount{"velox.cache_shrink_count"}; -constexpr folly::StringPiece kMetricCacheShrinkTimeMs{"velox.cache_shrink_ms"}; +constexpr std::string_view kMetricCacheShrinkTimeMs{"velox.cache_shrink_ms"}; -constexpr folly::StringPiece kMetricMaxSpillLevelExceededCount{ +constexpr std::string_view kMetricMaxSpillLevelExceededCount{ "velox.spill_max_level_exceeded_count"}; -constexpr folly::StringPiece kMetricQueryMemoryReclaimTimeMs{ +constexpr std::string_view kMetricQueryMemoryReclaimTimeMs{ "velox.query_memory_reclaim_time_ms"}; -constexpr folly::StringPiece kMetricQueryMemoryReclaimedBytes{ +constexpr std::string_view kMetricQueryMemoryReclaimedBytes{ "velox.query_memory_reclaim_bytes"}; -constexpr folly::StringPiece kMetricQueryMemoryReclaimCount{ +constexpr std::string_view kMetricQueryMemoryReclaimCount{ "velox.query_memory_reclaim_count"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimCount{ +constexpr std::string_view kMetricTaskMemoryReclaimCount{ "velox.task_memory_reclaim_count"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimWaitTimeMs{ +constexpr std::string_view kMetricTaskMemoryReclaimWaitTimeMs{ "velox.task_memory_reclaim_wait_ms"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimExecTimeMs{ +constexpr std::string_view kMetricTaskMemoryReclaimExecTimeMs{ "velox.task_memory_reclaim_exec_ms"}; -constexpr folly::StringPiece kMetricTaskMemoryReclaimWaitTimeoutCount{ +constexpr std::string_view kMetricTaskMemoryReclaimWaitTimeoutCount{ "velox.task_memory_reclaim_wait_timeout_count"}; -constexpr folly::StringPiece kMetricTaskSplitsCount{"velox.task_splits_count"}; +constexpr std::string_view kMetricTaskSplitsCount{"velox.task_splits_count"}; -constexpr folly::StringPiece kMetricOpMemoryReclaimTimeMs{ +constexpr std::string_view kMetricOpMemoryReclaimTimeMs{ "velox.op_memory_reclaim_time_ms"}; -constexpr folly::StringPiece kMetricOpMemoryReclaimedBytes{ +constexpr std::string_view kMetricOpMemoryReclaimedBytes{ "velox.op_memory_reclaim_bytes"}; -constexpr folly::StringPiece kMetricOpMemoryReclaimCount{ +constexpr std::string_view kMetricOpMemoryReclaimCount{ "velox.op_memory_reclaim_count"}; -constexpr folly::StringPiece kMetricMemoryNonReclaimableCount{ +constexpr std::string_view kMetricMemoryNonReclaimableCount{ "velox.memory_non_reclaimable_count"}; -constexpr folly::StringPiece kMetricMemoryPoolInitialCapacityBytes{ +constexpr std::string_view kMetricMemoryPoolInitialCapacityBytes{ "velox.memory_pool_initial_capacity_bytes"}; -constexpr folly::StringPiece kMetricMemoryPoolCapacityGrowCount{ +constexpr std::string_view kMetricMemoryPoolCapacityGrowCount{ "velox.memory_pool_capacity_growth_count"}; -constexpr folly::StringPiece kMetricMemoryPoolUsageLeakBytes{ +constexpr std::string_view kMetricMemoryPoolUsageLeakBytes{ "velox.memory_pool_usage_leak_bytes"}; -constexpr folly::StringPiece kMetricMemoryPoolReservationLeakBytes{ +constexpr std::string_view kMetricMemoryPoolReservationLeakBytes{ "velox.memory_pool_reservation_leak_bytes"}; -constexpr folly::StringPiece kMetricMemoryAllocatorDoubleFreeCount{ +constexpr std::string_view kMetricMemoryAllocatorDoubleFreeCount{ "velox.memory_allocator_double_free_count"}; -constexpr folly::StringPiece kMetricArbitratorLocalArbitrationCount{ +constexpr std::string_view kMetricArbitratorLocalArbitrationCount{ "velox.arbitrator_local_arbitration_count"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationCount{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationCount{ "velox.arbitrator_global_arbitration_count"}; -constexpr folly::StringPiece - kMetricArbitratorGlobalArbitrationNumReclaimVictims{ - "velox.arbitrator_global_arbitration_num_reclaim_victims"}; +constexpr std::string_view kMetricArbitratorGlobalArbitrationNumReclaimVictims{ + "velox.arbitrator_global_arbitration_num_reclaim_victims"}; -constexpr folly::StringPiece - kMetricArbitratorGlobalArbitrationFailedVictimCount{ - "velox.arbitrator_global_arbitration_failed_victim_count"}; +constexpr std::string_view kMetricArbitratorGlobalArbitrationFailedVictimCount{ + "velox.arbitrator_global_arbitration_failed_victim_count"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationBytes{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationBytes{ "velox.arbitrator_global_arbitration_bytes"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationTimeMs{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationTimeMs{ "velox.arbitrator_global_arbitration_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationWaitCount{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationWaitCount{ "velox.arbitrator_global_arbitration_wait_count"}; -constexpr folly::StringPiece kMetricArbitratorGlobalArbitrationWaitTimeMs{ +constexpr std::string_view kMetricArbitratorGlobalArbitrationWaitTimeMs{ "velox.arbitrator_global_arbitration_wait_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorAbortedCount{ +constexpr std::string_view kMetricArbitratorAbortedCount{ "velox.arbitrator_aborted_count"}; -constexpr folly::StringPiece kMetricArbitratorFailuresCount{ +constexpr std::string_view kMetricArbitratorFailuresCount{ "velox.arbitrator_failures_count"}; -constexpr folly::StringPiece kMetricArbitratorOpExecTimeMs{ +constexpr std::string_view kMetricArbitratorOpExecTimeMs{ "velox.arbitrator_op_exec_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorFreeCapacityBytes{ +constexpr std::string_view kMetricArbitratorFreeCapacityBytes{ "velox.arbitrator_free_capacity_bytes"}; -constexpr folly::StringPiece kMetricArbitratorFreeReservedCapacityBytes{ +constexpr std::string_view kMetricArbitratorFreeReservedCapacityBytes{ "velox.arbitrator_free_reserved_capacity_bytes"}; -constexpr folly::StringPiece kMetricDriverYieldCount{ - "velox.driver_yield_count"}; +constexpr std::string_view kMetricDriverYieldCount{"velox.driver_yield_count"}; -constexpr folly::StringPiece kMetricDriverQueueTimeMs{ +constexpr std::string_view kMetricDriverQueueTimeMs{ "velox.driver_queue_time_ms"}; -constexpr folly::StringPiece kMetricDriverExecTimeMs{ - "velox.driver_exec_time_ms"}; +constexpr std::string_view kMetricDriverExecTimeMs{"velox.driver_exec_time_ms"}; -constexpr folly::StringPiece kMetricSpilledInputBytes{ - "velox.spill_input_bytes"}; +constexpr std::string_view kMetricSpilledInputBytes{"velox.spill_input_bytes"}; -constexpr folly::StringPiece kMetricSpilledBytes{"velox.spill_bytes"}; +constexpr std::string_view kMetricSpilledBytes{"velox.spill_bytes"}; -constexpr folly::StringPiece kMetricSpilledRowsCount{"velox.spill_rows_count"}; +constexpr std::string_view kMetricSpilledRowsCount{"velox.spill_rows_count"}; -constexpr folly::StringPiece kMetricSpilledFilesCount{ - "velox.spill_files_count"}; +constexpr std::string_view kMetricSpilledFilesCount{"velox.spill_files_count"}; -constexpr folly::StringPiece kMetricSpillFillTimeMs{"velox.spill_fill_time_ms"}; +constexpr std::string_view kMetricSpillFillTimeMs{"velox.spill_fill_time_ms"}; -constexpr folly::StringPiece kMetricSpillSortTimeMs{"velox.spill_sort_time_ms"}; +constexpr std::string_view kMetricSpillSortTimeMs{"velox.spill_sort_time_ms"}; -constexpr folly::StringPiece kMetricSpillExtractVectorTimeMs{ +constexpr std::string_view kMetricSpillExtractVectorTimeMs{ "velox.spill_extract_vector_time_ms"}; -constexpr folly::StringPiece kMetricSpillSerializationTimeMs{ +constexpr std::string_view kMetricSpillSerializationTimeMs{ "velox.spill_serialization_time_ms"}; -constexpr folly::StringPiece kMetricSpillWritesCount{ - "velox.spill_writes_count"}; +constexpr std::string_view kMetricSpillWritesCount{"velox.spill_writes_count"}; -constexpr folly::StringPiece kMetricSpillFlushTimeMs{ - "velox.spill_flush_time_ms"}; +constexpr std::string_view kMetricSpillFlushTimeMs{"velox.spill_flush_time_ms"}; -constexpr folly::StringPiece kMetricSpillWriteTimeMs{ - "velox.spill_write_time_ms"}; +constexpr std::string_view kMetricSpillWriteTimeMs{"velox.spill_write_time_ms"}; -constexpr folly::StringPiece kMetricSpillMemoryBytes{ - "velox.spill_memory_bytes"}; +constexpr std::string_view kMetricSpillMemoryBytes{"velox.spill_memory_bytes"}; -constexpr folly::StringPiece kMetricSpillPeakMemoryBytes{ +constexpr std::string_view kMetricSpillPeakMemoryBytes{ "velox.spill_peak_memory_bytes"}; -constexpr folly::StringPiece kMetricFileWriterEarlyFlushedRawBytes{ +constexpr std::string_view kMetricFileWriterEarlyFlushedRawBytes{ "velox.file_writer_early_flushed_raw_bytes"}; -constexpr folly::StringPiece kMetricHiveSortWriterFinishTimeMs{ +constexpr std::string_view kMetricHiveSortWriterFinishTimeMs{ "velox.hive_sort_writer_finish_time_ms"}; -constexpr folly::StringPiece kMetricArbitratorRequestsCount{ +constexpr std::string_view kMetricArbitratorRequestsCount{ "velox.arbitrator_requests_count"}; -constexpr folly::StringPiece kMetricMemoryAllocatorMappedBytes{ +constexpr std::string_view kMetricMemoryAllocatorMappedBytes{ "velox.memory_allocator_mapped_bytes"}; -constexpr folly::StringPiece kMetricMemoryAllocatorAllocatedBytes{ +constexpr std::string_view kMetricMemoryAllocatorExternalMappedBytes{ + "velox.memory_allocator_external_mapped_bytes"}; + +constexpr std::string_view kMetricMemoryAllocatorAllocatedBytes{ "velox.memory_allocator_allocated_bytes"}; -constexpr folly::StringPiece kMetricMemoryAllocatorTotalUsedBytes{ +constexpr std::string_view kMetricMemoryAllocatorTotalUsedBytes{ "velox.memory_allocator_total_used_bytes"}; -constexpr folly::StringPiece kMetricMmapAllocatorExternalMappedBytes{ - "velox.mmap_allocator_external_mapped_bytes"}; - -constexpr folly::StringPiece kMetricMmapAllocatorDelegatedAllocatedBytes{ +constexpr std::string_view kMetricMmapAllocatorDelegatedAllocatedBytes{ "velox.mmap_allocator_delegated_allocated_bytes"}; -constexpr folly::StringPiece kMetricCacheMaxAgeSecs{"velox.cache_max_age_secs"}; +constexpr std::string_view kMetricCacheMaxAgeSecs{"velox.cache_max_age_secs"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEntries{ - "velox.memory_cache_num_entries"}; +constexpr std::string_view kMetricMemoryCacheNumTinyEntries{ + "velox.memory_cache_num_tiny_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEmptyEntries{ +constexpr std::string_view kMetricMemoryCacheNumLargeEntries{ + "velox.memory_cache_num_large_entries"}; + +constexpr std::string_view kMetricMemoryCacheNumEmptyEntries{ "velox.memory_cache_num_empty_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumSharedEntries{ +constexpr std::string_view kMetricMemoryCacheNumSharedEntries{ "velox.memory_cache_num_shared_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumExclusiveEntries{ +constexpr std::string_view kMetricMemoryCacheNumExclusiveEntries{ "velox.memory_cache_num_exclusive_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumPrefetchedEntries{ +constexpr std::string_view kMetricMemoryCacheNumPrefetchedEntries{ "velox.memory_cache_num_prefetched_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalTinyBytes{ +constexpr std::string_view kMetricMemoryCacheTotalTinyBytes{ "velox.memory_cache_total_tiny_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalLargeBytes{ +constexpr std::string_view kMetricMemoryCacheTotalLargeBytes{ "velox.memory_cache_total_large_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalTinyPaddingBytes{ +constexpr std::string_view kMetricMemoryCacheTotalTinyPaddingBytes{ "velox.memory_cache_total_tiny_padding_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalLargePaddingBytes{ +constexpr std::string_view kMetricMemoryCacheTotalLargePaddingBytes{ "velox.memory_cache_total_large_padding_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheTotalPrefetchBytes{ +constexpr std::string_view kMetricMemoryCacheTotalPrefetchBytes{ "velox.memory_cache_total_prefetched_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheSumEvictScore{ +constexpr std::string_view kMetricMemoryCacheSumEvictScore{ "velox.memory_cache_sum_evict_score"}; -constexpr folly::StringPiece kMetricMemoryCacheNumHits{ +constexpr std::string_view kMetricMemoryCacheNumHits{ "velox.memory_cache_num_hits"}; -constexpr folly::StringPiece kMetricMemoryCacheHitBytes{ +constexpr std::string_view kMetricMemoryCacheHitBytes{ "velox.memory_cache_hit_bytes"}; -constexpr folly::StringPiece kMetricMemoryCacheNumNew{ +constexpr std::string_view kMetricMemoryCacheNumNew{ "velox.memory_cache_num_new"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEvicts{ +constexpr std::string_view kMetricMemoryCacheNumEvicts{ "velox.memory_cache_num_evicts"}; -constexpr folly::StringPiece kMetricMemoryCacheNumSavableEvicts{ +constexpr std::string_view kMetricMemoryCacheNumSavableEvicts{ "velox.memory_cache_num_savable_evicts"}; -constexpr folly::StringPiece kMetricMemoryCacheNumEvictChecks{ +constexpr std::string_view kMetricMemoryCacheNumEvictChecks{ "velox.memory_cache_num_evict_checks"}; -constexpr folly::StringPiece kMetricMemoryCacheNumWaitExclusive{ +constexpr std::string_view kMetricMemoryCacheNumWaitExclusive{ "velox.memory_cache_num_wait_exclusive"}; -constexpr folly::StringPiece kMetricMemoryCacheNumAllocClocks{ +constexpr std::string_view kMetricMemoryCacheNumAllocClocks{ "velox.memory_cache_num_alloc_clocks"}; -constexpr folly::StringPiece kMetricMemoryCacheNumAgedOutEntries{ +constexpr std::string_view kMetricMemoryCacheNumAgedOutEntries{ "velox.memory_cache_num_aged_out_entries"}; -constexpr folly::StringPiece kMetricMemoryCacheNumStaleEntries{ +constexpr std::string_view kMetricMemoryCacheNumStaleEntries{ "velox.memory_cache_num_stale_entries"}; -constexpr folly::StringPiece kMetricSsdCacheCachedRegions{ +constexpr std::string_view kMetricSsdCacheCachedRegions{ "velox.ssd_cache_cached_regions"}; -constexpr folly::StringPiece kMetricSsdCacheCachedEntries{ +constexpr std::string_view kMetricSsdCacheCachedEntries{ "velox.ssd_cache_cached_entries"}; -constexpr folly::StringPiece kMetricSsdCacheCachedBytes{ +constexpr std::string_view kMetricSsdCacheCachedBytes{ "velox.ssd_cache_cached_bytes"}; -constexpr folly::StringPiece kMetricSsdCacheReadEntries{ +constexpr std::string_view kMetricSsdCacheReadEntries{ "velox.ssd_cache_read_entries"}; -constexpr folly::StringPiece kMetricSsdCacheReadBytes{ +constexpr std::string_view kMetricSsdCacheReadBytes{ "velox.ssd_cache_read_bytes"}; -constexpr folly::StringPiece kMetricSsdCacheWrittenEntries{ +constexpr std::string_view kMetricSsdCacheWrittenEntries{ "velox.ssd_cache_written_entries"}; -constexpr folly::StringPiece kMetricSsdCacheWrittenBytes{ +constexpr std::string_view kMetricSsdCacheWrittenBytes{ "velox.ssd_cache_written_bytes"}; -constexpr folly::StringPiece kMetricSsdCacheAgedOutEntries{ +constexpr std::string_view kMetricSsdCacheAgedOutEntries{ "velox.ssd_cache_aged_out_entries"}; -constexpr folly::StringPiece kMetricSsdCacheAgedOutRegions{ +constexpr std::string_view kMetricSsdCacheAgedOutRegions{ "velox.ssd_cache_aged_out_regions"}; -constexpr folly::StringPiece kMetricSsdCacheOpenSsdErrors{ +constexpr std::string_view kMetricSsdCacheOpenSsdErrors{ "velox.ssd_cache_open_ssd_errors"}; -constexpr folly::StringPiece kMetricSsdCacheOpenCheckpointErrors{ +constexpr std::string_view kMetricSsdCacheOpenCheckpointErrors{ "velox.ssd_cache_open_checkpoint_errors"}; -constexpr folly::StringPiece kMetricSsdCacheOpenLogErrors{ +constexpr std::string_view kMetricSsdCacheOpenLogErrors{ "velox.ssd_cache_open_log_errors"}; -constexpr folly::StringPiece kMetricSsdCacheMetaFileDeleteErrors{ +constexpr std::string_view kMetricSsdCacheMetaFileDeleteErrors{ "velox.ssd_cache_delete_meta_file_errors"}; -constexpr folly::StringPiece kMetricSsdCacheGrowFileErrors{ +constexpr std::string_view kMetricSsdCacheGrowFileErrors{ "velox.ssd_cache_grow_file_errors"}; -constexpr folly::StringPiece kMetricSsdCacheWriteSsdErrors{ +constexpr std::string_view kMetricSsdCacheWriteSsdErrors{ "velox.ssd_cache_write_ssd_errors"}; -constexpr folly::StringPiece kMetricSsdCacheWriteSsdDropped{ +constexpr std::string_view kMetricSsdCacheWriteNoSpaceErrors{ + "velox.ssd_cache_write_no_space_errors"}; + +constexpr std::string_view kMetricSsdCacheWriteSsdDropped{ "velox.ssd_cache_write_ssd_dropped"}; -constexpr folly::StringPiece kMetricSsdCacheWriteCheckpointErrors{ +constexpr std::string_view kMetricSsdCacheWriteExceedEntryLimit{ + "velox.ssd_cache_write_exceed_entry_limit"}; + +constexpr std::string_view kMetricSsdCacheWriteCheckpointErrors{ "velox.ssd_cache_write_checkpoint_errors"}; -constexpr folly::StringPiece kMetricSsdCacheReadCorruptions{ +constexpr std::string_view kMetricSsdCacheReadCorruptions{ "velox.ssd_cache_read_corruptions"}; -constexpr folly::StringPiece kMetricSsdCacheReadSsdErrors{ +constexpr std::string_view kMetricSsdCacheReadSsdErrors{ "velox.ssd_cache_read_ssd_errors"}; -constexpr folly::StringPiece kMetricSsdCacheReadCheckpointErrors{ +constexpr std::string_view kMetricSsdCacheReadCheckpointErrors{ "velox.ssd_cache_read_checkpoint_errors"}; -constexpr folly::StringPiece kMetricSsdCacheReadWithoutChecksum{ +constexpr std::string_view kMetricSsdCacheReadWithoutChecksum{ "velox.ssd_cache_read_without_checksum"}; -constexpr folly::StringPiece kMetricSsdCacheCheckpointsRead{ +constexpr std::string_view kMetricSsdCacheCheckpointsRead{ "velox.ssd_cache_checkpoints_read"}; -constexpr folly::StringPiece kMetricSsdCacheCheckpointsWritten{ +constexpr std::string_view kMetricSsdCacheCheckpointsWritten{ "velox.ssd_cache_checkpoints_written"}; -constexpr folly::StringPiece kMetricSsdCacheRegionsEvicted{ +constexpr std::string_view kMetricSsdCacheRegionsEvicted{ "velox.ssd_cache_regions_evicted"}; -constexpr folly::StringPiece kMetricSsdCacheRecoveredEntries{ +constexpr std::string_view kMetricSsdCacheRecoveredEntries{ "velox.ssd_cache_recovered_entries"}; -constexpr folly::StringPiece kMetricExchangeTransactionCreateDelay{ +constexpr std::string_view kMetricExchangeTransactionCreateDelay{ "velox.exchange.transaction_create_delay_ms"}; -constexpr folly::StringPiece kMetricExchangeDataTimeMs{ +constexpr std::string_view kMetricExchangeDataTimeMs{ "velox.exchange_data_time_ms"}; -constexpr folly::StringPiece kMetricExchangeDataBytes{ +constexpr std::string_view kMetricExchangeDataBytes{ "velox.exchange_data_bytes"}; -constexpr folly::StringPiece kMetricExchangeDataSize{ - "velox.exchange_data_size"}; +constexpr std::string_view kMetricExchangeDataSize{"velox.exchange_data_size"}; -constexpr folly::StringPiece kMetricExchangeDataCount{ +constexpr std::string_view kMetricExchangeDataCount{ "velox.exchange_data_count"}; -constexpr folly::StringPiece kMetricExchangeDataSizeTimeMs{ +constexpr std::string_view kMetricExchangeDataSizeTimeMs{ "velox.exchange_data_size_time_ms"}; -constexpr folly::StringPiece kMetricExchangeDataSizeCount{ +constexpr std::string_view kMetricExchangeDataSizeCount{ "velox.exchange_data_size_count"}; -constexpr folly::StringPiece kMetricStorageThrottledDurationMs{ +constexpr std::string_view kMetricStorageThrottledDurationMs{ "velox.storage_throttled_duration_ms"}; -constexpr folly::StringPiece kMetricStorageLocalThrottled{ +constexpr std::string_view kMetricStorageLocalThrottled{ "velox.storage_local_throttled_count"}; -constexpr folly::StringPiece kMetricStorageGlobalThrottled{ +constexpr std::string_view kMetricStorageGlobalThrottled{ "velox.storage_global_throttled_count"}; -constexpr folly::StringPiece kMetricStorageNetworkThrottled{ +constexpr std::string_view kMetricStorageNetworkThrottled{ "velox.storage_network_throttled_count"}; -constexpr folly::StringPiece kMetricIndexLookupResultRawBytes{ +constexpr std::string_view kMetricIndexLookupResultRawBytes{ "velox.index_lookup_result_raw_bytes"}; -constexpr folly::StringPiece kMetricIndexLookupResultBytes{ +constexpr std::string_view kMetricIndexLookupResultBytes{ "velox.index_lookup_result_bytes"}; -constexpr folly::StringPiece kMetricIndexLookupTimeMs{ +constexpr std::string_view kMetricIndexLookupTimeMs{ "velox.index_lookup_time_ms"}; -constexpr folly::StringPiece kMetricIndexLookupWaitTimeMs{ +constexpr std::string_view kMetricIndexLookupWaitTimeMs{ "velox.index_lookup_wait_time_ms"}; -constexpr folly::StringPiece kMetricIndexLookupBlockedWaitTimeMs{ +constexpr std::string_view kMetricIndexLookupBlockedWaitTimeMs{ "velox.index_lookup_blocked_wait_time_ms"}; -constexpr folly::StringPiece kMetricTableScanBatchProcessTimeMs{ +constexpr std::string_view kMetricIndexLookupErrorResultCount{ + "velox.index_lookup_error_result_count"}; + +constexpr std::string_view kMetricTableScanBatchProcessTimeMs{ "velox.table_scan_batch_process_time_ms"}; -constexpr folly::StringPiece kMetricTableScanBatchBytes{ +constexpr std::string_view kMetricTableScanBatchBytes{ "velox.table_scan_batch_bytes"}; -constexpr folly::StringPiece kMetricTaskBatchProcessTimeMs{ +constexpr std::string_view kMetricTaskBatchProcessTimeMs{ "velox.task_batch_process_time_ms"}; -constexpr folly::StringPiece kMetricTaskBarrierProcessTimeMs{ +constexpr std::string_view kMetricTaskBarrierProcessTimeMs{ "velox.task_barrier_process_time_ms"}; + } // namespace facebook::velox diff --git a/velox/common/base/ExceptionHelper.h b/velox/common/base/ExceptionHelper.h new file mode 100644 index 000000000000..193641ce63ae --- /dev/null +++ b/velox/common/base/ExceptionHelper.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace facebook::velox { + +struct CompileTimeEmptyString { + CompileTimeEmptyString() = default; + constexpr operator const char*() const { + return ""; + } + constexpr operator std::string_view() const { + return {}; + } + operator std::string() const { + return {}; + } +}; + +// When there is no message passed, we can statically detect this case +// and avoid passing even a single unnecessary argument pointer, +// minimizing size and thus maximizing eligibility for inlining. +inline CompileTimeEmptyString errorMessage() { + return {}; +} + +inline const char* errorMessage(const char* s) { + return s; +} + +inline std::string errorMessage(const std::string& str) { + return str; +} + +template +std::string errorMessage(fmt::string_view fmt, const Args&... args) { + return fmt::vformat(fmt, fmt::make_format_args(args...)); +} +} // namespace facebook::velox diff --git a/velox/common/base/Exceptions.h b/velox/common/base/Exceptions.h index 11fa7ae9ec59..4485b8b006a4 100644 --- a/velox/common/base/Exceptions.h +++ b/velox/common/base/Exceptions.h @@ -19,10 +19,10 @@ #include #include -#include #include #include +#include "velox/common/base/ExceptionHelper.h" #include "velox/common/base/FmtStdFormatters.h" #include "velox/common/base/VeloxException.h" @@ -39,19 +39,6 @@ struct VeloxCheckFailArgs { bool isRetriable; }; -struct CompileTimeEmptyString { - CompileTimeEmptyString() = default; - constexpr operator const char*() const { - return ""; - } - constexpr operator std::string_view() const { - return {}; - } - operator std::string() const { - return {}; - } -}; - // veloxCheckFail is defined as a separate helper function rather than // a macro or inline `throw` expression to allow the compiler *not* to // inline it when it is large. Having an out-of-line error path helps @@ -131,46 +118,26 @@ struct VeloxCheckFailStringType { template void veloxCheckFail( \ const VeloxCheckFailArgs& args, const std::string&); -// When there is no message passed, we can statically detect this case -// and avoid passing even a single unnecessary argument pointer, -// minimizing size and thus maximizing eligibility for inlining. -inline CompileTimeEmptyString errorMessage() { - return {}; -} - -inline const char* errorMessage(const char* s) { - return s; -} - -inline std::string errorMessage(const std::string& str) { - return str; -} - -template -std::string errorMessage(fmt::string_view fmt, const Args&... args) { - return fmt::vformat(fmt, fmt::make_format_args(args...)); -} - } // namespace detail -#define _VELOX_THROW_IMPL( \ - exception, exprStr, errorSource, errorCode, isRetriable, ...) \ - do { \ - /* GCC 9.2.1 doesn't accept this code with constexpr. */ \ - static const ::facebook::velox::detail::VeloxCheckFailArgs \ - veloxCheckFailArgs = { \ - __FILE__, \ - __LINE__, \ - __FUNCTION__, \ - exprStr, \ - errorSource, \ - errorCode, \ - isRetriable}; \ - auto message = ::facebook::velox::detail::errorMessage(__VA_ARGS__); \ - ::facebook::velox::detail::veloxCheckFail< \ - exception, \ - typename ::facebook::velox::detail::VeloxCheckFailStringType< \ - decltype(message)>::type>(veloxCheckFailArgs, message); \ +#define _VELOX_THROW_IMPL( \ + exception, exprStr, errorSource, errorCode, isRetriable, ...) \ + do { \ + /* GCC 9.2.1 doesn't accept this code with constexpr. */ \ + static const ::facebook::velox::detail::VeloxCheckFailArgs \ + veloxCheckFailArgs = { \ + __FILE__, \ + __LINE__, \ + __FUNCTION__, \ + exprStr, \ + errorSource, \ + errorCode, \ + isRetriable}; \ + auto message = ::facebook::velox::errorMessage(__VA_ARGS__); \ + ::facebook::velox::detail::veloxCheckFail< \ + exception, \ + typename ::facebook::velox::detail::VeloxCheckFailStringType< \ + decltype(message)>::type>(veloxCheckFailArgs, message); \ } while (0) #define _VELOX_CHECK_AND_THROW_IMPL( \ diff --git a/velox/common/base/Macros.h b/velox/common/base/Macros.h index 48bc0f38dcdf..62664a488f53 100644 --- a/velox/common/base/Macros.h +++ b/velox/common/base/Macros.h @@ -32,3 +32,13 @@ // Need this extra layer to expand __COUNTER__. #define VELOX_VARNAME_IMPL(x, y) VELOX_CONCAT(x, y) #define VELOX_VARNAME(x) VELOX_VARNAME_IMPL(x, __COUNTER__) + +// Workaround for GCC bug, it was fixed only in GCC 13. +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=93413 +// TLDR: GCC 12 and earlier do not support constexpr static variables for +// non-template classes with virtual default destructors. +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ < 13 +#define VELOX_CONSTEXPR_SINGLETON static const +#else +#define VELOX_CONSTEXPR_SINGLETON static constexpr +#endif diff --git a/velox/common/base/Nulls.h b/velox/common/base/Nulls.h index b3d7e7d5a7fa..cc84ea7efc96 100644 --- a/velox/common/base/Nulls.h +++ b/velox/common/base/Nulls.h @@ -34,28 +34,29 @@ constexpr char kNotNullByte = 0xff; constexpr uint64_t kNull64 = 0UL; constexpr uint64_t kNotNull64 = (~0UL); -inline bool isBitNull(const uint64_t* bits, int32_t index) { +inline bool isBitNull(const uint64_t* bits, uint32_t index) { return isBitSet(bits, index) == kNull; } -inline void setNull(uint64_t* bits, int32_t index) { +inline void setNull(uint64_t* bits, uint32_t index) { clearBit(bits, index); } -inline void clearNull(uint64_t* bits, int32_t index) { +inline void clearNull(uint64_t* bits, uint32_t index) { setBit(bits, index); } -inline void setNull(uint64_t* bits, int32_t index, bool isNull) { +inline void setNull(uint64_t* bits, uint32_t index, bool isNull) { setBit(bits, index, !isNull); } inline uint64_t -countNonNulls(const uint64_t* nulls, int32_t begin, int32_t end) { +countNonNulls(const uint64_t* nulls, uint32_t begin, uint32_t end) { return countBits(nulls, begin, end); } -inline uint64_t countNulls(const uint64_t* nulls, int32_t begin, int32_t end) { +inline uint64_t +countNulls(const uint64_t* nulls, uint32_t begin, uint32_t end) { return (end - begin) - countNonNulls(nulls, begin, end); } diff --git a/velox/common/base/PeriodicStatsReporter.cpp b/velox/common/base/PeriodicStatsReporter.cpp index ba9c3134dbed..fb07906dcfe8 100644 --- a/velox/common/base/PeriodicStatsReporter.cpp +++ b/velox/common/base/PeriodicStatsReporter.cpp @@ -112,21 +112,22 @@ void PeriodicStatsReporter::reportAllocatorStats() { RECORD_METRIC_VALUE( kMetricMemoryAllocatorMappedBytes, (velox::memory::AllocationTraits::pageBytes(allocator_->numMapped()))); + RECORD_METRIC_VALUE( + kMetricMemoryAllocatorExternalMappedBytes, + (velox::memory::AllocationTraits::pageBytes( + allocator_->numExternalMapped()))); RECORD_METRIC_VALUE( kMetricMemoryAllocatorAllocatedBytes, (velox::memory::AllocationTraits::pageBytes(allocator_->numAllocated()))); RECORD_METRIC_VALUE( kMetricMemoryAllocatorTotalUsedBytes, (allocator_->totalUsedBytes())); + // TODO(jtan6): Remove condition after T150019700 is done if (auto* mmapAllocator = dynamic_cast(allocator_)) { RECORD_METRIC_VALUE( kMetricMmapAllocatorDelegatedAllocatedBytes, (mmapAllocator->numMallocBytes())); - RECORD_METRIC_VALUE( - kMetricMmapAllocatorExternalMappedBytes, - velox::memory::AllocationTraits::pageBytes( - (mmapAllocator->numExternalMapped()))); } // TODO(xiaoxmeng): add memory allocation size stats. } @@ -138,7 +139,10 @@ void PeriodicStatsReporter::reportCacheStats() { const auto cacheStats = cache_->refreshStats(); // Memory cache snapshot stats. - RECORD_METRIC_VALUE(kMetricMemoryCacheNumEntries, cacheStats.numEntries); + RECORD_METRIC_VALUE( + kMetricMemoryCacheNumTinyEntries, cacheStats.numTinyEntries); + RECORD_METRIC_VALUE( + kMetricMemoryCacheNumLargeEntries, cacheStats.numLargeEntries); RECORD_METRIC_VALUE( kMetricMemoryCacheNumEmptyEntries, cacheStats.numEmptyEntries); RECORD_METRIC_VALUE(kMetricMemoryCacheNumSharedEntries, cacheStats.numShared); @@ -207,8 +211,13 @@ void PeriodicStatsReporter::reportCacheStats() { kMetricSsdCacheGrowFileErrors, deltaSsdStats.growFileErrors); REPORT_IF_NOT_ZERO( kMetricSsdCacheWriteSsdErrors, deltaSsdStats.writeSsdErrors); + REPORT_IF_NOT_ZERO( + kMetricSsdCacheWriteNoSpaceErrors, deltaSsdStats.writeSsdNoSpaceErrors); REPORT_IF_NOT_ZERO( kMetricSsdCacheWriteSsdDropped, deltaSsdStats.writeSsdDropped); + REPORT_IF_NOT_ZERO( + kMetricSsdCacheWriteExceedEntryLimit, + deltaSsdStats.writeSsdExceedEntryLimit); REPORT_IF_NOT_ZERO( kMetricSsdCacheWriteCheckpointErrors, deltaSsdStats.writeCheckpointErrors); diff --git a/velox/common/base/Portability.h b/velox/common/base/Portability.h index 60049fcc54c7..85e890585e58 100644 --- a/velox/common/base/Portability.h +++ b/velox/common/base/Portability.h @@ -19,6 +19,7 @@ #include #include #include +#include #include inline size_t count_trailing_zeros(uint64_t x) { diff --git a/velox/common/base/RuntimeMetrics.cpp b/velox/common/base/RuntimeMetrics.cpp index 691eb9f4a403..3130cbcabef1 100644 --- a/velox/common/base/RuntimeMetrics.cpp +++ b/velox/common/base/RuntimeMetrics.cpp @@ -48,7 +48,7 @@ void RuntimeMetric::merge(const RuntimeMetric& other) max = std::max(max, other.max); } -void RuntimeMetric::printMetric(std::stringstream& stream) const { +void RuntimeMetric::printMetric(std::ostream& stream) const { switch (unit) { case RuntimeCounter::Unit::kNanos: stream << " sum: " << succinctNanos(sum) << ", count: " << count @@ -80,7 +80,6 @@ std::string RuntimeMetric::toString() const { succinctNanos(min), succinctNanos(max), succinctNanos(count == 0 ? 0 : sum / count)); - break; case RuntimeCounter::Unit::kBytes: return fmt::format( "sum:{}, count:{}, min:{}, max:{}, avg: {}", @@ -89,7 +88,6 @@ std::string RuntimeMetric::toString() const { succinctBytes(min), succinctBytes(max), succinctBytes(count == 0 ? 0 : sum / count)); - break; case RuntimeCounter::Unit::kNone: [[fallthrough]]; default: diff --git a/velox/common/base/RuntimeMetrics.h b/velox/common/base/RuntimeMetrics.h index 55b4d027da9d..6402719fa970 100644 --- a/velox/common/base/RuntimeMetrics.h +++ b/velox/common/base/RuntimeMetrics.h @@ -19,9 +19,6 @@ #include #include #include -#include - -#include "velox/common/base/SuccinctPrinter.h" namespace facebook::velox { @@ -51,13 +48,21 @@ struct RuntimeMetric { RuntimeCounter::Unit _unit = RuntimeCounter::Unit::kNone) : unit(_unit), sum{value}, count{1}, min{value}, max{value} {} + explicit RuntimeMetric( + int64_t _sum, + int64_t _count, + int64_t _min, + int64_t _max, + RuntimeCounter::Unit _unit = RuntimeCounter::Unit::kNone) + : unit(_unit), sum{_sum}, count{_count}, min{_min}, max{_max} {} + void addValue(int64_t value); /// Aggregate sets 'min' and 'max' to 'sum', also sets 'count' to 1 if /// positive. void aggregate(); - void printMetric(std::stringstream& stream) const; + void printMetric(std::ostream& stream) const; void merge(const RuntimeMetric& other); @@ -81,11 +86,10 @@ class BaseRuntimeStatWriter { /// thread. /// NOTE: This is only used by the Velox Driver at the moment, which ensures the /// active Operator is being used by the writer. -void setThreadLocalRunTimeStatWriter( - BaseRuntimeStatWriter* FOLLY_NULLABLE writer); +void setThreadLocalRunTimeStatWriter(BaseRuntimeStatWriter* writer); /// Retrives the current runtime stats writer. -BaseRuntimeStatWriter* FOLLY_NULLABLE getThreadLocalRunTimeStatWriter(); +BaseRuntimeStatWriter* getThreadLocalRunTimeStatWriter(); /// Writes runtime counter to the current Operator running on that thread. void addThreadLocalRuntimeStat( @@ -95,8 +99,7 @@ void addThreadLocalRuntimeStat( /// Scope guard to conveniently set and revert back the current stat writer. class RuntimeStatWriterScopeGuard { public: - explicit RuntimeStatWriterScopeGuard( - BaseRuntimeStatWriter* FOLLY_NULLABLE writer) + explicit RuntimeStatWriterScopeGuard(BaseRuntimeStatWriter* writer) : prevWriter_(getThreadLocalRunTimeStatWriter()) { setThreadLocalRunTimeStatWriter(writer); } @@ -106,7 +109,7 @@ class RuntimeStatWriterScopeGuard { } private: - BaseRuntimeStatWriter* const FOLLY_NULLABLE prevWriter_; + BaseRuntimeStatWriter* const prevWriter_; }; } // namespace facebook::velox diff --git a/velox/common/base/Semaphore.h b/velox/common/base/Semaphore.h index 760a07262735..a18acc28ed8d 100644 --- a/velox/common/base/Semaphore.h +++ b/velox/common/base/Semaphore.h @@ -63,8 +63,8 @@ class Semaphore { private: std::mutex mutex_; std::condition_variable cv_; - volatile int32_t count_; - volatile int32_t numWaiting_{0}; + int32_t count_; + int32_t numWaiting_{0}; }; } // namespace facebook::velox diff --git a/velox/common/base/SimdUtil-inl.h b/velox/common/base/SimdUtil-inl.h index ba5593e96e1a..937c86d9f7ee 100644 --- a/velox/common/base/SimdUtil-inl.h +++ b/velox/common/base/SimdUtil-inl.h @@ -25,6 +25,13 @@ XSIMD_DECLARE_SIMD_REGISTER( } // namespace xsimd::types #endif +#if XSIMD_WITH_SVE +#include +namespace xsimd::types { +XSIMD_DECLARE_SIMD_REGISTER(bool, sve, detail::sve_vector_type); +} +#endif + namespace facebook::velox::simd { namespace detail { @@ -91,6 +98,10 @@ struct BitMask { return (vaddv_u8(vget_high_u8(vmask)) << 8) | vaddv_u8(vget_low_u8(vmask)); } #endif + + static int toBitMask(xsimd::batch_bool mask, const xsimd::generic&) { + return genericToBitMask(mask); + } }; template @@ -290,9 +301,21 @@ template <> inline xsimd::batch_bool leadingMask( int i, const xsimd::default_arch&) { + /* + With GCC builds, compiler throws an error "invalid cast" on reintepreting to + the same data type, in SVE 256's case, svbool_t + __attribute__((arm_sve_vector_bits(256))). + So this is a workaround for now. Can be updated once the bug in GCC is + resolved in future GCC versions. + */ + +#if XSIMD_WITH_SVE && defined(__GNUC__) && !defined(__clang__) + return xsimd::batch_bool(leadingMask32[i].data); +#else return reinterpret_cast< xsimd::batch_bool::register_type>( leadingMask32[i].data); +#endif } template <> @@ -306,9 +329,21 @@ template <> inline xsimd::batch_bool leadingMask( int i, const xsimd::default_arch&) { + /* + With GCC builds, compiler throws an error "invalid cast" on reintepreting to + the same data type, in SVE 256's case, svbool_t + __attribute__((arm_sve_vector_bits(256))). + So this is a workaround for now. Can be updated once the bug in GCC is + resolved in future GCC versions. + */ + +#if XSIMD_WITH_SVE && defined(__GNUC__) && !defined(__clang__) + return xsimd::batch_bool(leadingMask64[i].data); +#else return reinterpret_cast< xsimd::batch_bool::register_type>( leadingMask64[i].data); +#endif } } // namespace detail @@ -341,7 +376,7 @@ struct CopyWord, A> { // sizeof(T). Returns false if 'bytes' went to 0. template inline bool copyNextWord(void*& to, const void*& from, int64_t& bytes) { - if (bytes >= sizeof(T)) { + if (bytes >= static_cast(sizeof(T))) { CopyWord::apply(to, from); bytes -= sizeof(T); if (!bytes) { @@ -362,7 +397,7 @@ inline void memcpy(void* to, const void* from, int64_t bytes, const A& arch) { return; } } - while (bytes >= sizeof(int64_t)) { + while (bytes >= static_cast(sizeof(int64_t))) { if (!detail::copyNextWord(to, from, bytes)) { return; } @@ -417,7 +452,7 @@ void memset(void* to, char data, int32_t bytes, const A& arch) { } } int64_t data64 = *reinterpret_cast(&v); - while (bytes >= sizeof(int64_t)) { + while (bytes >= static_cast(sizeof(int64_t))) { if (!detail::setNextWord(to, data64, bytes, arch)) { return; } @@ -525,6 +560,18 @@ struct Gather { return maskApply(src, mask, base, loadIndices(indices, arch), arch); } +#if XSIMD_WITH_SVE + template + static xsimd::batch maskApply( + xsimd::batch src, + xsimd::batch_bool mask, + const T* base, + const int32_t* indices, + const xsimd::sve& arch) { + return genericMaskGather(src, mask, base, indices); + } +#endif + template static xsimd::batch maskApply( xsimd::batch src, @@ -582,6 +629,22 @@ struct Gather { return Batch64::load_unaligned(indices); } +#if (XSIMD_WITH_SVE && SVE_BITS == 128) + + static Batch64 loadIndices( + const int32_t* indices, + const xsimd::sve&) { + return Batch64::load_unaligned(indices); + } +#endif +#if (XSIMD_WITH_SVE && SVE_BITS == 256) + static Batch128 loadIndices( + const int32_t* indices, + const xsimd::sve&) { + return Batch128::load_unaligned(indices); + } +#endif + static Batch64 loadIndices( const int32_t* indices, const xsimd::neon&) { @@ -602,6 +665,34 @@ struct Gather { return genericGather(base, indices); } +#if (XSIMD_WITH_SVE && SVE_BITS == 256) + template + static xsimd::batch + apply(const T* base, Batch128 vindex, const xsimd::sve&) { + constexpr int N = xsimd::batch::size; + alignas(A::alignment()) T dst[N]; + auto bytes = reinterpret_cast(base); + for (int i = 0; i < N; ++i) { + dst[i] = *reinterpret_cast(bytes + vindex.data[i] * kScale); + } + return xsimd::load_aligned(dst); + } +#endif + +#if (XSIMD_WITH_SVE && SVE_BITS == 128) + template + static xsimd::batch + apply(const T* base, Batch64 vindex, const xsimd::sve&) { + constexpr int N = xsimd::batch::size; + alignas(A::alignment()) T dst[N]; + auto bytes = reinterpret_cast(base); + for (int i = 0; i < N; ++i) { + dst[i] = *reinterpret_cast(bytes + vindex.data[i] * kScale); + } + return xsimd::load_aligned(dst); + } +#endif + #if XSIMD_WITH_AVX2 template static xsimd::batch apply( @@ -624,6 +715,48 @@ struct Gather { return genericMaskGather(src, mask, base, indices); } +#if XSIMD_WITH_SVE + template + static xsimd::batch maskApply( + xsimd::batch src, + xsimd::batch_bool mask, + const T* base, + const int32_t* indices, + const xsimd::sve& arch) { + return genericMaskGather(src, mask, base, indices); + } +#endif + +#if (XSIMD_WITH_SVE && SVE_BITS == 128) + template + static xsimd::batch maskApply( + xsimd::batch src, + xsimd::batch_bool mask, + const T* base, + Batch64 vindex, + const xsimd::sve& arch) { + constexpr int N = Batch64::size; + alignas(A::alignment()) int32_t indices[N]; + vindex.store_unaligned(indices); + return maskApply(src, mask, base, indices, arch); + } +#endif + +#if (XSIMD_WITH_SVE && SVE_BITS == 256) + template + static xsimd::batch maskApply( + xsimd::batch src, + xsimd::batch_bool mask, + const T* base, + Batch128 vindex, + const xsimd::sve& arch) { + constexpr int N = Batch128::size; + alignas(A::alignment()) int32_t indices[N]; + vindex.store_unaligned(indices); + return maskApply(src, mask, base, indices, arch); + } +#endif + #if XSIMD_WITH_AVX2 template static xsimd::batch maskApply( @@ -697,6 +830,18 @@ struct Gather { return maskApply(src, mask, base, loadIndices(indices, arch), arch); } +#if XSIMD_WITH_SVE + template + static xsimd::batch maskApply( + xsimd::batch src, + xsimd::batch_bool mask, + const T* base, + const int64_t* indices, + const xsimd::sve& arch) { + return genericMaskGather(src, mask, base, indices); + } +#endif + #if XSIMD_WITH_AVX2 template static xsimd::batch maskApply( @@ -756,6 +901,26 @@ xsimd::batch pack32( } #endif +template +xsimd::batch pack32( + xsimd::batch x, + xsimd::batch y, + const xsimd::generic&) { + constexpr std::size_t size = xsimd::batch::size; + alignas(A) int32_t xArr[size]; + alignas(A) int32_t yArr[size]; + alignas(A) int16_t resultArr[2 * size]; + + x.store_unaligned(xArr); + y.store_unaligned(yArr); + + for (std::size_t i = 0; i < size; ++i) { + resultArr[i] = static_cast(xArr[i]); + resultArr[i + size] = static_cast(yArr[i]); + } + return xsimd::batch::load_unaligned(resultArr); +} + #if XSIMD_WITH_AVX2 template xsimd::batch pack32( @@ -795,6 +960,16 @@ template Batch64 genericPermute(Batch64 data, Batch64 idx) { static_assert(data.size >= idx.size); Batch64 ans; + for (size_t i = 0; i < idx.size; ++i) { + ans.data[i] = data.data[idx.data[i]]; + } + return ans; +} + +template +Batch128 genericPermute(Batch128 data, Batch128 idx) { + static_assert(data.size >= idx.size); + Batch128 ans; for (int i = 0; i < idx.size; ++i) { ans.data[i] = data.data[idx.data[i]]; } @@ -865,7 +1040,8 @@ xsimd::batch gather( } else { second = xsimd::batch::broadcast(0); } - return detail::pack32(first, second, arch); + auto packed = detail::pack32(first, second, arch); + return packed; } namespace detail { @@ -988,6 +1164,25 @@ struct GetHalf { } #endif + template + static xsimd::batch apply( + xsimd::batch data, + const xsimd::generic&) { + constexpr std::size_t input_size = xsimd::batch::size; + constexpr std::size_t half_size = input_size / 2; + + std::array input_buffer; + data.store_aligned(input_buffer.data()); + + std::array output_buffer; + for (std::size_t i = 0; i < half_size; ++i) { + output_buffer[i] = static_cast( + kSecond ? input_buffer[i + half_size] : input_buffer[i]); + } + + return xsimd::load_aligned(output_buffer.data()); + } + #if XSIMD_WITH_NEON template static xsimd::batch apply( @@ -1039,6 +1234,23 @@ struct GetHalf { return vmovl_u32(vreinterpret_u32_s32(half)); } #endif + + template + static xsimd::batch apply( + xsimd::batch data, + const xsimd::generic&) { + constexpr std::size_t input_size = xsimd::batch::size; + constexpr std::size_t half_size = input_size / 2; + std::array input_buffer; + data.store_aligned(input_buffer.data()); + std::array output_buffer; + for (std::size_t i = 0; i < half_size; ++i) { + output_buffer[i] = static_cast( + kSecond ? static_cast(input_buffer[i + half_size]) + : static_cast(input_buffer[i])); + } + return xsimd::load_aligned(output_buffer.data()); + } }; } // namespace detail @@ -1082,6 +1294,21 @@ struct Filter { return ans; } #endif + +#if XSIMD_WITH_SVE + static xsimd::batch + apply(xsimd::batch data, int mask, const xsimd::sve& arch) { + int lane_count = svcntb() / sizeof(T); + T compressed[lane_count]; + int idx = 0; + for (int i = 0; i < lane_count; i++) { + if (mask & (1 << i)) { + compressed[idx++] = data.get(i); + } + } + return xsimd::load_unaligned(compressed); + } +#endif }; template @@ -1132,6 +1359,15 @@ struct Crc32 { } #endif +#if XSIMD_WITH_SVE + static uint32_t apply(uint32_t checksum, uint64_t value, const xsimd::sve&) { + __asm__("crc32cx %w[c], %w[c], %x[v]" + : [c] "+r"(checksum) + : [v] "r"(value)); + return checksum; + } +#endif + #if XSIMD_WITH_NEON static uint32_t apply(uint32_t checksum, uint64_t value, const xsimd::neon&) { __asm__("crc32cx %w[c], %w[c], %x[v]" @@ -1157,6 +1393,20 @@ xsimd::batch iota(const A&) { namespace detail { +#if (XSIMD_WITH_SVE && SVE_BITS == 256) +template +struct HalfBatchImpl>> { + using Type = Batch128; +}; +#endif + +#if (XSIMD_WITH_SVE && SVE_BITS == 128) +template +struct HalfBatchImpl>> { + using Type = Batch64; +}; +#endif + template struct HalfBatchImpl>> { using Type = xsimd::batch; @@ -1194,10 +1444,18 @@ struct ReinterpretBatch { } }; -#if XSIMD_WITH_NEON || XSIMD_WITH_NEON64 +#if XSIMD_WITH_NEON || XSIMD_WITH_NEON64 || XSIMD_WITH_SVE template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_u8_s8(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1217,6 +1475,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_s8_u8(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1236,6 +1502,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_u16_s16(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1255,6 +1529,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_s16_u16(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1274,6 +1556,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_u32_s32(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1293,6 +1583,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_s32_u32(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1312,6 +1610,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_u64_u32(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1331,6 +1637,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_u64_s64(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1350,6 +1664,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_u32_s64(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1369,6 +1691,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_s64_u64(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1388,6 +1718,14 @@ struct ReinterpretBatch { template struct ReinterpretBatch { +#if XSIMD_WITH_SVE + static xsimd::batch apply( + xsimd::batch data, + const xsimd::sve&) { + return svreinterpret_u32_u64(data.data); + } +#endif + #if XSIMD_WITH_NEON static xsimd::batch apply( xsimd::batch data, @@ -1444,6 +1782,9 @@ namespace detail { #if XSIMD_WITH_AVX2 using CharVector = xsimd::batch; #define VELOX_SIMD_STRSTR 1 +#elif XSIMD_WITH_SVE +using CharVector = xsimd::batch; +#define VELOX_SIMD_STRSTR 1 #elif XSIMD_WITH_NEON using CharVector = xsimd::batch; #define VELOX_SIMD_STRSTR 1 @@ -1458,7 +1799,7 @@ size_t FOLLY_ALWAYS_INLINE smidStrstrMemcmp( const char* s, int64_t n, const char* needle, - size_t needleSize) { + int64_t needleSize) { static_assert(kNeedleSize >= 2); VELOX_DCHECK_GT(needleSize, 1); VELOX_DCHECK_GT(n, 0); @@ -1506,7 +1847,7 @@ size_t FOLLY_ALWAYS_INLINE smidStrstrMemcmp( } return std::string::npos; -}; +} #endif diff --git a/velox/common/base/SimdUtil.h b/velox/common/base/SimdUtil.h index edc7e06c2dfc..d853931776e3 100644 --- a/velox/common/base/SimdUtil.h +++ b/velox/common/base/SimdUtil.h @@ -131,6 +131,50 @@ struct Batch64 { } }; +template +struct Batch128 { + static constexpr size_t size = [] { + static_assert(16 % sizeof(T) == 0); + return 16 / sizeof(T); + }(); + + T data[size]; + + static Batch128 from(std::initializer_list values) { + VELOX_DCHECK_EQ(values.size(), size); + Batch128 ans; + for (int i = 0; i < size; ++i) { + ans.data[i] = *(values.begin() + i); + } + return ans; + } + + void store_unaligned(T* out) const { + std::copy(std::begin(data), std::end(data), out); + } + + static Batch128 load_aligned(const T* mem) { + return load_unaligned(mem); + } + + static Batch128 load_unaligned(const T* mem) { + Batch128 ans; + std::copy(mem, mem + size, ans.data); + return ans; + } + + friend Batch128 operator+(Batch128 x, T y) { + for (int i = 0; i < size; ++i) { + x.data[i] += y; + } + return x; + } + + friend Batch128 operator-(Batch128 x, T y) { + return x + (-y); + } +}; + namespace detail { template struct Gather; @@ -178,6 +222,17 @@ gather(const T* base, Batch64 vindex, const A& arch = {}) { return Impl::template apply(base, vindex.data, arch); } +template < + typename T, + typename IndexType, + int kScale = sizeof(T), + typename A = xsimd::default_arch> +xsimd::batch +gather(const T* base, Batch128 vindex, const A& arch = {}) { + using Impl = detail::Gather; + return Impl::template apply(base, vindex.data, arch); +} + // Same as 'gather' above except the indices are read from memory. template < typename T, @@ -223,6 +278,21 @@ xsimd::batch maskGather( return Impl::template maskApply(src, mask, base, vindex.data, arch); } +template < + typename T, + typename IndexType, + int kScale = sizeof(T), + typename A = xsimd::default_arch> +xsimd::batch maskGather( + xsimd::batch src, + xsimd::batch_bool mask, + const T* base, + Batch128 vindex, + const A& arch = {}) { + using Impl = detail::Gather; + return Impl::template maskApply(src, mask, base, vindex.data, arch); +} + // Same as 'maskGather' above but read indices from memory. template < typename T, @@ -360,6 +430,7 @@ uint32_t crc32U64(uint32_t checksum, uint64_t value, const A& arch = {}) { template xsimd::batch iota(const A& = {}); +#ifdef VELOX_ENABLE_LOAD_SIMD_VALUE_BUFFER // Returns a batch with all elements set to value. For batch we // use one bit to represent one element. template @@ -375,6 +446,7 @@ xsimd::batch setAll(T value, const A& = {}) { return xsimd::broadcast(value); } } +#endif // Stores 'data' into 'destination' for the lanes in 'mask'. 'mask' is expected // to specify contiguous lower lanes of 'batch'. For non-SIMD cases, 'mask' is diff --git a/velox/common/base/SkewedPartitionBalancer.cpp b/velox/common/base/SkewedPartitionBalancer.cpp index 24cb1511469e..7a1b4a6efa34 100644 --- a/velox/common/base/SkewedPartitionBalancer.cpp +++ b/velox/common/base/SkewedPartitionBalancer.cpp @@ -30,9 +30,10 @@ SkewedPartitionRebalancer::SkewedPartitionRebalancer( numTasks_(numTasks), minProcessedBytesRebalanceThresholdPerPartition_( minProcessedBytesRebalanceThresholdPerPartition), - minProcessedBytesRebalanceThreshold_(std::max( - minProcessedBytesRebalanceThreshold, - minProcessedBytesRebalanceThresholdPerPartition_)), + minProcessedBytesRebalanceThreshold_( + std::max( + minProcessedBytesRebalanceThreshold, + minProcessedBytesRebalanceThresholdPerPartition_)), partitionRowCount_(numPartitions_), partitionAssignments_(numPartitions_) { VELOX_CHECK_GT(numPartitions_, 0); @@ -174,9 +175,10 @@ void SkewedPartitionRebalancer::rebalanceBasedOnTaskSkewness( const uint32_t totalAssignedTasks = partitionAssignments_[maxPartition].size(); - if (partitionBytes_[maxPartition] < - (minProcessedBytesRebalanceThresholdPerPartition_ * - totalAssignedTasks)) { + if (partitionBytes_[maxPartition] == 0 || + (partitionBytes_[maxPartition] < + (minProcessedBytesRebalanceThresholdPerPartition_ * + totalAssignedTasks))) { break; } diff --git a/velox/common/base/SpillConfig.cpp b/velox/common/base/SpillConfig.cpp index dd428a41ec7b..caf55d773865 100644 --- a/velox/common/base/SpillConfig.cpp +++ b/velox/common/base/SpillConfig.cpp @@ -34,8 +34,10 @@ SpillConfig::SpillConfig( uint64_t _maxSpillRunRows, uint64_t _writerFlushThresholdSize, const std::string& _compressionKind, + uint32_t _numMaxMergeFiles, std::optional _prefixSortConfig, - const std::string& _fileCreateConfig) + const std::string& _fileCreateConfig, + uint32_t _windowMinReadBatchRows) : getSpillDirPathCb(std::move(_getSpillDirPathCb)), updateAndCheckSpillLimitCb(std::move(_updateAndCheckSpillLimitCb)), fileNamePrefix(std::move(_fileNamePrefix)), @@ -53,12 +55,18 @@ SpillConfig::SpillConfig( maxSpillRunRows(_maxSpillRunRows), writerFlushThresholdSize(_writerFlushThresholdSize), compressionKind(common::stringToCompressionKind(_compressionKind)), + numMaxMergeFiles(_numMaxMergeFiles), prefixSortConfig(_prefixSortConfig), - fileCreateConfig(_fileCreateConfig) { + fileCreateConfig(_fileCreateConfig), + windowMinReadBatchRows(_windowMinReadBatchRows) { VELOX_USER_CHECK_GE( spillableReservationGrowthPct, minSpillableReservationPct, "Spillable memory reservation growth pct should not be lower than minimum available pct"); + VELOX_CHECK_NE( + numMaxMergeFiles, + 1, + "NumMaxMergeFiles should not be 1 as merging should take at least 2 files to make progress"); } int32_t SpillConfig::spillLevel(uint8_t startBitOffset) const { diff --git a/velox/common/base/SpillConfig.h b/velox/common/base/SpillConfig.h index 7f30bc6e614f..8c77e9270f8c 100644 --- a/velox/common/base/SpillConfig.h +++ b/velox/common/base/SpillConfig.h @@ -52,6 +52,13 @@ using GetSpillDirectoryPathCB = std::function; /// bytes exceed the set limit. using UpdateAndCheckSpillLimitCB = std::function; +/// Specifies the options for spill to disk. +struct SpillDiskOptions { + std::string spillDirPath; + bool spillDirCreated{true}; + std::function spillDirCreateCb{nullptr}; +}; + /// Specifies the config for spilling. struct SpillConfig { SpillConfig() = default; @@ -71,8 +78,10 @@ struct SpillConfig { uint64_t _maxSpillRunRows, uint64_t _writerFlushThresholdSize, const std::string& _compressionKind, + uint32_t numMaxMergeFiles, std::optional _prefixSortConfig = std::nullopt, - const std::string& _fileCreateConfig = {}); + const std::string& _fileCreateConfig = {}, + uint32_t _windowMinReadBatchRows = 1'000); /// Returns the spilling level with given 'startBitOffset' and /// 'numPartitionBits'. @@ -151,11 +160,22 @@ struct SpillConfig { /// CompressionKind when spilling, CompressionKind_NONE means no compression. common::CompressionKind compressionKind; + /// The max number of files to merge at a time when merging sorted files into + /// a single ordered stream. 0 means unlimited. This is used to reduce memory + /// pressure by capping the number of open files when merging spilled sorted + /// files to avoid using too much memory and causing OOM. Note that this is + /// only applicable for ordered spill, is not applicable for spill scenarios + /// that don't need sorting, e.g. HashJoin. + uint32_t numMaxMergeFiles; + /// Prefix sort config when spilling, enable prefix sort when this config is /// set, otherwise, fallback to timsort. std::optional prefixSortConfig; /// Custom options passed to velox::FileSystem to create spill WriteFile. std::string fileCreateConfig; + + /// The minimum number of rows to read when processing spilled window data. + uint32_t windowMinReadBatchRows; }; } // namespace facebook::velox::common diff --git a/velox/common/base/SpillStats.cpp b/velox/common/base/SpillStats.cpp index 5c5e51014b78..caa8cc66ab07 100644 --- a/velox/common/base/SpillStats.cpp +++ b/velox/common/base/SpillStats.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/base/SpillStats.h" +#include #include "velox/common/base/Counters.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/StatsReporter.h" @@ -24,7 +25,7 @@ namespace facebook::velox::common { namespace { std::vector>& allSpillStats() { static std::vector> spillStatsList( - std::thread::hardware_concurrency()); + folly::hardware_concurrency()); return spillStatsList; } @@ -124,98 +125,6 @@ SpillStats SpillStats::operator-(const SpillStats& other) const { return result; } -bool SpillStats::operator<(const SpillStats& other) const { - uint32_t gtCount{0}; - uint32_t ltCount{0}; -#define UPDATE_COUNTER(counter) \ - do { \ - if (counter < other.counter) { \ - ++ltCount; \ - } else if (counter > other.counter) { \ - ++gtCount; \ - } \ - } while (0) - - UPDATE_COUNTER(spillRuns); - UPDATE_COUNTER(spilledInputBytes); - UPDATE_COUNTER(spilledBytes); - UPDATE_COUNTER(spilledRows); - UPDATE_COUNTER(spilledPartitions); - UPDATE_COUNTER(spilledFiles); - UPDATE_COUNTER(spillFillTimeNanos); - UPDATE_COUNTER(spillSortTimeNanos); - UPDATE_COUNTER(spillExtractVectorTimeNanos); - UPDATE_COUNTER(spillSerializationTimeNanos); - UPDATE_COUNTER(spillWrites); - UPDATE_COUNTER(spillFlushTimeNanos); - UPDATE_COUNTER(spillWriteTimeNanos); - UPDATE_COUNTER(spillMaxLevelExceededCount); - UPDATE_COUNTER(spillReadBytes); - UPDATE_COUNTER(spillReads); - UPDATE_COUNTER(spillReadTimeNanos); - UPDATE_COUNTER(spillDeserializationTimeNanos); -#undef UPDATE_COUNTER - VELOX_CHECK( - !((gtCount > 0) && (ltCount > 0)), - "gtCount {} ltCount {}", - gtCount, - ltCount); - return ltCount > 0; -} - -bool SpillStats::operator>(const SpillStats& other) const { - return !(*this < other) && (*this != other); -} - -bool SpillStats::operator>=(const SpillStats& other) const { - return !(*this < other); -} - -bool SpillStats::operator<=(const SpillStats& other) const { - return !(*this > other); -} - -bool SpillStats::operator==(const SpillStats& other) const { - return std::tie( - spillRuns, - spilledInputBytes, - spilledBytes, - spilledRows, - spilledPartitions, - spilledFiles, - spillFillTimeNanos, - spillSortTimeNanos, - spillExtractVectorTimeNanos, - spillSerializationTimeNanos, - spillWrites, - spillFlushTimeNanos, - spillWriteTimeNanos, - spillMaxLevelExceededCount, - spillReadBytes, - spillReads, - spillReadTimeNanos, - spillDeserializationTimeNanos) == - std::tie( - other.spillRuns, - other.spilledInputBytes, - other.spilledBytes, - other.spilledRows, - other.spilledPartitions, - other.spilledFiles, - other.spillFillTimeNanos, - other.spillSortTimeNanos, - other.spillExtractVectorTimeNanos, - other.spillSerializationTimeNanos, - other.spillWrites, - other.spillFlushTimeNanos, - other.spillWriteTimeNanos, - spillMaxLevelExceededCount, - spillReadBytes, - spillReads, - spillReadTimeNanos, - spillDeserializationTimeNanos); -} - void SpillStats::reset() { spillRuns = 0; spilledInputBytes = 0; diff --git a/velox/common/base/SpillStats.h b/velox/common/base/SpillStats.h index da9c5fe92184..36e3c7b33e70 100644 --- a/velox/common/base/SpillStats.h +++ b/velox/common/base/SpillStats.h @@ -15,8 +15,9 @@ */ #pragma once -#include -#include +#include +#include +#include #include @@ -97,14 +98,7 @@ struct SpillStats { SpillStats& operator+=(const SpillStats& other); SpillStats operator-(const SpillStats& other) const; - bool operator==(const SpillStats& other) const; - bool operator!=(const SpillStats& other) const { - return !(*this == other); - } - bool operator>(const SpillStats& other) const; - bool operator<(const SpillStats& other) const; - bool operator>=(const SpillStats& other) const; - bool operator<=(const SpillStats& other) const; + bool operator==(const SpillStats& other) const = default; void reset(); diff --git a/velox/common/base/SplitBlockBloomFilter.cpp b/velox/common/base/SplitBlockBloomFilter.cpp new file mode 100644 index 000000000000..ef4d5e87c564 --- /dev/null +++ b/velox/common/base/SplitBlockBloomFilter.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/SplitBlockBloomFilter.h" + +#include "velox/common/base/BitUtil.h" +#include "velox/common/base/Exceptions.h" + +#include +#include + +namespace facebook::velox { + +int64_t SplitBlockBloomFilter::numBlocks( + int64_t numElements, + double falsePositive) { + constexpr int K = xsimd::batch::size; + int64_t numBits = std::ceil( + -K * numElements / std::log(1 - std::pow(falsePositive, 1.0 / K))); + return bits::divRoundUp(numBits, 8 * sizeof(Block)); +} + +SplitBlockBloomFilter::SplitBlockBloomFilter(const std::span& blocks) + : blocks_(blocks) { + VELOX_CHECK_EQ(reinterpret_cast(blocks.data()) % sizeof(Block), 0); +} + +std::string SplitBlockBloomFilter::debugString() const { + std::ostringstream out; + out << "numBlocks=" << blocks_.size() << '\n'; + int64_t byBlockSetBits[1 + 8 * sizeof(Block)]{}; + int64_t byBitPosition[8 * sizeof(Block)]{}; + for (auto& block : blocks_) { + int numSetBits = 0; + for (int i = 0; i < xsimd::batch::size; ++i) { + auto n = std::popcount(block.data[i]); + numSetBits += n; + for (int j = 0; j < 32; ++j) { + byBitPosition[j + i * 32] += 1 & (block.data[i] >> j); + } + } + ++byBlockSetBits[numSetBits]; + } + for (int i = 0; i <= 8 * sizeof(Block); ++i) { + out << "Block set bits " << i << ": " << byBlockSetBits[i] << '\n'; + } + for (int i = 0; i < 8 * sizeof(Block); ++i) { + out << "Bit " << i << ": " << byBitPosition[i] << '\n'; + } + return out.str(); +} + +} // namespace facebook::velox diff --git a/velox/common/base/SplitBlockBloomFilter.h b/velox/common/base/SplitBlockBloomFilter.h new file mode 100644 index 000000000000..b5a40f88c3be --- /dev/null +++ b/velox/common/base/SplitBlockBloomFilter.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/SimdUtil.h" + +#include +#include + +namespace facebook::velox { + +/// SIMDized bloom filter implementation. We take 8 or 4 (depending on the SIMD +/// register size) bits from the inserted value, and split them into a block. A +/// block is the same size of a SIMD register. +/// +/// This data structure is not responsible for memory management or hashing. +/// +/// A detailed explanation about how the data structure works can be found here: +/// https://parquet.apache.org/docs/file-format/bloomfilter/ +class SplitBlockBloomFilter { + public: + /// A block is basically a SIMD register. Made public so user can calculate + /// the size needed for memory allocation; otherwise it's implementation + /// detail. + struct alignas(sizeof(xsimd::batch)) Block { + uint32_t data[xsimd::batch::size]; + }; + + /// Calculate the number of blocks needed to satisfy certain number of inserts + /// and false positive rate. A rough estimation between the memory usage and + /// false positive rate can be found below: + /// + /// | Bits of space per insert | False positive probability | + /// |--------------------------|----------------------------| + /// | 6.0 | 10% | + /// | 10.5 | 1% | + /// | 16.9 | 0.1% | + /// | 26.4 | 0.01% | + /// | 41 | 0.001% | + static int64_t numBlocks(int64_t numElements, double falsePositive); + + /// Construct the bloom filter using the blocks memory passed as parameter. + /// The block memory address must be aligned with block size (i.e. SIMD + /// register size). It is recommended to use numBlocks() to calculate the + /// number of block for allocation. + explicit SplitBlockBloomFilter(const std::span& blocks); + + /// Delete copy constructor and assignment to avoid accidentally mutating the + /// same underlying block data from multiple places. + SplitBlockBloomFilter(const SplitBlockBloomFilter&) = delete; + SplitBlockBloomFilter& operator=(const SplitBlockBloomFilter&) = delete; + + SplitBlockBloomFilter(SplitBlockBloomFilter&&) = default; + + /// Insert a hash into the bloom filter. The function used to generate this + /// hash should be avalanching. + void insert(uint64_t hash) { + auto mask = makeMask(hash); + auto* block = blocks_[blockIndex(hash)].data; + (xsimd::load_aligned(block) | mask).store_aligned(block); + } + + /// Check whether a hash has been inserted before. Could return true when it + /// has not been inserted. Never return false when the hash has been + /// inserted. + bool mayContain(uint64_t hash) const { + auto mask = makeMask(hash); + auto block = xsimd::load_aligned(blocks_[blockIndex(hash)].data); +#if XSIMD_WITH_AVX + return _mm256_testc_si256(block, mask); +#else + return xsimd::all(xsimd::bitwise_andnot(mask, block) == 0); +#endif + } + + /// Return the block index of the given hash. + uint64_t blockIndex(uint64_t hash) const { + return ((hash >> 32) * blocks_.size()) >> 32; + } + + std::string debugString() const; + + private: + static_assert(64 % sizeof(Block) == 0); + + template + static xsimd::batch makeSaltsVec() { + constexpr uint32_t kSalts[] = { + 0x2df1424bU, + 0x44974d91U, + 0x47b6137bU, + 0x5c6bfb31U, + 0x705495c7U, + 0x8824ad5bU, + 0x9efc4947U, + 0xa2b7289dU, + }; + if constexpr (xsimd::batch::size == 8) { + return xsimd::batch( + kSalts[0], + kSalts[1], + kSalts[2], + kSalts[3], + kSalts[4], + kSalts[5], + kSalts[6], + kSalts[7]); + } else { + static_assert(xsimd::batch::size == 4); + return xsimd::batch( + kSalts[0], kSalts[2], kSalts[4], kSalts[6]); + } + } + + static xsimd::batch makeMask(uint32_t hash) { + auto shifts = (makeSaltsVec() * xsimd::broadcast(hash)) >> 27; + return xsimd::broadcast(1) << shifts; + } + + std::span const blocks_; +}; + +} // namespace facebook::velox diff --git a/velox/common/base/StatsReporter.h b/velox/common/base/StatsReporter.h index d306c31cde18..9df453a75db1 100644 --- a/velox/common/base/StatsReporter.h +++ b/velox/common/base/StatsReporter.h @@ -17,7 +17,6 @@ #pragma once #include -#include /// StatsReporter designed to assist in reporting various metrics of the /// application that uses velox library. The library itself does not implement @@ -64,7 +63,7 @@ enum class StatType { HISTOGRAM, }; -inline std::string statTypeString(StatType stat) { +inline std::string_view statTypeString(StatType stat) { switch (stat) { case StatType::AVG: return "Avg"; @@ -77,7 +76,7 @@ inline std::string statTypeString(StatType stat) { case StatType::HISTOGRAM: return "Histogram"; default: - return fmt::format("UNKNOWN: {}", static_cast(stat)); + return "Unknown"; } } @@ -85,17 +84,17 @@ inline std::string statTypeString(StatType stat) { /// different implementations. class BaseStatsReporter { public: - virtual ~BaseStatsReporter() {} + virtual ~BaseStatsReporter() = default; /// Register a stat of the given stat type. /// @param key The key to identify the stat. /// @param statType How the stat is aggregated. virtual void registerMetricExportType(const char* key, StatType statType) - const = 0; + const {} virtual void registerMetricExportType( folly::StringPiece key, - StatType statType) const = 0; + StatType statType) const {} /// Register a histogram with a list of percentiles defined. /// @param key The key to identify the histogram. @@ -108,35 +107,98 @@ class BaseStatsReporter { int64_t bucketWidth, int64_t min, int64_t max, - const std::vector& pcts) const = 0; + const std::vector& pcts) const {} virtual void registerHistogramMetricExportType( folly::StringPiece key, int64_t bucketWidth, int64_t min, int64_t max, - const std::vector& pcts) const = 0; + const std::vector& pcts) const {} + + /// Register a quantile metric for quantile stats with export types, + /// quantiles, and sliding window periods. + /// @param key The key to identify the stat. + /// @param statTypes The list of stat types to export (e.g., AVG, SUM, COUNT). + /// @param pcts The quantile percentiles to track (as values between 0.0 + /// and 1.0, e.g., 0.5 for 50th percentile, 0.95 for 95th percentile). + /// @param slidingWindowsSeconds The sliding window periods in seconds. + virtual void registerQuantileMetricExportType( + const char* key, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds = {60}) const {} + + virtual void registerQuantileMetricExportType( + folly::StringPiece key, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds = {60}) const {} + + /// Register a dynamic quantile metric with a template key pattern that + /// supports runtime substitution. + /// @param keyPattern The key pattern with {} placeholders for substitution. + /// @param statTypes The list of stat types to export. + /// @param pcts The quantile percentiles to track. + /// @param slidingWindowsSeconds The sliding window periods in seconds. + virtual void registerDynamicQuantileMetricExportType( + const char* keyPattern, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds = {60}) const {} + + virtual void registerDynamicQuantileMetricExportType( + folly::StringPiece keyPattern, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds = {60}) const {} /// Add the given value to the stat. - virtual void addMetricValue(const std::string& key, size_t value = 1) - const = 0; + virtual void addMetricValue(const std::string& key, size_t value = 1) const {} - virtual void addMetricValue(const char* key, size_t value = 1) const = 0; + virtual void addMetricValue(const char* key, size_t value = 1) const {} - virtual void addMetricValue(folly::StringPiece key, size_t value = 1) - const = 0; + virtual void addMetricValue(folly::StringPiece key, size_t value = 1) const {} /// Add the given value to the histogram. virtual void addHistogramMetricValue(const std::string& key, size_t value) - const = 0; + const {} - virtual void addHistogramMetricValue(const char* key, size_t value) const = 0; + virtual void addHistogramMetricValue(const char* key, size_t value) const {} virtual void addHistogramMetricValue(folly::StringPiece key, size_t value) - const = 0; + const {} + + /// Add the given value to a quantile metric. + virtual void addQuantileMetricValue(const std::string& key, size_t value = 1) + const {} + + virtual void addQuantileMetricValue(const char* key, size_t value = 1) const { + } + + virtual void addQuantileMetricValue(folly::StringPiece key, size_t value = 1) + const {} + + /// Add the given value to a quantile metric. + virtual void addDynamicQuantileMetricValue( + const std::string& key, + folly::Range subkeys, + size_t value = 1) const {} + + virtual void addDynamicQuantileMetricValue( + const char* key, + folly::Range subkeys, + size_t value = 1) const {} + + virtual void addDynamicQuantileMetricValue( + folly::StringPiece key, + folly::Range subkeys, + size_t value = 1) const {} /// Return the aggregated metrics in a serialized string format. - virtual std::string fetchMetrics() = 0; + virtual std::string fetchMetrics() { + return ""; + } static bool registered; }; @@ -165,6 +227,30 @@ class DummyStatsReporter : public BaseStatsReporter { int64_t /* max */, const std::vector& /* pcts */) const override {} + void registerQuantileMetricExportType( + const char* /* key */, + const std::vector& /* statTypes */, + const std::vector& /* pcts */, + const std::vector& /* slidingWindowsSeconds */) const override {} + + void registerQuantileMetricExportType( + folly::StringPiece /* key */, + const std::vector& /* statTypes */, + const std::vector& /* pcts */, + const std::vector& /* slidingWindowsSeconds */) const override {} + + void registerDynamicQuantileMetricExportType( + const char* /* keyPattern */, + const std::vector& /* statTypes */, + const std::vector& /* pcts */, + const std::vector& /* slidingWindowsSeconds */) const override {} + + void registerDynamicQuantileMetricExportType( + folly::StringPiece /* keyPattern */, + const std::vector& /* statTypes */, + const std::vector& /* pcts */, + const std::vector& /* slidingWindowsSeconds */) const override {} + void addMetricValue(const std::string& /* key */, size_t /* value */) const override {} @@ -183,11 +269,82 @@ class DummyStatsReporter : public BaseStatsReporter { void addHistogramMetricValue(folly::StringPiece /* key */, size_t /* value */) const override {} + void addQuantileMetricValue(const std::string& /* key */, size_t /* value */) + const override {} + + void addQuantileMetricValue(const char* /* key */, size_t /* value */) + const override {} + + void addQuantileMetricValue(folly::StringPiece /* key */, size_t /* value */) + const override {} + + void addDynamicQuantileMetricValue( + const std::string& /* key */, + folly::Range /* subkeys */, + size_t /* value */) const override {} + + void addDynamicQuantileMetricValue( + const char* /* key */, + folly::Range /* subkeys */, + size_t /* value */) const override {} + + void addDynamicQuantileMetricValue( + folly::StringPiece /* key */, + folly::Range /* subkeys */, + size_t /* value */) const override {} + std::string fetchMetrics() override { return ""; } }; +/// Helper functions to create vectors from variadic arguments, reducing +/// boilerplate in quantile stat definitions. + +/// Create a vector of StatTypes from variadic arguments. +/// Usage: statTypes(StatType::AVG, StatType::COUNT, StatType::SUM) +template +std::vector statTypes(Args... args) { + return std::vector{args...}; +} + +/// Create a vector of percentiles from variadic arguments. +/// Usage: percentiles(0.5, 0.95, 0.99) +template +std::vector percentiles(Args... args) { + return std::vector{static_cast(args)...}; +} + +/// Create a vector of sliding window periods in seconds from variadic +/// arguments. Usage: slidingWindowsSeconds(60, 600, 3600) +template +std::vector slidingWindowsSeconds(Args... args) { + return std::vector{static_cast(args)...}; +} + +/// Helper class that stores subkeys in a member array and converts to +/// folly::Range. This is a temporary object that lives just for the duration of +/// the macro call. +template +class subkeys { + std::array pieces_; + + public: + template + subkeys(Args&&... args) + : pieces_{folly::StringPiece(std::forward(args))...} {} + + /// Conversion operator to folly::Range + operator folly::Range() const { + return folly::Range( + pieces_.data(), pieces_.size()); + } +}; + +/// Template deduction guide for subkeys class +template +subkeys(Args&&...) -> subkeys; + #define DEFINE_METRIC(key, type) \ { \ if (::facebook::velox::BaseStatsReporter::registered) { \ @@ -236,4 +393,52 @@ class DummyStatsReporter : public BaseStatsReporter { } \ } \ } + +#define DEFINE_QUANTILE_STAT(key, statTypes, percentiles, slidingWindows) \ + { \ + if (::facebook::velox::BaseStatsReporter::registered) { \ + auto reporter = folly::Singleton< \ + facebook::velox::BaseStatsReporter>::try_get_fast(); \ + if (FOLLY_LIKELY(reporter != nullptr)) { \ + reporter->registerQuantileMetricExportType( \ + (key), (statTypes), (percentiles), (slidingWindows)); \ + } \ + } \ + } + +#define RECORD_QUANTILE_STAT_VALUE(key, ...) \ + { \ + if (::facebook::velox::BaseStatsReporter::registered) { \ + auto reporter = folly::Singleton< \ + facebook::velox::BaseStatsReporter>::try_get_fast(); \ + if (FOLLY_LIKELY(reporter != nullptr)) { \ + reporter->addQuantileMetricValue((key), ##__VA_ARGS__); \ + } \ + } \ + } + +#define DEFINE_DYNAMIC_QUANTILE_STAT( \ + keyPattern, statTypes, percentiles, slidingWindows) \ + { \ + if (::facebook::velox::BaseStatsReporter::registered) { \ + auto reporter = folly::Singleton< \ + facebook::velox::BaseStatsReporter>::try_get_fast(); \ + if (FOLLY_LIKELY(reporter != nullptr)) { \ + reporter->registerDynamicQuantileMetricExportType( \ + (keyPattern), (statTypes), (percentiles), (slidingWindows)); \ + } \ + } \ + } + +#define RECORD_DYNAMIC_QUANTILE_STAT_VALUE(keyPattern, subkeys, ...) \ + { \ + if (::facebook::velox::BaseStatsReporter::registered) { \ + auto reporter = folly::Singleton< \ + facebook::velox::BaseStatsReporter>::try_get_fast(); \ + if (FOLLY_LIKELY(reporter != nullptr)) { \ + reporter->addDynamicQuantileMetricValue( \ + (keyPattern), (subkeys), ##__VA_ARGS__); \ + } \ + } \ + } } // namespace facebook::velox diff --git a/velox/common/base/Status.cpp b/velox/common/base/Status.cpp index a170d3d9053f..46a5f603846a 100644 --- a/velox/common/base/Status.cpp +++ b/velox/common/base/Status.cpp @@ -115,4 +115,19 @@ void Status::warn(const std::string_view& message) const { LOG(WARNING) << message << ": " << toString(); } +std::string internal::generateError(std::string message, std::string exprStr) { + std::string elaborateMessage; + if (!message.empty()) { + elaborateMessage += "Reason: "; + elaborateMessage += message; + elaborateMessage += '\n'; + } + if (!exprStr.empty()) { + elaborateMessage += "Expression: "; + elaborateMessage += exprStr; + elaborateMessage += '\n'; + } + return elaborateMessage; +} + } // namespace facebook::velox diff --git a/velox/common/base/Status.h b/velox/common/base/Status.h index f5fd816d9469..1cf42673bc82 100644 --- a/velox/common/base/Status.h +++ b/velox/common/base/Status.h @@ -26,6 +26,8 @@ #include #include +#include "velox/common/base/ExceptionHelper.h" + namespace facebook::velox { /// The Status object is an object holding the outcome of an operation (success @@ -213,9 +215,6 @@ class [[nodiscard]] Status { inline Status& operator=(Status&& s) noexcept; inline bool operator==(const Status& other) const noexcept; - inline bool operator!=(const Status& other) const noexcept { - return !(*this == other); - } // AND the statuses. inline Status operator&(const Status& s) const noexcept; @@ -528,6 +527,77 @@ void Status::moveFrom(Status& s) { VELOX_RETURN_IF(!__s.ok(), __s); \ } while (false) +#define _VELOX_RETURN_IMPL(expr, exprStr, error, ...) \ + do { \ + if (FOLLY_UNLIKELY(expr)) { \ + if (::facebook::velox::threadSkipErrorDetails()) { \ + return error(); \ + } \ + auto message = ::facebook::velox::errorMessage(__VA_ARGS__); \ + return error( \ + ::facebook::velox::internal::generateError(message, exprStr)); \ + } \ + } while (0) + +/// If the caller passes a custom message (4 *or more* arguments), we +/// have to construct a format string from ours ("({} vs. {})") plus +/// theirs by adding a space and shuffling arguments. If they don't (exactly 3 +/// arguments), we can just pass our own format string and arguments straight +/// through. +#define _VELOX_RETURN_OP_WITH_USER_FMT_HELPER( \ + implmacro, expr1, expr2, op, user_fmt, ...) \ + implmacro( \ + (expr1)op(expr2), \ + #expr1 " " #op " " #expr2, \ + "({} vs. {}) " user_fmt, \ + expr1, \ + expr2, \ + ##__VA_ARGS__) + +#define _VELOX_RETURN_OP_HELPER(implmacro, expr1, expr2, op, ...) \ + do { \ + if constexpr (FOLLY_PP_DETAIL_NARGS(__VA_ARGS__) > 0) { \ + _VELOX_RETURN_OP_WITH_USER_FMT_HELPER( \ + implmacro, expr1, expr2, op, __VA_ARGS__); \ + } else { \ + implmacro( \ + (expr1)op(expr2), \ + #expr1 " " #op " " #expr2, \ + "({} vs. {})", \ + expr1, \ + expr2); \ + } \ + } while (0) + +#define _VELOX_USER_RETURN_IMPL(expr, exprStr, ...) \ + _VELOX_RETURN_IMPL( \ + expr, exprStr, ::facebook::velox::Status::UserError, ##__VA_ARGS__) + +#define _VELOX_USER_RETURN_OP(expr1, expr2, op, ...) \ + _VELOX_RETURN_OP_HELPER( \ + _VELOX_USER_RETURN_IMPL, expr1, expr2, op, ##__VA_ARGS__) + +// For all below macros, an additional message can be passed using a +// format string and arguments, as with `fmt::format`. +#define VELOX_USER_RETURN(expr, ...) \ + _VELOX_USER_RETURN_IMPL(expr, #expr, ##__VA_ARGS__) +#define VELOX_USER_RETURN_GT(e1, e2, ...) \ + _VELOX_USER_RETURN_OP(e1, e2, >, ##__VA_ARGS__) +#define VELOX_USER_RETURN_GE(e1, e2, ...) \ + _VELOX_USER_RETURN_OP(e1, e2, >=, ##__VA_ARGS__) +#define VELOX_USER_RETURN_LT(e1, e2, ...) \ + _VELOX_USER_RETURN_OP(e1, e2, <, ##__VA_ARGS__) +#define VELOX_USER_RETURN_LE(e1, e2, ...) \ + _VELOX_USER_RETURN_OP(e1, e2, <=, ##__VA_ARGS__) +#define VELOX_USER_RETURN_EQ(e1, e2, ...) \ + _VELOX_USER_RETURN_OP(e1, e2, ==, ##__VA_ARGS__) +#define VELOX_USER_RETURN_NE(e1, e2, ...) \ + _VELOX_USER_RETURN_OP(e1, e2, !=, ##__VA_ARGS__) +#define VELOX_USER_RETURN_NULL(e, ...) \ + VELOX_USER_RETURN(e == nullptr, ##__VA_ARGS__) +#define VELOX_USER_RETURN_NOT_NULL(e, ...) \ + VELOX_USER_RETURN(e != nullptr, ##__VA_ARGS__) + namespace internal { /// Common API for extracting Status from either Status or Result (the latter @@ -540,6 +610,8 @@ inline Status genericToStatus(Status&& st) { return std::move(st); } +std::string generateError(std::string message, std::string exprStr); + } // namespace internal /// Holds a result or an error. Designed to be used by APIs that do not throw. diff --git a/velox/common/base/TraceConfig.cpp b/velox/common/base/TraceConfig.cpp index d88a529ac7cd..5007be9a48c7 100644 --- a/velox/common/base/TraceConfig.cpp +++ b/velox/common/base/TraceConfig.cpp @@ -22,14 +22,16 @@ namespace facebook::velox { TraceConfig::TraceConfig( - std::unordered_set _queryNodeIds, + std::string _queryNodeId, std::string _queryTraceDir, UpdateAndCheckTraceLimitCB _updateAndCheckTraceLimitCB, - std::string _taskRegExp) - : queryNodes(std::move(_queryNodeIds)), + std::string _taskRegExp, + bool _dryRun) + : queryNodeId(std::move(_queryNodeId)), queryTraceDir(std::move(_queryTraceDir)), updateAndCheckTraceLimitCB(std::move(_updateAndCheckTraceLimitCB)), - taskRegExp(std::move(_taskRegExp)) { - VELOX_CHECK(!queryNodes.empty(), "Query trace nodes cannot be empty"); + taskRegExp(std::move(_taskRegExp)), + dryRun(_dryRun) { + VELOX_CHECK(!queryNodeId.empty(), "The query trace node cannot be empty"); } } // namespace facebook::velox diff --git a/velox/common/base/TraceConfig.h b/velox/common/base/TraceConfig.h index 231733facb2e..0eb84070f227 100644 --- a/velox/common/base/TraceConfig.h +++ b/velox/common/base/TraceConfig.h @@ -38,18 +38,23 @@ namespace facebook::velox { using UpdateAndCheckTraceLimitCB = std::function; struct TraceConfig { - /// Target query trace nodes. - std::unordered_set queryNodes; + /// Target query trace node id. + std::string queryNodeId; /// Base dir of query trace. std::string queryTraceDir; UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB; /// The trace task regexp. std::string taskRegExp; + /// If true, we only collect operator input trace without the actual + /// execution. This is used by crash debugging so that we can collect the + /// input that triggers the crash. + bool dryRun{false}; TraceConfig( - std::unordered_set _queryNodeIds, - std::string _queryTraceDir, - UpdateAndCheckTraceLimitCB _updateAndCheckTraceLimitCB, - std::string _taskRegExp); + std::string queryNodeId, + std::string queryTraceDir, + UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB, + std::string taskRegExp, + bool dryRun); }; } // namespace facebook::velox diff --git a/velox/exec/TreeOfLosers.h b/velox/common/base/TreeOfLosers.h similarity index 100% rename from velox/exec/TreeOfLosers.h rename to velox/common/base/TreeOfLosers.h diff --git a/velox/common/base/VeloxException.cpp b/velox/common/base/VeloxException.cpp index 735953d169a8..380659e60a5f 100644 --- a/velox/common/base/VeloxException.cpp +++ b/velox/common/base/VeloxException.cpp @@ -169,6 +169,18 @@ bool isStackTraceEnabled(VeloxException::Type type) { return last->compare_exchange_strong(latest, now, std::memory_order_relaxed); } +void stringAppendNumber(std::string& str, size_t number) { + // Manual implementation of itoa to avoid the cost of std::to_string. + const auto numberStartOffset = str.end() - + str.begin(); // Not `size()`. We need distance. The type is different. + size_t remaining = number; + do { + str += static_cast('0' + remaining % 10); + remaining /= 10; + } while (remaining); + reverse(str.begin() + numberStartOffset, str.end()); +} + } // namespace template @@ -248,13 +260,7 @@ void VeloxException::State::finalize() const { if (line) { elaborateMessage += "Line: "; - auto len = elaborateMessage.size(); - size_t t = line; - do { - elaborateMessage += static_cast('0' + t % 10); - t /= 10; - } while (t); - reverse(elaborateMessage.begin() + len, elaborateMessage.end()); + stringAppendNumber(elaborateMessage, line); elaborateMessage += '\n'; } diff --git a/velox/common/base/VeloxException.h b/velox/common/base/VeloxException.h index 758eb1b6e6d8..b9ffb813f4f5 100644 --- a/velox/common/base/VeloxException.h +++ b/velox/common/base/VeloxException.h @@ -53,6 +53,10 @@ inline constexpr auto kErrorSourceRuntime = "RUNTIME"_fs; /// Errors where the root cause of the problem is some unreliable aspect of the /// system are classified with source SYSTEM. inline constexpr auto kErrorSourceSystem = "SYSTEM"_fs; + +/// Errors where the root cause of the problem is some external dependency (e.g. +/// storage) +inline constexpr auto kErrorSourceExternal = "EXTERNAL"_fs; } // namespace error_source namespace error_code { @@ -438,7 +442,7 @@ ExceptionContext& getExceptionContext(); /// exception context with the previous context held by the thread_local /// variable to allow retrieving the top-level context when there is an /// exception context hierarchy. -class ExceptionContextSetter { +class [[nodiscard]] ExceptionContextSetter { public: explicit ExceptionContextSetter(ExceptionContext value) : prev_{getExceptionContext()} { diff --git a/velox/common/base/benchmarks/CMakeLists.txt b/velox/common/base/benchmarks/CMakeLists.txt index a64e83974b6e..b4ffaf2f60a7 100644 --- a/velox/common/base/benchmarks/CMakeLists.txt +++ b/velox/common/base/benchmarks/CMakeLists.txt @@ -16,27 +16,37 @@ add_executable(velox_common_base_benchmarks BitUtilBenchmark.cpp) target_link_libraries( velox_common_base_benchmarks PUBLIC Folly::follybenchmark - PRIVATE velox_common_base Folly::folly) + PRIVATE velox_common_base Folly::folly +) add_executable(velox_common_stringsearch_benchmarks StringSearchBenchmark.cpp) target_link_libraries( velox_common_stringsearch_benchmarks PUBLIC Folly::follybenchmark - PRIVATE velox_common_base Folly::folly) + PRIVATE velox_common_base Folly::folly +) -add_executable(velox_common_indexed_priority_queue_benchmark - IndexedPriorityQueueBenchmark.cpp) +add_executable(velox_common_indexed_priority_queue_benchmark IndexedPriorityQueueBenchmark.cpp) target_link_libraries( velox_common_indexed_priority_queue_benchmark PUBLIC Folly::follybenchmark - PRIVATE velox_common_base Folly::folly) + PRIVATE velox_common_base Folly::folly +) -add_executable(velox_common_sorting_network_benchmark - SortingNetworkBenchmark.cpp) +add_executable(velox_common_sorting_network_benchmark SortingNetworkBenchmark.cpp) target_link_libraries( velox_common_sorting_network_benchmark PUBLIC Folly::follybenchmark - PRIVATE velox_common_base Folly::folly) + PRIVATE velox_common_base Folly::folly +) + +add_executable(velox_common_split_block_bloom_filter_benchmark SplitBlockBloomFilterBenchmark.cpp) + +target_link_libraries( + velox_common_split_block_bloom_filter_benchmark + PUBLIC Folly::follybenchmark + PRIVATE velox_common_base Folly::folly +) diff --git a/velox/common/base/benchmarks/SplitBlockBloomFilterBenchmark.cpp b/velox/common/base/benchmarks/SplitBlockBloomFilterBenchmark.cpp new file mode 100644 index 000000000000..6706be84ab24 --- /dev/null +++ b/velox/common/base/benchmarks/SplitBlockBloomFilterBenchmark.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/SplitBlockBloomFilter.h" + +#include +#include +#include + +#define VELOX_BENCHMARK(_make, _name, ...) \ + [[maybe_unused]] auto _name = _make(FOLLY_PP_STRINGIZE(_name), __VA_ARGS__) + +namespace facebook::velox { +namespace { + +template +class SplitBlockBloomFilterBenchmark { + public: + SplitBlockBloomFilterBenchmark( + const char* name, + Hasher hasher, + double falsePositive, + int numInserts, + int numTests) + : hasher_(std::move(hasher)), + numTests_(numTests), + blocks_(SplitBlockBloomFilter::numBlocks(numInserts, falsePositive)), + filter_(blocks_) { + for (int i = 0; i < numInserts; ++i) { + filter_.insert(hasher_(generateValue())); + } + folly::addBenchmark(__FILE__, name, [this] { return run(); }); + } + + private: + static T generateValue() { + if constexpr (sizeof(T) == 8) { + return folly::Random::rand64(); + } else { + static_assert(sizeof(T) == 4); + return folly::Random::rand32(); + } + } + + unsigned run() const { + int numHits = 0; + for (int i = 0; i < numTests_; ++i) { + numHits += filter_.mayContain(hasher_(generateValue())); + } + folly::doNotOptimizeAway(numHits); + return numTests_; + } + + const Hasher hasher_; + const double numTests_; + std::vector blocks_; + SplitBlockBloomFilter filter_; +}; + +template +SplitBlockBloomFilterBenchmark makeBenchmark( + const char* name, + Hasher hasher, + double falsePositive, + int numInserts, + int numTests) { + return SplitBlockBloomFilterBenchmark( + name, std::move(hasher), falsePositive, numInserts, numTests); +} + +} // namespace +} // namespace facebook::velox + +int main(int argc, char* argv[]) { + using namespace facebook::velox; + folly::Init follyInit(&argc, &argv); + VELOX_BENCHMARK( + makeBenchmark, + int32, + folly::hasher(), + 0.01, + 5'000'000, + 10'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64, + folly::hasher(), + 0.01, + 5'000'000, + 10'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_nohash, + folly::identity, + 0.01, + 5'000'000, + 10'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int32_small, + folly::hasher(), + 0.01, + 500'000, + 1'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_small, + folly::hasher(), + 0.01, + 500'000, + 1'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_nohash_small, + folly::identity, + 0.01, + 500'000, + 1'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int32_large, + folly::hasher(), + 0.01, + 50'000'000, + 100'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_large, + folly::hasher(), + 0.01, + 50'000'000, + 100'000'000); + VELOX_BENCHMARK( + makeBenchmark, + int64_nohash_large, + folly::identity, + 0.01, + 50'000'000, + 100'000'000); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/common/base/benchmarks/StringSearchBenchmark.cpp b/velox/common/base/benchmarks/StringSearchBenchmark.cpp index 5fee996709a3..cad51cb56fa7 100644 --- a/velox/common/base/benchmarks/StringSearchBenchmark.cpp +++ b/velox/common/base/benchmarks/StringSearchBenchmark.cpp @@ -105,11 +105,12 @@ class TestStringSearch { void runSearching(size_t iters) const { if constexpr (alg == SIMD) { FOR_EACH_RANGE (i, 0, iters) - doNotOptimizeAway(simd::simdStrstr( - heyStack_.data(), - heyStack_.size(), - needle_.data(), - needle_.size())); + doNotOptimizeAway( + simd::simdStrstr( + heyStack_.data(), + heyStack_.size(), + needle_.data(), + needle_.size())); } else if constexpr (alg == STD) { FOR_EACH_RANGE (i, 0, iters) doNotOptimizeAway( diff --git a/velox/common/base/tests/AsyncSourceTest.cpp b/velox/common/base/tests/AsyncSourceTest.cpp index 657a7ba8e08e..ce20becfeb4f 100644 --- a/velox/common/base/tests/AsyncSourceTest.cpp +++ b/velox/common/base/tests/AsyncSourceTest.cpp @@ -296,3 +296,134 @@ TEST(AsyncSourceTest, setContexts) { verifyContexts("test2", "task_id2"); } + +TEST(AsyncSourceTest, cancel) { + DataCounter::reset(); + + // Cancel before prepare() - task should not run and resources cleaned up + { + auto dataCounter = std::make_shared(); + auto asyncSource = std::make_shared>([dataCounter]() { + return std::make_unique(dataCounter->objectNumber()); + }); + dataCounter.reset(); + + asyncSource->cancel(); + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + } + DataCounter::reset(); + + // Cancel while task is running in prepare() - should not affect the + // prepare() result + { + folly::Baton<> startBaton; + folly::Baton<> finishBaton; + auto asyncSource = std::make_shared>( + [&startBaton, &finishBaton]() { + startBaton.post(); + finishBaton.wait(); + return std::make_unique(); + }); + + auto thread = std::thread([&asyncSource] { asyncSource->prepare(); }); + EXPECT_TRUE( + startBaton.try_wait_for(1s)); // Make sure prepare() gets lock first + + asyncSource->cancel(); // Should be no-op + + finishBaton.post(); + thread.join(); + + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + EXPECT_TRUE(asyncSource->hasValue()); + } + DataCounter::reset(); + + // Cancel after prepare() completes - should not destroy the result + { + auto asyncSource = std::make_shared>( + []() { return std::make_unique(); }); + asyncSource->prepare(); + + asyncSource->cancel(); // Should be no-op since make_ was taken + + EXPECT_TRUE(asyncSource->hasValue()); + EXPECT_NE(asyncSource->move(), nullptr); + } + DataCounter::reset(); + + // prepare() and move() are no-ops after cancel() + { + std::atomic taskExecuted{false}; + auto asyncSource = + std::make_shared>([&taskExecuted]() { + taskExecuted = true; + return std::make_unique(); + }); + + asyncSource->cancel(); + asyncSource->prepare(); // No-op + EXPECT_FALSE(taskExecuted); + EXPECT_FALSE(asyncSource->hasValue()); + + EXPECT_EQ(asyncSource->move(), nullptr); // No-op + EXPECT_FALSE(taskExecuted); + } + + // Multiple cancel calls are idempotent + { + auto dataCounter = std::make_shared(); + auto asyncSource = std::make_shared>([dataCounter]() { + return std::make_unique(dataCounter->objectNumber()); + }); + dataCounter.reset(); + + asyncSource->cancel(); + asyncSource->cancel(); // Should be safe + EXPECT_EQ(DataCounter::numDeletedDataCounters(), 1); + } + DataCounter::reset(); + + // Cancel called during move() execution - should be no-op + { + folly::Baton<> moveStarted; + folly::Baton<> continueMove; + auto asyncSource = std::make_shared>( + [&moveStarted, &continueMove]() { + moveStarted.post(); + continueMove.wait(); + return std::make_unique(); + }); + + // move() will execute the lambda inline since prepare() wasn't called + auto moveThread = std::thread([&asyncSource] { + auto result = asyncSource->move(); + EXPECT_NE(result, nullptr); + }); + + EXPECT_TRUE(moveStarted.try_wait_for(1s)); // Wait for move to start + + asyncSource->cancel(); // Should be no-op - make_ already taken + + continueMove.post(); + moveThread.join(); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + } + DataCounter::reset(); + + // Cancel called after move() completes - should be no-op + { + auto asyncSource = std::make_shared>( + []() { return std::make_unique(); }); + + auto result = asyncSource->move(); // Complete move + EXPECT_NE(result, nullptr); + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + + asyncSource->cancel(); // Should be no-op - moved_ is true, make_ is null + + // Item was already consumed + EXPECT_EQ(DataCounter::numCreatedDataCounters(), 1); + } + DataCounter::reset(); +} diff --git a/velox/common/base/tests/BitUtilTest.cpp b/velox/common/base/tests/BitUtilTest.cpp index 5ee8e4ce5066..4450a5145ca1 100644 --- a/velox/common/base/tests/BitUtilTest.cpp +++ b/velox/common/base/tests/BitUtilTest.cpp @@ -541,6 +541,21 @@ TEST_F(BitUtilTest, getAndClearLastSetBit) { EXPECT_EQ(bits, 0); } +TEST_F(BitUtilTest, negateBit) { + char data[35]; + for (int32_t i = 0; i < 100; i++) { + setBit(data, i, true); + } + std::vector indices = {0, 1, 2, 3, 4, 5, 6, 7}; + for (auto i : indices) { + negateBit(data, i); + EXPECT_EQ(isBitSet(data, i), false); + } + for (int32_t i = 8; i < 100; i++) { + EXPECT_EQ(isBitSet(data, i), true); + } +} + TEST_F(BitUtilTest, negate) { char data[35]; for (int32_t i = 0; i < 100; i++) { diff --git a/velox/common/base/tests/BloomFilterTest.cpp b/velox/common/base/tests/BloomFilterTest.cpp index fdc15090dd1d..7b1833042ea2 100644 --- a/velox/common/base/tests/BloomFilterTest.cpp +++ b/velox/common/base/tests/BloomFilterTest.cpp @@ -63,6 +63,38 @@ TEST_F(BloomFilterTest, serialize) { EXPECT_EQ(bloom.serializedSize(), deserialized.serializedSize()); } +TEST_F(BloomFilterTest, staticMayContain) { + constexpr int32_t kSize = 1024; + std::string serializedBloom; + BloomFilter bloom; + bloom.reset(kSize); + for (auto i = 0; i < kSize; ++i) { + bloom.insert(folly::hasher()(i)); + } + serializedBloom.resize(bloom.serializedSize()); + bloom.serialize(serializedBloom.data()); + int32_t numFalsePositives = 0; + for (auto i = 0; i < kSize; ++i) { + EXPECT_TRUE( + BloomFilter<>::mayContain( + serializedBloom.data(), folly::hasher()(i))); + + const uint64_t smallValueHash = folly::hasher()(i + kSize); + const bool isFalsePositiveForSmallValue = + BloomFilter<>::mayContain(serializedBloom.data(), smallValueHash); + EXPECT_EQ(isFalsePositiveForSmallValue, bloom.mayContain(smallValueHash)); + numFalsePositives += isFalsePositiveForSmallValue; + + const uint64_t largeValueHash = + folly::hasher()((i + kSize) * 123451); + const bool isFalsePositiveForLargeValue = + BloomFilter<>::mayContain(serializedBloom.data(), largeValueHash); + EXPECT_EQ(isFalsePositiveForLargeValue, bloom.mayContain(largeValueHash)); + numFalsePositives += isFalsePositiveForLargeValue; + } + EXPECT_GT(2, 100 * numFalsePositives / kSize); +} + TEST_F(BloomFilterTest, merge) { constexpr int32_t kSize = 10; BloomFilter bloom; diff --git a/velox/common/base/tests/CMakeLists.txt b/velox/common/base/tests/CMakeLists.txt index 6f1a7ab4ef20..0a85dca57ce1 100644 --- a/velox/common/base/tests/CMakeLists.txt +++ b/velox/common/base/tests/CMakeLists.txt @@ -32,9 +32,11 @@ add_executable( SkewedPartitionBalancerTest.cpp SpillConfigTest.cpp SpillStatsTest.cpp + SplitBlockBloomFilterTest.cpp StatsReporterTest.cpp StatusTest.cpp - SuccinctPrinterTest.cpp) + SuccinctPrinterTest.cpp +) add_test(velox_base_test velox_base_test) @@ -55,7 +57,8 @@ target_link_libraries( gflags::gflags GTest::gtest GTest::gmock - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_id_map_test IdMapTest.cpp) @@ -70,14 +73,11 @@ target_link_libraries( glog::glog GTest::gtest GTest::gtest_main - pthread) + pthread +) add_executable(velox_memcpy_meter Memcpy.cpp) target_link_libraries( velox_memcpy_meter - PRIVATE - velox_common_base - velox_exception - velox_time - Folly::folly - gflags::gflags) + PRIVATE velox_common_base velox_exception velox_time Folly::folly gflags::gflags +) diff --git a/velox/common/base/tests/ConcurrentCounterTest.cpp b/velox/common/base/tests/ConcurrentCounterTest.cpp index c96fc1ac9576..07967c0f965d 100644 --- a/velox/common/base/tests/ConcurrentCounterTest.cpp +++ b/velox/common/base/tests/ConcurrentCounterTest.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include "velox/common/base/tests/GTestUtils.h" @@ -48,7 +49,7 @@ class ConcurrentCounterTest : public testing::TestWithParam { void setupCounter() { counter_ = std::make_unique>( - std::thread::hardware_concurrency()); + folly::hardware_concurrency()); } const bool useUpdateFn_{GetParam()}; @@ -72,8 +73,8 @@ TEST_P(ConcurrentCounterTest, multithread) { const int32_t numUpdatesPerThread = 5'000; std::vector numThreads; numThreads.push_back(1); - numThreads.push_back(std::thread::hardware_concurrency()); - numThreads.push_back(std::thread::hardware_concurrency() * 2); + numThreads.push_back(folly::hardware_concurrency()); + numThreads.push_back(folly::hardware_concurrency() * 2); for (int numThreads : numThreads) { SCOPED_TRACE(fmt::format("numThreads: {}", numThreads)); counter_->testingClear(); diff --git a/velox/common/base/tests/ExceptionTest.cpp b/velox/common/base/tests/ExceptionTest.cpp index 9386b8cb672e..ff00f02489d2 100644 --- a/velox/common/base/tests/ExceptionTest.cpp +++ b/velox/common/base/tests/ExceptionTest.cpp @@ -41,7 +41,7 @@ struct fmt::formatter { template auto format(const Counter& c, FormatContext& ctx) const { auto x = c.counter++; - return format_to(ctx.out(), "{}", x); + return fmt::format_to(ctx.out(), "{}", x); } }; @@ -65,7 +65,7 @@ void verifyVeloxException( std::function f, const std::string& messagePrefix) { verifyException(f, [&messagePrefix](const auto& e) { - EXPECT_TRUE(folly::StringPiece{e.what()}.startsWith(messagePrefix)) + EXPECT_TRUE(std::string_view{e.what()}.starts_with(messagePrefix)) << "\nException message prefix mismatch.\n\nExpected prefix: " << messagePrefix << "\n\nActual message: " << e.what(); }); @@ -106,11 +106,12 @@ void testExceptionTraceCollectionControl(bool userException, bool enabled) { false); } } catch (VeloxException& e) { - SCOPED_TRACE(fmt::format( - "enabled: {}, user flag: {}, sys flag: {}", - enabled, - FLAGS_velox_exception_user_stacktrace_enabled, - FLAGS_velox_exception_system_stacktrace_enabled)); + SCOPED_TRACE( + fmt::format( + "enabled: {}, user flag: {}, sys flag: {}", + enabled, + FLAGS_velox_exception_user_stacktrace_enabled, + FLAGS_velox_exception_system_stacktrace_enabled)); ASSERT_EQ(userException, e.exceptionType() == VeloxException::Type::kUser); ASSERT_EQ(enabled, e.stackTrace() != nullptr); } @@ -170,12 +171,13 @@ void testExceptionTraceCollectionRateControl( false); } } catch (VeloxException& e) { - SCOPED_TRACE(fmt::format( - "userException: {}, hasRateLimit: {}, user limit: {}ms, sys limit: {}ms", - userException, - hasRateLimit, - FLAGS_velox_exception_user_stacktrace_rate_limit_ms, - FLAGS_velox_exception_system_stacktrace_rate_limit_ms)); + SCOPED_TRACE( + fmt::format( + "userException: {}, hasRateLimit: {}, user limit: {}ms, sys limit: {}ms", + userException, + hasRateLimit, + FLAGS_velox_exception_user_stacktrace_rate_limit_ms, + FLAGS_velox_exception_system_stacktrace_rate_limit_ms)); ASSERT_EQ( userException, e.exceptionType() == VeloxException::Type::kUser); ASSERT_EQ(!hasRateLimit || ((iter % 2) == 0), e.stackTrace() != nullptr); @@ -1015,6 +1017,6 @@ TEST(ExceptionTest, exceptionMacroInlining) { try { VELOX_USER_FAIL(errorStr, "definitely"); } catch (const std::exception& e) { - ASSERT_TRUE(folly::StringPiece{e.what()}.startsWith("argument not found")); + ASSERT_TRUE(std::string_view{e.what()}.starts_with("argument not found")); } } diff --git a/velox/common/base/tests/GTestUtils.h b/velox/common/base/tests/GTestUtils.h index 1f733808740a..42ec23513c4b 100644 --- a/velox/common/base/tests/GTestUtils.h +++ b/velox/common/base/tests/GTestUtils.h @@ -61,13 +61,15 @@ facebook::velox::VeloxRuntimeError, _expression, _errorMessage) #define VELOX_ASSERT_ERROR_STATUS(_expression, _statusCode, _errorMessage) \ - const auto status = (_expression); \ - ASSERT_TRUE(status.code() == _statusCode) \ - << "Expected error code to be '" << toString(_statusCode) \ - << "', but received '" << toString(status.code()) << "'."; \ - ASSERT_TRUE(status.message().find(_errorMessage) != std::string::npos) \ - << "Expected error message to contain '" << (_errorMessage) \ - << "', but received '" << status.message() << "'." + { \ + const auto status = (_expression); \ + ASSERT_TRUE(status.code() == _statusCode) \ + << "Expected error code to be '" << toString(_statusCode) \ + << "', but received '" << toString(status.code()) << "'."; \ + ASSERT_TRUE(status.message().find(_errorMessage) != std::string::npos) \ + << "Expected error message to contain '" << (_errorMessage) \ + << "', but received '" << status.message() << "'."; \ + } #define VELOX_ASSERT_ERROR_CODE_IMPL( \ _type, _expression, _errorCode, _errorMessage) \ @@ -99,12 +101,46 @@ _errorCode, \ _errorMessage) +#define VELOX_EXPECT_EQ_TYPES(actual, expected) \ + { \ + auto _actualType = (actual); \ + auto _expectedType = (expected); \ + if (_expectedType != nullptr) { \ + ASSERT_TRUE(_actualType != nullptr) \ + << "Expected: " << _expectedType->toString() << ", got null"; \ + EXPECT_EQ(*_actualType, *_expectedType) \ + << "Expected: " << _expectedType->toString() << ", got " \ + << _actualType->toString(); \ + } else { \ + EXPECT_EQ(_actualType, nullptr) \ + << "Expected null, got " << _actualType->toString(); \ + } \ + } + +#define VELOX_ASSERT_EQ_TYPES(actual, expected) \ + { \ + auto _actualType = (actual); \ + auto _expectedType = (expected); \ + if (_expectedType != nullptr) { \ + ASSERT_TRUE(_actualType != nullptr) \ + << "Expected: " << _expectedType->toString() << ", got null"; \ + ASSERT_EQ(*_actualType, *_expectedType) \ + << "Expected: " << _expectedType->toString() << ", got " \ + << _actualType->toString(); \ + } else { \ + ASSERT_EQ(_actualType, nullptr) \ + << "Expected null, got " << _actualType->toString(); \ + } \ + } + #ifndef NDEBUG #define DEBUG_ONLY_TEST(test_fixture, test_name) TEST(test_fixture, test_name) #define DEBUG_ONLY_TEST_F(test_fixture, test_name) \ TEST_F(test_fixture, test_name) #define DEBUG_ONLY_TEST_P(test_fixture, test_name) \ TEST_P(test_fixture, test_name) +#define DEBUG_ONLY_CO_TEST_F(test_fixture, test_name) \ + CO_TEST_F(test_fixture, test_name) #else #define DEBUG_ONLY_TEST(test_fixture, test_name) \ TEST(test_fixture, DISABLED_##test_name) @@ -112,4 +148,6 @@ TEST_F(test_fixture, DISABLED_##test_name) #define DEBUG_ONLY_TEST_P(test_fixture, test_name) \ TEST_P(test_fixture, DISABLED_##test_name) +#define DEBUG_ONLY_CO_TEST_F(test_fixture, test_name) \ + CO_TEST_F(test_fixture, DISABLED_test_name) #endif diff --git a/velox/common/base/tests/Memcpy.cpp b/velox/common/base/tests/Memcpy.cpp index 2ab51f87815e..18c7150759be 100644 --- a/velox/common/base/tests/Memcpy.cpp +++ b/velox/common/base/tests/Memcpy.cpp @@ -74,7 +74,7 @@ int main(int argc, char** argv) { Semaphore sem(0); std::vector ops; ops.resize(FLAGS_threads); - volatile uint64_t totalSum = 0; + uint64_t totalSum = 0; uint64_t totalUsec = 0; for (auto repeat = 0; repeat < FLAGS_repeats; ++repeat) { // Read once through 'other' to clear cache effects. diff --git a/velox/common/base/tests/SimdUtilTest.cpp b/velox/common/base/tests/SimdUtilTest.cpp index 447cc55a6e2f..4b932d9a1b3e 100644 --- a/velox/common/base/tests/SimdUtilTest.cpp +++ b/velox/common/base/tests/SimdUtilTest.cpp @@ -126,6 +126,7 @@ class SimdUtilTest : public testing::Test { folly::Random::DefaultGenerator rng_; }; +#ifdef VELOX_ENABLE_LOAD_SIMD_VALUE_BUFFER TEST_F(SimdUtilTest, setAll) { auto bits = simd::setAll(true); auto words = reinterpret_cast(&bits); @@ -133,6 +134,7 @@ TEST_F(SimdUtilTest, setAll) { EXPECT_EQ(words[i], -1ll); } } +#endif TEST_F(SimdUtilTest, bitIndices) { testIndices(1); diff --git a/velox/common/base/tests/SkewedPartitionBalancerTest.cpp b/velox/common/base/tests/SkewedPartitionBalancerTest.cpp index 5edb57563c60..9df4a2be7d73 100644 --- a/velox/common/base/tests/SkewedPartitionBalancerTest.cpp +++ b/velox/common/base/tests/SkewedPartitionBalancerTest.cpp @@ -409,8 +409,9 @@ TEST_F(SkewedPartitionRebalancerTest, concurrentFuzz) { threads.emplace_back([&]() { std::mt19937 localRng{200}; for (int iteration = 0; iteration < 1'000; ++iteration) { - SCOPED_TRACE(fmt::format( - "taskCount {}, iteration {}", taskCount, iteration)); + SCOPED_TRACE( + fmt::format( + "taskCount {}, iteration {}", taskCount, iteration)); const uint64_t processedBytes = 1 + folly::Random::rand32(512, localRng); balancer->addProcessedBytes(processedBytes); diff --git a/velox/common/base/tests/SpillConfigTest.cpp b/velox/common/base/tests/SpillConfigTest.cpp index 9949486a4862..30251da5c9bb 100644 --- a/velox/common/base/tests/SpillConfigTest.cpp +++ b/velox/common/base/tests/SpillConfigTest.cpp @@ -48,6 +48,7 @@ TEST_P(SpillConfigTest, spillLevel) { 0, 0, "none", + 0, prefixSortConfig_); struct { uint8_t bitOffset; @@ -134,6 +135,7 @@ TEST_P(SpillConfigTest, spillLevelLimit) { 0, 0, "none", + 0, prefixSortConfig_); ASSERT_EQ( @@ -181,6 +183,7 @@ TEST_P(SpillConfigTest, spillableReservationPercentages) { 1'000'000, 0, "none", + 0, prefixSortConfig_); }; diff --git a/velox/common/base/tests/SpillStatsTest.cpp b/velox/common/base/tests/SpillStatsTest.cpp index a9630c2bc972..965640114825 100644 --- a/velox/common/base/tests/SpillStatsTest.cpp +++ b/velox/common/base/tests/SpillStatsTest.cpp @@ -62,19 +62,11 @@ TEST(SpillStatsTest, spillStats) { stats2.spillReads = 10; stats2.spillReadTimeNanos = 100; stats2.spillDeserializationTimeNanos = 100; - ASSERT_TRUE(stats1 < stats2); - ASSERT_TRUE(stats1 <= stats2); - ASSERT_FALSE(stats1 > stats2); - ASSERT_FALSE(stats1 >= stats2); ASSERT_TRUE(stats1 != stats2); ASSERT_FALSE(stats1 == stats2); ASSERT_TRUE(stats1 == stats1); ASSERT_FALSE(stats1 != stats1); - ASSERT_FALSE(stats1 > stats1); - ASSERT_TRUE(stats1 >= stats1); - ASSERT_FALSE(stats1 < stats1); - ASSERT_TRUE(stats1 <= stats1); SpillStats delta = stats2 - stats1; ASSERT_EQ(delta.spilledInputBytes, 0); @@ -114,10 +106,6 @@ TEST(SpillStatsTest, spillStats) { stats1.spilledInputBytes = 2060; stats1.spilledBytes = 1030; stats1.spillReadBytes = 4096; - VELOX_ASSERT_THROW(stats1 < stats2, ""); - VELOX_ASSERT_THROW(stats1 > stats2, ""); - VELOX_ASSERT_THROW(stats1 <= stats2, ""); - VELOX_ASSERT_THROW(stats1 >= stats2, ""); ASSERT_TRUE(stats1 != stats2); ASSERT_FALSE(stats1 == stats2); const SpillStats zeroStats; diff --git a/velox/common/base/tests/SplitBlockBloomFilterTest.cpp b/velox/common/base/tests/SplitBlockBloomFilterTest.cpp new file mode 100644 index 000000000000..2a98ea756348 --- /dev/null +++ b/velox/common/base/tests/SplitBlockBloomFilterTest.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/SplitBlockBloomFilter.h" +#include "velox/common/testutil/RandomSeed.h" + +#include +#include + +#include + +namespace facebook::velox::test { +namespace { + +template +SplitBlockBloomFilter makeFilter( + const folly::F14FastSet& values, + const Hasher& hasher, + std::vector& blocks) { + blocks.resize(SplitBlockBloomFilter::numBlocks(values.size(), 0.01)); + bzero(blocks.data(), blocks.size() * sizeof(SplitBlockBloomFilter::Block)); + SplitBlockBloomFilter filter(blocks); + for (auto& value : values) { + filter.insert(hasher(value)); + } + return filter; +} + +TEST(SplitBlockBloomFilterTest, numBlocks) { + ASSERT_EQ( + SplitBlockBloomFilter::numBlocks(50'000'000, 0.01) * + sizeof(SplitBlockBloomFilter::Block), + xsimd::batch::size == 8 ? 60509568 : 65766912); + ASSERT_EQ( + SplitBlockBloomFilter::numBlocks(45'523'964, 0.1) * + sizeof(SplitBlockBloomFilter::Block), + xsimd::batch::size == 8 ? 32848640 : 27546352); +} + +TEST(SplitBlockBloomFilterTest, contiguous) { + constexpr int kSize = 100'000; + std::default_random_engine gen(common::testutil::getRandomSeed(42)); + std::uniform_int_distribution<> dist(0, 9); + folly::F14FastSet values; + values.reserve(kSize / 10); + for (int i = 0; i < kSize; ++i) { + if (dist(gen) == 0) { + values.insert(i); + } + } + std::vector blocks; + auto test = [&](auto hasher) { + auto filter = makeFilter(values, hasher, blocks); + int numFalsePositive = 0; + for (int i = 0; i < kSize; ++i) { + if (values.contains(i)) { + ASSERT_TRUE(filter.mayContain(hasher(i))); + } else { + numFalsePositive += filter.mayContain(hasher(i)); + } + } + ASSERT_LT(1.0 * numFalsePositive / kSize, 0.03); + }; + { + SCOPED_TRACE("Folly"); + test(folly::hasher()); + } + { + SCOPED_TRACE("Multiplication"); + test([](auto x) { return x * 0xc6a4a7935bd1e995L; }); + } +} + +TEST(SplitBlockBloomFilterTest, random) { + constexpr int kSize = 100'000; + std::default_random_engine gen(common::testutil::getRandomSeed(42)); + std::uniform_int_distribution dist; + folly::F14FastSet values; + values.reserve(kSize); + for (int i = 0; i < kSize; ++i) { + values.insert(dist(gen)); + } + std::vector blocks; + auto test = [&](auto hasher) { + auto filter = makeFilter(values, hasher, blocks); + for (auto value : values) { + ASSERT_TRUE(filter.mayContain(hasher(value))); + } + int numFalsePositive = 0; + for (int i = 0; i < kSize; ++i) { + auto value = dist(gen); + if (!values.contains(value)) { + numFalsePositive += filter.mayContain(hasher(value)); + } + } + ASSERT_LT(1.0 * numFalsePositive / kSize, 0.03); + }; + { + SCOPED_TRACE("Folly"); + test(folly::hasher()); + } + { + SCOPED_TRACE("Multiplication"); + test([](auto x) { return x * 0xc6a4a7935bd1e995L; }); + } +} + +} // namespace +} // namespace facebook::velox::test diff --git a/velox/common/base/tests/StatsReporterTest.cpp b/velox/common/base/tests/StatsReporterTest.cpp index 1fed4b9eb948..d76827d32f97 100644 --- a/velox/common/base/tests/StatsReporterTest.cpp +++ b/velox/common/base/tests/StatsReporterTest.cpp @@ -15,7 +15,6 @@ */ #include "velox/common/base/StatsReporter.h" -#include #include #include #include @@ -23,114 +22,81 @@ #include "velox/common/base/Counters.h" #include "velox/common/base/PeriodicStatsReporter.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/base/tests/StatsReporterUtils.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/caching/CacheTTLController.h" #include "velox/common/caching/SsdCache.h" #include "velox/common/memory/MmapAllocator.h" -namespace facebook::velox { +namespace facebook::velox::test { -class TestReporter : public BaseStatsReporter { - public: - mutable std::mutex m; - mutable std::map counterMap; - mutable std::unordered_map statTypeMap; - mutable std::unordered_map> - histogramPercentilesMap; - - void clear() { - std::lock_guard l(m); - counterMap.clear(); - statTypeMap.clear(); - histogramPercentilesMap.clear(); - } - - void registerMetricExportType(const char* key, StatType statType) - const override { - statTypeMap[key] = statType; - } - - void registerMetricExportType(folly::StringPiece key, StatType statType) - const override { - statTypeMap[key.str()] = statType; - } - - void registerHistogramMetricExportType( - const char* key, - int64_t /* bucketWidth */, - int64_t /* min */, - int64_t /* max */, - const std::vector& pcts) const override { - histogramPercentilesMap[key] = pcts; - } - - void registerHistogramMetricExportType( - folly::StringPiece key, - int64_t /* bucketWidth */, - int64_t /* min */, - int64_t /* max */, - const std::vector& pcts) const override { - histogramPercentilesMap[key.str()] = pcts; - } - - void addMetricValue(const std::string& key, const size_t value) - const override { - std::lock_guard l(m); - counterMap[key] += value; - } - - void addMetricValue(const char* key, const size_t value) const override { - std::lock_guard l(m); - counterMap[key] += value; - } - - void addMetricValue(folly::StringPiece key, size_t value) const override { - std::lock_guard l(m); - counterMap[key.str()] += value; - } - - void addHistogramMetricValue(const std::string& key, size_t value) - const override { - std::lock_guard l(m); - counterMap[key] = std::max(counterMap[key], value); - } - - void addHistogramMetricValue(const char* key, size_t value) const override { - std::lock_guard l(m); - counterMap[key] = std::max(counterMap[key], value); - } - - void addHistogramMetricValue(folly::StringPiece key, size_t value) - const override { - std::lock_guard l(m); - counterMap[key.str()] = std::max(counterMap[key.str()], value); - } - - std::string fetchMetrics() override { - std::stringstream ss; - ss << "["; - auto sep = ""; - for (const auto& [key, value] : counterMap) { - ss << sep << key << ":" << value; - sep = ","; - } - ss << "]"; - return ss.str(); - } +struct QuantileConfig { + std::vector statTypes; + std::vector percentiles; + std::vector slidingWindows; }; +inline QuantileConfig createStandardConfig() { + return {{StatType::AVG, StatType::COUNT}, {0.5, 0.95}, {60, 600}}; +} + class StatsReporterTest : public testing::Test { protected: void SetUp() override { + // Set the registered flag to true so macros will work + BaseStatsReporter::registered = true; reporter_ = std::dynamic_pointer_cast( folly::Singleton::try_get()); reporter_->clear(); } void TearDown() override { reporter_->clear(); + // Reset the registered flag + BaseStatsReporter::registered = false; } + public: std::shared_ptr reporter_; + + void verifyQuantileRegistration( + const std::string& key, + const QuantileConfig& config) { + EXPECT_TRUE(reporter_->quantileStatTypesMap.count(key)); + EXPECT_EQ(config.statTypes, reporter_->quantileStatTypesMap[key]); + EXPECT_EQ(config.percentiles, reporter_->quantilePercentilesMap[key]); + EXPECT_EQ(config.slidingWindows, reporter_->quantileSlidingWindowsMap[key]); + } + + void verifyDynamicQuantileRegistration( + const std::string& pattern, + const QuantileConfig& config) { + EXPECT_TRUE(reporter_->dynamicQuantileStatTypesMap.count(pattern)); + EXPECT_EQ( + config.statTypes, reporter_->dynamicQuantileStatTypesMap[pattern]); + EXPECT_EQ( + config.percentiles, reporter_->dynamicQuantilePercentilesMap[pattern]); + EXPECT_EQ( + config.slidingWindows, + reporter_->dynamicQuantileSlidingWindowsMap[pattern]); + } + + template + void registerAndVerifyQuantile( + KeyType key, + const QuantileConfig& config = createStandardConfig()) { + reporter_->registerQuantileMetricExportType( + key, config.statTypes, config.percentiles, config.slidingWindows); + verifyQuantileRegistration(std::string(key), config); + } + + template + void registerAndVerifyDynamicQuantile( + KeyType pattern, + const QuantileConfig& config = createStandardConfig()) { + reporter_->registerDynamicQuantileMetricExportType( + pattern, config.statTypes, config.percentiles, config.slidingWindows); + verifyDynamicQuantileRegistration(std::string(pattern), config); + } }; TEST_F(StatsReporterTest, trivialReporter) { @@ -193,7 +159,7 @@ class TestStatsReportMmapAllocator : public memory::MmapAllocator { return numMallocBytes_; } - memory::MachinePageCount numExternalMapped() const { + memory::MachinePageCount numExternalMapped() const override { return numExternalMapped_; } @@ -443,69 +409,109 @@ TEST_F(PeriodicStatsReporterTest, basic) { const auto& counterMap = reporter_->counterMap; { std::lock_guard l(reporter_->m); - ASSERT_EQ(counterMap.count(kMetricArbitratorFreeCapacityBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricArbitratorFreeReservedCapacityBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEmptyEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumSharedEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumExclusiveEntries.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMemoryCacheNumPrefetchedEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheTotalTinyBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheTotalLargeBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMemoryCacheTotalTinyPaddingBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMemoryCacheTotalLargePaddingBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheTotalPrefetchBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCachedEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCachedRegions.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCachedBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricCacheMaxAgeSecs.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryAllocatorMappedBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryAllocatorAllocatedBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMmapAllocatorDelegatedAllocatedBytes.str()), 1); - ASSERT_EQ( - counterMap.count(kMetricMmapAllocatorExternalMappedBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSpillMemoryBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSpillPeakMemoryBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryAllocatorTotalUsedBytes.str()), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricArbitratorFreeCapacityBytes)), 1); + ASSERT_EQ( + counterMap.count( + std::string(kMetricArbitratorFreeReservedCapacityBytes)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumTinyEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumLargeEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumEmptyEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumSharedEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumExclusiveEntries)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumPrefetchedEntries)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalTinyBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalLargeBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalTinyPaddingBytes)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalLargePaddingBytes)), + 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheTotalPrefetchBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCachedEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCachedRegions)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCachedBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricCacheMaxAgeSecs)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryAllocatorMappedBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryAllocatorAllocatedBytes)), 1); + ASSERT_EQ( + counterMap.count( + std::string(kMetricMmapAllocatorDelegatedAllocatedBytes)), + 1); + ASSERT_EQ( + counterMap.count( + std::string(kMetricMemoryAllocatorExternalMappedBytes)), + 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSpillMemoryBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSpillPeakMemoryBytes)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryAllocatorTotalUsedBytes)), 1); // Check deltas are not reported - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumHits.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheHitBytes.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumNew.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvicts.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumSavableEvicts.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvictChecks.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumWaitExclusive.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAllocClocks.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAgedOutEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheSumEvictScore.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadBytes.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenBytes.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenSsdErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenCheckpointErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenLogErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheMetaFileDeleteErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheGrowFileErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdDropped.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteCheckpointErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadSsdErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCorruptions.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCheckpointErrors.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsRead.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsWritten.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRegionsEvicted.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutRegions.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRecoveredEntries.str()), 0); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadWithoutChecksum.str()), 0); - ASSERT_EQ(counterMap.size(), 23); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumHits)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheHitBytes)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumNew)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumEvicts)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumSavableEvicts)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumEvictChecks)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumWaitExclusive)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAllocClocks)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAgedOutEntries)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheSumEvictScore)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadEntries)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadBytes)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenEntries)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenBytes)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenSsdErrors)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheOpenCheckpointErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenLogErrors)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheMetaFileDeleteErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheGrowFileErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdErrors)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteNoSpaceErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdDropped)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteExceedEntryLimit)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteCheckpointErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadSsdErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadCorruptions)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadCheckpointErrors)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCheckpointsRead)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheCheckpointsWritten)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheRegionsEvicted)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutEntries)), 0); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutRegions)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheRecoveredEntries)), 0); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadWithoutChecksum)), 0); + ASSERT_EQ(counterMap.size(), 24); } // Update stats @@ -526,7 +532,9 @@ TEST_F(PeriodicStatsReporterTest, basic) { newSsdStats->deleteMetaFileErrors = 10; newSsdStats->growFileErrors = 10; newSsdStats->writeSsdErrors = 10; + newSsdStats->writeSsdNoSpaceErrors = 1; newSsdStats->writeSsdDropped = 10; + newSsdStats->writeSsdExceedEntryLimit = 10; newSsdStats->writeCheckpointErrors = 10; newSsdStats->readSsdErrors = 10; newSsdStats->readSsdCorruptions = 10; @@ -545,8 +553,9 @@ TEST_F(PeriodicStatsReporterTest, basic) { .allocClocks = 10, .sumEvictScore = 10, .ssdStats = newSsdStats}); - arbitrator.updateStats(memory::MemoryArbitrator::Stats( - 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10)); + arbitrator.updateStats( + memory::MemoryArbitrator::Stats( + 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10)); std::this_thread::sleep_for(std::chrono::milliseconds(4'000)); // Stop right after sufficient wait to ensure the following reads from main @@ -556,39 +565,56 @@ TEST_F(PeriodicStatsReporterTest, basic) { // Check delta stats are reported { std::lock_guard l(reporter_->m); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumHits.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheHitBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumNew.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvicts.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumSavableEvicts.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumEvictChecks.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumWaitExclusive.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAllocClocks.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheNumAgedOutEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricMemoryCacheSumEvictScore.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWrittenBytes.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenSsdErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenCheckpointErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheOpenLogErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheMetaFileDeleteErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheGrowFileErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteSsdDropped.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheWriteCheckpointErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadSsdErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCorruptions.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadCheckpointErrors.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsRead.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheCheckpointsWritten.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRegionsEvicted.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheAgedOutRegions.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheRecoveredEntries.str()), 1); - ASSERT_EQ(counterMap.count(kMetricSsdCacheReadWithoutChecksum.str()), 1); - ASSERT_EQ(counterMap.size(), 55); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumHits)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheHitBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumNew)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricMemoryCacheNumEvicts)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumSavableEvicts)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumEvictChecks)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumWaitExclusive)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAllocClocks)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheNumAgedOutEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricMemoryCacheSumEvictScore)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWrittenBytes)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenSsdErrors)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheOpenCheckpointErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheOpenLogErrors)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheMetaFileDeleteErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheGrowFileErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdErrors)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteNoSpaceErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheWriteSsdDropped)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteExceedEntryLimit)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheWriteCheckpointErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadSsdErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheReadCorruptions)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadCheckpointErrors)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheCheckpointsRead)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheCheckpointsWritten)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheRegionsEvicted)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutEntries)), 1); + ASSERT_EQ(counterMap.count(std::string(kMetricSsdCacheAgedOutRegions)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheRecoveredEntries)), 1); + ASSERT_EQ( + counterMap.count(std::string(kMetricSsdCacheReadWithoutChecksum)), 1); + ASSERT_EQ(counterMap.size(), 58); } } @@ -612,12 +638,332 @@ TEST_F(PeriodicStatsReporterTest, allNullOption) { ASSERT_NO_THROW(stopPeriodicStatsReporter()); } +TEST_F(StatsReporterTest, registerQuantileMetricExportType) { + QuantileConfig config = { + {StatType::AVG, StatType::SUM, StatType::COUNT}, + {0.5, 0.95, 0.99}, + {60, 600}}; + registerAndVerifyQuantile("test_quantile_stat", config); +} + +TEST_F(StatsReporterTest, quantileRegistrationWithValues) { + // Test registration and value addition (covers all key types via templates) + const char* charKey = "test_quantile_char"; + std::string stringKey = "test_quantile_string"; + folly::StringPiece spKey("test_quantile_sp"); + + QuantileConfig config1 = { + {StatType::AVG, StatType::COUNT}, {0.5, 0.95}, {60, 600}}; + QuantileConfig config2 = { + {StatType::SUM, StatType::COUNT}, {0.75, 0.99}, {300}}; + QuantileConfig config3 = {{StatType::RATE}, {0.75, 0.90}, {3600}}; + + registerAndVerifyQuantile(charKey, config1); + registerAndVerifyQuantile(stringKey, config2); + registerAndVerifyQuantile(spKey, config3); + + // Test value addition + reporter_->addQuantileMetricValue(charKey, 100); + reporter_->addQuantileMetricValue(charKey, 50); + EXPECT_EQ(150, reporter_->counterMap[std::string(charKey)]); + + reporter_->addQuantileMetricValue(stringKey, 200); + reporter_->addQuantileMetricValue(stringKey, 300); + EXPECT_EQ(500, reporter_->counterMap[stringKey]); + + reporter_->addQuantileMetricValue(spKey, 25); + reporter_->addQuantileMetricValue(spKey, 75); + EXPECT_EQ(100, reporter_->counterMap[std::string(spKey)]); +} + +TEST_F(StatsReporterTest, multipleQuantileStats) { + QuantileConfig config1 = {{StatType::AVG}, {0.5}, {60}}; + QuantileConfig config2 = { + {StatType::COUNT, StatType::SUM}, {0.95, 0.99, 0.999}, {300, 900}}; + + registerAndVerifyQuantile("metric1", config1); + registerAndVerifyQuantile("metric2", config2); +} + +TEST_F(StatsReporterTest, emptyVectors) { + // Test registration with empty vectors (should be allowed) + std::vector emptyStatTypes; + std::vector emptyPercentiles; + std::vector emptySlidingWindows; + + EXPECT_NO_THROW(reporter_->registerQuantileMetricExportType( + "empty_metric", emptyStatTypes, emptyPercentiles, emptySlidingWindows)); + + EXPECT_TRUE(reporter_->quantileStatTypesMap.count("empty_metric")); + EXPECT_TRUE(reporter_->quantileStatTypesMap["empty_metric"].empty()); + EXPECT_TRUE(reporter_->quantilePercentilesMap["empty_metric"].empty()); + EXPECT_TRUE(reporter_->quantileSlidingWindowsMap["empty_metric"].empty()); +} + +TEST_F(StatsReporterTest, quantileStatMacros) { + DEFINE_QUANTILE_STAT( + "macro_test_stat", + statTypes(StatType::AVG, StatType::COUNT), + percentiles(0.5, 0.95, 0.99), + slidingWindowsSeconds(60, 600)); + + QuantileConfig expectedConfig = { + {StatType::AVG, StatType::COUNT}, {0.5, 0.95, 0.99}, {60, 600}}; + verifyQuantileRegistration("macro_test_stat", expectedConfig); + + RECORD_QUANTILE_STAT_VALUE("macro_test_stat", 100); + RECORD_QUANTILE_STAT_VALUE("macro_test_stat", 50); + RECORD_QUANTILE_STAT_VALUE("macro_test_stat"); // Default value of 1 + + EXPECT_EQ(151, reporter_->counterMap["macro_test_stat"]); +} + +TEST_F(StatsReporterTest, dynamicQuantileRegistrationWithValues) { + // Test registration and value addition for different key types + const char* charPattern = "test_metric_char.{}.{}"; + std::string stringPattern = "test_metric_string.{}.{}"; + folly::StringPiece spPattern("test_metric_sp.{}.{}"); + + QuantileConfig config1 = {{StatType::AVG}, {0.95}, {300}}; + QuantileConfig config2 = {{StatType::COUNT}, {0.5}, {60}}; + QuantileConfig config3 = {{StatType::SUM}, {0.5, 0.99}, {60, 600}}; + + registerAndVerifyDynamicQuantile(charPattern, config1); + registerAndVerifyDynamicQuantile(stringPattern, config2); + registerAndVerifyDynamicQuantile(spPattern, config3); + + auto subkeys1 = subkeys("db1", "table1"); + auto subkeys2 = subkeys("server1", "endpoint1"); + auto subkeys3 = subkeys("region1", "service1"); + + reporter_->addDynamicQuantileMetricValue(charPattern, subkeys1, 50); + reporter_->addDynamicQuantileMetricValue(charPattern, subkeys1, 150); + EXPECT_EQ(200, reporter_->counterMap["test_metric_char.db1.table1"]); + + reporter_->addDynamicQuantileMetricValue(stringPattern, subkeys2, 100); + reporter_->addDynamicQuantileMetricValue(stringPattern, subkeys2, 200); + EXPECT_EQ(300, reporter_->counterMap["test_metric_string.server1.endpoint1"]); + + reporter_->addDynamicQuantileMetricValue(spPattern, subkeys3, 75); + reporter_->addDynamicQuantileMetricValue(spPattern, subkeys3, 125); + EXPECT_EQ(200, reporter_->counterMap["test_metric_sp.region1.service1"]); +} + +TEST_F(StatsReporterTest, dynamicQuantileRegistrationOnly) { + // Test registration without values (registration-only test) + QuantileConfig config = { + {StatType::AVG, StatType::COUNT}, {0.5, 0.95}, {60, 600}}; + const char* pattern = "registration_test.{}.{}"; + + registerAndVerifyDynamicQuantile(pattern, config); + + // Verify registration worked but no values recorded yet + EXPECT_EQ(0, reporter_->counterMap.count("registration_test.key1.key2")); +} + +TEST_F(StatsReporterTest, dynamicQuantileMetricUnregisteredPattern) { + // Try to add a value for a dynamic quantile metric that was never registered + // This should silently ignore the value (not crash) + auto subkeyValues = subkeys("unregistered", "pattern"); + EXPECT_NO_THROW(reporter_->addDynamicQuantileMetricValue( + "unregistered_pattern.{}.{}", subkeyValues, 42)); + + // Verify no value was recorded + EXPECT_EQ( + 0, + reporter_->counterMap.count("unregistered_pattern.unregistered.pattern")); +} + +struct DynamicQuantilePatternTestCase { + std::string testName; + std::function testFunc; +}; + +class DynamicQuantilePatternTest + : public StatsReporterTest, + public testing::WithParamInterface {}; + +TEST_P(DynamicQuantilePatternTest, PatternScenarios) { + const auto& testCase = GetParam(); + testCase.testFunc(this); +} + +INSTANTIATE_TEST_SUITE_P( + PatternScenarios, + DynamicQuantilePatternTest, + testing::Values( + DynamicQuantilePatternTestCase{ + "DefaultValue", + [](StatsReporterTest* test) { + test->reporter_->registerDynamicQuantileMetricExportType( + "default_value_metric.{}", + statTypes(StatType::COUNT), + percentiles(0.5), + slidingWindowsSeconds(60)); + + auto subkeyValues = subkeys("test"); + test->reporter_->addDynamicQuantileMetricValue( + "default_value_metric.{}", subkeyValues, 1); + test->reporter_->addDynamicQuantileMetricValue( + "default_value_metric.{}", subkeyValues, 1); + test->reporter_->addDynamicQuantileMetricValue( + "default_value_metric.{}", subkeyValues, 1); + + EXPECT_EQ( + 3, test->reporter_->counterMap["default_value_metric.test"]); + }}, + DynamicQuantilePatternTestCase{ + "MultiplePatterns", + [](StatsReporterTest* test) { + test->reporter_->registerDynamicQuantileMetricExportType( + "pattern1.{}.{}", + statTypes(StatType::AVG), + percentiles(0.5), + slidingWindowsSeconds(60)); + + test->reporter_->registerDynamicQuantileMetricExportType( + "pattern2.{}.{}.{}", + statTypes(StatType::COUNT, StatType::SUM), + percentiles(0.95, 0.99, 0.999), + slidingWindowsSeconds(300, 900)); + + auto subkeys2 = subkeys("key1", "key2"); + auto subkeys3 = subkeys("key1", "key2", "key3"); + + test->reporter_->addDynamicQuantileMetricValue( + "pattern1.{}.{}", subkeys2, 100); + test->reporter_->addDynamicQuantileMetricValue( + "pattern2.{}.{}.{}", subkeys3, 200); + + EXPECT_EQ(100, test->reporter_->counterMap["pattern1.key1.key2"]); + EXPECT_EQ( + 200, test->reporter_->counterMap["pattern2.key1.key2.key3"]); + }}, + DynamicQuantilePatternTestCase{ + "ComplexPattern", + [](StatsReporterTest* test) { + test->reporter_->registerDynamicQuantileMetricExportType( + "complex.{}.latency.{}.p95", + statTypes(StatType::AVG, StatType::COUNT), + percentiles(0.95, 0.99), + slidingWindowsSeconds(60, 300, 3600)); + + auto subkeyValues = subkeys("http_server", "endpoint_api"); + test->reporter_->addDynamicQuantileMetricValue( + "complex.{}.latency.{}.p95", subkeyValues, 150); + test->reporter_->addDynamicQuantileMetricValue( + "complex.{}.latency.{}.p95", subkeyValues, 200); + + EXPECT_EQ( + 350, + test->reporter_->counterMap + ["complex.http_server.latency.endpoint_api.p95"]); + }}, + DynamicQuantilePatternTestCase{ + "SingleSubkey", + [](StatsReporterTest* test) { + test->reporter_->registerDynamicQuantileMetricExportType( + "single_sub.{}", + statTypes(StatType::RATE), + percentiles(0.75), + slidingWindowsSeconds(600)); + + auto subkeyValues = subkeys("single_value"); + test->reporter_->addDynamicQuantileMetricValue( + "single_sub.{}", subkeyValues, 300); + + EXPECT_EQ( + 300, test->reporter_->counterMap["single_sub.single_value"]); + }}, + DynamicQuantilePatternTestCase{ + "MultipleInstances", + [](StatsReporterTest* test) { + test->reporter_->registerDynamicQuantileMetricExportType( + "instance_test.{}.{}", + statTypes(StatType::SUM), + percentiles(0.5), + slidingWindowsSeconds(60)); + + auto subkeys1 = subkeys("instance1", "metric1"); + auto subkeys2 = subkeys("instance2", "metric2"); + auto subkeys3 = subkeys("instance1", "metric2"); + + test->reporter_->addDynamicQuantileMetricValue( + "instance_test.{}.{}", subkeys1, 100); + test->reporter_->addDynamicQuantileMetricValue( + "instance_test.{}.{}", subkeys2, 200); + test->reporter_->addDynamicQuantileMetricValue( + "instance_test.{}.{}", subkeys3, 300); + test->reporter_->addDynamicQuantileMetricValue( + "instance_test.{}.{}", subkeys1, 50); + + EXPECT_EQ( + 150, + test->reporter_ + ->counterMap["instance_test.instance1.metric1"]); + EXPECT_EQ( + 200, + test->reporter_ + ->counterMap["instance_test.instance2.metric2"]); + EXPECT_EQ( + 300, + test->reporter_ + ->counterMap["instance_test.instance1.metric2"]); + }})); + +TEST_F(StatsReporterTest, dynamicQuantileMetricEmptyVectors) { + // Test registration with empty vectors (should be allowed) + std::vector emptyStatTypes; + std::vector emptyPercentiles; + std::vector emptySlidingWindows; + + EXPECT_NO_THROW(reporter_->registerDynamicQuantileMetricExportType( + "empty_metric.{}", + emptyStatTypes, + emptyPercentiles, + emptySlidingWindows)); + + EXPECT_TRUE(reporter_->dynamicQuantileStatTypesMap.count("empty_metric.{}")); + EXPECT_TRUE( + reporter_->dynamicQuantileStatTypesMap["empty_metric.{}"].empty()); + EXPECT_TRUE( + reporter_->dynamicQuantilePercentilesMap["empty_metric.{}"].empty()); + EXPECT_TRUE( + reporter_->dynamicQuantileSlidingWindowsMap["empty_metric.{}"].empty()); + + // Try to use the empty pattern + auto subkeyValues = subkeys("test"); + EXPECT_NO_THROW(reporter_->addDynamicQuantileMetricValue( + "empty_metric.{}", subkeyValues, 100)); +} + +TEST_F(StatsReporterTest, dynamicQuantileStatMacros) { + DEFINE_DYNAMIC_QUANTILE_STAT( + "macro_test.{}.{}", + statTypes(StatType::AVG, StatType::COUNT), + percentiles(0.5, 0.95, 0.99), + slidingWindowsSeconds(60, 600)); + + QuantileConfig expectedConfig = { + {StatType::AVG, StatType::COUNT}, {0.5, 0.95, 0.99}, {60, 600}}; + verifyDynamicQuantileRegistration("macro_test.{}.{}", expectedConfig); + + RECORD_DYNAMIC_QUANTILE_STAT_VALUE( + "macro_test.{}.{}", subkeys("service", "method"), 100); + RECORD_DYNAMIC_QUANTILE_STAT_VALUE( + "macro_test.{}.{}", subkeys("service", "method"), 50); + RECORD_DYNAMIC_QUANTILE_STAT_VALUE( + "macro_test.{}.{}", subkeys("service", "method")); // Default value of 1 + + EXPECT_EQ(151, reporter_->counterMap["macro_test.service.method"]); +} + // Registering to folly Singleton with intended reporter type folly::Singleton reporter([]() { return new TestReporter(); }); -} // namespace facebook::velox +} // namespace facebook::velox::test int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/velox/common/base/tests/StatsReporterUtils.h b/velox/common/base/tests/StatsReporterUtils.h new file mode 100644 index 000000000000..f58111c3e29a --- /dev/null +++ b/velox/common/base/tests/StatsReporterUtils.h @@ -0,0 +1,249 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "velox/common/base/StatsReporter.h" + +namespace facebook::velox::test { + +/// A test implementation of BaseStatsReporter for use in unit tests. +/// This class provides a mock implementation that captures all metric +/// registrations and values for verification in tests. +class TestReporter : public BaseStatsReporter { + public: + mutable std::mutex m; + mutable std::map counterMap; + mutable std::unordered_map statTypeMap; + mutable std::unordered_map> + histogramPercentilesMap; + + mutable std::unordered_map> + quantileStatTypesMap; + mutable std::unordered_map> + quantilePercentilesMap; + mutable std::unordered_map> + quantileSlidingWindowsMap; + + mutable std::unordered_map> + dynamicQuantileStatTypesMap; + mutable std::unordered_map> + dynamicQuantilePercentilesMap; + mutable std::unordered_map> + dynamicQuantileSlidingWindowsMap; + + void clear() { + std::lock_guard l(m); + counterMap.clear(); + statTypeMap.clear(); + histogramPercentilesMap.clear(); + quantileStatTypesMap.clear(); + quantilePercentilesMap.clear(); + quantileSlidingWindowsMap.clear(); + dynamicQuantileStatTypesMap.clear(); + dynamicQuantilePercentilesMap.clear(); + dynamicQuantileSlidingWindowsMap.clear(); + } + + void registerMetricExportType(const char* key, StatType statType) + const override { + statTypeMap[key] = statType; + } + + void registerMetricExportType(folly::StringPiece key, StatType statType) + const override { + statTypeMap[std::string(key)] = statType; + } + + void registerHistogramMetricExportType( + const char* key, + int64_t /* bucketWidth */, + int64_t /* min */, + int64_t /* max */, + const std::vector& pcts) const override { + histogramPercentilesMap[key] = pcts; + } + + void registerHistogramMetricExportType( + folly::StringPiece key, + int64_t /* bucketWidth */, + int64_t /* min */, + int64_t /* max */, + const std::vector& pcts) const override { + histogramPercentilesMap[std::string(key)] = pcts; + } + + void registerQuantileMetricExportType( + const char* key, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds) const override { + std::lock_guard l(m); + quantileStatTypesMap[key] = statTypes; + quantilePercentilesMap[key] = pcts; + quantileSlidingWindowsMap[key] = slidingWindowsSeconds; + } + + void registerQuantileMetricExportType( + folly::StringPiece key, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds) const override { + std::lock_guard l(m); + quantileStatTypesMap[std::string(key)] = statTypes; + quantilePercentilesMap[std::string(key)] = pcts; + quantileSlidingWindowsMap[std::string(key)] = slidingWindowsSeconds; + } + + void addMetricValue(const std::string& key, const size_t value) + const override { + std::lock_guard l(m); + counterMap[key] += value; + } + + void addMetricValue(const char* key, const size_t value) const override { + std::lock_guard l(m); + counterMap[key] += value; + } + + void addMetricValue(folly::StringPiece key, size_t value) const override { + std::lock_guard l(m); + counterMap[std::string(key)] += value; + } + + void addHistogramMetricValue(const std::string& key, size_t value) + const override { + std::lock_guard l(m); + counterMap[key] = std::max(counterMap[key], value); + } + + void addHistogramMetricValue(const char* key, size_t value) const override { + std::lock_guard l(m); + counterMap[key] = std::max(counterMap[key], value); + } + + void addHistogramMetricValue(folly::StringPiece key, size_t value) + const override { + std::lock_guard l(m); + counterMap[std::string(key)] = + std::max(counterMap[std::string(key)], value); + } + + void addQuantileMetricValue(const std::string& key, size_t value) + const override { + std::lock_guard l(m); + counterMap[key] += value; + } + + void addQuantileMetricValue(const char* key, size_t value) const override { + std::lock_guard l(m); + counterMap[key] += value; + } + + void addQuantileMetricValue(folly::StringPiece key, size_t value) + const override { + std::lock_guard l(m); + counterMap[std::string(key)] += value; + } + + void registerDynamicQuantileMetricExportType( + const char* keyPattern, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds) const override { + std::lock_guard l(m); + dynamicQuantileStatTypesMap[keyPattern] = statTypes; + dynamicQuantilePercentilesMap[keyPattern] = pcts; + dynamicQuantileSlidingWindowsMap[keyPattern] = slidingWindowsSeconds; + } + + void registerDynamicQuantileMetricExportType( + folly::StringPiece keyPattern, + const std::vector& statTypes, + const std::vector& pcts, + const std::vector& slidingWindowsSeconds) const override { + std::lock_guard l(m); + dynamicQuantileStatTypesMap[std::string(keyPattern)] = statTypes; + dynamicQuantilePercentilesMap[std::string(keyPattern)] = pcts; + dynamicQuantileSlidingWindowsMap[std::string(keyPattern)] = + slidingWindowsSeconds; + } + + void addDynamicQuantileMetricValue( + const std::string& key, + folly::Range subkeys, + size_t value) const override { + std::lock_guard l(m); + // Check if the pattern was registered, if not silently ignore + if (dynamicQuantileStatTypesMap.find(key) == + dynamicQuantileStatTypesMap.end()) { + return; + } + + // Substitute placeholders in the key pattern with subkeys using fmt::format + std::string formattedKey; + fmt::dynamic_format_arg_store store; + for (const auto& subkey : subkeys) { + store.push_back(subkey); + } + formattedKey = fmt::vformat(key, store); + counterMap[formattedKey] += value; + } + + void addDynamicQuantileMetricValue( + const char* key, + folly::Range subkeys, + size_t value) const override { + addDynamicQuantileMetricValue(std::string(key), subkeys, value); + } + + void addDynamicQuantileMetricValue( + folly::StringPiece key, + folly::Range subkeys, + size_t value) const override { + addDynamicQuantileMetricValue(std::string(key), subkeys, value); + } + + std::string fetchMetrics() override { + std::stringstream ss; + ss << "["; + auto sep = ""; + for (const auto& [key, value] : counterMap) { + ss << sep << key << ":" << value; + sep = ","; + } + ss << "]"; + return ss.str(); + } + + // Get the current counter value for a specific key. + // Returns 0 if the key doesn't exist. + size_t getCounterValue(const std::string& key) const { + std::lock_guard l(m); + auto it = counterMap.find(key); + return it != counterMap.end() ? it->second : 0; + } +}; + +} // namespace facebook::velox::test diff --git a/velox/common/base/tests/StatusTest.cpp b/velox/common/base/tests/StatusTest.cpp index d5caa276a606..1501c18dd303 100644 --- a/velox/common/base/tests/StatusTest.cpp +++ b/velox/common/base/tests/StatusTest.cpp @@ -118,6 +118,36 @@ Status returnNotOk(Status s) { return Status::Invalid("invalid"); } +#define STATUS_MACRO_TEST(name, macro) \ + Status returnMacro##name() { \ + macro; \ + return Status::OK(); \ + } + +STATUS_MACRO_TEST(EmptyMessage, VELOX_USER_RETURN_GT(2, 1)); +STATUS_MACRO_TEST(Format, VELOX_USER_RETURN_GT(2, 1, "Occurred {} times.", 5)); +STATUS_MACRO_TEST(GT, VELOX_USER_RETURN_GT(2, 1, "User error occurred.")); +STATUS_MACRO_TEST(GE, VELOX_USER_RETURN_GE(2, 1, "User error occurred.")); +STATUS_MACRO_TEST(LT, VELOX_USER_RETURN_LT(1, 2, "User error occurred.")); +STATUS_MACRO_TEST(LE, VELOX_USER_RETURN_LE(1, 2, "User error occurred.")); +STATUS_MACRO_TEST(EQ, VELOX_USER_RETURN_EQ(1, 1, "User error occurred.")); +STATUS_MACRO_TEST(NE, VELOX_USER_RETURN_NE(1, 3, "User error occurred.")); +STATUS_MACRO_TEST( + NULL, + VELOX_USER_RETURN_NULL(nullptr, "User error occurred.")); + +Status returnNotNull(Status* status) { + VELOX_USER_RETURN_NOT_NULL(status, "User error occurred."); + return Status::OK(); +} + +Status returnMacroCheck() { + Status status = Status::OK(); + VELOX_USER_RETURN( + status.code() != StatusCode::kCancelled, "User error occurred."); + return Status::OK(); +} + TEST(StatusTest, statusMacros) { ASSERT_EQ(returnIf(true), Status::Invalid("error")); ASSERT_EQ(returnIf(false), Status::OK()); @@ -134,6 +164,67 @@ TEST(StatusTest, statusMacros) { didThrow = true; } ASSERT_TRUE(didThrow) << "VELOX_CHECK_OK did not throw"; + + ASSERT_EQ( + returnMacroCheck(), + Status::UserError( + "Reason: User error occurred.\nExpression: status.code() != StatusCode::kCancelled\n")); + ASSERT_EQ( + returnMacroEmptyMessage(), + Status::UserError("Reason: (2 vs. 1)\nExpression: 2 > 1\n")); + ASSERT_EQ( + returnMacroFormat(), + Status::UserError( + "Reason: (2 vs. 1) Occurred 5 times.\nExpression: 2 > 1\n")); + ASSERT_EQ( + returnMacroGT(), + Status::UserError( + "Reason: (2 vs. 1) User error occurred.\nExpression: 2 > 1\n")); + ASSERT_EQ( + returnMacroGE(), + Status::UserError( + "Reason: (2 vs. 1) User error occurred.\nExpression: 2 >= 1\n")); + ASSERT_EQ( + returnMacroLT(), + Status::UserError( + "Reason: (1 vs. 2) User error occurred.\nExpression: 1 < 2\n")); + ASSERT_EQ( + returnMacroLE(), + Status::UserError( + "Reason: (1 vs. 2) User error occurred.\nExpression: 1 <= 2\n")); + ASSERT_EQ( + returnMacroEQ(), + Status::UserError( + "Reason: (1 vs. 1) User error occurred.\nExpression: 1 == 1\n")); + ASSERT_EQ( + returnMacroNE(), + Status::UserError( + "Reason: (1 vs. 3) User error occurred.\nExpression: 1 != 3\n")); + ASSERT_EQ( + returnMacroNULL(), + Status::UserError( + "Reason: User error occurred.\nExpression: nullptr == nullptr\n")); + Status status = Status::OK(); + ASSERT_EQ( + returnNotNull(&status), + Status::UserError( + "Reason: User error occurred.\nExpression: status != nullptr\n")); +} + +TEST(StatusTest, statusMacrosSkipDetails) { + ScopedThreadSkipErrorDetails skipErrorDetails(true); + ASSERT_EQ(returnMacroCheck(), Status::UserError()); + ASSERT_EQ(returnMacroEmptyMessage(), Status::UserError()); + ASSERT_EQ(returnMacroFormat(), Status::UserError()); + ASSERT_EQ(returnMacroGT(), Status::UserError()); + ASSERT_EQ(returnMacroGE(), Status::UserError()); + ASSERT_EQ(returnMacroLT(), Status::UserError()); + ASSERT_EQ(returnMacroLE(), Status::UserError()); + ASSERT_EQ(returnMacroEQ(), Status::UserError()); + ASSERT_EQ(returnMacroNE(), Status::UserError()); + ASSERT_EQ(returnMacroNULL(), Status::UserError()); + Status status = Status::OK(); + ASSERT_EQ(returnNotNull(&status), Status::UserError()); } Expected modulo(int a, int b) { diff --git a/velox/common/caching/AsyncDataCache.cpp b/velox/common/caching/AsyncDataCache.cpp index 717107416d51..b54b4fa6979f 100644 --- a/velox/common/caching/AsyncDataCache.cpp +++ b/velox/common/caching/AsyncDataCache.cpp @@ -23,7 +23,6 @@ #include "velox/common/base/Exceptions.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/base/SuccinctPrinter.h" -#include "velox/common/caching/FileIds.h" #define VELOX_CACHE_ERROR(errorMessage) \ _VELOX_THROW( \ @@ -129,10 +128,11 @@ void AsyncDataCacheEntry::initialize(FileCacheKey key) { } else { // No memory to cover 'this'. release(); - VELOX_CACHE_ERROR(fmt::format( - "Failed to allocate {} pages for cache: {}", - sizePages, - cache->allocator()->getAndClearFailureMessage())); + VELOX_CACHE_ERROR( + fmt::format( + "Failed to allocate {} pages for cache: {}", + sizePages, + cache->allocator()->getAndClearFailureMessage())); } } } @@ -151,7 +151,7 @@ std::string AsyncDataCacheEntry::toString() const { numPins_); } -std::unique_ptr CacheShard::getFreeEntry() { +std::unique_ptr CacheShard::getFreeEntryLocked() { std::unique_ptr newEntry; if (freeEntries_.empty()) { newEntry = std::make_unique(this); @@ -176,7 +176,7 @@ CachePin CacheShard::findOrCreate( if (foundEntry->isExclusive()) { ++numWaitExclusive_; if (wait != nullptr) { - *wait = foundEntry->getFuture(); + *wait = foundEntry->getFutureLocked(); } return CachePin(); } @@ -212,7 +212,7 @@ CachePin CacheShard::findOrCreate( entryMap_.erase(it); } - auto newEntry = getFreeEntry(); + auto newEntry = getFreeEntryLocked(); // Initialize the members that must be set inside 'mutex_'. newEntry->numPins_ = AsyncDataCacheEntry::kExclusive; newEntry->promise_ = nullptr; @@ -336,7 +336,7 @@ std::unique_ptr> CacheShard::removeEntry( removeEntryLocked(entry); // After the entry is removed from the hash table, a promise can no longer // be made. It is safe to move the promise and realize it. - return entry->movePromise(); + return entry->movePromiseLocked(); } void CacheShard::removeEntryLocked(AsyncDataCacheEntry* entry) { @@ -409,7 +409,7 @@ uint64_t CacheShard::evict( eventCounter_ > entries_.size() / 4 || numChecked > entries_.size() / 8) { now = accessTime(); - calibrateThreshold(); + calibrateThresholdLocked(); numChecked = 0; eventCounter_ = 0; } @@ -490,8 +490,8 @@ void CacheShard::freeAllocations(std::vector& allocations) { allocations.clear(); } -void CacheShard::calibrateThreshold() { - auto numSamples = std::min(10, entries_.size()); +void CacheShard::calibrateThresholdLocked() { + auto numSamples = std::min(kMaxEvictionSamples, entries_.size()); auto now = accessTime(); auto entryIndex = (clockHand_ % entries_.size()); auto step = entries_.size() / numSamples; @@ -510,22 +510,32 @@ void CacheShard::calibrateThreshold() { return score; }, numSamples, - 80); + kEvictionPercentile); } void CacheShard::updateStats(CacheStats& stats) { std::lock_guard l(mutex_); for (auto& entry : entries_) { - if (!entry || !entry->key_.fileNum.hasValue()) { + if (!entry) { ++stats.numEmptyEntries; continue; } if (entry->isExclusive()) { - stats.exclusivePinnedBytes += - entry->data().byteSize() + entry->tinyData_.capacity(); + // We cannot read data() or tinyData_ which are being allocated during + // initialize(). Use size_ as an approximation of the pinned bytes. + stats.exclusivePinnedBytes += entry->size_; ++stats.numExclusive; - } else if (entry->isShared()) { + // Skip rest of the field accesses while entry is being initialized. + continue; + } + + if (!entry->key_.fileNum.hasValue()) { + ++stats.numEmptyEntries; + continue; + } + + if (entry->isShared()) { stats.sharedPinnedBytes += entry->data().byteSize() + entry->tinyData_.capacity(); ++stats.numShared; @@ -537,11 +547,15 @@ void CacheShard::updateStats(CacheStats& stats) { } ++stats.numEntries; - stats.tinySize += entry->tinyData_.size(); - stats.tinyPadding += entry->tinyData_.capacity() - entry->tinyData_.size(); if (entry->tinyData_.empty()) { stats.largeSize += entry->size_; stats.largePadding += entry->data_.byteSize() - entry->size_; + ++stats.numLargeEntries; + } else { + stats.tinySize += entry->tinyData_.size(); + stats.tinyPadding += + entry->tinyData_.capacity() - entry->tinyData_.size(); + ++stats.numTinyEntries; } } stats.numHit += numHit_; @@ -658,7 +672,7 @@ CacheStats CacheStats::operator-(const CacheStats& other) const { AsyncDataCache::AsyncDataCache( memory::MemoryAllocator* allocator, std::unique_ptr ssdCache) - : AsyncDataCache({}, allocator, std::move(ssdCache)){}; + : AsyncDataCache({}, allocator, std::move(ssdCache)) {} AsyncDataCache::AsyncDataCache( const Options& options, @@ -835,7 +849,6 @@ uint64_t AsyncDataCache::shrink(uint64_t targetBytes) { LOG(INFO) << "Try to shrink cache to free up " << velox::succinctBytes(targetBytes) << " memory"; - const uint64_t minBytesToEvict = 8UL << 20; uint64_t evictedBytes{0}; uint64_t shrinkTimeUs{0}; { @@ -843,7 +856,8 @@ uint64_t AsyncDataCache::shrink(uint64_t targetBytes) { for (int shard = 0; shard < shards_.size(); ++shard) { memory::Allocation unused; evictedBytes += shards_[shardCounter_++ & (kShardMask)]->evict( - std::max(minBytesToEvict, targetBytes - evictedBytes), + std::max( + CacheShard::kMinBytesToEvict, targetBytes - evictedBytes), // Cache shrink is triggered when server is under low memory pressure // so need to free up memory as soon as possible. So we always avoid // triggering ssd save to accelerate the cache evictions. @@ -875,7 +889,7 @@ bool AsyncDataCache::canTryAllocate( return true; } return numPages - acquired.numPages() <= - (memory::AllocationTraits::numPages(allocator_->capacity())) - + memory::AllocationTraits::numPages(allocator_->capacity()) - allocator_->numAllocated(); } @@ -1064,8 +1078,9 @@ CoalesceIoStats readPins( [&](int32_t size, std::vector>& ranges) { // This hack allows us to store the size of the gap in the Range, // without actually allocating a buffer for it. - ranges.push_back(folly::Range( - nullptr, reinterpret_cast(static_cast(size)))); + ranges.push_back( + folly::Range( + nullptr, reinterpret_cast(static_cast(size)))); }, std::move(readFunc)); } diff --git a/velox/common/caching/AsyncDataCache.h b/velox/common/caching/AsyncDataCache.h index 5f0bf7bf6088..0baf904deb03 100644 --- a/velox/common/caching/AsyncDataCache.h +++ b/velox/common/caching/AsyncDataCache.h @@ -263,22 +263,21 @@ class AsyncDataCacheEntry { /// Sets access stats so that this is immediately evictable. void makeEvictable(); - /// Moves the promise out of 'this'. Used in order to handle the - /// promise within the lock of the cache shard, so not within private - /// methods of 'this'. - std::unique_ptr> movePromise() { - return std::move(promise_); - } - std::string toString() const; private: void release(); void addReference(); + // Moves the promise out of 'this'. Must be called inside the mutex of + // 'shard_'. + std::unique_ptr> movePromiseLocked() { + return std::move(promise_); + } + // Returns a future that will be realized when a caller can retry getting // 'this'. Must be called inside the mutex of 'shard_'. - folly::SemiFuture getFuture() { + folly::SemiFuture getFutureLocked() { if (promise_ == nullptr) { promise_ = std::make_unique>(); } @@ -308,7 +307,7 @@ class AsyncDataCacheEntry { // True if 'this' is speculatively loaded. This is reset on first hit. Allows // catching a situation where prefetched entries get evicted before they are // hit. - bool isPrefetch_{false}; + tsan_atomic isPrefetch_{false}; // Sets after first use of a prefetched entry. Cleared by // getAndClearFirstUseFlag(). Does not require synchronization since used for @@ -496,6 +495,10 @@ struct CacheStats { int64_t largePadding{0}; /// Total number of entries. int32_t numEntries{0}; + /// Total number of tiny entries. + int32_t numTinyEntries{0}; + /// Total number of large entries. + int32_t numLargeEntries{0}; /// Number of entries that do not cache anything. int32_t numEmptyEntries{0}; /// Number of entries pinned for shared access. @@ -556,6 +559,8 @@ struct CacheStats { /// and other housekeeping. class CacheShard { public: + static constexpr uint64_t kMinBytesToEvict = 8UL << 20; // 8MB + CacheShard(AsyncDataCache* cache, double maxWriteRatio) : cache_(cache), maxWriteRatio_(maxWriteRatio) {} @@ -629,8 +634,10 @@ class CacheShard { private: static constexpr uint32_t kMaxFreeEntries = 1 << 10; static constexpr int32_t kNoThreshold = std::numeric_limits::max(); + static constexpr int32_t kMaxEvictionSamples = 10; + static constexpr int32_t kEvictionPercentile = 80; - void calibrateThreshold(); + void calibrateThresholdLocked(); void removeEntryLocked(AsyncDataCacheEntry* entry); @@ -638,7 +645,7 @@ class CacheShard { // // TODO: consider to pass a size hint so as to select the a free entry which // already has the right amount of memory associated with it. - std::unique_ptr getFreeEntry(); + std::unique_ptr getFreeEntryLocked(); CachePin initEntry(RawFileCacheKey key, AsyncDataCacheEntry* entry); @@ -703,7 +710,7 @@ class AsyncDataCache : public memory::Cache { int32_t _minSsdSavableBytes = 1 << 24) : maxWriteRatio(_maxWriteRatio), ssdSavableRatio(_ssdSavableRatio), - minSsdSavableBytes(_minSsdSavableBytes){}; + minSsdSavableBytes(_minSsdSavableBytes) {} /// The max ratio of the number of in-memory cache entries being written to /// SSD cache over the total number of cache entries. This is to control SSD @@ -794,8 +801,7 @@ class AsyncDataCache : public memory::Cache { #endif /// Returns snapshot of the aggregated stats from all shards and the stats of /// SSD cache if used. - virtual CacheStats - refreshStats() const; + virtual CacheStats refreshStats() const; /// If 'details' is true, returns the stats of the backing memory allocator /// and ssd cache. Otherwise, only returns the cache stats. @@ -845,7 +851,7 @@ class AsyncDataCache : public memory::Cache { const std::vector& keys, const SizeFunc& sizeFunc, const ProcessPin& processPin) { - for (auto i = 0; i < keys.size(); ++i) { + for (size_t i = 0; i < keys.size(); ++i) { auto pin = findOrCreate(keys[i], sizeFunc(i), nullptr); if (pin.empty() || pin.checkedEntry()->isShared()) { continue; @@ -876,9 +882,12 @@ class AsyncDataCache : public memory::Cache { void clear(); private: - static constexpr int32_t kNumShards = 4; // Must be power of 2. + // Must be power of 2. + static constexpr int32_t kNumShards = 4; static constexpr int32_t kShardMask = kNumShards - 1; + static_assert((kNumShards & kShardMask) == 0); + // True if 'acquired' has more pages than 'numPages' or allocator has space // for numPages - acquired pages of more allocation. bool canTryAllocate( diff --git a/velox/common/caching/CMakeLists.txt b/velox/common/caching/CMakeLists.txt index 7e49a752868a..bc09b5552c86 100644 --- a/velox/common/caching/CMakeLists.txt +++ b/velox/common/caching/CMakeLists.txt @@ -21,18 +21,21 @@ velox_add_library( SsdCache.cpp SsdFile.cpp SsdFileTracker.cpp - StringIdMap.cpp) + StringIdMap.cpp +) velox_link_libraries( velox_caching - PUBLIC velox_common_base - velox_exception - velox_file - velox_memory - velox_process - velox_time - Folly::folly - fmt::fmt - PRIVATE velox_time) + PUBLIC + velox_common_base + velox_exception + velox_file + velox_memory + velox_process + velox_time + Folly::folly + fmt::fmt + PRIVATE velox_time +) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/common/caching/SsdCache.cpp b/velox/common/caching/SsdCache.cpp index 2347b45982fd..50d32d6ba3ec 100644 --- a/velox/common/caching/SsdCache.cpp +++ b/velox/common/caching/SsdCache.cpp @@ -32,7 +32,8 @@ SsdCache::SsdCache(const Config& config) : filePrefix_(config.filePrefix), numShards_(config.numShards), groupStats_(std::make_unique()), - executor_(config.executor) { + executor_(config.executor), + maxEntries_(config.maxEntries) { // Make sure the given path of Ssd files has the prefix for local file system. // Local file system would be derived based on the prefix. VELOX_CHECK( @@ -58,6 +59,9 @@ SsdCache::SsdCache(const Config& config) const uint64_t sizeQuantum = numShards_ * SsdFile::kRegionSize; const int32_t fileMaxRegions = bits::roundUp(config.maxBytes, sizeQuantum) / sizeQuantum; + // Distribute maxEntries across shards + const uint64_t maxEntriesPerShard = + maxEntries_ == 0 ? 0 : bits::divRoundUp(maxEntries_, numShards_); for (auto i = 0; i < numShards_; ++i) { const auto fileConfig = SsdFile::Config( fmt::format("{}{}", filePrefix_, i), @@ -67,6 +71,7 @@ SsdCache::SsdCache(const Config& config) config.disableFileCow, config.checksumEnabled, checksumReadVerificationEnabled, + maxEntriesPerShard, executor_); files_.push_back(std::make_unique(fileConfig)); } @@ -90,7 +95,8 @@ bool SsdCache::startWrite() { } void SsdCache::write(std::vector pins) { - VELOX_CHECK_EQ(numShards_, writesInProgress_); + VELOX_CHECK_EQ( + numShards_, writesInProgress_, "startWrite() have not been called"); TestValue::adjust("facebook::velox::cache::SsdCache::write", this); @@ -98,7 +104,7 @@ void SsdCache::write(std::vector pins) { uint64_t bytes = 0; std::vector> shards(numShards_); - for (auto& pin : pins) { + for (const auto& pin : pins) { bytes += pin.checkedEntry()->size(); const auto& target = file(pin.checkedEntry()->key().fileNum.id()); shards[target.shardId()].push_back(std::move(pin)); @@ -135,9 +141,10 @@ void SsdCache::write(std::vector pins) { // Typically occurs every few GB. Allows detecting unusually slow rates // from failing devices. VELOX_SSD_CACHE_LOG(INFO) << fmt::format( - "Wrote {}, {} bytes/s", + "Wrote {} to SSD, {} bytes/s", succinctBytes(bytes), - static_cast(bytes) / (getCurrentTimeMicro() - startTimeUs)); + static_cast(bytes) * 1'000'000 / + (getCurrentTimeMicro() - startTimeUs)); } }); } @@ -191,7 +198,11 @@ std::string SsdCache::toString() const { out << "Ssd cache IO: Write " << succinctBytes(data.bytesWritten) << " read " << succinctBytes(data.bytesRead) << " Size " << succinctBytes(capacity) << " Occupied " << succinctBytes(data.bytesCached); - out << " " << (data.entriesCached >> 10) << "K entries."; + out << " " << (data.entriesCached >> 10) << "K entries"; + if (maxEntries_ > 0) { + out << " (max " << (maxEntries_ >> 10) << "K)"; + } + out << "."; out << "\nGroupStats: " << groupStats_->toString(capacity); return out.str(); } @@ -207,7 +218,7 @@ void SsdCache::shutdown() { VELOX_SSD_CACHE_LOG(INFO) << "SSD cache is shutting down"; while (writesInProgress_) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + std::this_thread::sleep_for(kWriteWaitMs); } for (auto& file : files_) { file->checkpoint(true); @@ -223,7 +234,7 @@ void SsdCache::clear() { void SsdCache::waitForWriteToFinish() { while (writesInProgress_ != 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + std::this_thread::sleep_for(kWriteWaitMs); } } diff --git a/velox/common/caching/SsdCache.h b/velox/common/caching/SsdCache.h index e46705949481..2ec837958c14 100644 --- a/velox/common/caching/SsdCache.h +++ b/velox/common/caching/SsdCache.h @@ -23,6 +23,8 @@ namespace facebook::velox::cache { #define VELOX_SSD_CACHE_LOG_PREFIX "[SSDCA] " #define VELOX_SSD_CACHE_LOG(severity) \ LOG(severity) << VELOX_SSD_CACHE_LOG_PREFIX +#define VELOX_SSD_CACHE_LOG_EVERY_MS(severity, ms) \ + FB_LOG_EVERY_MS(severity, ms) << VELOX_SSD_CACHE_LOG_PREFIX namespace test { class SsdCacheTestHelper; @@ -41,7 +43,8 @@ class SsdCache { uint64_t _checkpointIntervalBytes = 0, bool _disableFileCow = false, bool _checksumEnabled = false, - bool _checksumReadVerificationEnabled = false) + bool _checksumReadVerificationEnabled = false, + uint64_t _maxEntries = 0) : filePrefix(_filePrefix), maxBytes(_maxBytes), numShards(_numShards), @@ -49,7 +52,8 @@ class SsdCache { disableFileCow(_disableFileCow), checksumEnabled(_checksumEnabled), checksumReadVerificationEnabled(_checksumReadVerificationEnabled), - executor(_executor){}; + executor(_executor), + maxEntries(_maxEntries) {} std::string filePrefix; uint64_t maxBytes; @@ -71,6 +75,10 @@ class SsdCache { /// Executor for async fsync in checkpoint. folly::Executor* executor; + /// Maximum number of SSD cache entries allowed. A value of 0 means no + /// limit. When the limit is reached, new entry writes will be skipped. + uint64_t maxEntries; + std::string toString() const { return fmt::format( "{} shards, capacity {}, checkpoint size {}, file cow {}, checksum {}, read verification {}", @@ -169,6 +177,9 @@ class SsdCache { void waitForWriteToFinish(); private: + // Polling interval for waiting on write completion + static constexpr auto kWriteWaitMs = std::chrono::milliseconds(100); + void checkNotShutdownLocked() { VELOX_CHECK( !shutdown_, "Unexpected write after SSD cache has been shutdown"); @@ -179,6 +190,9 @@ class SsdCache { // Stats for selecting entries to save from AsyncDataCache. const std::unique_ptr groupStats_; folly::Executor* const executor_; + // Maximum number of SSD cache entries allowed. 0 means no limit. + const uint64_t maxEntries_; + mutable std::mutex mutex_; std::vector> files_; diff --git a/velox/common/caching/SsdFile.cpp b/velox/common/caching/SsdFile.cpp index fa5c46f22dae..caf78d29beb7 100644 --- a/velox/common/caching/SsdFile.cpp +++ b/velox/common/caching/SsdFile.cpp @@ -112,9 +112,10 @@ SsdFile::SsdFile(const Config& config) checksumReadVerificationEnabled_( config.checksumEnabled && config.checksumReadVerificationEnabled), shardId_(config.shardId), + maxEntries_(config.maxEntries), + executor_(config.executor), fs_(filesystems::getFileSystem(fileName_, nullptr)), - checkpointIntervalBytes_(config.checkpointIntervalBytes), - executor_(config.executor) { + checkpointIntervalBytes_(config.checkpointIntervalBytes) { process::TraceContext trace("SsdFile::SsdFile"); filesystems::FileOptions fileOptions; fileOptions.shouldThrowOnFileAlreadyExists = false; @@ -307,8 +308,19 @@ bool SsdFile::growOrEvictLocked() { } } - auto candidates = - tracker_.findEvictionCandidates(3, numRegions_, regionPins_); + // If SSD is in no space state and future eviction logging cannot go through, + // skip eviction, to avoid data inconsistency for checkpointing. Eviction log + // is up to date. Separately when SSD is in no space state, + // growOrEvictLocked() would not be invoked by write() in the first place. + if (state_.load() == State::kNoSpace) { + VELOX_SSD_CACHE_LOG_EVERY_MS(WARNING, 1'000) + << "Failed to grow cache file " << fileName_ + << " due to SSD in no space state."; + return false; + } + + auto candidates = tracker_.findEvictionCandidates( + kNumEvictionCandidates, numRegions_, regionPins_); if (candidates.empty()) { suspended_ = true; return false; @@ -345,6 +357,23 @@ void SsdFile::clearRegionEntriesLocked(const std::vector& regions) { void SsdFile::write(std::vector& pins) { process::TraceContext trace("SsdFile::write"); + + if (state_.load() == State::kNoSpace) { + ++stats_.writeSsdDropped; + VELOX_SSD_CACHE_LOG_EVERY_MS(WARNING, 10'000) + << "SSD file write is dropped in no space state."; + return; + } + + // Check entry count limit before writing + if (maxEntries_ > 0) { + std::shared_lock l(mutex_); + if (entries_.size() + pins.size() >= maxEntries_) { + ++stats_.writeSsdExceedEntryLimit; + return; + } + } + // Sorts the pins by their file/offset. In this way what is adjacent in // storage is likely adjacent on SSD. std::sort(pins.begin(), pins.end()); @@ -438,16 +467,29 @@ bool SsdFile::write( int64_t offset, int64_t length, const std::vector& iovecs) { + VELOX_DCHECK_NE(state_.load(), State::kNoSpace); + try { writeFile_->write(iovecs, offset, length); return true; } catch (const std::exception&) { + const int err = errno; VELOX_SSD_CACHE_LOG(ERROR) << "Failed to write to SSD, file name: " << fileName_ << ", size: " << iovecs.size() << ", offset: " << offset - << ", error code: " << errno - << ", error string: " << folly::errnoStr(errno); + << ", error code: " << err + << ", error string: " << folly::errnoStr(err); ++stats_.writeSsdErrors; + + if (err == ENOSPC) { + if (state_.exchange(State::kNoSpace) != State::kNoSpace) { + VELOX_SSD_CACHE_LOG(WARNING) + << "State of cache file " << fileName_ << " transits to " + << stateString(State::kNoSpace); + } + ++stats_.writeSsdNoSpaceErrors; + } + return false; } } @@ -520,6 +562,9 @@ void SsdFile::updateStats(SsdCacheStats& stats) const { stats.deleteMetaFileErrors += stats_.deleteMetaFileErrors; stats.growFileErrors += stats_.growFileErrors; stats.writeSsdErrors += stats_.writeSsdErrors; + stats.writeSsdNoSpaceErrors += stats_.writeSsdNoSpaceErrors; + stats.writeSsdDropped += stats_.writeSsdDropped; + stats.writeSsdExceedEntryLimit += stats_.writeSsdExceedEntryLimit; stats.writeCheckpointErrors += stats_.writeCheckpointErrors; stats.readSsdErrors += stats_.readSsdErrors; stats.readCheckpointErrors += stats_.readCheckpointErrors; @@ -711,9 +756,10 @@ void SsdFile::maybeFlushCheckpointBuffer(uint32_t appendBytes, bool force) { (force || checkpointBufferedDataSize_ + appendBytes >= kCheckpointBufferSize)) { VELOX_CHECK_NOT_NULL(checkpointBuffer_); - checkpointWriteFile_->append(std::string_view( - static_cast(checkpointBuffer_), - checkpointBufferedDataSize_)); + checkpointWriteFile_->append( + std::string_view( + static_cast(checkpointBuffer_), + checkpointBufferedDataSize_)); checkpointBufferedDataSize_ = 0; } } @@ -830,6 +876,23 @@ void SsdFile::checkpoint(bool force) { } } +// static +std::string SsdFile::stateString(State state) { + switch (state) { + case State::kActive: + return "Active"; + case State::kNoSpace: + return "NoSpace"; + default: + return fmt::format("UNKNOWN: {}", static_cast(state)); + } +} + +std::ostream& operator<<(std::ostream& out, const SsdFile::State& state) { + out << SsdFile::stateString(state); + return out; +} + void SsdFile::initializeCheckpoint() { if (!checkpointEnabled()) { return; @@ -964,7 +1027,7 @@ void SsdFile::readCheckpoint() { auto checkpointReadFile = fs_->openFileForRead(checkpointPath); stream = std::make_unique( std::move(checkpointReadFile), - 1 << 20, + kCheckpointReadBufferSize, memory::memoryManager()->cachePool()); } catch (std::exception& e) { ++stats_.openCheckpointErrors; diff --git a/velox/common/caching/SsdFile.h b/velox/common/caching/SsdFile.h index c9ab1f5b3016..eb5ecc861821 100644 --- a/velox/common/caching/SsdFile.h +++ b/velox/common/caching/SsdFile.h @@ -32,16 +32,18 @@ class SsdFileTestHelper; class SsdCacheTestHelper; } // namespace test -/// A 64 bit word describing a SSD cache entry in an SsdFile. The low 23 bits -/// are the size, for a maximum entry size of 8MB. The high bits are the offset. +/// The 'fileBits_' field is a 64 bit word describing a SSD cache entry in an +/// SsdFile. The low 23 bits are the size, for a maximum entry size of 8MB. The +/// high 41 bits are the offset. The 'checksum_' field is optional and is used +/// only when the checksum feature is enabled, otherwise, its value is always 0. class SsdRun { public: static constexpr int32_t kSizeBits = 23; - SsdRun() : fileBits_(0) {} + SsdRun() = default; SsdRun(uint64_t offset, uint32_t size, uint32_t checksum) - : fileBits_((offset << kSizeBits) | ((size - 1))), checksum_(checksum) { + : fileBits_((offset << kSizeBits) | (size - 1)), checksum_(checksum) { VELOX_CHECK_LT(offset, 1L << (64 - kSizeBits)); VELOX_CHECK_NE(size, 0); VELOX_CHECK_LE(size, 1 << kSizeBits); @@ -58,9 +60,11 @@ class SsdRun { checksum_ = other.checksum_; } - void operator=(SsdRun&& other) { + void operator=(SsdRun&& other) noexcept { fileBits_ = other.fileBits_; checksum_ = other.checksum_; + other.fileBits_ = 0; + other.checksum_ = 0; } uint64_t offset() const { @@ -83,8 +87,8 @@ class SsdRun { private: // Contains the file offset and size. - uint64_t fileBits_; - uint32_t checksum_; + uint64_t fileBits_{0}; + uint32_t checksum_{0}; }; /// Represents an SsdFile entry that is planned for load or being loaded. This @@ -164,7 +168,9 @@ struct SsdCacheStats { deleteMetaFileErrors = tsanAtomicValue(other.deleteMetaFileErrors); growFileErrors = tsanAtomicValue(other.growFileErrors); writeSsdErrors = tsanAtomicValue(other.writeSsdErrors); + writeSsdNoSpaceErrors = tsanAtomicValue(other.writeSsdNoSpaceErrors); writeSsdDropped = tsanAtomicValue(other.writeSsdDropped); + writeSsdExceedEntryLimit = tsanAtomicValue(other.writeSsdExceedEntryLimit); writeCheckpointErrors = tsanAtomicValue(other.writeCheckpointErrors); readSsdErrors = tsanAtomicValue(other.readSsdErrors); readCheckpointErrors = tsanAtomicValue(other.readCheckpointErrors); @@ -193,7 +199,11 @@ struct SsdCacheStats { deleteMetaFileErrors - other.deleteMetaFileErrors; result.growFileErrors = growFileErrors - other.growFileErrors; result.writeSsdErrors = writeSsdErrors - other.writeSsdErrors; + result.writeSsdNoSpaceErrors = + writeSsdNoSpaceErrors - other.writeSsdNoSpaceErrors; result.writeSsdDropped = writeSsdDropped - other.writeSsdDropped; + result.writeSsdExceedEntryLimit = + writeSsdExceedEntryLimit - other.writeSsdExceedEntryLimit; result.writeCheckpointErrors = writeCheckpointErrors - other.writeCheckpointErrors; result.readSsdCorruptions = readSsdCorruptions - other.readSsdCorruptions; @@ -232,7 +242,9 @@ struct SsdCacheStats { tsan_atomic deleteMetaFileErrors{0}; tsan_atomic growFileErrors{0}; tsan_atomic writeSsdErrors{0}; + tsan_atomic writeSsdNoSpaceErrors{0}; tsan_atomic writeSsdDropped{0}; + tsan_atomic writeSsdExceedEntryLimit{0}; tsan_atomic writeCheckpointErrors{0}; tsan_atomic readSsdErrors{0}; tsan_atomic readCheckpointErrors{0}; @@ -257,6 +269,7 @@ class SsdFile { bool _disableFileCow = false, bool _checksumEnabled = false, bool _checksumReadVerificationEnabled = false, + uint64_t _maxEntries = 0, folly::Executor* _executor = nullptr) : fileName(_fileName), shardId(_shardId), @@ -266,7 +279,8 @@ class SsdFile { checksumEnabled(_checksumEnabled), checksumReadVerificationEnabled( _checksumEnabled && _checksumReadVerificationEnabled), - executor(_executor){}; + maxEntries(_maxEntries), + executor(_executor) {} /// Name of cache file, used as prefix for checkpoint files. const std::string fileName; @@ -279,19 +293,28 @@ class SsdFile { /// Checkpoint after every 'checkpointIntervalBytes' written into this /// file. 0 means no checkpointing. This is set to 0 if checkpointing fails. - uint64_t checkpointIntervalBytes; + const uint64_t checkpointIntervalBytes; /// True if copy on write should be disabled. - bool disableFileCow; + const bool disableFileCow; /// If true, checksum write to SSD is enabled. - bool checksumEnabled; + const bool checksumEnabled; /// If true, checksum read verification from SSD is enabled. - bool checksumReadVerificationEnabled; + const bool checksumReadVerificationEnabled; + + /// Maximum number of SSD cache entries allowed. A value of 0 means no + /// limit. When the limit is reached, new entry writes will be skipped. + const uint64_t maxEntries; /// Executor for async fsync in checkpoint. - folly::Executor* executor; + folly::Executor* const executor; + }; + + enum class State : uint8_t { + kActive, + kNoSpace, }; static constexpr uint64_t kRegionSize = 1 << 26; // 64MB @@ -300,6 +323,9 @@ class SsdFile { /// filename. SsdFile(const Config& config); + /// Convert State to std::string. + static std::string stateString(State state); + /// Adds entries of 'pins' to this file. 'pins' must be in read mode and /// those pins that are successfully added to SSD are marked as being on SSD. /// The file of the entries must be a file that is backed by 'this'. @@ -386,8 +412,19 @@ class SsdFile { // Magic number at end of completed checkpoint file. static constexpr int64_t kCheckpointEndMarker = 0xcbedf11e; + // Maximum percentage of erased entries in a region before it becomes + // eligible for clearing and reuse. When more than 50% of a region's + // entries have been erased (e.g., via TTL eviction), the region can be + // cleared and added back to the writable regions pool. static constexpr int kMaxErasedSizePct = 50; + // Number of eviction candidates to consider when selecting regions to + // evict. + static constexpr int32_t kNumEvictionCandidates = 3; + + // Buffer size for reading checkpoint files during recovery. + static constexpr int32_t kCheckpointReadBufferSize = 1 << 20; // 1MB + // Updates the read count of a region. void regionRead(int32_t region, int32_t size) { tracker_.regionRead(region, size); @@ -473,6 +510,10 @@ class SsdFile { if (!checkpointEnabled()) { return false; } + // Once no SSD space, skip the subsequent checkpointing. + if (state_.load() == State::kNoSpace) { + return false; + } return force || (bytesAfterCheckpoint_ >= checkpointIntervalBytes_); } @@ -548,6 +589,12 @@ class SsdFile { // Shard index within 'cache_'. const int32_t shardId_; + // Maximum number of SSD cache entries allowed in this file. 0 means no limit. + const uint64_t maxEntries_; + + // Executor for async fsync in checkpoint. + folly::Executor* const executor_; + // Serializes access to all private data members. mutable std::shared_mutex mutex_; @@ -578,6 +625,8 @@ class SsdFile { // Map of file number and offset to location in file. folly::F14FastMap entries_; + std::atomic state_{State::kActive}; + // File system. std::shared_ptr fs_; @@ -603,9 +652,6 @@ class SsdFile { // means no checkpointing. This is set to 0 if checkpointing fails. int64_t checkpointIntervalBytes_{0}; - // Executor for async fsync in checkpoint. - folly::Executor* executor_; - // Count of bytes written after last checkpoint. std::atomic bytesAfterCheckpoint_{0}; @@ -622,4 +668,16 @@ class SsdFile { friend class test::SsdCacheTestHelper; }; +std::ostream& operator<<(std::ostream& out, const SsdFile::State& state); + } // namespace facebook::velox::cache + +template <> +struct fmt::formatter + : formatter { + auto format(facebook::velox::cache::SsdFile::State state, format_context& ctx) + const { + return formatter::format( + facebook::velox::cache::SsdFile::stateString(state), ctx); + } +}; diff --git a/velox/common/caching/StringIdMap.cpp b/velox/common/caching/StringIdMap.cpp index c8c88542da1e..e991c5a750ea 100644 --- a/velox/common/caching/StringIdMap.cpp +++ b/velox/common/caching/StringIdMap.cpp @@ -31,8 +31,8 @@ void StringIdMap::release(uint64_t id) { std::lock_guard l(mutex_); auto it = idToEntry_.find(id); if (it != idToEntry_.end()) { - VELOX_CHECK_LT( - 0, it->second.numInUse, "Extra release of id in StringIdMap"); + VELOX_CHECK_GT( + it->second.numInUse, 0, "Extra release of id in StringIdMap"); if (--it->second.numInUse == 0) { pinnedSize_ -= it->second.string.size(); auto strIter = stringToId_.find(it->second.string); @@ -60,11 +60,11 @@ uint64_t StringIdMap::makeId(std::string_view string) { if (it != stringToId_.end()) { auto entry = idToEntry_.find(it->second); VELOX_CHECK(entry != idToEntry_.end()); - if (++entry->second.numInUse == 1) { - pinnedSize_ += entry->second.string.size(); - } + VELOX_CHECK_GE(entry->second.numInUse, 1); + ++entry->second.numInUse; return it->second; } + Entry entry; entry.string = string; // Check that we do not use an id twice. In practice this never @@ -91,9 +91,8 @@ uint64_t StringIdMap::recoverId(uint64_t id, std::string_view string) { id, it->second, "Multiple recover ids assigned to {}", string); auto entry = idToEntry_.find(it->second); VELOX_CHECK(entry != idToEntry_.end()); - if (++entry->second.numInUse == 1) { - pinnedSize_ += entry->second.string.size(); - } + VELOX_CHECK_GE(entry->second.numInUse, 1); + ++entry->second.numInUse; return id; } diff --git a/velox/common/caching/tests/AsyncDataCacheTest.cpp b/velox/common/caching/tests/AsyncDataCacheTest.cpp index f60da8f4ba56..00ef2c6a80b9 100644 --- a/velox/common/caching/tests/AsyncDataCacheTest.cpp +++ b/velox/common/caching/tests/AsyncDataCacheTest.cpp @@ -670,7 +670,7 @@ TEST_P(AsyncDataCacheTest, pin) { EXPECT_LT(0, cache_->incrementPrefetchPages(0)); auto stats = cache_->refreshStats(); EXPECT_EQ(1, stats.numExclusive); - EXPECT_LE(kSize, stats.largeSize); + EXPECT_EQ(0, stats.largeSize); CachePin otherPin; EXPECT_THROW(otherPin = pin, VeloxException); @@ -1138,13 +1138,13 @@ TEST_P(AsyncDataCacheTest, shrinkCache) { pins.push_back(std::move(largePin)); } auto stats = cache_->refreshStats(); - ASSERT_EQ(stats.numEntries, numEntries * 2); + ASSERT_EQ(stats.numEntries, 0); ASSERT_EQ(stats.numEmptyEntries, 0); ASSERT_EQ(stats.numExclusive, numEntries * 2); ASSERT_EQ(stats.numEvict, 0); ASSERT_EQ(stats.numHit, 0); - ASSERT_EQ(stats.tinySize, kTinyDataSize * numEntries); - ASSERT_EQ(stats.largeSize, kLargeDataSize * numEntries); + ASSERT_EQ(stats.tinySize, 0); + ASSERT_EQ(stats.largeSize, 0); ASSERT_EQ(stats.sharedPinnedBytes, 0); ASSERT_GE( stats.exclusivePinnedBytes, @@ -1402,9 +1402,10 @@ TEST_P(AsyncDataCacheTest, makeEvictable) { std::vector keys; keys.reserve(cachePins.size()); for (const auto& pin : cachePins) { - keys.push_back(RawFileCacheKey{ - pin.checkedEntry()->key().fileNum.id(), - pin.checkedEntry()->key().offset}); + keys.push_back( + RawFileCacheKey{ + pin.checkedEntry()->key().fileNum.id(), + pin.checkedEntry()->key().offset}); } cachePins.clear(); for (const auto& key : keys) { diff --git a/velox/common/caching/tests/CMakeLists.txt b/velox/common/caching/tests/CMakeLists.txt index 78cdf56656f8..c80dacb33969 100644 --- a/velox/common/caching/tests/CMakeLists.txt +++ b/velox/common/caching/tests/CMakeLists.txt @@ -16,12 +16,8 @@ add_executable(simple_lru_cache_test SimpleLRUCacheTest.cpp) add_test(simple_lru_cache_test simple_lru_cache_test) target_link_libraries( simple_lru_cache_test - PRIVATE - Folly::folly - velox_time - glog::glog - GTest::gtest - GTest::gtest_main) + PRIVATE velox_common_base Folly::folly velox_time glog::glog GTest::gtest GTest::gtest_main +) add_executable( velox_cache_test @@ -29,7 +25,8 @@ add_executable( CacheTTLControllerTest.cpp SsdFileTest.cpp SsdFileTrackerTest.cpp - StringIdMapTest.cpp) + StringIdMapTest.cpp +) add_test(velox_cache_test velox_cache_test) target_link_libraries( velox_cache_test @@ -43,16 +40,12 @@ target_link_libraries( Folly::folly glog::glog GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(cached_factory_test CachedFactoryTest.cpp) add_test(cached_factory_test cached_factory_test) target_link_libraries( cached_factory_test - PRIVATE - velox_process - Folly::folly - velox_time - glog::glog - GTest::gtest - GTest::gtest_main) + PRIVATE velox_common_base Folly::folly velox_time glog::glog GTest::gtest GTest::gtest_main +) diff --git a/velox/common/caching/tests/CacheTestUtil.h b/velox/common/caching/tests/CacheTestUtil.h index 731da4e876a8..0be663d07970 100644 --- a/velox/common/caching/tests/CacheTestUtil.h +++ b/velox/common/caching/tests/CacheTestUtil.h @@ -59,6 +59,18 @@ class SsdFileTestHelper { return ssdFile_->checksumReadVerificationEnabled_; } + SsdFile::State state() const { + return ssdFile_->state_.load(); + } + + void setState(SsdFile::State state) { + ssdFile_->state_ = state; + } + + int32_t maxRegions() const { + return ssdFile_->maxRegions_; + } + /// Deletes the backing file. void deleteFile() { process::TraceContext trace("SsdFile::testingDeleteFile"); diff --git a/velox/common/caching/tests/SsdFileTest.cpp b/velox/common/caching/tests/SsdFileTest.cpp index bf4f2a94fbf6..a5ef1f6d0c2b 100644 --- a/velox/common/caching/tests/SsdFileTest.cpp +++ b/velox/common/caching/tests/SsdFileTest.cpp @@ -96,7 +96,8 @@ class SsdFileTest : public testing::Test { uint64_t checkpointIntervalBytes = 0, bool checksumEnabled = false, bool checksumReadVerificationEnabled = false, - bool disableFileCow = false) { + bool disableFileCow = false, + uint64_t maxEntries = 0) { SsdFile::Config config( fmt::format("{}/ssdtest", tempDirectory_->getPath()), 0, // shardId @@ -105,6 +106,7 @@ class SsdFileTest : public testing::Test { disableFileCow, checksumEnabled, checksumReadVerificationEnabled, + maxEntries, ssdExecutor()); ssdFile_ = std::make_unique(config); if (ssdFile_ != nullptr) { @@ -231,8 +233,9 @@ class SsdFileTest : public testing::Test { std::vector ssdPins; ssdPins.reserve(pins.size()); for (auto& pin : pins) { - ssdPins.push_back(ssdFile_->find(RawFileCacheKey{ - pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); + ssdPins.push_back(ssdFile_->find( + RawFileCacheKey{ + pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); EXPECT_FALSE(ssdPins.back().empty()); } ssdFile_->load(ssdPins, pins); @@ -820,6 +823,110 @@ TEST_F(SsdFileTest, ssdReadWithoutChecksumCheck) { #endif } +TEST_F(SsdFileTest, writeInNoSpaceState) { + constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; + initializeCache(kSsdSize); + + // Verify the initial state is kActive. + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kActive); + + // Verify the cache write is successful in the initial kActiveState. + auto pins = makePins(fileName_.id(), 0, 4096, 4096, 4096 * 10); + ssdFile_->write(pins); + SsdCacheStats statsBeforeNoSpace; + ssdFile_->updateStats(statsBeforeNoSpace); + EXPECT_GT(statsBeforeNoSpace.entriesWritten, 0); + EXPECT_EQ(statsBeforeNoSpace.writeSsdDropped, 0); + + // Set the state to kNoSpace to simulate the SSD running out of space. + ssdFileHelper_->setState(SsdFile::State::kNoSpace); + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kNoSpace); + + // Verify that writes are dropped and no new entries are written in the + // kNoSpace state. + auto morePins = makePins(fileName_.id(), 4096 * 10, 4096, 4096, 4096 * 5); + ssdFile_->write(morePins); + SsdCacheStats statsAfterNoSpace; + ssdFile_->updateStats(statsAfterNoSpace); + EXPECT_GT( + statsAfterNoSpace.writeSsdDropped, statsBeforeNoSpace.writeSsdDropped); + EXPECT_EQ( + statsAfterNoSpace.entriesWritten, statsBeforeNoSpace.entriesWritten); + + // Verify none of the new entries have ssdFile set. + for (const auto& pin : morePins) { + EXPECT_EQ(pin.entry()->ssdFile(), nullptr); + } +} + +TEST_F(SsdFileTest, checkpointInNoSpaceState) { + constexpr int64_t kSsdSize = 4 * SsdFile::kRegionSize; + const uint64_t checkpointIntervalBytes = 2 * SsdFile::kRegionSize; + initializeCache(kSsdSize, checkpointIntervalBytes); + + // Write some entries to trigger checkpoint eligibility. + auto pins = makePins(fileName_.id(), 0, 4096, 4096, 3 * SsdFile::kRegionSize); + ssdFile_->write(pins); + + SsdCacheStats statsBeforeNoSpace; + ssdFile_->updateStats(statsBeforeNoSpace); + + // Set the state to kNoSpace. + ssdFileHelper_->setState(SsdFile::State::kNoSpace); + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kNoSpace); + + // Verify checkpointing is skipped in the kNoSpace state. + ssdFile_->checkpoint(true); + + SsdCacheStats statsAfterNoSpace; + ssdFile_->updateStats(statsAfterNoSpace); + EXPECT_EQ( + statsAfterNoSpace.checkpointsWritten, + statsBeforeNoSpace.checkpointsWritten); +} + +TEST_F(SsdFileTest, growOrEvictBlockedInNoSpaceState) { + constexpr int64_t kSsdSize = 4 * SsdFile::kRegionSize; + const uint64_t checkpointIntervalBytes = kSsdSize; + initializeCache(kSsdSize, checkpointIntervalBytes); + + // Fill up the SSD cache to trigger eviction on subsequent writes. + for (auto startOffset = 0; startOffset <= kSsdSize - SsdFile::kRegionSize; + startOffset += SsdFile::kRegionSize) { + auto pins = makePins( + fileName_.id(), startOffset, 4096, 4096, SsdFile::kRegionSize - 1024); + ssdFile_->write(pins); + } + + // The SSD cache is at max regins. + SsdCacheStats statsBeforeNoSpace; + ssdFile_->updateStats(statsBeforeNoSpace); + EXPECT_EQ(statsBeforeNoSpace.regionsCached, ssdFileHelper_->maxRegions()); + + ssdFileHelper_->setState(SsdFile::State::kNoSpace); + EXPECT_EQ(ssdFileHelper_->state(), SsdFile::State::kNoSpace); + + // Verify the eviction should not happen since write was dropped before + // eviction. + auto newPins = makePins(fileName_.id(), kSsdSize * 2, 4096, 4096, 4096 * 5); + ssdFile_->write(newPins); + + SsdCacheStats statsAfterNoSpace; + ssdFile_->updateStats(statsAfterNoSpace); + + EXPECT_EQ( + statsAfterNoSpace.regionsEvicted, statsBeforeNoSpace.regionsEvicted); + EXPECT_GT( + statsAfterNoSpace.writeSsdDropped, statsBeforeNoSpace.writeSsdDropped); + EXPECT_EQ( + statsAfterNoSpace.entriesWritten, statsBeforeNoSpace.entriesWritten); + + // Verify none of the new entries have ssdFile set. + for (const auto& pin : newPins) { + EXPECT_EQ(pin.entry()->ssdFile(), nullptr); + } +} + TEST_F(SsdFileTest, dataFileErrorInjection) { constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; initializeCache(kSsdSize, 0, false, false, false, true); @@ -868,8 +975,9 @@ TEST_F(SsdFileTest, dataFileErrorInjection) { std::vector ssdPins; ssdPins.reserve(pins.size()); for (auto& pin : pins) { - ssdPins.push_back(ssdFile_->find(RawFileCacheKey{ - pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); + ssdPins.push_back(ssdFile_->find( + RawFileCacheKey{ + pin.entry()->key().fileNum.id(), pin.entry()->key().offset})); } SsdCacheStats statsWithReadErrorInjected; @@ -954,6 +1062,65 @@ TEST_F(SsdFileTest, evictlogFileErrorInjection) { ASSERT_GT(statsAfterRecovery.readCheckpointErrors, 0); } +TEST_F(SsdFileTest, maxEntriesLimit) { + constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; + constexpr uint64_t kMaxEntries = 100; + FLAGS_velox_ssd_verify_write = true; + + initializeCache(kSsdSize); + // Re-initialize SSD file with maxEntries limit + initializeSsdFile(kSsdSize, 0, false, false, false, kMaxEntries); + + // Write more entries than the limit + auto pins = makePins(fileName_.id(), 0, 4096, 2048 * 1025, 62 * kMB); + ASSERT_GT(pins.size(), kMaxEntries); + + ssdFile_->write(pins); + + SsdCacheStats stats; + ssdFile_->updateStats(stats); + + // The SSD file should have at most maxEntries + EXPECT_LE(stats.entriesCached, kMaxEntries); + // Some writes should have been dropped due to the entry limit + EXPECT_GT(stats.writeSsdExceedEntryLimit, 0); +} + +TEST_F(SsdFileTest, noWritesDroppedWithinMaxEntriesLimit) { + constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; + constexpr uint64_t kMaxEntries = 200; + FLAGS_velox_ssd_verify_write = true; + + initializeCache(kSsdSize); + // Re-initialize SSD file with maxEntries limit + initializeSsdFile(kSsdSize, 0, false, false, false, kMaxEntries); + + // Write fewer entries than the limit + auto pins = makePins(fileName_.id(), 0, 4096, 4096, 4096 * 50); + ASSERT_LT(pins.size(), kMaxEntries); + const auto numPins = pins.size(); + + ssdFile_->write(pins); + + SsdCacheStats stats; + ssdFile_->updateStats(stats); + + // All entries should be cached since we're within the limit + EXPECT_EQ(stats.entriesCached, numPins); + // No writes should be dropped due to exceeding entry limit + EXPECT_EQ(stats.writeSsdExceedEntryLimit, 0); + // No writes should be dropped due to lack of space + EXPECT_EQ(stats.writeSsdDropped, 0); + // All entries should have been written + EXPECT_EQ(stats.entriesWritten, numPins); + + // Verify all entries are readable + for (auto& pin : pins) { + EXPECT_EQ(ssdFile_.get(), pin.entry()->ssdFile()); + } + readAndCheckPins(pins); +} + #ifdef VELOX_SSD_FILE_TEST_SET_NO_COW_FLAG TEST_F(SsdFileTest, disabledCow) { constexpr int64_t kSsdSize = 16 * SsdFile::kRegionSize; diff --git a/velox/common/caching/tests/StringIdMapTest.cpp b/velox/common/caching/tests/StringIdMapTest.cpp index 59d95af2e88a..1a1d006748aa 100644 --- a/velox/common/caching/tests/StringIdMapTest.cpp +++ b/velox/common/caching/tests/StringIdMapTest.cpp @@ -22,7 +22,7 @@ using namespace facebook::velox; TEST(StringIdMapTest, basic) { - constexpr const char* kFile1 = "file_1"; + constexpr std::string_view kFile1 = "file_1"; StringIdMap map; uint64_t id = 0; { @@ -33,7 +33,7 @@ TEST(StringIdMapTest, basic) { id = lease2.id(); lease1 = lease2; EXPECT_EQ(id, lease1.id()); - EXPECT_EQ(strlen(kFile1), map.pinnedSize()); + EXPECT_EQ(kFile1.size(), map.pinnedSize()); } StringIdLease lease3(map, kFile1); EXPECT_NE(lease3.id(), id); @@ -56,50 +56,48 @@ TEST(StringIdMapTest, rehash) { } TEST(StringIdMapTest, recover) { - constexpr const char* kRecoverFile1 = "file_1"; - constexpr const char* kRecoverFile2 = "file_2"; - constexpr const char* kRecoverFile3 = "file_3"; + constexpr std::string_view kRecoverFile1("file_1"); + constexpr std::string_view kRecoverFile2("file_2"); + constexpr std::string_view kRecoverFile3("file_3"); StringIdMap map; const uint64_t recoverId1{10}; const uint64_t recoverId2{20}; { StringIdLease lease(map, recoverId1, kRecoverFile1); ASSERT_TRUE(lease.hasValue()); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); ASSERT_EQ(map.testingLastId(), recoverId1); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId1, kRecoverFile2), + StringIdLease(map, recoverId1, kRecoverFile2), "(1 vs. 0) Reused recover id 10 assigned to file_2"); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile1), + StringIdLease(map, recoverId2, kRecoverFile1), "(20 vs. 10) Multiple recover ids assigned to file_1"); } ASSERT_EQ(map.pinnedSize(), 0); StringIdLease lease1(map, kRecoverFile1); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); ASSERT_EQ(map.testingLastId(), recoverId1 + 1); { StringIdLease lease(map, recoverId2, kRecoverFile2); ASSERT_TRUE(lease.hasValue()); ASSERT_EQ(lease.id(), recoverId2); - ASSERT_EQ( - map.pinnedSize(), ::strlen(kRecoverFile1) + ::strlen(kRecoverFile2)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size() + kRecoverFile2.size()); ASSERT_EQ(map.testingLastId(), recoverId2); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile3), + StringIdLease(map, recoverId2, kRecoverFile3), "(1 vs. 0) Reused recover id 20 assigned to file_3"); VELOX_ASSERT_THROW( - std::make_unique(map, recoverId2, kRecoverFile1), + StringIdLease(map, recoverId2, kRecoverFile1), "(20 vs. 11) Multiple recover ids assigned to file_1"); StringIdLease dupLease(map, recoverId2, kRecoverFile2); ASSERT_TRUE(lease.hasValue()); ASSERT_EQ(lease.id(), recoverId2); - ASSERT_EQ( - map.pinnedSize(), ::strlen(kRecoverFile1) + ::strlen(kRecoverFile2)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size() + kRecoverFile2.size()); } ASSERT_EQ(map.testingLastId(), recoverId2); - ASSERT_EQ(map.pinnedSize(), ::strlen(kRecoverFile1)); + ASSERT_EQ(map.pinnedSize(), kRecoverFile1.size()); } diff --git a/velox/common/compression/CMakeLists.txt b/velox/common/compression/CMakeLists.txt index d6ff2579ea8b..d8743ce38603 100644 --- a/velox/common/compression/CMakeLists.txt +++ b/velox/common/compression/CMakeLists.txt @@ -20,12 +20,11 @@ velox_add_library(velox_common_compression Compression.cpp LzoDecompressor.cpp) velox_link_libraries( velox_common_compression PUBLIC velox_status Folly::folly - PRIVATE velox_exception) + PRIVATE velox_exception +) if(VELOX_ENABLE_COMPRESSION_LZ4) - velox_sources(velox_common_compression PRIVATE Lz4Compression.cpp - HadoopCompressionFormat.cpp) + velox_sources(velox_common_compression PRIVATE Lz4Compression.cpp HadoopCompressionFormat.cpp) velox_link_libraries(velox_common_compression PUBLIC lz4::lz4) - velox_compile_definitions(velox_common_compression - PRIVATE VELOX_ENABLE_COMPRESSION_LZ4) + velox_compile_definitions(velox_common_compression PRIVATE VELOX_ENABLE_COMPRESSION_LZ4) endif() diff --git a/velox/common/compression/Compression.cpp b/velox/common/compression/Compression.cpp index 8473b4fe119a..4dde4b95a89b 100644 --- a/velox/common/compression/Compression.cpp +++ b/velox/common/compression/Compression.cpp @@ -122,9 +122,10 @@ Expected> Codec::create( const CodecOptions& codecOptions) { if (!isAvailable(kind)) { auto name = compressionKindToString(kind); - return folly::makeUnexpected(Status::Invalid( - "Support for codec '{}' is either not built or not implemented.", - name)); + return folly::makeUnexpected( + Status::Invalid( + "Support for codec '{}' is either not built or not implemented.", + name)); } auto compressionLevel = codecOptions.compressionLevel; @@ -155,9 +156,10 @@ Expected> Codec::create( } VELOX_RETURN_UNEXPECTED_IF( codec == nullptr, - Status::Invalid(fmt::format( - "Support for codec '{}' is either not built or not implemented.", - compressionKindToString(kind)))); + Status::Invalid( + fmt::format( + "Support for codec '{}' is either not built or not implemented.", + compressionKindToString(kind)))); VELOX_RETURN_UNEXPECTED_NOT_OK(codec->init()); @@ -184,8 +186,9 @@ bool Codec::isAvailable(CompressionKind kind) { Expected Codec::getUncompressedLength( const uint8_t* input, uint64_t inputLength) const { - return folly::makeUnexpected(Status::Invalid( - "getUncompressedLength is unsupported with {} format.", name())); + return folly::makeUnexpected( + Status::Invalid( + "getUncompressedLength is unsupported with {} format.", name())); } Expected Codec::compressFixedLength( @@ -203,14 +206,16 @@ bool Codec::supportsStreamingCompression() const { Expected> Codec::makeStreamingCompressor() { - return folly::makeUnexpected(Status::Invalid( - "Streaming compression is unsupported with {} format.", name())); + return folly::makeUnexpected( + Status::Invalid( + "Streaming compression is unsupported with {} format.", name())); } Expected> Codec::makeStreamingDecompressor() { - return folly::makeUnexpected(Status::Invalid( - "Streaming decompression is unsupported with {} format.", name())); + return folly::makeUnexpected( + Status::Invalid( + "Streaming decompression is unsupported with {} format.", name())); } int32_t Codec::compressionLevel() const { diff --git a/velox/common/compression/Lz4Compression.cpp b/velox/common/compression/Lz4Compression.cpp index 111b34bf7695..58c6ed6662f5 100644 --- a/velox/common/compression/Lz4Compression.cpp +++ b/velox/common/compression/Lz4Compression.cpp @@ -223,7 +223,7 @@ Status LZ4Compressor::init() { firstTime_ = true; ret = LZ4F_createCompressionContext(&ctx_, LZ4F_VERSION); - VELOX_RETURN_IF(LZ4F_isError(ret), lz4Error("LZ4 init failed: ", ret)); + VELOX_RETURN_IF(LZ4F_isError(ret), lz4Error("LZ4 init failed: {}", ret)); return Status::OK(); } @@ -255,7 +255,7 @@ Expected LZ4Compressor::compress( ctx_, output, outputSize, input, inputSize, nullptr /* options */); VELOX_RETURN_UNEXPECTED_IF( LZ4F_isError(numBytesOrError), - lz4Error("LZ4 compress updated failed: ", numBytesOrError)); + lz4Error("LZ4 compress updated failed: {}", numBytesOrError)); bytesWritten += static_cast(numBytesOrError); VELOX_DCHECK_LE(bytesWritten, outputSize); @@ -287,7 +287,7 @@ Expected LZ4Compressor::flush( LZ4F_flush(ctx_, output, outputSize, nullptr /* options */); VELOX_RETURN_UNEXPECTED_IF( LZ4F_isError(numBytesOrError), - lz4Error("LZ4 flush failed: ", numBytesOrError)); + lz4Error("LZ4 flush failed: {}", numBytesOrError)); bytesWritten += static_cast(numBytesOrError); VELOX_DCHECK_LE(bytesWritten, outputLength); @@ -319,7 +319,7 @@ Expected LZ4Compressor::finalize( LZ4F_compressEnd(ctx_, output, outputSize, nullptr /* options */); VELOX_RETURN_UNEXPECTED_IF( LZ4F_isError(numBytesOrError), - lz4Error("LZ4 end failed: ", numBytesOrError)); + lz4Error("LZ4 finalize failed: {}", numBytesOrError)); bytesWritten += static_cast(numBytesOrError); VELOX_DCHECK_LE(bytesWritten, outputLength); @@ -333,7 +333,7 @@ Status LZ4Compressor::compressBegin( auto numBytesOrError = LZ4F_compressBegin(ctx_, output, outputLen, &prefs_); VELOX_RETURN_IF( LZ4F_isError(numBytesOrError), - lz4Error("LZ4 compress begin failed: ", numBytesOrError)); + lz4Error("LZ4 compress begin failed: {}", numBytesOrError)); firstTime_ = false; output += numBytesOrError; outputLen -= numBytesOrError; @@ -344,7 +344,7 @@ Status LZ4Compressor::compressBegin( Status LZ4Decompressor::init() { finished_ = false; auto ret = LZ4F_createDecompressionContext(&ctx_, LZ4F_VERSION); - VELOX_RETURN_IF(LZ4F_isError(ret), lz4Error("LZ4 init failed: ", ret)); + VELOX_RETURN_IF(LZ4F_isError(ret), lz4Error("LZ4 init failed: {}", ret)); return Status::OK(); } @@ -378,7 +378,7 @@ Expected LZ4Decompressor::decompress( auto ret = LZ4F_decompress( ctx_, output, &outputSize, input, &inputSize, nullptr /* options */); VELOX_RETURN_UNEXPECTED_IF( - LZ4F_isError(ret), lz4Error("LZ4 decompress failed: ", ret)); + LZ4F_isError(ret), lz4Error("LZ4 decompression failed: {}", ret)); finished_ = (ret == 0); return DecompressResult{ static_cast(inputSize), @@ -443,7 +443,7 @@ Expected Lz4FrameCodec::compress( static_cast(inputLength), &prefs_); VELOX_RETURN_UNEXPECTED_IF( - LZ4F_isError(ret), lz4Error("Lz4 compression failure: ", ret)); + LZ4F_isError(ret), lz4Error("LZ4 compression failed: {}", ret)); return static_cast(ret); } @@ -470,12 +470,11 @@ Expected Lz4FrameCodec::decompress( bytesWritten += result.bytesWritten; VELOX_RETURN_UNEXPECTED_IF( result.outputTooSmall, - Status::IOError("Lz4 decompression buffer too small.")); + Status::IOError("LZ4 decompression buffer too small.")); } VELOX_RETURN_UNEXPECTED_IF( !decompressor->isFinished() || inputLength != 0, - Status::IOError( - "Lz4 compressed input contains less than one frame.")); + Status::IOError("LZ4 decompression failed.")); return bytesWritten; }); } @@ -534,7 +533,7 @@ Expected Lz4RawCodec::compress( compressionLevel_); } VELOX_RETURN_UNEXPECTED_IF( - compressedSize == 0, Status::IOError("Lz4 compression failure.")); + compressedSize == 0, Status::IOError("LZ4 compression failed.")); return static_cast(compressedSize); } @@ -551,7 +550,7 @@ Expected Lz4RawCodec::decompress( static_cast(inputLength), static_cast(outputLength)); VELOX_RETURN_UNEXPECTED_IF( - decompressedSize < 0, Status::IOError("Lz4 decompression failure.")); + decompressedSize < 0, Status::IOError("LZ4 decompression failed.")); return static_cast(decompressedSize); } diff --git a/velox/common/compression/tests/CMakeLists.txt b/velox/common/compression/tests/CMakeLists.txt index 3b75b6fea698..5f7a19e7906a 100644 --- a/velox/common/compression/tests/CMakeLists.txt +++ b/velox/common/compression/tests/CMakeLists.txt @@ -17,5 +17,5 @@ add_test(velox_common_compression_test velox_common_compression_test) target_link_libraries( velox_common_compression_test PUBLIC velox_link_libs - PRIVATE velox_common_compression velox_exception GTest::gtest - GTest::gtest_main) + PRIVATE velox_common_compression velox_exception GTest::gtest GTest::gtest_main +) diff --git a/velox/common/compression/tests/CompressionTest.cpp b/velox/common/compression/tests/CompressionTest.cpp index 846d606e4499..fbaef2253da7 100644 --- a/velox/common/compression/tests/CompressionTest.cpp +++ b/velox/common/compression/tests/CompressionTest.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -99,18 +100,18 @@ void checkCodecRoundtrip( Codec* c1, Codec* c2, const std::vector& data) { - auto maxCompressedLen = + auto maxCompressedLength = static_cast(c1->maxCompressedLength(data.size())); - std::vector compressed(maxCompressedLen); + std::vector compressed(maxCompressedLength); // Allocate at least 1 byte to ensure data.get() is not nullptr. std::vector decompressed(data.size() == 0 ? 1 : data.size()); // Compress with codec c1. - auto compressionLength = + auto compressedLength = c1->compress( - data.data(), data.size(), compressed.data(), maxCompressedLen) + data.data(), data.size(), compressed.data(), maxCompressedLength) .thenOrThrow(folly::identity, throwsNotOk); - compressed.resize(compressionLength); + compressed.resize(compressedLength); // Decompress with codec c2. auto decompressedLength = c2->decompress( @@ -122,6 +123,39 @@ void checkCodecRoundtrip( decompressed.resize(data.size()); ASSERT_EQ(data, decompressed); ASSERT_EQ(data.size(), decompressedLength); + + // Compress with codec c1 with a smaller output buffer to test compression + // failure. + static const std::unordered_map + compressionFailures = { + {"lz4", "LZ4 compression failed: ERROR_dstMaxSize_tooSmall"}, + {"lz4_raw", "LZ4 compression failed"}, + {"lz4_hadoop", "LZ4 compression failed"}}; + VELOX_ASSERT_ERROR_STATUS( + c1->compress( + data.data(), data.size(), compressed.data(), compressedLength - 1) + .error(), + StatusCode::kIOError, + compressionFailures.at(c1->name())); + + // Decompress corrupted data. + std::vector corruptedData = compressed; + corruptedData.resize(compressed.size() + 1); + + static const std::unordered_map + decompressionFailures = { + {"lz4", "LZ4 decompression failed."}, + {"lz4_raw", "LZ4 decompression failed."}, + {"lz4_hadoop", "LZ4 decompression failed."}}; + VELOX_ASSERT_ERROR_STATUS( + c2->decompress( + corruptedData.data(), + corruptedData.size(), + decompressed.data(), + decompressed.size()) + .error(), + StatusCode::kIOError, + decompressionFailures.at(c2->name())); } // Use same codec for both compression and decompression. diff --git a/velox/common/config/CMakeLists.txt b/velox/common/config/CMakeLists.txt index 9639a2c8b6f7..c2d2acd82b93 100644 --- a/velox/common/config/CMakeLists.txt +++ b/velox/common/config/CMakeLists.txt @@ -17,7 +17,4 @@ if(${VELOX_BUILD_TESTING}) endif() velox_add_library(velox_common_config Config.cpp) -velox_link_libraries( - velox_common_config - PUBLIC velox_common_base velox_exception - PRIVATE re2::re2) +velox_link_libraries(velox_common_config PUBLIC velox_common_base velox_exception PRIVATE re2::re2) diff --git a/velox/common/config/Config.cpp b/velox/common/config/Config.cpp index 18cc705949cd..987dbdb3dc05 100644 --- a/velox/common/config/Config.cpp +++ b/velox/common/config/Config.cpp @@ -138,13 +138,12 @@ std::unordered_map ConfigBase::rawConfigsCopy() return configs_; } -folly::Optional ConfigBase::get(const std::string& key) const { - folly::Optional val; - std::shared_lock l(mutex_); - auto it = configs_.find(key); - if (it != configs_.end()) { - val = it->second; +std::optional ConfigBase::access(const std::string& key) const { + std::shared_lock l{mutex_}; + if (auto it = configs_.find(key); it != configs_.end()) { + return it->second; } - return val; + return std::nullopt; } + } // namespace facebook::velox::config diff --git a/velox/common/config/Config.h b/velox/common/config/Config.h index 96f77d59cd7b..16031818d9b2 100644 --- a/velox/common/config/Config.h +++ b/velox/common/config/Config.h @@ -23,6 +23,7 @@ #include "folly/Conv.h" #include "velox/common/base/Exceptions.h" +#include "velox/common/config/IConfig.h" namespace facebook::velox::config { @@ -47,7 +48,7 @@ std::chrono::duration toDuration(const std::string& str); /// The concrete config class should inherit the config base and define all the /// entries. -class ConfigBase { +class ConfigBase : public IConfig { public: template struct Entry { @@ -111,49 +112,20 @@ class ConfigBase { : entry.defaultVal; } - template - folly::Optional get( - const std::string& key, - std::function toT = [](auto /* unused */, - auto value) { - return folly::to(value); - }) const { - auto val = get(key); - if (val.hasValue()) { - return toT(key, val.value()); - } else { - return folly::none; - } - } - - template - T get( - const std::string& key, - const T& defaultValue, - std::function toT = [](auto /* unused */, - auto value) { - return folly::to(value); - }) const { - auto val = get(key); - if (val.hasValue()) { - return toT(key, val.value()); - } else { - return defaultValue; - } - } + using IConfig::get; bool valueExists(const std::string& key) const; const std::unordered_map& rawConfigs() const; - std::unordered_map rawConfigsCopy() const; + std::unordered_map rawConfigsCopy() const final; protected: mutable std::shared_mutex mutex_; std::unordered_map configs_; private: - folly::Optional get(const std::string& key) const; + std::optional access(const std::string& key) const final; const bool mutable_; }; diff --git a/velox/common/config/IConfig.h b/velox/common/config/IConfig.h new file mode 100644 index 000000000000..06d8d8bc3f5e --- /dev/null +++ b/velox/common/config/IConfig.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace facebook::velox::config { + +/// IConfig - Read-only config interface +/// for accessing key-value parameters. +/// Supports value retrieval by key and +/// duplication of the raw configuration data. +/// Can be used by velox::QueryConfig to access +/// externally managed system configuration. +class IConfig { + public: + template + std::optional get( + const std::string& key, + const std::function& toT = + [](auto /* unused */, auto value) { + return folly::to(value); + }) const { + if (auto val = access(key)) { + return toT(key, *val); + } + return std::nullopt; + } + + template + T get( + const std::string& key, + const T& defaultValue, + const std::function& toT = + [](auto /* unused */, auto value) { + return folly::to(value); + }) const { + if (auto val = access(key)) { + return toT(key, *val); + } + return defaultValue; + } + + virtual std::unordered_map rawConfigsCopy() + const = 0; + + virtual ~IConfig() = default; + + private: + virtual std::optional access(const std::string& key) const = 0; +}; + +} // namespace facebook::velox::config diff --git a/velox/common/config/tests/CMakeLists.txt b/velox/common/config/tests/CMakeLists.txt index 83de01821eea..feb2ae1222b0 100644 --- a/velox/common/config/tests/CMakeLists.txt +++ b/velox/common/config/tests/CMakeLists.txt @@ -17,4 +17,5 @@ add_test(velox_config_test velox_config_test) target_link_libraries( velox_config_test PUBLIC Folly::folly - PRIVATE velox_common_config GTest::gtest GTest::gtest_main) + PRIVATE velox_common_config GTest::gtest GTest::gtest_main +) diff --git a/velox/common/dynamic_registry/tests/CMakeLists.txt b/velox/common/dynamic_registry/tests/CMakeLists.txt index 58132e596e1f..40ccba9a9895 100644 --- a/velox/common/dynamic_registry/tests/CMakeLists.txt +++ b/velox/common/dynamic_registry/tests/CMakeLists.txt @@ -17,47 +17,40 @@ # VELOX_TEST_DYNAMIC_LIBRARY_PATH macro to locate the .so binary. add_library(velox_function_dynamic SHARED DynamicFunction.cpp) -add_library(velox_overwrite_int_function_dynamic SHARED - DynamicIntFunctionOverwrite.cpp) -add_library(velox_overwrite_varchar_function_dynamic SHARED - DynamicVarcharFunctionOverwrite.cpp) +add_library(velox_overwrite_int_function_dynamic SHARED DynamicIntFunctionOverwrite.cpp) +add_library(velox_overwrite_varchar_function_dynamic SHARED DynamicVarcharFunctionOverwrite.cpp) add_library(velox_function_err_dynamic SHARED DynamicErrFunction.cpp) -add_library(velox_overload_int_function_dynamic SHARED - DynamicIntFunctionOverload.cpp) -add_library(velox_overload_varchar_function_dynamic SHARED - DynamicVarcharFunctionOverload.cpp) -add_library(velox_function_non_default_dynamic SHARED - DynamicFunctionNonDefault.cpp) +add_library(velox_overload_int_function_dynamic SHARED DynamicIntFunctionOverload.cpp) +add_library(velox_overload_varchar_function_dynamic SHARED DynamicVarcharFunctionOverload.cpp) +add_library(velox_function_non_default_dynamic SHARED DynamicFunctionNonDefault.cpp) set(CMAKE_DYLIB_TEST_LINK_LIBRARIES fmt::fmt Folly::folly glog::glog xsimd) -target_link_libraries( - velox_function_dynamic - PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) +target_link_libraries(velox_function_dynamic PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) target_link_libraries( velox_overwrite_int_function_dynamic - PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) + PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES} +) target_link_libraries( velox_overwrite_varchar_function_dynamic - PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) + PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES} +) -target_link_libraries( - velox_function_err_dynamic - PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) +target_link_libraries(velox_function_err_dynamic PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) target_link_libraries( velox_overload_int_function_dynamic - PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) + PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES} +) target_link_libraries( velox_overload_varchar_function_dynamic - PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) + PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES} +) -target_link_libraries( - velox_function_non_default_dynamic - PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) +target_link_libraries(velox_function_non_default_dynamic PRIVATE ${CMAKE_DYLIB_TEST_LINK_LIBRARIES}) if(APPLE) set(COMMON_LIBRARY_LINK_OPTIONS "-Wl,-undefined,dynamic_lookup") @@ -68,20 +61,13 @@ else() set(COMMON_LIBRARY_LINK_OPTIONS "-Wl,--exclude-libs,ALL") endif() -target_link_options(velox_function_dynamic PRIVATE - ${COMMON_LIBRARY_LINK_OPTIONS}) -target_link_options(velox_overwrite_int_function_dynamic PRIVATE - ${COMMON_LIBRARY_LINK_OPTIONS}) -target_link_options(velox_overwrite_varchar_function_dynamic PRIVATE - ${COMMON_LIBRARY_LINK_OPTIONS}) -target_link_options(velox_function_err_dynamic PRIVATE - ${COMMON_LIBRARY_LINK_OPTIONS}) -target_link_options(velox_overload_int_function_dynamic PRIVATE - ${COMMON_LIBRARY_LINK_OPTIONS}) -target_link_options(velox_overload_varchar_function_dynamic PRIVATE - ${COMMON_LIBRARY_LINK_OPTIONS}) -target_link_options(velox_function_non_default_dynamic PRIVATE - ${COMMON_LIBRARY_LINK_OPTIONS}) +target_link_options(velox_function_dynamic PRIVATE ${COMMON_LIBRARY_LINK_OPTIONS}) +target_link_options(velox_overwrite_int_function_dynamic PRIVATE ${COMMON_LIBRARY_LINK_OPTIONS}) +target_link_options(velox_overwrite_varchar_function_dynamic PRIVATE ${COMMON_LIBRARY_LINK_OPTIONS}) +target_link_options(velox_function_err_dynamic PRIVATE ${COMMON_LIBRARY_LINK_OPTIONS}) +target_link_options(velox_overload_int_function_dynamic PRIVATE ${COMMON_LIBRARY_LINK_OPTIONS}) +target_link_options(velox_overload_varchar_function_dynamic PRIVATE ${COMMON_LIBRARY_LINK_OPTIONS}) +target_link_options(velox_function_non_default_dynamic PRIVATE ${COMMON_LIBRARY_LINK_OPTIONS}) # Here's the actual test which will dynamically load the library defined above. add_executable(velox_function_dynamic_link_test DynamicLinkTest.cpp) @@ -94,15 +80,16 @@ target_link_libraries( xsimd GTest::gmock GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) target_compile_definitions( velox_function_dynamic_link_test - PRIVATE VELOX_TEST_DYNAMIC_LIBRARY_PATH="${CMAKE_CURRENT_BINARY_DIR}") + PRIVATE VELOX_TEST_DYNAMIC_LIBRARY_PATH="${CMAKE_CURRENT_BINARY_DIR}" +) target_compile_definitions( velox_function_dynamic_link_test - PRIVATE - VELOX_TEST_DYNAMIC_LIBRARY_PATH_SUFFIX="${CMAKE_SHARED_LIBRARY_SUFFIX}") + PRIVATE VELOX_TEST_DYNAMIC_LIBRARY_PATH_SUFFIX="${CMAKE_SHARED_LIBRARY_SUFFIX}" +) -add_test(NAME velox_function_dynamic_link_test - COMMAND velox_function_dynamic_link_test) +add_test(NAME velox_function_dynamic_link_test COMMAND velox_function_dynamic_link_test) diff --git a/velox/common/dynamic_registry/tests/DynamicErrFunction.cpp b/velox/common/dynamic_registry/tests/DynamicErrFunction.cpp index f97444f465e9..f3e8ef3826f5 100644 --- a/velox/common/dynamic_registry/tests/DynamicErrFunction.cpp +++ b/velox/common/dynamic_registry/tests/DynamicErrFunction.cpp @@ -18,8 +18,8 @@ // This file defines a mock function that will be dynamically linked and // registered. There are no restrictions as to how the function needs to be -// defined, but the library (.so) needs to provide a `void registry()` C -// function in the top-level namespace. +// defined, but the library (.so) needs to provide a `void registerExtensions()` +// C function in the top-level namespace. // // (note the extern "C" directive to prevent the compiler from mangling the // symbol name). diff --git a/velox/common/dynamic_registry/tests/DynamicFunction.cpp b/velox/common/dynamic_registry/tests/DynamicFunction.cpp index a41a216521d2..dcc852d956d8 100644 --- a/velox/common/dynamic_registry/tests/DynamicFunction.cpp +++ b/velox/common/dynamic_registry/tests/DynamicFunction.cpp @@ -18,8 +18,8 @@ // This file defines a mock function that will be dynamically linked and // registered. There are no restrictions as to how the function needs to be -// defined, but the library (.so) needs to provide a `void registry()` C -// function in the top-level namespace. +// defined, but the library (.so) needs to provide a `void registerExtensions()` +// C function in the top-level namespace. // // (note the extern "C" directive to prevent the compiler from mangling the // symbol name). diff --git a/velox/common/dynamic_registry/tests/DynamicFunctionNonDefault.cpp b/velox/common/dynamic_registry/tests/DynamicFunctionNonDefault.cpp index 8489c88a9e4f..aebff84ce74f 100644 --- a/velox/common/dynamic_registry/tests/DynamicFunctionNonDefault.cpp +++ b/velox/common/dynamic_registry/tests/DynamicFunctionNonDefault.cpp @@ -18,8 +18,8 @@ // This file defines a mock function that will be dynamically linked and // registered. There are no restrictions as to how the function needs to be -// defined, but the library (.so) needs to provide a `void registry()` C -// function in the top-level namespace. +// defined, but the library (.so) needs to provide a `void registerExtensions()` +// C function in the top-level namespace. // // (note the extern "C" directive to prevent the compiler from mangling the // symbol name). diff --git a/velox/common/dynamic_registry/tests/DynamicIntFunctionOverload.cpp b/velox/common/dynamic_registry/tests/DynamicIntFunctionOverload.cpp index f0be371c1201..5129f9cff1d3 100644 --- a/velox/common/dynamic_registry/tests/DynamicIntFunctionOverload.cpp +++ b/velox/common/dynamic_registry/tests/DynamicIntFunctionOverload.cpp @@ -18,8 +18,8 @@ // This file defines a mock function that will be dynamically linked and // registered. There are no restrictions as to how the function needs to be -// defined, but the library (.so) needs to provide a `void registry()` C -// function in the top-level namespace. +// defined, but the library (.so) needs to provide a `void registerExtensions()` +// C function in the top-level namespace. // // (note the extern "C" directive to prevent the compiler from mangling the // symbol name). diff --git a/velox/common/dynamic_registry/tests/DynamicIntFunctionOverwrite.cpp b/velox/common/dynamic_registry/tests/DynamicIntFunctionOverwrite.cpp index dbb9623a2604..804edeec30bd 100644 --- a/velox/common/dynamic_registry/tests/DynamicIntFunctionOverwrite.cpp +++ b/velox/common/dynamic_registry/tests/DynamicIntFunctionOverwrite.cpp @@ -18,8 +18,8 @@ // This file defines a mock function that will be dynamically linked and // registered. There are no restrictions as to how the function needs to be -// defined, but the library (.so) needs to provide a `void registry()` C -// function in the top-level namespace. +// defined, but the library (.so) needs to provide a `void registerExtensions()` +// C function in the top-level namespace. // // (note the extern "C" directive to prevent the compiler from mangling the // symbol name). diff --git a/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp b/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp index 97672761bb88..fbade514dbdf 100644 --- a/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp +++ b/velox/common/dynamic_registry/tests/DynamicLinkTest.cpp @@ -152,8 +152,8 @@ TEST_F(DynamicLinkTest, dynamicLoadErrFunc) { dynamicFunctionFail(0, 0), "Scalar function signature is not supported: dynamic_err(BIGINT). Supported signatures: (array(bigint)) -> bigint."); - auto check = makeRowVector( - {makeNullableArrayVector(std::vector>>{ + auto check = makeRowVector({makeNullableArrayVector( + std::vector>>{ {0, 1, 3, 4, 5, 6, 7, 8, 9}})}); // Expecting a success because we are passing in an array. diff --git a/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverload.cpp b/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverload.cpp index 250eea92a1c6..1f6c4788ff5b 100644 --- a/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverload.cpp +++ b/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverload.cpp @@ -18,8 +18,8 @@ // This file defines a mock function that will be dynamically linked and // registered. There are no restrictions as to how the function needs to be -// defined, but the library (.so) needs to provide a `void registry()` C -// function in the top-level namespace. +// defined, but the library (.so) needs to provide a `void registerExtensions()` +// C function in the top-level namespace. // // (note the extern "C" directive to prevent the compiler from mangling the // symbol name). diff --git a/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverwrite.cpp b/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverwrite.cpp index 1678a16e0b98..adbdd299a238 100644 --- a/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverwrite.cpp +++ b/velox/common/dynamic_registry/tests/DynamicVarcharFunctionOverwrite.cpp @@ -18,8 +18,8 @@ // This file defines a mock function that will be dynamically linked and // registered. There are no restrictions as to how the function needs to be -// defined, but the library (.so) needs to provide a `void registry()` C -// function in the top-level namespace. +// defined, but the library (.so) needs to provide a `void registerExtensions()` +// C function in the top-level namespace. // // (note the extern "C" directive to prevent the compiler from mangling the // symbol name). diff --git a/velox/common/encode/Base32.cpp b/velox/common/encode/Base32.cpp new file mode 100644 index 000000000000..b43a2b03f0a3 --- /dev/null +++ b/velox/common/encode/Base32.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/encode/Base32.h" + +#include +#include +#include + +#include "velox/common/base/CheckedArithmetic.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::encoding { + +// Reverse lookup table for decoding. 255 means invalid character. +// Only uppercase letters (A-Z) and digits 2-7 are valid per RFC 4648. +// Lowercase letters are NOT supported (matching Google Guava's +// BaseEncoding.base32()). +constexpr const Base32::ReverseIndex kBase32ReverseIndexTable = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 0-15 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 16-31 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 32-47 + 255, 255, 26, 27, 28, 29, 30, 31, 255, + 255, 255, 255, 255, 255, 255, 255, // 48-63 ('2'-'7') + 255, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, // 64-79 ('A'-'O') + 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 255, 255, 255, 255, 255, // 80-95 ('P'-'Z') + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 96-111 (lowercase 'a'-'o' - INVALID) + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 112-127 (lowercase 'p'-'z' - INVALID) + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 128-143 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 144-159 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 160-175 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 176-191 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 192-207 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 208-223 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, // 224-239 + 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255 // 240-255 +}; + +// static +folly::Expected Base32::base32ReverseLookup( + char encodedChar, + const ReverseIndex& reverseIndex) { + auto index = reverseIndex[static_cast(encodedChar)]; + if (index >= 32) { + return folly::makeUnexpected( + Status::UserError("Unrecognized character: {}", encodedChar)); + } + return index; +} + +// static +folly::Expected Base32::calculateDecodedSize( + const char* input, + const size_t inputSize) { + if (inputSize == 0) { + return 0; + } + + // Count valid (non-padding, non-whitespace) characters and validate them + size_t validCharCount = 0; + for (size_t i = 0; i < inputSize; ++i) { + char c = input[i]; + if (c == Base32::kPadding || std::isspace(static_cast(c))) { + continue; + } + + // Validate character first + auto index = kBase32ReverseIndexTable[static_cast(c)]; + if (index >= 32) { + return folly::makeUnexpected( + Status::UserError("Unrecognized character: {}", c)); + } + validCharCount++; + } + + // Validate input length matches Google Guava's Base32 behavior. + // Base32 encoding groups characters into quantums of 8 characters (40 bits). + // Valid character counts (mod 8) are: 0, 2, 4, 5, 7 + // Invalid character counts (mod 8) are: 1, 3, 6 + // These invalid counts leave too many incomplete bits that cannot form + // complete bytes. + if (validCharCount > 0) { + size_t remainder = validCharCount % 8; + if (remainder == 1 || remainder == 3 || remainder == 6) { + return folly::makeUnexpected( + Status::UserError("Invalid input length {}", validCharCount)); + } + } + + // Calculate decoded size + // Each base32 character represents 5 bits + // We need to convert to bytes (8 bits each) + size_t totalBits = checkedMultiply(validCharCount, size_t(5)); + size_t decodedSize = totalBits / 8; + + return decodedSize; +} + +// static +Status Base32::decode( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize) { + auto decodedSize = decodeImpl( + input, inputSize, outputBuffer, outputSize, kBase32ReverseIndexTable); + if (decodedSize.hasError()) { + return decodedSize.error(); + } + return Status::OK(); +} + +// Decodes Base32 input using the provided reverse lookup table. +// This is the core decoding implementation that accumulates 5-bit values +// from Base32 characters and outputs 8-bit bytes. +// static +folly::Expected Base32::decodeImpl( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize, + const ReverseIndex& reverseIndex) { + if (inputSize == 0) { + return 0; + } + + size_t outputPos = 0; + uint64_t accumulator = 0; + size_t bitsAccumulated = 0; + + for (size_t i = 0; i < inputSize; ++i) { + char c = input[i]; + + // Skip padding and whitespace (RFC 4648 allows whitespace in encoded data) + if (c == Base32::kPadding || std::isspace(static_cast(c))) { + continue; + } + + // Validate and convert character to 5-bit value + auto value = base32ReverseLookup(c, reverseIndex); + if (value.hasError()) { + return folly::makeUnexpected(value.error()); + } + + // Accumulate 5 bits from each Base32 character + // Each character contributes 5 bits to the bit accumulator + accumulator = (accumulator << 5) | value.value(); + bitsAccumulated += 5; + + // Extract full bytes (8 bits) when we have accumulated enough bits + if (bitsAccumulated >= 8) { + if (outputPos >= outputSize) { + return folly::makeUnexpected( + Status::UserError("Output buffer too small")); + } + outputBuffer[outputPos++] = + static_cast((accumulator >> (bitsAccumulated - 8)) & 0xFF); + bitsAccumulated -= 8; + } + } + + return outputPos; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base32.h b/velox/common/encode/Base32.h new file mode 100644 index 000000000000..246a60d7994f --- /dev/null +++ b/velox/common/encode/Base32.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +#include "velox/common/base/Status.h" + +namespace facebook::velox::encoding { + +class Base32 { + public: + static const size_t kCharsetSize = 32; + static const size_t kReverseIndexSize = 256; + + /// Character set used for Base32 encoding. + /// Contains specific characters that form the encoding scheme. + using Charset = std::array; + + /// Reverse lookup table for decoding. + /// Maps each possible encoded character to its corresponding numeric value + /// within the encoding base. + using ReverseIndex = std::array; + + /// Decodes the specified number of characters from the 'input' and writes the + /// result to the 'outputBuffer'. + static Status decode( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize); + + /// Calculates the decoded size based on encoded input. + static folly::Expected calculateDecodedSize( + const char* input, + const size_t inputSize); + + // Padding character used in encoding. + static const char kPadding = '='; + + // Constants defining the size in bytes of binary and encoded blocks for + // Base32 encoding. Size of a binary block in bytes (5 bytes = 40 bits) + static const int kBinaryBlockByteSize = 5; + // Size of an encoded block in bytes (8 bytes = 40 bits) + static const int kEncodedBlockByteSize = 8; + + private: + // Reverse lookup helper function to get the original index of a Base32 + // character. + static folly::Expected base32ReverseLookup( + char encodedChar, + const ReverseIndex& reverseIndex); + + // Decodes the specified data using the provided reverse lookup table. + static folly::Expected decodeImpl( + const char* input, + size_t inputSize, + char* outputBuffer, + size_t outputSize, + const ReverseIndex& reverseIndex); +}; + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index cc62a8f7a0e8..2b87fd141ab5 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -187,14 +187,13 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) { // static void Base64::encode(const char* input, size_t inputSize, char* output) { - encodeImpl( - folly::StringPiece(input, inputSize), kBase64Charset, true, output); + encodeImpl(std::string_view(input, inputSize), kBase64Charset, true, output); } // static void Base64::encodeUrl(const char* input, size_t inputSize, char* output) { encodeImpl( - folly::StringPiece(input, inputSize), kBase64UrlCharset, true, output); + std::string_view(input, inputSize), kBase64UrlCharset, true, output); } // static @@ -249,13 +248,13 @@ void Base64::encodeImpl( } // static -std::string Base64::encode(folly::StringPiece text) { +std::string Base64::encode(std::string_view text) { return encodeImpl(text, kBase64Charset, true); } // static std::string Base64::encode(const char* input, size_t inputSize) { - return encode(folly::StringPiece(input, inputSize)); + return encode(std::string_view(input, inputSize)); } namespace { @@ -308,7 +307,7 @@ std::string Base64::encode(const folly::IOBuf* inputBuffer) { } // static -std::string Base64::decode(folly::StringPiece encodedText) { +std::string Base64::decode(std::string_view encodedText) { std::string decodedResult; decode(std::make_pair(encodedText.data(), encodedText.size()), decodedResult); return decodedResult; @@ -346,9 +345,10 @@ Expected Base64::base64ReverseLookup( const ReverseIndex& reverseIndex) { auto reverseLookupValue = reverseIndex[static_cast(encodedChar)]; if (reverseLookupValue >= 0x40) { - return folly::makeUnexpected(Status::UserError( - "decode() - invalid input string: invalid character '{}'", - encodedChar)); + return folly::makeUnexpected( + Status::UserError( + "decode() - invalid input string: invalid character '{}'", + encodedChar)); } return reverseLookupValue; } @@ -381,8 +381,9 @@ Expected Base64::calculateDecodedSize( // block size if (inputSize % kEncodedBlockByteSize != 0) { return folly::makeUnexpected( - Status::UserError("Base64::decode() - invalid input string: " - "string length is not a multiple of 4.")); + Status::UserError( + "Base64::decode() - invalid input string: " + "string length is not a multiple of 4.")); } auto decodedSize = @@ -403,9 +404,10 @@ Expected Base64::calculateDecodedSize( // Adjust the needed size for extra bytes, if present if (extraBytes) { if (extraBytes == 1) { - return folly::makeUnexpected(Status::UserError( - "Base64::decode() - invalid input string: " - "string length cannot be 1 more than a multiple of 4.")); + return folly::makeUnexpected( + Status::UserError( + "Base64::decode() - invalid input string: " + "string length cannot be 1 more than a multiple of 4.")); } decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; } @@ -431,8 +433,9 @@ Expected Base64::decodeImpl( if (outputSize < decodedSize.value()) { return folly::makeUnexpected( - Status::UserError("Base64::decode() - invalid output string: " - "output string is too small.")); + Status::UserError( + "Base64::decode() - invalid output string: " + "output string is too small.")); } outputSize = decodedSize.value(); @@ -492,13 +495,13 @@ Expected Base64::decodeImpl( } // static -std::string Base64::encodeUrl(folly::StringPiece text) { +std::string Base64::encodeUrl(std::string_view text) { return encodeImpl(text, kBase64UrlCharset, false); } // static std::string Base64::encodeUrl(const char* input, size_t inputSize) { - return encodeUrl(folly::StringPiece(input, inputSize)); + return encodeUrl(std::string_view(input, inputSize)); } // static @@ -521,7 +524,7 @@ Status Base64::decodeUrl( } // static -std::string Base64::decodeUrl(folly::StringPiece encodedText) { +std::string Base64::decodeUrl(std::string_view encodedText) { std::string decodedOutput; decodeUrl( std::make_pair(encodedText.data(), encodedText.size()), decodedOutput); @@ -546,4 +549,189 @@ void Base64::decodeUrl( decodedOutput.resize(decodedSize.value()); } +// static +Status Base64::decodeMime(const char* input, size_t inputSize, char* output) { + if (inputSize == 0) { + return Status::OK(); + } + + // 24-bit buffer. + uint32_t accumulator = 0; + // Next shift amount. + int bitsNeeded = 18; + size_t idx = 0; + char* outPtr = output; + + while (idx < inputSize) { + unsigned char c = static_cast(input[idx++]); + int val = kBase64ReverseIndexTable[c]; + + // Padding character. + if (c == kPadding) { + // If we see '=' too early or only one '=' when two are needed → error. + if (bitsNeeded == 18 || + (bitsNeeded == 6 && (idx == inputSize || input[idx++] != kPadding))) { + return Status::UserError( + "Input byte array has wrong 4-byte ending unit."); + } + break; + } + + // Skip whitespace or other non-Base64 chars. + if (val < 0 || val >= 0x40) { + continue; + } + + // Accumulate 6 bits + accumulator |= (static_cast(val) << bitsNeeded); + bitsNeeded -= 6; + + // If we've collected 24 bits, write out 3 bytes. + if (bitsNeeded < 0) { + *outPtr++ = static_cast((accumulator >> 16) & 0xFF); + *outPtr++ = static_cast((accumulator >> 8) & 0xFF); + *outPtr++ = static_cast(accumulator & 0xFF); + accumulator = 0; + bitsNeeded = 18; + } + } + + // Handle any remaining bits (1 or 2 bytes). + if (bitsNeeded == 0) { + *outPtr++ = static_cast((accumulator >> 16) & 0xFF); + *outPtr++ = static_cast((accumulator >> 8) & 0xFF); + } else if (bitsNeeded == 6) { + *outPtr++ = static_cast((accumulator >> 16) & 0xFF); + } else if (bitsNeeded == 12) { + return Status::UserError("Last unit does not have enough valid bits."); + } + + // Verify no illegal trailing Base64 data. + while (idx < inputSize) { + unsigned char c = static_cast(input[idx++]); + int val = kBase64ReverseIndexTable[c]; + // Valid data after completion is an error. + if (val >= 0 && val < 0x40) { + return Status::UserError("Input byte array has incorrect ending."); + } + // '=' padding beyond handled ones is OK; other negatives are skips. + } + + return Status::OK(); +} + +// static +Expected Base64::calculateMimeDecodedSize( + const char* input, + const size_t inputSize) { + if (inputSize == 0) { + return 0; + } + if (inputSize < 2) { + if (kBase64ReverseIndexTable[static_cast(input[0])] >= 0x40) { + return 0; + } + return folly::makeUnexpected( + Status::UserError( + "Input should at least have 2 bytes for base64 bytes.")); + } + auto decodedSize = inputSize; + // Compute how many true Base64 chars. + for (size_t i = 0; i < inputSize; ++i) { + auto c = input[i]; + if (c == kPadding) { + decodedSize -= inputSize - i; + break; + } + if (kBase64ReverseIndexTable[static_cast(c)] >= 0x40) { + decodedSize--; + } + } + // If no explicit padding but validChars ≢ 0 mod 4, infer missing '='. + size_t paddings = 0; + if ((decodedSize & 0x3) != 0) { + paddings = 4 - (decodedSize & 0x3); + } + // Each 4-char block yields 3 bytes; subtract padding. + decodedSize = 3 * ((decodedSize + 3) / 4) - paddings; + return decodedSize; +} + +// static +void Base64::encodeMime(const char* input, size_t inputSize, char* output) { + // If there's nothing to encode, do nothing. + if (inputSize == 0) { + return; + } + + const char* readPtr = input; + char* writePtr = output; + // Bytes per 76-char line. + const size_t bytesPerLine = (kMaxLineLength / 4) * 3; + size_t remaining = inputSize; + + // Process full lines of up to 'bytesPerLine' bytes. + while (remaining >= 3) { + // Round down to a multiple of 3, but not more than one line. + size_t chunk = std::min(bytesPerLine, (remaining / 3) * 3); + const char* chunkEnd = readPtr + chunk; + + // Encode each group of 3 bytes into 4 Base64 characters. + while (readPtr + 2 < chunkEnd) { + // Read three bytes separately to avoid undefined behavior. + uint8_t b0 = static_cast(*readPtr++); + uint8_t b1 = static_cast(*readPtr++); + uint8_t b2 = static_cast(*readPtr++); + // Pack into a 24-bit value. + uint32_t trio = (static_cast(b0) << 16) | + (static_cast(b1) << 8) | static_cast(b2); + // Emit four Base64 characters. + *writePtr++ = kBase64Charset[(trio >> 18) & 0x3F]; + *writePtr++ = kBase64Charset[(trio >> 12) & 0x3F]; + *writePtr++ = kBase64Charset[(trio >> 6) & 0x3F]; + *writePtr++ = kBase64Charset[trio & 0x3F]; + } + + remaining -= chunk; + + // Insert CRLF if we filled exactly one line and still have more data. + if (chunk == bytesPerLine && remaining > 0) { + *writePtr++ = kNewline[0]; + *writePtr++ = kNewline[1]; + } + } + + // Handle the final 1 or 2 leftover bytes with padding. + if (remaining > 0) { + uint8_t b0 = static_cast(*readPtr++); + // First Base64 character from the high 6 bits. + *writePtr++ = kBase64Charset[b0 >> 2]; + + if (remaining == 1) { + // Only one byte remains: produce two chars + two '=' paddings. + *writePtr++ = kBase64Charset[(b0 & 0x03) << 4]; + *writePtr++ = kPadding; + *writePtr = kPadding; + } else { + // Two bytes remain: produce three chars + one '=' padding. + uint8_t b1 = static_cast(*readPtr); + *writePtr++ = kBase64Charset[((b0 & 0x03) << 4) | (b1 >> 4)]; + *writePtr++ = kBase64Charset[(b1 & 0x0F) << 2]; + *writePtr = kPadding; + } + } +} + +// static +size_t Base64::calculateMimeEncodedSize(size_t inputSize) { + if (inputSize == 0) { + return 0; + } + + size_t encodedSize = calculateEncodedSize(inputSize, true); + // Add CRLFs: one per full kMaxLineLength block. + encodedSize += (encodedSize - 1) / kMaxLineLength * kNewline.size(); + return encodedSize; +} + } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index e6f802489b27..7dca7d2fdbce 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -16,7 +16,6 @@ #pragma once -#include #include #include @@ -45,7 +44,7 @@ class Base64 { static std::string encode(const char* input, size_t inputSize); /// Encodes the specified text. - static std::string encode(folly::StringPiece text); + static std::string encode(std::string_view text); /// Encodes the specified IOBuf data. static std::string encode(const folly::IOBuf* inputBuffer); @@ -60,7 +59,7 @@ class Base64 { static std::string encodeUrl(const char* input, size_t inputSize); /// Encodes the specified text using URL encoding. - static std::string encodeUrl(folly::StringPiece text); + static std::string encodeUrl(std::string_view text); /// Encodes the specified IOBuf data using URL encoding. static std::string encodeUrl(const folly::IOBuf* inputBuffer); @@ -72,7 +71,7 @@ class Base64 { encodeUrl(const char* input, size_t inputSize, char* outputBuffer); /// Decodes the input Base64 encoded string. - static std::string decode(folly::StringPiece encodedText); + static std::string decode(std::string_view encodedText); /// Decodes the specified encoded payload and writes the result to the /// 'output'. @@ -94,7 +93,7 @@ class Base64 { size_t outputSize); /// Decodes the input Base64 URL encoded string. - static std::string decodeUrl(folly::StringPiece encodedText); + static std::string decodeUrl(std::string_view encodedText); /// Decodes the specified URL encoded payload and writes the result to the /// 'output'. @@ -110,6 +109,16 @@ class Base64 { char* outputBuffer, size_t outputSize); + /// Decodes a Base64 MIME‐mode buffer back to binary. + /// Skips any non-Base64 chars (e.g. CR/LF). + static Status + decodeMime(const char* input, size_t inputSize, char* outputBuffer); + + /// Encodes the input buffer into Base64 MIME format. + /// Inserts a CRLF every kMaxLineLength output characters. + static void + encodeMime(const char* input, size_t inputSize, char* outputBuffer); + /// Calculates the encoded size based on input 'inputSize'. static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true); @@ -119,10 +128,25 @@ class Base64 { const char* input, size_t& inputSize); + /// Calculates the decoded binary size of a MIME‐mode Base64 input, + /// accounting for padding and ignoring whitespace. + static Expected calculateMimeDecodedSize( + const char* input, + const size_t inputSize); + + /// Computes the exact output length for MIME‐mode Base64 encoding, + /// including required CRLF line breaks. + static size_t calculateMimeEncodedSize(size_t inputSize); + private: // Padding character used in encoding. static const char kPadding = '='; + // Soft Line breaks used in mime encoding as defined in RFC 2045, section 6.8: + // https://www.rfc-editor.org/rfc/rfc2045#section-6.8 + inline static const std::string kNewline{"\r\n"}; + static const size_t kMaxLineLength = 76; + // Checks if the input Base64 string is padded. static inline bool isPadded(const char* input, size_t inputSize) { return (inputSize > 0 && input[inputSize - 1] == kPadding); diff --git a/velox/common/strings/ByteStream.h b/velox/common/encode/ByteStream.h similarity index 91% rename from velox/common/strings/ByteStream.h rename to velox/common/encode/ByteStream.h index 0a02b7a6a31a..9ebbee1054d7 100644 --- a/velox/common/strings/ByteStream.h +++ b/velox/common/encode/ByteStream.h @@ -13,13 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#pragma once + /** * Interfaces for efficient stream-like I/O. */ -#ifndef COMMON_STRINGS_BYTESTREAM_H_ -#define COMMON_STRINGS_BYTESTREAM_H_ - #include #include #include @@ -27,10 +27,8 @@ #include #include -#include -namespace facebook { -namespace strings { +namespace facebook::velox::strings { const size_t kSizeMax = std::numeric_limits::max(); @@ -88,28 +86,28 @@ class ByteSink { * append() will return 0). In particular, this may not be used for * non-blocking behavior. */ - virtual size_t append(folly::StringPiece str) = 0; + virtual size_t append(std::string_view str) = 0; size_t append(const void* data, size_t size) { - return append(folly::StringPiece(static_cast(data), size)); + return append(std::string_view(static_cast(data), size)); } /** - * Append the given string to this ByteSink. The string must remain + * Append the given string to this ByteSink. The string must remain * allocated (and unchanged) until the ByteSink is destroyed. */ - virtual size_t appendAllocated(folly::StringPiece str) { + virtual size_t appendAllocated(std::string_view str) { return append(str); } /** * Convenience function that appends the bitwise representation of count - * objects starting at address obj. The usual caveats about endianness, + * objects starting at address obj. The usual caveats about endianness, * padding apply. */ template size_t appendBitwise(const T* obj, size_t count) { const size_t sz = count * sizeof(T); - return append(folly::StringPiece(reinterpret_cast(obj), sz)); + return append(std::string_view(reinterpret_cast(obj), sz)); } /** @@ -187,8 +185,8 @@ class SByteSink : public ByteSink { public: explicit SByteSink(S* str) : str_(str) {} - size_t append(folly::StringPiece s) override { - str_->append(s.start(), s.size()); + size_t append(std::string_view s) override { + str_->append(s.data(), s.size()); return s.size(); } @@ -239,7 +237,7 @@ class ByteSource { * next() will return false, but bad() will also return false. On error, * next() returns false, and bad() returns true. */ - virtual bool next(folly::StringPiece* chunk) = 0; + virtual bool next(std::string_view* chunk) = 0; /** * Push back the last numBytes returned by the last next() call, so @@ -318,7 +316,7 @@ class ByteSourceBuffer : public std::basic_streambuf { class StringByteSource : public ByteSource { public: explicit StringByteSource( - const folly::StringPiece& str, + const std::string_view& str, size_t maxBytes = kSizeMax) : str_(str), offset_(0), @@ -328,15 +326,17 @@ class StringByteSource : public ByteSource { bool bad() const override { return false; } - bool next(folly::StringPiece* chunk) override { + + bool next(std::string_view* chunk) override { if (offset_ == str_.size()) { return false; } size_t len = std::min(str_.size() - offset_, maxBytes_); - chunk->reset(str_.start() + offset_, len); + *chunk = std::string_view(str_.data() + offset_, len); offset_ += len; return true; } + void backUp(size_t numBytes) override { CHECK_LE(numBytes, maxBytes_); CHECK_GE(offset_, numBytes); @@ -344,12 +344,9 @@ class StringByteSource : public ByteSource { } private: - folly::StringPiece str_; + std::string_view str_; size_t offset_; size_t maxBytes_; }; -} // namespace strings -} // namespace facebook - -#endif /* __COMMON_STRINGS_BYTESTREAM_H_ */ +} // namespace facebook::velox::strings diff --git a/velox/common/encode/CMakeLists.txt b/velox/common/encode/CMakeLists.txt index 15acfcc232a2..32875adc48e7 100644 --- a/velox/common/encode/CMakeLists.txt +++ b/velox/common/encode/CMakeLists.txt @@ -16,5 +16,5 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -velox_add_library(velox_encode Base64.cpp) +velox_add_library(velox_encode Base64.cpp Base32.cpp) velox_link_libraries(velox_encode PUBLIC velox_status Folly::folly) diff --git a/velox/common/encode/Coding.h b/velox/common/encode/Coding.h index 284c226da901..e4edaac4bcde 100644 --- a/velox/common/encode/Coding.h +++ b/velox/common/encode/Coding.h @@ -13,22 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#pragma once + // // Miscellaneous number encoding/decoding routines // - Varint coding // - ZigZag coding -#ifndef FACEBOOK_COMMON_ENCODE_CODING_H_ -#define FACEBOOK_COMMON_ENCODE_CODING_H_ - #include #include #include -#include + +#include "velox/common/encode/ByteStream.h" #include "velox/common/encode/UInt128.h" -#include "velox/common/strings/ByteStream.h" -namespace facebook { +namespace facebook::velox { // Variable-length integer encoding, using a little-endian, base-128 // representation. @@ -49,6 +49,7 @@ class Varint { static UInt128 shift(uint64_t a, uint64_t b, uint32_t bBits) { return (UInt128(a) << bBits) | b; } + static std::pair unshift(UInt128 val, uint32_t bBits) { std::pair p; p.second = val.lo() & ((static_cast(1) << bBits) - 1); @@ -108,19 +109,20 @@ class Varint { char buf[kMaxSize64]; char* p = buf; encode(val, &p); - sink->append(folly::StringPiece(buf, p - buf)); + sink->append(std::string_view(buf, p - buf)); } + static void encode128ToByteSink(UInt128 val, strings::ByteSink* sink) { char buf[kMaxSize128]; char* p = buf; encode128(val, &p); - sink->append(folly::StringPiece(buf, p - buf)); + sink->append(std::string_view(buf, p - buf)); } // Returns true if decode can be called without causing a CHECK failure. // The pointers are not adjusted at all - static bool canDecode(folly::StringPiece src) { - src = src.subpiece(0, kMaxSize64); + static bool canDecode(std::string_view src) { + src = src.substr(0, kMaxSize64); return std::any_of( src.begin(), src.end(), [](char v) { return ~v & 0x80; }); } @@ -163,6 +165,7 @@ class Varint { *src = p; return val; } + static UInt128 decode128(const char** src, int64_t max_size) { if (max_size > kMaxSize128) { // Varint-encoded numbers are at most 19 bytes, and we want to catch @@ -184,17 +187,18 @@ class Varint { return val; } - // Decode a value from a StringPiece, and advance the StringPiece. - static uint64_t decode(folly::StringPiece* data) { - const char* p = data->start(); + // Decode a value from a string_view, and advance it. + static uint64_t decode(std::string_view* data) { + const char* p = data->data(); uint64_t val = decode(&p, data->size()); - data->advance(p - data->start()); + data->remove_prefix(p - data->data()); return val; } - static UInt128 decode128(folly::StringPiece* data) { - const char* p = data->start(); + + static UInt128 decode128(std::string_view* data) { + const char* p = data->data(); UInt128 val = decode128(&p, data->size()); - data->advance(p - data->start()); + data->remove_prefix(p - data->data()); return val; } @@ -203,13 +207,13 @@ class Varint { uint64_t val = 0; int32_t shift = 0; int32_t max_size = kMaxSize64; - folly::StringPiece chunk; + std::string_view chunk; int32_t remaining = 0; const char* p = nullptr; for (;;) { if (remaining == 0) { CHECK(src->next(&chunk)); - p = chunk.start(); + p = chunk.data(); remaining = chunk.size(); DCHECK_GT(remaining, 0); } @@ -229,17 +233,18 @@ class Varint { } return val; } + static UInt128 decode128FromByteSource(strings::ByteSource* src) { UInt128 val = 0; int32_t shift = 0; int32_t max_size = kMaxSize128; - folly::StringPiece chunk; + std::string_view chunk; int32_t remaining = 0; const char* p = nullptr; for (;;) { if (remaining == 0) { CHECK(src->next(&chunk)); - p = chunk.start(); + p = chunk.data(); remaining = chunk.size(); DCHECK_GT(remaining, 0); } @@ -283,32 +288,30 @@ class ZigZag { } }; -namespace internal { +namespace detail { class ByteSinkAppender { public: /* implicit */ ByteSinkAppender(strings::ByteSink* out) : out_(out) {} - void operator()(folly::StringPiece sp) { + void operator()(std::string_view sp) { out_->append(sp.data(), sp.size()); } private: strings::ByteSink* out_; }; -} // namespace internal +} // namespace detail // Import GroupVarint encoding / decoding code from folly typedef folly::GroupVarint32 GroupVarint32; typedef folly::GroupVarint64 GroupVarint64; -typedef folly::GroupVarintEncoder +typedef folly::GroupVarintEncoder GroupVarint32Encoder; -typedef folly::GroupVarintEncoder +typedef folly::GroupVarintEncoder GroupVarint64Encoder; typedef folly::GroupVarintDecoder GroupVarint32Decoder; typedef folly::GroupVarintDecoder GroupVarint64Decoder; -} // namespace facebook - -#endif /* FACEBOOK_COMMON_ENCODE_CODING_H_ */ +} // namespace facebook::velox diff --git a/velox/common/encode/UInt128.h b/velox/common/encode/UInt128.h index 8167e1fe5eee..45c879e51525 100644 --- a/velox/common/encode/UInt128.h +++ b/velox/common/encode/UInt128.h @@ -108,12 +108,7 @@ class UInt128 { return UInt128(~hi_, ~lo_); } - bool operator==(UInt128 other) const { - return hi_ == other.hi_ && lo_ == other.lo_; - } - bool operator!=(UInt128 other) const { - return !(*this == other); - } + bool operator==(const UInt128& other) const = default; private: uint64_t hi_; diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 37ddcf8240f8..bc78bd6c5019 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -25,27 +25,20 @@ namespace facebook::velox::encoding { class Base64Test : public ::testing::Test {}; TEST_F(Base64Test, fromBase64) { - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ=="))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ==")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4="))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ="))); - EXPECT_EQ( - "1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA=="))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ=")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA==")); // Check encoded strings without padding - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ"))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4"))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ"))); - EXPECT_EQ("1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA"))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA")); } TEST_F(Base64Test, calculateDecodedSizeProperSize) { @@ -107,4 +100,75 @@ TEST_F(Base64Test, countsPaddingCorrectly) { EXPECT_EQ(1, Base64::numPadding("ABC=", 4)); EXPECT_EQ(2, Base64::numPadding("AB==", 4)); } + +TEST_F(Base64Test, calculateMimeDecodedSize) { + EXPECT_EQ(0, Base64::calculateMimeDecodedSize("", 0).value()); + EXPECT_EQ(0, Base64::calculateMimeDecodedSize("#", 1).value()); + EXPECT_EQ(3, Base64::calculateMimeDecodedSize("TWFu", 4).value()); + EXPECT_EQ(1, Base64::calculateMimeDecodedSize("AQ==", 4).value()); + EXPECT_EQ(2, Base64::calculateMimeDecodedSize("TWE=", 4).value()); + EXPECT_EQ(3, Base64::calculateMimeDecodedSize("TWFu\r\n", 6).value()); + EXPECT_EQ(3, Base64::calculateMimeDecodedSize("!TW!Fu!", 7).value()); + EXPECT_EQ(1, Base64::calculateMimeDecodedSize("TQ", 2).value()); + EXPECT_EQ( + Base64::calculateMimeDecodedSize("A", 1).error(), + Status::UserError( + "Input should at least have 2 bytes for base64 bytes.")); +} + +TEST_F(Base64Test, decodeMime) { + auto decodeMime = [](const std::string& in) { + size_t decSize = + Base64::calculateMimeDecodedSize(in.data(), in.size()).value(); + std::string out(decSize, '\0'); + auto result = Base64::decodeMime(in.data(), in.size(), out.data()); + if (!result.ok()) { + VELOX_USER_FAIL(result.message()); + } + return out; + }; + EXPECT_EQ("", decodeMime("")); + EXPECT_EQ("Man", decodeMime("TWFu")); + EXPECT_EQ("ManMan", decodeMime("TWFu\r\nTWFu")); + EXPECT_EQ("\x01", decodeMime("AQ==")); + EXPECT_EQ("\xff\xee", decodeMime("/+4=")); + VELOX_ASSERT_USER_THROW( + decodeMime("QUFBx"), "Last unit does not have enough valid bits"); + VELOX_ASSERT_USER_THROW( + decodeMime("xx=y"), "Input byte array has wrong 4-byte ending unit"); + VELOX_ASSERT_USER_THROW( + decodeMime("xx="), "Input byte array has wrong 4-byte ending unit"); + VELOX_ASSERT_USER_THROW( + decodeMime("QUFB="), "Input byte array has wrong 4-byte ending unit"); + VELOX_ASSERT_USER_THROW( + decodeMime("AQ==y"), "Input byte array has incorrect ending"); +} + +TEST_F(Base64Test, calculateMimeEncodedSize) { + EXPECT_EQ(0, Base64::calculateMimeEncodedSize(0)); + EXPECT_EQ(8, Base64::calculateMimeEncodedSize(4)); + EXPECT_EQ(76, Base64::calculateMimeEncodedSize(57)); + EXPECT_EQ(82, Base64::calculateMimeEncodedSize(58)); + EXPECT_EQ(274, Base64::calculateMimeEncodedSize(200)); +} + +TEST_F(Base64Test, encodeMime) { + auto encodeMime = [](const std::string& in) { + size_t len = Base64::calculateMimeEncodedSize(in.size()); + std::string out(len, '\0'); + Base64::encodeMime(in.data(), in.size(), out.data()); + return out; + }; + EXPECT_EQ("", encodeMime("")); + EXPECT_EQ("TWFu", encodeMime("Man")); + EXPECT_EQ("AQ==", encodeMime("\x01")); + EXPECT_EQ("/+4=", encodeMime("\xff\xee")); + EXPECT_EQ( + "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFB", + encodeMime(std::string(57, 'A'))); + EXPECT_EQ( + "QUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFB\r\nQQ==", + encodeMime(std::string(58, 'A'))); +} + } // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 5a72d9c55382..9acd50159563 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -17,9 +17,5 @@ add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test PUBLIC Folly::folly - PRIVATE - velox_encode - velox_exception - velox_status - GTest::gtest - GTest::gtest_main) + PRIVATE velox_encode velox_exception velox_status velox_common_base GTest::gtest GTest::gtest_main +) diff --git a/velox/common/encode/tests/ZigZagTest.cpp b/velox/common/encode/tests/ZigZagTest.cpp index a69d6326056d..e7a02add7608 100644 --- a/velox/common/encode/tests/ZigZagTest.cpp +++ b/velox/common/encode/tests/ZigZagTest.cpp @@ -15,7 +15,6 @@ */ #include -#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/encode/Coding.h" #include "velox/type/HugeInt.h" diff --git a/velox/common/file/CMakeLists.txt b/velox/common/file/CMakeLists.txt index 31d9f1ebe0ff..b89a4e3a38e8 100644 --- a/velox/common/file/CMakeLists.txt +++ b/velox/common/file/CMakeLists.txt @@ -14,16 +14,12 @@ # for generated headers include_directories(.) -velox_add_library( - velox_file - File.cpp - FileInputStream.cpp - FileSystems.cpp - Utils.cpp) +velox_add_library(velox_file File.cpp FileInputStream.cpp FileSystems.cpp Utils.cpp) velox_link_libraries( velox_file PUBLIC velox_exception Folly::folly - PRIVATE velox_buffer velox_common_base fmt::fmt glog::glog) + PRIVATE velox_buffer velox_common_base fmt::fmt glog::glog +) if(${VELOX_BUILD_TESTING} OR ${VELOX_BUILD_TEST_UTILS}) add_subdirectory(tests) diff --git a/velox/common/file/File.cpp b/velox/common/file/File.cpp index c1f1e8982737..668cba32cede 100644 --- a/velox/common/file/File.cpp +++ b/velox/common/file/File.cpp @@ -60,10 +60,10 @@ T getAttribute( std::string ReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { std::string buf; buf.resize(length); - auto res = pread(offset, length, buf.data(), stats); + auto res = pread(offset, length, buf.data(), fileStorageContext); buf.resize(res.size()); return buf; } @@ -71,7 +71,7 @@ std::string ReadFile::pread( uint64_t ReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { auto fileSize = size(); uint64_t numRead = 0; if (offset >= fileSize) { @@ -81,7 +81,7 @@ uint64_t ReadFile::preadv( auto copySize = std::min(range.size(), fileSize - offset); // NOTE: skip the gap in case of coalesce io. if (range.data() != nullptr) { - pread(offset, copySize, range.data(), stats); + pread(offset, copySize, range.data(), fileStorageContext); } offset += copySize; numRead += copySize; @@ -92,18 +92,21 @@ uint64_t ReadFile::preadv( uint64_t ReadFile::preadv( folly::Range regions, folly::Range iobufs, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { VELOX_CHECK_EQ(regions.size(), iobufs.size()); uint64_t length = 0; for (size_t i = 0; i < regions.size(); ++i) { const auto& region = regions[i]; auto& output = iobufs[i]; output = folly::IOBuf(folly::IOBuf::CREATE, region.length); - pread(region.offset, region.length, output.writableData(), stats); + pread( + region.offset, + region.length, + output.writableData(), + fileStorageContext); output.append(region.length); length += region.length; } - return length; } @@ -111,7 +114,7 @@ std::string_view InMemoryReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { bytesRead_ += length; memcpy(buf, file_.data() + offset, length); return {static_cast(buf), length}; @@ -120,7 +123,7 @@ std::string_view InMemoryReadFile::pread( std::string InMemoryReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { bytesRead_ += length; return std::string(file_.data() + offset, length); } @@ -202,7 +205,7 @@ std::string_view LocalReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { preadInternal(offset, length, static_cast(buf)); return {static_cast(buf), length}; } @@ -210,7 +213,7 @@ std::string_view LocalReadFile::pread( uint64_t LocalReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { // Dropped bytes sized so that a typical dropped range of 50K is not // too many iovecs. static thread_local std::vector droppedBytes(16 * 1024); @@ -267,17 +270,18 @@ uint64_t LocalReadFile::preadv( folly::SemiFuture LocalReadFile::preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { if (!executor_) { - return ReadFile::preadvAsync(offset, buffers, stats); + return ReadFile::preadvAsync(offset, buffers, fileStorageContext); } auto [promise, future] = folly::makePromiseContract(); executor_->add([this, _promise = std::move(promise), _offset = offset, _buffers = buffers, - _stats = stats]() mutable { - auto delegateFuture = ReadFile::preadvAsync(_offset, _buffers, _stats); + _fileStorageContext = fileStorageContext]() mutable { + auto delegateFuture = + ReadFile::preadvAsync(_offset, _buffers, _fileStorageContext); _promise.setTry(std::move(delegateFuture).getTry()); }); return std::move(future); diff --git a/velox/common/file/File.h b/velox/common/file/File.h index 18d1c264ca7a..6089915dffac 100644 --- a/velox/common/file/File.h +++ b/velox/common/file/File.h @@ -37,14 +37,32 @@ #include #include #include +#include #include "velox/common/base/Exceptions.h" #include "velox/common/file/FileSystems.h" #include "velox/common/file/Region.h" #include "velox/common/io/IoStatistics.h" +#include + namespace facebook::velox { +struct FileStorageContext { + /// Stats for IO operations + filesystems::File::IoStats* ioStats{nullptr}; + + /// Options for file read operations + folly::F14FastMap fileReadOps; + + FileStorageContext() = default; + + FileStorageContext( + filesystems::File::IoStats* stats, + folly::F14FastMap fileReadOps = {}) + : ioStats(stats), fileReadOps(std::move(fileReadOps)) {} +}; + // A read-only file. All methods in this object should be thread safe. class ReadFile { public: @@ -52,16 +70,12 @@ class ReadFile { // Reads the data at [offset, offset + length) into the provided pre-allocated // buffer 'buf'. The bytes are returned as a string_view pointing to 'buf'. - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual std::string_view pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const = 0; + const FileStorageContext& fileStorageContext = {}) const = 0; // Same as above, but returns owned data directly. // @@ -69,20 +83,16 @@ class ReadFile { virtual std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const; + const FileStorageContext& fileStorageContext = {}) const; // Reads starting at 'offset' into the memory referenced by the // Ranges in 'buffers'. The buffers are filled left to right. A // buffer with nullptr data will cause its size worth of bytes to be skipped. - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual uint64_t preadv( uint64_t /*offset*/, const std::vector>& /*buffers*/, - filesystems::File::IoStats* stats = nullptr) const; + const FileStorageContext& fileStorageContext = {}) const; // Vectorized read API. Implementations can coalesce and parallelize. // The offsets don't need to be sorted. @@ -93,30 +103,23 @@ class ReadFile { // by the preadv. // Returns the total number of bytes read, which might be different than the // sum of all buffer sizes (for example, if coalescing was used). - // - // 'stats' is an IoStatistics pointer passed in by the caller to collect stats - // for this read operation. - // // This method should be thread safe. virtual uint64_t preadv( folly::Range regions, folly::Range iobufs, - filesystems::File::IoStats* stats = nullptr) const; + const FileStorageContext& fileStorageContext = {}) const; /// Like preadv but may execute asynchronously and returns the read size or /// exception via SemiFuture. Use hasPreadvAsync() to check if the /// implementation is in fact asynchronous. - /// - /// 'stats' is an IoStatistics pointer passed in by the caller to collect - /// stats for this read operation. - /// /// This method should be thread safe. virtual folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const { + const FileStorageContext& fileStorageContext = {}) const { try { - return folly::SemiFuture(preadv(offset, buffers, stats)); + return folly::SemiFuture( + preadv(offset, buffers, fileStorageContext)); } catch (const std::exception& e) { return folly::makeSemiFuture(e); } @@ -240,12 +243,12 @@ class InMemoryReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t size() const final { return file_.size(); @@ -311,19 +314,19 @@ class LocalReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t size() const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; bool hasPreadvAsync() const override { return executor_ != nullptr; diff --git a/velox/common/file/FileSystems.cpp b/velox/common/file/FileSystems.cpp index aa738b6a43ba..d5d07ceb943b 100644 --- a/velox/common/file/FileSystems.cpp +++ b/velox/common/file/FileSystems.cpp @@ -17,6 +17,7 @@ #include "velox/common/file/FileSystems.h" #include #include +#include #include "velox/common/base/Exceptions.h" #include "velox/common/file/File.h" @@ -90,7 +91,7 @@ class LocalFileSystem : public FileSystem { std::max( 1, static_cast( - std::thread::hardware_concurrency() / 2)), + folly::hardware_concurrency() / 2)), std::make_shared( "LocalReadahead")) : nullptr) {} @@ -229,7 +230,7 @@ class LocalFileSystem : public FileSystem { // Note: presto behavior is to prefix local paths with 'file:'. // Check for that prefix and prune to absolute regular paths as needed. return [](std::string_view filePath) { - return filePath.find("/") == 0 || filePath.find(kFileScheme) == 0; + return filePath.starts_with('/') || filePath.starts_with(kFileScheme); }; } diff --git a/velox/common/file/FileSystems.h b/velox/common/file/FileSystems.h index a1173190ba56..4362dcea4914 100644 --- a/velox/common/file/FileSystems.h +++ b/velox/common/file/FileSystems.h @@ -17,6 +17,7 @@ #include "velox/common/base/Exceptions.h" #include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/file/TokenProvider.h" #include "velox/common/memory/MemoryPool.h" #include @@ -47,10 +48,10 @@ struct FileOptions { /// etc. static constexpr folly::StringPiece kFileCreateConfig{"file-create-config"}; - std::unordered_map values; + std::unordered_map values{}; memory::MemoryPool* pool{nullptr}; /// If specified then can be trusted to be the file size. - std::optional fileSize; + std::optional fileSize{}; /// Whether to create parent directories if they don't exist. /// @@ -84,6 +85,13 @@ struct FileOptions { /// A hint to the file system for which region size of the file should be /// read. Specifically, the read length. std::optional readRangeHint{std::nullopt}; + + /// A token provider that can be used to get tokens for accessing the file. + std::shared_ptr tokenProvider{nullptr}; + + /// File read operations metadata that can be passed to the underlying file + /// system for tracking and logging purposes. + folly::F14FastMap fileReadOps{}; }; /// Defines directory options diff --git a/velox/common/file/PlainUserNameTokenProvider.h b/velox/common/file/PlainUserNameTokenProvider.h new file mode 100644 index 000000000000..acbc8c38748c --- /dev/null +++ b/velox/common/file/PlainUserNameTokenProvider.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/common/file/TokenProvider.h" + +namespace facebook::velox::filesystems { + +class PlainUserNameAccessToken : public filesystems::AccessToken { + public: + explicit PlainUserNameAccessToken(const std::string& user) : user_(user) {} + ~PlainUserNameAccessToken() override = default; + std::string getUser() const { + return user_; + } + + private: + std::string user_; +}; + +class PlainUserNameTokenProvider : public filesystems::TokenProvider { + public: + explicit PlainUserNameTokenProvider(const std::string& user) : user_(user) {} + bool equals(const TokenProvider& other) const override { + auto* o = dynamic_cast(&other); + return o && o->user_ == user_; + } + size_t hash() const override { + return std::hash()(user_); + } + std::shared_ptr getToken( + const filesystems::AccessTokenKey& /*key*/) const override { + return std::make_shared(user_); + } + + private: + std::string user_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/common/file/Region.h b/velox/common/file/Region.h index ae2ef9c8a2b2..5d588479b769 100644 --- a/velox/common/file/Region.h +++ b/velox/common/file/Region.h @@ -17,8 +17,13 @@ #pragma once #include +#include #include +#include + +#include "velox/common/base/SuccinctPrinter.h" + namespace facebook::velox::common { /// Defines a disk region to read. @@ -35,6 +40,14 @@ struct Region { return offset < other.offset || (offset == other.offset && length < other.length); } + + std::string toString() const { + return fmt::format( + "Region{{offset: {}, length: {}, label: {}}}", + offset, + succinctBytes(length), + label); + } }; } // namespace facebook::velox::common diff --git a/velox/common/file/TokenProvider.h b/velox/common/file/TokenProvider.h new file mode 100644 index 000000000000..e262fe8d0b99 --- /dev/null +++ b/velox/common/file/TokenProvider.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace facebook::velox::filesystems { + +/// Identifier for the file systems to implement to differentiate different +/// tokens needed in the same query (user). Such information usually needs to +/// be passed down and stored in the ReadFile/WriteFile object of the specific +/// file system. +class AccessTokenKey { + public: + virtual ~AccessTokenKey() = default; +}; + +/// Abstract token each file system can implement and cast. +class AccessToken { + public: + virtual ~AccessToken() = default; +}; + +/// Interface for providing access tokens to file systems. +class TokenProvider { + public: + virtual ~TokenProvider() = default; + + virtual bool equals(const TokenProvider& other) const = 0; + virtual size_t hash() const = 0; + + virtual std::shared_ptr getToken( + const AccessTokenKey& key) const = 0; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/common/file/Utils.h b/velox/common/file/Utils.h index b19468fbe1fe..1d04d5608066 100644 --- a/velox/common/file/Utils.h +++ b/velox/common/file/Utils.h @@ -44,10 +44,6 @@ class CoalesceRegions { return lhs.begin_ == rhs.begin_ && lhs.end_ == rhs.end_; } - friend bool operator!=(const Iter& lhs, const Iter& rhs) { - return !(lhs == rhs); - } - std::pair operator*() const { return {begin_, end_}; } diff --git a/velox/common/file/tests/CMakeLists.txt b/velox/common/file/tests/CMakeLists.txt index fb9ef6f73521..47b3c11e3d2b 100644 --- a/velox/common/file/tests/CMakeLists.txt +++ b/velox/common/file/tests/CMakeLists.txt @@ -12,15 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_file_test_utils TestUtils.cpp FaultyFile.cpp - FaultyFileSystem.cpp) +add_library(velox_file_test_utils TestUtils.cpp FaultyFile.cpp FaultyFileSystem.cpp) -target_link_libraries( - velox_file_test_utils - PUBLIC velox_file) +target_link_libraries(velox_file_test_utils PUBLIC velox_file) -add_executable(velox_file_test FileTest.cpp FileInputStreamTest.cpp - UtilsTest.cpp) +add_executable(velox_file_test FileTest.cpp FileInputStreamTest.cpp UtilsTest.cpp) add_test(velox_file_test velox_file_test) target_link_libraries( velox_file_test @@ -31,4 +27,5 @@ target_link_libraries( velox_temp_path GTest::gmock GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) diff --git a/velox/common/file/tests/FaultyFile.cpp b/velox/common/file/tests/FaultyFile.cpp index 17897fa99214..434028a6750d 100644 --- a/velox/common/file/tests/FaultyFile.cpp +++ b/velox/common/file/tests/FaultyFile.cpp @@ -34,7 +34,7 @@ std::string_view FaultyReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { if (injectionHook_ != nullptr) { FaultFileReadOperation op(path_, offset, length, buf); injectionHook_(&op); @@ -42,13 +42,13 @@ std::string_view FaultyReadFile::pread( return std::string_view(static_cast(op.buf), op.length); } } - return delegatedFile_->pread(offset, length, buf, stats); + return delegatedFile_->pread(offset, length, buf, fileStorageContext); } uint64_t FaultyReadFile::preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { if (injectionHook_ != nullptr) { FaultFileReadvOperation op(path_, offset, buffers); injectionHook_(&op); @@ -56,16 +56,16 @@ uint64_t FaultyReadFile::preadv( return op.readBytes; } } - return delegatedFile_->preadv(offset, buffers, stats); + return delegatedFile_->preadv(offset, buffers, fileStorageContext); } folly::SemiFuture FaultyReadFile::preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats) const { + const FileStorageContext& fileStorageContext) const { // TODO: add fault injection for async read later. if (delegatedFile_->hasPreadvAsync() || executor_ == nullptr) { - return delegatedFile_->preadvAsync(offset, buffers, stats); + return delegatedFile_->preadvAsync(offset, buffers, fileStorageContext); } auto promise = std::make_unique>(); folly::SemiFuture future = promise->getSemiFuture(); @@ -73,9 +73,9 @@ folly::SemiFuture FaultyReadFile::preadvAsync( _promise = std::move(promise), _offset = offset, _buffers = buffers, - _stats = stats]() { + _fileStorageContext = fileStorageContext]() { auto delegateFuture = - delegatedFile_->preadvAsync(_offset, _buffers, _stats); + delegatedFile_->preadvAsync(_offset, _buffers, _fileStorageContext); _promise->setValue(delegateFuture.wait().value()); }); return future; diff --git a/velox/common/file/tests/FaultyFile.h b/velox/common/file/tests/FaultyFile.h index 2b4818bd7a16..8fac903b0f6e 100644 --- a/velox/common/file/tests/FaultyFile.h +++ b/velox/common/file/tests/FaultyFile.h @@ -29,7 +29,7 @@ class FaultyReadFile : public ReadFile { FileFaultInjectionHook injectionHook, folly::Executor* executor); - ~FaultyReadFile() override{}; + ~FaultyReadFile() override {} uint64_t size() const override { return delegatedFile_->size(); @@ -39,12 +39,12 @@ class FaultyReadFile : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; uint64_t memoryUsage() const override { return delegatedFile_->memoryUsage(); @@ -72,7 +72,7 @@ class FaultyReadFile : public ReadFile { folly::SemiFuture preadvAsync( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override; + const FileStorageContext& fileStorageContext = {}) const override; private: const std::string path_; @@ -88,7 +88,7 @@ class FaultyWriteFile : public WriteFile { std::shared_ptr delegatedFile, FileFaultInjectionHook injectionHook); - ~FaultyWriteFile() override{}; + ~FaultyWriteFile() override {} void append(std::string_view data) override; diff --git a/velox/common/file/tests/FaultyFileSystem.h b/velox/common/file/tests/FaultyFileSystem.h index 1dbd698df66b..c14cc7201ee1 100644 --- a/velox/common/file/tests/FaultyFileSystem.h +++ b/velox/common/file/tests/FaultyFileSystem.h @@ -176,7 +176,7 @@ class FaultyFileSystem : public FileSystem { mutable std::mutex mu_; std::optional fileInjections_; std::optional fsInjections_; - folly::Executor* executor_; + folly::Executor* executor_{nullptr}; }; /// Registers the faulty filesystem. diff --git a/velox/common/file/tests/FileTest.cpp b/velox/common/file/tests/FileTest.cpp index 6194fbcc35ba..fb819d63e90c 100644 --- a/velox/common/file/tests/FileTest.cpp +++ b/velox/common/file/tests/FileTest.cpp @@ -16,6 +16,7 @@ #include #include +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/File.h" @@ -168,8 +169,9 @@ TEST(InMemoryFile, preadv) { std::vector values; values.reserve(iobufs.size()); for (auto& iobuf : iobufs) { - values.push_back(std::string{ - reinterpret_cast(iobuf.data()), iobuf.length()}); + values.push_back( + std::string{ + reinterpret_cast(iobuf.data()), iobuf.length()}); } EXPECT_EQ(expected, values); @@ -187,9 +189,7 @@ class LocalFileTest : public ::testing::TestWithParam { const bool useFaultyFs_; const std::unique_ptr executor_ = std::make_unique( - std::max( - 1, - static_cast(std::thread::hardware_concurrency() / 2)), + std::max(1, static_cast(folly::hardware_concurrency() / 2)), std::make_shared( "LocalFileReadAheadTest")); }; diff --git a/velox/common/file/tests/PlainUserNameTokenProviderTest.cpp b/velox/common/file/tests/PlainUserNameTokenProviderTest.cpp new file mode 100644 index 000000000000..a49c749bdb93 --- /dev/null +++ b/velox/common/file/tests/PlainUserNameTokenProviderTest.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/file/PlainUserNameTokenProvider.h" +#include +#include "velox/common/file/TokenProvider.h" + +using namespace ::testing; +namespace facebook::velox::filesystems { +TEST(PlainUserNameTokenProviderTest, testTokenProviderUserName) { + filesystems::AccessTokenKey key; + auto tokenProvider = + std::make_shared("test_user"); + auto baseToken = tokenProvider->getToken(key); + auto userToken = + std::dynamic_pointer_cast(baseToken); + ASSERT_EQ(userToken->getUser(), "test_user"); +} + +TEST(PlainUserNameTokenProviderTest, testTokenProviderEquals) { + auto tokenProvider1 = + std::make_shared("test_user_1"); + auto tokenProvider2 = tokenProvider1; + auto tokenProvider3 = + std::make_shared("test_user_1"); + auto tokenProvider4 = + std::make_shared("test_user_2"); + ASSERT_TRUE(tokenProvider1->equals(*tokenProvider2)); + ASSERT_TRUE(tokenProvider1->equals(*tokenProvider3)); + ASSERT_FALSE(tokenProvider1->equals(*tokenProvider4)); +} + +TEST(PlainUserNameTokenProviderTest, testTokenProviderHash) { + auto tokenProvider1 = + std::make_shared("test_user_1"); + auto tokenProvider2 = + std::make_shared("test_user_1"); + auto tokenProvider3 = + std::make_shared("test_user_2"); + ASSERT_EQ(tokenProvider1->hash(), tokenProvider2->hash()); + ASSERT_NE(tokenProvider1->hash(), tokenProvider3->hash()); +} +} // namespace facebook::velox::filesystems diff --git a/velox/common/file/tests/UtilsTest.cpp b/velox/common/file/tests/UtilsTest.cpp index a2b2d6ee5a04..f7d2e1ffa8a3 100644 --- a/velox/common/file/tests/UtilsTest.cpp +++ b/velox/common/file/tests/UtilsTest.cpp @@ -17,6 +17,7 @@ #include #include +#include "velox/common/file/Region.h" #include "velox/common/file/Utils.h" #include "velox/common/file/tests/TestUtils.h" @@ -98,11 +99,12 @@ auto getReader( /* minTailRoom*/ 0); for (size_t i = 1; i < size; ++i) { - head->appendToChain(folly::IOBuf::copyBuffer( - buf.data() + offset + i, - /* size */ 1, - /* headroom */ 0, - /* minTailRoom*/ 0)); + head->appendToChain( + folly::IOBuf::copyBuffer( + buf.data() + offset + i, + /* size */ 1, + /* headroom */ 0, + /* minTailRoom*/ 0)); } return head; } else { @@ -289,3 +291,19 @@ INSTANTIATE_TEST_SUITE_P( ReadToIOBufsTest, ValuesIn( std::vector({false, true}))); + +TEST(RegionTest, toString) { + EXPECT_EQ(Region(0, 0).toString(), "Region{offset: 0, length: 0B, label: }"); + EXPECT_EQ( + Region(100, 256).toString(), + "Region{offset: 100, length: 256B, label: }"); + EXPECT_EQ( + Region(1024, 1024, "test").toString(), + "Region{offset: 1024, length: 1.00KB, label: test}"); + EXPECT_EQ( + Region(0, 1'048'576, "stream").toString(), + "Region{offset: 0, length: 1.00MB, label: stream}"); + EXPECT_EQ( + Region(12345, 1'073'741'824).toString(), + "Region{offset: 12345, length: 1.00GB, label: }"); +} diff --git a/scripts/docs-requirements.txt b/velox/common/future/CMakeLists.txt similarity index 91% rename from scripts/docs-requirements.txt rename to velox/common/future/CMakeLists.txt index 955db01b172d..a598690b32e5 100644 --- a/scripts/docs-requirements.txt +++ b/velox/common/future/CMakeLists.txt @@ -11,8 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -sphinx -sphinx-tabs -breathe -sphinx_rtd_theme -chardet +velox_install_library_headers() diff --git a/velox/common/future/VeloxPromise.h b/velox/common/future/VeloxPromise.h index ea0d02fbef9c..7ac170a2fc11 100644 --- a/velox/common/future/VeloxPromise.h +++ b/velox/common/future/VeloxPromise.h @@ -28,12 +28,15 @@ class VeloxPromise : public folly::Promise { VeloxPromise() : folly::Promise() {} explicit VeloxPromise(const std::string& context) - : folly::Promise(), context_(context) {} + : folly::Promise(), context_(context) { + if (context.empty()) { + LOG(WARNING) + << "PROMISE: VeloxPromise must be constructed with a context."; + } + } - VeloxPromise( - folly::futures::detail::EmptyConstruct, - const std::string& context) noexcept - : folly::Promise(folly::Promise::makeEmpty()), context_(context) {} + explicit VeloxPromise(folly::futures::detail::EmptyConstruct) noexcept + : folly::Promise(folly::Promise::makeEmpty()) {} ~VeloxPromise() { if (!this->isFulfilled()) { @@ -42,7 +45,7 @@ class VeloxPromise : public folly::Promise { } } - explicit VeloxPromise(VeloxPromise&& other) noexcept + VeloxPromise(VeloxPromise&& other) noexcept : folly::Promise(std::move(other)), context_(std::move(other.context_)) {} @@ -52,8 +55,8 @@ class VeloxPromise : public folly::Promise { return *this; } - static VeloxPromise makeEmpty(const std::string& context = "") noexcept { - return VeloxPromise(folly::futures::detail::EmptyConstruct{}, context); + static VeloxPromise makeEmpty() noexcept { + return VeloxPromise(folly::futures::detail::EmptyConstruct{}); } private: @@ -65,8 +68,14 @@ using ContinuePromise = VeloxPromise; using ContinueFuture = folly::SemiFuture; /// Equivalent of folly's makePromiseContract for VeloxPromise. +/// +/// NOTE: When you already have a valid promise, just call +/// Promise::getSemiFuture() on it to get the future, instead of using this +/// function to overwrite the promise. Overwriting valid promise would cause +/// exception throwing and stack unwinding thus performance issue. See +/// https://github.com/prestodb/presto/issues/26094 for details. static inline std::pair -makeVeloxContinuePromiseContract(const std::string& promiseContext = "") { +makeVeloxContinuePromiseContract(const std::string& promiseContext) { auto p = ContinuePromise(promiseContext); auto f = p.getSemiFuture(); return std::make_pair(std::move(p), std::move(f)); diff --git a/velox/common/fuzzer/CMakeLists.txt b/velox/common/fuzzer/CMakeLists.txt index 1ecff4348f55..d2e3663b73d0 100644 --- a/velox/common/fuzzer/CMakeLists.txt +++ b/velox/common/fuzzer/CMakeLists.txt @@ -23,7 +23,8 @@ velox_link_libraries( Folly::folly velox_type velox_common_fuzzer_util - velox_functions_prestosql) + velox_functions_prestosql +) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/common/fuzzer/ConstrainedGenerators.cpp b/velox/common/fuzzer/ConstrainedGenerators.cpp index c5c65dbc5031..2e5daddc1742 100644 --- a/velox/common/fuzzer/ConstrainedGenerators.cpp +++ b/velox/common/fuzzer/ConstrainedGenerators.cpp @@ -14,11 +14,11 @@ * limitations under the License. */ -#include -#include - #include "velox/common/fuzzer/ConstrainedGenerators.h" +#include #include "velox/common/fuzzer/Utils.h" +#include "velox/common/memory/HashStringAllocator.h" +#include "velox/functions/lib/SetDigest.h" #include "velox/functions/lib/TDigest.h" #include "velox/functions/prestosql/types/BingTileType.h" @@ -49,13 +49,21 @@ folly::json::serialization_opts getSerializationOptions( folly::json::serialization_opts opts; opts.allow_non_string_keys = true; opts.allow_nan_inf = true; + opts.sort_keys = true; if (makeRandomVariation) { opts.convert_int_keys = rand(rng); opts.pretty_formatting = rand(rng); opts.pretty_formatting_indent_width = rand(rng, 0, 4); opts.encode_non_ascii = rand(rng); - opts.sort_keys = rand(rng); opts.skip_invalid_utf8 = rand(rng); + + // With 50% chance, sort object keys in reverse order. + if (rand(rng)) { + opts.sort_keys_by = [](folly::dynamic const& left, + folly::dynamic const& right) { + return right < left; + }; + } } return opts; } @@ -104,9 +112,9 @@ folly::dynamic JsonInputGenerator::convertVariantToDynamic( case TypeKind::BIGINT: return convertVariantToDynamicPrimitive(object); case TypeKind::REAL: - return convertVariantToDynamicPrimitive(object); + return convertVariantToDynamicFloatingPoint(object); case TypeKind::DOUBLE: - return convertVariantToDynamicPrimitive(object); + return convertVariantToDynamicFloatingPoint(object); case TypeKind::VARCHAR: return convertVariantToDynamicPrimitive(object); case TypeKind::VARBINARY: @@ -261,22 +269,94 @@ variant TDigestInputGenerator::generate() { return variant::null(type_->kind()); } velox::functions::TDigest<> digest; - double compression = rand(rng_, 10.0, 100.0); + double compression = rand(rng_, 10.0, 1000.0); digest.setCompression(compression); std::vector positions; - for (int i = 0; i < 10; i++) { - double value = rand(rng_, 0.0, 100.0); - int64_t weight = rand(rng_, 1, 100); + static boost::random::uniform_real_distribution valueDist( + 0.0, 1000.0); + static boost::random::uniform_int_distribution weightDist(1, 1000); + for (int i = 0; i < 100; i++) { + double value = valueDist(rng_); + int64_t weight = weightDist(rng_); digest.add(positions, value, weight); } digest.compress(positions); size_t byteSize = digest.serializedByteSize(); std::string serializedDigest(byteSize, '\0'); digest.serialize(&serializedDigest[0]); - StringView serializedView(serializedDigest.data(), serializedDigest.size()); return variant::create(serializedDigest); } +SetDigestInputGenerator::SetDigestInputGenerator( + size_t seed, + const TypePtr& type, + double nullRatio) + : AbstractInputGenerator(seed, type, nullptr, nullRatio), + pool_(velox::memory::memoryManager()->addLeafPool()), + allocator_(std::make_unique(pool_.get())) { + // SetDigest supports int64_t and StringView + static const std::vector kBaseTypes{BIGINT(), VARCHAR()}; + baseType_ = kBaseTypes[rand(rng_, 0, kBaseTypes.size() - 1)]; +} + +SetDigestInputGenerator::~SetDigestInputGenerator() = default; + +template +variant SetDigestInputGenerator::generateTyped() { + velox::functions::SetDigest digest(allocator_.get()); + + int numValues = rand(rng_, 10, 100); + for (int i = 0; i < numValues; ++i) { + int64_t value = rand(rng_); + digest.add(value); + } + + size_t byteSize = digest.estimatedSerializedSize(); + std::string serializedDigest(byteSize, '\0'); + digest.serialize(&serializedDigest[0]); + return variant::create(serializedDigest); +} + +template <> +variant SetDigestInputGenerator::generateTyped() { + velox::functions::SetDigest digest(allocator_.get()); + + int numValues = rand(rng_, 10, 100); + static const std::vector encodings{ + UTF8CharList::ASCII, + UTF8CharList::UNICODE_CASE_SENSITIVE, + UTF8CharList::EXTENDED_UNICODE, + UTF8CharList::MATHEMATICAL_SYMBOLS}; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + std::wstring_convert, char16_t> converter; +#pragma GCC diagnostic pop + + for (int i = 0; i < numValues; ++i) { + auto size = rand(rng_, 0, 100); + std::string result; + auto str = randString(rng_, size, encodings, result, converter); + digest.add(StringView(str)); + } + + size_t byteSize = digest.estimatedSerializedSize(); + std::string serializedDigest(byteSize, '\0'); + digest.serialize(&serializedDigest[0]); + return variant::create(serializedDigest); +} + +variant SetDigestInputGenerator::generate() { + if (coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + + if (baseType_->isBigint()) { + return generateTyped(); + } else { + return generateTyped(); + } +} + // BingTileInputGenerator BingTileInputGenerator::BingTileInputGenerator( @@ -303,6 +383,40 @@ int64_t BingTileInputGenerator::generateImpl() { return static_cast(BingTileType::bingTileCoordsToInt(x, y, zoom)); } +QDigestInputGenerator::QDigestInputGenerator( + size_t seed, + const TypePtr& type, + double nullRatio, + const TypePtr& qdigestType) + : AbstractInputGenerator(seed, type, nullptr, nullRatio), + qdigestType{qdigestType} {} + +QDigestInputGenerator::~QDigestInputGenerator() = default; + +variant QDigestInputGenerator::generate() { + constexpr double kAccuracy = 1.0E-3; + + if (coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + + size_t len = rand(rng_, 1, 1000); + + std::string serializedStr = [&]() { + switch (qdigestType->kind()) { + case TypeKind::BIGINT: + return createSerializedDigest(len, kAccuracy); + case TypeKind::DOUBLE: + return createSerializedDigest(len, kAccuracy); + case TypeKind::REAL: + return createSerializedDigest(len, kAccuracy); + default: + VELOX_FAIL("Unsupported type for QDigest: {}", qdigestType->toString()); + } + }(); + return variant::create(std::move(serializedStr)); +} + // Utility functions template std::unique_ptr getRandomInputGeneratorPrimitive( @@ -494,7 +608,6 @@ CastVarcharInputGenerator::CastVarcharInputGenerator( CastVarcharInputGenerator::~CastVarcharInputGenerator() = default; -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" std::string CastVarcharInputGenerator::generateValidPrimitiveAsString() { switch (castToType_->kind()) { case TypeKind::BOOLEAN: { @@ -532,7 +645,10 @@ std::string CastVarcharInputGenerator::generateValidPrimitiveAsString() { case TypeKind::VARCHAR: { // Generate random string. std::string input; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" std::wstring_convert, char16_t> converter; +#pragma GCC diagnostic pop auto randomStr = randString( rng_, rand(rng_, 0, 20), @@ -543,9 +659,10 @@ std::string CastVarcharInputGenerator::generateValidPrimitiveAsString() { } default: // cast from varchar doesn't support complex types - VELOX_FAIL_UNSUPPORTED_INPUT_UNCATCHABLE(fmt::format( - "Type `{}` not supported for cast varchar custom generator", - castToType_->kind())); + VELOX_FAIL_UNSUPPORTED_INPUT_UNCATCHABLE( + fmt::format( + "Type `{}` not supported for cast varchar custom generator", + castToType_->kind())); } } @@ -588,4 +705,216 @@ variant CastVarcharInputGenerator::generate() { return variant(input); } +// URLInputGenerator creates URL input data for URL functions. +URLInputGenerator::URLInputGenerator( + size_t seed, + const TypePtr& type, + double nullRatio, + std::string functionName, + std::vector functionsToSkipForMailTo, + std::vector functionsToSkipForTruncate) + : AbstractInputGenerator(seed, type, nullptr, nullRatio), + functionName_{std::move(functionName)}, + functionsToSkipForMailTo_{std::move(functionsToSkipForMailTo)}, + functionsToSkipForTruncate_{std::move(functionsToSkipForTruncate)} {} + +URLInputGenerator::~URLInputGenerator() = default; + +std::shared_ptr URLInputGenerator::generateURLRules() { + auto url = RuleList({ + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("http"), + std::make_shared("https"), + })), + std::make_shared("://"), + std::make_shared(rng_), // domain + std::make_shared("."), + std::make_shared(rng_, 2, 3, true), + std::make_shared( // port + rng_, + std::make_shared(std::vector>{ + std::make_shared(":"), + std::make_shared(rng_, 3, 7, true)})), + std::make_shared( // path + rng_, + std::make_shared(std::vector>{ + std::make_shared("/"), + std::make_shared(rng_), + }), + 0, + 5), + std::make_shared( // query + rng_, + std::make_shared(std::vector>{ + std::make_shared("?"), + std::make_shared(rng_), + std::make_shared("="), + std::make_shared(rng_), + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("&"), + std::make_shared(rng_), + std::make_shared("="), + std::make_shared(rng_)}), + 1, + 3), + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("["), + std::make_shared("]")}), + 0, + 3), + })), + std::make_shared( // fragment + rng_, + std::make_shared(std::vector>{ + std::make_shared("#"), + std::make_shared(rng_), + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("["), + std::make_shared("]")}), + 0, + 3), + })), + }); + + return std::make_shared(url); +} + +std::shared_ptr URLInputGenerator::generateChromeExtensionRules() { + auto chrome_extension = RuleList({ + std::make_shared("chrome-extension:/"), + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("/"), + std::make_shared(rng_)}), + 1, + 3), + std::make_shared("/"), + std::make_shared(rng_), + std::make_shared(".html"), + }); + + return std::make_shared(chrome_extension); +} + +std::shared_ptr URLInputGenerator::generateMailToRules() { + auto mailTo = RuleList({ + std::make_shared("mailto:"), + std::make_shared(rng_), + std::make_shared("@"), + std::make_shared(rng_), + std::make_shared("."), + std::make_shared(rng_, 2, 3, true), + std::make_shared( // subject + rng_, + std::make_shared(std::vector>{ + std::make_shared("&"), + std::make_shared("subject"), + std::make_shared("="), + std::make_shared(rng_), + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("\%20"), + std::make_shared(rng_)}), + 1, + 3)})), + std::make_shared( // recipients + rng_, + std::make_shared(std::vector>{ + std::make_shared("&"), + std::make_shared("cc"), + std::make_shared("="), + std::make_shared(rng_), + std::make_shared("@"), + std::make_shared(rng_), + std::make_shared("."), + std::make_shared(rng_, 2, 3, true), + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("&"), + std::make_shared("bcc"), + std::make_shared("="), + std::make_shared(rng_), + std::make_shared("@"), + std::make_shared(rng_), + std::make_shared("."), + std::make_shared(rng_, 2, 3, true), + }), + 1, + 3)})), + std::make_shared( // body + rng_, + std::make_shared(std::vector>{ + std::make_shared("&"), + std::make_shared("body"), + std::make_shared("="), + std::make_shared(rng_), + std::make_shared( + rng_, + std::make_shared(std::vector>{ + std::make_shared("\%20"), + std::make_shared(rng_)}), + 1, + 3)})), + }); + + return std::make_shared(mailTo); +} + +variant URLInputGenerator::generate() { + // Randomly add nulls. + if (coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + + std::vector> rules; + rules.push_back(generateURLRules()); + rules.push_back(generateChromeExtensionRules()); + if (std::find( + functionsToSkipForMailTo_.begin(), + functionsToSkipForMailTo_.end(), + functionName_) == functionsToSkipForMailTo_.end()) { + rules.push_back(generateMailToRules()); + } + + auto choices = ChoiceRule(rng_, std::make_shared(rules)); + + auto input = choices.generate(); + + // Make additional random variations to valid input data to see how these + // functions process them. + if (coinToss(rng_, 0.2)) { + makeRandomStrVariation( + input, + rng_, + RandomStrVariationOptions{ + 0.0, + 0.1, + std::find( + functionsToSkipForTruncate_.begin(), + functionsToSkipForTruncate_.end(), + functionName_) != functionsToSkipForTruncate_.end() + ? 0.0 + : 0.1}); + } + + // Intentionally ignore these cases due to intentional differences, see + // https://github.com/facebookincubator/velox/issues/14204 + if (input == "https://" || input == "https:") { + return variant::null(type_->kind()); + } + + return variant(input); +} + } // namespace facebook::velox::fuzzer diff --git a/velox/common/fuzzer/ConstrainedGenerators.h b/velox/common/fuzzer/ConstrainedGenerators.h index 70bc1026d9ea..930fd2bade6a 100644 --- a/velox/common/fuzzer/ConstrainedGenerators.h +++ b/velox/common/fuzzer/ConstrainedGenerators.h @@ -21,6 +21,8 @@ #include "folly/json.h" #include "velox/common/fuzzer/Utils.h" +#include "velox/common/memory/HashStringAllocator.h" +#include "velox/functions/lib/QuantileDigest.h" #include "velox/type/Type.h" #include "velox/type/Variant.h" @@ -81,7 +83,7 @@ template class RandomInputGenerator>> : public AbstractInputGenerator { public: - RandomInputGenerator>>( + RandomInputGenerator( size_t seed, const TypePtr& type, double nullRatio, @@ -97,8 +99,7 @@ class RandomInputGenerator>> encodings_{encodings}, randomStrVariationOptions_{randomStrVariationOptions} {} - ~RandomInputGenerator>>() - override = default; + ~RandomInputGenerator() override = default; variant generate() override { if (coinToss(rng_, nullRatio_)) { @@ -125,7 +126,7 @@ template class RandomInputGenerator>> : public AbstractInputGenerator { public: - RandomInputGenerator>>( + RandomInputGenerator( size_t seed, const TypePtr& type, double nullRatio, @@ -142,8 +143,7 @@ class RandomInputGenerator>> containAtIndex_{containAtIndex}, containGenerator_{std::move(containGenerator)} {} - ~RandomInputGenerator>>() - override = default; + ~RandomInputGenerator() override = default; variant generate() override { if (coinToss(rng_, nullRatio_)) { @@ -179,7 +179,7 @@ template class RandomInputGenerator>> : public AbstractInputGenerator { public: - RandomInputGenerator>>( + RandomInputGenerator( size_t seed, const TypePtr& type, double nullRatio, @@ -206,8 +206,7 @@ class RandomInputGenerator>> } } - ~RandomInputGenerator>>() - override = default; + ~RandomInputGenerator() override = default; variant generate() override { if (coinToss(rng_, nullRatio_)) { @@ -247,7 +246,7 @@ template class RandomInputGenerator>> : public AbstractInputGenerator { public: - RandomInputGenerator>>( + RandomInputGenerator( size_t seed, const TypePtr& type, std::vector> fieldGenerators, @@ -266,8 +265,7 @@ class RandomInputGenerator>> } } - ~RandomInputGenerator>>() - override = default; + ~RandomInputGenerator() override = default; variant generate() override { if (coinToss(rng_, nullRatio_)) { @@ -383,6 +381,21 @@ class JsonInputGenerator : public AbstractInputGenerator { return folly::dynamic(value); } + // Presto and Velox JSON parser have different behavior for floating point + // with magnitudes greater than equal to 10^-3 and less than 10^7. Clamp + // values to avoid scientific notation when comparing JSON parse results. + template + folly::dynamic convertVariantToDynamicFloatingPoint(const variant& v) { + using T = typename TypeTraits::DeepCopiedType; + VELOX_CHECK(v.isSet()); + const T value = v.value(); + const T absValue = std::abs(value); + const T sign = value < 0 ? static_cast(-1.0) : static_cast(1.0); + const T clampedValue = + std::clamp(absValue, static_cast(1e-3), static_cast(1e7 - 1)); + return folly::dynamic(sign * clampedValue); + } + folly::dynamic convertVariantToDynamic(const variant& object); void makeRandomVariation(std::string& json); @@ -424,7 +437,6 @@ class PhoneNumberInputGenerator : public AbstractInputGenerator { /// them: /// - On strings, arrays, and objects: .length() /// - On arrays: [begin:end:step] - class JsonPathGenerator : public AbstractInputGenerator { public: JsonPathGenerator( @@ -480,6 +492,33 @@ class CastVarcharInputGenerator : public AbstractInputGenerator { std::string generateValidPrimitiveAsString(); }; +class URLInputGenerator : public AbstractInputGenerator { + public: + URLInputGenerator( + size_t seed, + const TypePtr& type, + double nullRatio, + std::string functionName, + std::vector functionsToSkipForMailTo, + std::vector functionsToSkipForTruncate); + + ~URLInputGenerator() override; + + variant generate() override; + + private: + std::shared_ptr generateURLRules(); + std::shared_ptr generateMailToRules(); + std::shared_ptr generateChromeExtensionRules(); + + const std::string functionName_; + // Particular UDFs are known to have mismatches for mailto and trucated input. + // Let's skip those test cases for now. More info can be found in + // https://github.com/facebookincubator/velox/issues/14204. + const std::vector functionsToSkipForMailTo_; + const std::vector functionsToSkipForTruncate_; +}; + class TDigestInputGenerator : public AbstractInputGenerator { public: TDigestInputGenerator(size_t seed, const TypePtr& type, double nullRatio); @@ -489,6 +528,23 @@ class TDigestInputGenerator : public AbstractInputGenerator { variant generate() override; }; +class SetDigestInputGenerator : public AbstractInputGenerator { + public: + SetDigestInputGenerator(size_t seed, const TypePtr& type, double nullRatio); + + ~SetDigestInputGenerator() override; + + variant generate() override; + + private: + template + variant generateTyped(); + + TypePtr baseType_; + std::shared_ptr pool_; + std::unique_ptr allocator_; +}; + class BingTileInputGenerator : public AbstractInputGenerator { public: BingTileInputGenerator(size_t seed, const TypePtr& type, double nullRatio); @@ -500,4 +556,59 @@ class BingTileInputGenerator : public AbstractInputGenerator { private: int64_t generateImpl(); }; + +class QDigestInputGenerator : public AbstractInputGenerator { + public: + QDigestInputGenerator( + size_t seed, + const TypePtr& type, + double nullRatio, + const TypePtr& qdigestType); + + ~QDigestInputGenerator() override; + + variant generate() override; + + private: + const TypePtr qdigestType; + + template + std::vector generateRandomValue(size_t len) { + std::vector values; + values.reserve(len); + + auto makeDist = []() { + if constexpr (std::is_integral_v) { + return std::uniform_int_distribution(0, 10000); + } else { + return std::uniform_real_distribution(0.0, 10000.0); + } + }; + + auto dist = makeDist(); + for (size_t i = 0; i < len; ++i) { + values.push_back(dist(rng_)); + } + return values; + } + + template + std::string createSerializedDigest(size_t len, double accuracy) { + using facebook::velox::functions::qdigest::QuantileDigest; + + std::allocator allocator; + QuantileDigest> digest(allocator, accuracy); + + const auto input = generateRandomValue(len); + auto dist = boost::random::uniform_real_distribution(1.0, 100.0); + for (const auto& value : input) { + digest.add(value, dist(rng_)); + } + const auto serializedSize = digest.serializedByteSize(); + std::vector serializedData(serializedSize); + digest.serialize(serializedData.data()); + + return std::string(serializedData.begin(), serializedData.end()); + } +}; } // namespace facebook::velox::fuzzer diff --git a/velox/common/fuzzer/Utils.cpp b/velox/common/fuzzer/Utils.cpp index 7cdb14e9ca4f..5d8394d84de0 100644 --- a/velox/common/fuzzer/Utils.cpp +++ b/velox/common/fuzzer/Utils.cpp @@ -15,6 +15,8 @@ */ #include "velox/common/fuzzer/Utils.h" +#include +#include "velox/type/Time.h" namespace facebook::velox::fuzzer { @@ -152,6 +154,15 @@ static const std::vector>> {u'\u27C0', u'\u27EF'}, // Math Symbols. {u'\u2A00', u'\u2AFF'}, // Supplemental. }, + // UTF8CharList::ALPHABETIC + { + {u'A', u'Z'}, // Uppercase alphabetic characters. + {u'a', u'z'}, // Lowercase alphabetic characters. + }, + // UTF8CharList::NUMERIC + { + {u'0', u'9'}, // Numeric characters. + }, }; FOLLY_ALWAYS_INLINE char16_t getRandomChar( @@ -193,4 +204,38 @@ std::string randString( } #pragma GCC diagnostic pop +int16_t generateRandomTimezoneOffset( + FuzzerGenerator& rng, + double frequentlyUsedProbability) { + // 25% probability: pick from frequently used offsets + // 75% probability: generate random offset in range [-840, 840] + if (coinToss(rng, frequentlyUsedProbability)) { + auto index = + rand(rng, 0, kFrequentlyUsedTimezoneOffsets.size() - 1); + return kFrequentlyUsedTimezoneOffsets[index]; + } else { + return rand(rng, -util::kTimeZoneBias, util::kTimeZoneBias); + } +} + +std::string timezoneOffsetToString(int16_t offsetMinutes) { + // Validate range [-840, 840] + VELOX_USER_CHECK( + offsetMinutes >= -util::kTimeZoneBias && + offsetMinutes <= util::kTimeZoneBias, + "Timezone offset {} minutes is out of range [-840, 840]", + offsetMinutes); + + // Determine sign + char sign = (offsetMinutes >= 0) ? '+' : '-'; + + // Calculate hours and minutes using absolute value + int16_t absOffset = std::abs(offsetMinutes); + int16_t hours = absOffset / util::kMinutesInHour; + int16_t minutes = absOffset % util::kMinutesInHour; + + // Format as "+HH:mm" or "-HH:mm" with zero-padding + return fmt::format("{}{:02d}:{:02d}", sign, hours, minutes); +} + } // namespace facebook::velox::fuzzer diff --git a/velox/common/fuzzer/Utils.h b/velox/common/fuzzer/Utils.h index 951ef7564fd0..5c011321b223 100644 --- a/velox/common/fuzzer/Utils.h +++ b/velox/common/fuzzer/Utils.h @@ -31,11 +31,24 @@ namespace facebook::velox::fuzzer { using FuzzerGenerator = folly::detail::DefaultGenerator; +// Frequently used timezone offsets in minutes (US timezones including DST) +// -240 = UTC-4:00 (EDT) +// -300 = UTC-5:00 (EST/CDT) +// -360 = UTC-6:00 (CST/MDT) +// -420 = UTC-7:00 (MST/PDT) +// -480 = UTC-8:00 (PST) +// -540 = UTC-9:00 (AKST) +// -600 = UTC-10:00 (HST) +constexpr std::array kFrequentlyUsedTimezoneOffsets = + {-240, -300, -360, -420, -480, -540, -600}; + enum UTF8CharList { ASCII = 0, // Ascii character set. UNICODE_CASE_SENSITIVE = 1, // Unicode scripts that support case. EXTENDED_UNICODE = 2, // Extended Unicode: Arabic, Devanagiri etc - MATHEMATICAL_SYMBOLS = 3 // Mathematical Symbols. + MATHEMATICAL_SYMBOLS = 3, // Mathematical Symbols. + ALPHABETIC = 4, // Alphabetic Symbols. + NUMERIC = 5 // Numeric Symbols. }; bool coinToss(FuzzerGenerator& rng, double threshold); @@ -165,6 +178,33 @@ inline Timestamp rand(FuzzerGenerator& rng, DataSpec /*dataSpec*/) { int32_t randDate(FuzzerGenerator& rng); +/// Generate random timezone offset using biased distribution +/// 25% probability: picks from frequently used offsets +/// 75% probability: generates random offset from [-840, 840] +/// +/// @param rng Random number generator +/// @param frequentlyUsedProbability Probability of selecting from frequently +/// used offsets (default 0.25 for 25%) +/// @return Timezone offset in minutes [-840, 840] +int16_t generateRandomTimezoneOffset( + FuzzerGenerator& rng, + double frequentlyUsedProbability = 0.25); + +/// Convert timezone offset in minutes to "+HH:mm" or "-HH:mm" format +/// Always uses Presto-compatible +HH:mm format (never +HH or +HHmm) +/// +/// Examples: +/// - 0 → "+00:00" +/// - 330 → "+05:30" +/// - -300 → "-05:00" +/// - 840 → "+14:00" +/// - -840 → "-14:00" +/// +/// @param offsetMinutes Timezone offset in minutes [-840, 840] +/// @return Timezone offset string in "+HH:mm" or "-HH:mm" format +/// @throws VeloxException if offsetMinutes is out of range [-840, 840] +std::string timezoneOffsetToString(int16_t offsetMinutes); + template < typename T, typename std::enable_if_t, int> = 0> @@ -188,4 +228,215 @@ std::string randString( std::wstring_convert, char16_t>& converter); #pragma GCC diagnostic pop +// Beginning of rule or grammar based input generation. The following set of +// classes will allow us to use a rule based approach to generate random inputs. + +// Class rule defines abstract function generate which outputs our input string +// dependent on the defined rule instantiation. +class Rule { + public: + Rule() = default; + + virtual std::string generate() { + return ""; + } + + virtual ~Rule() = default; +}; + +// List of Rules. +class RuleList : public Rule { + public: + explicit RuleList(std::vector> rules) + : rules_{std::move(rules)} {} + + std::string generate() override { + std::string string; + for (auto& rule : rules_) { + string += rule->generate(); + } + return string; + } + + size_t size() { + return rules_.size(); + } + + std::shared_ptr operator[](size_t index) const { + if (index >= rules_.size()) { + throw std::out_of_range("Index out of range"); + } + return rules_[index]; + } + + virtual ~RuleList() = default; + + private: + std::vector> rules_; +}; + +// Rules that are randomly included. +class OptionalRule : public Rule { + public: + OptionalRule(FuzzerGenerator& rng, std::shared_ptr rule) + : rng_(rng), rule_(std::move(rule)) {} + + std::string generate() override { + return coinToss(rng_, .5) ? rule_->generate() : ""; + } + + virtual ~OptionalRule() = default; + + private: + FuzzerGenerator& rng_; + std::shared_ptr rule_; +}; + +// Rules that can repeat 'count' times. +class RepeatingRule : public Rule { + public: + RepeatingRule( + FuzzerGenerator& rng, + std::shared_ptr rule, + size_t minCount, + size_t maxCount) + : rng_(rng), + rule_(std::move(rule)), + minCount_(minCount), + maxCount_(maxCount) {} + + std::string generate() override { + auto repetitions = rand(rng_, minCount_, maxCount_); + std::string string; + for (size_t i = 0; i < repetitions; ++i) { + string += rule_->generate(); + } + return string; + } + + virtual ~RepeatingRule() = default; + + private: + FuzzerGenerator& rng_; + std::shared_ptr rule_; + size_t minCount_, maxCount_; +}; + +// Rule randomly chosen from list. +class ChoiceRule : public Rule { + public: + ChoiceRule(FuzzerGenerator& rng, std::shared_ptr rules) + : rng_(rng), rules_(std::move(rules)) {} + + std::string generate() override { + auto index = rand(rng_, 0, rules_->size() - 1); + return (*rules_)[index]->generate(); + } + + virtual ~ChoiceRule() = default; + + private: + FuzzerGenerator& rng_; + std::shared_ptr rules_; +}; + +// Simple rule that is just a constant string. +class ConstantRule : public Rule { + public: + explicit ConstantRule(std::string constant) + : constant_(std::move(constant)) {} + + std::string generate() override { + return constant_; + } + + virtual ~ConstantRule() = default; + + private: + std::string constant_; +}; + +// String of 1 to 20 characters long that generates ASCII characters. Size and +// character list can be modified. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +class StringRule : public Rule { + public: + explicit StringRule( + FuzzerGenerator& rng, + std::vector encodings = {UTF8CharList::ASCII}, + size_t minLength = 1, + size_t maxLength = 20, + bool flexibleLength = true) + : rng_(rng), + encodings_(std::move(encodings)), + minLength_(minLength), + maxLength_(maxLength), + flexibleLength_(flexibleLength) {} + + std::string generate() override { + auto length = maxLength_; + if (flexibleLength_) { + length = rand(rng_, minLength_, maxLength_); + } + return randString(rng_, length, encodings_, buf_, converter_); + } + + virtual ~StringRule() = default; + + private: + FuzzerGenerator& rng_; + std::vector encodings_; + size_t minLength_; + size_t maxLength_; + bool flexibleLength_; + + std::wstring_convert, char16_t> converter_; + std::string buf_; +}; +#pragma GCC diagnostic pop + +// Extends from StringRule, uses only alphabetic characters +class WordRule : public StringRule { + public: + explicit WordRule(FuzzerGenerator& rng) + : StringRule(rng, {UTF8CharList::ALPHABETIC}) {} + + WordRule( + FuzzerGenerator& rng, + size_t minLength, + size_t maxLength, + bool flexibleLength) + : StringRule( + rng, + {UTF8CharList::ALPHABETIC}, + minLength, + maxLength, + flexibleLength) {} + + virtual ~WordRule() = default; +}; + +// Extends from StringRule, uses only numeric characters for random number +// generation with a particular size. +class NumRule : public StringRule { + public: + explicit NumRule(FuzzerGenerator& rng) + : StringRule(rng, {UTF8CharList::NUMERIC}) {} + + NumRule( + FuzzerGenerator& rng, + size_t minLength, + size_t maxLength, + bool flexibleLength) + : StringRule( + rng, + {UTF8CharList::NUMERIC}, + minLength, + maxLength, + flexibleLength) {} + + virtual ~NumRule() = default; +}; + } // namespace facebook::velox::fuzzer diff --git a/velox/common/fuzzer/tests/CMakeLists.txt b/velox/common/fuzzer/tests/CMakeLists.txt index 88e9e2875acb..9278ec7fb8d4 100644 --- a/velox/common/fuzzer/tests/CMakeLists.txt +++ b/velox/common/fuzzer/tests/CMakeLists.txt @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_constrained_input_generators_test - ConstrainedGeneratorsTest.cpp) +add_executable(velox_constrained_input_generators_test ConstrainedGeneratorsTest.cpp) add_test( NAME velox_constrained_input_generators_test COMMAND velox_constrained_input_generators_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_constrained_input_generators_test @@ -28,4 +28,5 @@ target_link_libraries( velox_type GTest::gtest GTest::gtest_main - GTest::gmock) + GTest::gmock +) diff --git a/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp b/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp index b36801de5695..1defe5db7b2a 100644 --- a/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp +++ b/velox/common/fuzzer/tests/ConstrainedGeneratorsTest.cpp @@ -21,6 +21,8 @@ #include "velox/common/memory/Memory.h" #include "velox/functions/prestosql/json/JsonExtractor.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/QDigestType.h" +#include "velox/functions/prestosql/types/SetDigestType.h" #include "velox/functions/prestosql/types/TDigestType.h" #include "velox/type/Variant.h" @@ -404,9 +406,10 @@ TEST_F(ConstrainedGeneratorsTest, jsonPath) { const auto jsonPath = jsonPathGenerator->generate(); if (jsonPath.hasValue()) { if (json.hasValue()) { - EXPECT_NO_THROW(functions::jsonExtract( - json.value(), - jsonPath.value())); + EXPECT_NO_THROW( + functions::jsonExtract( + json.value(), + jsonPath.value())); } } else { hasNull = true; @@ -422,4 +425,29 @@ TEST_F(ConstrainedGeneratorsTest, tdigest) { EXPECT_EQ(value.kind(), TypeKind::VARBINARY); } +TEST_F(ConstrainedGeneratorsTest, setdigest) { + std::unique_ptr generator = + std::make_unique(0, SETDIGEST(), 0.4); + auto value = generator->generate(); + EXPECT_EQ(value.kind(), TypeKind::VARBINARY); +} + +TEST_F(ConstrainedGeneratorsTest, qdigest) { + std::unique_ptr generator = + std::make_unique( + 0, QDIGEST(DOUBLE()), 0.4, DOUBLE()); + auto value = generator->generate(); + EXPECT_EQ(value.kind(), TypeKind::VARBINARY); + + generator = + std::make_unique(0, QDIGEST(REAL()), 0.4, REAL()); + value = generator->generate(); + EXPECT_EQ(value.kind(), TypeKind::VARBINARY); + + generator = std::make_unique( + 0, QDIGEST(BIGINT()), 0.4, BIGINT()); + value = generator->generate(); + EXPECT_EQ(value.kind(), TypeKind::VARBINARY); +} + } // namespace facebook::velox::fuzzer::test diff --git a/velox/common/fuzzer/tests/UtilsTest.cpp b/velox/common/fuzzer/tests/UtilsTest.cpp new file mode 100644 index 000000000000..5e82bdc85d62 --- /dev/null +++ b/velox/common/fuzzer/tests/UtilsTest.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/fuzzer/Utils.h" + +#include + +#include +#include "velox/type/Variant.h" + +namespace facebook::velox::fuzzer::test { + +class UtilsTest : public testing::Test {}; + +TEST_F(UtilsTest, testRuleList) { + auto simple = RuleList( + std::vector>{ + std::make_shared("Hello"), + std::make_shared(","), + std::make_shared(" "), + std::make_shared("world"), + std::make_shared("!"), + }); + ASSERT_EQ(simple.generate(), "Hello, world!"); + + FuzzerGenerator rng; + auto fuzz = RuleList( + std::vector>{ + std::make_shared("Hello"), + std::make_shared(","), + std::make_shared(" "), + std::make_shared(rng), + std::make_shared("!"), + }); + ASSERT_TRUE( + std::regex_match(fuzz.generate(), std::regex("Hello, \\w{1,20}!"))); +} + +TEST_F(UtilsTest, testOptionalRule) { + FuzzerGenerator rng; + auto simple = std::make_shared( + rng, + std::make_shared(std::vector>{ + std::make_shared("a"), + })); + ASSERT_TRUE(std::regex_match(simple->generate(), std::regex("^(|a)$"))); +} + +TEST_F(UtilsTest, testRepeatingRule) { + FuzzerGenerator rng; + auto simple = std::make_shared( + rng, + std::make_shared(std::vector>{ + std::make_shared("a"), + }), + 2, + 5); + ASSERT_TRUE(std::regex_match(simple->generate(), std::regex("^a{2,5}$"))); + + auto fuzz = std::make_shared( + rng, + std::make_shared(std::vector>{ + std::make_shared("a"), + std::make_shared(rng, 1, 1, false), + }), + 2, + 5); + ASSERT_TRUE(std::regex_match(fuzz->generate(), std::regex("^\\w{4,10}$"))); +} + +TEST_F(UtilsTest, testConstantRule) { + auto rule = std::make_shared("a"); + ASSERT_EQ(rule->generate(), "a"); + + auto rule_list = RuleList( + std::vector>{ + std::make_shared("a"), + std::make_shared("b"), + std::make_shared("c")}); + ASSERT_EQ(rule_list.generate(), "abc"); +} + +TEST_F(UtilsTest, testStringRule) { + FuzzerGenerator rng; + auto simple = std::make_shared(rng); + ASSERT_TRUE( + std::regex_match( + simple->generate(), std::regex("^[\x21-\x7F]+$"))); // printable ascii + ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\w+$"))); + ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\d+$"))); + + auto specified_flexible = std::make_shared( + rng, std::vector{UTF8CharList::ASCII}, 3, 7, true); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{0,2}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^[\\x21-\\x7F]{8,}$"))); + + auto specified_strict = std::make_shared( + rng, std::vector{UTF8CharList::ASCII}, 3, 7, false); + ASSERT_TRUE( + std::regex_match( + specified_strict->generate(), std::regex("^[\x21-\x7F]{7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_strict->generate(), std::regex("^[\x21-\x7F]{8}$"))); +} + +TEST_F(UtilsTest, testWordRule) { + FuzzerGenerator rng; + auto simple = std::make_shared(rng); + ASSERT_TRUE(std::regex_match(simple->generate(), std::regex("^[a-zA-Z]+$"))); + ASSERT_TRUE(std::regex_match(simple->generate(), std::regex("^\\w+$"))); + ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\d+$"))); + ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\W+$"))); + + auto specified_flexible = std::make_shared(rng, 3, 7, true); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^[a-zA-Z]{3,7}$"))); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{0,2}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{8,}$"))); + + auto specified_strict = std::make_shared(rng, 3, 7, false); + ASSERT_TRUE( + std::regex_match(specified_strict->generate(), std::regex("^\\w{7}$"))); + ASSERT_FALSE( + std::regex_match(specified_strict->generate(), std::regex("^\\w{8}$"))); +} + +TEST_F(UtilsTest, testNumRule) { + FuzzerGenerator rng; + auto simple = std::make_shared(rng); + ASSERT_TRUE(std::regex_match(simple->generate(), std::regex("^\\d+$"))); + ASSERT_FALSE(std::regex_match(simple->generate(), std::regex("^\\D+$"))); + + auto specified_flexible = std::make_shared(rng, 3, 7, true); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{3,7}$"))); + ASSERT_TRUE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\w{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\D{3,7}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{0,2}$"))); + ASSERT_FALSE( + std::regex_match( + specified_flexible->generate(), std::regex("^\\d{8,}$"))); + + auto specified_strict = std::make_shared(rng, 3, 7, false); + ASSERT_TRUE( + std::regex_match(specified_strict->generate(), std::regex("^\\d{7}$"))); + ASSERT_FALSE( + std::regex_match(specified_strict->generate(), std::regex("^\\d{8}$"))); +} + +} // namespace facebook::velox::fuzzer::test diff --git a/CMake/resolve_dependency_modules/cpr/cpr-remove-sancheck.patch b/velox/common/geospatial/CMakeLists.txt similarity index 66% rename from CMake/resolve_dependency_modules/cpr/cpr-remove-sancheck.patch rename to velox/common/geospatial/CMakeLists.txt index 4fca92831a20..515f0021d152 100644 --- a/CMake/resolve_dependency_modules/cpr/cpr-remove-sancheck.patch +++ b/velox/common/geospatial/CMakeLists.txt @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# This hangs on CI and is not needed #9116 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -84,7 +84,6 @@ endif() - include(GNUInstallDirs) - include(FetchContent) - include(cmake/code_coverage.cmake) --include(cmake/sanitizer.cmake) - include(cmake/clear_variable.cmake) - # So CMake can find FindMbedTLS.cmake +if(VELOX_ENABLE_GEO) + velox_add_library(velox_common_geospatial_serde GeometrySerde.cpp) + velox_link_libraries(velox_common_geospatial_serde velox_expression GEOS::geos) +endif() + +velox_install_library_headers() + +if(${VELOX_BUILD_TESTING}) + add_subdirectory(tests) +endif() diff --git a/velox/common/geospatial/GeometryConstants.h b/velox/common/geospatial/GeometryConstants.h new file mode 100644 index 000000000000..d598f29e0262 --- /dev/null +++ b/velox/common/geospatial/GeometryConstants.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// This file contains constats for working with geospatial queries. +// They _must not_ require the GEOS library (or any 3p library). + +namespace facebook::velox::common::geospatial { + +enum class GeometrySerializationType : uint8_t { + POINT = 0, + MULTI_POINT = 1, + LINE_STRING = 2, + MULTI_LINE_STRING = 3, + POLYGON = 4, + MULTI_POLYGON = 5, + GEOMETRY_COLLECTION = 6, + ENVELOPE = 7 +}; + +enum class EsriShapeType : uint32_t { + POINT = 1, + POLYLINE = 3, + POLYGON = 5, + MULTI_POINT = 8 +}; + +/// Latitude/Longitude range constraints for spherical coordinates. +constexpr double kMinLatitude = -90.0; +constexpr double kMaxLatitude = 90.0; +constexpr double kMinLongitude = -180.0; +constexpr double kMaxLongitude = 180.0; + +/// BingTile-specific latitude constraints (narrower than standard lat/long). +constexpr double kMinBingTileLatitude = -85.05112878; +constexpr double kMaxBingTileLatitude = 85.05112878; + +} // namespace facebook::velox::common::geospatial diff --git a/velox/common/geospatial/GeometrySerde.cpp b/velox/common/geospatial/GeometrySerde.cpp new file mode 100644 index 000000000000..cda4e5195cf2 --- /dev/null +++ b/velox/common/geospatial/GeometrySerde.cpp @@ -0,0 +1,323 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "velox/common/base/IOUtils.h" +#include "velox/common/geospatial/GeometrySerde.h" + +using facebook::velox::common::InputByteStream; +using facebook::velox::common::geospatial::EsriShapeType; +using facebook::velox::common::geospatial::GeometrySerializationType; + +namespace facebook::velox::common::geospatial { + +geos::geom::GeometryFactory* GeometryDeserializer::getGeometryFactory() { + thread_local static geos::geom::GeometryFactory::Ptr geometryFactory = + geos::geom::GeometryFactory::create(); + return geometryFactory.get(); +} + +std::unique_ptr GeometryDeserializer::readGeometry( + velox::common::InputByteStream& stream, + size_t size) { + auto geometryType = static_cast( + stream.read()); + switch (geometryType) { + case GeometrySerializationType::POINT: + return readPoint(stream); + case GeometrySerializationType::MULTI_POINT: + return readMultiPoint(stream); + case GeometrySerializationType::LINE_STRING: + return readPolyline(stream, false); + case GeometrySerializationType::MULTI_LINE_STRING: + return readPolyline(stream, true); + case GeometrySerializationType::POLYGON: + return readPolygon(stream, false); + case GeometrySerializationType::MULTI_POLYGON: + return readPolygon(stream, true); + case GeometrySerializationType::ENVELOPE: + return readEnvelope(stream); + case GeometrySerializationType::GEOMETRY_COLLECTION: + return readGeometryCollection(stream, size); + default: + VELOX_FAIL( + "Unrecognized geometry type: {}", static_cast(geometryType)); + } +} + +const std::unique_ptr +GeometryDeserializer::deserializeEnvelope(const StringView& geometry) { + velox::common::InputByteStream inputStream(geometry.data()); + auto geometryType = inputStream.read(); + + switch (geometryType) { + case GeometrySerializationType::POINT: + return std::make_unique( + *readPoint(inputStream)->getEnvelopeInternal()); + case GeometrySerializationType::MULTI_POINT: + case GeometrySerializationType::LINE_STRING: + case GeometrySerializationType::MULTI_LINE_STRING: + case GeometrySerializationType::POLYGON: + case GeometrySerializationType::MULTI_POLYGON: + skipEsriType(inputStream); + return deserializeEnvelope(inputStream); + case GeometrySerializationType::ENVELOPE: + return deserializeEnvelope(inputStream); + case GeometrySerializationType::GEOMETRY_COLLECTION: + return std::make_unique( + *readGeometryCollection(inputStream, geometry.size()) + ->getEnvelopeInternal()); + default: + VELOX_FAIL( + "Unrecognized geometry type: {}", static_cast(geometryType)); + } +} + +std::unique_ptr GeometryDeserializer::deserializeEnvelope( + velox::common::InputByteStream& input) { + auto xMin = input.read(); + auto yMin = input.read(); + auto xMax = input.read(); + auto yMax = input.read(); + + if (FOLLY_UNLIKELY( + isEsriNaN(xMin) || isEsriNaN(yMin) || isEsriNaN(xMax) || + isEsriNaN(yMax))) { + return std::make_unique(); + } + + return std::make_unique(xMin, xMax, yMin, yMax); +} + +geos::geom::Coordinate GeometryDeserializer::readCoordinate( + velox::common::InputByteStream& input) { + auto x = input.read(); + auto y = input.read(); + return {x, y}; +} + +std::unique_ptr +GeometryDeserializer::readCoordinates( + velox::common::InputByteStream& input, + size_t count) { + auto coords = std::make_unique(count, 2); + for (size_t i = 0; i < count; ++i) { + // TODO: Consider using setOrdinate if there's a performance issue. + coords->setAt(readCoordinate(input), i); + } + return coords; +} + +std::unique_ptr GeometryDeserializer::readPoint( + velox::common::InputByteStream& input) { + geos::geom::Coordinate coordinate = readCoordinate(input); + if (std::isnan(coordinate.x) || std::isnan(coordinate.y)) { + return getGeometryFactory()->createPoint(); + } + return std::unique_ptr( + getGeometryFactory()->createPoint(coordinate)); +} + +std::unique_ptr GeometryDeserializer::readMultiPoint( + velox::common::InputByteStream& input) { + skipEsriType(input); + skipEnvelope(input); + size_t pointCount = input.read(); + auto coords = readCoordinates(input, pointCount); + std::vector> points; + points.reserve(coords->size()); + for (size_t i = 0; i < coords->size(); ++i) { + points.push_back( + std::unique_ptr(getGeometryFactory()->createPoint( + geos::geom::Coordinate(coords->getX(i), coords->getY(i))))); + } + return getGeometryFactory()->createMultiPoint(std::move(points)); +} + +std::unique_ptr GeometryDeserializer::readPolyline( + velox::common::InputByteStream& input, + bool multiType) { + skipEsriType(input); + skipEnvelope(input); + size_t partCount = input.read(); + size_t pointCount = input.read(); + + if (partCount == 0) { + if (multiType) { + return getGeometryFactory()->createMultiLineString(); + } + return getGeometryFactory()->createLineString(); + } + + std::vector startIndexes(partCount); + for (size_t i = 0; i < partCount; ++i) { + startIndexes[i] = input.read(); + } + + std::vector partLengths(partCount); + if (partCount > 1) { + partLengths[0] = startIndexes[1]; + for (size_t i = 1; i < partCount - 1; ++i) { + partLengths[i] = startIndexes[i + 1] - startIndexes[i]; + } + } + partLengths[partCount - 1] = pointCount - startIndexes[partCount - 1]; + + std::vector> lineStrings; + lineStrings.reserve(partCount); + for (size_t i = 0; i < partCount; ++i) { + lineStrings.push_back( + getGeometryFactory()->createLineString( + readCoordinates(input, partLengths[i]))); + } + + if (multiType) { + return getGeometryFactory()->createMultiLineString(std::move(lineStrings)); + } + + if (lineStrings.size() != 1) { + VELOX_FAIL("Expected a single LineString for non-multiType polyline."); + } + + return std::move(lineStrings[0]); +} + +std::unique_ptr GeometryDeserializer::readPolygon( + velox::common::InputByteStream& input, + bool multiType) { + skipEsriType(input); + skipEnvelope(input); + + size_t partCount = input.read(); + size_t pointCount = input.read(); + if (partCount == 0) { + if (multiType) { + return getGeometryFactory()->createMultiPolygon(); + } + return getGeometryFactory()->createPolygon(); + } + + std::vector startIndexes(partCount); + for (size_t i = 0; i < partCount; i++) { + startIndexes[i] = input.read(); + } + + std::vector partLengths(partCount); + if (partCount > 1) { + partLengths[0] = startIndexes[1]; + for (size_t i = 1; i < partCount - 1; i++) { + partLengths[i] = startIndexes[i + 1] - startIndexes[i]; + } + } + partLengths[partCount - 1] = pointCount - startIndexes[partCount - 1]; + + std::unique_ptr shell = nullptr; + std::vector> holes; + std::vector> polygons; + + // Shells _should_ be clockwise and holes _should_ be counter-clockwise, + // but this doesn't always happen for single Polygons. For single Polygons, + // we read the first ring as a shell and the rest as holes. For MultiPolygons, + // we read the first ring as a shell, and any counter-clockwise rings as + // holes, then push a polygon and reset if a clockwise ring is encountered. + for (size_t i = 0; i < partCount; i++) { + auto coordinates = readCoordinates(input, partLengths[i]); + + if (multiType) { + bool clockwiseFlag = + GeometrySerializer::isClockwise(coordinates, 0, coordinates->size()); + if (shell && clockwiseFlag) { + // next polygon has started + polygons.push_back( + getGeometryFactory()->createPolygon( + std::move(shell), std::move(holes))); + holes.clear(); + shell = nullptr; + } + } + + auto ring = getGeometryFactory()->createLinearRing(std::move(coordinates)); + if (shell == nullptr) { + shell = std::move(ring); + } else { + holes.push_back(std::move(ring)); + } + } + + polygons.push_back( + getGeometryFactory()->createPolygon(std::move(shell), std::move(holes))); + + if (multiType) { + return getGeometryFactory()->createMultiPolygon(std::move(polygons)); + } + + if (polygons.size() != 1) { + VELOX_FAIL("Expected exactly one polygon, but found multiple."); + } + return std::move(polygons[0]); +} + +std::unique_ptr GeometryDeserializer::readEnvelope( + velox::common::InputByteStream& input) { + auto xMin = input.read(); + auto yMin = input.read(); + auto xMax = input.read(); + auto yMax = input.read(); + + if (isEsriNaN(xMin) || isEsriNaN(yMin) || isEsriNaN(xMax) || + isEsriNaN(yMax)) { + return getGeometryFactory()->createPolygon(); + } + + auto coordinates = std::make_unique(); + coordinates->add(geos::geom::Coordinate(xMin, yMin)); + coordinates->add(geos::geom::Coordinate(xMin, yMax)); + coordinates->add(geos::geom::Coordinate(xMax, yMax)); + coordinates->add(geos::geom::Coordinate(xMax, yMin)); + coordinates->add(geos::geom::Coordinate(xMin, yMin)); // Close the ring + + auto shell = getGeometryFactory()->createLinearRing(std::move(coordinates)); + return getGeometryFactory()->createPolygon(std::move(shell), {}); +} + +std::unique_ptr +GeometryDeserializer::readGeometryCollection( + velox::common::InputByteStream& input, + size_t size) { + std::vector> geometries; + + auto offset = input.offset(); + while (size - offset > 0) { + // Skip the length field + input.read(); + geometries.push_back(readGeometry(input, size)); + offset = input.offset(); + } + std::vector rawGeometries; + rawGeometries.reserve(geometries.size()); + for (const auto& geometry : geometries) { + rawGeometries.push_back(geometry.get()); + } + + return std::unique_ptr( + getGeometryFactory()->createGeometryCollection(rawGeometries)); +} + +} // namespace facebook::velox::common::geospatial diff --git a/velox/common/geospatial/GeometrySerde.h b/velox/common/geospatial/GeometrySerde.h new file mode 100644 index 000000000000..3b622b570761 --- /dev/null +++ b/velox/common/geospatial/GeometrySerde.h @@ -0,0 +1,518 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "velox/common/base/IOUtils.h" +#include "velox/common/geospatial/GeometryConstants.h" +#include "velox/expression/ComplexViewTypes.h" +#include "velox/type/StringView.h" + +namespace facebook::velox::common::geospatial { + +/** + * VarbinaryWriter is a utility for serializing raw binary data to a + * generic writer interface. It supports writing either raw byte arrays + * or trivially copyable types. + * + * @tparam StringWriter A type that provides an `append(std::string_view)` + * method, used to consume the binary output. Examples include `std::string` or + * `core::StringWriter`. + */ +template +class VarbinaryWriter { + public: + /* implicit */ VarbinaryWriter(StringWriter& stringWriter) + : stringWriter_(stringWriter) {} + VarbinaryWriter() = delete; + + void write(const char* data, size_t size) { + stringWriter_.append(std::string_view(data, size)); + } + + template + void write(const T& value) { + static_assert( + std::is_trivially_copyable_v, "T must be trivially copyable"); + stringWriter_.append( + std::string_view(reinterpret_cast(&value), sizeof(T))); + } + + private: + StringWriter& stringWriter_; +}; + +class GeometrySerializer { + public: + /// Serialize geometry into Velox's internal format. Do not call this within + /// GEOS_TRY macro: it will catch the exceptions that need to bubble up. + template + static void serialize( + const geos::geom::Geometry& geometry, + StringWriter& stringWriter) { + VarbinaryWriter writer(stringWriter); + writeGeometry(geometry, writer); + } + + template + static void serializeEnvelope( + double xMin, + double yMin, + double xMax, + double yMax, + StringWriter& stringWriter) { + VarbinaryWriter writer(stringWriter); + writer.write(static_cast(GeometrySerializationType::ENVELOPE)); + writer.write(xMin); + writer.write(yMin); + writer.write(xMax); + writer.write(yMax); + } + + template + static void serializeEnvelope( + geos::geom::Envelope& envelope, + StringWriter& stringWriter) { + if (FOLLY_UNLIKELY(envelope.isNull())) { + serializeEnvelope( + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + stringWriter); + } else { + serializeEnvelope( + envelope.getMinX(), + envelope.getMinY(), + envelope.getMaxX(), + envelope.getMaxY(), + stringWriter); + } + } + + /// Determines if a ring of coordinates (from `start` to `end`) is oriented + /// clockwise. + FOLLY_ALWAYS_INLINE static bool isClockwise( + const std::unique_ptr& coordinates, + size_t start, + size_t end) { + double sum = 0.0; + for (size_t i = start; i < end - 1; i++) { + const auto& p1 = coordinates->getAt(i); + const auto& p2 = coordinates->getAt(i + 1); + sum += (p2.x - p1.x) * (p2.y + p1.y); + } + return sum > 0.0; + } + + /// Reverses the order of coordinates in the sequence between `start` and + /// `end` + FOLLY_ALWAYS_INLINE static void reverse( + const std::unique_ptr& coordinates, + size_t start, + size_t end) { + for (size_t i = 0; i < (end - start) / 2; ++i) { + auto temp = coordinates->getAt(start + i); + coordinates->setAt(coordinates->getAt(end - 1 - i), start + i); + coordinates->setAt(temp, end - 1 - i); + } + } + + /// Ensures that a polygon ring has the canonical orientation: + /// - Exterior rings (shells) must be clockwise. + /// - Interior rings (holes) must be counter-clockwise. + FOLLY_ALWAYS_INLINE static void canonicalizePolygonCoordinates( + const std::unique_ptr& coordinates, + size_t start, + size_t end, + bool isShell) { + bool isClockwiseFlag = isClockwise(coordinates, start, end); + + if ((isShell && !isClockwiseFlag) || (!isShell && isClockwiseFlag)) { + reverse(coordinates, start, end); + } + } + + /// Applies `canonicalizePolygonCoordinates` to all rings in a polygon. + FOLLY_ALWAYS_INLINE static void canonicalizePolygonCoordinates( + const std::unique_ptr& coordinates, + const std::vector& partIndexes, + const std::vector& shellPart) { + for (size_t part = 0; part < partIndexes.size() - 1; part++) { + canonicalizePolygonCoordinates( + coordinates, + partIndexes[part], + partIndexes[part + 1], + shellPart[part]); + } + if (!partIndexes.empty()) { + canonicalizePolygonCoordinates( + coordinates, + partIndexes.back(), + coordinates->size(), + shellPart.back()); + } + } + + private: + template + static void writeGeometry( + const geos::geom::Geometry& geometry, + VarbinaryWriter& writer) { + auto geometryType = geometry.getGeometryTypeId(); + switch (geometryType) { + case geos::geom::GEOS_POINT: + writePoint(geometry, writer); + break; + case geos::geom::GEOS_MULTIPOINT: + writeMultiPoint(geometry, writer); + break; + case geos::geom::GEOS_LINESTRING: + case geos::geom::GEOS_LINEARRING: + writePolyline(geometry, writer, false); + break; + case geos::geom::GEOS_MULTILINESTRING: + writePolyline(geometry, writer, true); + break; + case geos::geom::GEOS_POLYGON: + writePolygon(geometry, writer, false); + break; + case geos::geom::GEOS_MULTIPOLYGON: + writePolygon(geometry, writer, true); + break; + case geos::geom::GEOS_GEOMETRYCOLLECTION: + writeGeometryCollection(geometry, writer); + break; + default: + VELOX_FAIL( + "Unrecognized geometry type: {}", + static_cast(geometryType)); + break; + } + } + + template + static void writeEnvelope( + const geos::geom::Geometry& geometry, + VarbinaryWriter& writer) { + if (geometry.isEmpty()) { + writer.write(std::numeric_limits::quiet_NaN()); + writer.write(std::numeric_limits::quiet_NaN()); + writer.write(std::numeric_limits::quiet_NaN()); + writer.write(std::numeric_limits::quiet_NaN()); + return; + } + + auto envelope = geometry.getEnvelopeInternal(); + writer.write(envelope->getMinX()); + writer.write(envelope->getMinY()); + writer.write(envelope->getMaxX()); + writer.write(envelope->getMaxY()); + } + + template + static void writeCoordinates( + const std::unique_ptr& coords, + VarbinaryWriter& writer) { + for (size_t i = 0; i < coords->size(); ++i) { + writer.write(coords->getX(i)); + writer.write(coords->getY(i)); + } + } + + template + static void writePoint( + const geos::geom::Geometry& point, + VarbinaryWriter& writer) { + writer.write(static_cast(GeometrySerializationType::POINT)); + if (!point.isEmpty()) { + writeCoordinates(point.getCoordinates(), writer); + } else { + writer.write(std::numeric_limits::quiet_NaN()); + writer.write(std::numeric_limits::quiet_NaN()); + } + } + + template + static void writeMultiPoint( + const geos::geom::Geometry& geometry, + VarbinaryWriter& writer) { + writer.write(static_cast(GeometrySerializationType::MULTI_POINT)); + writer.write(static_cast(EsriShapeType::MULTI_POINT)); + writeEnvelope(geometry, writer); + writer.write(static_cast(geometry.getNumPoints())); + writeCoordinates(geometry.getCoordinates(), writer); + } + + template + static void writePolyline( + const geos::geom::Geometry& geometry, + VarbinaryWriter& writer, + bool multiType) { + size_t numParts; + size_t numPoints = geometry.getNumPoints(); + + if (multiType) { + numParts = geometry.getNumGeometries(); + writer.write( + static_cast(GeometrySerializationType::MULTI_LINE_STRING)); + } else { + numParts = (numPoints > 0) ? 1 : 0; + writer.write( + static_cast(GeometrySerializationType::LINE_STRING)); + } + + writer.write(static_cast(EsriShapeType::POLYLINE)); + + writeEnvelope(geometry, writer); + + writer.write(static_cast(numParts)); + writer.write(static_cast(numPoints)); + + size_t partIndex = 0; + for (size_t geomIdx = 0; geomIdx < numParts; ++geomIdx) { + writer.write(static_cast(partIndex)); + partIndex += geometry.getGeometryN(geomIdx)->getNumPoints(); + } + + if (multiType) { + for (size_t partIdx = 0; partIdx < numParts; ++partIdx) { + const auto* part = geometry.getGeometryN(partIdx); + writeCoordinates(part->getCoordinates(), writer); + } + } else { + writeCoordinates(geometry.getCoordinates(), writer); + } + } + + template + static void writePolygon( + const geos::geom::Geometry& geometry, + VarbinaryWriter& writer, + bool multiType) { + size_t numGeometries = geometry.getNumGeometries(); + size_t numParts = 0; + size_t numPoints = geometry.getNumPoints(); + + for (size_t geomIdx = 0; geomIdx < numGeometries; geomIdx++) { + auto polygon = dynamic_cast( + geometry.getGeometryN(geomIdx)); + if (polygon && polygon->getNumPoints() > 0) { + numParts += polygon->getNumInteriorRing() + 1; + } + } + + if (multiType) { + writer.write( + static_cast(GeometrySerializationType::MULTI_POLYGON)); + } else { + writer.write(static_cast(GeometrySerializationType::POLYGON)); + } + + writer.write(static_cast(EsriShapeType::POLYGON)); + writeEnvelope(geometry, writer); + + writer.write(static_cast(numParts)); + writer.write(static_cast(numPoints)); + + if (numParts == 0) { + return; + } + + std::vector partIndexes(numParts); + std::vector shellPart(numParts); + + size_t currentPart = 0; + size_t currentPoint = 0; + for (size_t geomIdx = 0; geomIdx < numGeometries; geomIdx++) { + const geos::geom::Polygon* polygon = + dynamic_cast( + geometry.getGeometryN(geomIdx)); + + partIndexes[currentPart] = currentPoint; + shellPart[currentPart] = true; + currentPart++; + currentPoint += polygon->getExteriorRing()->getNumPoints(); + + size_t holesCount = polygon->getNumInteriorRing(); + for (size_t holeIndex = 0; holeIndex < holesCount; holeIndex++) { + partIndexes[currentPart] = currentPoint; + shellPart[currentPart] = false; + currentPart++; + currentPoint += polygon->getInteriorRingN(holeIndex)->getNumPoints(); + } + } + + for (size_t partIndex : partIndexes) { + writer.write(static_cast(partIndex)); + } + + auto coordinates = geometry.getCoordinates(); + canonicalizePolygonCoordinates(coordinates, partIndexes, shellPart); + writeCoordinates(coordinates, writer); + } + + template + static void writeGeometryCollection( + const geos::geom::Geometry& collection, + VarbinaryWriter& writer) { + writer.write( + static_cast(GeometrySerializationType::GEOMETRY_COLLECTION)); + + for (size_t geometryIndex = 0; + geometryIndex < collection.getNumGeometries(); + ++geometryIndex) { + auto* geometry = collection.getGeometryN(geometryIndex); + // Use a temporary buffer to serialize the geometry and calculate its + // length + std::string tempBuffer; + VarbinaryWriter tempOutput(tempBuffer); + + writeGeometry(*geometry, tempOutput); + + int32_t length = static_cast(tempBuffer.size()); + writer.write(length); + writer.write(tempBuffer.data(), tempBuffer.size()); + } + } +}; + +class GeometryDeserializer { + public: + /// Deserialize Velox's internal format to a geometry. Do not call this + /// within GEOS_TRY macro: it will catch the exceptions that need to bubble + /// up. + static std::unique_ptr deserialize( + const StringView& geometryString) { + velox::common::InputByteStream inputStream(geometryString.data()); + return readGeometry(inputStream, geometryString.size()); + } + + static const std::unique_ptr deserializeEnvelope( + const StringView& geometry); + + template + static std::unique_ptr + deserializePointsToCoordinate( + const exec::ArrayView& input, + const std::string& functionName, + bool forbidDuplicates) { + std::unique_ptr coords = + std::make_unique(input.size(), 2); + + double lastX = std::numeric_limits::signaling_NaN(); + double lastY = std::numeric_limits::signaling_NaN(); + for (int i = 0; i < input.size(); i++) { + if (!input[i].has_value()) { + VELOX_USER_FAIL( + fmt::format( + "Invalid input to {}: input array contains null at index {}.", + functionName, + i)); + } + + StringView view = *input[i]; + + velox::common::InputByteStream inputStream(view.data()); + auto geometryType = inputStream.read(); + if (geometryType != GeometrySerializationType::POINT) { + VELOX_USER_FAIL( + fmt::format( + "Non-point geometry in {} input at index {}.", + functionName, + i)); + } + auto x = inputStream.read(); + auto y = inputStream.read(); + if (std::isnan(x) || std::isnan(y)) { + VELOX_USER_FAIL( + fmt::format( + "Empty point in {} input at index {}.", functionName, i)); + } + if (forbidDuplicates && x == lastX && y == lastY) { + VELOX_USER_FAIL( + fmt::format( + "Repeated point sequence in {}: point {},{} at index {}.", + functionName, + x, + y, + i)); + } + lastX = x; + lastY = y; + coords->setAt({x, y}, i); + } + return coords; + } + + /// Returns the thread-local GEOS geometry factory. + static geos::geom::GeometryFactory* getGeometryFactory(); + + private: + static std::unique_ptr readGeometry( + velox::common::InputByteStream& stream, + size_t size); + + static bool isEsriNaN(double d) { + return std::isnan(d) || d < -1.0E38; + } + + static void skipEsriType(velox::common::InputByteStream& input) { + input.read(); // Esri type is an integer + } + + static void skipEnvelope(velox::common::InputByteStream& input) { + input.read(4); // Envelopes are 4 doubles (minX, minY, maxX, maxY) + } + + static std::unique_ptr deserializeEnvelope( + velox::common::InputByteStream& input); + + static geos::geom::Coordinate readCoordinate( + velox::common::InputByteStream& input); + + static std::unique_ptr readCoordinates( + velox::common::InputByteStream& input, + size_t count); + + static std::unique_ptr readPoint( + velox::common::InputByteStream& input); + + static std::unique_ptr readMultiPoint( + velox::common::InputByteStream& input); + + static std::unique_ptr readPolyline( + velox::common::InputByteStream& input, + bool multiType); + + static std::unique_ptr readPolygon( + velox::common::InputByteStream& input, + bool multiType); + + static std::unique_ptr readEnvelope( + velox::common::InputByteStream& input); + + static std::unique_ptr readGeometryCollection( + velox::common::InputByteStream& input, + size_t size); +}; + +} // namespace facebook::velox::common::geospatial diff --git a/velox/common/geospatial/tests/CMakeLists.txt b/velox/common/geospatial/tests/CMakeLists.txt new file mode 100644 index 000000000000..55b2f76cde56 --- /dev/null +++ b/velox/common/geospatial/tests/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(VELOX_ENABLE_GEO) + add_executable(velox_common_geospatial_serde_test GeometrySerdeTest.cpp) + + add_test(velox_common_geospatial_serde_test velox_common_geospatial_serde_test) + + target_link_libraries( + velox_common_geospatial_serde_test + velox_common_geospatial_serde + GTest::gtest + GTest::gtest_main + GTest::gmock + GTest::gmock_main + GEOS::geos + ) +endif() diff --git a/velox/common/geospatial/tests/GeometrySerdeTest.cpp b/velox/common/geospatial/tests/GeometrySerdeTest.cpp new file mode 100644 index 000000000000..14ee6186a682 --- /dev/null +++ b/velox/common/geospatial/tests/GeometrySerdeTest.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/geospatial/GeometrySerde.h" +#include +#include +#include +#include +#include "velox/type/StringView.h" + +using namespace ::testing; + +using namespace facebook::velox::common::geospatial; + +void assertRoundtrip(const std::string& wkt) { + geos::io::WKTReader reader; + geos::io::WKTWriter writer; + std::unique_ptr geometry = reader.read(wkt); + + std::string buffer; + GeometrySerializer::serialize(*geometry, buffer); + facebook::velox::StringView readBuffer(buffer); + auto deserialized = GeometryDeserializer::deserialize(readBuffer); + + EXPECT_TRUE(geometry->equals(deserialized.get())) + << std::endl + << "Input:" << std::endl + << wkt << std::endl + << "Output:" << std::endl + << writer.write(deserialized.get()); +} + +TEST(GeometrySerdeTest, testBasicSerde) { + assertRoundtrip("POINT EMPTY"); + assertRoundtrip("POINT (1 2)"); + assertRoundtrip("MULTIPOINT EMPTY"); + assertRoundtrip("MULTIPOINT (1 2)"); + assertRoundtrip("MULTIPOINT (1 2, 1 0)"); + + assertRoundtrip("LINESTRING EMPTY"); + assertRoundtrip("LINESTRING (1 2, 1 0)"); + assertRoundtrip("LINESTRING (1 0, 2 0, 2 1, 1 1, 1 0)"); + assertRoundtrip("MULTILINESTRING EMPTY"); + assertRoundtrip("MULTILINESTRING ((1 2, 1 0))"); + assertRoundtrip("MULTILINESTRING ((1 2, 1 0), (10 11, 12 13))"); + + assertRoundtrip("POLYGON EMPTY"); + assertRoundtrip("POLYGON ((1 0, 2 0, 2 1, 1 1, 1 0))"); + assertRoundtrip( + "POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))"); + assertRoundtrip("MULTIPOLYGON EMPTY"); + assertRoundtrip("MULTIPOLYGON (((1 0, 2 0, 2 1, 1 1, 1 0)))"); + assertRoundtrip( + "MULTIPOLYGON ( ((10 0, 20 0, 20 10, 10 10, 10 0)), ((0 0, 4 0, 4 4, 0 4, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1)) )"); +} + +TEST(GeometrySerdeTest, testGeometryCollectionSerde) { + assertRoundtrip("GEOMETRYCOLLECTION EMPTY"); + assertRoundtrip("GEOMETRYCOLLECTION (POINT EMPTY)"); + assertRoundtrip("GEOMETRYCOLLECTION (POINT (0 0))"); + assertRoundtrip("GEOMETRYCOLLECTION (POINT (0 0), POINT EMPTY)"); + assertRoundtrip("GEOMETRYCOLLECTION (POINT (0 0), POINT (0 0))"); + assertRoundtrip("GEOMETRYCOLLECTION (POINT (0 0), POINT (1 1))"); + assertRoundtrip("GEOMETRYCOLLECTION (MULTIPOINT EMPTY)"); + assertRoundtrip( + "GEOMETRYCOLLECTION (MULTIPOINT (0 0, 1 2), POINT (1 1), MULTIPOINT EMPTY)"); + + assertRoundtrip("GEOMETRYCOLLECTION (LINESTRING EMPTY)"); + assertRoundtrip("GEOMETRYCOLLECTION (MULTILINESTRING EMPTY)"); + assertRoundtrip( + "GEOMETRYCOLLECTION (MULTILINESTRING ((0 1, 2 3, 0 3, 0 1), (10 10, 10 12, 12 10)), POINT EMPTY, LINESTRING (0 0, -1 -1, 2 0))"); + + assertRoundtrip("GEOMETRYCOLLECTION (POLYGON EMPTY)"); + assertRoundtrip("GEOMETRYCOLLECTION (MULTIPOLYGON EMPTY)"); + assertRoundtrip("GEOMETRYCOLLECTION (GEOMETRYCOLLECTION EMPTY)"); + assertRoundtrip( + "GEOMETRYCOLLECTION (POINT (1 2), LINESTRING (8 4, 5 7), POLYGON EMPTY)"); + assertRoundtrip( + "GEOMETRYCOLLECTION (GEOMETRYCOLLECTION ( MULTIPOINT (1 2) ))"); +} + +TEST(GeometrySerdeTest, testComplexSerde) { + assertRoundtrip("GEOMETRYCOLLECTION ( MULTIPOINT EMPTY, MULTIPOINT (1 1) )"); + assertRoundtrip("GEOMETRYCOLLECTION (POLYGON EMPTY, POINT (1 2))"); + assertRoundtrip( + "GEOMETRYCOLLECTION (POLYGON EMPTY, MULTIPOINT (1 2), GEOMETRYCOLLECTION ( MULTIPOINT (3 4) ))"); + assertRoundtrip( + "GEOMETRYCOLLECTION (POLYGON EMPTY, GEOMETRYCOLLECTION ( POINT (1 2), POLYGON ((0 0, 4 0, 4 4, 0 4, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1)), GEOMETRYCOLLECTION EMPTY, MULTIPOLYGON ( ((10 10, 14 10, 14 14, 10 14, 10 10), (11 11, 12 11, 12 12, 11 12, 11 11)), ((-1 -1, -2 -2, -1 -2, -1 -1)) ) ))"); +} + +TEST(GeometrySerdeTest, testSmallAreaRing) { + assertRoundtrip( + "MULTIPOLYGON (((18.6317421 49.9605785, 18.6318832 49.9607979, 18.6324683 49.9607312, 18.6332842 49.9605658, 18.6332003 49.9603557, 18.6339711 49.9602283, 18.6341994 49.9601905, 18.6343455 49.96016, 18.6344167 49.9601452, 18.6346696 49.9600919, 18.6349643 49.9600567, 18.6352271 49.9601455, 18.6354493 49.9600501, 18.6358024 49.9601071, 18.6358911 49.9600263, 18.6336542 49.9592453, 18.6334794 49.9591838, 18.6337483 49.9581339, 18.6335303 49.9580562, 18.6331284 49.9579122, 18.6324931 49.9576885, 18.6322503 49.9575998, 18.6321381 49.9581593, 18.6321172 49.9582692, 18.6324683 49.9583852, 18.6325255 49.9584004, 18.6327588 49.958489, 18.6324792 49.9588351, 18.6323941 49.9588049, 18.6323261 49.9587807, 18.6320354 49.9586789, 18.6319443 49.9592903, 18.6326731 49.9595648, 18.6331388 49.9594836, 18.6335981 49.959673, 18.6333065 49.9597934, 18.6328096 49.9600844, 18.6330209 49.9601348, 18.633424 49.9602597, 18.6332263 49.960317, 18.6315633 49.9597642, 18.6309331 49.9600741, 18.6317421 49.9605785)), ((18.6298591 49.9606201, 18.6298592 49.96062, 18.6298589 49.9606193, 18.6298591 49.9606201)))"); +} diff --git a/velox/common/hyperloglog/CMakeLists.txt b/velox/common/hyperloglog/CMakeLists.txt index 34ff61b18d44..649d1bc9851c 100644 --- a/velox/common/hyperloglog/CMakeLists.txt +++ b/velox/common/hyperloglog/CMakeLists.txt @@ -16,12 +16,10 @@ velox_add_library( BiasCorrection.cpp DenseHll.cpp SparseHll.cpp - Murmur3Hash128.cpp) + Murmur3Hash128.cpp +) -velox_link_libraries( - velox_common_hyperloglog - PUBLIC velox_memory - PRIVATE velox_exception) +velox_link_libraries(velox_common_hyperloglog PUBLIC velox_memory PRIVATE velox_exception) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/common/hyperloglog/DenseHll.cpp b/velox/common/hyperloglog/DenseHll.cpp index f3a3f54b2d2d..2245a3f69096 100644 --- a/velox/common/hyperloglog/DenseHll.cpp +++ b/velox/common/hyperloglog/DenseHll.cpp @@ -15,13 +15,13 @@ */ #include "velox/common/hyperloglog/DenseHll.h" -#include -#include +#include "velox/common/base/BitUtil.h" #include "velox/common/base/IOUtils.h" #include "velox/common/hyperloglog/BiasCorrection.h" #include "velox/common/hyperloglog/HllUtils.h" namespace facebook::velox::common::hll { + namespace { const int kBitsPerBucket = 4; const int8_t kMaxDelta = (1 << kBitsPerBucket) - 1; @@ -119,14 +119,26 @@ double correctBias(double rawEstimate, int8_t indexBitLength) { } } // namespace -DenseHll::DenseHll(int8_t indexBitLength, HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} { +template +DenseHll::DenseHll(int8_t indexBitLength, TAllocator* allocator) + : allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} { initialize(indexBitLength); } -void DenseHll::initialize(int8_t indexBitLength) { +template +DenseHll::DenseHll(TAllocator* allocator) + : indexBitLength_(-1), + baselineCount_(0), + allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} {} + +template +void DenseHll::initialize(int8_t indexBitLength) { VELOX_CHECK_GE(indexBitLength, 4, "indexBitLength must be in [4, 16] range"); VELOX_CHECK_LE(indexBitLength, 16, "indexBitLength must be in [4, 16] range"); @@ -137,13 +149,15 @@ void DenseHll::initialize(int8_t indexBitLength) { deltas_.resize(numBuckets * kBitsPerBucket / 8); } -void DenseHll::insertHash(uint64_t hash) { +template +void DenseHll::insertHash(uint64_t hash) { auto index = computeIndex(hash, indexBitLength_); auto value = numberOfLeadingZeros(hash, indexBitLength_) + 1; insert(index, value); } -void DenseHll::insert(int32_t index, int8_t value) { +template +void DenseHll::insert(int32_t index, int8_t value) { auto delta = value - baseline_; auto oldDelta = getDelta(index); @@ -261,7 +275,8 @@ DenseHllView deserialize(const char* serialized) { } } // namespace -int64_t DenseHll::cardinality() const { +template +int64_t DenseHll::cardinality() const { DenseHllView hll{ indexBitLength_, baseline_, @@ -272,18 +287,14 @@ int64_t DenseHll::cardinality() const { return cardinalityImpl(hll); } -// static -int64_t DenseHll::cardinality(const char* serialized) { - auto hll = deserialize(serialized); - return cardinalityImpl(hll); -} - -int8_t DenseHll::getDelta(int32_t index) const { +template +int8_t DenseHll::getDelta(int32_t index) const { int slot = index >> 1; return (deltas_[slot] >> shiftForBucket(index)) & kBucketMask; } -void DenseHll::setDelta(int32_t index, int8_t value) { +template +void DenseHll::setDelta(int32_t index, int8_t value) { int slot = index >> 1; // Clear the old value. @@ -295,12 +306,14 @@ void DenseHll::setDelta(int32_t index, int8_t value) { deltas_[slot] |= setMask; } -int8_t DenseHll::getOverflow(int32_t index) const { +template +int8_t DenseHll::getOverflow(int32_t index) const { return getOverflowImpl( index, overflows_, overflowBuckets_.data(), overflowValues_.data()); } -int DenseHll::findOverflowEntry(int32_t index) const { +template +int DenseHll::findOverflowEntry(int32_t index) const { for (auto i = 0; i < overflows_; i++) { if (overflowBuckets_[i] == index) { return i; @@ -309,7 +322,8 @@ int DenseHll::findOverflowEntry(int32_t index) const { return -1; } -void DenseHll::adjustBaselineIfNeeded() { +template +void DenseHll::adjustBaselineIfNeeded() { auto numBuckets = 1 << indexBitLength_; while (baselineCount_ == 0) { @@ -359,7 +373,8 @@ void DenseHll::adjustBaselineIfNeeded() { } } -void DenseHll::sortOverflows() { +template +void DenseHll::sortOverflows() { // traditional insertion sort (ok for small arrays) for (int i = 1; i < overflows_; i++) { auto bucket = overflowBuckets_[i]; @@ -385,7 +400,8 @@ void DenseHll::sortOverflows() { } } -int32_t DenseHll::serializedSize() const { +template +int32_t DenseHll::serializedSize() const { return 1 /* type + version */ + 1 /* indexBitLength */ + 1 /* baseline */ @@ -395,13 +411,17 @@ int32_t DenseHll::serializedSize() const { + overflows_ /* overflow bucket values */; } -// static -bool DenseHll::canDeserialize(const char* input) { +int64_t DenseHlls::cardinality(const char* serialized) { + auto hll = deserialize(serialized); + return cardinalityImpl(hll); +} + +bool DenseHlls::canDeserialize(const char* input) { return *reinterpret_cast(input) == kPrestoDenseV2; } // static -bool DenseHll::canDeserialize(const char* input, int size) { +bool DenseHlls::canDeserialize(const char* input, int size) { if (size < 5) { // Min serialized sparse HLL size is 5 bytes. return false; @@ -459,22 +479,23 @@ bool DenseHll::canDeserialize(const char* input, int size) { return true; } -// static -int8_t DenseHll::deserializeIndexBitLength(const char* input) { +int8_t DenseHlls::deserializeIndexBitLength(const char* input) { common::InputByteStream stream(input); stream.read(); return stream.read(); } -// static -int32_t DenseHll::estimateInMemorySize(int8_t indexBitLength) { +int32_t DenseHlls::estimateInMemorySize(int8_t indexBitLength) { // Note: we don't take into account overflow entries since their number can // vary. - return sizeof(indexBitLength_) + sizeof(baseline_) + sizeof(baselineCount_) + + // return sizeof(indexBitLength_) + sizeof(baseline_) + + // sizeof(baselineCount_) + (1 << indexBitLength) / 2; + return sizeof(int8_t) + sizeof(int8_t) + sizeof(int32_t) + (1 << indexBitLength) / 2; } -void DenseHll::serialize(char* output) { +template +void DenseHll::serialize(char* output) { // sort overflow arrays to get consistent serialization for equivalent HLLs sortOverflows(); @@ -492,10 +513,12 @@ void DenseHll::serialize(char* output) { } } -DenseHll::DenseHll(const char* serialized, HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} { +template +DenseHll::DenseHll(const char* serialized, TAllocator* allocator) + : allocator_(allocator), + deltas_{TStlAllocator(allocator)}, + overflowBuckets_{TStlAllocator(allocator)}, + overflowValues_{TStlAllocator(allocator)} { auto hll = deserialize(serialized); initialize(hll.indexBitLength); baseline_ = hll.baseline; @@ -525,7 +548,8 @@ DenseHll::DenseHll(const char* serialized, HashStringAllocator* allocator) } } -void DenseHll::mergeWith(const DenseHll& other) { +template +void DenseHll::mergeWith(const DenseHll& other) { VELOX_CHECK_EQ( indexBitLength_, other.indexBitLength_, @@ -539,7 +563,8 @@ void DenseHll::mergeWith(const DenseHll& other) { other.overflowValues_.data()}); } -void DenseHll::mergeWith(const char* serialized) { +template +void DenseHll::mergeWith(const char* serialized) { common::InputByteStream stream(serialized); auto version = stream.read(); @@ -561,7 +586,8 @@ void DenseHll::mergeWith(const char* serialized) { mergeWith({baseline, deltas, overflows, overflowBuckets, overflowValues}); } -std::pair DenseHll::computeNewValue( +template +std::pair DenseHll::computeNewValue( int8_t delta, int8_t otherDelta, int32_t bucket, @@ -585,7 +611,8 @@ std::pair DenseHll::computeNewValue( return {std::max(value1, value2), overflowEntry}; } -void DenseHll::mergeWith(const HllView& other) { +template +void DenseHll::mergeWith(const HllView& other) { // Number of 'delta' bytes that fit in a single SIMD batch. Each 'delta' byte // stores 2 4-bit deltas. constexpr auto batchSize = xsimd::batch::size; @@ -611,7 +638,10 @@ void DenseHll::mergeWith(const HllView& other) { adjustBaselineIfNeeded(); } -int32_t DenseHll::mergeWithSimd(const HllView& other, int8_t newBaseline) { +template +int32_t DenseHll::mergeWithSimd( + const HllView& other, + int8_t newBaseline) { const auto batchSize = xsimd::batch::size; const auto bucketMaskBatch = xsimd::broadcast(kBucketMask); @@ -751,7 +781,10 @@ int32_t DenseHll::mergeWithSimd(const HllView& other, int8_t newBaseline) { return baselineCount; } -int32_t DenseHll::mergeWithScalar(const HllView& other, int8_t newBaseline) { +template +int32_t DenseHll::mergeWithScalar( + const HllView& other, + int8_t newBaseline) { int32_t baselineCount = 0; int bucket = 0; @@ -787,8 +820,11 @@ int32_t DenseHll::mergeWithScalar(const HllView& other, int8_t newBaseline) { return baselineCount; } -int8_t -DenseHll::updateOverflow(int32_t index, int overflowEntry, int8_t delta) { +template +int8_t DenseHll::updateOverflow( + int32_t index, + int overflowEntry, + int8_t delta) { if (delta > kMaxDelta) { if (overflowEntry != -1) { // update existing overflow @@ -804,7 +840,8 @@ DenseHll::updateOverflow(int32_t index, int overflowEntry, int8_t delta) { return delta; } -void DenseHll::addOverflow(int32_t index, int8_t overflow) { +template +void DenseHll::addOverflow(int32_t index, int8_t overflow) { overflowBuckets_.resize(overflows_ + 1); overflowValues_.resize(overflows_ + 1); @@ -813,10 +850,17 @@ void DenseHll::addOverflow(int32_t index, int8_t overflow) { overflows_++; } -void DenseHll::removeOverflow(int overflowEntry) { +template +void DenseHll::removeOverflow(int overflowEntry) { // Remove existing overflow. overflowBuckets_[overflowEntry] = overflowBuckets_[overflows_ - 1]; overflowValues_[overflowEntry] = overflowValues_[overflows_ - 1]; overflows_--; } + +// Explicit template instantiation for both HashStringAllocator (default) and +// memory::MemoryPool +template class DenseHll; +template class DenseHll; + } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/DenseHll.h b/velox/common/hyperloglog/DenseHll.h index b6b5f03f8cdf..6936b6d8085b 100644 --- a/velox/common/hyperloglog/DenseHll.h +++ b/velox/common/hyperloglog/DenseHll.h @@ -17,7 +17,45 @@ #include "velox/common/memory/HashStringAllocator.h" namespace facebook::velox::common::hll { -class SparseHll; + +class DenseHlls { + public: + /// Returns cardinality estimate from the specified serialized digest. + /// @param serialized Pointer to serialized DenseHll data + /// @return Estimated cardinality of the HyperLogLog + static int64_t cardinality(const char* serialized); + + /// Returns true if 'input' contains Presto DenseV2 format indicator. + /// @param input Pointer to serialized data to check + /// @return True if the data is in DenseV2 format, false otherwise + static bool canDeserialize(const char* input); + + /// Returns true if 'input' contains Presto DenseV2 format indicator and the + /// rest of the data matches HLL format: + /// 1 byte for version + /// 1 byte for index bit length, index bit length must be in [4,16] + /// 1 byte for baseline value + /// 2^(n-1) bytes for buckets, values in buckets must be in [0,63] + /// 2 bytes for # overflow buckets + /// 3 * #overflow buckets bytes for overflow buckets/values + /// More information here: + /// https://engineering.fb.com/2018/12/13/data-infrastructure/hyperloglog/ + /// @param input Pointer to serialized data to validate + /// @param size Size of the serialized data in bytes + /// @return True if the data is valid DenseV2 format, false otherwise + static bool canDeserialize(const char* input, int size); + + /// Extracts the index bit length from serialized DenseHll data. + /// @param input Pointer to serialized DenseHll data + /// @return The index bit length used in the serialized HLL + static int8_t deserializeIndexBitLength(const char* input); + + /// Returns an estimate of memory usage for DenseHll instance with the + /// specified number of bits per bucket. + /// @param indexBitLength Number of bits per bucket (must be in [4,16]) + /// @return Estimated memory usage in bytes + static int32_t estimateInMemorySize(int8_t indexBitLength); +}; /// HyperLogLog implementation using dense storage layout. /// The number of bits to use as bucket (indexBitLength) is specified by the @@ -26,18 +64,19 @@ class SparseHll; /// /// Memory usage: 2 ^ (indexBitLength - 1) bytes. 2KB for indexBitLength of 12 /// which provides max standard error of 0.023. +template class DenseHll { public: - DenseHll(int8_t indexBitLength, HashStringAllocator* allocator); + template + using TStlAllocator = typename TAllocator::template TStlAllocator; + + DenseHll(int8_t indexBitLength, TAllocator* allocator); - DenseHll(const char* serialized, HashStringAllocator* allocator); + DenseHll(const char* serialized, TAllocator* allocator); /// Creates an uninitialized instance that doesn't allcate any significant /// memory. The caller must call initialize before using the HLL. - explicit DenseHll(HashStringAllocator* allocator) - : deltas_{StlAllocator(allocator)}, - overflowBuckets_{StlAllocator(allocator)}, - overflowValues_{StlAllocator(allocator)} {} + explicit DenseHll(TAllocator* allocator); /// Allocates memory that can fit 2 ^ indexBitLength buckets. void initialize(int8_t indexBitLength); @@ -55,28 +94,9 @@ class DenseHll { int64_t cardinality() const; - static int64_t cardinality(const char* serialized); - /// Serializes internal state using Presto DenseV2 format. void serialize(char* output); - /// Returns true if 'input' contains Presto DenseV2 format indicator. - static bool canDeserialize(const char* input); - - /// Returns true if 'input' contains Presto DenseV2 format indicator and the - /// rest of the data matches HLL format: - /// 1 byte for version - /// 1 byte for index bit length, index bit length must be in [4,16] - /// 1 byte for baseline value - /// 2^(n-1) bytes for buckets, values in buckets must be in [0,63] - /// 2 bytes for # overflow buckets - /// 3 * #overflow buckets bytes for overflow buckets/values - /// More information here: - /// https://engineering.fb.com/2018/12/13/data-infrastructure/hyperloglog/ - static bool canDeserialize(const char* input, int size); - - static int8_t deserializeIndexBitLength(const char* input); - /// Returns the size of the serialized state without serialising. int32_t serializedSize() const; @@ -86,10 +106,6 @@ class DenseHll { void mergeWith(const char* serialized); - /// Returns an estimate of memory usage for DenseHll instance with the - /// specified number of bits per bucket. - static int32_t estimateInMemorySize(int8_t indexBitLength); - private: int8_t getDelta(int32_t index) const; @@ -147,20 +163,19 @@ class DenseHll { /// Number of zero deltas. int32_t baselineCount_; + TAllocator* allocator_; + /// Per-bucket values represented as deltas from the baseline_. Each entry /// stores 2 values, 4 bits each. The maximum value that can be stored is 15. /// Larger values are stored in a separate overflow list. - std::vector> deltas_; - - /// Number of overflowing values, e.g. values where delta from baseline is - /// greater than 15. + std::vector> deltas_; int16_t overflows_{0}; /// List of buckets with overflowing values. - std::vector> overflowBuckets_; + std::vector> overflowBuckets_; /// Overflowing values stored as deltas from the deltas: value - 15 - /// baseline. - std::vector> overflowValues_; + std::vector> overflowValues_; }; } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/HllAccumulator.h b/velox/common/hyperloglog/HllAccumulator.h new file mode 100644 index 000000000000..4f47b627c18a --- /dev/null +++ b/velox/common/hyperloglog/HllAccumulator.h @@ -0,0 +1,244 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#define XXH_INLINE_ALL + +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/hyperloglog/DenseHll.h" +#include "velox/common/hyperloglog/Murmur3Hash128.h" +#include "velox/common/hyperloglog/SparseHll.h" +#include "velox/common/memory/HashStringAllocator.h" + +namespace facebook::velox::common::hll { + +namespace detail { +template +inline uint64_t hashOne(const T& value) { + if constexpr (std::is_same_v) { + return hashOne(value.toMillis()); + } + if constexpr (HllAsFinalResult) { + if constexpr (std::is_same_v) { + return common::hll::Murmur3Hash128::hash64ForLong(value, 0); + } else if constexpr (std::is_same_v) { + return common::hll::Murmur3Hash128::hash64ForLong( + *reinterpret_cast(&value), 0); + } + return common::hll::Murmur3Hash128::hash64(&value, sizeof(T), 0); + } else { + return XXH64(&value, sizeof(T), 0); + } +} + +template <> +inline uint64_t hashOne(const StringView& value) { + return XXH64(value.data(), value.size(), 0); +} + +template <> +inline uint64_t hashOne(const StringView& value) { + return common::hll::Murmur3Hash128::hash64(value.data(), value.size(), 0); +} + +} // namespace detail + +template < + typename T, + bool HllAsFinalResult, + typename TAllocator = HashStringAllocator> +struct HllAccumulator { + explicit HllAccumulator(TAllocator* allocator) + : sparseHll_{allocator}, denseHll_{allocator} {} + + explicit HllAccumulator(int8_t indexBitLength, TAllocator* allocator) + : isSparse_(true), + indexBitLength_(indexBitLength), + sparseHll_(allocator), + denseHll_(allocator) { + // Set soft memory limit for sparse HLL to convert to dense when exceeded. + sparseHll_.setSoftMemoryLimit( + DenseHlls::estimateInMemorySize(indexBitLength_)); + } + + void setIndexBitLength(int8_t indexBitLength) { + indexBitLength_ = indexBitLength; + sparseHll_.setSoftMemoryLimit( + DenseHlls::estimateInMemorySize(indexBitLength_)); + } + + /// Creates an HllAccumulator instance from serialized data. + static std::unique_ptr deserialize( + const char* data, + TAllocator* allocator) { + if (SparseHlls::canDeserialize(data)) { + int8_t indexBitLength = SparseHlls::deserializeIndexBitLength(data); + auto wrapper = + std::make_unique(indexBitLength, allocator); + wrapper->sparseHll_ = SparseHll(data, allocator); + wrapper->sparseHll_.setSoftMemoryLimit( + DenseHlls::estimateInMemorySize(indexBitLength)); + return wrapper; + } else if (DenseHlls::canDeserialize(data)) { + int8_t indexBitLength = DenseHlls::deserializeIndexBitLength(data); + auto wrapper = + std::make_unique(indexBitLength, allocator); + wrapper->denseHll_ = DenseHll(data, allocator); + wrapper->isSparse_ = false; + return wrapper; + } else { + VELOX_FAIL("Cannot deserialize HyperLogLog"); + } + } + + void append(T value) { + const auto hash = detail::hashOne(value); + insertHash(hash); + } + + void insertHash(uint64_t hash) { + if (isSparse_) { + // insertHash returns true if soft memory limit exceeded + if (sparseHll_.insertHash(hash)) { + toDense(); + } + } else { + denseHll_.insertHash(hash); + } + } + + int64_t cardinality() const { + return isSparse_ ? sparseHll_.cardinality() : denseHll_.cardinality(); + } + + void mergeWith(StringView serialized, TAllocator* allocator) { + auto input = serialized.data(); + if (indexBitLength_ < 0) { + // deserializeIndexBitLength is the same between Dense and Sparse HLL + setIndexBitLength(DenseHlls::deserializeIndexBitLength(input)); + } + + if (SparseHlls::canDeserialize(input)) { + SparseHll other{input, allocator}; + mergeWithSparse(other); + } else if (DenseHlls::canDeserialize(input)) { + DenseHll other{input, allocator}; + mergeWithDense(other); + } else { + VELOX_USER_FAIL("Unexpected type of HLL"); + } + } + + void mergeWith(const HllAccumulator& other) { + if (indexBitLength_ < 0) { + setIndexBitLength(other.indexBitLength_); + } + if (other.isSparse_) { + mergeWithSparse(other.sparseHll_); + } else { + mergeWithDense(other.denseHll_); + } + } + + int32_t serializedSize() { + return isSparse_ ? sparseHll_.serializedSize() : denseHll_.serializedSize(); + } + + void serialize(char* outputBuffer) { + return isSparse_ ? sparseHll_.serialize(indexBitLength_, outputBuffer) + : denseHll_.serialize(outputBuffer); + } + + bool isSparse() const { + return isSparse_; + } + + private: + void toDense() { + isSparse_ = false; + denseHll_.initialize(indexBitLength_); + sparseHll_.toDense(denseHll_); + sparseHll_.reset(); + } + + void mergeWithSparse(const SparseHll& other) { + if (isSparse_) { + sparseHll_.mergeWith(other); + if (sparseHll_.overLimit()) { + toDense(); + } + } else { + other.toDense(denseHll_); + } + } + + void mergeWithDense(const DenseHll& other) { + if (isSparse_) { + toDense(); + } + denseHll_.mergeWith(other); + } + + bool isSparse_{true}; + int8_t indexBitLength_{-1}; + SparseHll sparseHll_; + DenseHll denseHll_; +}; + +template <> +struct HllAccumulator { + explicit HllAccumulator(HashStringAllocator* /*allocator*/) {} + + void append(bool value) { + approxDistinctState_ |= (1 << value); + } + + int64_t cardinality() const { + return (approxDistinctState_ & 1) + ((approxDistinctState_ & 2) >> 1); + } + + void mergeWith( + StringView /*serialized*/, + HashStringAllocator* /*allocator*/) { + VELOX_UNREACHABLE( + "APPROX_DISTINCT unsupported mergeWith(StringView, HashStringAllocator*)"); + } + + void mergeWith(int8_t data) { + approxDistinctState_ |= data; + } + + int32_t serializedSize() const { + return sizeof(int8_t); + } + + void serialize(char* /*outputBuffer*/) { + VELOX_UNREACHABLE("APPROX_DISTINCT unsupported serialize(char*)"); + } + + void setIndexBitLength(int8_t /*indexBitLength*/) {} + + int8_t getState() const { + return approxDistinctState_; + } + + private: + int8_t approxDistinctState_{0}; +}; +} // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/KHyperLogLog.h b/velox/common/hyperloglog/KHyperLogLog.h new file mode 100644 index 000000000000..6f4eca2c3286 --- /dev/null +++ b/velox/common/hyperloglog/KHyperLogLog.h @@ -0,0 +1,173 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include +#include "velox/common/base/Status.h" +#include "velox/common/hyperloglog/HllAccumulator.h" + +namespace facebook::velox::common::hll { + +/// See "KHyperLogLog: EstimatingReidentifiability and Joinability of Large Data +/// at Scale" by Chia et al., 2019. +/// https://research.google/pubs/khyperloglog-estimating-reidentifiability-and-joinability-of-large-data-at-scale/ +template +class KHyperLogLog { + public: + template + using TStlAllocator = typename TAllocator::template TStlAllocator; + + static constexpr int32_t kDefaultHllBuckets = 256; + static constexpr int32_t kDefaultMaxSize = 4096; + + explicit KHyperLogLog(TAllocator* allocator) + : KHyperLogLog(kDefaultMaxSize, kDefaultHllBuckets, allocator) {} + + KHyperLogLog(int maxSize, int hllBuckets, TAllocator* allocator) + : maxSize_(maxSize), + hllBuckets_(hllBuckets), + allocator_(allocator), + minhash_( + TStlAllocator>>(allocator)) { + VELOX_CHECK( + hllBuckets > 0 && (hllBuckets & (hllBuckets - 1)) == 0, + "hllBuckets must be a power of 2"); + // numBuckets = 2^bits + indexBitLength_ = static_cast(std::log2(hllBuckets)); + } + + /// Creates a KHyperLogLog instance from serialized data. + /// Returns an error status if the data is invalid. + static Expected>> + deserialize(const char* data, size_t size, TAllocator* allocator); + + /// Serializes the KHyperLogLog state to a varbinary. + /// The caller must ensure there is enough space at the buffer location, using + /// estimatedSerializedSize(). Note: Not const because DenseHll::serialize() + /// may reorder internal overflow arrays for deterministic serialization. + void serialize(char* outputBuffer); + + /// Returns an estimate of the size in bytes of the serialized + /// representation. + size_t estimatedSerializedSize() const; + + /// Convert JoinKey to int64_t, hash, then add with its associated User + /// Identifying Information (UII) to the existing KHLL. + template + void add(TJoinKey joinKey, TUii uii); + + /// Returns the estimated cardinality (number of distinct join keys). + /// When isExact() is true, returns exact count. Otherwise, uses min-hash + /// density estimation to extrapolate to the full hash space. + int64_t cardinality() const; + + /// Returns whether the KHyperLogLog is in exact mode. + /// Exact mode means the number of distinct values is less than maxSize, + /// so all values are being tracked exactly without approximation. + bool isExact() const; + + size_t minhashSize() const; + + /// Counts the exact number of common join keys between two KHyperLogLogs. + /// Both instances must be in exact mode (isExact() == true), otherwise + /// throws. + static int64_t exactIntersectionCardinality( + const KHyperLogLog& left, + const KHyperLogLog& right); + + /// Computes the Jaccard index (similarity coefficient) between two + /// KHyperLogLogs. Uses the min-hash approach: looks at the first + /// min(|left.minhash|, |right.minhash|) entries in the sorted union of keys + /// and calculates the proportion that appear in both sets. + static double jaccardIndex( + const KHyperLogLog& left, + const KHyperLogLog& right); + + /// Merges two KHyperLogLog instances into a new instance. + /// Uses the one with the smaller maxSize as the base to avoid losing + /// resolution. Takes ownership of the input unique_ptrs and returns one + /// of them to avoid copying. + static std::unique_ptr> merge( + std::unique_ptr> left, + std::unique_ptr> right); + + /// Calculates the reidentification potential, which is the proportion of + /// values with cardinality (number of distinct UIIs) at or below the given + /// threshold. A higher value indicates more values are highly unique and + /// could potentially be used to reidentify individuals. + double reidentificationPotential(int64_t threshold) const; + + /// Returns a histogram of the uniqueness distribution with specified number + /// of buckets. For each value, determines its cardinality (number of distinct + /// UIIs) and increments the corresponding bucket. For hash values larger than + /// the histogram size, the cardinality will be added to the last bucket at + /// histogramSize. + folly::F14FastMap uniquenessDistribution( + int64_t histogramSize) const; + + /// Merges another KHyperLogLog instance into this one (modifies this + /// instance). + void mergeWith(const KHyperLogLog& other); + + /// Merges another serialized KHyperLogLog into this one. + /// The serialized data is assumed to be valid. + void mergeWith(StringView serialized, TAllocator* allocator) { + auto other = common::hll::KHyperLogLog::deserialize( + serialized.data(), serialized.size(), allocator); + VELOX_CHECK(other.hasValue(), "Failed to deserialize KHyperLogLog"); + mergeWith(*other.value()); + } + + private: + void update(int64_t hash, TUii uii); + + void removeOverflowEntries(); + + void increaseTotalHllSize(HllAccumulator& hll); + + void decreaseTotalHllSize(HllAccumulator& hll); + + int64_t maxKey() const { + if (minhash_.empty()) { + return INT64_MIN; + } + return minhash_.rbegin()->first; + } + + int32_t maxSize_; + int32_t hllBuckets_; + int8_t indexBitLength_; + TAllocator* allocator_; + + std::map< + int64_t, + HllAccumulator, + std::less, + TStlAllocator< + std::pair>>> + minhash_; + + size_t hllsTotalEstimatedSerializedSize_{0}; +}; + +} // namespace facebook::velox::common::hll + +#include "velox/common/hyperloglog/KHyperLogLogImpl.h" diff --git a/velox/common/hyperloglog/KHyperLogLogImpl.h b/velox/common/hyperloglog/KHyperLogLogImpl.h new file mode 100644 index 000000000000..757ed4fcf298 --- /dev/null +++ b/velox/common/hyperloglog/KHyperLogLogImpl.h @@ -0,0 +1,438 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/IOUtils.h" +#include "velox/common/hyperloglog/Murmur3Hash128.h" +#include "velox/type/HugeInt.h" + +namespace facebook::velox::common::hll { + +namespace detail { +constexpr uint8_t kVersionByte = 1; +constexpr int64_t kHashOutputHalfRange = INT64_MAX; +constexpr size_t kHeaderSize = sizeof(uint8_t) // version + + 4 * sizeof(int32_t); // maxSize, hllBuckets, minhashSize, hllsTotalSize; + +template +static int64_t hashKey(TJoinKey joinKey) { + int64_t result; + if constexpr (std::is_same_v) { + result = joinKey.toMillis(); + } else if constexpr ( + std::is_same_v || + std::is_same_v) { + return Murmur3Hash128::hash64(&joinKey, sizeof(joinKey), 0); + } else if constexpr (std::is_integral_v) { + result = static_cast(joinKey); + } else if constexpr (std::is_same_v) { + // Cast to double first, then extract bits, based on implicit coercion + double dbl = static_cast(joinKey); + std::memcpy(&result, &dbl, sizeof(result)); + } else if constexpr (std::is_same_v) { + std::memcpy(&result, &joinKey, sizeof(result)); + } else if constexpr (std::is_same_v) { + result = + common::hll::Murmur3Hash128::hash64(joinKey.data(), joinKey.size(), 0); + } else { + VELOX_UNREACHABLE("Unsupported input type: {}", typeid(TJoinKey).name()); + } + + return Murmur3Hash128::hash64(&result, sizeof(result), 0); +} +} // namespace detail + +template +Expected>> +KHyperLogLog::deserialize( + const char* data, + size_t size, + TAllocator* allocator) { + VELOX_RETURN_UNEXPECTED_IF( + size < sizeof(uint8_t), + Status::UserError("Invalid KHyperLogLog data: too small")); + + InputByteStream stream(data); + + uint8_t version = stream.read(); + VELOX_RETURN_UNEXPECTED_IF( + version != detail::kVersionByte, + Status::UserError("Unsupported KHyperLogLog version")); + + VELOX_RETURN_UNEXPECTED_IF( + size < detail::kHeaderSize, + Status::UserError("Invalid KHyperLogLog data: insufficient header size")); + + // Header values + int32_t maxSize = stream.read(); + int32_t hllBuckets = stream.read(); + int32_t minhashSize = stream.read(); + int32_t totalHllSize = stream.read(); + + if (minhashSize == 0) { + VELOX_RETURN_UNEXPECTED_IF( + totalHllSize != 0, + Status::UserError( + "Invalid KHyperLogLog data: minhashSize is 0 but totalHllSize is not 0")); + size_t remainingSize = size - stream.offset(); + VELOX_RETURN_UNEXPECTED_IF( + remainingSize != 0, + Status::UserError( + "Invalid KHyperLogLog data: minhashSize is 0 but extra data remains")); + return std::make_unique>( + maxSize, hllBuckets, allocator); + } + + // Validate remaining size. + size_t expectedRemainingSize = + minhashSize * (sizeof(int32_t) + sizeof(int64_t)) + totalHllSize; + size_t remainingSize = size - stream.offset(); + VELOX_RETURN_UNEXPECTED_IF( + remainingSize < expectedRemainingSize, + Status::UserError( + "Invalid KHyperLogLog data: insufficient data for minhash and HLLs")); + + // Read HLL sizes. + std::vector hllSizes(minhashSize); + stream.copyTo(hllSizes.data(), minhashSize); + + // Read keys. + std::vector keys(minhashSize); + stream.copyTo(keys.data(), minhashSize); + + // Create KHyperLogLog instance. + auto result = std::make_unique>( + maxSize, hllBuckets, allocator); + + // Deserialize HLLs. + const char* hllData = data + stream.offset(); + remainingSize = size - stream.offset(); + + size_t sumOfHllSizes = 0; + for (int32_t hllSize : hllSizes) { + sumOfHllSizes += hllSize; + } + + VELOX_RETURN_UNEXPECTED_IF( + remainingSize < sumOfHllSizes, + Status::UserError( + "Invalid KHyperLogLog data: insufficient data for HLLs")); + + for (int32_t i = 0; i < minhashSize; ++i) { + int32_t hllSize = hllSizes[i]; + int64_t key = keys[i]; + + auto hll = + HllAccumulator::deserialize(hllData, allocator); + result->increaseTotalHllSize(*hll); + result->minhash_.emplace(key, std::move(*hll)); + + hllData += hllSize; + remainingSize -= hllSize; + } + + return result; +} + +template +void KHyperLogLog::serialize(char* outputBuffer) { + OutputByteStream stream(outputBuffer); + + // Write version. + stream.appendOne(detail::kVersionByte); + + // Write header. + stream.appendOne(maxSize_); + stream.appendOne(hllBuckets_); + stream.appendOne(static_cast(minhash_.size())); + + // Write the sum of all HLL sizes. + int32_t totalHllSize = 0; + for (auto& [key, hll] : minhash_) { + totalHllSize += hll.serializedSize(); + } + stream.appendOne(totalHllSize); + + // Write HLL sizes in sorted key order. minHash_ is a sorted map, so no + // additional sorting is necessary for deterministic serialization. + int32_t maxSerializedSize = 0; + for (auto& [key, hll] : minhash_) { + int32_t hllSerializedSize = hll.serializedSize(); + stream.appendOne(hllSerializedSize); + maxSerializedSize = std::max(maxSerializedSize, hllSerializedSize); + } + + // Write keys in sorted order. + for (const auto& [key, hll] : minhash_) { + stream.appendOne(key); + } + + // Write serialized HLLs in sorted key order. + std::string hllBuffer(maxSerializedSize, '\0'); + for (auto& [key, hll] : minhash_) { + const_cast&>(hll).serialize( + hllBuffer.data()); + stream.append(hllBuffer.data(), hll.serializedSize()); + } +} + +template +size_t KHyperLogLog::estimatedSerializedSize() const { + return detail::kHeaderSize + // header: version, maxSize, hllBuckets, + // minhashSize, totalHllSize + minhash_.size() * sizeof(int32_t) + // individual HLL sizes + + minhash_.size() * sizeof(int64_t) + // minhash keys + + hllsTotalEstimatedSerializedSize_; // sum of all HLL serialized sizes +} + +template +template +void KHyperLogLog::add(TJoinKey joinKey, TUii uii) { + update(detail::hashKey(joinKey), uii); +} + +template +void KHyperLogLog::update(int64_t hash, TUii uii) { + auto it = minhash_.end(); + if (!(isExact() || (minhash_.size() > 0 && hash < maxKey()) || + ((it = minhash_.find(hash)) != minhash_.end()))) { + return; + } + + // Get or create HLL for this hash. + HllAccumulator* hll; + if (it == minhash_.end()) { + auto [iterator, inserted] = minhash_.emplace( + hash, + HllAccumulator(indexBitLength_, allocator_)); + hll = &iterator->second; + } else { + hll = &it->second; + decreaseTotalHllSize(*hll); + } + + hll->append(uii); + + increaseTotalHllSize(*hll); + removeOverflowEntries(); +} + +template +int64_t KHyperLogLog::cardinality() const { + if (isExact()) { + return static_cast(minhash_.size()); + } + + // Estimate cardinality by calculating the average spacing (density) between + // stored hash values, then extrapolating to the full hash space. + // The hash range is the full 64-bit hash range (2^64) which does not fit in a + // double, so both the hash range and the density are halved to keep + // proportions mathematically correct. The "-1" is a statistical bias + // correction from "On Synopses for Distinct-Value Estimation Under Multiset + // Operations" by Beyer et. al. + // Use unsigned arithmetic to avoid overflow for large maxKey values. + uint64_t hashesRange = + static_cast(maxKey()) - static_cast(INT64_MIN); + double halfDensity = + static_cast(hashesRange) / (minhash_.size() - 1) / 2.0; + return static_cast( + static_cast(detail::kHashOutputHalfRange) / halfDensity); +} + +template +bool KHyperLogLog::isExact() const { + return static_cast(minhash_.size()) < maxSize_; +} + +template +size_t KHyperLogLog::minhashSize() const { + return minhash_.size(); +} + +template +int64_t KHyperLogLog::exactIntersectionCardinality( + const KHyperLogLog& left, + const KHyperLogLog& right) { + VELOX_CHECK( + left.isExact(), + "exactIntersectionCardinality cannot operate on approximate sets"); + VELOX_CHECK( + right.isExact(), + "exactIntersectionCardinality cannot operate on approximate sets"); + + // Optimize by iterating through the smaller map and checking the larger one. + const auto& smaller = left.minhash_.size() <= right.minhash_.size() + ? left.minhash_ + : right.minhash_; + const auto& larger = left.minhash_.size() <= right.minhash_.size() + ? right.minhash_ + : left.minhash_; + + // Count intersection of keys. + int64_t intersectionCnt = 0; + for (const auto& [key, hll] : smaller) { + if (larger.contains(key)) { + intersectionCnt++; + } + } + + return intersectionCnt; +} + +template +double KHyperLogLog::jaccardIndex( + const KHyperLogLog& left, + const KHyperLogLog& right) { + if (left.minhash_.empty() && right.minhash_.empty()) { + return 1.0; + } + auto smallerSize = std::min(left.minhash_.size(), right.minhash_.size()); + + if (smallerSize == 0) { + return 0.0; + } + + auto itA = left.minhash_.begin(); + auto itB = right.minhash_.begin(); + + auto intersectionCnt = 0; + auto unionCnt = 0; + + // Merge the two sorted sequences, counting intersection along the way. + while (itA != left.minhash_.end() && itB != right.minhash_.end() && + unionCnt < smallerSize) { + if (itA->first < itB->first) { + ++itA; + } else if (itB->first < itA->first) { + ++itB; + } else { + intersectionCnt++; + ++itA; + ++itB; + } + unionCnt++; + } + + return static_cast(intersectionCnt) / smallerSize; +} + +template +std::unique_ptr> +KHyperLogLog::merge( + std::unique_ptr> left, + std::unique_ptr> right) { + // The KHLL with the smaller K will be used as base. This is because if a KHLL + // with a smaller K is merged into a KHLL with a larger K, the smaller KHLL's + // minhash will not cover all of the larger minhash's range. Instead, we want + // to keep only the smallest K number of HLLs in the new KHLL. + if (left->maxSize_ <= right->maxSize_) { + left->mergeWith(*right); + return std::move(left); + } else { + right->mergeWith(*left); + return std::move(right); + } +} + +template +void KHyperLogLog::mergeWith( + const KHyperLogLog& other) { + for (const auto& [key, otherHll] : other.minhash_) { + auto it = minhash_.find(key); + if (it != minhash_.end()) { + decreaseTotalHllSize(it->second); + it->second.mergeWith(otherHll); + increaseTotalHllSize(it->second); + } else { + HllAccumulator newHll( + indexBitLength_, allocator_); + newHll.mergeWith(otherHll); + increaseTotalHllSize(newHll); + minhash_.emplace(key, std::move(newHll)); + } + } + + removeOverflowEntries(); +} + +template +double KHyperLogLog::reidentificationPotential( + int64_t threshold) const { + if (minhash_.empty()) { + return 0.0; + } + int64_t highlyUniqueValues = 0; + + for (const auto& [key, hll] : minhash_) { + if (hll.cardinality() <= threshold) { + highlyUniqueValues++; + } + } + + return static_cast(highlyUniqueValues) / minhash_.size(); +} + +template +folly::F14FastMap +KHyperLogLog::uniquenessDistribution( + int64_t histogramSize) const { + folly::F14FastMap out; + + for (int64_t i = 1; i <= histogramSize; ++i) { + out[i] = 0.0; + } + + int64_t size = minhash_.size(); + if (size == 0) { + return out; + } + + double bucketScale = 1.0 / static_cast(size); + for (const auto& [key, hll] : minhash_) { + int64_t cardinality = hll.cardinality(); + int64_t bucket = std::min(cardinality, histogramSize); + out[bucket] += bucketScale; + } + + return out; +} + +template +void KHyperLogLog::removeOverflowEntries() { + while (static_cast(minhash_.size()) > maxSize_) { + auto maxIt = std::prev(minhash_.end()); + if (maxIt != minhash_.end()) { + decreaseTotalHllSize(maxIt->second); + minhash_.erase(maxIt); + } + } +} + +template +void KHyperLogLog::increaseTotalHllSize( + HllAccumulator& hll) { + hllsTotalEstimatedSerializedSize_ += hll.serializedSize(); +} + +template +void KHyperLogLog::decreaseTotalHllSize( + HllAccumulator& hll) { + hllsTotalEstimatedSerializedSize_ -= hll.serializedSize(); +} + +} // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/SparseHll.cpp b/velox/common/hyperloglog/SparseHll.cpp index 27d290dd10ef..ed1fc10800ea 100644 --- a/velox/common/hyperloglog/SparseHll.cpp +++ b/velox/common/hyperloglog/SparseHll.cpp @@ -34,9 +34,8 @@ inline uint32_t decodeValue(uint32_t entry) { return entry & ((1 << kValueBitLength) - 1); } -int searchIndex( - uint32_t index, - const std::vector>& entries) { +template +int searchIndex(uint32_t index, const VectorType& entries) { int low = 0; int high = entries.size() - 1; @@ -69,7 +68,65 @@ common::InputByteStream initializeInputStream(const char* serialized) { } } // namespace -bool SparseHll::insertHash(uint64_t hash) { +// Static utility functions implementation +int64_t SparseHlls::cardinality(const char* serialized) { + static const int kTotalBuckets = 1 << kIndexBitLength; + + auto stream = initializeInputStream(serialized); + auto size = stream.read(); + + int zeroBuckets = kTotalBuckets - size; + return std::round(linearCounting(zeroBuckets, kTotalBuckets)); +} + +std::string SparseHlls::serializeEmpty(int8_t indexBitLength) { + static const size_t kSize = 4; + + std::string serialized; + serialized.resize(kSize); + + common::OutputByteStream stream(serialized.data()); + stream.appendOne(kPrestoSparseV2); + stream.appendOne(indexBitLength); + stream.appendOne(static_cast(0)); + return serialized; +} + +bool SparseHlls::canDeserialize(const char* input) { + return *reinterpret_cast(input) == kPrestoSparseV2; +} + +int8_t SparseHlls::deserializeIndexBitLength(const char* input) { + common::InputByteStream stream(input); + stream.read(); // Skip version + return stream.read(); // Return indexBitLength +} + +// Template method implementations +template +SparseHll::SparseHll(TAllocator* allocator) + : allocator_(allocator), entries_{TStlAllocator(allocator)} {} + +template +SparseHll::SparseHll(const char* serialized, TAllocator* allocator) + : allocator_(allocator), entries_{TStlAllocator(allocator)} { + common::InputByteStream stream(serialized); + auto version = stream.read(); + VELOX_CHECK_EQ(kPrestoSparseV2, version); + + // Skip indexBitLength from serialized data - we use fixed kIndexBitLength + // internally + stream.read(); + + auto size = stream.read(); + entries_.resize(size); + for (auto i = 0; i < size; i++) { + entries_[i] = stream.read(); + } +} + +template +bool SparseHll::insertHash(uint64_t hash) { auto index = computeIndex(hash, kIndexBitLength); auto value = numberOfLeadingZeros(hash, kIndexBitLength); @@ -88,29 +145,21 @@ bool SparseHll::insertHash(uint64_t hash) { return overLimit(); } -int64_t SparseHll::cardinality() const { +template +int64_t SparseHll::cardinality() const { // Estimate the cardinality using linear counting over the theoretical // 2^kIndexBitLength buckets available due to the fact that we're // recording the raw leading kIndexBitLength of the hash. This produces // much better precision while in the sparse regime. - static const int kTotalBuckets = 1 << kIndexBitLength; + const int kTotalBuckets = 1 << kIndexBitLength; int zeroBuckets = kTotalBuckets - entries_.size(); return std::round(linearCounting(zeroBuckets, kTotalBuckets)); } -// static -int64_t SparseHll::cardinality(const char* serialized) { - static const int kTotalBuckets = 1 << kIndexBitLength; - - auto stream = initializeInputStream(serialized); - auto size = stream.read(); - - int zeroBuckets = kTotalBuckets - size; - return std::round(linearCounting(zeroBuckets, kTotalBuckets)); -} - -void SparseHll::serialize(int8_t indexBitLength, char* output) const { +template +void SparseHll::serialize(int8_t indexBitLength, char* output) + const { common::OutputByteStream stream(output); stream.appendOne(kPrestoSparseV2); stream.appendOne(indexBitLength); @@ -120,75 +169,54 @@ void SparseHll::serialize(int8_t indexBitLength, char* output) const { } } -// static -std::string SparseHll::serializeEmpty(int8_t indexBitLength) { - static const size_t kSize = 4; - - std::string serialized; - serialized.resize(kSize); - - common::OutputByteStream stream(serialized.data()); - stream.appendOne(kPrestoSparseV2); - stream.appendOne(indexBitLength); - stream.appendOne(static_cast(0)); - return serialized; -} - -// static -bool SparseHll::canDeserialize(const char* input) { - return *reinterpret_cast(input) == kPrestoSparseV2; -} - -int32_t SparseHll::serializedSize() const { +template +int32_t SparseHll::serializedSize() const { return 1 /* version */ + 1 /* indexBitLength */ + 2 /* number of entries */ + entries_.size() * 4; } -int32_t SparseHll::inMemorySize() const { +template +int32_t SparseHll::inMemorySize() const { return sizeof(uint32_t) * entries_.size(); } -SparseHll::SparseHll(const char* serialized, HashStringAllocator* allocator) - : entries_{StlAllocator(allocator)} { - auto stream = initializeInputStream(serialized); - - auto size = stream.read(); - entries_.resize(size); - for (auto i = 0; i < size; i++) { - entries_[i] = stream.read(); - } -} - -void SparseHll::mergeWith(const SparseHll& other) { +template +void SparseHll::mergeWith(const SparseHll& other) { auto size = other.entries_.size(); // This check prevents merge aggregation from being performed on - // empty_approx_set(), an empty HyperLogLog. The merge function typically does - // not take an empty HyperLogLog structure as an argument. + // empty_approx_set(), an empty HyperLogLog. The merge function typically + // does not take an empty HyperLogLog structure as an argument. if (size) { mergeWith(size, other.entries_.data()); } } -void SparseHll::mergeWith(const char* serialized) { +template +void SparseHll::mergeWith(const char* serialized) { auto stream = initializeInputStream(serialized); auto size = stream.read(); // This check prevents merge aggregation from being performed on - // empty_approx_set(), an empty HyperLogLog. The merge function typically does - // not take an empty HyperLogLog structure as an argument. + // empty_approx_set(), an empty HyperLogLog. The merge function typically + // does not take an empty HyperLogLog structure as an argument. if (size) { mergeWith( size, reinterpret_cast(serialized + stream.offset())); } } -void SparseHll::mergeWith(size_t otherSize, const uint32_t* otherEntries) { +template +void SparseHll::mergeWith( + size_t otherSize, + const uint32_t* otherEntries) { VELOX_CHECK_GT(otherSize, 0); auto size = entries_.size(); - std::vector merged(size + otherSize); + + auto merged = std::vector>( + size + otherSize, TStlAllocator(allocator_)); int pos = 0; int leftPos = 0; @@ -223,7 +251,8 @@ void SparseHll::mergeWith(size_t otherSize, const uint32_t* otherEntries) { } } -void SparseHll::verify() const { +template +void SparseHll::verify() const { if (entries_.size() <= 1) { return; } @@ -236,11 +265,11 @@ void SparseHll::verify() const { } } -void SparseHll::toDense(DenseHll& denseHll) const { +template +void SparseHll::toDense(DenseHll& denseHll) const { auto indexBitLength = denseHll.indexBitLength(); - for (auto i = 0; i < entries_.size(); i++) { - auto entry = entries_[i]; + for (auto entry : entries_) { auto index = entry >> (32 - indexBitLength); auto shiftedValue = entry << indexBitLength; auto zeros = shiftedValue == 0 ? 32 : __builtin_clz(shiftedValue); @@ -257,4 +286,9 @@ void SparseHll::toDense(DenseHll& denseHll) const { } } +// Explicit template instantiation for HashStringAllocator (default) +template class SparseHll; +// Explicit template instantiation for memory::MemoryPool +template class SparseHll; + } // namespace facebook::velox::common::hll diff --git a/velox/common/hyperloglog/SparseHll.h b/velox/common/hyperloglog/SparseHll.h index e881b5776270..61a3cb8cb29a 100644 --- a/velox/common/hyperloglog/SparseHll.h +++ b/velox/common/hyperloglog/SparseHll.h @@ -18,15 +18,42 @@ #include "velox/common/memory/HashStringAllocator.h" namespace facebook::velox::common::hll { + +class SparseHlls { + public: + /// Returns cardinality estimate from the specified serialized digest. + /// @param serialized Pointer to serialized SparseHll data + /// @return Estimated cardinality of the HyperLogLog + static int64_t cardinality(const char* serialized); + + /// Returns true if 'input' has Presto SparseV2 format. + /// @param input Pointer to serialized data to check + /// @return True if the data is in SparseV2 format, false otherwise + static bool canDeserialize(const char* input); + + /// Creates an empty serialized SparseHll with the specified index bit length. + /// @param indexBitLength Number of bits for indexing (must be in [4,16]) + /// @return Serialized empty SparseHll as a string + static std::string serializeEmpty(int8_t indexBitLength); + + /// Extracts the index bit length from serialized SparseHll data. + /// @param input Pointer to serialized SparseHll data + /// @return The index bit length used in the serialized HLL + static int8_t deserializeIndexBitLength(const char* input); +}; + /// HyperLogLog implementation using sparse storage layout. /// It uses 26-bit buckets and provides high accuracy for low cardinalities. /// Memory usage: 4 bytes for each observed bucket. +template class SparseHll { public: - explicit SparseHll(HashStringAllocator* allocator) - : entries_{StlAllocator(allocator)} {} + template + using TStlAllocator = typename TAllocator::template TStlAllocator; + + explicit SparseHll(TAllocator* allocator); - SparseHll(const char* serialized, HashStringAllocator* allocator); + SparseHll(const char* serialized, TAllocator* allocator); void setSoftMemoryLimit(uint32_t softMemoryLimit) { softNumEntriesLimit_ = softMemoryLimit / 4; @@ -42,17 +69,9 @@ class SparseHll { int64_t cardinality() const; - /// Returns cardinality estimate from the specified serialized digest. - static int64_t cardinality(const char* serialized); - /// Serializes internal state using Presto SparseV2 format. void serialize(int8_t indexBitLength, char* output) const; - static std::string serializeEmpty(int8_t indexBitLength); - - /// Returns true if 'input' has Presto SparseV2 format. - static bool canDeserialize(const char* input); - /// Returns the size of the serialized state without serialising. int32_t serializedSize() const; @@ -63,7 +82,7 @@ class SparseHll { void mergeWith(const char* serialized); /// Merges state into provided instance of DenseHll. - void toDense(DenseHll& denseHll) const; + void toDense(DenseHll& denseHll) const; /// Returns current memory usage. int32_t inMemorySize() const; @@ -84,8 +103,8 @@ class SparseHll { /// A list of observed buckets. Each entry is a 32 bit integer encoding 26-bit /// bucket and 6-bit value (number of zeros in the input hash after the bucket /// + 1). - std::vector> entries_; - + TAllocator* allocator_; + std::vector> entries_; /// Number of entries that can be stored before reaching soft memory limit. uint32_t softNumEntriesLimit_{0}; }; diff --git a/velox/common/hyperloglog/benchmarks/CMakeLists.txt b/velox/common/hyperloglog/benchmarks/CMakeLists.txt index f18f8a0a2b7b..c0b49671fd66 100644 --- a/velox/common/hyperloglog/benchmarks/CMakeLists.txt +++ b/velox/common/hyperloglog/benchmarks/CMakeLists.txt @@ -15,5 +15,7 @@ add_executable(velox_common_hyperloglog_dense_hll_bm DenseHll.cpp) target_link_libraries( - velox_common_hyperloglog_dense_hll_bm velox_common_hyperloglog - Folly::follybenchmark) + velox_common_hyperloglog_dense_hll_bm + velox_common_hyperloglog + Folly::follybenchmark +) diff --git a/velox/common/hyperloglog/benchmarks/DenseHll.cpp b/velox/common/hyperloglog/benchmarks/DenseHll.cpp index 7233280f1d66..0bd112e0a720 100644 --- a/velox/common/hyperloglog/benchmarks/DenseHll.cpp +++ b/velox/common/hyperloglog/benchmarks/DenseHll.cpp @@ -49,7 +49,7 @@ class DenseHllBenchmark { folly::BenchmarkSuspender suspender; HashStringAllocator allocator(pool_); - common::hll::DenseHll hll(hashBits, &allocator); + common::hll::DenseHll<> hll(hashBits, &allocator); suspender.dismiss(); @@ -61,7 +61,7 @@ class DenseHllBenchmark { private: std::string makeSerializedHll(int hashBits, int32_t step) { HashStringAllocator allocator(pool_); - common::hll::DenseHll hll(hashBits, &allocator); + common::hll::DenseHll<> hll(hashBits, &allocator); for (int32_t i = 0; i < 1'000'000; ++i) { auto hash = hashOne(i * step); hll.insertHash(hash); @@ -69,7 +69,7 @@ class DenseHllBenchmark { return serialize(hll); } - static std::string serialize(common::hll::DenseHll& denseHll) { + static std::string serialize(common::hll::DenseHll<>& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); diff --git a/velox/common/hyperloglog/tests/CMakeLists.txt b/velox/common/hyperloglog/tests/CMakeLists.txt index c05a810796fe..cfa5230b68ef 100644 --- a/velox/common/hyperloglog/tests/CMakeLists.txt +++ b/velox/common/hyperloglog/tests/CMakeLists.txt @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_common_hyperloglog_test DenseHllTest.cpp SparseHllTest.cpp) +add_executable( + velox_common_hyperloglog_test + DenseHllTest.cpp + SparseHllTest.cpp + HllAccumulatorTest.cpp +) -add_test(NAME velox_common_hyperloglog_test - COMMAND velox_common_hyperloglog_test) +add_test(NAME velox_common_hyperloglog_test COMMAND velox_common_hyperloglog_test) target_link_libraries( velox_common_hyperloglog_test - PRIVATE velox_common_hyperloglog velox_encode GTest::gtest GTest::gtest_main) + PRIVATE velox_common_hyperloglog velox_encode GTest::gtest GTest::gtest_main +) diff --git a/velox/common/hyperloglog/tests/DenseHllTest.cpp b/velox/common/hyperloglog/tests/DenseHllTest.cpp index c688a6918c47..99eb261dbfb2 100644 --- a/velox/common/hyperloglog/tests/DenseHllTest.cpp +++ b/velox/common/hyperloglog/tests/DenseHllTest.cpp @@ -15,9 +15,8 @@ */ #include "velox/common/hyperloglog/DenseHll.h" +#include #include -#include -#include #include #define XXH_INLINE_ALL @@ -34,22 +33,27 @@ uint64_t hashOne(T value) { return XXH64(&value, sizeof(value), 0); } -class DenseHllTest : public ::testing::TestWithParam { +template +class DenseHllTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } - DenseHll roundTrip(DenseHll& hll) { - auto size = hll.serializedSize(); - std::string serialized; - serialized.resize(size); - hll.serialize(serialized.data()); + void SetUp() override { + if constexpr (std::is_same_v) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } - return DenseHll(serialized.data(), &allocator_); + DenseHll roundTrip(DenseHll& hll) { + auto serialized = this->serialize(hll); + return DenseHll(serialized.data(), allocator_); } - std::string serialize(DenseHll& denseHll) { + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -58,23 +62,19 @@ class DenseHllTest : public ::testing::TestWithParam { } template - void testMergeWith( - int8_t indexBitLength, - const std::vector& left, - const std::vector& right) { - testMergeWith(indexBitLength, left, right, false); - testMergeWith(indexBitLength, left, right, true); + void testMergeWith(const std::vector& left, const std::vector& right) { + testMergeWith(left, right, false); + testMergeWith(left, right, true); } template void testMergeWith( - int8_t indexBitLength, const std::vector& left, const std::vector& right, bool serialized) { - DenseHll hllLeft{indexBitLength, &allocator_}; - DenseHll hllRight{indexBitLength, &allocator_}; - DenseHll expected{indexBitLength, &allocator_}; + DenseHll hllLeft{11, allocator_}; + DenseHll hllRight{11, allocator_}; + DenseHll expected{11, allocator_}; for (auto value : left) { auto hash = hashOne(value); @@ -89,30 +89,51 @@ class DenseHllTest : public ::testing::TestWithParam { } if (serialized) { - auto serializedRight = serialize(hllRight); + auto serializedRight = this->serialize(hllRight); hllLeft.mergeWith(serializedRight.data()); } else { hllLeft.mergeWith(hllRight); } ASSERT_EQ(hllLeft.cardinality(), expected.cardinality()); - ASSERT_EQ(serialize(hllLeft), serialize(expected)); + ASSERT_EQ(this->serialize(hllLeft), this->serialize(expected)); - auto hllLeftSerialized = serialize(hllLeft); + auto hllLeftSerialized = this->serialize(hllLeft); ASSERT_EQ( - DenseHll::cardinality(hllLeftSerialized.data()), + DenseHlls::cardinality(hllLeftSerialized.data()), expected.cardinality()); } std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + TAllocator* allocator_; }; -TEST_P(DenseHllTest, basic) { - int8_t indexBitLength = GetParam(); +using AllocatorTypes = + ::testing::Types; + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) { + return "hsa"; + } else if constexpr (std::is_same_v) { + return "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + } +}; + +TYPED_TEST_SUITE(DenseHllTest, AllocatorTypes, NameGenerator); + +TYPED_TEST(DenseHllTest, basic) { + int8_t indexBitLength = 11; + DenseHll denseHll{indexBitLength, this->allocator_}; - DenseHll denseHll{indexBitLength, &allocator_}; for (int i = 0; i < 1'000; i++) { auto value = i % 17; auto hash = hashOne(value); @@ -131,31 +152,29 @@ TEST_P(DenseHllTest, basic) { ASSERT_EQ(expectedCardinality, denseHll.cardinality()); - DenseHll deserialized = roundTrip(denseHll); + DenseHll deserialized = this->roundTrip(denseHll); ASSERT_EQ(expectedCardinality, deserialized.cardinality()); - auto serialized = serialize(denseHll); - ASSERT_EQ(expectedCardinality, DenseHll::cardinality(serialized.data())); + auto serialized = this->serialize(denseHll); + ASSERT_EQ(expectedCardinality, DenseHlls::cardinality(serialized.data())); } -TEST_P(DenseHllTest, highCardinality) { - int8_t indexBitLength = GetParam(); +TYPED_TEST(DenseHllTest, highCardinality) { + int8_t indexBitLength = 11; + DenseHll denseHll{indexBitLength, this->allocator_}; - DenseHll denseHll{indexBitLength, &allocator_}; for (int i = 0; i < 10'000'000; i++) { auto hash = hashOne(i); denseHll.insertHash(hash); } - if (indexBitLength >= 11) { - ASSERT_NEAR(10'000'000, denseHll.cardinality(), 150'000); - } + ASSERT_NEAR(10'000'000, denseHll.cardinality(), 150'000); - DenseHll deserialized = roundTrip(denseHll); + auto deserialized = this->roundTrip(denseHll); ASSERT_EQ(denseHll.cardinality(), deserialized.cardinality()); - auto serialized = serialize(denseHll); - ASSERT_EQ(denseHll.cardinality(), DenseHll::cardinality(serialized.data())); + auto serialized = this->serialize(denseHll); + ASSERT_EQ(denseHll.cardinality(), DenseHlls::cardinality(serialized.data())); } namespace { @@ -170,62 +189,234 @@ std::vector sequence(T start, T end) { } } // namespace -TEST_P(DenseHllTest, canDeserialize) { +TYPED_TEST(DenseHllTest, mergeWith) { + // small, non-overlapping + this->testMergeWith(sequence(0, 100), sequence(100, 200)); + this->testMergeWith(sequence(100, 200), sequence(0, 100)); + + // small, overlapping + this->testMergeWith(sequence(0, 100), sequence(50, 150)); + this->testMergeWith(sequence(50, 150), sequence(0, 100)); + + // small, same + this->testMergeWith(sequence(0, 100), sequence(0, 100)); + + // large, non-overlapping + this->testMergeWith(sequence(0, 20'000), sequence(20'000, 40'000)); + this->testMergeWith(sequence(20'000, 40'000), sequence(0, 20'000)); + + // large, overlapping + this->testMergeWith(sequence(0, 2'000'000), sequence(1'000'000, 3'000'000)); + this->testMergeWith(sequence(1'000'000, 3'000'000), sequence(0, 2'000'000)); + + // large, same + this->testMergeWith(sequence(0, 2'000'000), sequence(0, 2'000'000)); +} + +// Separate test class for testing various index bit lengths +template +struct AllocatorWithIndexBits { + using AllocatorType = TAllocator; + static constexpr int8_t indexBitLength() { + return IndexBitLength; + } +}; + +template +class DenseHllMergeTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + + std::string serialize(DenseHll& denseHll) { + auto size = denseHll.serializedSize(); + std::string serialized; + serialized.resize(size); + denseHll.serialize(serialized.data()); + return serialized; + } + + template + void testMergeWith( + int8_t indexBitLength, + const std::vector& left, + const std::vector& right) { + testMergeWith(indexBitLength, left, right, false); + testMergeWith(indexBitLength, left, right, true); + } + + template + void testMergeWith( + int8_t indexBitLength, + const std::vector& left, + const std::vector& right, + bool serialized) { + DenseHll hllLeft{indexBitLength, allocator_}; + DenseHll hllRight{indexBitLength, allocator_}; + DenseHll expected{indexBitLength, allocator_}; + + for (auto value : left) { + auto hash = hashOne(value); + hllLeft.insertHash(hash); + expected.insertHash(hash); + } + + for (auto value : right) { + auto hash = hashOne(value); + hllRight.insertHash(hash); + expected.insertHash(hash); + } + + if (serialized) { + auto serializedRight = this->serialize(hllRight); + hllLeft.mergeWith(serializedRight.data()); + } else { + hllLeft.mergeWith(hllRight); + } + + ASSERT_EQ(hllLeft.cardinality(), expected.cardinality()); + ASSERT_EQ(this->serialize(hllLeft), this->serialize(expected)); + + auto hllLeftSerialized = this->serialize(hllLeft); + ASSERT_EQ( + DenseHlls::cardinality(hllLeftSerialized.data()), + expected.cardinality()); + } + + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + HashStringAllocator hsa_{pool_.get()}; + typename TParam::AllocatorType* allocator_; +}; + +using DenseHllMergeTestParams = ::testing::Types< + // HashStringAllocator with all index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + // MemoryPool with all index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits>; + +class ComprehensiveNameGenerator { + public: + template + static std::string GetName(int) { + std::string allocatorName; + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocatorName = "hsa"; + } else if constexpr (std::is_same_v< + typename TParam::AllocatorType, + memory::MemoryPool>) { + allocatorName = "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + return fmt::format("{}_{}", allocatorName, TParam::indexBitLength()); + } +}; + +TYPED_TEST_SUITE( + DenseHllMergeTest, + DenseHllMergeTestParams, + ComprehensiveNameGenerator); + +class DenseHllCanDeserializeTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } +}; + +TEST_F(DenseHllCanDeserializeTest, canDeserialize) { // These are not valid HLL but all pass canDeserialize version only check. - std::vector invalidStrings{ + std::vector invalidStrings{ "AxIRESUhEzNBFCQWYxEjIzI1ISURNidCMlViIjOSNyATBhYSIiJDUyMBIlcSMDUiEUEiESM1ITckQkQTMSMhMyQx", "Aw==", "AyAAABEQAgAlAgAQAQAlMgAhQQABAwAAERAAAAEQACA=", "AwDuUjGFaQ==", "AwkLD8BYTA9BXyg="}; - for (folly::StringPiece& invalidString : invalidStrings) { + for (const auto invalidString : invalidStrings) { auto invalidHll = Base64::decode(invalidString); - EXPECT_TRUE(DenseHll::canDeserialize(invalidHll.c_str())); + EXPECT_TRUE(DenseHlls::canDeserialize(invalidHll.c_str())); EXPECT_FALSE( - DenseHll::canDeserialize(invalidHll.c_str(), invalidHll.length())); + DenseHlls::canDeserialize(invalidHll.c_str(), invalidHll.length())); } - std::vector validStrings{ + std::vector validStrings{ "AwwAQSVCQ4QUNDJkIzMjaSQmaVUxRSVDQ1FaIiJEYkNxNTEzWBQ0M0IhQSRDQ0RkYkRXMSJjM0MSJWMkQlNUNCJHVIVUM1QzQTVEUyE0ExMyV0NYQSR0NFaSFXI1IzKEJkMjEydDUzOVAjJFIkSTUREzM2MjEkg2U7MjIiIkRyJhMzMyMiFEQ1IyIlMyZFMkIyBzNSRGUUMiLMMTQzNDUiEmM0JxY1IjRFRUNlJjJFY0UxMjMWQkNSFRVBQlM1IzNFIiVDIiMhJiMSQVIjQpMjdWJnM0QzKCIRMjNFQnkyRkRjZFVCYjJScVZnM1QzMSYkMXIyIzQTUzMzQjNSRBVkITMEmAM1JiRAMyUzUXU0RDNkMSJBNSNCIiFCVkQxNDQiEkZFdGE5JEUio1VDRzJTJnMkQ0VDM1UTcFQTQSeCJDNCIiNiljJiMUNSdHSkEiNSMjJpJDIlUxElJiVXRCSDU1ERgQM7MyJDUiQmMUUVFSUhIjATJUJCWGYVUkc0NTISMUd4NUNTQzRiMjU0MzUjEkIzIyhRMkEQEnMiNFJGRSUjE2JVNUJTJEN0MiY2IiJGBXQ1QlUCZCcyIyJTQyRENhIzMjSEJEtnYiMjAjQzNEUyUiNJlDQ2gyImMzRDNWJRMhIVFHJUdkNjYSZDE1VlJAQ0IyNRJCMzVSQXFycjJTRlMVhjVDIxOhIiYWUTZmJSRDJFRjJEQ0gCMzYlN1dDVYIpMwMWdERkVzQxIlcjMiZkEEFCEiNCJnMhFjJiMkUmIqNCNSA1MyZiI2NCRkRkJzM1UiFBM1MzVFEoeDNGBhOTMTZEQiUVQiUjQiN0oldrMSRDEzcjNiI1K2REQlIUMiRUF1MzRmJDNDVTUSM3UFQkI1UyYiQRMxdoIlIjZUM1RSckYmMnMjNhc0RCQkFTMoMjISNyJCQSIlclUUU0IVUVQyNCAjNBUjUyIzRFFEBEMlJjEyM1IlRZIjJEUkZiZCFiNEQjU0FVIRMzNEMyIjQxIzMiIzJDM1lkQSMWYSFBMmQTZkQiYSZRhhNiQiMDcxImIzZWclJaRCCBcyYUVRVVRRZlERNBETQiIyRVBRNTMzVCQyJDVTUjJDUyJQI6iAYlMlNDISETVCZiIyZyRDIkREFEMmJEU4JDQkM0F1RSM4MyQjJSMUIjNHMkJCNBNTNiVJJTQVIjFzFhQTcyRTNERCJEU0MiRlYkIjMkO0YiISKFIyExVjYmE0QxWCMRU0NDJTUylDZCNUE0VAUSEWMjISNDU1VzUkMzM1MiQhYzhxNEQTlENHY0MSckRHEmNiFyNjY1QyAkNDM0JCI3JRIzRlEhREZTJhZkUldVI1YzJ2MzIkUzRRQTNDKUNDI1OIESUjUlNQMmNSUzZCM0FDJrMjEiFVSmMkM0VCYzMiQyYjYhBSM0RTYyISQnEzWjNCNIRkg0Q1M4EnNGYzMkoSVEU1MjJWQSRDVXVCMjMjMiIyUyM2NTN0chNSNUMkNHJzITImUjNIRDYTQ0Mi6UZyNUMYIzNGMDcWYkRUN2EzElMyEyMiEzESQjNUIxYzYhMkdjYVMhRCGEEiJDIkF0RCNTUlQiUUITgjUxWSNBRSJjNFYzJEM0UzRoY0IzM0KkJGIjUyITVUQqQzJDUqFDIzIiYxRjMidEVkJUFyc0Q0I3k2MWIyMkZnNiMRVRIjY1MyNEMxIsMmRYEkJnI1MyNCEjNlFyVDUyckVUU3WTMjdFMkFCRDRCIXNFMyElZUExEkMzIiURQiJRQkJGNDM0YUI0ExFRJiRWQVAiJTpENHZERDMyE3OFIkgiRTUlFGQxZCGFUkVDU0QjMiMSMUF1EiKmRDUVZlFUMVR1MjRVUyYxVSQzJDVTImRDJBUjEzUjMTYTUjFkUUIyEiMiYSE5STI0QTNURGIyNkMTNBVFMiMzU0UkNHY0QyJCcxJBZVRlIRIVJWgyUUdVMzQ1UWQ1KDUgJjIyMENDQgBUQiAiRGIzVSUjRmExJDMYRCYRMxJHKSkTIjEzJRA6NDdDUjZzEiMkBCUlM0IlFBZGIiRDMxJlRyQxNkJSJBQiIzIyczIzOTJhQkMjNyRCQ1MiIiOFMSNyVTJlJkE1UkMzUURmJWEYJzMyNDNAIiQDZTITJEImFiIUJEggIUMzE0k2FDVSczRRJEQTM2QSEUFHM2MSVDKEMkZAYUVhJGOERCV2EiYjJTUxRCoxI4NChBRnNSIiNFU4MhIjRoNSM0UkISYiMhIwgXRBQldUglUyVFQyEwVBIhOjNjMRNDd0hoUiJUoyJDMiJTNVQlMSMTpRkzNCNTJlMiNUEkMlYjRGJ4KUszNRISQTIBImIyFCIEFlQVc0MSIiI1JjRjNhI0YkUkRVI0UxVGSRFDgzIkUxRiNgElMiNLJaMyBGc0MzIQRFUyNEMyQnVxNUMSM1U0EzJTNENDdJMRUyMVEyNDFDSUQVRjNVIyE0RTRnMkJkYwEnJyEnUCFDMxRhUiVEMTIwUmNiFENTMyRRdFMjQRIohiJDM2KDMjM2NUNCZlIkJUMzNCMhQjMxEnREESZFUDZ0M0MTJRMkImQxYjcyQ0IjcoIxYyIzonc2JDRCUlA1JEFGJkRHVHMzI2IjUmRTMVJCMyUnJSUzNFQ1QiRjFEMTNmIUckMzRFVBIjU0UzIxI0JTI4JCM0FVYnlkNDFRJEJUQlEiI0ImNCJUIiaGFDYkQ2QSQxdFUjQnVHIzMmghQlNSZUIiRRQ0MAAA==", }; - for (folly::StringPiece& validString : validStrings) { + for (const auto validString : validStrings) { auto validHll = Base64::decode(validString); - EXPECT_TRUE(DenseHll::canDeserialize(validHll.c_str())); - EXPECT_TRUE(DenseHll::canDeserialize(validHll.c_str(), validHll.length())); + EXPECT_TRUE(DenseHlls::canDeserialize(validHll.c_str())); + EXPECT_TRUE(DenseHlls::canDeserialize(validHll.c_str(), validHll.length())); } } -TEST_P(DenseHllTest, mergeWith) { - int8_t indexBitLength = GetParam(); +TYPED_TEST(DenseHllMergeTest, mergeWith) { + int8_t indexBitLength = TypeParam::indexBitLength(); // small, non-overlapping - testMergeWith(indexBitLength, sequence(0, 100), sequence(100, 200)); - testMergeWith(indexBitLength, sequence(100, 200), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(100, 200)); + this->testMergeWith(indexBitLength, sequence(100, 200), sequence(0, 100)); // small, overlapping - testMergeWith(indexBitLength, sequence(0, 100), sequence(50, 150)); - testMergeWith(indexBitLength, sequence(50, 150), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(50, 150)); + this->testMergeWith(indexBitLength, sequence(50, 150), sequence(0, 100)); // small, same - testMergeWith(indexBitLength, sequence(0, 100), sequence(0, 100)); + this->testMergeWith(indexBitLength, sequence(0, 100), sequence(0, 100)); // large, non-overlapping - testMergeWith(indexBitLength, sequence(0, 20'000), sequence(20'000, 40'000)); - testMergeWith(indexBitLength, sequence(20'000, 40'000), sequence(0, 20'000)); + this->testMergeWith( + indexBitLength, sequence(0, 20'000), sequence(20'000, 40'000)); + this->testMergeWith( + indexBitLength, sequence(20'000, 40'000), sequence(0, 20'000)); // large, overlapping - testMergeWith( + this->testMergeWith( indexBitLength, sequence(0, 2'000'000), sequence(1'000'000, 3'000'000)); - testMergeWith( + this->testMergeWith( indexBitLength, sequence(1'000'000, 3'000'000), sequence(0, 2'000'000)); // large, same - testMergeWith(indexBitLength, sequence(0, 2'000'000), sequence(0, 2'000'000)); + this->testMergeWith( + indexBitLength, sequence(0, 2'000'000), sequence(0, 2'000'000)); } - -INSTANTIATE_TEST_SUITE_P( - DenseHllTest, - DenseHllTest, - ::testing::Values(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)); diff --git a/velox/common/hyperloglog/tests/HllAccumulatorTest.cpp b/velox/common/hyperloglog/tests/HllAccumulatorTest.cpp new file mode 100644 index 000000000000..02136168ebaf --- /dev/null +++ b/velox/common/hyperloglog/tests/HllAccumulatorTest.cpp @@ -0,0 +1,332 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/hyperloglog/HllAccumulator.h" + +#define XXH_INLINE_ALL +#include + +#include +#include + +using namespace facebook::velox; +using namespace facebook::velox::common::hll; + +namespace { +const int8_t kDefaultIndexBitLength = 11; +const double kDefaultStandardError = + 1.04 / std::sqrt(1 << kDefaultIndexBitLength); +} // namespace + +template +class HllAccumulatorTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + if constexpr (std::is_same_v) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + HashStringAllocator hsa_{pool_.get()}; + TAllocator* allocator_{}; +}; + +using AllocatorTypes = + ::testing::Types; + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) { + return "hsa"; + } else if constexpr (std::is_same_v) { + return "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + } +}; + +TYPED_TEST_SUITE(HllAccumulatorTest, AllocatorTypes, NameGenerator); + +TYPED_TEST(HllAccumulatorTest, basicInt64) { + // Test SparseHLL. + HllAccumulator accumulator( + kDefaultIndexBitLength, this->allocator_); + + constexpr int64_t numValues = 100; + for (int64_t i = 0; i < numValues; i++) { + accumulator.append(i); + } + + // Sparse HLL should be exact. + EXPECT_EQ(accumulator.cardinality(), numValues); + + // Test DenseHLL. + constexpr int64_t numValuesDense = 10000; + for (int64_t i = 0; i < numValuesDense; i++) { + accumulator.append(i); + } + EXPECT_NEAR( + accumulator.cardinality(), + numValuesDense, + numValuesDense * kDefaultStandardError); +} + +TYPED_TEST(HllAccumulatorTest, basicDouble) { + // Test SparseHLL. + HllAccumulator accumulator( + kDefaultIndexBitLength, this->allocator_); + + constexpr int numValues = 150; + for (int i = 0; i < numValues; i++) { + accumulator.append(static_cast(i) * 1.5); + } + EXPECT_EQ(accumulator.cardinality(), numValues); + + // Test DenseHLL. + constexpr int numValuesDense = 15000; + for (int i = numValues; i < numValuesDense; i++) { + accumulator.append(static_cast(i) * 1.5); + } + EXPECT_NEAR( + accumulator.cardinality(), + numValuesDense, + numValuesDense * kDefaultStandardError); +} + +TYPED_TEST(HllAccumulatorTest, basicStringView) { + // Test SparseHLL. + HllAccumulator accumulator( + kDefaultIndexBitLength, this->allocator_); + + constexpr int numValues = 100; + std::vector strings; + strings.reserve(numValues); + for (int i = 0; i < numValues; i++) { + strings.push_back("value_" + std::to_string(i)); + } + for (const auto& str : strings) { + accumulator.append(StringView(str)); + } + EXPECT_EQ(accumulator.cardinality(), numValues); + + // Test DenseHLL. + constexpr int numValuesDense = 10000; + for (int i = numValues; i < numValuesDense; i++) { + strings.push_back("value_" + std::to_string(i)); + } + for (int i = numValues; i < numValuesDense; i++) { + accumulator.append(StringView(strings[i])); + } + EXPECT_NEAR( + accumulator.cardinality(), + numValuesDense, + numValuesDense * kDefaultStandardError); +} + +TYPED_TEST(HllAccumulatorTest, serde) { + HllAccumulator accumulator( + kDefaultIndexBitLength, this->allocator_); + + constexpr int64_t numValues = 200; + for (int64_t i = 0; i < numValues; i++) { + accumulator.append(i); + } + + auto size = accumulator.serializedSize(); + std::string serialized(size, '\0'); + accumulator.serialize(serialized.data()); + + auto deserialized = HllAccumulator::deserialize( + serialized.data(), this->allocator_); + + EXPECT_EQ(deserialized->cardinality(), numValues); + + // Test round trip + std::string reserialized(deserialized->serializedSize(), '\0'); + deserialized->serialize(reserialized.data()); + EXPECT_EQ(reserialized, serialized); +} + +TYPED_TEST(HllAccumulatorTest, mergeWithBothSparse) { + HllAccumulator accumulator1( + kDefaultIndexBitLength, this->allocator_); + HllAccumulator accumulator2( + kDefaultIndexBitLength, this->allocator_); + + // Add non-overlapping values that keep both sparse. + for (int64_t i = 0; i < 100; i++) { + accumulator1.append(i); + } + for (int64_t i = 100; i < 200; i++) { + accumulator2.append(i); + } + + accumulator1.mergeWith(accumulator2); + + // Resulting accumulator should be sparse and exact. + EXPECT_TRUE(accumulator1.isSparse()); + EXPECT_EQ(accumulator1.cardinality(), 200); +} + +TYPED_TEST(HllAccumulatorTest, mergeWithBothDense) { + HllAccumulator accumulator1( + kDefaultIndexBitLength, this->allocator_); + HllAccumulator accumulator2( + kDefaultIndexBitLength, this->allocator_); + + // Add non-overlapping values that trigger dense mode. + constexpr int64_t numValues = 10000; + for (int64_t i = 0; i < 5000; i++) { + accumulator1.append(i); + } + for (int64_t i = 5000; i < numValues; i++) { + accumulator2.append(i); + } + + accumulator1.mergeWith(accumulator2); + EXPECT_FALSE(accumulator1.isSparse()); + + EXPECT_NEAR( + accumulator1.cardinality(), numValues, numValues * kDefaultStandardError); +} + +TYPED_TEST(HllAccumulatorTest, mergeSparseWithDense) { + HllAccumulator sparseAccumulator( + kDefaultIndexBitLength, this->allocator_); + HllAccumulator denseAccumulator( + kDefaultIndexBitLength, this->allocator_); + + for (int64_t i = 0; i < 100; i++) { + sparseAccumulator.append(i); + } + + constexpr int64_t numValuesDense = 5000; + for (int64_t i = 100; i < numValuesDense; i++) { + denseAccumulator.append(i); + } + + sparseAccumulator.mergeWith(denseAccumulator); + + // mergeWith should convert any sparse accumulator to dense if either is + // dense. Result should be dense and approximate. + + EXPECT_FALSE(sparseAccumulator.isSparse()); + EXPECT_NEAR( + sparseAccumulator.cardinality(), + numValuesDense, + numValuesDense * kDefaultStandardError); +} + +TYPED_TEST(HllAccumulatorTest, mergeDenseWithSparse) { + HllAccumulator sparseAccumulator( + kDefaultIndexBitLength, this->allocator_); + HllAccumulator denseAccumulator( + kDefaultIndexBitLength, this->allocator_); + + for (int64_t i = 0; i < 100; i++) { + sparseAccumulator.append(i); + } + + constexpr int64_t numValuesDense = 5000; + for (int64_t i = 100; i < numValuesDense; i++) { + denseAccumulator.append(i); + } + + denseAccumulator.mergeWith(sparseAccumulator); + + // Result should be dense and approximate. + EXPECT_FALSE(denseAccumulator.isSparse()); + EXPECT_NEAR( + denseAccumulator.cardinality(), + numValuesDense, + numValuesDense * kDefaultStandardError); +} + +TYPED_TEST(HllAccumulatorTest, mergeWithSerializedDataSparse) { + HllAccumulator accumulator1( + kDefaultIndexBitLength, this->allocator_); + HllAccumulator accumulator2( + kDefaultIndexBitLength, this->allocator_); + + for (int64_t i = 0; i < 100; i++) { + accumulator1.append(i); + } + for (int64_t i = 100; i < 200; i++) { + accumulator2.append(i); + } + + auto size = accumulator2.serializedSize(); + std::string buffer(size, '\0'); + accumulator2.serialize(buffer.data()); + + // Merge with serialized data (should remain sparse). + accumulator1.mergeWith(StringView(buffer), this->allocator_); + + // Should be sparse and exact. + EXPECT_TRUE(accumulator1.isSparse()); + EXPECT_EQ(accumulator1.cardinality(), 200); +} + +TYPED_TEST(HllAccumulatorTest, mergeWithOverlappingDataSparse) { + HllAccumulator accumulator1( + kDefaultIndexBitLength, this->allocator_); + HllAccumulator accumulator2( + kDefaultIndexBitLength, this->allocator_); + + for (int64_t i = 0; i < 100; i++) { + accumulator1.append(i); + } + for (int64_t i = 50; i < 150; i++) { + accumulator2.append(i); + } + + accumulator1.mergeWith(accumulator2); + + // Sparse HLL should be sparse and exact: union of [0, 100) and [50, 150) = + // [0, 150) + EXPECT_TRUE(accumulator1.isSparse()); + EXPECT_EQ(accumulator1.cardinality(), 150); +} + +TYPED_TEST(HllAccumulatorTest, mergeUninitializedAccumulator) { + HllAccumulator accumulator(this->allocator_); + HllAccumulator initialized( + kDefaultIndexBitLength, this->allocator_); + + constexpr int64_t numValues = 100; + for (int64_t i = 0; i < numValues; i++) { + initialized.append(i); + } + + auto size = initialized.serializedSize(); + std::string buffer(size, '\0'); + initialized.serialize(buffer.data()); + + accumulator.mergeWith(StringView(buffer), this->allocator_); + + EXPECT_EQ(accumulator.cardinality(), numValues); +} diff --git a/velox/common/hyperloglog/tests/KHyperLogLogTest.cpp b/velox/common/hyperloglog/tests/KHyperLogLogTest.cpp new file mode 100644 index 000000000000..0303cf1c25fd --- /dev/null +++ b/velox/common/hyperloglog/tests/KHyperLogLogTest.cpp @@ -0,0 +1,705 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/hyperloglog/KHyperLogLog.h" +#include "velox/common/memory/Memory.h" + +#include +#include +#include +#include +#include "velox/type/Timestamp.h" + +using namespace facebook::velox::common::hll; +using namespace facebook::velox::memory; +using namespace facebook::velox; + +namespace { +const int32_t kDefaultNumBuckets = 256; +// Theoretical relative standard error formula from the HyperLogLog paper +// (Flajolet et al.): 1.04 / sqrt(num buckets) +const double kDefaultStandardError = 1.04 / std::sqrt(kDefaultNumBuckets); +} // namespace + +class KHyperLogLogTest : public ::testing::Test { + public: + void SetUp() override { + facebook::velox::memory::MemoryManager::initialize({}); + pool_ = facebook::velox::memory::memoryManager()->addLeafPool(); + hsa_ = std::make_unique(pool_.get()); + allocator_ = hsa_.get(); + } + + void TearDown() override { + // Clean up allocator before pool. + hsa_.reset(); + pool_.reset(); + } + + protected: + std::shared_ptr pool_; + std::unique_ptr hsa_; + HashStringAllocator* allocator_{}; + + // Helper function to generate random values with quadratic distribution. + // Squaring a uniform random [0.0, 1.0] heavily skews results toward zero + // (e.g., ~70% of values fall in the lower half). This simulates realistic + // data where most values have low cardinality and few have high cardinality, + // which is critical for testing KHLL's privacy features under conditions + // where re-identification risk is highest. + int64_t randomLong(int64_t range) { + double random = folly::Random::randDouble01(); + return static_cast(std::pow(random, 2.0) * range); + } + + // Creates a KHyperLogLog with specific values for testing. + std::unique_ptr> createKHLL( + const std::vector& values, + const std::vector& uiis) { + EXPECT_EQ(values.size(), uiis.size()); + auto khll = std::make_unique>( + allocator_); + + for (size_t i = 0; i < values.size(); ++i) { + khll->add(values[i], uiis[i]); + } + + return khll; + } +}; + +TEST_F(KHyperLogLogTest, basicCardinality) { + auto khll = + std::make_unique>(allocator_); + + // Empty KHLL should have cardinality 0. + EXPECT_EQ(0, khll->cardinality()); + EXPECT_TRUE(khll->isExact()); + + // Add some values. + for (int64_t i = 0; i < 100; ++i) { + khll->add(i, randomLong(100)); + } + + // Should be exact since it is under the default max size. + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); +} + +TEST_F(KHyperLogLogTest, cardinalityAccuracy) { + // Test representative precision levels (low and high) to ensure + // accuracy across different bucket configurations: 4 (16 buckets) and + // 12 (4096 buckets). + const int trials = 30; + + for (int indexBits : {4, 12}) { + const int numberOfBuckets = 1 << indexBits; + const int maxCardinality = numberOfBuckets * 2; + + std::vector errors; + + for (int trial = 0; trial < trials; ++trial) { + auto khll = std::make_unique>( + KHyperLogLog::kDefaultMaxSize, + numberOfBuckets, + allocator_); + + for (int cardinality = 1; cardinality <= maxCardinality; cardinality++) { + khll->add(folly::Random::rand64(), 0L); + + // Sample every 20% of bucket count (vs Java's 10%) to reduce test + // time + if (cardinality % (numberOfBuckets / 5) == 0) { + double error = + (static_cast(khll->cardinality()) - cardinality) / + cardinality; + errors.push_back(std::abs(error)); + } + } + } + + // Calculate standard deviation + double mean = 0.0; + for (double error : errors) { + mean += error; + } + mean /= errors.size(); + + double variance = 0.0; + for (double error : errors) { + variance += (error - mean) * (error - mean); + } + variance /= errors.size(); + double stdDev = std::sqrt(variance); + + EXPECT_LE(stdDev, kDefaultStandardError) + << "Cardinality mismatch at indexBits " << indexBits << ", bucket " + << numberOfBuckets; + } +} + +TEST_F(KHyperLogLogTest, mergeWith) { + auto khll1 = createKHLL( + std::vector{0, 1, 2, 3, 4, 5}, + std::vector{10, 11, 12, 13, 14, 15}); + + auto khll2 = createKHLL( + std::vector{3, 4, 5, 6, 7, 8}, + std::vector{13, 14, 15, 16, 17, 18}); + + auto expected = createKHLL( + std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8}, + std::vector{10, 11, 12, 13, 14, 15, 16, 17, 18}); + + khll1->mergeWith(*khll2); + + EXPECT_EQ(expected->cardinality(), khll1->cardinality()); + EXPECT_EQ( + expected->reidentificationPotential(10), + khll1->reidentificationPotential(10)); +} + +TEST_F(KHyperLogLogTest, merge) { + // Helpers to create a KHLL with given maxSize and data + auto createSmaller = [&]() { + const auto smallerSize = 5; + auto khll = std::make_unique>( + smallerSize, kDefaultNumBuckets, allocator_); + for (size_t i = 0; i < smallerSize; ++i) { + khll->add(i, i); + } + return khll; + }; + + auto createLarger = [&]() { + const auto largerSize = 10; + auto khll = std::make_unique>( + largerSize, kDefaultNumBuckets, allocator_); + for (size_t i = 0; i < largerSize; ++i) { + khll->add(i, i); + } + return khll; + }; + + // Test merge(left, right) + auto smallerFirst = KHyperLogLog::merge( + createSmaller(), createLarger()); + + // Test merge(right, left) - should produce same result + auto largerFist = KHyperLogLog::merge( + createLarger(), createSmaller()); + + EXPECT_EQ(smallerFirst->cardinality(), largerFist->cardinality()); + + // Explicitly merging smaller into larger should produce different results. + auto larger = createLarger(); + larger->mergeWith(*createSmaller()); + EXPECT_NE(larger->cardinality(), largerFist->cardinality()); +} + +TEST_F(KHyperLogLogTest, serde) { + // Test small serialization + std::vector values; + std::vector uiis; + + for (int64_t i = 0; i < 1000; ++i) { + values.push_back(i); + uiis.push_back(randomLong(100)); + } + + auto khll = createKHLL(values, uiis); + + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserializedResult = + KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserializedResult.hasValue()); + auto& deserialized = deserializedResult.value(); + + EXPECT_EQ(khll->cardinality(), deserialized->cardinality()); + EXPECT_EQ( + khll->reidentificationPotential(10), + deserialized->reidentificationPotential(10)); + + // Test round-trip + std::string reserializeBuffer(totalSize, '\0'); + deserialized->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + + // Test empty KHLL round-trip + auto emptyKhll = createKHLL({}, {}); + size_t emptySize = emptyKhll->estimatedSerializedSize(); + std::string emptyOutputBuffer(emptySize, '\0'); + emptyKhll->serialize(emptyOutputBuffer.data()); + auto deserializedEmptyResult = + KHyperLogLog::deserialize( + emptyOutputBuffer.data(), emptyOutputBuffer.size(), allocator_); + ASSERT_TRUE(deserializedEmptyResult.hasValue()); + auto& deserializedEmpty = deserializedEmptyResult.value(); + EXPECT_EQ(deserializedEmpty->cardinality(), 0); + std::string reserializedEmptyOutputBuffer(emptySize, '\0'); + deserializedEmpty->serialize(reserializedEmptyOutputBuffer.data()); + EXPECT_EQ(emptyOutputBuffer, reserializedEmptyOutputBuffer); +} + +TEST_F(KHyperLogLogTest, uniquenessDistribution) { + const int histogramSize = 256; + const int count = 1000; + + auto khll = + std::make_unique>(allocator_); + std::map> valueToUiis; + + for (int i = 0; i < count; ++i) { + int64_t uii = randomLong(histogramSize); + int64_t value = randomLong(count); + khll->add(value, uii); + valueToUiis[value].insert(uii); + } + auto khllHistogram = khll->uniquenessDistribution(histogramSize); + + // Build the actual histogram + std::map actualHistogram; + int size = valueToUiis.size(); + + for (const auto& [value, uiiSet] : valueToUiis) { + int64_t bucket = std::min( + static_cast(uiiSet.size()), + static_cast(histogramSize)); + actualHistogram[bucket] += 1.0 / size; + } + + // Verify histogram accuracy (with some tolerance for approximation) + for (int64_t i = 1; i < histogramSize; ++i) { + double expected = actualHistogram.count(i) ? actualHistogram[i] : 0.0; + double khllEstimated = khllHistogram.count(i) ? khllHistogram[i] : 0.0; + + // Use 10% tolerance since the values of uniqueness distribution are a sum + // of 1 / size of minHash, and not the cardinality estimates. + EXPECT_NEAR(khllEstimated, expected, expected * 0.1) + << "Histogram mismatch at bucket " << i; + } +} + +TEST_F(KHyperLogLogTest, reidentificationPotential) { + auto khll = + std::make_unique>(allocator_); + std::map> valueToUiis; + + const int count = 1000; + for (int i = 0; i < count; ++i) { + int64_t uii = randomLong(100); + int64_t value = randomLong(count); + khll->add(value, uii); + valueToUiis[value].insert(uii); + } + + // Test different thresholds + for (int threshold = 1; threshold < 10; ++threshold) { + double khllEstimated = khll->reidentificationPotential(threshold); + + // Calculate the actual reidentification potential + int highlyUniqueCount = 0; + for (const auto& [value, uiiSet] : valueToUiis) { + if (static_cast(uiiSet.size()) <= threshold) { + highlyUniqueCount++; + } + } + + double expected = + static_cast(highlyUniqueCount) / valueToUiis.size(); + + if (expected > 0) { + EXPECT_NEAR(khllEstimated, expected, expected * kDefaultStandardError) + << "Reidentification potential mismatch for threshold " << threshold; + } + } +} + +TEST_F(KHyperLogLogTest, exactIntersectionCardinality) { + auto khll1 = createKHLL( + std::vector{1, 2, 3, 4, 5}, + std::vector{10, 20, 30, 40, 50}); + + auto khll2 = createKHLL( + std::vector{3, 4, 5, 6, 7}, + std::vector{30, 40, 50, 60, 70}); + + EXPECT_TRUE(khll1->isExact()); + EXPECT_TRUE(khll2->isExact()); + + int64_t intersection = + KHyperLogLog::exactIntersectionCardinality( + *khll1, *khll2); + // Values 3, 4, 5 are in both + EXPECT_EQ(3, intersection); +} + +TEST_F(KHyperLogLogTest, jaccardIndex) { + // Test with larger datasets. + // Create two KHLLs where one is a subset of the other + const int64_t set1Size = 100000; + const int64_t set2Size = 150000; + + auto khll1 = + std::make_unique>(allocator_); + auto khll2 = + std::make_unique>(allocator_); + + // Add values 0 to 99,999 to khll1 + for (int64_t i = 0; i < set1Size; ++i) { + khll1->add(i, randomLong(100)); + } + + // Add values 0 to 149,999 to khll2 (includes all of khll1) + for (int64_t i = 0; i < set2Size; ++i) { + khll2->add(i, randomLong(100)); + } + + double jaccard = + KHyperLogLog::jaccardIndex(*khll1, *khll2); + + // Expected Jaccard = |intersection| / |union| + // Intersection: 100,000 (all of set1) + // Union: 150,000 (all of set2) + // Jaccard = 100,000 / 150,000 = 2/3 ≈ 0.6667 + double expectedJaccard = static_cast(set1Size) / set2Size; + + EXPECT_NEAR( + jaccard, expectedJaccard, expectedJaccard * kDefaultStandardError); +} + +TEST_F(KHyperLogLogTest, largeDataset) { + auto khll = + std::make_unique>(allocator_); + + const int count = 200000; + std::unordered_set uniqueValues; + + for (int i = 0; i < count; ++i) { + int64_t value = folly::Random::rand64(); + int64_t uii = randomLong(100); + khll->add(value, uii); + uniqueValues.insert(value); + } + + EXPECT_FALSE(khll->isExact()); + + int64_t expected = static_cast(uniqueValues.size()); + int64_t khllEstimated = khll->cardinality(); + + EXPECT_NEAR(expected, khllEstimated, khllEstimated * kDefaultStandardError); +} + +TEST_F(KHyperLogLogTest, minhashSize) { + auto khll = + std::make_unique>(allocator_); + + EXPECT_EQ(0, khll->minhashSize()); + + khll->add(1, 10); + khll->add(2, 20); + khll->add(3, 30); + + EXPECT_EQ(3, khll->minhashSize()); +} + +TEST_F(KHyperLogLogTest, estimatedSizes) { + auto khll = + std::make_unique>(allocator_); + + size_t initialEstimatedSerSize = khll->estimatedSerializedSize(); + + for (int64_t i = 0; i < 100; ++i) { + khll->add(i, randomLong(100)); + } + + size_t finalEstimatedSerSize = khll->estimatedSerializedSize(); + + // Sizes should increase after adding data. + EXPECT_GT(finalEstimatedSerSize, initialEstimatedSerSize); + + // Verify the estimated size is accurate. + std::string serializedBuffer(finalEstimatedSerSize, '\0'); + khll->serialize(serializedBuffer.data()); + size_t actualSerSize = serializedBuffer.size(); + + EXPECT_LE(actualSerSize, finalEstimatedSerSize) + << "Actual serialized size exceeds estimate - potential buffer overflow"; + + EXPECT_NEAR( + actualSerSize, + finalEstimatedSerSize, + finalEstimatedSerSize * kDefaultStandardError); +} + +TEST_F(KHyperLogLogTest, differentJoinKeyUIITypes) { + // Test different TJoinKey, TUii combinations: + // int32_t TJoinKey, int32_t TUii + { + auto khll = std::make_unique>( + allocator_); + for (int32_t i = 0; i < 100; ++i) { + khll->add(i, i); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserializedResult = + KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserializedResult.hasValue()); + auto& deserialized = deserializedResult.value(); + std::string reserializeBuffer(totalSize, '\0'); + deserialized->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // uint32_t TJoinKey, uint32_t TUii + { + auto khll = std::make_unique>( + allocator_); + for (uint32_t i = 0; i < 100; ++i) { + khll->add(i % 10, i); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(10, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserializedResult = + KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserializedResult.hasValue()); + auto& deserialized = deserializedResult.value(); + std::string reserializeBuffer(totalSize, '\0'); + deserialized->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // int16_t TJoinKey, int16_t TUii + { + auto khll = std::make_unique>( + allocator_); + for (int16_t i = 0; i < 10000; ++i) { + khll->add(i % 100, i); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // uint16_t TJoinKey, uint16_t TUii + { + auto khll = std::make_unique>( + allocator_); + for (uint16_t i = 0; i < 100; ++i) { + khll->add(i, i); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // int8_t TJoinKey, int8_t TUii + { + auto khll = + std::make_unique>(allocator_); + for (int8_t i = -100; i < 100; ++i) { + khll->add(i, i); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(200, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // uint8_t TJoinKey, uint8_t TUii + { + auto khll = std::make_unique>( + allocator_); + for (uint8_t i = 0; i < 100; ++i) { + khll->add(i, i); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // float TJoinKey, float TUii + { + auto khll = + std::make_unique>(allocator_); + for (int i = 0; i < 10; ++i) { + khll->add(static_cast(i), static_cast(i)); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(10, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // double TJoinKey, double TUii + { + auto khll = + std::make_unique>(allocator_); + for (int i = 0; i < 100; ++i) { + khll->add(static_cast(i), static_cast(i)); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // StringView TJoinKey, StringView TUii + { + auto khll = std::make_unique>( + allocator_); + std::vector strings; + strings.reserve(100); + for (int i = 0; i < 100; ++i) { + strings.push_back("key_" + std::to_string(i)); + } + for (int i = 0; i < 100; ++i) { + khll->add(StringView(strings[i]), StringView(strings[i])); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // Timestamp TJoinKey, Timestamp TUii + { + auto khll = std::make_unique>( + allocator_); + for (int i = 0; i < 100; ++i) { + Timestamp ts(i * 1000, 0); + khll->add(ts, ts); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } + + // int128_t TJoinKey, int128_t TUii + { + auto khll = std::make_unique>( + allocator_); + for (int i = 0; i < 100; ++i) { + // Create 128-bit values with different upper and lower 64 bits + int128_t value = (static_cast(i) << 64) | (i + 1000); + khll->add(value, value); + } + EXPECT_TRUE(khll->isExact()); + EXPECT_EQ(100, khll->cardinality()); + + // Test round trip + size_t totalSize = khll->estimatedSerializedSize(); + std::string outputBuffer(totalSize, '\0'); + khll->serialize(outputBuffer.data()); + auto deserialized = KHyperLogLog::deserialize( + outputBuffer.data(), outputBuffer.size(), allocator_); + ASSERT_TRUE(deserialized.hasValue()); + std::string reserializeBuffer(totalSize, '\0'); + deserialized.value()->serialize(reserializeBuffer.data()); + EXPECT_EQ(outputBuffer, reserializeBuffer); + } +} diff --git a/velox/common/hyperloglog/tests/SparseHllTest.cpp b/velox/common/hyperloglog/tests/SparseHllTest.cpp index 299e2c8aebd9..9d2ae9238d0d 100644 --- a/velox/common/hyperloglog/tests/SparseHllTest.cpp +++ b/velox/common/hyperloglog/tests/SparseHllTest.cpp @@ -18,6 +18,7 @@ #define XXH_INLINE_ALL #include +#include #include using namespace facebook::velox; @@ -28,12 +29,21 @@ uint64_t hashOne(T value) { return XXH64(&value, sizeof(value), 0); } +template class SparseHllTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + void SetUp() override { + if constexpr (std::is_same_v) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + template void testMergeWith(const std::vector& left, const std::vector& right) { testMergeWith(left, right, false); @@ -45,9 +55,9 @@ class SparseHllTest : public ::testing::Test { const std::vector& left, const std::vector& right, bool serialized) { - SparseHll hllLeft{&allocator_}; - SparseHll hllRight{&allocator_}; - SparseHll expected{&allocator_}; + SparseHll hllLeft{allocator_}; + SparseHll hllRight{allocator_}; + SparseHll expected{allocator_}; for (auto value : left) { auto hash = hashOne(value); @@ -77,16 +87,20 @@ class SparseHllTest : public ::testing::Test { auto hllLeftSerialized = serialize(11, hllLeft); ASSERT_EQ( - SparseHll::cardinality(hllLeftSerialized.data()), + SparseHlls::cardinality(hllLeftSerialized.data()), expected.cardinality()); } - SparseHll roundTrip(SparseHll& hll) { - auto serialized = serialize(11, hll); - return SparseHll(serialized.data(), &allocator_); + SparseHll roundTrip( + SparseHll& hll, + int8_t indexBitLength = 11) { + auto serialized = serialize(indexBitLength, hll); + return SparseHll(serialized.data(), allocator_); } - std::string serialize(int8_t indexBitLength, const SparseHll& sparseHll) { + std::string serialize( + int8_t indexBitLength, + const SparseHll& sparseHll) { auto size = sparseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -94,7 +108,7 @@ class SparseHllTest : public ::testing::Test { return serialized; } - std::string serialize(DenseHll& denseHll) { + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -104,11 +118,32 @@ class SparseHllTest : public ::testing::Test { std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + TAllocator* allocator_; +}; + +using AllocatorTypes = + ::testing::Types; + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) { + return "hsa"; + } else if constexpr (std::is_same_v) { + return "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + } }; -TEST_F(SparseHllTest, basic) { - SparseHll sparseHll{&allocator_}; +TYPED_TEST_SUITE(SparseHllTest, AllocatorTypes, NameGenerator); + +TYPED_TEST(SparseHllTest, basic) { + SparseHll sparseHll{this->allocator_}; for (int i = 0; i < 1'000; i++) { auto value = i % 17; auto hash = hashOne(value); @@ -118,16 +153,16 @@ TEST_F(SparseHllTest, basic) { sparseHll.verify(); ASSERT_EQ(17, sparseHll.cardinality()); - auto deserialized = roundTrip(sparseHll); + auto deserialized = this->roundTrip(sparseHll); deserialized.verify(); ASSERT_EQ(17, deserialized.cardinality()); - auto serialized = serialize(11, sparseHll); - ASSERT_EQ(17, SparseHll::cardinality(serialized.data())); + auto serialized = this->serialize(11, sparseHll); + ASSERT_EQ(17, SparseHlls::cardinality(serialized.data())); } -TEST_F(SparseHllTest, highCardinality) { - SparseHll sparseHll{&allocator_}; +TYPED_TEST(SparseHllTest, highCardinality) { + SparseHll sparseHll{this->allocator_}; for (int i = 0; i < 1'000; i++) { auto hash = hashOne(i); sparseHll.insertHash(hash); @@ -136,12 +171,12 @@ TEST_F(SparseHllTest, highCardinality) { sparseHll.verify(); ASSERT_EQ(1'000, sparseHll.cardinality()); - auto deserialized = roundTrip(sparseHll); + auto deserialized = this->roundTrip(sparseHll); deserialized.verify(); ASSERT_EQ(1'000, deserialized.cardinality()); - auto serialized = serialize(11, sparseHll); - ASSERT_EQ(1'000, SparseHll::cardinality(serialized.data())); + auto serialized = this->serialize(11, sparseHll); + ASSERT_EQ(1'000, SparseHlls::cardinality(serialized.data())); } namespace { @@ -156,30 +191,80 @@ std::vector sequence(T start, T end) { } } // namespace -TEST_F(SparseHllTest, mergeWith) { +TYPED_TEST(SparseHllTest, mergeWith) { // with overlap - testMergeWith(sequence(0, 100), sequence(50, 150)); - testMergeWith(sequence(50, 150), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(50, 150)); + this->testMergeWith(sequence(50, 150), sequence(0, 100)); // no overlap - testMergeWith(sequence(0, 100), sequence(200, 300)); - testMergeWith(sequence(200, 300), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(200, 300)); + this->testMergeWith(sequence(200, 300), sequence(0, 100)); // idempotent - testMergeWith(sequence(0, 100), sequence(0, 100)); + this->testMergeWith(sequence(0, 100), sequence(0, 100)); // empty sequence - testMergeWith(sequence(0, 100), {}); - testMergeWith({}, sequence(100, 300)); + this->testMergeWith(sequence(0, 100), {}); + this->testMergeWith({}, sequence(100, 300)); +} + +TYPED_TEST(SparseHllTest, toDense) { + int8_t indexBitLength = 11; + + SparseHll sparseHll{this->allocator_}; + DenseHll expectedHll{indexBitLength, this->allocator_}; + for (int i = 0; i < 1'000; i++) { + auto hash = hashOne(i); + sparseHll.insertHash(hash); + expectedHll.insertHash(hash); + } + + DenseHll denseHll{indexBitLength, this->allocator_}; + sparseHll.toDense(denseHll); + ASSERT_EQ(denseHll.cardinality(), expectedHll.cardinality()); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); +} + +TYPED_TEST(SparseHllTest, testNumberOfZeros) { + int8_t indexBitLength = 11; + for (int i = 0; i < 64 - indexBitLength; ++i) { + auto hash = 1ull << i; + SparseHll sparseHll(this->allocator_); + sparseHll.insertHash(hash); + DenseHll expectedHll(indexBitLength, this->allocator_); + expectedHll.insertHash(hash); + DenseHll denseHll(indexBitLength, this->allocator_); + sparseHll.toDense(denseHll); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); + } } -class SparseHllToDenseTest : public ::testing::TestWithParam { +template +struct AllocatorWithIndexBits { + using AllocatorType = TAllocator; + static constexpr int8_t indexBitLength() { + return IndexBitLength; + } +}; + +template +class SparseHllToDenseTest : public ::testing::Test { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } - std::string serialize(DenseHll& denseHll) { + void SetUp() override { + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocator_ = &hsa_; + } else { + allocator_ = pool_.get(); + } + } + + std::string serialize(DenseHll& denseHll) { auto size = denseHll.serializedSize(); std::string serialized; serialized.resize(size); @@ -189,41 +274,93 @@ class SparseHllToDenseTest : public ::testing::TestWithParam { std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; - HashStringAllocator allocator_{pool_.get()}; + HashStringAllocator hsa_{pool_.get()}; + typename TParam::AllocatorType* allocator_; }; -TEST_P(SparseHllToDenseTest, toDense) { - int8_t indexBitLength = GetParam(); +using SparseHllToDenseTestParams = ::testing::Types< + // HashStringAllocator with various index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + // MemoryPool with various index bit lengths + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits, + AllocatorWithIndexBits>; + +class ToDenseNameGenerator { + public: + template + static std::string GetName(int) { + std::string allocatorName; + if constexpr (std::is_same_v< + typename TParam::AllocatorType, + HashStringAllocator>) { + allocatorName = "hsa"; + } else if constexpr (std::is_same_v< + typename TParam::AllocatorType, + memory::MemoryPool>) { + allocatorName = "pool"; + } else { + VELOX_UNREACHABLE( + "Only HashStringAllocator and MemoryPool are supported allocator types."); + } + return fmt::format("{}_{}", allocatorName, TParam::indexBitLength()); + } +}; - SparseHll sparseHll{&allocator_}; - DenseHll expectedHll{indexBitLength, &allocator_}; +TYPED_TEST_SUITE( + SparseHllToDenseTest, + SparseHllToDenseTestParams, + ToDenseNameGenerator); + +TYPED_TEST(SparseHllToDenseTest, toDense) { + int8_t indexBitLength = TypeParam::indexBitLength(); + + SparseHll sparseHll{this->allocator_}; + DenseHll expectedHll{indexBitLength, this->allocator_}; for (int i = 0; i < 1'000; i++) { auto hash = hashOne(i); sparseHll.insertHash(hash); expectedHll.insertHash(hash); } - DenseHll denseHll{indexBitLength, &allocator_}; + DenseHll denseHll{indexBitLength, this->allocator_}; sparseHll.toDense(denseHll); ASSERT_EQ(denseHll.cardinality(), expectedHll.cardinality()); - ASSERT_EQ(serialize(denseHll), serialize(expectedHll)); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); } -TEST_P(SparseHllToDenseTest, testNumberOfZeros) { - auto indexBitLength = GetParam(); +TYPED_TEST(SparseHllToDenseTest, testNumberOfZeros) { + auto indexBitLength = TypeParam::indexBitLength(); for (int i = 0; i < 64 - indexBitLength; ++i) { auto hash = 1ull << i; - SparseHll sparseHll(&allocator_); + SparseHll sparseHll(this->allocator_); sparseHll.insertHash(hash); - DenseHll expectedHll(indexBitLength, &allocator_); + DenseHll expectedHll(indexBitLength, this->allocator_); expectedHll.insertHash(hash); - DenseHll denseHll(indexBitLength, &allocator_); + DenseHll denseHll(indexBitLength, this->allocator_); sparseHll.toDense(denseHll); - ASSERT_EQ(serialize(denseHll), serialize(expectedHll)); + ASSERT_EQ(this->serialize(denseHll), this->serialize(expectedHll)); } } - -INSTANTIATE_TEST_SUITE_P( - SparseHllToDenseTest, - SparseHllToDenseTest, - ::testing::Values(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)); diff --git a/velox/common/io/IoStatistics.h b/velox/common/io/IoStatistics.h index 2111a8877b47..1371d977a594 100644 --- a/velox/common/io/IoStatistics.h +++ b/velox/common/io/IoStatistics.h @@ -43,6 +43,24 @@ struct OperationCounters { class IoCounter { public: + IoCounter& operator=(const IoCounter& other) noexcept { + if (this != &other) { + count_.store( + other.count_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + sum_.store( + other.sum_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + min_.store( + other.min_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + max_.store( + other.max_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + } + return *this; + } + uint64_t count() const { return count_; } diff --git a/velox/common/memory/Allocation.cpp b/velox/common/memory/Allocation.cpp index 884af7c82cf5..63fb4db4baa7 100644 --- a/velox/common/memory/Allocation.cpp +++ b/velox/common/memory/Allocation.cpp @@ -35,10 +35,11 @@ void Allocation::append(uint8_t* address, MachinePageCount numPages) { runs_.empty() || address != runs_.back().data(), "Appending a duplicate address into a PageRun"); if (FOLLY_UNLIKELY(numPages > Allocation::PageRun::kMaxPagesInRun)) { - VELOX_MEM_ALLOC_ERROR(fmt::format( - "The number of pages to append {} exceeds the PageRun limit {}", - numPages, - Allocation::PageRun::kMaxPagesInRun)); + VELOX_MEM_ALLOC_ERROR( + fmt::format( + "The number of pages to append {} exceeds the PageRun limit {}", + numPages, + Allocation::PageRun::kMaxPagesInRun)); } numPages_ += numPages; runs_.emplace_back(address, numPages); diff --git a/velox/common/memory/ArbitrationOperation.h b/velox/common/memory/ArbitrationOperation.h index e494d93d2b85..d9159c583c87 100644 --- a/velox/common/memory/ArbitrationOperation.h +++ b/velox/common/memory/ArbitrationOperation.h @@ -154,7 +154,7 @@ struct fmt::formatter : formatter { auto format( facebook::velox::memory::ArbitrationOperation::State state, - format_context& ctx) { + format_context& ctx) const { return formatter::format( facebook::velox::memory::ArbitrationOperation::stateName(state), ctx); } diff --git a/velox/common/memory/ArbitrationParticipant.cpp b/velox/common/memory/ArbitrationParticipant.cpp index 1b29c2107121..7d74c98c7311 100644 --- a/velox/common/memory/ArbitrationParticipant.cpp +++ b/velox/common/memory/ArbitrationParticipant.cpp @@ -31,7 +31,7 @@ using namespace facebook::velox::memory; std::string ArbitrationParticipant::Config::toString() const { return fmt::format( - "initCapacity {}, minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, minFreeCapacity {}, minFreeCapacityRatio {}, minReclaimBytes {}, minReclaimPct {}, abortCapacityLimit {}", + "initCapacity {}, minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, minFreeCapacity {}, minFreeCapacityRatio {}, minReclaimBytes {}, minReclaimPct {}", succinctBytes(initCapacity), succinctBytes(minCapacity), succinctBytes(fastExponentialGrowthCapacityLimit), @@ -39,8 +39,7 @@ std::string ArbitrationParticipant::Config::toString() const { succinctBytes(minFreeCapacity), minFreeCapacityRatio, succinctBytes(minReclaimBytes), - minReclaimPct, - succinctBytes(abortCapacityLimit)); + minReclaimPct); } ArbitrationParticipant::Config::Config( @@ -51,8 +50,7 @@ ArbitrationParticipant::Config::Config( uint64_t _minFreeCapacity, double _minFreeCapacityRatio, uint64_t _minReclaimBytes, - double _minReclaimPct, - uint64_t _abortCapacityLimit) + double _minReclaimPct) : initCapacity(_initCapacity), minCapacity(_minCapacity), fastExponentialGrowthCapacityLimit(_fastExponentialGrowthCapacityLimit), @@ -60,8 +58,7 @@ ArbitrationParticipant::Config::Config( minFreeCapacity(_minFreeCapacity), minFreeCapacityRatio(_minFreeCapacityRatio), minReclaimBytes(_minReclaimBytes), - minReclaimPct(_minReclaimPct), - abortCapacityLimit(_abortCapacityLimit) { + minReclaimPct(_minReclaimPct) { VELOX_CHECK_GE(slowCapacityGrowRatio, 0); VELOX_CHECK_EQ( fastExponentialGrowthCapacityLimit == 0, @@ -82,10 +79,6 @@ ArbitrationParticipant::Config::Config( "adjustment.", minFreeCapacity, minFreeCapacityRatio); - VELOX_CHECK( - bits::isPowerOfTwo(abortCapacityLimit), - "abortCapacityLimit {} not a power of two", - abortCapacityLimit); VELOX_CHECK( 0 <= minReclaimPct && minReclaimPct <= 1, "minReclaimPct {} must be in [0, 1]", @@ -109,8 +102,7 @@ ArbitrationParticipant::ArbitrationParticipant( pool_(pool.get()), config_(config), maxCapacity_(pool_->maxCapacity()), - createTimeNs_(getCurrentTimeNano()), - poolPriority_(pool_->poolPriority()) { + createTimeNs_(getCurrentTimeNano()) { VELOX_CHECK_LE( config_->minCapacity, maxCapacity_, @@ -363,7 +355,9 @@ uint64_t ArbitrationParticipant::abortLocked( if (aborted_) { return 0; } + aborted_ = true; } + try { VELOX_MEM_LOG(WARNING) << "Memory pool " << pool_->name() << " is being aborted"; @@ -378,8 +372,6 @@ uint64_t ArbitrationParticipant::abortLocked( VELOX_CHECK(pool_->aborted()); std::lock_guard l(stateLock_); - VELOX_CHECK(!aborted_); - aborted_ = true; return shrinkLocked(/*reclaimAll=*/true); } @@ -449,8 +441,9 @@ ArbitrationTimedLock::ArbitrationTimedLock( uint64_t timeoutNs) : mutex_(mutex) { if (!mutex_.try_lock_for(std::chrono::nanoseconds(timeoutNs))) { - VELOX_MEM_ARBITRATION_TIMEOUT(fmt::format( - "Memory arbitration lock timed out when reclaiming from arbitration participant.")); + VELOX_MEM_ARBITRATION_TIMEOUT( + fmt::format( + "Memory arbitration lock timed out when reclaiming from arbitration participant.")); } } diff --git a/velox/common/memory/ArbitrationParticipant.h b/velox/common/memory/ArbitrationParticipant.h index 4eb7763c8edb..530b04b9dda0 100644 --- a/velox/common/memory/ArbitrationParticipant.h +++ b/velox/common/memory/ArbitrationParticipant.h @@ -119,17 +119,6 @@ class ArbitrationParticipant uint64_t minReclaimBytes; double minReclaimPct; - /// Specifies the starting memory capacity limit for global arbitration to - /// search for victim participant to reclaim used memory by abort. For - /// participants with capacity larger than the limit, the global arbitration - /// choose to abort the youngest participant which has the largest - /// participant id. This helps to let the old queries to run to completion. - /// The abort capacity limit is reduced by half if couldn't find a victim - /// participant until reaches to zero. - /// - /// NOTE: the limit must be zero or a power of 2. - uint64_t abortCapacityLimit; - Config( uint64_t _initCapacity, uint64_t _minCapacity, @@ -138,8 +127,7 @@ class ArbitrationParticipant uint64_t _minFreeCapacity, double _minFreeCapacityRatio, uint64_t _minReclaimBytes, - double _minReclaimPct, - uint64_t _abortCapacityLimit); + double _minReclaimPct); std::string toString() const; }; @@ -173,11 +161,6 @@ class ArbitrationParticipant return config_->minCapacity; } - /// Returns the priority of the underlying query memory pool. - uint32_t poolPriority() const { - return poolPriority_; - } - /// Returns the duration of this arbitration participant since its creation. uint64_t durationNs() const { const auto now = getCurrentTimeNano(); @@ -349,7 +332,6 @@ class ArbitrationParticipant const Config* const config_; const uint64_t maxCapacity_; const uint64_t createTimeNs_; - const uint32_t poolPriority_; mutable std::mutex stateLock_; bool aborted_{false}; diff --git a/velox/common/memory/ByteStream.cpp b/velox/common/memory/ByteStream.cpp index 4fe93b569216..c910f32c9c90 100644 --- a/velox/common/memory/ByteStream.cpp +++ b/velox/common/memory/ByteStream.cpp @@ -16,19 +16,21 @@ #include "velox/common/memory/ByteStream.h" +#include + namespace facebook::velox { +static ByteRange convByteRange(folly::ByteRange br) { + return {const_cast(br.data()), folly::to_signed(br.size()), 0}; +} + std::vector byteRangesFromIOBuf(folly::IOBuf* iobuf) { if (iobuf == nullptr) { return {}; } std::vector byteRanges; - auto* current = iobuf; - do { - byteRanges.push_back( - {current->writableData(), static_cast(current->length()), 0}); - current = current->next(); - } while (current != iobuf); + auto dst = std::back_inserter(byteRanges); + std::transform(iobuf->begin(), iobuf->end(), dst, convByteRange); return byteRanges; } diff --git a/velox/common/memory/ByteStream.h b/velox/common/memory/ByteStream.h index fee24a1f2e87..1063e71ee096 100644 --- a/velox/common/memory/ByteStream.h +++ b/velox/common/memory/ByteStream.h @@ -361,9 +361,10 @@ class ByteOutputStream { } if (current_->position + sizeof(T) * values.size() > current_->size) { - appendStringView(std::string_view( - reinterpret_cast(&values[0]), - values.size() * sizeof(T))); + appendStringView( + std::string_view( + reinterpret_cast(&values[0]), + values.size() * sizeof(T))); return; } auto* target = current_->buffer + current_->position; @@ -537,9 +538,10 @@ class AppendWindow { ~AppendWindow() noexcept { if (scratchPtr_.size()) { try { - stream_.appendStringView(std::string_view( - reinterpret_cast(scratchPtr_.get()), - scratchPtr_.size() * sizeof(T))); + stream_.appendStringView( + std::string_view( + reinterpret_cast(scratchPtr_.get()), + scratchPtr_.size() * sizeof(T))); } catch (const std::exception& e) { // This is impossible because construction ensures there is space for // the bytes in the stream. diff --git a/velox/common/memory/CMakeLists.txt b/velox/common/memory/CMakeLists.txt index f3b39127cfd3..a36e0f06a6be 100644 --- a/velox/common/memory/CMakeLists.txt +++ b/velox/common/memory/CMakeLists.txt @@ -32,18 +32,21 @@ velox_add_library( MmapArena.cpp RawVector.cpp SharedArbitrator.cpp - StreamArena.cpp) + StreamArena.cpp +) velox_link_libraries( velox_memory - PUBLIC velox_common_base - velox_common_config - velox_exception - velox_flag_definitions - velox_time - velox_type - Folly::folly - fmt::fmt - gflags::gflags - glog::glog - PRIVATE velox_test_util re2::re2) + PUBLIC + velox_common_base + velox_common_config + velox_exception + velox_flag_definitions + velox_time + velox_type + Folly::folly + fmt::fmt + gflags::gflags + glog::glog + PRIVATE velox_test_util re2::re2 +) diff --git a/velox/common/memory/HashStringAllocator.h b/velox/common/memory/HashStringAllocator.h index fd9fadc1c2b2..6e153894dfc8 100644 --- a/velox/common/memory/HashStringAllocator.h +++ b/velox/common/memory/HashStringAllocator.h @@ -19,7 +19,6 @@ #include "velox/common/memory/AllocationPool.h" #include "velox/common/memory/ByteStream.h" #include "velox/common/memory/CompactDoubleList.h" -#include "velox/common/memory/Memory.h" #include "velox/common/memory/StreamArena.h" #include "velox/type/StringView.h" @@ -27,6 +26,9 @@ namespace facebook::velox { +template +struct StlAllocator; + /// Implements an arena backed by memory::Allocation. This is for backing /// ByteOutputStream or for allocating single blocks. Blocks can be individually /// freed. Adjacent frees are coalesced and free blocks are kept in a free list. @@ -41,6 +43,9 @@ namespace facebook::velox { /// backing a HashStringAllocator is set to kArenaEnd. class HashStringAllocator : public StreamArena { public: + template + using TStlAllocator = StlAllocator; + /// The minimum allocation must have space after the header for the free list /// pointers and the trailing length. static constexpr int32_t kMinAlloc = @@ -659,8 +664,8 @@ class RowSizeTracker { counter_(counter) {} ~RowSizeTracker() { - auto delta = allocator_->currentBytes() - size_; - if (delta) { + const auto delta = allocator_->currentBytes() - size_; + if (delta != 0) { saturatingIncrement(&counter_, delta); } } @@ -668,7 +673,7 @@ class RowSizeTracker { private: // Increments T at *pointer without wrapping around at overflow. void saturatingIncrement(T* pointer, int64_t delta) { - auto value = *reinterpret_cast(pointer) + delta; + const auto value = *reinterpret_cast(pointer) + delta; *reinterpret_cast(pointer) = std::min(value, std::numeric_limits::max()); } @@ -685,11 +690,15 @@ struct StlAllocator { explicit StlAllocator(HashStringAllocator* allocator) : allocator_{allocator} { - VELOX_CHECK(allocator); + VELOX_CHECK_NOT_NULL(allocator); } + // We can use "explicit" here based on the C++ standard. But + // libstdc++ 12 or older doesn't work for std::vector and + // "explicit". We can avoid it by not using "explicit" here. + // See also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115854 template - explicit StlAllocator(const StlAllocator& allocator) + StlAllocator(const StlAllocator& allocator) : allocator_{allocator.allocator()} { VELOX_CHECK_NOT_NULL(allocator_); } @@ -713,16 +722,10 @@ struct StlAllocator { return allocator_; } - friend bool operator==(const StlAllocator& lhs, const StlAllocator& rhs) { - return lhs.allocator_ == rhs.allocator_; - } - - friend bool operator!=(const StlAllocator& lhs, const StlAllocator& rhs) { - return !(lhs == rhs); - } + bool operator==(const StlAllocator& other) const = default; private: - HashStringAllocator* allocator_; + HashStringAllocator* const allocator_; }; /// An allocator backed by HashStringAllocator that guaratees a configurable @@ -803,12 +806,6 @@ struct AlignedStlAllocator { return lhs.allocator_ == rhs.allocator_; } - friend bool operator!=( - const AlignedStlAllocator& lhs, - const AlignedStlAllocator& rhs) { - return !(lhs == rhs); - } - private: // Pad the memory user requested by some padding to facilitate memory // alignment later. Memory layout: @@ -833,7 +830,7 @@ struct AlignedStlAllocator { return reinterpret_cast(alignedPtr); } - HashStringAllocator* allocator_; + HashStringAllocator* const allocator_; const bool poolAligned_; }; diff --git a/velox/common/memory/MallocAllocator.cpp b/velox/common/memory/MallocAllocator.cpp index ff44791763ac..775401ebe99a 100644 --- a/velox/common/memory/MallocAllocator.cpp +++ b/velox/common/memory/MallocAllocator.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/memory/MallocAllocator.h" +#include #include "velox/common/memory/Memory.h" #include @@ -33,7 +34,7 @@ MallocAllocator::MallocAllocator(size_t capacity, uint32_t reservationByteLimit) decrementUsageWithReservationFunc(counter, decrement, lock); return true; }), - reservations_(std::thread::hardware_concurrency()) {} + reservations_(folly::hardware_concurrency()) {} MallocAllocator::~MallocAllocator() { // TODO: Remove the check when memory leak issue is resolved. @@ -58,9 +59,10 @@ bool MallocAllocator::allocateNonContiguousWithoutRetry( !incrementUsage(totalBytes)) { const auto errorMsg = fmt::format( "Exceeded memory allocator limit when allocating {} new pages" - ", the memory allocator capacity is {}", + ", the memory allocator capacity is {}, the allocated bytes is {}", sizeMix.totalPages, - succinctBytes(capacity_)); + succinctBytes(capacity_), + succinctBytes(allocatedBytes_)); VELOX_MEM_LOG_EVERY_MS(WARNING, 1000) << errorMsg; setAllocatorFailureMessage(errorMsg); return false; @@ -150,6 +152,7 @@ bool MallocAllocator::allocateContiguousImpl( } numMapped_.fetch_sub(numContiguousCollateralPages); numAllocated_.fetch_sub(numContiguousCollateralPages); + numExternalMapped_.fetch_sub(numContiguousCollateralPages); decrementUsage(AllocationTraits::pageBytes(numContiguousCollateralPages)); allocation.clear(); } @@ -171,6 +174,7 @@ bool MallocAllocator::allocateContiguousImpl( } numAllocated_.fetch_add(numPages); numMapped_.fetch_add(numPages); + numExternalMapped_.fetch_add(numPages); void* data = ::mmap( nullptr, AllocationTraits::pageBytes(maxPages), @@ -228,6 +232,7 @@ void MallocAllocator::freeContiguousImpl(ContiguousAllocation& allocation) { } numMapped_.fetch_sub(numPages); numAllocated_.fetch_sub(numPages); + numExternalMapped_.fetch_sub(numPages); decrementUsage(bytes); allocation.clear(); } @@ -248,6 +253,7 @@ bool MallocAllocator::growContiguousWithoutRetry( } numAllocated_ += increment; numMapped_ += increment; + numExternalMapped_ += increment; allocation.set( allocation.data(), allocation.size() + AllocationTraits::kPageSize * increment, diff --git a/velox/common/memory/MallocAllocator.h b/velox/common/memory/MallocAllocator.h index ade9536f0fb6..5e3e7508fdc4 100644 --- a/velox/common/memory/MallocAllocator.h +++ b/velox/common/memory/MallocAllocator.h @@ -77,6 +77,10 @@ class MallocAllocator : public MemoryAllocator { return numMapped_; } + MachinePageCount numExternalMapped() const override { + return numExternalMapped_; + } + bool checkConsistency() const override; std::string toString() const override; @@ -202,11 +206,12 @@ class MallocAllocator : public MemoryAllocator { if (originalBytes - bytes < 0) { // In case of inconsistency while freeing memory, do not revert in this // case because free is guaranteed to happen. - VELOX_MEM_ALLOC_ERROR(fmt::format( - "Trying to free {} bytes, which is larger than current allocated " - "bytes {}", - bytes, - originalBytes)) + VELOX_MEM_ALLOC_ERROR( + fmt::format( + "Trying to free {} bytes, which is larger than current allocated " + "bytes {}", + bytes, + originalBytes)) } } diff --git a/velox/common/memory/Memory.cpp b/velox/common/memory/Memory.cpp index 3b9f3cbf716a..01de413c830d 100644 --- a/velox/common/memory/Memory.cpp +++ b/velox/common/memory/Memory.cpp @@ -74,36 +74,6 @@ std::unique_ptr createArbitrator( .extraConfigs = options.extraArbitratorConfigs}); } -std::shared_ptr createAllocator( - const MemoryManagerOptions& options) { - if (options.useMmapAllocator) { - MmapAllocator::Options mmapOptions; - mmapOptions.capacity = options.allocatorCapacity; - mmapOptions.largestSizeClass = options.largestSizeClassPages; - mmapOptions.useMmapArena = options.useMmapArena; - mmapOptions.mmapArenaCapacityRatio = options.mmapArenaCapacityRatio; - return std::make_shared(mmapOptions); - } else { - return std::make_shared( - options.allocatorCapacity, - options.allocationSizeThresholdWithReservation); - } -} - -std::unique_ptr createArbitrator( - const MemoryManagerOptions& options) { - // TODO: consider to reserve a small amount of memory to compensate for the - // non-reclaimable cache memory which are pinned by query accesses if - // enabled. - - return MemoryArbitrator::create( - {.kind = options.arbitratorKind, - .capacity = - std::min(options.arbitratorCapacity, options.allocatorCapacity), - .arbitrationStateCheckCb = options.arbitrationStateCheckCb, - .extraConfigs = options.extraArbitratorConfigs}); -} - std::vector> createSharedLeafMemoryPools( MemoryPool& sysPool) { VELOX_CHECK_EQ(sysPool.name(), kSysRootName); @@ -117,6 +87,50 @@ std::vector> createSharedLeafMemoryPools( } return leafPools; } + +// Used by sys root memory pool for use case that expect a memory reclaimer to +// set like QueryCtx. +class SysMemoryReclaimer : public memory::MemoryReclaimer { + public: + static std::unique_ptr create() { + return std::unique_ptr(new SysMemoryReclaimer()); + } + + uint64_t reclaim( + memory::MemoryPool* pool, + uint64_t targetBytes, + uint64_t maxWaitMs, + memory::MemoryReclaimer::Stats& stats) override { + return 0; + } + + void enterArbitration() override {} + + void leaveArbitration() noexcept override {} + + int32_t priority() const override { + return 0; + } + + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const override { + return false; + } + + /// Invoked by the memory arbitrator to abort memory 'pool' and the associated + /// query execution when encounters non-recoverable memory reclaim error or + /// fails to reclaim enough free capacity. The abort is a synchronous + /// operation and we expect most of used memory to be freed after the abort + /// completes. 'error' should be passed in as the direct cause of the + /// abortion. It will be propagated all the way to task level for accurate + /// error exposure. + void abort(MemoryPool* pool, const std::exception_ptr& error) override { + VELOX_UNSUPPORTED("SysMemoryReclaimer::abort is not supported"); + } + + private: + SysMemoryReclaimer() : MemoryReclaimer{0} {}; +}; } // namespace MemoryManager::MemoryManager(const MemoryManager::Options& options) @@ -147,50 +161,7 @@ MemoryManager::MemoryManager(const MemoryManager::Options& options) cachePool_{addLeafPool("__sys_caching__")}, tracePool_{addLeafPool("__sys_tracing__")}, sharedLeafPools_(createSharedLeafMemoryPools(*sysRoot_)) { - VELOX_CHECK_NOT_NULL(allocator_); - VELOX_CHECK_NOT_NULL(arbitrator_); - VELOX_USER_CHECK_GE(capacity(), 0); - VELOX_CHECK_GE(allocator_->capacity(), arbitrator_->capacity()); - MemoryAllocator::alignmentCheck(0, alignment_); - const bool ret = sysRoot_->grow(sysRoot_->maxCapacity(), 0); - VELOX_CHECK( - ret, - "Failed to set max capacity {} for {}", - succinctBytes(sysRoot_->maxCapacity()), - sysRoot_->name()); - VELOX_CHECK_EQ( - sharedLeafPools_.size(), - std::max(1, FLAGS_velox_memory_num_shared_leaf_pools)); -} - -MemoryManager::MemoryManager(const MemoryManagerOptions& options) - : allocator_{createAllocator(options)}, - arbitrator_(createArbitrator(options)), - alignment_(std::max(MemoryAllocator::kMinAlignment, options.alignment)), - checkUsageLeak_(options.checkUsageLeak), - coreOnAllocationFailureEnabled_(options.coreOnAllocationFailureEnabled), - disableMemoryPoolTracking_(options.disableMemoryPoolTracking), - getPreferredSize_(options.getPreferredSize), - poolDestructionCb_([&](MemoryPool* pool) { dropPool(pool); }), - sysRoot_{std::make_shared( - this, - std::string(kSysRootName), - MemoryPool::Kind::kAggregate, - nullptr, - nullptr, - // NOTE: the default root memory pool has no capacity limit, and it is - // used for system usage in production such as disk spilling. - MemoryPool::Options{ - .alignment = alignment_, - .maxCapacity = kMaxMemory, - .trackUsage = options.trackDefaultUsage, - .coreOnAllocationFailureEnabled = - options.coreOnAllocationFailureEnabled, - .getPreferredSize = getPreferredSize_})}, - spillPool_{addLeafPool("__sys_spilling__")}, - cachePool_{addLeafPool("__sys_caching__")}, - tracePool_{addLeafPool("__sys_tracing__")}, - sharedLeafPools_(createSharedLeafMemoryPools(*sysRoot_)) { + sysRoot_->setReclaimer(SysMemoryReclaimer::create()); VELOX_CHECK_NOT_NULL(allocator_); VELOX_CHECK_NOT_NULL(arbitrator_); VELOX_USER_CHECK_GE(capacity(), 0); @@ -256,19 +227,6 @@ void MemoryManager::initialize(const MemoryManager::Options& options) { state.instance.store(instance, std::memory_order_release); } -// static -void MemoryManager::initialize(const MemoryManagerOptions& options) { - auto& state = singletonState(); - std::lock_guard l(state.mutex); - auto* instance = state.instance.load(std::memory_order_acquire); - VELOX_CHECK_NULL( - instance, - "The memory manager has already been set: {}", - instance->toString()); - instance = new MemoryManager(options); - state.instance.store(instance, std::memory_order_release); -} - // static. MemoryManager* MemoryManager::getInstance() { auto* instance = singletonState().instance.load(std::memory_order_acquire); @@ -292,15 +250,6 @@ MemoryManager& MemoryManager::testingSetInstance( return *instance; } -MemoryManager& MemoryManager::testingSetInstance( - const MemoryManagerOptions& options) { - auto& state = singletonState(); - std::lock_guard l(state.mutex); - auto* instance = new MemoryManager(options); - delete state.instance.exchange(instance, std::memory_order_acq_rel); - return *instance; -} - int64_t MemoryManager::capacity() const { return allocator_->capacity(); } @@ -331,8 +280,7 @@ std::shared_ptr MemoryManager::addRootPool( const std::string& name, int64_t maxCapacity, std::unique_ptr reclaimer, - const std::optional& poolDebugOpts, - uint32_t poolPriority) { + const std::optional& poolDebugOpts) { std::string poolName = name; if (poolName.empty()) { static std::atomic poolId{0}; @@ -346,7 +294,6 @@ std::shared_ptr MemoryManager::addRootPool( options.coreOnAllocationFailureEnabled = coreOnAllocationFailureEnabled_; options.getPreferredSize = getPreferredSize_; options.debugOptions = poolDebugOpts; - options.poolPriority = poolPriority; auto pool = createRootPool(poolName, reclaimer, options); if (!disableMemoryPoolTracking_) { @@ -474,10 +421,6 @@ void initializeMemoryManager(const MemoryManager::Options& options) { MemoryManager::initialize(options); } -void initializeMemoryManager(const MemoryManagerOptions& options) { - MemoryManager::initialize(options); -} - MemoryManager* memoryManager() { return MemoryManager::getInstance(); } @@ -497,6 +440,10 @@ MemoryPool& deprecatedSharedLeafPool() { return deprecatedDefaultMemoryManager().deprecatedSharedLeafPool(); } +MemoryPool& deprecatedRootPool() { + return deprecatedDefaultMemoryManager().deprecatedSysRootPool(); +} + memory::MemoryPool* spillMemoryPool() { return memory::MemoryManager::getInstance()->spillPool(); } diff --git a/velox/common/memory/Memory.h b/velox/common/memory/Memory.h index 317c32935927..3e40d2ff9aac 100644 --- a/velox/common/memory/Memory.h +++ b/velox/common/memory/Memory.h @@ -41,7 +41,6 @@ #include "velox/common/memory/MemoryPool.h" DECLARE_bool(velox_memory_leak_check_enabled); -DECLARE_bool(velox_memory_pool_debug_enabled); DECLARE_bool(velox_enable_memory_usage_track_in_default_memory_pool); namespace facebook::velox::memory { @@ -59,118 +58,6 @@ namespace facebook::velox::memory { "{}", \ errorMessage); -/// TODO(wangke): deprecate when MemoryManagerOptions is fully migrated. -struct MemoryManagerOptions { - /// Specifies the default memory allocation alignment. - uint16_t alignment{MemoryAllocator::kMaxAlignment}; - - /// If true, enable memory usage tracking in the default memory pool. - bool trackDefaultUsage{ - FLAGS_velox_enable_memory_usage_track_in_default_memory_pool}; - - /// If true, check the memory pool and usage leaks on destruction. - /// - /// TODO: deprecate this flag after all the existing memory leak use cases - /// have been fixed. - bool checkUsageLeak{FLAGS_velox_memory_leak_check_enabled}; - - /// Terminates the process and generates a core file on an allocation failure - bool coreOnAllocationFailureEnabled{false}; - - /// Disables the memory manager's tracking on memory pools. - bool disableMemoryPoolTracking{false}; - - /// ================== 'MemoryAllocator' settings ================== - - /// Specifies the max memory allocation capacity in bytes enforced by - /// MemoryAllocator, default unlimited. - int64_t allocatorCapacity{kMaxMemory}; - - /// If true, uses MmapAllocator for memory allocation which manages the - /// physical memory allocation on its own through std::mmap techniques. If - /// false, use MallocAllocator which delegates the memory allocation to - /// std::malloc. - bool useMmapAllocator{false}; - - /// Number of pages in the largest size class in MmapAllocator. - int32_t largestSizeClassPages{256}; - - /// If true, allocations larger than largest size class size will be delegated - /// to ManagedMmapArena. Otherwise a system mmap call will be issued for each - /// such allocation. - /// - /// NOTE: this only applies for MmapAllocator. - bool useMmapArena{false}; - - /// Used to determine MmapArena capacity. The ratio represents - /// 'allocatorCapacity' to single MmapArena capacity ratio. - /// - /// NOTE: this only applies for MmapAllocator. - int32_t mmapArenaCapacityRatio{10}; - - /// If not zero, reserve 'smallAllocationReservePct'% of space from - /// 'allocatorCapacity' for ad hoc small allocations. And those allocations - /// are delegated to std::malloc. If 'maxMallocBytes' is 0, this value will be - /// disregarded. - /// - /// NOTE: this only applies for MmapAllocator. - uint32_t smallAllocationReservePct{0}; - - /// The allocation threshold less than which an allocation is delegated to - /// std::malloc(). If it is zero, then we don't delegate any allocation - /// std::malloc, and 'smallAllocationReservePct' will be automatically set to - /// 0 disregarding any passed in value. - /// - /// NOTE: this only applies for MmapAllocator. - int32_t maxMallocBytes{3072}; - - /// The memory allocations with size smaller than this threshold check the - /// capacity with local sharded counter to reduce the lock contention on the - /// global allocation counter. The sharded local counters reserve/release - /// memory capacity from the global counter in batch. With this optimization, - /// we don't have to update the global counter for each individual small - /// memory allocation. If it is zero, then this optimization is disabled. The - /// default is 1MB. - /// - /// NOTE: this only applies for MallocAllocator. - uint32_t allocationSizeThresholdWithReservation{1 << 20}; - - /// ================== 'MemoryArbitrator' settings ================= - - /// Memory capacity available for query/task memory pools. This capacity - /// setting should be equal or smaller than 'allocatorCapacity'. The - /// difference between 'allocatorCapacity' and 'arbitratorCapacity' is - /// reserved for system usage such as cache and spilling. - /// - /// NOTE: - /// - if 'arbitratorCapacity' is greater than 'allocatorCapacity', the - /// behavior will be equivalent to as if they are equal, meaning no - /// reservation capacity for system usage. - int64_t arbitratorCapacity{kMaxMemory}; - - /// The string kind of memory arbitrator used in the memory manager. - /// - /// NOTE: the arbitrator will only be created if its kind is set explicitly. - /// Otherwise MemoryArbitrator::create returns a nullptr. - std::string arbitratorKind{}; - - /// Provided by the query system to validate the state after a memory pool - /// enters arbitration if not null. For instance, Prestissimo provides - /// callback to check if a memory arbitration request is issued from a driver - /// thread, then the driver should be put in suspended state to avoid the - /// potential deadlock when reclaim memory from the task of the request memory - /// pool. - MemoryArbitrationStateCheckCB arbitrationStateCheckCb{nullptr}; - - /// Additional configs that are arbitrator implementation specific. - std::unordered_map extraArbitratorConfigs{}; - - /// Provides the customized get preferred size function for memory pool - /// allocation. It returns the actual allocation size for a given input size. - /// If not set, uses the memory pool's default get preferred size function. - std::function getPreferredSize{nullptr}; -}; - /// 'MemoryManager' is responsible for creating allocator, arbitrator and /// managing the memory pools. class MemoryManager { @@ -291,18 +178,12 @@ class MemoryManager { explicit MemoryManager(const Options& options = Options{}); - /// TODO(wangke): deprecate when MemoryManagerOptions is fully migrated. - explicit MemoryManager(const MemoryManagerOptions& options); - ~MemoryManager(); /// Creates process-wide memory manager using specified options. Throws if /// memory manager has already been created by an easier call. static void initialize(const Options& options); - /// TODO(wangke): deprecate when MemoryManagerOptions is fully migrated. - static void initialize(const MemoryManagerOptions& options); - /// Returns process-wide memory manager. Throws if 'initialize' hasn't been /// called yet. static MemoryManager* getInstance(); @@ -319,9 +200,6 @@ class MemoryManager { /// Used by test to override the process-wide memory manager. static MemoryManager& testingSetInstance(const Options& options); - /// TODO(wangke): deprecate when MemoryManagerOptions is fully migrated. - static MemoryManager& testingSetInstance(const MemoryManagerOptions& options); - /// Returns the memory capacity of this memory manager which puts a hard cap /// on memory usage, and any allocation that exceeds this capacity throws. int64_t capacity() const; @@ -337,8 +215,7 @@ class MemoryManager { int64_t maxCapacity = kMaxMemory, std::unique_ptr reclaimer = nullptr, const std::optional& poolDebugOpts = - std::nullopt, - uint32_t poolPriority = 0); + std::nullopt); /// Creates a leaf memory pool for direct memory allocation use with specified /// 'name'. If 'name' is missing, the memory manager generates a default name @@ -390,8 +267,8 @@ class MemoryManager { std::string toString(bool detail = false) const; /// Returns the memory manger's internal default root memory pool for testing - /// purpose. - MemoryPool& testingDefaultRoot() const { + /// purpose and legacy use cases. + MemoryPool& deprecatedSysRootPool() const { return *sysRoot_; } @@ -458,9 +335,6 @@ class MemoryManager { /// the function throws. void initializeMemoryManager(const MemoryManager::Options& options); -/// TODO(wangke): deprecate when MemoryManagerOptions is fully migrated. -void initializeMemoryManager(const MemoryManagerOptions& options); - /// Returns the process-wide memory manager. /// /// NOTE: user should have already initialized memory manager by calling. @@ -486,6 +360,9 @@ std::shared_ptr deprecatedAddDefaultLeafMemoryPool( /// lifecycle of the allocated memory pools properly. MemoryPool& deprecatedSharedLeafPool(); +/// Returns the sys root memory pool from the default memory manager. +MemoryPool& deprecatedRootPool(); + /// Returns the system-wide memory pool for spilling memory usage. memory::MemoryPool* spillMemoryPool(); diff --git a/velox/common/memory/MemoryAllocator.h b/velox/common/memory/MemoryAllocator.h index 4a7ac706c715..621fe2e2a884 100644 --- a/velox/common/memory/MemoryAllocator.h +++ b/velox/common/memory/MemoryAllocator.h @@ -87,7 +87,7 @@ struct Stats { /// power of 2 >= the allocation size. static constexpr int32_t kNumSizes = 20; Stats() { - for (auto i = 0; i < sizes.size(); ++i) { + for (size_t i = 0; i < sizes.size(); ++i) { sizes[i].size = 1 << i; } } @@ -341,6 +341,8 @@ class MemoryAllocator : public std::enable_shared_from_this { virtual MachinePageCount numMapped() const = 0; + virtual MachinePageCount numExternalMapped() const = 0; + virtual Stats stats() const { return stats_; } @@ -502,6 +504,14 @@ class MemoryAllocator : public std::enable_shared_from_this { // system by 'this' (via madvise calls). std::atomic numMapped_{0}; + // Number of pages allocated and explicitly mmap'd by the + // application via allocateContiguous, outside of + // 'sizeClasses'. These pages are counted in 'numAllocated_' and + // 'numMapped_'. Allocation requests are decided against + // 'numAllocated_' and 'numMapped_'. This counter is informational + // only. + std::atomic numExternalMapped_{0}; + // Indicates if the failure injection is persistent or transient. // // NOTE: this is only used for testing purpose. diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index 1c40c8bf82df..b4b6af0a7ea1 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -99,7 +99,11 @@ class NoopArbitrator : public MemoryArbitrator { } void removePool(MemoryPool* pool) override { - VELOX_CHECK_EQ(pool->reservedBytes(), 0); + VELOX_CHECK_EQ( + pool->reservedBytes(), + 0, + "Memory pool has unexpected reserved bytes on removal: {}", + pool->name()); } // Noop arbitrator has no memory capacity limit so no operation needed for @@ -263,9 +267,10 @@ uint64_t MemoryReclaimer::reclaim( nonReclaimableCandidates.push_back(Candidate{std::move(child), 0}); continue; } - candidates.push_back(Candidate{ - std::move(child), - static_cast(reclaimableBytesOpt.value())}); + candidates.push_back( + Candidate{ + std::move(child), + static_cast(reclaimableBytesOpt.value())}); } } } @@ -321,19 +326,6 @@ void MemoryReclaimer::Stats::reset() { reclaimWaitTimeUs = 0; } -bool MemoryReclaimer::Stats::operator==( - const MemoryReclaimer::Stats& other) const { - return numNonReclaimableAttempts == other.numNonReclaimableAttempts && - reclaimExecTimeUs == other.reclaimExecTimeUs && - reclaimedBytes == other.reclaimedBytes && - reclaimWaitTimeUs == other.reclaimWaitTimeUs; -} - -bool MemoryReclaimer::Stats::operator!=( - const MemoryReclaimer::Stats& other) const { - return !(*this == other); -} - MemoryReclaimer::Stats& MemoryReclaimer::Stats::operator+=( const MemoryReclaimer::Stats& other) { numNonReclaimableAttempts += other.numNonReclaimableAttempts; @@ -425,11 +417,8 @@ bool MemoryArbitrator::Stats::operator==(const Stats& other) const { other.numNonReclaimableAttempts); } -bool MemoryArbitrator::Stats::operator!=(const Stats& other) const { - return !(*this == other); -} - -bool MemoryArbitrator::Stats::operator<(const Stats& other) const { +std::strong_ordering MemoryArbitrator::Stats::operator<=>( + const Stats& other) const { uint32_t gtCount{0}; uint32_t ltCount{0}; #define UPDATE_COUNTER(counter) \ @@ -454,19 +443,9 @@ bool MemoryArbitrator::Stats::operator<(const Stats& other) const { "gtCount {} ltCount {}", gtCount, ltCount); - return ltCount > 0; -} - -bool MemoryArbitrator::Stats::operator>(const Stats& other) const { - return !(*this < other) && (*this != other); -} - -bool MemoryArbitrator::Stats::operator>=(const Stats& other) const { - return !(*this < other); -} - -bool MemoryArbitrator::Stats::operator<=(const Stats& other) const { - return !(*this > other); + return ltCount > 0 ? std::strong_ordering::less + : gtCount > 0 ? std::strong_ordering::greater + : std::strong_ordering::equal; } MemoryArbitrationContext::MemoryArbitrationContext(const MemoryPool* requestor) diff --git a/velox/common/memory/MemoryArbitrator.h b/velox/common/memory/MemoryArbitrator.h index b5419050f9a6..9cb833e7c85c 100644 --- a/velox/common/memory/MemoryArbitrator.h +++ b/velox/common/memory/MemoryArbitrator.h @@ -241,11 +241,7 @@ class MemoryArbitrator { Stats operator-(const Stats& other) const; bool operator==(const Stats& other) const; - bool operator!=(const Stats& other) const; - bool operator<(const Stats& other) const; - bool operator>(const Stats& other) const; - bool operator>=(const Stats& other) const; - bool operator<=(const Stats& other) const; + std::strong_ordering operator<=>(const Stats& other) const; bool empty() const { return numRequests == 0; @@ -321,8 +317,7 @@ class MemoryReclaimer { void reset(); - bool operator==(const Stats& other) const; - bool operator!=(const Stats& other) const; + bool operator==(const Stats& other) const = default; Stats& operator+=(const Stats& other); }; @@ -368,7 +363,7 @@ class MemoryReclaimer { /// rec3 -> rec6 -> rec7 virtual int32_t priority() const { return priority_; - }; + } /// Invoked by the memory arbitrator to get the amount of memory bytes that /// can be reclaimed from 'pool'. The function returns true if 'pool' is @@ -404,7 +399,7 @@ class MemoryReclaimer { virtual void abort(MemoryPool* pool, const std::exception_ptr& error); protected: - explicit MemoryReclaimer(int32_t priority) : priority_(priority){}; + explicit MemoryReclaimer(int32_t priority) : priority_(priority) {} private: const int32_t priority_; diff --git a/velox/common/memory/MemoryPool.cpp b/velox/common/memory/MemoryPool.cpp index f1c01835bebb..66ba6804cc8e 100644 --- a/velox/common/memory/MemoryPool.cpp +++ b/velox/common/memory/MemoryPool.cpp @@ -17,8 +17,8 @@ #include "velox/common/memory/MemoryPool.h" #include -#include +#include "velox/common/Casts.h" #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/base/SuccinctPrinter.h" @@ -39,14 +39,14 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::memory { namespace { // Check if memory operation is allowed and increment the named stats. -#define CHECK_AND_INC_MEM_OP_STATS(stats) \ +#define CHECK_AND_INC_MEM_OP_STATS(pool, stats) \ do { \ - if (FOLLY_UNLIKELY(kind_ != Kind::kLeaf)) { \ + if (FOLLY_UNLIKELY(pool->kind_ != Kind::kLeaf)) { \ VELOX_FAIL( \ "Memory operation is only allowed on leaf memory pool: {}", \ - toString()); \ + pool->toString()); \ } \ - ++num##stats##_; \ + ++pool->num##stats##_; \ } while (0) // Check if memory operation is allowed and increment the named stats. @@ -154,9 +154,9 @@ std::string capacityToString(int64_t capacity) { return capacity == kMaxMemory ? "UNLIMITED" : succinctBytes(capacity); } -#define DEBUG_RECORD_ALLOC(...) \ - if (FOLLY_UNLIKELY(debugEnabled())) { \ - recordAllocDbg(__VA_ARGS__); \ +#define DEBUG_RECORD_ALLOC(pool, ...) \ + if (FOLLY_UNLIKELY(pool->debugEnabled())) { \ + pool->recordAllocDbg(__VA_ARGS__); \ } #define DEBUG_RECORD_FREE(...) \ if (FOLLY_UNLIKELY(debugEnabled())) { \ @@ -227,7 +227,6 @@ MemoryPool::MemoryPool( trackUsage_(options.trackUsage), threadSafe_(options.threadSafe), debugOptions_(options.debugOptions), - poolPriority_(options.poolPriority), coreOnAllocationFailureEnabled_(options.coreOnAllocationFailureEnabled), getPreferredSize_( options.getPreferredSize == nullptr @@ -385,7 +384,7 @@ void MemoryPool::dropChild(const MemoryPool* child) { 1, "Child memory pool {} doesn't exist in {}", child->name(), - toString()); + name()); } bool MemoryPool::aborted() const { @@ -523,60 +522,63 @@ void* MemoryPoolImpl::allocate( } } - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto alignedSize = sizeAlign(size); reserve(alignedSize); void* buffer = allocator_->allocateBytes(alignedSize, alignment_); if (FOLLY_UNLIKELY(buffer == nullptr)) { release(alignedSize); - handleAllocationFailure(fmt::format( - "{} failed with {} from {} {}", - __FUNCTION__, - succinctBytes(size), - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(buffer, size); + handleAllocationFailure( + fmt::format( + "{} failed with {} from {} {}", + __FUNCTION__, + succinctBytes(size), + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, buffer, size); return buffer; } void* MemoryPoolImpl::allocateZeroFilled(int64_t numEntries, int64_t sizeEach) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto size = sizeEach * numEntries; const auto alignedSize = sizeAlign(size); reserve(alignedSize); void* buffer = allocator_->allocateZeroFilled(alignedSize); if (FOLLY_UNLIKELY(buffer == nullptr)) { release(alignedSize); - handleAllocationFailure(fmt::format( - "{} failed with {} entries and {} each from {} {}", - __FUNCTION__, - numEntries, - succinctBytes(sizeEach), - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(buffer, size); + handleAllocationFailure( + fmt::format( + "{} failed with {} entries and {} each from {} {}", + __FUNCTION__, + numEntries, + succinctBytes(sizeEach), + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, buffer, size); return buffer; } void* MemoryPoolImpl::reallocate(void* p, int64_t size, int64_t newSize) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); const auto alignedNewSize = sizeAlign(newSize); reserve(alignedNewSize); void* newP = allocator_->allocateBytes(alignedNewSize, alignment_); if (FOLLY_UNLIKELY(newP == nullptr)) { release(alignedNewSize); - handleAllocationFailure(fmt::format( - "{} failed with new {} and old {} from {} {}", - __FUNCTION__, - succinctBytes(newSize), - succinctBytes(size), - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(newP, newSize); + handleAllocationFailure( + fmt::format( + "{} failed with new {} and old {} from {} {}", + __FUNCTION__, + succinctBytes(newSize), + succinctBytes(size), + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, newP, newSize); if (p != nullptr) { ::memcpy(newP, p, std::min(size, newSize)); free(p, size); @@ -585,18 +587,40 @@ void* MemoryPoolImpl::reallocate(void* p, int64_t size, int64_t newSize) { } void MemoryPoolImpl::free(void* p, int64_t size) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); const auto alignedSize = sizeAlign(size); DEBUG_RECORD_FREE(p, size); allocator_->freeBytes(p, alignedSize); release(alignedSize); } +bool MemoryPoolImpl::transferTo(MemoryPool* dest, void* buffer, uint64_t size) { + if (!isLeaf() || !dest->isLeaf()) { + return false; + } + VELOX_CHECK_NOT_NULL(dest); + auto* destImpl = checkedPointerCast(dest); + if (allocator_ != destImpl->allocator_) { + return false; + } + + CHECK_AND_INC_MEM_OP_STATS(destImpl, Allocs); + const auto alignedSize = sizeAlign(size); + destImpl->reserve(alignedSize); + DEBUG_RECORD_ALLOC(destImpl, buffer, size); + + CHECK_AND_INC_MEM_OP_STATS(this, Frees); + DEBUG_RECORD_FREE(buffer, size); + release(alignedSize); + + return true; +} + void MemoryPoolImpl::allocateNonContiguous( MachinePageCount numPages, Allocation& out, MachinePageCount minSizeClass) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); if (!out.empty()) { INC_MEM_OP_STATS(Frees); } @@ -617,21 +641,22 @@ void MemoryPoolImpl::allocateNonContiguous( }, minSizeClass)) { VELOX_CHECK(out.empty()); - handleAllocationFailure(fmt::format( - "{} failed with {} pages from {} {}", - __FUNCTION__, - numPages, - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(out); + handleAllocationFailure( + fmt::format( + "{} failed with {} pages from {} {}", + __FUNCTION__, + numPages, + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, out); VELOX_CHECK(!out.empty()); VELOX_CHECK_NULL(out.pool()); out.setPool(this); } void MemoryPoolImpl::freeNonContiguous(Allocation& allocation) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); DEBUG_RECORD_FREE(allocation); const int64_t freedBytes = allocator_->freeNonContiguous(allocation); VELOX_CHECK(allocation.empty()); @@ -650,7 +675,7 @@ void MemoryPoolImpl::allocateContiguous( MachinePageCount numPages, ContiguousAllocation& out, MachinePageCount maxPages) { - CHECK_AND_INC_MEM_OP_STATS(Allocs); + CHECK_AND_INC_MEM_OP_STATS(this, Allocs); if (!out.empty()) { INC_MEM_OP_STATS(Frees); } @@ -669,21 +694,22 @@ void MemoryPoolImpl::allocateContiguous( }, maxPages)) { VELOX_CHECK(out.empty()); - handleAllocationFailure(fmt::format( - "{} failed with {} pages from {} {}", - __FUNCTION__, - numPages, - toString(), - allocator_->getAndClearFailureMessage())); - } - DEBUG_RECORD_ALLOC(out); + handleAllocationFailure( + fmt::format( + "{} failed with {} pages from {} {}", + __FUNCTION__, + numPages, + toString(), + allocator_->getAndClearFailureMessage())); + } + DEBUG_RECORD_ALLOC(this, out); VELOX_CHECK(!out.empty()); VELOX_CHECK_NULL(out.pool()); out.setPool(this); } void MemoryPoolImpl::freeContiguous(ContiguousAllocation& allocation) { - CHECK_AND_INC_MEM_OP_STATS(Frees); + CHECK_AND_INC_MEM_OP_STATS(this, Frees); const int64_t bytesToFree = allocation.size(); DEBUG_RECORD_FREE(allocation); allocator_->freeContiguous(allocation); @@ -702,12 +728,13 @@ void MemoryPoolImpl::growContiguous( release(allocBytes); } })) { - handleAllocationFailure(fmt::format( - "{} failed with {} pages from {} {}", - __FUNCTION__, - increment, - toString(), - allocator_->getAndClearFailureMessage())); + handleAllocationFailure( + fmt::format( + "{} failed with {} pages from {} {}", + __FUNCTION__, + increment, + toString(), + allocator_->getAndClearFailureMessage())); } if (FOLLY_UNLIKELY(debugEnabled())) { recordGrowDbg(allocation.data(), allocation.size()); @@ -777,7 +804,7 @@ std::shared_ptr MemoryPoolImpl::genChild( } bool MemoryPoolImpl::maybeReserve(uint64_t increment) { - CHECK_AND_INC_MEM_OP_STATS(Reserves); + CHECK_AND_INC_MEM_OP_STATS(this, Reserves); TestValue::adjust( "facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", this); // TODO: make this a configurable memory pool option. @@ -806,9 +833,6 @@ void MemoryPoolImpl::reserve(uint64_t size, bool reserveOnly) { reserveNonThreadSafe(size, reserveOnly); } } - if (reserveOnly) { - return; - } } void MemoryPoolImpl::reserveThreadSafe(uint64_t size, bool reserveOnly) { @@ -886,9 +910,16 @@ void MemoryPoolImpl::growCapacity(MemoryPool* requestor, uint64_t size) { VELOX_CHECK(requestor->isLeaf()); ++numCapacityGrowths_; - { + try { MemoryPoolArbitrationSection arbitrationSection(requestor); arbitrator_->growCapacity(this, size); + } catch (const VeloxRuntimeError& veloxError) { + if (FOLLY_UNLIKELY( + debugEnabled() && + veloxError.errorCode() == error_code::kMemCapExceeded)) { + std::rethrow_exception(wrapExceptionDbg(veloxError)); + } + throw; } // The memory pool might have been aborted during the time it leaves the // arbitration no matter the arbitration succeed or not. @@ -928,7 +959,7 @@ void MemoryPoolImpl::incrementReservationLocked(uint64_t bytes) { } void MemoryPoolImpl::release() { - CHECK_AND_INC_MEM_OP_STATS(Releases); + CHECK_AND_INC_MEM_OP_STATS(this, Releases); release(0, true); } @@ -985,6 +1016,21 @@ void MemoryPoolImpl::decrementReservation(uint64_t size) noexcept { sanityCheckLocked(); } +std::string MemoryPoolImpl::toString(bool detail) const { + std::string result; + { + std::lock_guard l(mutex_); + result = toStringLocked(); + } + if (detail) { + result += "\n" + treeMemoryUsage(); + } + if (FOLLY_UNLIKELY(debugEnabled())) { + result += "\n" + dumpRecordsDbg(); + } + return result; +} + std::string MemoryPoolImpl::treeMemoryUsage(bool skipEmptyPool) const { if (parent_ != nullptr) { return parent_->treeMemoryUsage(skipEmptyPool); @@ -1204,10 +1250,31 @@ void MemoryPoolImpl::recordAllocDbg(const void* addr, uint64_t size) { if (!needRecordDbg(true)) { return; } - std::lock_guard l(debugAllocMutex_); - debugAllocRecords_.emplace( - reinterpret_cast(addr), - AllocationRecord{size, process::StackTrace()}); + AllocationRecord allocationRecord{size, process::StackTrace()}; + std::lock_guard debugAllocLock(debugAllocMutex_); + auto [it, inserted] = debugAllocRecords_.try_emplace( + reinterpret_cast(addr), std::move(allocationRecord)); + VELOX_CHECK(inserted); + if (debugOptions_->debugPoolWarnThresholdBytes == 0 || + debugWarnThresholdExceeded_) { + return; + } + const auto usedBytes = reservedBytes(); + if (usedBytes >= debugOptions_->debugPoolWarnThresholdBytes) { + debugWarnThresholdExceeded_ = true; + VELOX_MEM_LOG(WARNING) << fmt::format( + "[MemoryPool] Memory pool '{}' exceeded warning threshold of {} with allocation of {}, resulting in total size of {}.\n" + "======== Allocation Stack ========\n" + "{}\n" + "======= Current Allocations ======\n" + "{}", + name_, + succinctBytes(debugOptions_->debugPoolWarnThresholdBytes), + succinctBytes(size), + succinctBytes(usedBytes), + it->second.callStack.toString(), + dumpRecordsDbgLocked()); + } } void MemoryPoolImpl::recordAllocDbg(const Allocation& allocation) { @@ -1240,16 +1307,17 @@ void MemoryPoolImpl::recordFreeDbg(const void* addr, uint64_t size) { const auto allocRecord = allocResult->second; if (allocRecord.size != size) { const auto freeStackTrace = process::StackTrace().toString(); - VELOX_FAIL(fmt::format( - "[MemoryPool] Trying to free {} bytes on an allocation of {} bytes.\n" - "======== Allocation Stack ========\n" - "{}\n" - "============ Free Stack ==========\n" - "{}\n", - size, - allocRecord.size, - allocRecord.callStack.toString(), - freeStackTrace)); + VELOX_FAIL( + fmt::format( + "[MemoryPool] Trying to free {} bytes on an allocation of {} bytes.\n" + "======== Allocation Stack ========\n" + "{}\n" + "============ Free Stack ==========\n" + "{}\n", + size, + allocRecord.size, + allocRecord.callStack.toString(), + freeStackTrace)); } debugAllocRecords_.erase(addrUint64); } @@ -1289,10 +1357,79 @@ void MemoryPoolImpl::leakCheckDbg() { if (debugAllocRecords_.empty()) { return; } - std::stringbuf buf; - std::ostream oss(&buf); - oss << "[MemoryPool] : " << name_ << " - Detected total of " - << debugAllocRecords_.size() << " leaked allocations:\n"; + VELOX_FAIL( + fmt::format( + "[MemoryPool] Leak check failed for '{}' pool - {}", + name_, + dumpRecordsDbg())); +} + +void MemoryPoolImpl::treeAllocationRecordsDbg( + std::vector& poolDumps) const { + VELOX_CHECK(debugEnabled()); + { + std::lock_guard debugAllocLock(debugAllocMutex_); + if (!debugAllocRecords_.empty()) { + MemoryPoolDump dump{ + .dumpedRecords = fmt::format( + "Memory pool '{}' - {}", name(), dumpRecordsDbgLocked()), + .bytes = reservedBytes(), + }; + poolDumps.emplace_back(std::move(dump)); + } + } + if (isLeaf()) { + return; + } + visitChildren([&poolDumps](MemoryPool* pool) { + toImpl(pool)->treeAllocationRecordsDbg(poolDumps); + return true; + }); +} + +std::exception_ptr MemoryPoolImpl::wrapExceptionDbg( + const VeloxRuntimeError& veloxError) const { + VELOX_CHECK(debugEnabled()); + VELOX_CHECK(isRoot()); + std::vector poolAllocationsSorted; + treeAllocationRecordsDbg(poolAllocationsSorted); + std::stringstream oss; + if (poolAllocationsSorted.empty()) { + oss << "No allocation records found."; + } else { + std::sort( + poolAllocationsSorted.begin(), + poolAllocationsSorted.end(), + [](const auto& a, const auto& b) { return a.bytes > b.bytes; }); + for (const auto& record : poolAllocationsSorted) { + oss << record.dumpedRecords << "\n\n"; + } + } + const auto wrappedMessage = fmt::format( + "{}\n\n" + "======= Current Allocations ======\n" + "{}", + veloxError.message(), + oss.str()); + return std::make_exception_ptr(VeloxRuntimeError( + veloxError.file(), + veloxError.line(), + veloxError.function(), + veloxError.failingExpression(), + wrappedMessage, + veloxError.errorSource(), + veloxError.errorCode(), + veloxError.isRetriable(), + veloxError.exceptionName())); +} + +std::string MemoryPoolImpl::dumpRecordsDbgLocked() const { + VELOX_CHECK(debugEnabled()); + std::stringstream oss; + oss << fmt::format( + "Found {} allocations with {} total size:\n", + debugAllocRecords_.size(), + succinctBytes(reservedBytes())); struct AllocationStats { uint64_t size{0}; uint64_t numAllocations{0}; @@ -1317,12 +1454,13 @@ void MemoryPoolImpl::leakCheckDbg() { return a.second.size > b.second.size; }); for (const auto& pair : sortedRecords) { - oss << "======== Leaked memory from " << pair.second.numAllocations - << " total allocations of " << succinctBytes(pair.second.size) - << " total size ========\n" - << pair.first << "\n"; + oss << fmt::format( + "======== {} allocations of {} total size ========\n{}\n", + pair.second.numAllocations, + succinctBytes(pair.second.size), + pair.first); } - VELOX_FAIL(buf.str()); + return oss.str(); } void MemoryPoolImpl::handleAllocationFailure( diff --git a/velox/common/memory/MemoryPool.h b/velox/common/memory/MemoryPool.h index cbb8a86cd23f..38b9febcccf0 100644 --- a/velox/common/memory/MemoryPool.h +++ b/velox/common/memory/MemoryPool.h @@ -29,7 +29,6 @@ #include "velox/common/memory/MemoryArbitrator.h" DECLARE_bool(velox_memory_leak_check_enabled); -DECLARE_bool(velox_memory_pool_debug_enabled); DECLARE_bool(velox_memory_pool_capacity_transfer_across_tasks); namespace facebook::velox::exec { @@ -42,6 +41,9 @@ class MemoryManager; constexpr int64_t kMaxMemory = std::numeric_limits::max(); +template +class StlAllocator; + /// This class provides the memory allocation interfaces for a query execution. /// Each query execution entity creates a dedicated memory pool object. The /// memory pool objects from a query are organized as a tree with four levels @@ -91,6 +93,9 @@ constexpr int64_t kMaxMemory = std::numeric_limits::max(); /// also provides memory usage accounting. class MemoryPool : public std::enable_shared_from_this { public: + template + using TStlAllocator = StlAllocator; + /// Defines the kinds of a memory pool. enum class Kind { /// The leaf memory pool is used for memory allocation. User can allocate @@ -112,6 +117,12 @@ class MemoryPool : public std::enable_shared_from_this { /// memory pools whose name matches the specified regular expression. Empty /// string means no match for all. std::string debugPoolNameRegex; + + /// Warning threshold in bytes for debug memory pools. When set to a + /// non-zero value, a warning will be logged once per memory pool when + /// allocations cause the pool to exceed this threshold. A value of + /// 0 means no warning threshold is enforced. + uint64_t debugPoolWarnThresholdBytes{0}; }; struct Options { @@ -152,11 +163,6 @@ class MemoryPool : public std::enable_shared_from_this { /// If non-empty, enables debug mode for the created memory pool. std::optional debugOptions{std::nullopt}; - - /// Sets the priority of the memory pool. The priority is used for - /// determining which pools to abort when the system is out of memory. - /// higher poolPriority value respresents higher priority and vice-versa. - uint32_t poolPriority{0}; }; /// Constructs a named memory pool with specified 'name', 'parent' and 'kind'. @@ -245,6 +251,13 @@ class MemoryPool : public std::enable_shared_from_this { /// Frees an allocated buffer. virtual void free(void* p, int64_t size) = 0; + /// Transfer the ownership of memory at 'buffer' for 'size' bytes to the + /// memory pool 'dest'. Returns true if the transfer succeeds. + virtual bool + transferTo(MemoryPool* /*dest*/, void* /*buffer*/, uint64_t /*size*/) { + return false; + } + /// Allocates one or more runs that add up to at least 'numPages', with the /// smallest run being at least 'minSizeClass' pages. 'minSizeClass' must be /// <= the size of the largest size class. The new memory is returned in 'out' @@ -302,11 +315,6 @@ class MemoryPool : public std::enable_shared_from_this { return alignment_; } - /// Returns the priority of this pool. - uint32_t poolPriority() const { - return poolPriority_; - } - /// Resource governing methods used to track and limit the memory usage /// through this memory pool object. @@ -556,7 +564,6 @@ class MemoryPool : public std::enable_shared_from_this { const bool trackUsage_; const bool threadSafe_; const std::optional debugOptions_; - const uint32_t poolPriority_; const bool coreOnAllocationFailureEnabled_; std::function getPreferredSize_; @@ -615,6 +622,8 @@ class MemoryPoolImpl : public MemoryPool { void free(void* p, int64_t size) override; + bool transferTo(MemoryPool* dest, void* buffer, uint64_t size) override; + void allocateNonContiguous( MachinePageCount numPages, Allocation& out, @@ -678,17 +687,7 @@ class MemoryPoolImpl : public MemoryPool { void setDestructionCallback(const DestructionCallback& callback); - std::string toString(bool detail = false) const override { - std::string result; - { - std::lock_guard l(mutex_); - result = toStringLocked(); - } - if (detail) { - result += "\n" + treeMemoryUsage(); - } - return result; - } + std::string toString(bool detail = false) const override; /// Detailed debug pool state printout by traversing the pool structure from /// the root memory pool. @@ -1008,6 +1007,33 @@ class MemoryPoolImpl : public MemoryPool { // pool is enabled. void leakCheckDbg(); + // Holds formatted string of dumped allocation records for a leaf memory pool, + // along with the total pool size in bytes. + struct MemoryPoolDump { + std::string dumpedRecords; + int64_t bytes; + }; + + // Recursively collects 'MemoryPoolDump' records for this memory pool and + // all its descendants in the tree. Called during memory capacity-exceeded + // exceptions to extend the error message with additional debug information. + void treeAllocationRecordsDbg(std::vector& poolDumps) const; + + // Wraps the message of a memory capacity exceeded exception with debug + // allocation records from all memory pools in the subtree. This function is + // called from the root memory pool. + std::exception_ptr wrapExceptionDbg( + const VeloxRuntimeError& veloxError) const; + + // Dump the recorded call sites of the memory allocations in + // 'debugAllocRecords_' to the string. + std::string dumpRecordsDbgLocked() const; + + std::string dumpRecordsDbg() const { + std::lock_guard l(debugAllocMutex_); + return dumpRecordsDbgLocked(); + } + void handleAllocationFailure(const std::string& failureMessage); MemoryManager* const manager_; @@ -1071,10 +1097,13 @@ class MemoryPoolImpl : public MemoryPool { std::atomic_uint64_t numCapacityGrowths_{0}; // Mutex for 'debugAllocRecords_'. - std::mutex debugAllocMutex_; + mutable std::mutex debugAllocMutex_; // Map from address to 'AllocationRecord'. std::unordered_map debugAllocRecords_; + + // Flag to track if warning threshold has been exceeded once for this pool. + bool debugWarnThresholdExceeded_{false}; }; /// An Allocator backed by a memory pool for STL containers. @@ -1086,6 +1115,8 @@ class StlAllocator { /* implicit */ StlAllocator(MemoryPool& pool) : pool{pool} {} + explicit StlAllocator(MemoryPool* pool) : pool{*pool} {} + template /* implicit */ StlAllocator(const StlAllocator& a) : pool{a.pool} {} @@ -1104,11 +1135,6 @@ class StlAllocator { } return false; } - - template - bool operator!=(const StlAllocator& rhs) const { - return !(*this == rhs); - } }; } // namespace facebook::velox::memory diff --git a/velox/common/memory/MmapAllocator.cpp b/velox/common/memory/MmapAllocator.cpp index 8ab440e21252..54f80b202829 100644 --- a/velox/common/memory/MmapAllocator.cpp +++ b/velox/common/memory/MmapAllocator.cpp @@ -33,9 +33,11 @@ MmapAllocator::MmapAllocator(const Options& options) maxMallocBytes_ == 0 ? 0 : options.capacity * options.smallAllocationReservePct / 100), - capacity_(bits::roundUp( - AllocationTraits::numPages(options.capacity - mallocReservedBytes_), - 64 * sizeClassSizes_.back())) { + capacity_( + bits::roundUp( + AllocationTraits::numPages( + options.capacity - mallocReservedBytes_), + 64 * sizeClassSizes_.back())) { for (const auto& size : sizeClassSizes_) { sizeClasses_.push_back(std::make_unique(capacity_ / size, size)); } @@ -293,10 +295,11 @@ bool MmapAllocator::allocateContiguousImpl( const std::string errorMsg = fmt::format( "Exceeded memory allocator limit when allocating {} new pages for " "total allocation of {} pages, the memory allocator capacity is" - " {} pages", + " {} pages, the allocated pages is {}", newPages, numPages, - capacity_); + capacity_, + numAllocated_); VELOX_MEM_LOG_EVERY_MS(WARNING, 1000) << errorMsg; setAllocatorFailureMessage(errorMsg); rollbackAllocation(0); @@ -395,10 +398,11 @@ bool MmapAllocator::growContiguousWithoutRetry( const std::string errorMsg = fmt::format( "Exceeded memory allocator limit when allocating {} new pages for " "total allocation of {} pages, the memory allocator capacity is" - " {} pages", + " {} pages, the allocated pages is {}", increment, allocation.numPages(), - capacity_); + capacity_, + numAllocated_); VELOX_MEM_LOG_EVERY_MS(WARNING, 1000) << errorMsg; setAllocatorFailureMessage(errorMsg); numAllocated_ -= increment; diff --git a/velox/common/memory/MmapAllocator.h b/velox/common/memory/MmapAllocator.h index fd17ee8fa765..f771db0071dc 100644 --- a/velox/common/memory/MmapAllocator.h +++ b/velox/common/memory/MmapAllocator.h @@ -145,7 +145,7 @@ class MmapAllocator : public MemoryAllocator { return numMapped_; } - MachinePageCount numExternalMapped() const { + MachinePageCount numExternalMapped() const override { return numExternalMapped_; } @@ -382,14 +382,6 @@ class MmapAllocator : public MemoryAllocator { // Serializes moving capacity between size classes std::mutex sizeClassBalanceMutex_; - // Number of pages allocated and explicitly mmap'd by the - // application via allocateContiguous, outside of - // 'sizeClasses'. These pages are counted in 'numAllocated_' and - // 'numMapped_'. Allocation requests are decided against - // 'numAllocated_' and 'numMapped_'. This counter is informational - // only. - std::atomic numExternalMapped_{0}; - // Allocations smaller than 'maxMallocBytes' will be delegated to // std::malloc(). // diff --git a/velox/common/memory/RawVector.h b/velox/common/memory/RawVector.h index 437ba52120d5..a61703fe8173 100644 --- a/velox/common/memory/RawVector.h +++ b/velox/common/memory/RawVector.h @@ -36,10 +36,18 @@ class raw_vector { static_assert( std::is_trivially_destructible_v && std::is_trivially_copyable_v); - explicit raw_vector(memory::MemoryPool* pool = nullptr) : pool_(pool) {} + explicit raw_vector() {} - explicit raw_vector(int64_t size, memory::MemoryPool* pool = nullptr) - : pool_(pool) { + explicit raw_vector(int64_t size) { + resize(size); + } + + explicit raw_vector(memory::MemoryPool* pool) : pool_(pool) { + VELOX_CHECK_NOT_NULL(pool); + } + + explicit raw_vector(int64_t size, memory::MemoryPool* pool) : pool_(pool) { + VELOX_CHECK_NOT_NULL(pool); resize(size); } @@ -204,9 +212,12 @@ class raw_vector { // Clear the word below the pointer so that we do not get read of // uninitialized when reading a partial word that extends below // the pointer. + // Suppress GCC14 warning. "error: writing 8 bytes into a region of size 0" + VELOX_SUPPRESS_STRINGOP_OVERFLOW_WARNING *reinterpret_cast( reinterpret_cast(getDataFromBuffer(buffer)) - sizeof(int64_t)) = 0; + VELOX_UNSUPPRESS_STRINGOP_OVERFLOW_WARNING return getDataFromBuffer(buffer); } diff --git a/velox/common/memory/Scratch.h b/velox/common/memory/Scratch.h index dcde26e90d91..bed226a0f944 100644 --- a/velox/common/memory/Scratch.h +++ b/velox/common/memory/Scratch.h @@ -79,14 +79,17 @@ class Scratch { // stringop-overflow warning when 'newCapacity' is 0. folly::assume(capacity_ >= 0); if (newCapacity > capacity_) { - Item* newItems = - reinterpret_cast(::malloc(sizeof(Item) * newCapacity)); + auto* newItems = + reinterpret_cast(::malloc(sizeof(Item) * newCapacity)); if (fill_ > 0) { ::memcpy(newItems, items_, fill_ * sizeof(Item)); } - ::memset(newItems + fill_, 0, (newCapacity - fill_) * sizeof(Item)); + ::memset( + newItems + fill_ * sizeof(Item), + 0, + (newCapacity - fill_) * sizeof(Item)); ::free(items_); - items_ = newItems; + items_ = reinterpret_cast(newItems); capacity_ = newCapacity; } fill_ = std::min(fill_, newCapacity); diff --git a/velox/common/memory/SharedArbitrator.cpp b/velox/common/memory/SharedArbitrator.cpp index b694b7aa4a75..4d4c74540b11 100644 --- a/velox/common/memory/SharedArbitrator.cpp +++ b/velox/common/memory/SharedArbitrator.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/memory/SharedArbitrator.h" +#include #include #include #include @@ -58,25 +59,28 @@ namespace { } #define MEM_POOL_CAP_EXCEEDED(errorMessage, requestPool) \ - VELOX_MEM_POOL_CAP_EXCEEDED(fmt::format( \ - "Exceeded memory pool capacity. {}\n{}\n\n{}", \ - errorMessage, \ - this->toString(), \ - requestPool->toString(true))); + VELOX_MEM_POOL_CAP_EXCEEDED( \ + fmt::format( \ + "Exceeded memory pool capacity. {}\n{}\n\n{}", \ + errorMessage, \ + this->toString(), \ + requestPool->toString(true))); #define LOCAL_MEM_ARBITRATION_FAILED(errorMessage, requestPool) \ - VELOX_MEM_ARBITRATION_FAILED(fmt::format( \ - "Local arbitration failure. {}\n{}\n\n{}", \ - errorMessage, \ - this->toString(), \ - requestPool->toString(true))); + VELOX_MEM_ARBITRATION_FAILED( \ + fmt::format( \ + "Local arbitration failure. {}\n{}\n\n{}", \ + errorMessage, \ + this->toString(), \ + requestPool->toString(true))); #define GLOBAL_MEM_ARBITRATION_FAILED(errorMessage, requestPool) \ - VELOX_MEM_ARBITRATION_FAILED(fmt::format( \ - "Global arbitration failure. {}\n{}\n\n{}", \ - errorMessage, \ - this->toString(), \ - requestPool->toString(true))); + VELOX_MEM_ARBITRATION_FAILED( \ + fmt::format( \ + "Global arbitration failure. {}\n{}\n\n{}", \ + errorMessage, \ + this->toString(), \ + requestPool->toString(true))); template T getConfig( @@ -126,10 +130,11 @@ uint64_t SharedArbitrator::ExtraConfig::memoryPoolReservedCapacity( uint64_t SharedArbitrator::ExtraConfig::maxMemoryArbitrationTimeNs( const std::unordered_map& configs) { return std::chrono::duration_cast( - config::toDuration(getConfig( - configs, - kMaxMemoryArbitrationTime, - std::string(kDefaultMaxMemoryArbitrationTime)))) + config::toDuration( + getConfig( + configs, + kMaxMemoryArbitrationTime, + std::string(kDefaultMaxMemoryArbitrationTime)))) .count(); } @@ -167,6 +172,16 @@ double SharedArbitrator::ExtraConfig::memoryPoolMinReclaimPct( configs, kMemoryPoolMinReclaimPct, kDefaultMemoryPoolMinReclaimPct); } +uint64_t SharedArbitrator::ExtraConfig::memoryPoolSpillCapacityLimit( + const std::unordered_map& configs) { + return config::toCapacity( + getConfig( + configs, + kMemoryPoolSpillCapacityLimit, + std::string(kDefaultMemoryPoolSpillCapacityLimit)), + config::CapacityUnit::BYTE); +} + uint64_t SharedArbitrator::ExtraConfig::memoryPoolAbortCapacityLimit( const std::unordered_map& configs) { return config::toCapacity( @@ -253,8 +268,7 @@ SharedArbitrator::SharedArbitrator(const Config& config) ExtraConfig::memoryPoolMinFreeCapacity(config.extraConfigs), ExtraConfig::memoryPoolMinFreeCapacityPct(config.extraConfigs), ExtraConfig::memoryPoolMinReclaimBytes(config.extraConfigs), - ExtraConfig::memoryPoolMinReclaimPct(config.extraConfigs), - ExtraConfig::memoryPoolAbortCapacityLimit(config.extraConfigs)), + ExtraConfig::memoryPoolMinReclaimPct(config.extraConfigs)), memoryReclaimThreadsHwMultiplier_( ExtraConfig::memoryReclaimThreadsHwMultiplier(config.extraConfigs)), globalArbitrationEnabled_( @@ -283,15 +297,16 @@ SharedArbitrator::SharedArbitrator(const Config& config) "memoryReclaimThreadsHwMultiplier_ needs to be positive"); const uint64_t numReclaimThreads = std::max( - 1, - std::thread::hardware_concurrency() * memoryReclaimThreadsHwMultiplier_); + 1, folly::hardware_concurrency() * memoryReclaimThreadsHwMultiplier_); memoryReclaimExecutor_ = std::make_unique( numReclaimThreads, std::make_shared("MemoryReclaim")); VELOX_MEM_LOG(INFO) << "Start memory reclaim executor with " << numReclaimThreads << " threads"; - setupGlobalArbitration(); + setupGlobalArbitration( + ExtraConfig::memoryPoolSpillCapacityLimit(config.extraConfigs), + ExtraConfig::memoryPoolAbortCapacityLimit(config.extraConfigs)); VELOX_MEM_LOG(INFO) << "Shared arbitrator created with " << succinctBytes(capacity_) << " capacity, " @@ -332,20 +347,37 @@ void SharedArbitrator::shutdown() { memoryReclaimExecutor_.reset(); VELOX_MEM_LOG(INFO) << "Memory reclaim executor stopped"; - VELOX_CHECK_EQ( - participants_.size(), 0, "Unexpected alive participants on destruction"); + if (!participants_.empty()) { + std::vector participantNames; + participantNames.reserve(participants_.size()); + for (const auto& partitionEntry : participants_) { + participantNames.push_back(partitionEntry.second->name()); + } + VELOX_FAIL( + "Unexpected alive participants on destruction: {}", + folly::join(",", participantNames)); + } } -void SharedArbitrator::setupGlobalArbitration() { +void SharedArbitrator::setupGlobalArbitration( + uint64_t spillCapacityLimit, + uint64_t abortCapacityLimit) { if (!globalArbitrationEnabled_) { return; } + VELOX_CHECK( + bits::isPowerOfTwo(spillCapacityLimit), + "spillCapacityLimit {} not a power of two", + spillCapacityLimit); + VELOX_CHECK( + bits::isPowerOfTwo(abortCapacityLimit), + "abortCapacityLimit {} not a power of two", + abortCapacityLimit); VELOX_CHECK_NULL(globalArbitrationController_); - const uint64_t minAbortCapacity = 32 << 20; - for (auto abortLimit = participantConfig_.abortCapacityLimit; abortLimit >= - std::max(minAbortCapacity, - folly::nextPowTwo(participantConfig_.minCapacity)); + const auto minAbortCapacity = std::max( + 32 << 20, folly::nextPowTwo(participantConfig_.minCapacity)); + for (auto abortLimit = abortCapacityLimit; abortLimit >= minAbortCapacity; abortLimit /= 2) { globalArbitrationAbortCapacityLimits_.push_back(abortLimit); } @@ -355,6 +387,22 @@ void SharedArbitrator::setupGlobalArbitration() { << folly::join( ",", globalArbitrationAbortCapacityLimits_); + VELOX_CHECK_GE(spillCapacityLimit, participantConfig_.minReclaimBytes); + const auto minSpillCapacity = std::max( + folly::nextPowTwo(participantConfig_.minReclaimBytes), + folly::nextPowTwo(participantConfig_.minCapacity)); + for (auto spillLimit = spillCapacityLimit; spillLimit >= minSpillCapacity; + spillLimit /= 2) { + globalArbitrationSpillCapacityLimits_.push_back(spillLimit); + } + if (globalArbitrationSpillCapacityLimits_.empty()) { + globalArbitrationSpillCapacityLimits_.push_back(minSpillCapacity); + } + + VELOX_MEM_LOG(INFO) << "Global arbitration spill capacity limits: " + << folly::join( + ",", globalArbitrationSpillCapacityLimits_); + globalArbitrationController_ = std::make_unique([&]() { folly::setThreadName("GlobalArbitrationController"); globalArbitrationMain(); @@ -368,6 +416,7 @@ void SharedArbitrator::shutdownGlobalArbitration() { } VELOX_CHECK(!globalArbitrationAbortCapacityLimits_.empty()); + VELOX_CHECK(!globalArbitrationSpillCapacityLimits_.empty()); VELOX_CHECK_NOT_NULL(globalArbitrationController_); VELOX_MEM_LOG(INFO) << "Stopping global arbitration controller"; @@ -494,14 +543,14 @@ void SharedArbitrator::addPool(const std::shared_ptr& pool) { } void SharedArbitrator::removePool(MemoryPool* pool) { - VELOX_CHECK_EQ(pool->reservedBytes(), 0); + VELOX_CHECK_EQ(pool->reservedBytes(), 0, "{}", pool->name()); const uint64_t freedBytes = shrinkPool(pool, 0); - VELOX_CHECK_EQ(pool->capacity(), 0); + VELOX_CHECK_EQ(pool->capacity(), 0, "{}", pool->name()); freeCapacity(freedBytes); std::unique_lock guard{participantLock_}; const auto ret = participants_.erase(pool->name()); - VELOX_CHECK_EQ(ret, 1); + VELOX_CHECK_EQ(ret, 1, "{}", pool->name()); } std::vector SharedArbitrator::getCandidates( @@ -519,6 +568,7 @@ std::vector SharedArbitrator::getCandidates( return candidates; } +// static void SharedArbitrator::sortCandidatesByReclaimableFreeCapacity( std::vector& candidates) { std::sort( @@ -532,8 +582,9 @@ void SharedArbitrator::sortCandidatesByReclaimableFreeCapacity( &candidates); } -void SharedArbitrator::sortCandidatesByReclaimableUsedCapacity( - std::vector& candidates) { +std::vector> +SharedArbitrator::sortAndGroupSpillCandidates( + std::vector&& candidates) { std::sort( candidates.begin(), candidates.end(), @@ -541,9 +592,95 @@ void SharedArbitrator::sortCandidatesByReclaimableUsedCapacity( return lhs.reclaimableUsedCapacity > rhs.reclaimableUsedCapacity; }); + const auto numCandidates = candidates.size(); + std::vector> candidateGroups; + candidateGroups.reserve(globalArbitrationSpillCapacityLimits_.size()); + uint32_t candidateIdx{0}; + for (const auto& capacityLimit : globalArbitrationSpillCapacityLimits_) { + if (candidateIdx >= numCandidates) { + break; + } + candidateGroups.emplace_back(); + candidateGroups.back().reserve(numCandidates - candidateIdx); + for (; candidateIdx < numCandidates; ++candidateIdx) { + if (candidates[candidateIdx].reclaimableUsedCapacity < capacityLimit) { + break; + } + candidateGroups.back().push_back(std::move(candidates[candidateIdx])); + } + if (candidateGroups.back().empty()) { + candidateGroups.pop_back(); + } + } + + // Sort candidates in each group by priority and reclaimable used capacity. + for (auto& candidateGroup : candidateGroups) { + std::sort( + candidateGroup.begin(), + candidateGroup.end(), + [](const ArbitrationCandidate& lhs, const ArbitrationCandidate& rhs) { + const auto* lhsReclaimer = lhs.participant->pool()->reclaimer(); + const auto* rhsReclaimer = rhs.participant->pool()->reclaimer(); + if (FOLLY_UNLIKELY( + lhsReclaimer == nullptr || rhsReclaimer == nullptr)) { + VELOX_FAIL( + "Spill candidates must have memory reclaimer set. Left '{}', right '{}'", + lhs.participant->name(), + rhs.participant->name()); + } + if (lhsReclaimer->priority() == rhsReclaimer->priority()) { + return lhs.reclaimableUsedCapacity > rhs.reclaimableUsedCapacity; + } + return lhsReclaimer->priority() > rhsReclaimer->priority(); + }); + } + TestValue::adjust( - "facebook::velox::memory::SharedArbitrator::sortCandidatesByReclaimableUsedCapacity", - &candidates); + "facebook::velox::memory::SharedArbitrator::sortAndGroupSpillCandidates", + &candidateGroups); + return candidateGroups; +} + +// static +std::vector> +SharedArbitrator::sortAndGroupAbortCandidates( + std::vector&& candidates) { + std::sort( + candidates.begin(), + candidates.end(), + [](const ArbitrationCandidate& lhs, const ArbitrationCandidate& rhs) { + const auto* lhsReclaimer = lhs.participant->pool()->reclaimer(); + const auto* rhsReclaimer = rhs.participant->pool()->reclaimer(); + if (lhsReclaimer == nullptr || rhsReclaimer == nullptr) { + // Participants without reclaimer are treated as low priority, putting + // them in front. + return (lhsReclaimer == nullptr) > (rhsReclaimer == nullptr); + } + return lhsReclaimer->priority() > rhsReclaimer->priority(); + }); + + std::vector> candidateGroups; + std::optional prevPriority; + for (auto i = 0; i < candidates.size(); ++i) { + const auto* curReclaimer = candidates[i].participant->pool()->reclaimer(); + const auto curPriority = curReclaimer == nullptr + ? std::nullopt + : std::optional(curReclaimer->priority()); + if (i == 0) { + prevPriority = curPriority; + candidateGroups.emplace_back( + std::vector{std::move(candidates[i])}); + continue; + } + if (curPriority != prevPriority) { + prevPriority = curPriority; + candidateGroups.emplace_back( + std::vector{std::move(candidates[i])}); + } else { + candidateGroups.back().push_back(std::move(candidates[i])); + } + } + return candidateGroups; } std::optional SharedArbitrator::findAbortCandidate( @@ -561,47 +698,34 @@ std::optional SharedArbitrator::findAbortCandidate( return std::nullopt; } - // Returns if other candidate should be chosen for abort. - // With the same capacity size bucket, we favor highest priority followed - // by oldest participant to not to be killed. This allows long running - // highest priority query to proceed first. - auto chooseAnotherCandidateForAbort = [&](const ArbitrationCandidate& current, - const ArbitrationCandidate& other) { - if (current.participant->poolPriority() < - other.participant->poolPriority()) { - return false; - } else if ( - current.participant->poolPriority() == - other.participant->poolPriority() && - current.participant->id() > other.participant->id()) { - return false; - } - return true; - }; - - for (uint64_t capacityLimit : globalArbitrationAbortCapacityLimits_) { - int32_t candidateIdx{-1}; - for (int32_t i = 0; i < candidates.size(); ++i) { - if (candidates[i].participant->aborted()) { - continue; - } - if (candidates[i].currentCapacity < capacityLimit || - candidates[i].currentCapacity == 0) { - continue; - } - if (candidateIdx == -1) { - candidateIdx = i; - continue; + auto candidateGroups = sortAndGroupAbortCandidates(std::move(candidates)); + + for (auto& candidateGroup : candidateGroups) { + for (uint64_t capacityLimit : globalArbitrationAbortCapacityLimits_) { + int32_t candidateIdx{-1}; + for (int32_t i = 0; i < candidateGroup.size(); ++i) { + if (candidateGroup[i].participant->aborted()) { + continue; + } + if (candidateGroup[i].currentCapacity < capacityLimit || + candidateGroup[i].currentCapacity == 0) { + continue; + } + if (candidateIdx == -1) { + candidateIdx = i; + continue; + } + // With the same capacity size bucket, we favor the old participant to + // not to be killed, to let long running query proceed first. + if (candidateGroup[candidateIdx].participant->id() < + candidateGroup[i].participant->id()) { + candidateIdx = i; + } } - - if (chooseAnotherCandidateForAbort( - candidates[candidateIdx], candidates[i])) { - candidateIdx = i; + if (candidateIdx != -1) { + return candidateGroup[candidateIdx]; } } - if (candidateIdx != -1) { - return candidates[candidateIdx]; - } } if (!force) { @@ -609,22 +733,23 @@ std::optional SharedArbitrator::findAbortCandidate( return std::nullopt; } - // Can't find an eligible abort candidate and then return the lowest priority - // youngest candidate which has the largest participant id. - int32_t candidateIdx{-1}; - for (auto i = 0; i < candidates.size(); ++i) { - if (candidateIdx == -1) { - candidateIdx = i; - } else if (chooseAnotherCandidateForAbort( - candidates[candidateIdx], candidates[i])) { + // Can't find an eligible abort candidate and then return the youngest + // candidate (which has the largest participant id) in the lowest priority + // bucket. + VELOX_CHECK(!candidateGroups.empty() && !candidateGroups[0].empty()); + int32_t candidateIdx{0}; + for (auto i = 0; i < candidateGroups[0].size(); ++i) { + if (candidateGroups[0][i].participant->id() > + candidateGroups[0][candidateIdx].participant->id()) { candidateIdx = i; } } - VELOX_CHECK_NE(candidateIdx, -1); + VELOX_MEM_LOG(WARNING) - << "Can't find an eligible abort victim and force to abort the youngest participant " - << candidates[candidateIdx].participant->name(); - return candidates[candidateIdx]; + << "Can't find an eligible abort victim and force to abort the youngest " + "participant " + << candidateGroups[0][candidateIdx].participant->name(); + return candidateGroups[0][candidateIdx]; } void SharedArbitrator::updateArbitrationRequestStats() { @@ -772,19 +897,21 @@ void SharedArbitrator::growCapacity(ArbitrationOperation& op) { RETURN_IF_TRUE(maybeGrowFromSelf(op)); if (!ensureCapacity(op)) { + const auto maxCapacity = op.participant()->maxCapacity(); MEM_POOL_CAP_EXCEEDED( fmt::format( - "Can't grow {} capacity with {}. This will exceed its max capacity " + "Can't grow {} capacity with {}. This will exceed its {} " "{}, current capacity {}.", op.participant()->name(), succinctBytes(op.requestBytes()), - succinctBytes(op.participant()->maxCapacity()), + capacity_ < maxCapacity ? "arbitrator capacity" + : "memory pool capacity", + succinctBytes(std::min(capacity_, maxCapacity)), succinctBytes(op.participant()->capacity())), op.participant()->pool()); } checkIfAborted(op); - checkIfTimeout(op); RETURN_IF_TRUE(maybeGrowFromSelf(op)); @@ -804,6 +931,9 @@ void SharedArbitrator::growCapacity(ArbitrationOperation& op) { succinctBytes(participantConfig_.minReclaimBytes)), op.participant()->pool()); } + + checkIfTimeout(op); + // After failing to acquire enough free capacity to fulfil this capacity // growth request, we will try to reclaim from the participant itself before // failing this operation. We only do this if global memory arbitration is @@ -1016,10 +1146,11 @@ void SharedArbitrator::checkIfAborted(ArbitrationOperation& op) { void SharedArbitrator::checkIfTimeout(ArbitrationOperation& op) { if (FOLLY_UNLIKELY(op.hasTimeout())) { - VELOX_MEM_ARBITRATION_TIMEOUT(fmt::format( - "Memory arbitration timed out on memory pool: {} after running {}", - op.participant()->name(), - succinctNanos(op.executionTimeNs()))); + VELOX_MEM_ARBITRATION_TIMEOUT( + fmt::format( + "Memory arbitration timed out on memory pool: {} after running {}", + op.participant()->name(), + succinctNanos(op.executionTimeNs()))); } } @@ -1137,31 +1268,34 @@ uint64_t SharedArbitrator::reclaimUsedMemoryBySpill( allParticipantsReclaimed = true; const uint64_t prevReclaimedBytes = reclaimedUsedBytes_; - auto candidates = getCandidates(); - sortCandidatesByReclaimableUsedCapacity(candidates); + auto candidates = getCandidates(); std::vector victims; victims.reserve(candidates.size()); + auto candidateGroups = sortAndGroupSpillCandidates(std::move(candidates)); + uint64_t bytesToReclaim{0}; - for (auto& candidate : candidates) { - if (candidate.reclaimableUsedCapacity < - participantConfig_.minReclaimBytes) { - break; - } - if (failedParticipants.count(candidate.participant->id()) != 0) { - VELOX_CHECK_EQ( - reclaimedParticipants.count(candidate.participant->id()), 1); - continue; - } - if (bytesToReclaim >= targetBytes) { - if (reclaimedParticipants.count(candidate.participant->id()) == 0) { - allParticipantsReclaimed = false; + for (auto& candidateGroup : candidateGroups) { + for (auto& candidate : candidateGroup) { + if (candidate.reclaimableUsedCapacity < + participantConfig_.minReclaimBytes) { + continue; } - continue; + if (failedParticipants.count(candidate.participant->id()) != 0) { + VELOX_CHECK_EQ( + reclaimedParticipants.count(candidate.participant->id()), 1); + continue; + } + if (bytesToReclaim >= targetBytes) { + if (reclaimedParticipants.count(candidate.participant->id()) == 0) { + allParticipantsReclaimed = false; + } + continue; + } + bytesToReclaim += candidate.reclaimableUsedCapacity; + reclaimedParticipants.insert(candidate.participant->id()); + victims.push_back(std::move(candidate)); } - bytesToReclaim += candidate.reclaimableUsedCapacity; - reclaimedParticipants.insert(candidate.participant->id()); - victims.push_back(std::move(candidate)); } if (victims.empty()) { FB_LOG_EVERY_MS(WARNING, 1'000) @@ -1228,15 +1362,15 @@ uint64_t SharedArbitrator::reclaimUsedMemoryByAbort(bool force) { // after abort operation. const auto currentCapacity = victim.participant->pool()->capacity(); try { - VELOX_MEM_POOL_ABORTED(fmt::format( - "Memory pool aborted to reclaim used memory, current capacity {}, " - "requesting capacity from global arbitration {} memory pool " - "priority:{}\nstats:\n{}\n{}", - succinctBytes(currentCapacity), - succinctBytes(victim.participant->globalArbitrationGrowCapacity()), - victim.participant->pool()->poolPriority(), - victim.participant->pool()->toString(), - victim.participant->pool()->treeMemoryUsage())); + VELOX_MEM_POOL_ABORTED( + fmt::format( + "Memory pool aborted to reclaim used memory, current capacity {}, " + "requesting capacity from global arbitration {} memory pool " + "stats:\n{}\n{}", + succinctBytes(currentCapacity), + succinctBytes(victim.participant->globalArbitrationGrowCapacity()), + victim.participant->pool()->toString(), + victim.participant->pool()->treeMemoryUsage())); } catch (VeloxRuntimeError&) { abort(victim.participant, std::current_exception()); return currentCapacity; @@ -1270,7 +1404,6 @@ uint64_t SharedArbitrator::reclaim( if (participant->aborted()) { removeGlobalArbitrationWaiter(participant->id()); } - freeCapacity(reclaimedBytes); updateMemoryReclaimStats( reclaimedBytes, reclaimTimeNs, localArbitration, stats); @@ -1282,6 +1415,8 @@ uint64_t SharedArbitrator::reclaim( << " stats " << succinctBytes(stats.reclaimedBytes) << " numNonReclaimableAttempts " << stats.numNonReclaimableAttempts; + + freeCapacity(reclaimedBytes); if (reclaimedBytes == 0) { FB_LOG_EVERY_MS(WARNING, 1'000) << fmt::format( "Nothing reclaimed from memory pool {} with reclaim target {}, memory pool stats:\n{}\n{}", diff --git a/velox/common/memory/SharedArbitrator.h b/velox/common/memory/SharedArbitrator.h index 5a0238e460dc..16efd247110c 100644 --- a/velox/common/memory/SharedArbitrator.h +++ b/velox/common/memory/SharedArbitrator.h @@ -135,13 +135,29 @@ class SharedArbitrator : public memory::MemoryArbitrator { static double memoryPoolMinReclaimPct( const std::unordered_map& configs); + /// Specifies the starting memory capacity limit for global arbitration to + /// search for victim participant to reclaim used memory by spill. For + /// participants with reclaimable used capacity larger than the limit, the + /// global arbitration choose to spill the lowest priority participant with + /// highest reclaimable used capacity. The spill capacity limit is reduced + /// by half if couldn't find a victim participant until reaches to zero. + /// + /// NOTE: the limit must be zero or a power of 2. + static constexpr std::string_view kMemoryPoolSpillCapacityLimit{ + "memory-pool-spill-capacity-limit"}; + static constexpr std::string_view kDefaultMemoryPoolSpillCapacityLimit{ + "4GB"}; + static uint64_t memoryPoolSpillCapacityLimit( + const std::unordered_map& configs); + /// Specifies the starting memory capacity limit for global arbitration to /// search for victim participant to reclaim used memory by abort. For /// participants with capacity larger than the limit, the global arbitration - /// choose to abort the youngest participant which has the largest - /// participant id. This helps to let the old queries to run to completion. - /// The abort capacity limit is reduced by half if could not find a victim - /// participant until this reaches to zero. + /// choose to abort the participant that has lowest priority and shortest + /// execution time (largest participant id). This helps to let the low + /// priority queries to be aborted first, as well as old queries to run to + /// completion. The abort capacity limit is reduced by half if couldn't find + /// a victim participant until reaches to zero. /// /// NOTE: the limit must be zero or a power of 2. static constexpr std::string_view kMemoryPoolAbortCapacityLimit{ @@ -387,7 +403,9 @@ class SharedArbitrator : public memory::MemoryArbitrator { // Invoked to initialize the global arbitration on arbitrator start-up. It // starts the background threads to used memory from running queries // on-demand. - void setupGlobalArbitration(); + void setupGlobalArbitration( + uint64_t spillCapacityLimit, + uint64_t abortCapacityLimit); // Invoked to stop the global arbitration threads on shut-down. void shutdownGlobalArbitration(); @@ -468,10 +486,6 @@ class SharedArbitrator : public memory::MemoryArbitrator { uint64_t reclaimUsedMemoryBySpill(uint64_t targetBytes); - // Sorts 'candidates' based on reclaimable used capacity in descending order. - static void sortCandidatesByReclaimableUsedCapacity( - std::vector& candidates); - // Invoked to reclaim the used memory capacity to abort the participant with // the largest capacity to free up memory. The function returns the actually // reclaimed capacity in bytes. @@ -483,6 +497,21 @@ class SharedArbitrator : public memory::MemoryArbitrator { // abort if there is no eligible one. uint64_t reclaimUsedMemoryByAbort(bool force); + // Sorts and groups 'candidates' for spilling. The sort firstly groups + // candidates first based on 'globalArbitrationSpillCapacityLimits_', larger + // bucket in the front. Then order each group based on priority and + // reclaimable used capacity in descending order, with lower priority (higher + // priority value) and higher reclaimable used capacity ones in front. + // Priority takes precedence over reclaimable used capacity. + std::vector> sortAndGroupSpillCandidates( + std::vector&& candidates); + + // Sorts 'candidates' based on participant's reclaimer priority in descending + // order, putting lower priority ones (with higher priority value) first, and + // high priority ones (with lower priority value) later. + static std::vector> + sortAndGroupAbortCandidates(std::vector&& candidates); + // Finds the participant victim to abort to free used memory based on the // participant's memory capacity and age. The function returns std::nullopt if // there is no eligible candidate. If 'force' is true, it picks up the @@ -636,6 +665,13 @@ class SharedArbitrator : public memory::MemoryArbitrator { // If there is no such participant, it goes to the next limit and so on. std::vector globalArbitrationAbortCapacityLimits_; + // The spill capacity limits listed in descending order. It is used by global + // arbitration to choose the victim to spill. It starts with the largest limit + // and spill the largest reclaimable used capacity participant whose + // reclaimable used capacity is larger than the limit. If there is no such + // participant, it goes to the next limit and so on. + std::vector globalArbitrationSpillCapacityLimits_; + // The global arbitration control thread which runs the global arbitration at // the background, and dispatch the actual memory reclaim work on different // participants to 'globalArbitrationExecutor_' and collects the results back. diff --git a/velox/common/memory/StreamArena.h b/velox/common/memory/StreamArena.h index c6bc95d011c2..b00b9dff950b 100644 --- a/velox/common/memory/StreamArena.h +++ b/velox/common/memory/StreamArena.h @@ -62,6 +62,10 @@ class StreamArena { /// serilizers. virtual void clear(); + memory::MachinePageCount testingAllocationQuantum() const { + return allocationQuantum_; + } + private: memory::MemoryPool* const pool_; diff --git a/velox/common/memory/tests/ArbitrationParticipantTest.cpp b/velox/common/memory/tests/ArbitrationParticipantTest.cpp index 8d7e70c9f587..b3abd4b27929 100644 --- a/velox/common/memory/tests/ArbitrationParticipantTest.cpp +++ b/velox/common/memory/tests/ArbitrationParticipantTest.cpp @@ -103,7 +103,6 @@ constexpr double kMemoryPoolMinFreeCapacityRatio = 0.25; constexpr uint64_t kFastExponentialGrowthCapacityLimit = 256 * MB; constexpr uint64_t kMemoryPoolMinReclaimBytes = 0; constexpr double kMemoryPoolMinReclaimPct = 0; -constexpr uint64_t kMemoryPoolAbortCapacityLimit = 0; constexpr double kSlowCapacityGrowRatio = 0.25; class MemoryReclaimer; @@ -234,11 +233,12 @@ class MockTask : public std::enable_shared_from_this { ReclaimInjectionCallback reclaimInjectCb, ArbitrationInjectionCallback arbitrationInjectCb) { root_->setReclaimer(RootMemoryReclaimer::create(shared_from_this())); - pool_->setReclaimer(std::make_unique( - shared_from_this(), - reclaimable, - std::move(reclaimInjectCb), - std::move(arbitrationInjectCb))); + pool_->setReclaimer( + std::make_unique( + shared_from_this(), + reclaimable, + std::move(reclaimInjectCb), + std::move(arbitrationInjectCb))); } const std::shared_ptr& pool() const { @@ -415,8 +415,7 @@ static ArbitrationParticipant::Config arbitrationConfig( uint64_t minFreeCapacity = kMemoryPoolMinFreeCapacity, double minFreeCapacityRatio = kMemoryPoolMinFreeCapacityRatio, uint64_t minReclaimBytes = kMemoryPoolMinReclaimBytes, - double minReclaimPct = kMemoryPoolMinReclaimPct, - uint64_t abortCapacityLimit = kMemoryPoolAbortCapacityLimit) { + double minReclaimPct = kMemoryPoolMinReclaimPct) { return ArbitrationParticipant::Config{ 0, minCapacity, @@ -425,8 +424,7 @@ static ArbitrationParticipant::Config arbitrationConfig( minFreeCapacity, minFreeCapacityRatio, minReclaimBytes, - minReclaimPct, - abortCapacityLimit}; + minReclaimPct}; } TEST_F(ArbitrationParticipantTest, config) { @@ -439,13 +437,12 @@ TEST_F(ArbitrationParticipantTest, config) { double minFreeCapacityRatio; uint64_t minReclaimBytes; double minReclaimPct; - uint64_t abortCapacityLimit; bool expectedError; std::string expectedToString; std::string debugString() const { return fmt::format( - "initCapacity {}, minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, minFreeCapacity {}, minFreeCapacityRatio {}, minReclaimBytes {}, minReclaimPct {}, abortCapacityLimit {}, expectedError {}, expectedToString: {}", + "initCapacity {}, minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, minFreeCapacity {}, minFreeCapacityRatio {}, minReclaimBytes {}, minReclaimPct {}, expectedError {}, expectedToString: {}", succinctBytes(initCapacity), succinctBytes(minCapacity), succinctBytes(fastExponentialGrowthCapacityLimit), @@ -454,7 +451,6 @@ TEST_F(ArbitrationParticipantTest, config) { minFreeCapacityRatio, succinctBytes(minReclaimBytes), minReclaimPct, - succinctBytes(abortCapacityLimit), expectedError, expectedToString); } @@ -467,9 +463,8 @@ TEST_F(ArbitrationParticipantTest, config) { 0.1, 1, 0.25, - 2, false, - "initCapacity 1B, minCapacity 1B, fastExponentialGrowthCapacityLimit 1B, slowCapacityGrowRatio 0.1, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.25, abortCapacityLimit 2B"}, + "initCapacity 1B, minCapacity 1B, fastExponentialGrowthCapacityLimit 1B, slowCapacityGrowRatio 0.1, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.25"}, {0, 1, 0, @@ -478,9 +473,8 @@ TEST_F(ArbitrationParticipantTest, config) { 0.1, 1, 0.5, - 0, false, - "initCapacity 0B, minCapacity 1B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5, abortCapacityLimit 0B"}, + "initCapacity 0B, minCapacity 1B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5"}, {0, 1, 0, @@ -489,9 +483,8 @@ TEST_F(ArbitrationParticipantTest, config) { 0, 1, 0.5, - 0, false, - "initCapacity 0B, minCapacity 1B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 0B, minFreeCapacityRatio 0, minReclaimBytes 1B, minReclaimPct 0.5, abortCapacityLimit 0B"}, + "initCapacity 0B, minCapacity 1B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 0B, minFreeCapacityRatio 0, minReclaimBytes 1B, minReclaimPct 0.5"}, {1, 1, 0, @@ -500,9 +493,8 @@ TEST_F(ArbitrationParticipantTest, config) { 0.1, 1, 0.5, - 0, false, - "initCapacity 1B, minCapacity 1B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5, abortCapacityLimit 0B"}, + "initCapacity 1B, minCapacity 1B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5"}, {1, 0, 1, @@ -511,9 +503,8 @@ TEST_F(ArbitrationParticipantTest, config) { 0.1, 1, 0.5, - 0, false, - "initCapacity 1B, minCapacity 0B, fastExponentialGrowthCapacityLimit 1B, slowCapacityGrowRatio 0.1, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5, abortCapacityLimit 0B"}, + "initCapacity 1B, minCapacity 0B, fastExponentialGrowthCapacityLimit 1B, slowCapacityGrowRatio 0.1, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5"}, {1, 0, 0, @@ -522,9 +513,8 @@ TEST_F(ArbitrationParticipantTest, config) { 0.1, 0, 0.5, - 1, false, - "initCapacity 1B, minCapacity 0B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 0B, minReclaimPct 0.5, abortCapacityLimit 1B"}, + "initCapacity 1B, minCapacity 0B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 0B, minReclaimPct 0.5"}, {0, 0, 0, @@ -533,9 +523,8 @@ TEST_F(ArbitrationParticipantTest, config) { 0, 1, 0.5, - 0, false, - "initCapacity 0B, minCapacity 0B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 0B, minFreeCapacityRatio 0, minReclaimBytes 1B, minReclaimPct 0.5, abortCapacityLimit 0B"}, + "initCapacity 0B, minCapacity 0B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 0B, minFreeCapacityRatio 0, minReclaimBytes 1B, minReclaimPct 0.5"}, {0, 0, 0, @@ -544,12 +533,11 @@ TEST_F(ArbitrationParticipantTest, config) { 0.1, 1, 0.5, - 0, false, - "initCapacity 0B, minCapacity 0B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5, abortCapacityLimit 0B"}, - {0, 1, 0, 0.1, 1, 0.1, 1, 0.5, 2, true, ""}, - {0, 1, 1, 0.1, 0, 0.1, 1, 0.5, 2, true, ""}, - {0, 1, 1, 0.1, 1, 0, 1, 0.5, 2, true, ""}, + "initCapacity 0B, minCapacity 0B, fastExponentialGrowthCapacityLimit 0B, slowCapacityGrowRatio 0, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 1B, minReclaimPct 0.5"}, + {0, 1, 0, 0.1, 1, 0.1, 1, 0.5, true, ""}, + {0, 1, 1, 0.1, 0, 0.1, 1, 0.5, true, ""}, + {0, 1, 1, 0.1, 1, 0, 1, 0.5, true, ""}, {1, 1, 1, @@ -558,15 +546,12 @@ TEST_F(ArbitrationParticipantTest, config) { 0.1, 0, 0.5, - 0, false, - "initCapacity 1B, minCapacity 1B, fastExponentialGrowthCapacityLimit 1B, slowCapacityGrowRatio 2, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 0B, minReclaimPct 0.5, abortCapacityLimit 0B"}, - {0, 1, 1, -1, 1, 0.1, 1, 0.5, 0, true, ""}, - {0, 1, 1, 0.1, 1, 2, 1, 0.5, 0, true, ""}, - {0, 1, 1, 0.1, 1, -1, 1, 0.5, 0, true, ""}, - {0, 0, 0, 0, 1, 0.1, 0, 0.5, 3, true, ""}, - {0, 0, 0, 0, 1, 0.1, 1, 0.5, 3, true, ""}, - {0, 0, 0, 0, 1, 0.1, 1, 1.5, 0, true, ""}}; + "initCapacity 1B, minCapacity 1B, fastExponentialGrowthCapacityLimit 1B, slowCapacityGrowRatio 2, minFreeCapacity 1B, minFreeCapacityRatio 0.1, minReclaimBytes 0B, minReclaimPct 0.5"}, + {0, 1, 1, -1, 1, 0.1, 1, 0.5, true, ""}, + {0, 1, 1, 0.1, 1, 2, 1, 0.5, true, ""}, + {0, 1, 1, 0.1, 1, -1, 1, 0.5, true, ""}, + {0, 0, 0, 0, 1, 0.1, 1, 1.5, true, ""}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); @@ -580,8 +565,7 @@ TEST_F(ArbitrationParticipantTest, config) { testData.minFreeCapacity, testData.minFreeCapacityRatio, testData.minReclaimBytes, - testData.minReclaimPct, - testData.abortCapacityLimit), + testData.minReclaimPct), ""); continue; } @@ -593,8 +577,7 @@ TEST_F(ArbitrationParticipantTest, config) { testData.minFreeCapacity, testData.minFreeCapacityRatio, testData.minReclaimBytes, - testData.minReclaimPct, - testData.abortCapacityLimit); + testData.minReclaimPct); ASSERT_EQ(testData.initCapacity, config.initCapacity); ASSERT_EQ(testData.minCapacity, config.minCapacity); ASSERT_EQ( @@ -605,7 +588,6 @@ TEST_F(ArbitrationParticipantTest, config) { ASSERT_EQ(testData.minFreeCapacityRatio, config.minFreeCapacityRatio); ASSERT_EQ(testData.minReclaimBytes, config.minReclaimBytes); ASSERT_EQ(testData.minReclaimPct, config.minReclaimPct); - ASSERT_EQ(testData.abortCapacityLimit, config.abortCapacityLimit); ASSERT_EQ(config.toString(), testData.expectedToString); } } @@ -772,7 +754,7 @@ TEST_F(ArbitrationParticipantTest, getGrowTargets) { auto participant = ArbitrationParticipant::create(10, task->pool(), &config); auto scopedParticipant = participant->lock().value(); - scopedParticipant->shrink(/*reclaimFromAll=*/true); + scopedParticipant->shrink(/*reclaimAll=*/true); ASSERT_EQ(scopedParticipant->capacity(), 0); void* buffer = task->allocate(testData.capacity); SCOPE_EXIT { @@ -879,7 +861,7 @@ TEST_F(ArbitrationParticipantTest, reclaimableFreeCapacityAndShrink) { ASSERT_EQ(scopedParticipant->pool()->peakBytes(), testData.peakBytes); } - scopedParticipant->shrink(/*reclaimFromAll=*/true); + scopedParticipant->shrink(/*reclaimAll=*/true); scopedParticipant->grow(testData.capacity, 0); ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); @@ -1039,7 +1021,7 @@ TEST_F(ArbitrationParticipantTest, reclaimableUsedCapacityAndReclaim) { ASSERT_EQ(scopedParticipant->pool()->peakBytes(), testData.peakBytes); } - scopedParticipant->shrink(/*reclaimFromAll=*/true); + scopedParticipant->shrink(/*reclaimAll=*/true); scopedParticipant->grow(testData.capacity, 0); ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); @@ -1354,7 +1336,7 @@ TEST_F(ArbitrationParticipantTest, abort) { const std::string abortReason = "test abort"; try { VELOX_FAIL(abortReason); - } catch (const VeloxRuntimeError& e) { + } catch (const VeloxRuntimeError&) { ASSERT_EQ( scopedParticipant->abort(std::current_exception()), testData.expectedReclaimCapacity); @@ -1370,7 +1352,7 @@ TEST_F(ArbitrationParticipantTest, abort) { try { VELOX_FAIL(abortReason); - } catch (const VeloxRuntimeError& e) { + } catch (const VeloxRuntimeError&) { ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 0); } ASSERT_EQ(scopedParticipant->stats().numShrinks, prevNumShrunks + 1); @@ -1450,7 +1432,7 @@ DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, reclaimLock) { const std::string abortReason = "test abort"; try { VELOX_FAIL(abortReason); - } catch (const VeloxRuntimeError& e) { + } catch (const VeloxRuntimeError&) { ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 32 * MB); } abortCompletedFlag = true; @@ -1527,7 +1509,7 @@ DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, abortedCheck) { const std::string abortReason = "test abort1"; try { VELOX_FAIL(abortReason); - } catch (const VeloxRuntimeError& e) { + } catch (const VeloxRuntimeError&) { ASSERT_EQ(scopedParticipant->abort(std::current_exception()), MB); } }); @@ -1536,7 +1518,7 @@ DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, abortedCheck) { const std::string abortReason = "test abort2"; try { VELOX_FAIL(abortReason); - } catch (const VeloxRuntimeError& e) { + } catch (const VeloxRuntimeError&) { ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 0); } }); @@ -1552,6 +1534,73 @@ DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, abortedCheck) { VELOX_ASSERT_THROW(task->allocate(MB), "test abort1"); } +DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, concurrentAbort) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + task->allocate(32 * MB); + auto scopedParticipant = participant->lock().value(); + + std::atomic_bool firstAbortStarted{false}; + std::atomic_bool secondAbortWaitFlag{true}; + folly::EventCount secondAbortWait; + std::atomic_bool firstAbortWaitFlag{true}; + folly::EventCount firstAbortWait; + + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::reclaim", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + VELOX_FAIL("reclaim abort message"); + }))); + + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::abortLocked", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + if (!firstAbortStarted.exchange(true)) { + // First abort thread signals it started and waits for second + // abort to finish + secondAbortWaitFlag = false; + secondAbortWait.notifyAll(); + firstAbortWait.await( + [&]() { return !firstAbortWaitFlag.load(); }); + } + }))); + + // First abort thread through reclaim - will wait at the test value + std::thread reclaimAbortThread([&]() { + memory::MemoryReclaimer::Stats stats; + scopedParticipant->reclaim(32 * MB, 1'000'000'000'000, stats); + }); + + // Wait for first abort to start + secondAbortWait.await([&]() { return !secondAbortWaitFlag.load(); }); + + // Second abort uses main thread. + try { + VELOX_FAIL("test abort 2"); + } catch (const VeloxRuntimeError&) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 32 * MB); + } + + // Signal first abort to continue + firstAbortWaitFlag = false; + firstAbortWait.notifyAll(); + + // Wait for first abort thread to complete + reclaimAbortThread.join(); + + // Verify the pool is aborted + ASSERT_TRUE(task->pool()->aborted()); + ASSERT_TRUE(scopedParticipant->aborted()); + ASSERT_EQ(scopedParticipant->capacity(), 0); + + // Verify the error message is from the second abort + VELOX_ASSERT_THROW( + std::rethrow_exception(task->abortError()), "test abort 2"); +} + TEST_F(ArbitrationParticipantTest, capacityCheck) { auto task = createTask(256 << 20); const auto config = arbitrationConfig(512 << 20); @@ -1684,7 +1733,7 @@ TEST_F(ArbitrationParticipantTest, arbitrationOperation) { ASSERT_FALSE(abortOp.aborted()); try { VELOX_FAIL("abort op"); - } catch (const VeloxRuntimeError& e) { + } catch (const VeloxRuntimeError&) { ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 0); } ASSERT_TRUE(abortOp.aborted()); @@ -1897,8 +1946,7 @@ TEST_F(ArbitrationParticipantTest, arbitrationOperationState) { #ifndef TSAN_BUILD TEST_F(ArbitrationParticipantTest, arbitrationOperationTimedLock) { auto participantPool = manager_->addRootPool("arbitrationOperationTimedLock"); - auto config = - ArbitrationParticipant::Config(0, 1024, 0, 0, 0, 0, 128, 0, 512); + auto config = ArbitrationParticipant::Config(0, 1024, 0, 0, 0, 0, 128, 0); auto participant = ArbitrationParticipant::create( folly::Random::rand64(), participantPool, &config); diff --git a/velox/common/memory/tests/ByteStreamTest.cpp b/velox/common/memory/tests/ByteStreamTest.cpp index 7ef7eff64b3e..e77ffb2b5fa4 100644 --- a/velox/common/memory/tests/ByteStreamTest.cpp +++ b/velox/common/memory/tests/ByteStreamTest.cpp @@ -248,10 +248,11 @@ TEST_F(ByteStreamTest, newRangeAllocation) { byteStream.startWrite(0); for (int i = 0; i < testData.newRangeSizes.size(); ++i) { const auto newRangeSize = testData.newRangeSizes[i]; - SCOPED_TRACE(fmt::format( - "iteration {} allocation size {}", - i, - succinctBytes(testData.newRangeSizes[i]))); + SCOPED_TRACE( + fmt::format( + "iteration {} allocation size {}", + i, + succinctBytes(testData.newRangeSizes[i]))); std::string value(newRangeSize, 'a'); byteStream.appendStringView(value); ASSERT_EQ(arena->size(), testData.expectedArenaAllocationSizes[i]); @@ -470,8 +471,9 @@ class InputByteStreamTest : public ByteStreamTest, fmt::format("{}/{}", tempDirPath_->getPath(), fileId_++); auto writeFile = fs_->openFileForWrite(filePath); for (auto& byteRange : byteRanges) { - writeFile->append(std::string_view( - reinterpret_cast(byteRange.buffer), byteRange.size)); + writeFile->append( + std::string_view( + reinterpret_cast(byteRange.buffer), byteRange.size)); } writeFile->close(); return std::make_unique( diff --git a/velox/common/memory/tests/CMakeLists.txt b/velox/common/memory/tests/CMakeLists.txt index dc93cc7368f9..bb07501386c8 100644 --- a/velox/common/memory/tests/CMakeLists.txt +++ b/velox/common/memory/tests/CMakeLists.txt @@ -30,7 +30,8 @@ add_executable( RawVectorTest.cpp ScratchTest.cpp SharedArbitratorTest.cpp - StreamArenaTest.cpp) + StreamArenaTest.cpp +) target_link_libraries( velox_memory_test @@ -51,7 +52,8 @@ target_link_libraries( GTest::gmock GTest::gtest GTest::gtest_main - re2::re2) + re2::re2 +) gtest_add_tests(velox_memory_test "" AUTO) @@ -59,12 +61,11 @@ if(VELOX_ENABLE_BENCHMARKS) add_executable(velox_fragmentation_benchmark FragmentationBenchmark.cpp) target_link_libraries( - velox_fragmentation_benchmark PRIVATE velox_memory Folly::folly - gflags::gflags glog::glog) + velox_fragmentation_benchmark + PRIVATE velox_memory Folly::folly gflags::gflags glog::glog + ) - add_executable(velox_concurrent_allocation_benchmark - ConcurrentAllocationBenchmark.cpp) + add_executable(velox_concurrent_allocation_benchmark ConcurrentAllocationBenchmark.cpp) - target_link_libraries(velox_concurrent_allocation_benchmark - PRIVATE velox_memory velox_time) + target_link_libraries(velox_concurrent_allocation_benchmark PRIVATE velox_memory velox_time) endif() diff --git a/velox/common/memory/tests/MemoryAllocatorTest.cpp b/velox/common/memory/tests/MemoryAllocatorTest.cpp index 7b9f123f96a6..cc64ad8ddde1 100644 --- a/velox/common/memory/tests/MemoryAllocatorTest.cpp +++ b/velox/common/memory/tests/MemoryAllocatorTest.cpp @@ -258,8 +258,9 @@ class MemoryAllocatorTest : public testing::TestWithParam { allocations.clear(); } - void clearAllocations(std::vector>>& - allocationsVector) { + void clearAllocations( + std::vector>>& + allocationsVector) { for (auto& allocations : allocationsVector) { for (auto& allocation : allocations) { instance_->freeNonContiguous(*allocation); @@ -831,65 +832,66 @@ TEST_P(MemoryAllocatorTest, nonContiguousFailure) { static_cast(numNewPages), injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Allocate failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kAllocate}, - {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Allocate failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kAllocate}, + {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; std::unordered_map expectedErrorMsg = { {MemoryAllocator::InjectedFailure::kAllocate, diff --git a/velox/common/memory/tests/MemoryArbitratorTest.cpp b/velox/common/memory/tests/MemoryArbitratorTest.cpp index 4cf4c4920591..5fd819033feb 100644 --- a/velox/common/memory/tests/MemoryArbitratorTest.cpp +++ b/velox/common/memory/tests/MemoryArbitratorTest.cpp @@ -713,12 +713,16 @@ TEST_F(MemoryReclaimerTest, scopedReclaimedBytesRecorder) { auto childPool = root->addLeafChild("memoryReclaimRecorder", true); ASSERT_EQ(childPool->reservedBytes(), 0); int64_t reclaimedBytes{0}; - { ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); } + { + ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); + } ASSERT_EQ(reclaimedBytes, 0); void* buffer = childPool->allocate(1 << 20); ASSERT_EQ(childPool->reservedBytes(), 1 << 20); - { ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); } + { + ScopedReclaimedBytesRecorder recorder(childPool.get(), &reclaimedBytes); + } ASSERT_EQ(reclaimedBytes, 0); reclaimedBytes = 0; diff --git a/velox/common/memory/tests/MemoryCapExceededTest.cpp b/velox/common/memory/tests/MemoryCapExceededTest.cpp index f44823aded0f..38c739f27f16 100644 --- a/velox/common/memory/tests/MemoryCapExceededTest.cpp +++ b/velox/common/memory/tests/MemoryCapExceededTest.cpp @@ -69,7 +69,7 @@ TEST_P(MemoryCapExceededTest, singleDriver) { // why). std::vector expectedTexts = { "Can't grow ", - "capacity with 2.00MB. This will exceed its max capacity 5.00MB, current " + "capacity with 2.00MB. This will exceed its memory pool capacity 5.00MB, current " "capacity 5.00MB.\n" "ARBITRATOR[SHARED CAPACITY[6.00GB] STATS[numRequests 1 numRunning 1 " "numSucceded 0 numAborted 0 numFailures 0 numNonReclaimableAttempts 0 " @@ -92,7 +92,7 @@ TEST_P(MemoryCapExceededTest, singleDriver) { "node.1 usage 12.00KB reserved 1.00MB peak 1.00MB", "op.1.0.0.FilterProject usage 12.00KB reserved 1.00MB peak 12.00KB", "node.2 usage 3.74MB reserved 4.00MB peak 4.00MB", - "op.2.0.0.Aggregation usage 3.74MB reserved 4.00MB peak 3.74MB", + "op.2.0.0.Aggregation usage 3.74MB reserved 4.00MB peak 3.76MB", "Top 2 leaf memory pool usages:"}; std::vector data; @@ -113,8 +113,9 @@ TEST_P(MemoryCapExceededTest, singleDriver) { .orderBy({"c0"}, false) .planNode(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); CursorParameters params; params.planNode = plan; params.queryCtx = queryCtx; @@ -171,8 +172,9 @@ TEST_P(MemoryCapExceededTest, multipleDrivers) { .singleAggregation({"c0"}, {"sum(c1)"}) .planNode(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, exec::MemoryReclaimer::create())); const int32_t numDrivers = 10; CursorParameters params; diff --git a/velox/common/memory/tests/MemoryManagerTest.cpp b/velox/common/memory/tests/MemoryManagerTest.cpp index 9a3460bd6432..f1862a03673b 100644 --- a/velox/common/memory/tests/MemoryManagerTest.cpp +++ b/velox/common/memory/tests/MemoryManagerTest.cpp @@ -31,13 +31,14 @@ DECLARE_bool(velox_enable_memory_usage_track_in_default_memory_pool); using namespace ::testing; namespace facebook::velox::memory { - namespace { -constexpr folly::StringPiece kSysRootName{"__sys_root__"}; + +constexpr std::string_view kSysRootName{"__sys_root__"}; MemoryManager& toMemoryManager(MemoryManager& manager) { return *static_cast(&manager); } + } // namespace class MemoryManagerTest : public testing::Test { @@ -57,11 +58,29 @@ TEST_F(MemoryManagerTest, ctor) { ASSERT_EQ(manager.capacity(), kMaxMemory); ASSERT_EQ(0, manager.getTotalBytes()); ASSERT_EQ(manager.alignment(), MemoryAllocator::kMaxAlignment); - ASSERT_EQ(manager.testingDefaultRoot().alignment(), manager.alignment()); - ASSERT_EQ(manager.testingDefaultRoot().capacity(), kMaxMemory); - ASSERT_EQ(manager.testingDefaultRoot().maxCapacity(), kMaxMemory); + ASSERT_EQ(manager.deprecatedSysRootPool().alignment(), manager.alignment()); + ASSERT_EQ(manager.deprecatedSysRootPool().capacity(), kMaxMemory); + ASSERT_EQ(manager.deprecatedSysRootPool().maxCapacity(), kMaxMemory); ASSERT_EQ(manager.arbitrator()->kind(), "NOOP"); + auto sysPool = manager.deprecatedSysRootPool().shared_from_this(); + ASSERT_NE(sysPool->reclaimer(), nullptr); + try { + VELOX_FAIL("Trigger Error"); + } catch (const velox::VeloxRuntimeError&) { + VELOX_ASSERT_THROW( + sysPool->reclaimer()->abort( + &manager.deprecatedSysRootPool(), std::current_exception()), + "SysMemoryReclaimer::abort is not supported"); + } + ASSERT_EQ(sysPool->reclaimer()->priority(), 0); + memory::MemoryReclaimer::Stats stats; + ASSERT_EQ( + sysPool->reclaimer()->reclaim(sysPool.get(), 1'000, 1'000, stats), 0); + uint64_t reclaimableBytes{0}; + ASSERT_FALSE( + sysPool->reclaimer()->reclaimableBytes(*sysPool, reclaimableBytes)); } + { const auto kCapacity = 8L * 1024 * 1024; MemoryManager::Options options; @@ -70,7 +89,7 @@ TEST_F(MemoryManagerTest, ctor) { MemoryManager manager{options}; ASSERT_EQ(kCapacity, manager.capacity()); ASSERT_EQ(manager.numPools(), 3); - ASSERT_EQ(manager.testingDefaultRoot().alignment(), manager.alignment()); + ASSERT_EQ(manager.deprecatedSysRootPool().alignment(), manager.alignment()); } { const auto kCapacity = 8L * 1024 * 1024; @@ -81,10 +100,10 @@ TEST_F(MemoryManagerTest, ctor) { MemoryManager manager{options}; ASSERT_EQ(manager.alignment(), MemoryAllocator::kMinAlignment); - ASSERT_EQ(manager.testingDefaultRoot().alignment(), manager.alignment()); + ASSERT_EQ(manager.deprecatedSysRootPool().alignment(), manager.alignment()); // TODO: replace with root pool memory tracker quota check. ASSERT_EQ( - kSharedPoolCount + 3, manager.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 3, manager.deprecatedSysRootPool().getChildCount()); ASSERT_EQ(kCapacity, manager.capacity()); ASSERT_EQ(0, manager.getTotalBytes()); } @@ -203,14 +222,18 @@ TEST_F(MemoryManagerTest, addPool) { auto rootPool = manager.addRootPool("duplicateRootPool", kMaxMemory); ASSERT_EQ(rootPool->capacity(), kMaxMemory); ASSERT_EQ(rootPool->maxCapacity(), kMaxMemory); - { ASSERT_ANY_THROW(manager.addRootPool("duplicateRootPool", kMaxMemory)); } + { + ASSERT_ANY_THROW(manager.addRootPool("duplicateRootPool", kMaxMemory)); + } auto threadSafeLeafPool = manager.addLeafPool("leafPool", true); ASSERT_EQ(threadSafeLeafPool->capacity(), kMaxMemory); ASSERT_EQ(threadSafeLeafPool->maxCapacity(), kMaxMemory); auto nonThreadSafeLeafPool = manager.addLeafPool("duplicateLeafPool", true); ASSERT_EQ(nonThreadSafeLeafPool->capacity(), kMaxMemory); ASSERT_EQ(nonThreadSafeLeafPool->maxCapacity(), kMaxMemory); - { ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); } + { + ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); + } const int64_t poolCapacity = 1 << 20; auto rootPoolWithMaxCapacity = manager.addRootPool("rootPoolWithCapacity", poolCapacity); @@ -255,7 +278,9 @@ TEST_F(MemoryManagerTest, addPoolWithArbitrator) { auto nonThreadSafeLeafPool = manager.addLeafPool("duplicateLeafPool", true); ASSERT_EQ(nonThreadSafeLeafPool->capacity(), kMaxMemory); ASSERT_EQ(nonThreadSafeLeafPool->maxCapacity(), kMaxMemory); - { ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); } + { + ASSERT_ANY_THROW(manager.addLeafPool("duplicateLeafPool")); + } const int64_t poolCapacity = 1 << 30; auto rootPoolWithMaxCapacity = manager.addRootPool( "rootPoolWithCapacity", poolCapacity, MemoryReclaimer::create()); @@ -275,18 +300,18 @@ TEST_F(MemoryManagerTest, defaultMemoryManager) { auto& managerB = toMemoryManager(deprecatedDefaultMemoryManager()); const auto kSharedPoolCount = FLAGS_velox_memory_num_shared_leaf_pools + 3; ASSERT_EQ(managerA.numPools(), 3); - ASSERT_EQ(managerA.testingDefaultRoot().getChildCount(), kSharedPoolCount); + ASSERT_EQ(managerA.deprecatedSysRootPool().getChildCount(), kSharedPoolCount); ASSERT_EQ(managerB.numPools(), 3); - ASSERT_EQ(managerB.testingDefaultRoot().getChildCount(), kSharedPoolCount); + ASSERT_EQ(managerB.deprecatedSysRootPool().getChildCount(), kSharedPoolCount); auto child1 = managerA.addLeafPool("child_1"); - ASSERT_EQ(child1->parent()->name(), managerA.testingDefaultRoot().name()); + ASSERT_EQ(child1->parent()->name(), managerA.deprecatedSysRootPool().name()); auto child2 = managerB.addLeafPool("child_2"); - ASSERT_EQ(child2->parent()->name(), managerA.testingDefaultRoot().name()); + ASSERT_EQ(child2->parent()->name(), managerA.deprecatedSysRootPool().name()); EXPECT_EQ( - kSharedPoolCount + 2, managerA.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 2, managerA.deprecatedSysRootPool().getChildCount()); EXPECT_EQ( - kSharedPoolCount + 2, managerB.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 2, managerB.deprecatedSysRootPool().getChildCount()); ASSERT_EQ(managerA.numPools(), 5); ASSERT_EQ(managerB.numPools(), 5); auto pool = managerB.addRootPool(); @@ -300,9 +325,9 @@ TEST_F(MemoryManagerTest, defaultMemoryManager) { "Memory Manager[capacity UNLIMITED alignment 64B usedBytes 0B number of pools 6\nList of root pools:\n\t__sys_root__\n\tdefault_root_0\n\trefcount 2\nMemory Allocator[MALLOC capacity UNLIMITED allocated bytes 0 allocated pages 0 mapped pages 0]\nARBIRTATOR[NOOP CAPACITY[UNLIMITED]]]"); child1.reset(); EXPECT_EQ( - kSharedPoolCount + 1, managerA.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 1, managerA.deprecatedSysRootPool().getChildCount()); child2.reset(); - EXPECT_EQ(kSharedPoolCount, managerB.testingDefaultRoot().getChildCount()); + EXPECT_EQ(kSharedPoolCount, managerB.deprecatedSysRootPool().getChildCount()); ASSERT_EQ(managerA.numPools(), 4); ASSERT_EQ(managerB.numPools(), 4); pool.reset(); @@ -328,8 +353,9 @@ TEST_F(MemoryManagerTest, defaultMemoryManager) { for (int i = 0; i < 32; ++i) { ASSERT_THAT( managerA.toString(true), - testing::HasSubstr(fmt::format( - "__sys_shared_leaf__{} usage 0B reserved 0B peak 0B\n", i))); + testing::HasSubstr( + fmt::format( + "__sys_shared_leaf__{} usage 0B reserved 0B peak 0B\n", i))); } } @@ -337,32 +363,35 @@ TEST_F(MemoryManagerTest, defaultMemoryManager) { TEST(MemoryHeaderTest, addDefaultLeafMemoryPool) { auto& manager = toMemoryManager(deprecatedDefaultMemoryManager()); const auto kSharedPoolCount = FLAGS_velox_memory_num_shared_leaf_pools + 3; - ASSERT_EQ(manager.testingDefaultRoot().getChildCount(), kSharedPoolCount); + ASSERT_EQ(manager.deprecatedSysRootPool().getChildCount(), kSharedPoolCount); { auto poolA = deprecatedAddDefaultLeafMemoryPool(); ASSERT_EQ(poolA->kind(), MemoryPool::Kind::kLeaf); auto poolB = deprecatedAddDefaultLeafMemoryPool(); ASSERT_EQ(poolB->kind(), MemoryPool::Kind::kLeaf); EXPECT_EQ( - kSharedPoolCount + 2, manager.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 2, manager.deprecatedSysRootPool().getChildCount()); { auto poolC = deprecatedAddDefaultLeafMemoryPool(); ASSERT_EQ(poolC->kind(), MemoryPool::Kind::kLeaf); EXPECT_EQ( - kSharedPoolCount + 3, manager.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 3, + manager.deprecatedSysRootPool().getChildCount()); { auto poolD = deprecatedAddDefaultLeafMemoryPool(); ASSERT_EQ(poolD->kind(), MemoryPool::Kind::kLeaf); EXPECT_EQ( - kSharedPoolCount + 4, manager.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 4, + manager.deprecatedSysRootPool().getChildCount()); } EXPECT_EQ( - kSharedPoolCount + 3, manager.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 3, + manager.deprecatedSysRootPool().getChildCount()); } EXPECT_EQ( - kSharedPoolCount + 2, manager.testingDefaultRoot().getChildCount()); + kSharedPoolCount + 2, manager.deprecatedSysRootPool().getChildCount()); } - EXPECT_EQ(kSharedPoolCount, manager.testingDefaultRoot().getChildCount()); + EXPECT_EQ(kSharedPoolCount, manager.deprecatedSysRootPool().getChildCount()); auto namedPool = deprecatedAddDefaultLeafMemoryPool("namedPool"); ASSERT_EQ(namedPool->name(), "namedPool"); @@ -402,7 +431,7 @@ TEST_F(MemoryManagerTest, memoryPoolManagement) { if (i % 2) { ASSERT_EQ(pool->kind(), MemoryPool::Kind::kLeaf); userLeafPools.push_back(pool); - ASSERT_EQ(pool->parent()->name(), manager.testingDefaultRoot().name()); + ASSERT_EQ(pool->parent()->name(), manager.deprecatedSysRootPool().name()); } else { ASSERT_EQ(pool->kind(), MemoryPool::Kind::kAggregate); ASSERT_EQ(pool->parent(), nullptr); @@ -442,12 +471,12 @@ TEST_F(MemoryManagerTest, globalMemoryManager) { auto* managerII = memoryManager(); const auto kSharedPoolCount = FLAGS_velox_memory_num_shared_leaf_pools + 3; { - auto& rootI = manager->testingDefaultRoot(); + auto& rootI = manager->deprecatedSysRootPool(); const std::string childIName("some_child"); auto childI = rootI.addLeafChild(childIName); ASSERT_EQ(rootI.getChildCount(), kSharedPoolCount + 1); - auto& rootII = managerII->testingDefaultRoot(); + auto& rootII = managerII->deprecatedSysRootPool(); ASSERT_EQ(kSharedPoolCount + 1, rootII.getChildCount()); std::vector pools{}; rootII.visitChildren([&pools](MemoryPool* child) { @@ -466,7 +495,7 @@ TEST_F(MemoryManagerTest, globalMemoryManager) { auto childII = manager->addLeafPool("another_child"); ASSERT_EQ(childII->kind(), MemoryPool::Kind::kLeaf); ASSERT_EQ(rootI.getChildCount(), kSharedPoolCount + 2); - ASSERT_EQ(childII->parent()->name(), kSysRootName.str()); + ASSERT_EQ(childII->parent()->name(), kSysRootName); childII.reset(); ASSERT_EQ(rootI.getChildCount(), kSharedPoolCount + 1); ASSERT_EQ(rootII.getChildCount(), kSharedPoolCount + 1); @@ -511,7 +540,7 @@ TEST_F(MemoryManagerTest, alignmentOptionCheck) { manager.alignment(), std::max(testData.alignment, MemoryAllocator::kMinAlignment)); ASSERT_EQ( - manager.testingDefaultRoot().alignment(), + manager.deprecatedSysRootPool().alignment(), std::max(testData.alignment, MemoryAllocator::kMinAlignment)); auto leafPool = manager.addLeafPool("leafPool"); ASSERT_EQ( diff --git a/velox/common/memory/tests/MemoryPoolBenchmark.cpp b/velox/common/memory/tests/MemoryPoolBenchmark.cpp index 82296fb7e74f..077a3873e9a5 100644 --- a/velox/common/memory/tests/MemoryPoolBenchmark.cpp +++ b/velox/common/memory/tests/MemoryPoolBenchmark.cpp @@ -35,7 +35,7 @@ class BenchmarkHelper { std::vector findLeaves() { std::vector leaves; - findLeavesOf(manager_.testingDefaultRoot(), leaves); + findLeavesOf(manager_.deprecatedSysRootPool(), leaves); return leaves; } @@ -158,7 +158,7 @@ void addNLeaves(MemoryPool& pool, size_t n) { BENCHMARK(FlatTree, iters) { folly::BenchmarkSuspender suspender; MemoryManager manager{}; - addNLeaves(manager.testingDefaultRoot(), 10 * 20); + addNLeaves(manager.deprecatedSysRootPool(), 10 * 20); BenchmarkHelper helper{manager}; suspender.dismiss(); helper.runForEachPool([iters](MemoryPool& pool) { diff --git a/velox/common/memory/tests/MemoryPoolTest.cpp b/velox/common/memory/tests/MemoryPoolTest.cpp index d0b6a8bbd9cf..42f8c6f2d900 100644 --- a/velox/common/memory/tests/MemoryPoolTest.cpp +++ b/velox/common/memory/tests/MemoryPoolTest.cpp @@ -31,7 +31,6 @@ #include "velox/common/testutil/TestValue.h" DECLARE_bool(velox_memory_leak_check_enabled); -DECLARE_bool(velox_memory_pool_debug_enabled); DECLARE_int32(velox_memory_num_shared_leaf_pools); using namespace ::testing; @@ -61,15 +60,16 @@ struct TestParam { class MemoryPoolTest : public testing::TestWithParam { public: static const std::vector getTestParams() { - std::vector params; - params.push_back({true, true, false}); - params.push_back({true, false, false}); - params.push_back({false, true, false}); - params.push_back({false, false, false}); - params.push_back({true, true, true}); - params.push_back({true, false, true}); - params.push_back({false, true, true}); - params.push_back({false, false, true}); + std::vector params = { + {true, true, false}, + {true, false, false}, + {false, true, false}, + {false, false, false}, + {true, true, true}, + {true, false, true}, + {false, true, true}, + {false, false, true}, + }; return params; } @@ -165,13 +165,14 @@ TEST_P(MemoryPoolTest, ctor) { auto fakeRoot = std::make_shared( &manager, "fake_root", MemoryPool::Kind::kAggregate, nullptr, nullptr); // We can't construct an aggregate memory pool with non-thread safe. - ASSERT_ANY_THROW(std::make_shared( - &manager, - "fake_root", - MemoryPool::Kind::kAggregate, - nullptr, - nullptr, - MemoryPool::Options{.threadSafe = false})); + ASSERT_ANY_THROW( + std::make_shared( + &manager, + "fake_root", + MemoryPool::Kind::kAggregate, + nullptr, + nullptr, + MemoryPool::Options{.threadSafe = false})); ASSERT_EQ("fake_root", fakeRoot->name()); ASSERT_EQ( static_cast(root.get())->testingAllocator(), @@ -784,7 +785,7 @@ TEST_P(MemoryPoolTest, memoryCapExceptions) { "1, frees 0, reserves 0, releases 0, collisions 0])> " "Exceeded memory allocator limit when allocating 32769 " "new pages for total allocation of 32769 pages, the memory" - " allocator capacity is 32768 pages", + " allocator capacity is 32768 pages, the allocated pages is 32769", isLeafThreadSafe_ ? "thread-safe" : "non-thread-safe"), ex.message()); } @@ -945,7 +946,7 @@ TEST_P(MemoryPoolTest, childUsageTest) { TEST_P(MemoryPoolTest, getPreferredSize) { MemoryManager& manager = *getMemoryManager(); - auto& pool = dynamic_cast(manager.testingDefaultRoot()); + auto& pool = dynamic_cast(manager.deprecatedSysRootPool()); // size < 8 EXPECT_EQ(8, pool.preferredSize(1)); @@ -1047,7 +1048,7 @@ TEST_P(MemoryPoolTest, customizedGetPreferredSize) { TEST_P(MemoryPoolTest, getPreferredSizeOverflow) { MemoryManager& manager = *getMemoryManager(); - auto& pool = dynamic_cast(manager.testingDefaultRoot()); + auto& pool = dynamic_cast(manager.deprecatedSysRootPool()); EXPECT_EQ(1ULL << 32, pool.preferredSize((1ULL << 32) - 1)); EXPECT_EQ(1ULL << 63, pool.preferredSize((1ULL << 62) - 1 + (1ULL << 62))); @@ -1055,7 +1056,7 @@ TEST_P(MemoryPoolTest, getPreferredSizeOverflow) { TEST_P(MemoryPoolTest, allocatorOverflow) { MemoryManager& manager = *getMemoryManager(); - auto& pool = dynamic_cast(manager.testingDefaultRoot()); + auto& pool = dynamic_cast(manager.deprecatedSysRootPool()); StlAllocator alloc(pool); EXPECT_THROW(alloc.allocate(1ULL << 62), VeloxException); EXPECT_THROW(alloc.deallocate(nullptr, 1ULL << 62), VeloxException); @@ -1402,71 +1403,73 @@ TEST_P(MemoryPoolTest, persistentNonContiguousAllocateFailure) { static_cast(numNewPages), injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Allocate failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kAllocate}, - {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kAllocate}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kAllocate}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Allocate failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kAllocate}, + {200, 100, MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kAllocate}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kAllocate}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useMmap:{}, useCache:{}", - testData.debugString(), - useMmap_, - useCache_)); + SCOPED_TRACE( + fmt::format( + "{}, useMmap:{}, useCache:{}", + testData.debugString(), + useMmap_, + useCache_)); if ((testData.injectedFailure != MemoryAllocator::InjectedFailure::kAllocate) && !useMmap_) { @@ -1584,11 +1587,12 @@ TEST_P(MemoryPoolTest, transientNonContiguousAllocateFailure) { MemoryAllocator::InjectedFailure::kMadvise}, {200, 100, 100, MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useMmap:{}, useCache:{}", - testData.debugString(), - useMmap_, - useCache_)); + SCOPED_TRACE( + fmt::format( + "{}, useMmap:{}, useCache:{}", + testData.debugString(), + useMmap_, + useCache_)); if ((testData.injectedFailure != MemoryAllocator::InjectedFailure::kAllocate) && !useMmap_) { @@ -1672,81 +1676,82 @@ TEST_P(MemoryPoolTest, persistentContiguousAllocateFailure) { numNewPages, injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Mmap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMmap}, - {200, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Mmap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMmap}, + {200, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); if (!useMmap_) { @@ -1790,87 +1795,89 @@ TEST_P(MemoryPoolTest, transientContiguousAllocateFailure) { numNewPages, injectedFailure); } - } testSettings[] = {// Cap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {200, 100, MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kCap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kCap}, - // Mmap failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kCap}, - {100, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMmap}, - {200, 100, MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMmap}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMmap}, - // Madvise failure injection. - {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {0, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun, - MemoryAllocator::InjectedFailure::kMadvise}, - {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun / 2 + 100, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun - 1, - MemoryAllocator::InjectedFailure::kMadvise}, - {Allocation::PageRun::kMaxPagesInRun, - Allocation::PageRun::kMaxPagesInRun / 2, - MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {200, 100, MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kCap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kCap}, + // Mmap failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kCap}, + {100, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMmap}, + {200, 100, MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMmap}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMmap}, + // Madvise failure injection. + {0, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {0, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {100, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun, + MemoryAllocator::InjectedFailure::kMadvise}, + {200, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun / 2 + 100, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun - 1, + MemoryAllocator::InjectedFailure::kMadvise}, + {Allocation::PageRun::kMaxPagesInRun, + Allocation::PageRun::kMaxPagesInRun / 2, + MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useCache:{} , useMmap:{}", - testData.debugString(), - useCache_, - useMmap_)); + SCOPED_TRACE( + fmt::format( + "{}, useCache:{} , useMmap:{}", + testData.debugString(), + useCache_, + useMmap_)); if (!useMmap_) { // No failure injections supported for contiguous allocation of // MallocAllocator. @@ -1944,33 +1951,34 @@ TEST_P(MemoryPoolTest, persistentContiguousGrowAllocateFailure) { injectedFailure, expectedErrorMessage); } - } testSettings[] = {// Cap failure injection. - {10, - 100, - MemoryAllocator::InjectedFailure::kCap, - "growContiguous failed with 100 pages from Memory Pool"}, - {100, - 10, - MemoryAllocator::InjectedFailure::kCap, - "growContiguous failed with 10 pages from Memory Pool"}, - // Mmap failure injection. - {10, - 100, - MemoryAllocator::InjectedFailure::kMmap, - "growContiguous failed with 100 pages from Memory Pool"}, - {100, - 10, - MemoryAllocator::InjectedFailure::kMmap, - "growContiguous failed with 10 pages from Memory Pool"}, - // Madvise failure injection. - {10, - 100, - MemoryAllocator::InjectedFailure::kMadvise, - "growContiguous failed with 100 pages from Memory Pool"}, - {100, - 10, - MemoryAllocator::InjectedFailure::kMadvise, - "growContiguous failed with 10 pages from Memory Pool"}}; + } testSettings[] = { + // Cap failure injection. + {10, + 100, + MemoryAllocator::InjectedFailure::kCap, + "growContiguous failed with 100 pages from Memory Pool"}, + {100, + 10, + MemoryAllocator::InjectedFailure::kCap, + "growContiguous failed with 10 pages from Memory Pool"}, + // Mmap failure injection. + {10, + 100, + MemoryAllocator::InjectedFailure::kMmap, + "growContiguous failed with 100 pages from Memory Pool"}, + {100, + 10, + MemoryAllocator::InjectedFailure::kMmap, + "growContiguous failed with 10 pages from Memory Pool"}, + // Madvise failure injection. + {10, + 100, + MemoryAllocator::InjectedFailure::kMadvise, + "growContiguous failed with 100 pages from Memory Pool"}, + {100, + 10, + MemoryAllocator::InjectedFailure::kMadvise, + "growContiguous failed with 10 pages from Memory Pool"}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); if (!useMmap_) { @@ -2014,21 +2022,23 @@ TEST_P(MemoryPoolTest, transientContiguousGrowAllocateFailure) { numGrowPages, injectedFailure); } - } testSettings[] = {// Cap failure injection. - {10, 100, MemoryAllocator::InjectedFailure::kCap}, - {100, 10, MemoryAllocator::InjectedFailure::kCap}, - // Mmap failure injection. - {10, 100, MemoryAllocator::InjectedFailure::kMmap}, - {100, 10, MemoryAllocator::InjectedFailure::kMmap}, - // Madvise failure injection. - {10, 100, MemoryAllocator::InjectedFailure::kMadvise}, - {100, 10, MemoryAllocator::InjectedFailure::kMadvise}}; + } testSettings[] = { + // Cap failure injection. + {10, 100, MemoryAllocator::InjectedFailure::kCap}, + {100, 10, MemoryAllocator::InjectedFailure::kCap}, + // Mmap failure injection. + {10, 100, MemoryAllocator::InjectedFailure::kMmap}, + {100, 10, MemoryAllocator::InjectedFailure::kMmap}, + // Madvise failure injection. + {10, 100, MemoryAllocator::InjectedFailure::kMadvise}, + {100, 10, MemoryAllocator::InjectedFailure::kMadvise}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{}, useCache:{} , useMmap:{}", - testData.debugString(), - useCache_, - useMmap_)); + SCOPED_TRACE( + fmt::format( + "{}, useCache:{} , useMmap:{}", + testData.debugString(), + useCache_, + useMmap_)); if (!useMmap_) { // No failure injections supported for contiguous allocation of // MallocAllocator. @@ -2520,7 +2530,6 @@ TEST_P(MemoryPoolTest, concurrentUpdateToDifferentPools) { } TEST_P(MemoryPoolTest, concurrentUpdatesToTheSamePool) { - FLAGS_velox_memory_pool_debug_enabled = true; if (!isLeafThreadSafe_) { return; } @@ -2706,7 +2715,6 @@ TEST(MemoryPoolTest, visitChildren) { } TEST(MemoryPoolTest, debugMode) { - FLAGS_velox_memory_pool_debug_enabled = true; constexpr int64_t kMaxMemory = 10 * GB; constexpr int64_t kNumIterations = 100; const std::vector kAllocSizes = {128, 8 * KB, 2 * MB}; @@ -2726,12 +2734,23 @@ TEST(MemoryPoolTest, debugMode) { ->addLeafChild("child"); const auto& allocRecords = std::dynamic_pointer_cast(pool) ->testingDebugAllocRecords(); + std::vector smallAllocs; + smallAllocs.reserve(kNumIterations); for (int32_t i = 0; i < kNumIterations; i++) { smallAllocs.push_back(pool->allocate(kAllocSizes[0])); } EXPECT_EQ(allocRecords.size(), kNumIterations); checkAllocs(allocRecords, kAllocSizes[0]); + + // Check toString() works with debug mode enabled + const auto poolString = pool->toString(); + EXPECT_FALSE(poolString.empty()); + EXPECT_TRUE( + poolString.find( + "======== 100 allocations of 12.50KB total size ========") != + std::string::npos); + for (int32_t i = 0; i < kNumIterations; i++) { pool->free(smallAllocs[i], kAllocSizes[0]); } @@ -2779,14 +2798,16 @@ TEST(MemoryPoolTest, debugModeWithFilter) { "root0", kMaxMemory, nullptr, - debugEnabled ? std::optional(MemoryPool::DebugOptions{ - .debugPoolNameRegex = "NO-MATCH"}) - : std::nullopt); + debugEnabled + ? std::optional( + MemoryPool::DebugOptions{.debugPoolNameRegex = "NO-MATCH"}) + : std::nullopt); auto pool0_0 = root0->addLeafChild("PartialAggregation.0.0"); auto* buffer0 = pool0_0->allocate(1 * KB); - EXPECT_TRUE(std::dynamic_pointer_cast(pool0_0) - ->testingDebugAllocRecords() - .empty()); + EXPECT_TRUE( + std::dynamic_pointer_cast(pool0_0) + ->testingDebugAllocRecords() + .empty()); pool0_0->free(buffer0, 1 * KB); // leaf child created from MemoryPool, match filter @@ -2794,8 +2815,9 @@ TEST(MemoryPoolTest, debugModeWithFilter) { "root1", kMaxMemory, nullptr, - debugEnabled ? std::optional(MemoryPool::DebugOptions{ - .debugPoolNameRegex = ".*PartialAggregation.*"}) + debugEnabled ? std::optional( + MemoryPool::DebugOptions{ + .debugPoolNameRegex = ".*PartialAggregation.*"}) : std::nullopt); auto pool1_0 = root1->addLeafChild("PartialAggregation.0.1"); auto* buffer1 = pool1_0->allocate(1 * KB); @@ -2816,9 +2838,10 @@ TEST(MemoryPoolTest, debugModeWithFilter) { // old pool from root0 should not be affected by root1 buffer0 = pool0_0->allocate(1 * KB); - EXPECT_TRUE(std::dynamic_pointer_cast(pool0_0) - ->testingDebugAllocRecords() - .empty()); + EXPECT_TRUE( + std::dynamic_pointer_cast(pool0_0) + ->testingDebugAllocRecords() + .empty()); pool0_0->free(buffer0, 1 * KB); // leaf child created from MemoryPool, match filter @@ -2826,9 +2849,10 @@ TEST(MemoryPoolTest, debugModeWithFilter) { "root2", kMaxMemory, nullptr, - debugEnabled ? std::optional(MemoryPool::DebugOptions{ - .debugPoolNameRegex = ".*OrderBy.*"}) - : std::nullopt); + debugEnabled + ? std::optional( + MemoryPool::DebugOptions{.debugPoolNameRegex = ".*OrderBy.*"}) + : std::nullopt); auto pool2_0 = root2->addLeafChild("OrderBy.0.0"); auto* buffer2 = pool2_0->allocate(1 * KB); if (!debugEnabled) { @@ -2879,13 +2903,51 @@ TEST(MemoryPoolTest, debugModeWithFilter) { // leaf child created from MemoryManager, not match filter auto sysLeaf = manager.addLeafPool("Arbitrator.0.0"); auto* buffer5 = sysLeaf->allocate(1 * KB); - EXPECT_TRUE(std::dynamic_pointer_cast(sysLeaf) - ->testingDebugAllocRecords() - .empty()); + EXPECT_TRUE( + std::dynamic_pointer_cast(sysLeaf) + ->testingDebugAllocRecords() + .empty()); sysLeaf->free(buffer5, 1 * KB); } } +TEST_P(MemoryPoolTest, debugModeWrapCapException) { + const uint64_t kMaxCap = 128L * MB; + MemoryManager::Options options; + options.allocatorCapacity = kMaxCap; + options.arbitratorCapacity = kMaxCap; + options.extraArbitratorConfigs = { + {std::string(SharedArbitrator::ExtraConfig::kReservedCapacity), + folly::to(kMaxCap / 2) + "B"}}; + setupMemory(options); + auto manager = getMemoryManager(); + auto root = + manager->addRootPool("MemoryCapExceptions", kMaxCap, nullptr, {{".*"}}); + auto pool1 = root->addLeafChild("static_quota_1", isLeafThreadSafe_); + auto pool2 = root->addLeafChild("static_quota_2", isLeafThreadSafe_); + { + std::vector buffers{ + pool1->allocate(64L * MB), pool1->allocate(64L * MB)}; + try { + pool2->allocate(1L * MB); + } catch (const velox::VeloxRuntimeError& ex) { + ASSERT_EQ(error_source::kErrorSourceRuntime.c_str(), ex.errorSource()); + ASSERT_EQ(error_code::kMemCapExceeded.c_str(), ex.errorCode()); + EXPECT_TRUE( + ex.message().find( + "Exceeded memory pool capacity.\n\n" + "======= Current Allocations ======\n" + "Memory pool 'static_quota_1' - Found 2 allocations with 128.00MB total size:\n" + "======== 2 allocations of 128.00MB total size ========") != + std::string::npos) + << "Actual error message: " << ex.message(); + } + for (auto buffer : buffers) { + pool1->free(buffer, 64L * MB); + } + } +} + TEST_P(MemoryPoolTest, shrinkAndGrowAPIs) { MemoryManager& manager = *getMemoryManager(); std::vector capacities = {kMaxMemory, 128 * MB}; @@ -3968,9 +4030,13 @@ TEST_P(MemoryPoolTest, abort) { ASSERT_TRUE(rootPool->aborted()); // Allocate more buffer to trigger reservation increment at the root. - { VELOX_ASSERT_THROW(leafPool->allocate(capacity / 2), ""); } + { + VELOX_ASSERT_THROW(leafPool->allocate(capacity / 2), ""); + } // Allocate more buffer to trigger memory arbitration at the root. - { VELOX_ASSERT_THROW(leafPool->allocate(capacity * 2), ""); } + { + VELOX_ASSERT_THROW(leafPool->allocate(capacity * 2), ""); + } // Allocate without trigger memory reservation increment. void* buf2 = leafPool->allocate(128); ASSERT_EQ(leafPool->usedBytes(), 256); @@ -4037,6 +4103,213 @@ TEST_P(MemoryPoolTest, allocationWithCoveredCollateral) { pool->freeContiguous(contiguousAllocation); } +TEST_P(MemoryPoolTest, transferTo) { + MemoryManager::Options options; + options.alignment = MemoryAllocator::kMinAlignment; + options.allocatorCapacity = kDefaultCapacity; + setupMemory(options); + auto manager = getMemoryManager(); + + auto largestSizeClass = manager->allocator()->largestSizeClass(); + std::vector pageCounts{ + largestSizeClass, + largestSizeClass + 1, + largestSizeClass / 10, + 1, + largestSizeClass * 2, + largestSizeClass * 3 + 1}; + + auto assertEqualBytes = [](const memory::MemoryPool* pool, + int64_t usedBytes, + int64_t peakBytes, + int64_t reservedBytes) { + EXPECT_EQ(pool->usedBytes(), usedBytes); + EXPECT_EQ(pool->peakBytes(), peakBytes); + EXPECT_EQ(pool->reservedBytes(), reservedBytes); + }; + + auto assertZeroByte = [](const memory::MemoryPool* pool) { + EXPECT_EQ(pool->usedBytes(), 0); + EXPECT_EQ(pool->reservedBytes(), 0); + }; + + auto getMemoryBytes = [](const memory::MemoryPool* pool) { + return std::make_tuple( + pool->usedBytes(), pool->peakBytes(), pool->reservedBytes()); + }; + + auto createPools = [&manager](bool betweenDifferentRoots) { + auto root1 = manager->addRootPool("root1"); + auto root2 = manager->addRootPool("root2"); + std::shared_ptr from; + std::shared_ptr to; + if (betweenDifferentRoots) { + from = root1->addLeafChild("from"); + to = root2->addLeafChild("to"); + } else { + from = root1->addLeafChild("from"); + to = root1->addLeafChild("to"); + } + return std::make_tuple(root1, root2, from, to); + }; + + auto testTransferAllocate = [&assertZeroByte, + &assertEqualBytes, + &getMemoryBytes, + &createPools](bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + const auto kSize = 1024; + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + auto buffer = from->allocate(kSize); + // Transferring between non-leaf pools is not allowed. + EXPECT_FALSE(from->root()->transferTo(to.get(), buffer, kSize)); + EXPECT_FALSE(from->transferTo(to->root(), buffer, kSize)); + + std::tie(usedBytes, peakBytes, reservedBytes) = getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), buffer, kSize); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->free(buffer, kSize); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateZeroFilled = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + const auto kSize = 1024; + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + auto buffer = from->allocateZeroFilled(8, kSize / 8); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), buffer, kSize); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->free(buffer, kSize); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateContiguous = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + uint64_t pageCount, bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + ContiguousAllocation out; + from->allocateContiguous(pageCount, out); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + from->transferTo(to.get(), out.data(), out.size()); + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + rootPeakBytes *= 2; + } + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + to->freeContiguous(out); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + auto testTransferAllocateNonContiguous = + [&assertZeroByte, &assertEqualBytes, &getMemoryBytes, &createPools]( + uint64_t pageCount, bool betweenDifferentRoots) { + auto [root1, root2, from, to] = createPools(betweenDifferentRoots); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + + int64_t usedBytes, rootUsedBytes; + int64_t peakBytes, rootPeakBytes; + int64_t reservedBytes, rootReservedBytes; + Allocation out; + from->allocateNonContiguous(pageCount, out); + std::tie(usedBytes, peakBytes, reservedBytes) = + getMemoryBytes(from.get()); + std::tie(rootUsedBytes, rootPeakBytes, rootReservedBytes) = + getMemoryBytes(from->root()); + for (auto i = 0; i < out.numRuns(); ++i) { + const auto& run = out.runAt(i); + from->transferTo(to.get(), run.data(), run.numBytes()); + } + assertEqualBytes(to.get(), usedBytes, peakBytes, reservedBytes); + if (from->root() == to->root()) { + EXPECT_EQ(to->root()->usedBytes(), rootUsedBytes); + // We reserve and release memory run-by-run, so the peak bytes would + // be no greater than twice of the original peak bytes. + EXPECT_LE(to->root()->peakBytes(), rootPeakBytes * 2); + EXPECT_EQ(to->root()->reservedBytes(), rootReservedBytes); + } else { + assertEqualBytes( + to->root(), rootUsedBytes, rootPeakBytes, rootReservedBytes); + } + to->freeNonContiguous(out); + assertZeroByte(from.get()); + assertZeroByte(to.get()); + assertZeroByte(from->root()); + assertZeroByte(to->root()); + }; + + // Test transfer between siblings of the same root pool. + testTransferAllocate(false); + testTransferAllocateZeroFilled(false); + for (auto pageCount : pageCounts) { + testTransferAllocateContiguous(pageCount, false); + testTransferAllocateNonContiguous(pageCount, false); + } + + // Test transfer between different root pools. + testTransferAllocate(true); + testTransferAllocateZeroFilled(true); + for (auto pageCount : pageCounts) { + testTransferAllocateContiguous(pageCount, true); + testTransferAllocateNonContiguous(pageCount, true); + } +} + VELOX_INSTANTIATE_TEST_SUITE_P( MemoryPoolTestSuite, MemoryPoolTest, diff --git a/velox/common/memory/tests/MockSharedArbitratorTest.cpp b/velox/common/memory/tests/MockSharedArbitratorTest.cpp index 732136a4768e..a1569f547cf7 100644 --- a/velox/common/memory/tests/MockSharedArbitratorTest.cpp +++ b/velox/common/memory/tests/MockSharedArbitratorTest.cpp @@ -92,12 +92,13 @@ class MockTask : public std::enable_shared_from_this { class MemoryReclaimer : public memory::MemoryReclaimer { public: - MemoryReclaimer(const std::shared_ptr& task) - : memory::MemoryReclaimer(0), task_(task) {} + MemoryReclaimer(const std::shared_ptr& task, int32_t priority) + : memory::MemoryReclaimer(priority), task_(task) {} static std::unique_ptr create( - const std::shared_ptr& task) { - return std::make_unique(task); + const std::shared_ptr& task, + int32_t priority) { + return std::make_unique(task, priority); } void abort(MemoryPool* pool, const std::exception_ptr& error) override { @@ -113,16 +114,12 @@ class MockTask : public std::enable_shared_from_this { std::weak_ptr task_; }; - void initTaskPool( - MemoryManager* manager, - uint64_t capacity, - uint32_t taskPriority = 0) { + void + initTaskPool(MemoryManager* manager, uint64_t capacity, int32_t priority) { root_ = manager->addRootPool( fmt::format("RootPool-{}", poolId_++), capacity, - MemoryReclaimer::create(shared_from_this()), - std::nullopt, - taskPriority); + MemoryReclaimer::create(shared_from_this(), priority)); } MemoryPool* pool() const { @@ -439,75 +436,99 @@ class MockSharedArbitrationTest : public testing::Test { } void SetUp() override { - setupMemory(); + setupMemory({}); } void TearDown() override { clearTasks(); } - void setupMemory( - int64_t memoryCapacity = kMemoryCapacity, - int64_t reservedMemoryCapacity = 0, - uint64_t memoryPoolInitCapacity = 0, - uint64_t memoryPoolReserveCapacity = 0, - uint64_t fastExponentialGrowthCapacityLimit = 0, - double slowCapacityGrowPct = 0, - uint64_t memoryPoolMinFreeCapacity = 0, - double memoryPoolMinFreeCapacityPct = 0, - uint64_t memoryPoolMinReclaimBytes = 0, - double memoryPoolMinReclaimPct = 0, - uint64_t memoryPoolAbortCapacityLimit = 0, - double globalArbitrationReclaimPct = 0, - double memoryReclaimThreadsHwMultiplier = - kMemoryReclaimThreadsHwMultiplier, - std::function arbitrationStateCheckCb = nullptr, - bool globalArtbitrationEnabled = true, - uint64_t arbitrationTimeoutNs = 5 * 60 * 1'000'000'000UL, - bool globalArbitrationWithoutSpill = false, - // Set the globalArbitrationAbortTimeRatio to be very small so that the - // query can be aborted sooner and the test would not timeout. - double globalArbitrationAbortTimeRatio = 0.005) { + struct ArbitratorOptions { + int64_t memoryCapacity{kMemoryCapacity}; + int64_t reservedMemoryCapacity{0}; + uint64_t memoryPoolInitCapacity{0}; + uint64_t memoryPoolReserveCapacity{0}; + uint64_t fastExponentialGrowthCapacityLimit{0}; + + double slowCapacityGrowPct{0}; + uint64_t memoryPoolMinFreeCapacity{0}; + double memoryPoolMinFreeCapacityPct{0}; + uint64_t memoryPoolMinReclaimBytes{0}; + double memoryPoolMinReclaimPct{0}; + + uint64_t memoryPoolSpillCapacityLimit{0}; + uint64_t memoryPoolAbortCapacityLimit{0}; + double globalArbitrationReclaimPct{0}; + double memoryReclaimThreadsHwMultiplier{kMemoryReclaimThreadsHwMultiplier}; + std::function arbitrationStateCheckCb{nullptr}; + + bool globalArtbitrationEnabled{true}; + uint64_t arbitrationTimeoutNs{5 * 60 * 1'000'000'000UL}; + bool globalArbitrationWithoutSpill{false}; + // Set the globalArbitrationAbortTimeRatio to be very small so that the + // query can be aborted sooner and the test would not timeout. + double globalArbitrationAbortTimeRatio{0.005}; + }; + + void setupMemory(ArbitratorOptions arbitratorOptions) { MemoryManager::Options options; - options.allocatorCapacity = memoryCapacity; + options.allocatorCapacity = arbitratorOptions.memoryCapacity; std::string arbitratorKind = "SHARED"; options.arbitratorKind = arbitratorKind; using ExtraConfig = SharedArbitrator::ExtraConfig; options.extraArbitratorConfigs = { {std::string(ExtraConfig::kReservedCapacity), - folly::to(reservedMemoryCapacity) + "B"}, + folly::to(arbitratorOptions.reservedMemoryCapacity) + + "B"}, {std::string(ExtraConfig::kMemoryPoolInitialCapacity), - folly::to(memoryPoolInitCapacity) + "B"}, + folly::to(arbitratorOptions.memoryPoolInitCapacity) + + "B"}, {std::string(ExtraConfig::kMemoryPoolReservedCapacity), - folly::to(memoryPoolReserveCapacity) + "B"}, + folly::to(arbitratorOptions.memoryPoolReserveCapacity) + + "B"}, {std::string(ExtraConfig::kFastExponentialGrowthCapacityLimit), - folly::to(fastExponentialGrowthCapacityLimit) + "B"}, + folly::to( + arbitratorOptions.fastExponentialGrowthCapacityLimit) + + "B"}, {std::string(ExtraConfig::kSlowCapacityGrowPct), - folly::to(slowCapacityGrowPct)}, + folly::to(arbitratorOptions.slowCapacityGrowPct)}, {std::string(ExtraConfig::kMemoryPoolMinFreeCapacity), - folly::to(memoryPoolMinFreeCapacity) + "B"}, + folly::to(arbitratorOptions.memoryPoolMinFreeCapacity) + + "B"}, {std::string(ExtraConfig::kMemoryPoolMinFreeCapacityPct), - folly::to(memoryPoolMinFreeCapacityPct)}, + folly::to( + arbitratorOptions.memoryPoolMinFreeCapacityPct)}, {std::string(ExtraConfig::kMemoryPoolMinReclaimBytes), - folly::to(memoryPoolMinReclaimBytes) + "B"}, + folly::to(arbitratorOptions.memoryPoolMinReclaimBytes) + + "B"}, {std::string(ExtraConfig::kMemoryPoolMinReclaimPct), - folly::to(memoryPoolMinReclaimPct)}, + folly::to(arbitratorOptions.memoryPoolMinReclaimPct)}, + {std::string(ExtraConfig::kMemoryPoolSpillCapacityLimit), + folly::to( + arbitratorOptions.memoryPoolSpillCapacityLimit) + + "B"}, {std::string(ExtraConfig::kMemoryPoolAbortCapacityLimit), - folly::to(memoryPoolAbortCapacityLimit) + "B"}, + folly::to( + arbitratorOptions.memoryPoolAbortCapacityLimit) + + "B"}, {std::string(ExtraConfig::kGlobalArbitrationMemoryReclaimPct), - folly::to(globalArbitrationReclaimPct)}, + folly::to(arbitratorOptions.globalArbitrationReclaimPct)}, {std::string(ExtraConfig::kMemoryReclaimThreadsHwMultiplier), - folly::to(memoryReclaimThreadsHwMultiplier)}, + folly::to( + arbitratorOptions.memoryReclaimThreadsHwMultiplier)}, {std::string(ExtraConfig::kMaxMemoryArbitrationTime), - folly::to(arbitrationTimeoutNs) + "ns"}, + folly::to(arbitratorOptions.arbitrationTimeoutNs) + "ns"}, {std::string(ExtraConfig::kGlobalArbitrationEnabled), - folly::to(globalArtbitrationEnabled)}, + folly::to(arbitratorOptions.globalArtbitrationEnabled)}, {std::string(ExtraConfig::kGlobalArbitrationWithoutSpill), - folly::to(globalArbitrationWithoutSpill)}, + folly::to( + arbitratorOptions.globalArbitrationWithoutSpill)}, {std::string(ExtraConfig::kGlobalArbitrationAbortTimeRatio), - folly::to(globalArbitrationAbortTimeRatio)}}; - options.arbitrationStateCheckCb = std::move(arbitrationStateCheckCb); + folly::to( + arbitratorOptions.globalArbitrationAbortTimeRatio)}}; + options.arbitrationStateCheckCb = + std::move(arbitratorOptions.arbitrationStateCheckCb); options.checkUsageLeak = true; manager_ = std::make_unique(options); ASSERT_EQ(manager_->arbitrator()->kind(), arbitratorKind); @@ -516,9 +537,9 @@ class MockSharedArbitrationTest : public testing::Test { std::shared_ptr addTask( int64_t capacity = kMaxMemory, - uint32_t taskPriority = 0) { + int32_t priority = 0) { auto task = std::make_shared(); - task->initTaskPool(manager_.get(), capacity, taskPriority); + task->initTaskPool(manager_.get(), capacity, priority); return task; } @@ -821,40 +842,26 @@ TEST_F(MockSharedArbitrationTest, extraConfigs) { "Failed while parsing SharedArbitrator configs"); // Invalid memory reclaim executor hw multiplier. VELOX_ASSERT_THROW( - setupMemory(kMemoryCapacity, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1), + setupMemory({.memoryReclaimThreadsHwMultiplier = -1}), "memoryReclaimThreadsHwMultiplier_ needs to be positive"); // Invalid global arbitration reclaim pct. VELOX_ASSERT_THROW( - setupMemory(kMemoryCapacity, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 200), + setupMemory({.globalArbitrationReclaimPct = 200}), "(200 vs. 100) Invalid globalArbitrationMemoryReclaimPct"); // Invalid max memory arbitration time. VELOX_ASSERT_THROW( setupMemory( - kMemoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - nullptr, - false, - 0), + {.memoryReclaimThreadsHwMultiplier = 0, + .globalArtbitrationEnabled = false, + .arbitrationTimeoutNs = 0}), "(0 vs. 0) maxArbitrationTimeNs can't be zero"); } TEST_F(MockSharedArbitrationTest, constructor) { setupMemory( - kMemoryCapacity, - kReservedMemoryCapacity, - kMemoryPoolInitCapacity, - kMemoryPoolReservedCapacity); + {.reservedMemoryCapacity = kReservedMemoryCapacity, + .memoryPoolInitCapacity = kMemoryPoolInitCapacity, + .memoryPoolReserveCapacity = kMemoryPoolReservedCapacity}); const int reservedCapacity = arbitrator_->stats().freeReservedCapacityBytes; const int nonReservedCapacity = arbitrator_->stats().freeCapacityBytes - reservedCapacity; @@ -890,7 +897,10 @@ TEST_F(MockSharedArbitrationTest, arbitrationStateCheck) { ASSERT_TRUE(RE2::FullMatch(pool.name(), re)) << pool.name(); ++checkCount; }; - setupMemory(memCapacity, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0, checkCountCb); + setupMemory( + {.memoryCapacity = memCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationStateCheckCb = checkCountCb}); const int numTasks{5}; std::vector> tasks; @@ -915,7 +925,10 @@ TEST_F(MockSharedArbitrationTest, arbitrationStateCheck) { MemoryArbitrationStateCheckCB badCheckCb = [&](MemoryPool& /*unused*/) { VELOX_FAIL("bad check"); }; - setupMemory(memCapacity, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0, badCheckCb); + setupMemory( + {.memoryCapacity = memCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationStateCheckCb = badCheckCb}); std::shared_ptr task = addTask(kMemoryCapacity); ASSERT_EQ(task->capacity(), 0); MockMemoryOperator* memOp = task->addMemoryOp(); @@ -925,7 +938,9 @@ TEST_F(MockSharedArbitrationTest, arbitrationStateCheck) { TEST_F(MockSharedArbitrationTest, asyncArbitrationWork) { const int memoryCapacity = 512 * MB; const int poolCapacity = 256 * MB; - setupMemory(memoryCapacity, 0, poolCapacity, 0); + setupMemory( + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = poolCapacity}); std::atomic_int reclaimedCount{0}; std::shared_ptr task = addTask(poolCapacity); @@ -1231,8 +1246,8 @@ TEST_F(MockSharedArbitrationTest, shrinkPools) { memoryPoolCapacity, memoryPoolCapacity, false}, - // Global arbitration choose to abort the younger participant with same - // capacity bucket. + // Global arbitration choose to abort the younger participant with + // same capacity bucket. {memoryPoolCapacity, false, memoryPoolCapacity, 0, 0, true}, {memoryPoolCapacity, true, memoryPoolCapacity / 2, 0, 0, false}, {memoryPoolCapacity, true, memoryPoolCapacity / 2, 0, 0, false}}, @@ -1274,8 +1289,8 @@ TEST_F(MockSharedArbitrationTest, shrinkPools) { false}, {memoryPoolCapacity, true, memoryPoolCapacity, 0, 0, true}, {memoryPoolCapacity, false, memoryPoolCapacity / 2, 0, 0, true}, - // Global arbitration choose to abort the younger participant with same - // capacity bucket. + // Global arbitration choose to abort the younger participant with + // same capacity bucket. {memoryPoolCapacity, false, memoryPoolCapacity / 2, 0, 0, true}}, memoryPoolCapacity, memoryCapacity / 2 + memoryPoolCapacity / 2, @@ -1288,8 +1303,8 @@ TEST_F(MockSharedArbitrationTest, shrinkPools) { {{memoryPoolCapacity, true, memoryPoolCapacity, 0, 0, true}, {memoryPoolCapacity, true, memoryPoolCapacity, 0, 0, true}, {memoryPoolCapacity, false, memoryPoolCapacity / 2, 0, 0, true}, - // Global arbitration choose to abort the younger participant with same - // capacity bucket. + // Global arbitration choose to abort the younger participant with + // same capacity bucket. {memoryPoolCapacity, false, memoryPoolCapacity / 2, 0, 0, true}}, memoryPoolCapacity, 0, @@ -1320,7 +1335,9 @@ TEST_F(MockSharedArbitrationTest, shrinkPools) { SCOPED_TRACE(testData.debugString()); // Make simple settings to focus shrink capacity logic testing. - setupMemory(memoryCapacity, 0, testData.memoryPoolInitCapacity); + setupMemory( + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = testData.memoryPoolInitCapacity}); std::vector taskContainers; for (const auto& testTask : testData.testTasks) { auto task = addTask(testTask.capacity); @@ -1361,7 +1378,7 @@ TEST_F(MockSharedArbitrationTest, shrinkPools) { TEST_F(MockSharedArbitrationTest, shrinkPoolsDelayedAbort) { const int64_t memoryCapacity = 256 * MB; - setupMemory(memoryCapacity); + setupMemory({.memoryCapacity = memoryCapacity}); // Create first task using half the memory auto task1 = addTask(128 * MB); @@ -1404,7 +1421,7 @@ TEST_F(MockSharedArbitrationTest, shrinkPoolsDelayedAbort) { // serial execution mode. DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationsFromSameQuery) { const int64_t memoryCapacity = 256 << 20; - setupMemory(memoryCapacity); + setupMemory({.memoryCapacity = memoryCapacity}); auto runTask = addTask(memoryCapacity); auto* runPool = runTask->addMemoryOp(true); auto* waitPool = runTask->addMemoryOp(true); @@ -1484,7 +1501,7 @@ DEBUG_ONLY_TEST_F( localArbitrationsFromDifferentQueries) { const int64_t memoryCapacity = 512 << 20; const uint64_t memoryPoolCapacity = memoryCapacity / 2; - setupMemory(memoryCapacity); + setupMemory({.memoryCapacity = memoryCapacity}); auto task1 = addTask(memoryPoolCapacity); auto* op1 = task1->addMemoryOp(true); @@ -1692,7 +1709,9 @@ TEST_F(MockSharedArbitrationTest, badNonReclaimableQuery) { SCOPED_TRACE(testData.debugString()); // Make simple settings to focus shrink capacity logic testing. - setupMemory(memoryCapacity, 0, 0, 0, 0, 0, 0, 0, 0, 0, memoryCapacity / 8); + setupMemory( + {.memoryCapacity = memoryCapacity, + .memoryPoolAbortCapacityLimit = memoryCapacity / 8}); std::vector taskContainers; for (const auto& testTask : testData.testTasks) { auto task = addTask(memoryCapacity); @@ -1739,7 +1758,9 @@ DEBUG_ONLY_TEST_F( const uint64_t memoryPoolReservedCapacity = 8 << 20; const uint64_t reservedMemoryCapacity = 64 << 20; setupMemory( - memoryCapacity, reservedMemoryCapacity, 0, memoryPoolReservedCapacity); + {.memoryCapacity = memoryCapacity, + .reservedMemoryCapacity = reservedMemoryCapacity, + .memoryPoolReserveCapacity = memoryPoolReservedCapacity}); auto globalArbitrationTriggerThread = std::thread([&]() { std::unordered_map runtimeStats; @@ -1816,7 +1837,9 @@ DEBUG_ONLY_TEST_F( const uint64_t memoryPoolCapacity = 64 << 20; const uint64_t memoryPoolReservedCapacity = 8 << 20; setupMemory( - memoryCapacity, reservedMemoryCapacity, 0, memoryPoolReservedCapacity); + {.memoryCapacity = memoryCapacity, + .reservedMemoryCapacity = reservedMemoryCapacity, + .memoryPoolReserveCapacity = memoryPoolReservedCapacity}); auto localArbitrationTask = addTask(memoryPoolCapacity); auto* localArbitrationOp = localArbitrationTask->addMemoryOp(true); @@ -1908,24 +1931,11 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, globalArbitrationAbortTimeRatio) { const uint64_t abortTimeThresholdNs = maxArbitrationTimeNs * globalArbitrationAbortTimeRatio; setupMemory( - memoryCapacity, - 0, - memoryPoolInitCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - kMemoryReclaimThreadsHwMultiplier, - nullptr, - true, - maxArbitrationTimeNs, - false, - globalArbitrationAbortTimeRatio); + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = memoryPoolInitCapacity, + .memoryReclaimThreadsHwMultiplier = kMemoryReclaimThreadsHwMultiplier, + .arbitrationTimeoutNs = maxArbitrationTimeNs, + .globalArbitrationAbortTimeRatio = globalArbitrationAbortTimeRatio}); test::SharedArbitratorTestHelper arbitratorHelper(arbitrator_); @@ -1989,23 +1999,11 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationWithoutSpill) { const int64_t memoryCapacity = 512 << 20; const uint64_t memoryPoolInitCapacity = memoryCapacity / 2; setupMemory( - memoryCapacity, - 0, - memoryPoolInitCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - kMemoryReclaimThreadsHwMultiplier, - nullptr, - true, - 5 * 60 * 1'000'000'000UL, - true); + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = memoryPoolInitCapacity, + .memoryReclaimThreadsHwMultiplier = kMemoryReclaimThreadsHwMultiplier, + .arbitrationTimeoutNs = 5 * 60 * 1'000'000'000UL, + .globalArbitrationWithoutSpill = true}); auto triggerTask = addTask(memoryCapacity); auto* triggerOp = triggerTask->addMemoryOp(false); triggerOp->allocate(memoryCapacity / 2); @@ -2037,28 +2035,16 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationWithoutSpill) { } TEST_F(MockSharedArbitrationTest, globalArbitrationSmallParticipantLargeGrow) { - // This test tests global arbitration takes into consideration the additional - // attempting grow capacity when selecting abort partitipants. + // This test tests global arbitration takes into consideration the + // additional attempting grow capacity when selecting abort partitipants. const int64_t kMemoryCapacity = 512 << 20; const uint64_t kMemoryPoolInitCapacity = kMemoryCapacity / 2; setupMemory( - kMemoryCapacity, - 0, - kMemoryPoolInitCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - kMemoryCapacity, // Set abort capacity limit to differenciate capacity. - 0, - kMemoryReclaimThreadsHwMultiplier, - nullptr, - true, - 5 * 60 * 1'000'000'000UL, - true); + {.memoryCapacity = kMemoryCapacity, + .memoryPoolInitCapacity = kMemoryPoolInitCapacity, + // Set abort capacity limit to differenciate capacity. + .memoryPoolAbortCapacityLimit = kMemoryCapacity, + .globalArbitrationWithoutSpill = true}); auto task0 = addTask(kMemoryCapacity); auto* op0 = task0->addMemoryOp(false); @@ -2095,71 +2081,159 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationSmallParticipantLargeGrow) { "Memory pool aborted to reclaim used memory"); } -TEST_F(MockSharedArbitrationTest, globalArbitrationWithMemoryPoolPriority) { - // This test tests global arbitration takes into consideration query priority - // attempting to grow capacity when selecting abort partitipants. +TEST_F(MockSharedArbitrationTest, globalArbitrationBySpillWithPriority) { + const int64_t memoryCapacity = 512 << 20; + + struct TaskData { + uint64_t capacity; + int32_t priority; + bool expectSpill; + std::string debugString() const { + return fmt::format( + "capacity {}, priority {}, expectSpill {}", + succinctBytes(capacity), + priority, + expectSpill); + } + }; + + struct TestData { + std::string testName; + uint64_t shrinkBytes; + uint64_t spillCapacityLimit; + uint64_t spillCapacityLowerBound; + std::vector tasks; + std::string debugString() const { + std::stringstream ss; + for (const auto& task : tasks) { + ss << task.debugString() << ", "; + } + return fmt::format( + "testName {}, shrinkBytes {}, tasks [{}]", + testName, + shrinkBytes, + ss.str()); + } + }; + + std::vector testSettings = { + {"test-0", 64 << 20, 512 << 20, 0, {}}, + + {"test-1", + 64 << 20, + 512 << 20, + 0, + {{256 << 20, 0, true}, {192 << 20, 1, false}, {64 << 20, 2, false}}}, + + {"test-2", + 64 << 20, + 512 << 20, + 0, + {{192 << 20, 2, true}, {192 << 20, 1, false}, {128 << 20, 0, false}}}, + + {"test-3", + 256 << 20, + 512 << 20, + 0, + {{144 << 20, 1, true}, + {136 << 20, 1, true}, + {128 << 20, 1, false}, + {104 << 20, 0, false}}}, + + {"test-4", + 256 << 20, + 512 << 20, + 128 << 20, + {{144 << 20, 1, true}, + {120 << 20, 1, false}, + {120 << 20, 1, false}, + {104 << 20, 0, false}}}, + + {"test-5", + 230 << 20, + 512 << 20, + 0, + { + {48 << 20, 3, true}, + {44 << 20, 3, true}, + {40 << 20, 3, true}, + {52 << 20, 2, true}, + {48 << 20, 2, true}, + {44 << 20, 2, false}, + {56 << 20, 1, false}, + {52 << 20, 1, false}, + {48 << 20, 1, false}, + {28 << 20, 4, false}, + {28 << 20, 4, false}, + {24 << 20, 4, false}, + }}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + setupMemory( + {.memoryCapacity = memoryCapacity, + .memoryPoolMinReclaimBytes = testData.spillCapacityLowerBound, + .memoryPoolSpillCapacityLimit = testData.spillCapacityLimit, + .globalArbitrationWithoutSpill = true}); + + std::vector> tasks; + for (const auto& taskData : testData.tasks) { + tasks.push_back(addTask(taskData.capacity, taskData.priority)); + auto* op = tasks.back()->addMemoryOp(true); + op->allocate(taskData.capacity); + } + + arbitrator_->shrinkCapacity(testData.shrinkBytes, true, false); + for (auto i = 0; i < tasks.size(); ++i) { + ASSERT_EQ( + tasks[i]->capacity() < testData.tasks[i].capacity, + testData.tasks[i].expectSpill); + } + } +} + +TEST_F(MockSharedArbitrationTest, globalArbitrationByAbortWithPriority) { + // This test tests global arbitration takes into consideration query + // priority attempting to grow capacity when selecting abort partitipants. const int64_t memoryCapacity = 512 << 20; const uint64_t memoryPoolInitCapacity = memoryCapacity / 2; setupMemory( - memoryCapacity, - 0, - memoryPoolInitCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - memoryCapacity, // Set abort capacity limit to differenciate capacity. - 0, - kMemoryReclaimThreadsHwMultiplier, - nullptr, - true, - 5 * 60 * 1'000'000'000UL, - true); + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = memoryPoolInitCapacity, + // Set abort capacity limit to differenciate capacity. + .memoryPoolAbortCapacityLimit = memoryCapacity, + .globalArbitrationWithoutSpill = true}); - // task0 is normal priority with 256MB capacity with initial allocation of - // 256MB - auto task0 = addTask(memoryCapacity / 2, 100); + auto task0 = addTask(384 << 20, 1); auto* op0 = task0->addMemoryOp(false); - op0->allocate(memoryCapacity / 2); + op0->allocate(384 << 20); - // task1 is low priority with 256MB capacity with initial allocation of 256MB - auto task1 = addTask(memoryCapacity / 2, 10); - auto* op1 = task1->addMemoryOp(true); - op1->allocate(memoryCapacity / 2); + auto task1 = addTask(64 << 20, 1); + auto* op1 = task1->addMemoryOp(false); + op1->allocate(64 << 20); - // task2 is normal priority in lower bucket has 256MB capacity with 0 - // allocation - auto task2 = addTask(memoryCapacity / 2, 999); - auto* op2 = task2->addMemoryOp(true); - - std::unordered_map runtimeStats; - auto statsWriter = std::make_unique(runtimeStats); - setThreadLocalRunTimeStatWriter(statsWriter.get()); + auto task2 = addTask(64 << 20, 2); + auto* op2 = task2->addMemoryOp(false); + op2->allocate(64 << 20); // At this point, memory pool is full ASSERT_EQ(manager_->capacity(), manager_->getTotalBytes()); - // Next allocation should succeed with side effect of lowest priority - // query getting killed. - op2->allocate(memoryCapacity / 2); - - ASSERT_EQ( - runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].count, 1); - ASSERT_GT(runtimeStats[SharedArbitrator::kMemoryArbitrationWallNanos].sum, 0); - ASSERT_EQ( - runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].count, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kGlobalArbitrationWaitCount].sum, 1); - ASSERT_EQ(runtimeStats[SharedArbitrator::kLocalArbitrationCount].count, 0); - - // task1 gets aborted since its lowest priority compared to task0 - // task2 is younger in same bucket but survives due to priority. + arbitrator_->shrinkCapacity(64 << 20, false, true); ASSERT_TRUE(task0->error() == nullptr); - ASSERT_TRUE(task1->error() != nullptr); - ASSERT_TRUE(task2->error() == nullptr); + ASSERT_TRUE(task1->error() == nullptr); + VELOX_ASSERT_THROW( + std::rethrow_exception(task2->error()), + "Memory pool aborted to reclaim used memory"); + + arbitrator_->shrinkCapacity(64 << 20, false, true); + VELOX_ASSERT_THROW( + std::rethrow_exception(task0->error()), + "Memory pool aborted to reclaim used memory"); + ASSERT_TRUE(task1->error() == nullptr); + arbitrator_->shrinkCapacity(64 << 20, false, true); VELOX_ASSERT_THROW( std::rethrow_exception(task1->error()), "Memory pool aborted to reclaim used memory"); @@ -2168,7 +2242,9 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationWithMemoryPoolPriority) { DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, multipleGlobalRuns) { const int64_t memoryCapacity = 512 << 20; const uint64_t memoryPoolInitCapacity = memoryCapacity / 2; - setupMemory(memoryCapacity, 0, memoryPoolInitCapacity, 0); + setupMemory( + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = memoryPoolInitCapacity}); auto runTask = addTask(memoryCapacity); auto* runPool = runTask->addMemoryOp(true); runPool->allocate(memoryCapacity / 2); @@ -2253,21 +2329,15 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationEnableCheck) { const int64_t memoryCapacity = 512 << 20; const uint64_t memoryPoolInitCapacity = memoryCapacity / 2; setupMemory( - memoryCapacity, - 0, - memoryPoolInitCapacity, - 0, - kFastExponentialGrowthCapacityLimit, - kSlowCapacityGrowPct, - kMemoryPoolMinFreeCapacity, - kMemoryPoolMinFreeCapacityPct, - 0, - 0, - 0, - kGlobalArbitrationReclaimPct, - kMemoryReclaimThreadsHwMultiplier, - nullptr, - globalArbitrationEnabled); + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = memoryPoolInitCapacity, + .fastExponentialGrowthCapacityLimit = + kFastExponentialGrowthCapacityLimit, + .slowCapacityGrowPct = kSlowCapacityGrowPct, + .memoryPoolMinFreeCapacity = kMemoryPoolMinFreeCapacity, + .memoryPoolMinFreeCapacityPct = kMemoryPoolMinFreeCapacityPct, + .globalArbitrationReclaimPct = kGlobalArbitrationReclaimPct, + .globalArtbitrationEnabled = globalArbitrationEnabled}); test::SharedArbitratorTestHelper arbitratorHelper(arbitrator_); ASSERT_EQ( @@ -2327,27 +2397,24 @@ TEST_F(MockSharedArbitrationTest, singlePoolShrinkWithoutArbitration) { if (testParam.expectThrow) { VELOX_ASSERT_THROW( setupMemory( - memoryCapacity, - 0, - memoryCapacity, - 0, - 0, - 0, - testParam.memoryPoolMinFreeCapacity, - testParam.memoryPoolMinFreeCapacityPct), + {memoryCapacity, + 0, + memoryCapacity, + 0, + 0, + 0, + testParam.memoryPoolMinFreeCapacity, + testParam.memoryPoolMinFreeCapacityPct}), "both need to be set (non-zero) at the same time to enable shrink " "capacity adjustment."); continue; } else { setupMemory( - memoryCapacity, - 0, - memoryCapacity, - 0, - 0, - 0, - testParam.memoryPoolMinFreeCapacity, - testParam.memoryPoolMinFreeCapacityPct); + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = memoryCapacity, + .memoryPoolMinFreeCapacity = testParam.memoryPoolMinFreeCapacity, + .memoryPoolMinFreeCapacityPct = + testParam.memoryPoolMinFreeCapacityPct}); } auto task = addTask(); @@ -2387,12 +2454,11 @@ TEST_F(MockSharedArbitrationTest, singlePoolGrowWithoutArbitration) { for (const auto& testParam : testParams) { SCOPED_TRACE(testParam.debugString()); setupMemory( - memoryCapacity, - 0, - memoryPoolInitCapacity, - 0, - testParam.fastExponentialGrowthCapacityLimit, - testParam.slowCapacityGrowPct); + {.memoryCapacity = memoryCapacity, + .memoryPoolInitCapacity = memoryPoolInitCapacity, + .fastExponentialGrowthCapacityLimit = + testParam.fastExponentialGrowthCapacityLimit, + .slowCapacityGrowPct = testParam.slowCapacityGrowPct}); auto* memOp = addMemoryOp(); const int allocateSize = 1 * MB; @@ -2434,8 +2500,8 @@ TEST_F(MockSharedArbitrationTest, singlePoolGrowWithoutArbitration) { TEST_F(MockSharedArbitrationTest, maxCapacityReserve) { struct { - uint64_t memCapacity; - uint64_t reservedCapacity; + int64_t memCapacity; + int64_t reservedCapacity; uint64_t poolInitCapacity; uint64_t poolReservedCapacity; uint64_t poolMaxCapacity; @@ -2466,10 +2532,10 @@ TEST_F(MockSharedArbitrationTest, maxCapacityReserve) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); setupMemory( - testData.memCapacity, - testData.reservedCapacity, - testData.poolInitCapacity, - testData.poolReservedCapacity); + {.memoryCapacity = testData.memCapacity, + .reservedMemoryCapacity = testData.reservedCapacity, + .memoryPoolInitCapacity = testData.poolInitCapacity, + .memoryPoolReserveCapacity = testData.poolReservedCapacity}); if (testData.expectedError) { VELOX_ASSERT_THROW(addTask(testData.poolMaxCapacity), ""); continue; @@ -2607,14 +2673,11 @@ TEST_F(MockSharedArbitrationTest, ensureMemoryPoolMaxCapacity) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); setupMemory( - memCapacity, - 0, - poolInitCapacity, - 0, - kFastExponentialGrowthCapacityLimit, - kSlowCapacityGrowPct, - 0, - 0); + {.memoryCapacity = memCapacity, + .memoryPoolInitCapacity = poolInitCapacity, + .fastExponentialGrowthCapacityLimit = + kFastExponentialGrowthCapacityLimit, + .slowCapacityGrowPct = kSlowCapacityGrowPct}); auto requestor = addTask(testData.poolMaxCapacity); auto* requestorOp = addMemoryOp(requestor, testData.isReclaimable); @@ -2655,7 +2718,7 @@ TEST_F(MockSharedArbitrationTest, ensureMemoryPoolMaxCapacity) { TEST_F(MockSharedArbitrationTest, ensureNodeMaxCapacity) { struct { - uint64_t nodeCapacity; + int64_t nodeCapacity; uint64_t poolMaxCapacity; bool isReclaimable; uint64_t allocatedBytes; @@ -2688,7 +2751,7 @@ TEST_F(MockSharedArbitrationTest, ensureNodeMaxCapacity) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - setupMemory(testData.nodeCapacity, 0, 0, 0); + setupMemory({.memoryCapacity = testData.nodeCapacity}); auto requestor = addTask(testData.poolMaxCapacity); auto* requestorOp = addMemoryOp(requestor, testData.isReclaimable); @@ -2711,8 +2774,8 @@ TEST_F(MockSharedArbitrationTest, ensureNodeMaxCapacity) { } DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, arbitrationAbort) { - uint64_t memoryCapacity = 256 * MB; - setupMemory(memoryCapacity); + int64_t memoryCapacity = 256 * MB; + setupMemory({.memoryCapacity = memoryCapacity}); std::shared_ptr task1 = addTask(memoryCapacity); auto* op1 = task1->addMemoryOp(true, [&](MemoryPool* /*unsed*/, uint64_t /*unsed*/) { @@ -2759,8 +2822,8 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, arbitrationAbort) { } TEST_F(MockSharedArbitrationTest, shutdown) { - uint64_t memoryCapacity = 256 * MB; - setupMemory(memoryCapacity); + int64_t memoryCapacity = 256 * MB; + setupMemory({.memoryCapacity = memoryCapacity}); arbitrator_->shutdown(); // double shutdown. arbitrator_->shutdown(); @@ -2779,24 +2842,11 @@ TEST_F(MockSharedArbitrationTest, shutdown) { } DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, shutdownWait) { - uint64_t memoryCapacity = 256 * MB; + int64_t memoryCapacity = 256 * MB; setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0, - nullptr, - true, - 2'000'000'000UL); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationTimeoutNs = 2'000'000'000UL}); std::shared_ptr task1 = addTask(memoryCapacity); auto* op1 = task1->addMemoryOp(true); op1->allocate(memoryCapacity / 2); @@ -2956,17 +3006,9 @@ TEST_F(MockSharedArbitrationTest, memoryPoolAbortCapacityLimit) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - testData.memoryPoolAbortCapacityLimit); + {.memoryCapacity = memoryCapacity, + .memoryPoolAbortCapacityLimit = + testData.memoryPoolAbortCapacityLimit}); std::vector taskContainers; for (const auto& testTask : testData.testTasks) { @@ -2998,8 +3040,8 @@ TEST_F(MockSharedArbitrationTest, memoryPoolAbortCapacityLimit) { DEBUG_ONLY_TEST_F( MockSharedArbitrationTest, globalArbitrationWaitReturnEarlyWithFreeCapacity) { - uint64_t memoryCapacity = 256 * MB; - setupMemory(memoryCapacity); + int64_t memoryCapacity = 256 * MB; + setupMemory({.memoryCapacity = memoryCapacity}); std::shared_ptr task1 = addTask(memoryCapacity); auto* op1 = task1->addMemoryOp(true); op1->allocate(memoryCapacity / 2); @@ -3051,24 +3093,11 @@ DEBUG_ONLY_TEST_F( } DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, globalArbitrationTimeout) { - uint64_t memoryCapacity = 256 * MB; + int64_t memoryCapacity = 256 * MB; setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0, - nullptr, - true, - 1'000'000'000UL); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationTimeoutNs = 1'000'000'000UL}); std::shared_ptr task1 = addTask(memoryCapacity); auto* op1 = task1->addMemoryOp(true); op1->allocate(memoryCapacity / 2); @@ -3109,31 +3138,18 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, globalArbitrationTimeout) { } DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationTimeout) { - uint64_t memoryCapacity = 256 * MB; + int64_t memoryCapacity = 256 * MB; setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0, - nullptr, - true, - 1'000'000'000UL); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationTimeoutNs = 1'000'000'000UL}); std::shared_ptr task = addTask(memoryCapacity); ASSERT_EQ(task->capacity(), 0); auto* op = task->addMemoryOp(true); op->allocate(memoryCapacity / 2); SCOPED_TESTVALUE_SET( - "facebook::velox::memory::ArbitrationParticipant::reclaim", + "facebook::velox::memory::SharedArbitrator::growCapacity", std::function( ([&](const ArbitrationParticipant* /*unused*/) { std::this_thread::sleep_for(std::chrono::seconds(2)); // NOLINT @@ -3147,30 +3163,18 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationTimeout) { testing::HasSubstr("Memory arbitration timed out on memory pool")); } - // Reclaim happened before timeout check. - ASSERT_EQ(task->capacity(), 0); + // Timeout check happened before reclaim. + ASSERT_EQ(task->capacity(), memoryCapacity / 2); } DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, reclaimLockTimeout) { const uint64_t memoryCapacity = 256 * MB; const uint64_t arbitrationTimeoutMs = 1'000; setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0, - nullptr, - false, - arbitrationTimeoutMs); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .globalArtbitrationEnabled = false, + .arbitrationTimeoutNs = arbitrationTimeoutMs}); std::shared_ptr task = addTask(memoryCapacity); ASSERT_EQ(task->capacity(), 0); auto* op = task->addMemoryOp(true); @@ -3187,8 +3191,8 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, reclaimLockTimeout) { "facebook::velox::memory::ArbitrationParticipant::reclaim", std::function( ([&](const ArbitrationParticipant* /*unused*/) { - // Timeout shall be enforced at lock level. We don't expect code to - // execute pass the lock in reclaim method. + // Timeout shall be enforced at lock level. We don't expect code + // to execute pass the lock in reclaim method. FAIL(); }))); @@ -3207,24 +3211,11 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, reclaimLockTimeout) { } DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationQueueTimeout) { - uint64_t memoryCapacity = 256 * MB; + int64_t memoryCapacity = 256 * MB; setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0, - nullptr, - true, - 1'000'000'000UL); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationTimeoutNs = 1'000'000'000UL}); std::shared_ptr task = addTask(memoryCapacity); ASSERT_EQ(task->capacity(), 0); auto* op = task->addMemoryOp(true); @@ -3417,17 +3408,10 @@ TEST_F(MockSharedArbitrationTest, minReclaimBytes) { SCOPED_TRACE(testData.debugString()); // Make simple settings to focus shrink capacity logic testing. setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - testData.minReclaimBytes, - 0, - memoryCapacity); + {.memoryCapacity = memoryCapacity, + .memoryPoolMinReclaimBytes = testData.minReclaimBytes, + .memoryPoolSpillCapacityLimit = memoryCapacity, + .memoryPoolAbortCapacityLimit = memoryCapacity}); std::vector taskContainers; for (const auto& testTask : testData.testTasks) { auto task = addTask(); @@ -3559,7 +3543,8 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationReclaimPct) { SCOPED_TRACE(testData.debugString()); setupMemory( - memoryCapacity, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, testData.reclaimPct); + {.memoryCapacity = memoryCapacity, + .globalArbitrationReclaimPct = testData.reclaimPct}); std::vector taskContainers; for (const auto& testTask : testData.testTasks) { auto task = addTask(); @@ -3591,8 +3576,11 @@ TEST_F(MockSharedArbitrationTest, globalArbitrationReclaimPct) { } TEST_F(MockSharedArbitrationTest, noEligibleAbortCandidate) { - uint64_t memoryCapacity = 256 * MB; - setupMemory(memoryCapacity, memoryCapacity / 2, 0, memoryCapacity / 4); + int64_t memoryCapacity = 256 * MB; + setupMemory( + {.memoryCapacity = memoryCapacity, + .reservedMemoryCapacity = memoryCapacity / 2, + .memoryPoolReserveCapacity = static_cast(memoryCapacity / 4)}); std::shared_ptr task = addTask(memoryCapacity); ASSERT_EQ(task->capacity(), memoryCapacity / 4); auto* op = task->addMemoryOp(true); @@ -3604,7 +3592,9 @@ TEST_F(MockSharedArbitrationTest, growWithArbitrationAbort) { const int memCapacity = 256 * MB; const int initPoolCapacity = 8 * MB; setupMemory( - memCapacity, 0, initPoolCapacity, 0, 0, 0, 0, 0, 0, 0, memCapacity); + {.memoryCapacity = memCapacity, + .memoryPoolInitCapacity = initPoolCapacity, + .memoryPoolAbortCapacityLimit = memCapacity}); auto* reclaimableOp = addMemoryOp(nullptr, true); ASSERT_EQ(reclaimableOp->capacity(), initPoolCapacity); @@ -3640,10 +3630,10 @@ TEST_F(MockSharedArbitrationTest, growWithArbitrationAbort) { TEST_F(MockSharedArbitrationTest, singlePoolGrowCapacityWithArbitration) { const std::vector isLeafReclaimables = {false, true}; - const uint64_t memoryCapacity = 128 * MB; + const int64_t memoryCapacity = 128 * MB; for (const auto isLeafReclaimable : isLeafReclaimables) { SCOPED_TRACE(fmt::format("isLeafReclaimable {}", isLeafReclaimable)); - setupMemory(memoryCapacity); + setupMemory({.memoryCapacity = memoryCapacity}); auto* op = addMemoryOp(nullptr, isLeafReclaimable); op->allocate(memoryCapacity); verifyArbitratorStats(arbitrator_->stats(), memoryCapacity, 0, 0, 1); @@ -3676,15 +3666,17 @@ TEST_F(MockSharedArbitrationTest, singlePoolGrowCapacityWithArbitration) { } } -// This test verifies if a single memory pool fails to grow capacity because of -// reserved capacity. +// This test verifies if a single memory pool fails to grow capacity because +// of reserved capacity. // TODO: add reserved capacity check in ensure capacity. TEST_F(MockSharedArbitrationTest, singlePoolGrowCapacityFailedWithAbort) { const uint64_t memoryCapacity = 128 * MB; const uint64_t reservedMemoryCapacity = 64 * MB; const uint64_t memoryPoolReservedCapacity = 64 * MB; setupMemory( - memoryCapacity, reservedMemoryCapacity, 0, memoryPoolReservedCapacity); + {.memoryCapacity = memoryCapacity, + .reservedMemoryCapacity = reservedMemoryCapacity, + .memoryPoolReserveCapacity = memoryPoolReservedCapacity}); auto* op = addMemoryOp(nullptr, true); op->allocate(memoryCapacity - reservedMemoryCapacity); verifyArbitratorStats( @@ -3716,7 +3708,7 @@ TEST_F(MockSharedArbitrationTest, arbitrateWithCapacityShrink) { const std::vector isLeafReclaimables = {true, false}; for (const auto isLeafReclaimable : isLeafReclaimables) { SCOPED_TRACE(fmt::format("isLeafReclaimable {}", isLeafReclaimable)); - setupMemory(); + setupMemory({}); auto* reclaimedOp = addMemoryOp(nullptr, isLeafReclaimable); const int reclaimedOpCapacity = kMemoryCapacity * 2 / 3; const int allocateSize = 32 * MB; @@ -3754,17 +3746,10 @@ TEST_F(MockSharedArbitrationTest, arbitrateWithMemoryReclaim) { for (const auto isLeafReclaimable : isLeafReclaimables) { SCOPED_TRACE(fmt::format("isLeafReclaimable {}", isLeafReclaimable)); setupMemory( - memoryCapacity, - reservedMemoryCapacity, - 0, - reservedPoolCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - memoryPoolAbortCapacityLimit); + {.memoryCapacity = memoryCapacity, + .reservedMemoryCapacity = reservedMemoryCapacity, + .memoryPoolReserveCapacity = reservedPoolCapacity, + .memoryPoolAbortCapacityLimit = memoryPoolAbortCapacityLimit}); auto* reclaimedOp = addMemoryOp(nullptr, isLeafReclaimable); reclaimedOp->allocate( memoryCapacity - reservedMemoryCapacity - reservedPoolCapacity); @@ -3796,22 +3781,9 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, abortWithNoCandidate) { const uint64_t memoryCapacity = 256 * MB; const uint64_t maxArbitrationTimeNs = 1'000'000'000UL; setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0, - nullptr, - true, - maxArbitrationTimeNs); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationTimeoutNs = maxArbitrationTimeNs}); auto* reclaimedOp1 = addMemoryOp(nullptr, false); reclaimedOp1->allocate(memoryCapacity / 2); auto* reclaimedOp2 = addMemoryOp(nullptr, false); @@ -3863,22 +3835,9 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, reclaimWithNoCandidate) { const uint64_t memoryCapacity = 256 * MB; const uint64_t maxArbitrationTimeNs = 1'000'000'000UL; setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1.0, - nullptr, - true, - maxArbitrationTimeNs); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .arbitrationTimeoutNs = maxArbitrationTimeNs}); auto* reclaimedOp1 = addMemoryOp(nullptr, true); reclaimedOp1->allocate(memoryCapacity / 2); auto* reclaimedOp2 = addMemoryOp(nullptr, true); @@ -3930,21 +3889,11 @@ TEST_F(MockSharedArbitrationTest, arbitrateBySelfMemoryReclaim) { const uint64_t reservedCapacity = 8 * MB; const uint64_t poolReservedCapacity = 4 * MB; setupMemory( - memCapacity, - reservedCapacity, - reservedCapacity, - poolReservedCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - kMemoryReclaimThreadsHwMultiplier, - nullptr, - false); + {.memoryCapacity = memCapacity, + .reservedMemoryCapacity = reservedCapacity, + .memoryPoolInitCapacity = reservedCapacity, + .memoryPoolReserveCapacity = poolReservedCapacity, + .globalArtbitrationEnabled = false}); std::shared_ptr task = addTask(kMemoryCapacity); auto* memOp = addMemoryOp(task, isLeafReclaimable); const int allocateSize = 8 * MB; @@ -3992,7 +3941,7 @@ TEST_F(MockSharedArbitrationTest, noAbortOnRequestWhenArbitrationFails) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - setupMemory(memCapacity, 0); + setupMemory({.memoryCapacity = memCapacity}); std::shared_ptr task = addTask(kMemoryCapacity); auto* memOp = addMemoryOp(task, false); if (testData.initialAllocationSize != 0) { @@ -4023,7 +3972,7 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, orderedArbitration) { } }))); SCOPED_TESTVALUE_SET( - "facebook::velox::memory::SharedArbitrator::sortCandidatesByReclaimableUsedCapacity", + "facebook::velox::memory::SharedArbitrator::sortSpillCandidates", std::function*)>( ([&](const std::vector* candidates) { for (int i = 1; i < candidates->size(); ++i) { @@ -4035,7 +3984,7 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, orderedArbitration) { folly::Random::DefaultGenerator rng; rng.seed(512); - const uint64_t memCapacity = 512 * MB; + const int64_t memCapacity = 512 * MB; const uint64_t reservedMemCapacity = 128 * MB; const uint64_t initPoolCapacity = 32 * MB; const uint64_t reservedPoolCapacity = 8 * MB; @@ -4056,10 +4005,10 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, orderedArbitration) { SCOPED_TRACE(testData.debugString()); setupMemory( - memCapacity, - reservedMemCapacity, - initPoolCapacity, - reservedPoolCapacity); + {.memoryCapacity = memCapacity, + .reservedMemoryCapacity = reservedMemCapacity, + .memoryPoolInitCapacity = initPoolCapacity, + .memoryPoolReserveCapacity = reservedPoolCapacity}); std::vector memOps; for (int i = 0; i < numTasks; ++i) { auto* memOp = addMemoryOp(); @@ -4092,7 +4041,9 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, orderedArbitration) { TEST_F(MockSharedArbitrationTest, enterArbitrationException) { const uint64_t memCapacity = 128 * MB; const uint64_t initPoolCapacity = memCapacity; - setupMemory(memCapacity, 0, initPoolCapacity, 0); + setupMemory( + {.memoryCapacity = memCapacity, + .memoryPoolInitCapacity = initPoolCapacity}); auto* reclaimedOp = addMemoryOp(); ASSERT_EQ(reclaimedOp->capacity(), memCapacity); const int allocationSize = 8 * MB; @@ -4158,14 +4109,11 @@ TEST_F(MockSharedArbitrationTest, noArbitratiognFromAbortedPool) { TEST_F(MockSharedArbitrationTest, memoryReclaimeFailureTriggeredAbort) { setupMemory( - kMemoryCapacity, - 0, - kMemoryPoolInitCapacity, - 0, - kFastExponentialGrowthCapacityLimit, - kSlowCapacityGrowPct, - 0, - 0); + {.memoryCapacity = kMemoryCapacity, + .memoryPoolInitCapacity = kMemoryPoolInitCapacity, + .fastExponentialGrowthCapacityLimit = + kFastExponentialGrowthCapacityLimit, + .slowCapacityGrowPct = kSlowCapacityGrowPct}); const int numTasks = 4; const int smallTaskMemoryCapacity = kMemoryCapacity / 8; const int largeTaskMemoryCapacity = kMemoryCapacity / 2; @@ -4212,7 +4160,7 @@ TEST_F(MockSharedArbitrationTest, memoryReclaimeFailureTriggeredAbort) { // This test makes sure the memory capacity grows as expected. DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, concurrentArbitrationRequests) { - setupMemory(kMemoryCapacity); + setupMemory({.memoryCapacity = kMemoryCapacity}); std::shared_ptr task = addTask(); MockMemoryOperator* op1 = addMemoryOp(task); MockMemoryOperator* op2 = addMemoryOp(task); @@ -4248,7 +4196,7 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, concurrentArbitrationRequests) { DEBUG_ONLY_TEST_F( MockSharedArbitrationTest, freeUnusedCapacityWhenReclaimMemoryPool) { - setupMemory(kMemoryCapacity); + setupMemory({.memoryCapacity = kMemoryCapacity}); const int allocationSize = kMemoryCapacity / 4; std::shared_ptr reclaimedTask = addTask(); MockMemoryOperator* reclaimedTaskOp = addMemoryOp(reclaimedTask); @@ -4263,7 +4211,7 @@ DEBUG_ONLY_TEST_F( folly::EventCount reclaimBlock; std::atomic_bool reclaimBlockFlag{true}; SCOPED_TESTVALUE_SET( - "facebook::velox::memory::SharedArbitrator::sortCandidatesByReclaimableUsedCapacity", + "facebook::velox::memory::SharedArbitrator::sortAndGroupSpillCandidates", std::function(([&](const MemoryPool* /*unsed*/) { reclaimWaitFlag = false; reclaimWait.notifyAll(); @@ -4294,7 +4242,7 @@ DEBUG_ONLY_TEST_F( TEST_F(MockSharedArbitrationTest, arbitrationFailure) { int64_t maxCapacity = 128 * MB; - int64_t initialCapacity = 0 * MB; + uint64_t initialCapacity = 0 * MB; struct { int64_t requestorCapacity; int64_t requestorRequestBytes; @@ -4320,7 +4268,9 @@ TEST_F(MockSharedArbitrationTest, arbitrationFailure) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - setupMemory(maxCapacity, 0, initialCapacity); + setupMemory( + {.memoryCapacity = maxCapacity, + .memoryPoolInitCapacity = initialCapacity}); std::shared_ptr requestorTask = addTask(); MockMemoryOperator* requestorOp = addMemoryOp(requestorTask, false); requestorOp->allocate(testData.requestorCapacity); @@ -4371,21 +4321,11 @@ TEST_F( // Set min reclaim bytes to avoid reclaim from itself before fail the // arbitration. setupMemory( - memoryCapacity, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - hasMinReclaimBytes ? MB : 0, - 0, - 0, - 0, - 1.0, - nullptr, - false); + {.memoryCapacity = memoryCapacity, + .memoryPoolMinReclaimBytes = + (hasMinReclaimBytes ? static_cast(MB) : 0UL), + .memoryReclaimThreadsHwMultiplier = 1.0, + .globalArtbitrationEnabled = false}); std::shared_ptr task1 = addTask(); MockMemoryOperator* op1 = task1->addMemoryOp(false); op1->allocate(memoryCapacity / 4 * 3); @@ -4405,7 +4345,9 @@ TEST_F( reclaimBeforeReachCapacityLimitWhenGlobalArbitrationDisabled) { const int64_t memoryCapacity = 128 * MB; setupMemory( - memoryCapacity, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0, nullptr, false); + {.memoryCapacity = memoryCapacity, + .memoryReclaimThreadsHwMultiplier = 1.0, + .globalArtbitrationEnabled = false}); std::shared_ptr task1 = addTask(); MockMemoryOperator* op1 = task1->addMemoryOp(true); op1->allocate(memoryCapacity / 2); diff --git a/velox/common/memory/tests/RawVectorTest.cpp b/velox/common/memory/tests/RawVectorTest.cpp index 3d28f136d8cd..cfd9611b57c1 100644 --- a/velox/common/memory/tests/RawVectorTest.cpp +++ b/velox/common/memory/tests/RawVectorTest.cpp @@ -48,8 +48,19 @@ class RawVectorTest : public testing::WithParamInterface, std::shared_ptr pool_; }; +raw_vector makeRawVector( + int64_t initialCapacity = 0, + memory::MemoryPool* pool = nullptr) { + if (pool != nullptr) { + return raw_vector(initialCapacity, pool); + } else { + return raw_vector(initialCapacity); + } +} + TEST_P(RawVectorTest, basic) { - raw_vector ints(pool_.get()); + raw_vector ints = + makeRawVector(0, GetParam().useMemoryPool ? pool_.get() : nullptr); EXPECT_TRUE(ints.empty()); EXPECT_EQ(0, ints.capacity()); EXPECT_EQ(0, ints.size()); @@ -60,7 +71,8 @@ TEST_P(RawVectorTest, basic) { } TEST_P(RawVectorTest, padding) { - raw_vector ints(1000, pool_.get()); + raw_vector ints = + makeRawVector(1000, GetParam().useMemoryPool ? pool_.get() : nullptr); EXPECT_EQ(1000, ints.size()); // Check padding. Write a vector right below start and right after // capacity. These should fit and give no error with asan. @@ -71,7 +83,8 @@ TEST_P(RawVectorTest, padding) { } TEST_P(RawVectorTest, resize) { - raw_vector ints(1000, pool_.get()); + raw_vector ints = + makeRawVector(1000, GetParam().useMemoryPool ? pool_.get() : nullptr); ints.resize(ints.capacity()); auto size = ints.size(); ints[size - 1] = 12345; @@ -105,17 +118,24 @@ TEST_P(RawVectorTest, copyAndMove) { {leaf0.get(), nullptr}, {nullptr, leaf0.get()}}; for (auto& data : testData) { - raw_vector ints(1000, data.sourcePool); + raw_vector ints = makeRawVector(1000, data.sourcePool); // a raw_vector is intentionally not initialized. memset(ints.data(), 11, ints.size() * sizeof(int32_t)); ints[ints.size() - 1] = 12345; - raw_vector intsCopy(data.destPool); + raw_vector intsCopy; + if (data.destPool) { + intsCopy = raw_vector(data.destPool); + } intsCopy = ints; EXPECT_EQ( 0, memcmp(ints.data(), intsCopy.data(), ints.size() * sizeof(int32_t))); - raw_vector intsMoved(data.destPool); + raw_vector intsMoved; + if (data.destPool) { + intsMoved = raw_vector(data.destPool); + } intsMoved = std::move(ints); + // NOLINTNEXTLINE(bugprone-use-after-move) EXPECT_TRUE(ints.empty()); EXPECT_EQ( @@ -128,7 +148,8 @@ TEST_P(RawVectorTest, copyAndMove) { } TEST_P(RawVectorTest, iota) { - raw_vector storage(pool_.get()); + raw_vector storage = + makeRawVector(0, GetParam().useMemoryPool ? pool_.get() : nullptr); // Small sizes are preallocated. EXPECT_EQ(11, iota(12, storage)[11]); EXPECT_TRUE(storage.empty()); @@ -138,7 +159,8 @@ TEST_P(RawVectorTest, iota) { } TEST_P(RawVectorTest, iterator) { - raw_vector data(pool_.get()); + raw_vector data = + makeRawVector(0, GetParam().useMemoryPool ? pool_.get() : nullptr); data.push_back(11); data.push_back(22); data.push_back(33); @@ -150,7 +172,8 @@ TEST_P(RawVectorTest, iterator) { } TEST_P(RawVectorTest, toStdVector) { - raw_vector data(pool_.get()); + raw_vector data = + makeRawVector(0, GetParam().useMemoryPool ? pool_.get() : nullptr); data.push_back(11); data.push_back(22); data.push_back(33); diff --git a/velox/common/memory/tests/SharedArbitratorTest.cpp b/velox/common/memory/tests/SharedArbitratorTest.cpp index f01cc043da5c..96b052d9c52c 100644 --- a/velox/common/memory/tests/SharedArbitratorTest.cpp +++ b/velox/common/memory/tests/SharedArbitratorTest.cpp @@ -365,11 +365,12 @@ DEBUG_ONLY_TEST_P( .queryCtx(queryCtx) .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, "true") - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .planNode()) .assertResults("SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); ASSERT_TRUE(queryCtxStateChecked); ASSERT_FALSE(queryCtx->testingUnderArbitration()); @@ -590,11 +591,12 @@ DEBUG_ONLY_TEST_P(SharedArbitrationTestWithThreadingModes, reclaimToOrderBy) { newQueryBuilder() .queryCtx(orderByQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .orderBy({"c0 ASC NULLS LAST"}, false) - .capturePlanNodeId(orderByNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .orderBy({"c0 ASC NULLS LAST"}, false) + .capturePlanNodeId(orderByNodeId) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); auto taskStats = exec::toPlanStats(task->taskStats()); auto& stats = taskStats.at(orderByNodeId); @@ -607,12 +609,13 @@ DEBUG_ONLY_TEST_P(SharedArbitrationTestWithThreadingModes, reclaimToOrderBy) { newQueryBuilder() .queryCtx(fakeMemoryQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .addNode([&](std::string id, core::PlanNodePtr input) { - return std::make_shared(id, input); - }) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .addNode([&](std::string id, core::PlanNodePtr input) { + return std::make_shared(id, input); + }) + .planNode()) .assertResults("SELECT * FROM tmp"); }); @@ -691,11 +694,12 @@ DEBUG_ONLY_TEST_P( newQueryBuilder() .queryCtx(aggregationQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto taskStats = exec::toPlanStats(task->taskStats()); @@ -709,12 +713,13 @@ DEBUG_ONLY_TEST_P( newQueryBuilder() .queryCtx(fakeMemoryQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .addNode([&](std::string id, core::PlanNodePtr input) { - return std::make_shared(id, input); - }) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .addNode([&](std::string id, core::PlanNodePtr input) { + return std::make_shared(id, input); + }) + .planNode()) .assertResults("SELECT * FROM tmp"); }); @@ -755,7 +760,7 @@ DEBUG_ONLY_TEST_P( folly::EventCount taskPauseWait; auto taskPauseWaitKey = taskPauseWait.prepareWait(); - const auto fakeAllocationSize = kMemoryCapacity - (32L << 20); + const auto fakeAllocationSize = kMemoryCapacity - (2L << 20); std::atomic injectAllocationOnce{true}; fakeOperatorFactory_->setAllocationCallback([&](Operator* op) { @@ -822,12 +827,13 @@ DEBUG_ONLY_TEST_P( newQueryBuilder() .queryCtx(fakeMemoryQueryCtx) .serialExecution(isSerialExecutionMode_) - .plan(PlanBuilder() - .values(vectors) - .addNode([&](std::string id, core::PlanNodePtr input) { - return std::make_shared(id, input); - }) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .addNode([&](std::string id, core::PlanNodePtr input) { + return std::make_shared(id, input); + }) + .planNode()) .assertResults("SELECT * FROM tmp"); }); @@ -950,12 +956,13 @@ DEBUG_ONLY_TEST_P( .config(core::QueryConfig::kJoinSpillEnabled, "true") .config(core::QueryConfig::kSpillNumPartitionBits, "2") .maxDrivers(numDrivers) - .plan(PlanBuilder() - .values(vectors) - .localPartition({"c0", "c1"}) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .localPartition(std::vector{}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .localPartition({"c0", "c1"}) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .localPartition(std::vector{}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"), "Aborted for external error"); @@ -1022,12 +1029,13 @@ DEBUG_ONLY_TEST_P( .config(core::QueryConfig::kSpillEnabled, "true") .config(core::QueryConfig::kJoinSpillEnabled, "true") .config(core::QueryConfig::kSpillNumPartitionBits, "2") - .plan(PlanBuilder() - .values(vectors) - .localPartition({"c0", "c1"}) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .localPartition(std::vector{}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .localPartition({"c0", "c1"}) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .localPartition(std::vector{}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); }); @@ -1194,23 +1202,25 @@ DEBUG_ONLY_TEST_P( if (sameDriver) { task = newQueryBuilder() .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .localPartition(std::vector{}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .localPartition(std::vector{}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); } else { task = newQueryBuilder() .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .localPartition({"c0", "c1"}) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggregationNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .localPartition({"c0", "c1"}) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggregationNodeId) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); } @@ -1369,6 +1379,7 @@ TEST_P( if (e.errorCode() != error_code::kMemCapExceeded.c_str() && e.errorCode() != error_code::kMemAborted.c_str() && e.errorCode() != error_code::kMemAllocError.c_str() && + e.errorCode() != error_code::kMemArbitrationTimeout.c_str() && (e.message() != "Aborted for external error")) { std::rethrow_exception(std::current_exception()); } @@ -1424,12 +1435,20 @@ TEST_P(SharedArbitrationTestWithThreadingModes, reserveReleaseCounters) { VELOX_INSTANTIATE_TEST_SUITE_P( SharedArbitrationTest, SharedArbitrationTestWithParallelExecutionModeOnly, - testing::ValuesIn(std::vector{{false}})); + testing::ValuesIn(std::vector{{false}}), + [](const testing::TestParamInfo& info) { + return fmt::format( + "{}", info.param.isSerialExecutionMode ? "serial" : "parallel"); + }); VELOX_INSTANTIATE_TEST_SUITE_P( SharedArbitrationTest, SharedArbitrationTestWithThreadingModes, - testing::ValuesIn(std::vector{{false}, {true}})); + testing::ValuesIn(std::vector{{false}, {true}}), + [](const testing::TestParamInfo& info) { + return fmt::format( + "{}", info.param.isSerialExecutionMode ? "serial" : "parallel"); + }); } // namespace facebook::velox::memory int main(int argc, char** argv) { diff --git a/velox/common/memory/tests/SharedArbitratorTestUtil.h b/velox/common/memory/tests/SharedArbitratorTestUtil.h index d8124599606d..8af5e3e05726 100644 --- a/velox/common/memory/tests/SharedArbitratorTestUtil.h +++ b/velox/common/memory/tests/SharedArbitratorTestUtil.h @@ -109,43 +109,4 @@ class ArbitrationParticipantTestHelper { ArbitrationParticipant* const participant_; }; -struct ArbitrationTestStructs { - ArbitrationParticipant::Config config; - std::shared_ptr participant{nullptr}; - std::shared_ptr operation{nullptr}; - - static ArbitrationTestStructs createArbitrationTestStructs( - const std::shared_ptr& pool, - uint64_t initCapacity = 1024, - uint64_t minCapacity = 128, - uint64_t fastExponentialGrowthCapacityLimit = 0, - double slowCapacityGrowRatio = 0, - uint64_t minFreeCapacity = 0, - double minFreeCapacityRatio = 0, - uint64_t minReclaimBytes = 128, - double minReclaimPct = 0, - uint64_t abortCapacityLimit = 512, - uint64_t requestBytes = 128, - uint64_t maxArbitrationTimeNs = 1'000'000'000'000UL /* 1'000s */) { - ArbitrationTestStructs ret{ - .config = ArbitrationParticipant::Config( - initCapacity, - minCapacity, - fastExponentialGrowthCapacityLimit, - slowCapacityGrowRatio, - minFreeCapacity, - minFreeCapacityRatio, - minReclaimBytes, - minReclaimPct, - abortCapacityLimit)}; - ret.participant = ArbitrationParticipant::create( - folly::Random::rand64(), pool, &ret.config); - ret.operation = std::make_shared( - ScopedArbitrationParticipant(ret.participant, pool), - requestBytes, - maxArbitrationTimeNs); - return ret; - } -}; - } // namespace facebook::velox::memory::test diff --git a/velox/common/process/CMakeLists.txt b/velox/common/process/CMakeLists.txt index 1ac1085e98e7..2fb856caa543 100644 --- a/velox/common/process/CMakeLists.txt +++ b/velox/common/process/CMakeLists.txt @@ -18,19 +18,22 @@ velox_add_library( StackTrace.cpp ThreadDebugInfo.cpp TraceContext.cpp - TraceHistory.cpp) + TraceHistory.cpp +) velox_link_libraries( velox_process PUBLIC velox_file velox_flag_definitions Folly::folly - PRIVATE fmt::fmt gflags::gflags) + PRIVATE fmt::fmt gflags::gflags +) # Profiler need not be part of the core Velox library add_library(velox_profiler OBJECT Profiler.cpp) target_link_libraries( velox_profiler PUBLIC velox_flag_definitions Folly::folly - PRIVATE fmt::fmt gflags::gflags glog::glog) + PRIVATE fmt::fmt gflags::gflags glog::glog +) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/common/process/Profiler.cpp b/velox/common/process/Profiler.cpp index 82a5e8365269..012c3ce35242 100644 --- a/velox/common/process/Profiler.cpp +++ b/velox/common/process/Profiler.cpp @@ -153,12 +153,13 @@ void Profiler::copyToResult(const std::string* data) { auto now = nowSeconds(); auto elapsed = (now - sampleStartTime_); auto cpu = cpuSeconds(); - out->append(fmt::format( - "Profile from {} to {} at {}% CPU\n\n", + out->append( + fmt::format( + "Profile from {} to {} at {}% CPU\n\n", - timeString(sampleStartTime_), - timeString(now), - 100 * (cpu - cpuAtSampleStart_) / std::max(1, elapsed))); + timeString(sampleStartTime_), + timeString(now), + 100 * (cpu - cpuAtSampleStart_) / std::max(1, elapsed))); out->append(std::string_view(buffer, resultSize)); if (extraReport_) { std::string extra = extraReport_(); @@ -191,18 +192,19 @@ std::thread Profiler::startSample() { // and killing it with SIGINT produces a corrupt perf.data // file. The perf.data file generated when called via system() is // good, though. Unsolved mystery. - system(fmt::format( - "(cd {}; /usr/bin/perf record --pid {} {};" - "perf report --sort symbol > perf ;" - "sed --in-place 's/ / /'g perf;" - "sed --in-place 's/ / /'g perf; date) " - ">> {}/perftrace 2>>{}/perftrace2", - FLAGS_profiler_tmp_dir, - getpid(), - FLAGS_profiler_perf_flags, - FLAGS_profiler_tmp_dir, - FLAGS_profiler_tmp_dir) - .c_str()); // NOLINT + system( + fmt::format( + "(cd {}; /usr/bin/perf record --pid {} {};" + "perf report --sort symbol > perf ;" + "sed --in-place 's/ / /'g perf;" + "sed --in-place 's/ / /'g perf; date) " + ">> {}/perftrace 2>>{}/perftrace2", + FLAGS_profiler_tmp_dir, + getpid(), + FLAGS_profiler_perf_flags, + FLAGS_profiler_tmp_dir, + FLAGS_profiler_tmp_dir) + .c_str()); // NOLINT if (shouldSaveResult_) { copyToResult(); } diff --git a/velox/common/process/StackTrace.cpp b/velox/common/process/StackTrace.cpp index 96cf0a91a25e..fb8ebc3db1df 100644 --- a/velox/common/process/StackTrace.cpp +++ b/velox/common/process/StackTrace.cpp @@ -46,14 +46,14 @@ StackTrace::StackTrace(int32_t skipFrames) { } StackTrace::StackTrace(const StackTrace& other) { - bt_pointers_ = other.bt_pointers_; - if (folly::test_once(other.bt_vector_flag_)) { - bt_vector_ = other.bt_vector_; - folly::call_once(bt_vector_flag_, [] {}); // Set the flag. + btPtrs_ = other.btPtrs_; + if (folly::test_once(other.btVectorFlag_)) { + btVector_ = other.btVector_; + folly::call_once(btVectorFlag_, [] {}); // Set the flag. } - if (folly::test_once(other.bt_flag_)) { + if (folly::test_once(other.btFlag_)) { bt_ = other.bt_; - folly::call_once(bt_flag_, [] {}); // Set the flag. + folly::call_once(btFlag_, [] {}); // Set the flag. } } @@ -69,9 +69,9 @@ void StackTrace::create(int32_t skipFrames) { const int32_t kDefaultSkipFrameAdjust = 2; // ::create(), ::StackTrace() const int32_t kMaxFrames = 75; - bt_pointers_.clear(); - uintptr_t btpointers[kMaxFrames]; - ssize_t framecount = folly::symbolizer::getStackTrace(btpointers, kMaxFrames); + btPtrs_.clear(); + uintptr_t btPtrs[kMaxFrames]; + ssize_t framecount = folly::symbolizer::getStackTrace(btPtrs, kMaxFrames); if (framecount <= 0) { return; } @@ -79,9 +79,9 @@ void StackTrace::create(int32_t skipFrames) { framecount = std::min(framecount, static_cast(kMaxFrames)); skipFrames = std::max(skipFrames + kDefaultSkipFrameAdjust, 0); - bt_pointers_.reserve(framecount - skipFrames); + btPtrs_.reserve(framecount - skipFrames); for (int32_t i = skipFrames; i < framecount; i++) { - bt_pointers_.push_back(reinterpret_cast(btpointers[i])); + btPtrs_.push_back(reinterpret_cast(btPtrs[i])); } } @@ -89,32 +89,32 @@ void StackTrace::create(int32_t skipFrames) { // reporting functions const std::vector& StackTrace::toStrVector() const { - folly::call_once(bt_vector_flag_, [&] { + folly::call_once(btVectorFlag_, [&] { size_t frame = 0; static folly::Indestructible myname{ folly::demangle(typeid(decltype(*this))) + "::"}; - bt_vector_.reserve(bt_pointers_.size()); - for (auto ptr : bt_pointers_) { + btVector_.reserve(btPtrs_.size()); + for (auto ptr : btPtrs_) { auto framename = translateFrame(ptr); - if (folly::StringPiece(framename).startsWith(*myname)) { + if (framename.starts_with(*myname)) { continue; // ignore frames in the StackTrace class } - bt_vector_.push_back(fmt::format("# {:<2d} {}", frame++, framename)); + btVector_.push_back(fmt::format("# {:<2d} {}", frame++, framename)); } }); - return bt_vector_; + return btVector_; } const std::string& StackTrace::toString() const { - folly::call_once(bt_flag_, [&] { + folly::call_once(btFlag_, [&] { const auto& vec = toStrVector(); size_t needed = 0; for (const auto& frame : vec) { needed += frame.size() + 1; } bt_.reserve(needed); - for (const auto& frame_title : vec) { - bt_ += frame_title; + for (const auto& frameTitle : vec) { + bt_ += frameTitle; bt_ += '\n'; } }); diff --git a/velox/common/process/StackTrace.h b/velox/common/process/StackTrace.h index 1632d2468151..de114c1d9d32 100644 --- a/velox/common/process/StackTrace.h +++ b/velox/common/process/StackTrace.h @@ -25,65 +25,47 @@ namespace facebook::velox::process { /////////////////////////////////////////////////////////////////////////////// -// TODO: Deprecate in favor of folly::symbolizer. +/// TODO: Deprecate in favor of folly::symbolizer. class StackTrace { public: - /** - * Translate a frame pointer to file name and line number pair. - */ + /// Translate a frame pointer to file name and line number pair. static std::string translateFrame(void* framePtr, bool lineNumbers = true); - /** - * Demangle a function name. - */ + /// Demangle a function name. static std::string demangle(const char* mangled); - public: - /** - * Constructor -- saves the current stack trace. By default, we skip the - * frames for StackTrace::StackTrace. If you want those, you can pass - * '-2' to skipFrames. - */ + /// Constructor -- saves the current stack trace. By default, we skip the + /// frames for StackTrace::StackTrace. If you want those, you can pass '-2' + /// to skipFrames. explicit StackTrace(int32_t skipFrames = 0); StackTrace(const StackTrace& other); StackTrace& operator=(const StackTrace& other); - /** - * Generate an output of the written stack trace. - */ + /// Generate an output of the written stack trace. const std::string& toString() const; - /** - * Generate a vector that for each position has the title of the frame. - */ + /// Generate a vector that for each position has the title of the frame. const std::vector& toStrVector() const; - /** - * Return the raw stack pointers. - */ + /// Return the raw stack pointers. const std::vector& getStack() const { - return bt_pointers_; + return btPtrs_; } - /** - * Log stacktrace into a file under /tmp. If "out" is not null, - * also store translated stack trace into the variable. - * Returns the name of the generated file. - */ + /// Log stacktrace into a file under /tmp. If "out" is not null, also store + /// translated stack trace into the variable. Returns the name of the + /// generated file. std::string log(const char* errorType, std::string* out = nullptr) const; private: - /** - * Record bt pointers. - */ + // Record bt pointers. void create(int32_t skipFrames); - private: - std::vector bt_pointers_; - mutable folly::once_flag bt_vector_flag_; - mutable std::vector bt_vector_; - mutable folly::once_flag bt_flag_; + std::vector btPtrs_; + mutable folly::once_flag btVectorFlag_; + mutable std::vector btVector_; + mutable folly::once_flag btFlag_; mutable std::string bt_; }; diff --git a/velox/common/process/tests/CMakeLists.txt b/velox/common/process/tests/CMakeLists.txt index 5f36a9bacf94..6200f5474fd8 100644 --- a/velox/common/process/tests/CMakeLists.txt +++ b/velox/common/process/tests/CMakeLists.txt @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_process_test ProfilerTest.cpp ThreadLocalRegistryTest.cpp - TraceContextTest.cpp TraceHistoryTest.cpp) +add_executable( + velox_process_test + ProfilerTest.cpp + ThreadLocalRegistryTest.cpp + TraceContextTest.cpp + TraceHistoryTest.cpp +) add_test(velox_process_test velox_process_test) target_link_libraries( velox_process_test - PRIVATE - velox_process - velox_profiler - fmt::fmt - velox_time - GTest::gtest - GTest::gtest_main) + PRIVATE velox_process velox_profiler fmt::fmt velox_time GTest::gtest GTest::gtest_main +) diff --git a/velox/common/serialization/CMakeLists.txt b/velox/common/serialization/CMakeLists.txt index c818597442d2..772987a61f35 100644 --- a/velox/common/serialization/CMakeLists.txt +++ b/velox/common/serialization/CMakeLists.txt @@ -14,8 +14,7 @@ velox_add_library(velox_serialization DeserializationRegistry.cpp) -velox_link_libraries(velox_serialization PUBLIC velox_exception Folly::folly - glog::glog) +velox_link_libraries(velox_serialization PUBLIC velox_exception Folly::folly glog::glog) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/common/serialization/Serializable.h b/velox/common/serialization/Serializable.h index befe526f4b84..148d12ae2c7c 100644 --- a/velox/common/serialization/Serializable.h +++ b/velox/common/serialization/Serializable.h @@ -143,9 +143,9 @@ class ISerializable { template < typename T, std::enable_if_t< - std::is_same_v>>> - static folly::dynamic serialize(const folly::Optional& val) { - if (!val.hasValue()) { + std::is_same_v>>> + static folly::dynamic serialize(const std::optional& val) { + if (!val.has_value()) { return nullptr; } @@ -250,16 +250,16 @@ class ISerializable { template < typename T, typename = std::enable_if_t< - std::is_same_v>>> - static folly::Optional< + std::is_same_v>>> + static std::optional< decltype(ISerializable::deserialize( std::declval()))> deserialize(const folly::dynamic& obj, void* context = nullptr) { if (obj.isNull()) { - return folly::none; + return std::nullopt; } auto val = deserialize(obj); - return folly::Optional< + return std::optional< decltype(ISerializable::deserialize( std::declval()))>(move(val)); } @@ -298,8 +298,7 @@ class ISerializable { } template < - template - typename TMap, + template typename TMap, typename TKey, typename TMapped, typename... TArgs, diff --git a/velox/common/serialization/tests/CMakeLists.txt b/velox/common/serialization/tests/CMakeLists.txt index 74b2baf630ea..a61f9252cace 100644 --- a/velox/common/serialization/tests/CMakeLists.txt +++ b/velox/common/serialization/tests/CMakeLists.txt @@ -17,10 +17,5 @@ add_test(velox_serialization_test velox_serialization_test) target_link_libraries( velox_serialization_test - PRIVATE - velox_exception - velox_serialization - Folly::folly - glog::glog - GTest::gtest - GTest::gtest_main) + PRIVATE velox_exception velox_serialization Folly::folly glog::glog GTest::gtest GTest::gtest_main +) diff --git a/velox/common/serialization/tests/SerializableTest.cpp b/velox/common/serialization/tests/SerializableTest.cpp index 90b60852b9b7..a9692a8da161 100644 --- a/velox/common/serialization/tests/SerializableTest.cpp +++ b/velox/common/serialization/tests/SerializableTest.cpp @@ -160,8 +160,7 @@ TEST(SerializableTest, context) { } template < - template - typename TMap, + template typename TMap, typename TKey, typename TMapped, typename TIt, diff --git a/velox/common/testutil/CMakeLists.txt b/velox/common/testutil/CMakeLists.txt index 6e8e05b849ff..5b5dca326d9d 100644 --- a/velox/common/testutil/CMakeLists.txt +++ b/velox/common/testutil/CMakeLists.txt @@ -12,13 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_test_util ScopedTestTime.cpp TestValue.cpp - RandomSeed.cpp) +velox_add_library(velox_test_util ScopedTestTime.cpp TestValue.cpp RandomSeed.cpp) -velox_link_libraries( - velox_test_util - PUBLIC velox_exception - PRIVATE glog::glog Folly::folly) +velox_link_libraries(velox_test_util PUBLIC velox_exception PRIVATE glog::glog Folly::folly) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/common/testutil/OptionalEmpty.h b/velox/common/testutil/OptionalEmpty.h index dc7a22029289..ace3e3ad0688 100644 --- a/velox/common/testutil/OptionalEmpty.h +++ b/velox/common/testutil/OptionalEmpty.h @@ -30,4 +30,13 @@ struct OptionalEmptyT { inline constexpr OptionalEmptyT optionalEmpty; +struct EmptyT { + template + operator std::vector() const { + return {}; + } +}; + +inline constexpr EmptyT Empty; + } // namespace facebook::velox::common::testutil diff --git a/velox/common/testutil/TestValue.cpp b/velox/common/testutil/TestValue.cpp index b833b980ac42..bf14f66d8070 100644 --- a/velox/common/testutil/TestValue.cpp +++ b/velox/common/testutil/TestValue.cpp @@ -20,7 +20,7 @@ namespace facebook::velox::common::testutil { std::mutex TestValue::mutex_; bool TestValue::enabled_ = false; -std::unordered_map TestValue::injectionMap_; +folly::F14FastMap TestValue::injectionMap_; #ifndef NDEBUG void TestValue::enable() { @@ -38,12 +38,12 @@ bool TestValue::enabled() { return enabled_; } -void TestValue::clear(const std::string& injectionPoint) { +void TestValue::clear(std::string_view injectionPoint) { std::lock_guard l(mutex_); injectionMap_.erase(injectionPoint); } -void TestValue::adjust(const std::string& injectionPoint, void* testData) { +void TestValue::adjust(std::string_view injectionPoint, void* testData) { Callback injectionCb; { std::lock_guard l(mutex_); @@ -60,7 +60,7 @@ void TestValue::disable() {} bool TestValue::enabled() { return false; } -void TestValue::clear(const std::string& injectionPoint) {} +void TestValue::clear(std::string_view injectionPoint) {} #endif } // namespace facebook::velox::common::testutil diff --git a/velox/common/testutil/TestValue.h b/velox/common/testutil/TestValue.h index ba988a6935bb..321936bd67cd 100644 --- a/velox/common/testutil/TestValue.h +++ b/velox/common/testutil/TestValue.h @@ -15,9 +15,10 @@ */ #pragma once +#include #include #include -#include +#include #include "velox/common/base/Exceptions.h" #include "velox/common/base/Macros.h" @@ -45,24 +46,24 @@ class TestValue { /// injected callback hook. template static void set( - const std::string& injectionPoint, + std::string_view injectionPoint, std::function injectionCb); /// Invoked by the test code to unregister a callback hook at the specified /// execution point. - static void clear(const std::string& injectionPoint); + static void clear(std::string_view injectionPoint); /// Invoked by the production code to try to invoke the test callback hook /// with 'testData' if there is one registered at the specified execution /// point. 'testData' capture the mutable production execution state. - static void adjust(const std::string& injectionPoint, void* testData); + static void adjust(std::string_view injectionPoint, void* testData); private: using Callback = std::function; static std::mutex mutex_; static bool enabled_; - static std::unordered_map injectionMap_; + static folly::F14FastMap injectionMap_; }; class ScopedTestValue { @@ -79,13 +80,13 @@ class ScopedTestValue { } private: - const std::string point_; + std::string point_; }; #ifndef NDEBUG template void TestValue::set( - const std::string& injectionPoint, + std::string_view injectionPoint, std::function injectionCb) { std::lock_guard l(mutex_); if (!enabled_) { @@ -99,15 +100,14 @@ void TestValue::set( #else template void TestValue::set( - const std::string& injectionPoint, + std::string_view injectionPoint, std::function injectionCb) {} #endif #ifdef NDEBUG // Keep the definition in header so that it can be inlined (elided). -inline void TestValue::adjust( - const std::string& injectionPoint, - void* testData) {} +inline void TestValue::adjust(std::string_view injectionPoint, void* testData) { +} #endif #define SCOPED_TESTVALUE_SET(point, ...) \ diff --git a/velox/common/testutil/tests/CMakeLists.txt b/velox/common/testutil/tests/CMakeLists.txt index d164600656a3..727b492bf564 100644 --- a/velox/common/testutil/tests/CMakeLists.txt +++ b/velox/common/testutil/tests/CMakeLists.txt @@ -12,15 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. include(GoogleTest) -add_executable(velox_test_util_test TestScopedTestTime.cpp TestValueTest.cpp) +add_executable(velox_test_util_test CastsTest.cpp TestScopedTestTime.cpp TestValueTest.cpp) + gtest_add_tests(velox_test_util_test "" AUTO) target_link_libraries( velox_test_util_test - PRIVATE - velox_test_util - velox_exception - velox_exec - velox_time - GTest::gtest - GTest::gtest_main) + PRIVATE velox_test_util velox_exception velox_exec velox_time GTest::gtest GTest::gtest_main +) diff --git a/velox/common/testutil/tests/CastsTest.cpp b/velox/common/testutil/tests/CastsTest.cpp new file mode 100644 index 000000000000..b5a9cf600527 --- /dev/null +++ b/velox/common/testutil/tests/CastsTest.cpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/Casts.h" +#include "velox/common/base/tests/GTestUtils.h" + +namespace facebook::velox { + +namespace { + +// Test classes for inheritance hierarchy +class BaseClass { + public: + virtual ~BaseClass() = default; + virtual int getValue() const { + return 42; + } +}; + +class DerivedClass : public BaseClass { + public: + int getValue() const override { + return 100; + } + + int getDerivedValue() const { + return 200; + } +}; + +class AnotherDerivedClass : public BaseClass { + public: + int getValue() const override { + return 300; + } +}; + +class UnrelatedClass { + public: + virtual ~UnrelatedClass() = default; + int getValue() const { + return 999; + } +}; + +} // namespace + +class CastsTest : public ::testing::Test { + protected: + void SetUp() override { + basePtr_ = std::make_shared(); + derivedPtr_ = std::make_shared(); + anotherDerivedPtr_ = std::make_shared(); + unrelatedPtr_ = std::make_shared(); + + baseUniquePtr_ = std::make_unique(); + derivedUniquePtr_ = std::make_unique(); + + baseRawPtr_ = new BaseClass(); + derivedRawPtr_ = new DerivedClass(); + } + + void TearDown() override { + delete baseRawPtr_; + delete derivedRawPtr_; + } + + std::shared_ptr basePtr_; + std::shared_ptr derivedPtr_; + std::shared_ptr anotherDerivedPtr_; + std::shared_ptr unrelatedPtr_; + + std::unique_ptr baseUniquePtr_; + std::unique_ptr derivedUniquePtr_; + + BaseClass* baseRawPtr_; + DerivedClass* derivedRawPtr_; +}; + +// Tests for checkedPointerCast with shared_ptr +TEST_F(CastsTest, checkedPointerCastSharedPtrSuccess) { + // Cast derived to base (should always work) + auto result = checkedPointerCast(derivedPtr_); + EXPECT_NE(result, nullptr); + EXPECT_EQ(result->getValue(), 100); + + // Cast base to derived when it actually is derived + std::shared_ptr basePtrToDerived = derivedPtr_; + auto derivedResult = checkedPointerCast(basePtrToDerived); + EXPECT_NE(derivedResult, nullptr); + EXPECT_EQ(derivedResult->getValue(), 100); + EXPECT_EQ(derivedResult->getDerivedValue(), 200); +} + +TEST_F(CastsTest, checkedPointerCastSharedPtrFailure) { + // Try to cast base to derived when it's not actually derived + VELOX_ASSERT_THROW(checkedPointerCast(basePtr_), ""); + + // Try to cast to unrelated class + VELOX_ASSERT_THROW(checkedPointerCast(derivedPtr_), ""); +} + +TEST_F(CastsTest, checkedPointerCastSharedPtrNullInput) { + std::shared_ptr nullPtr; + VELOX_ASSERT_THROW(checkedPointerCast(nullPtr), ""); +} + +// Tests for checkedPointerCast with unique_ptr +TEST_F(CastsTest, checkedPointerCastUniquePtrSuccess) { + // Cast derived to base + auto derivedForCast = std::make_unique(); + auto result = checkedPointerCast(std::move(derivedForCast)); + EXPECT_NE(result, nullptr); + EXPECT_EQ(result->getValue(), 100); + + // Cast base to derived when it actually is derived + std::unique_ptr basePtrToDerived = + std::make_unique(); + auto derivedResult = + checkedPointerCast(std::move(basePtrToDerived)); + EXPECT_NE(derivedResult, nullptr); + EXPECT_EQ(derivedResult->getValue(), 100); + EXPECT_EQ(derivedResult->getDerivedValue(), 200); +} + +TEST_F(CastsTest, checkedPointerCastUniquePtrFailure) { + // Try to cast base to derived when it's not actually derived + auto baseForCast = std::make_unique(); + VELOX_ASSERT_THROW( + checkedPointerCast(std::move(baseForCast)), ""); +} + +TEST_F(CastsTest, checkedPointerCastUniquePtrNullInput) { + std::unique_ptr nullPtr; + VELOX_ASSERT_THROW(checkedPointerCast(std::move(nullPtr)), ""); +} + +// Tests for checkedPointerCast with raw pointers +TEST_F(CastsTest, checkedPointerCastRawPtrSuccess) { + // Cast derived to base + auto result = checkedPointerCast(derivedRawPtr_); + EXPECT_NE(result, nullptr); + EXPECT_EQ(result->getValue(), 100); + + // Cast base to derived when it actually is derived + BaseClass* basePtrToDerived = derivedRawPtr_; + auto derivedResult = checkedPointerCast(basePtrToDerived); + EXPECT_NE(derivedResult, nullptr); + EXPECT_EQ(derivedResult->getValue(), 100); + EXPECT_EQ(derivedResult->getDerivedValue(), 200); +} + +TEST_F(CastsTest, checkedPointerCastRawPtrFailure) { + // Try to cast base to derived when it's not actually derived + VELOX_ASSERT_THROW(checkedPointerCast(baseRawPtr_), ""); +} + +TEST_F(CastsTest, checkedPointerCastRawPtrNullInput) { + BaseClass* nullPtr = nullptr; + VELOX_ASSERT_THROW(checkedPointerCast(nullPtr), ""); +} + +// Tests for staticUniquePointerCast +TEST_F(CastsTest, staticUniquePointerCastSuccess) { + // Create a unique_ptr to derived and cast to base + auto derivedForCast = std::make_unique(); + auto originalPtr = derivedForCast.get(); + auto result = staticUniquePointerCast(std::move(derivedForCast)); + + EXPECT_NE(result, nullptr); + EXPECT_EQ(result.get(), originalPtr); // Should be the same pointer + EXPECT_EQ(result->getValue(), 100); +} + +TEST_F(CastsTest, staticUniquePointerCastNullInput) { + std::unique_ptr nullPtr; + VELOX_ASSERT_THROW( + staticUniquePointerCast(std::move(nullPtr)), ""); +} + +// Tests for isInstanceOf with shared_ptr +TEST_F(CastsTest, isInstanceOfSharedPtr) { + // Test positive cases + EXPECT_TRUE(isInstanceOf(derivedPtr_)); + EXPECT_TRUE(isInstanceOf(derivedPtr_)); + EXPECT_TRUE(isInstanceOf(anotherDerivedPtr_)); + EXPECT_TRUE(isInstanceOf(anotherDerivedPtr_)); + + // Test negative cases + EXPECT_FALSE(isInstanceOf(basePtr_)); + EXPECT_FALSE(isInstanceOf(derivedPtr_)); + EXPECT_FALSE(isInstanceOf(anotherDerivedPtr_)); + EXPECT_FALSE(isInstanceOf(derivedPtr_)); +} + +TEST_F(CastsTest, isInstanceOfSharedPtrNullInput) { + std::shared_ptr nullPtr; + VELOX_ASSERT_THROW(isInstanceOf(nullPtr), ""); +} + +// Tests for isInstanceOf with unique_ptr +TEST_F(CastsTest, isInstanceOfUniquePtr) { + // Test positive cases + EXPECT_TRUE(isInstanceOf(derivedUniquePtr_)); + EXPECT_TRUE(isInstanceOf(derivedUniquePtr_)); + + // Test negative cases + EXPECT_FALSE(isInstanceOf(baseUniquePtr_)); + EXPECT_FALSE(isInstanceOf(derivedUniquePtr_)); +} + +TEST_F(CastsTest, isInstanceOfUniquePtrNullInput) { + std::unique_ptr nullPtr; + VELOX_ASSERT_THROW(isInstanceOf(nullPtr), ""); +} + +// Tests for isInstanceOf with raw pointers +TEST_F(CastsTest, isInstanceOfRawPtr) { + // Test positive cases + EXPECT_TRUE(isInstanceOf(derivedRawPtr_)); + EXPECT_TRUE(isInstanceOf(derivedRawPtr_)); + + // Test negative cases + EXPECT_FALSE(isInstanceOf(baseRawPtr_)); + EXPECT_FALSE(isInstanceOf(derivedRawPtr_)); +} + +TEST_F(CastsTest, isInstanceOfRawPtrNullInput) { + BaseClass* nullPtr = nullptr; + VELOX_ASSERT_THROW(isInstanceOf(nullPtr), ""); +} + +// Test error messages contain useful information +TEST_F(CastsTest, errorMessageContent) { + try { + checkedPointerCast(basePtr_); + FAIL() << "Expected VeloxException to be thrown"; + } catch (const VeloxException& e) { + const std::string& message = e.message(); + // Check that the error message contains type information + EXPECT_TRUE(message.find("Failed to cast") != std::string::npos); + EXPECT_TRUE(message.find("BaseClass") != std::string::npos); + EXPECT_TRUE(message.find("DerivedClass") != std::string::npos); + } +} + +// Test that successful casts preserve object identity +TEST_F(CastsTest, objectIdentityPreserved) { + // For shared_ptr + std::shared_ptr basePtrToDerived = derivedPtr_; + auto castedShared = checkedPointerCast(basePtrToDerived); + EXPECT_EQ(castedShared.get(), derivedPtr_.get()); + + // For raw ptr + BaseClass* basePtrToDerivedRaw = derivedRawPtr_; + auto castedRaw = checkedPointerCast(basePtrToDerivedRaw); + EXPECT_EQ(castedRaw, derivedRawPtr_); +} + +// Test exception safety for unique_ptr casting +TEST_F(CastsTest, uniquePtrExceptionSafety) { + auto baseForCast = std::make_unique(); + + try { + checkedPointerCast(std::move(baseForCast)); + FAIL() << "Expected VeloxException to be thrown"; + } catch (const VeloxException&) { + // The unique_ptr should have been restored and the object should still + // exist We can't directly test this without access to the internal + // implementation, but the test framework will detect memory leaks if the + // object was lost + } +} + +} // namespace facebook::velox diff --git a/velox/common/time/CMakeLists.txt b/velox/common/time/CMakeLists.txt index 8c9e39f51835..46c878968f03 100644 --- a/velox/common/time/CMakeLists.txt +++ b/velox/common/time/CMakeLists.txt @@ -16,5 +16,4 @@ if(${VELOX_BUILD_TESTING}) endif() velox_add_library(velox_time CpuWallTimer.cpp Timer.cpp) -velox_link_libraries(velox_time PUBLIC velox_process velox_test_util - Folly::folly fmt::fmt) +velox_link_libraries(velox_time PUBLIC velox_process velox_test_util Folly::folly fmt::fmt) diff --git a/velox/common/time/CpuWallTimer.cpp b/velox/common/time/CpuWallTimer.cpp index e21944b05d1c..93537bb13d62 100644 --- a/velox/common/time/CpuWallTimer.cpp +++ b/velox/common/time/CpuWallTimer.cpp @@ -18,15 +18,19 @@ namespace facebook::velox { -CpuWallTimer::CpuWallTimer(CpuWallTiming& timing) : timing_(timing) { +CpuWallTimer::CpuWallTimer(CpuWallTiming& timing) + : wallTimeStart_(std::chrono::steady_clock::now()), + cpuTimeStart_(process::threadCpuNanos()), + timing_(timing) { ++timing_.count; - cpuTimeStart_ = process::threadCpuNanos(); - wallTimeStart_ = std::chrono::steady_clock::now(); } CpuWallTimer::~CpuWallTimer() { + // NOTE: End the cpu-time timing first, and then end the wall-time timing, + // so as to avoid the counter-intuitive phenomenon that the final calculated + // cpu-time is slightly larger than the wall-time. timing_.cpuNanos += process::threadCpuNanos() - cpuTimeStart_; - auto duration = std::chrono::duration_cast( + const auto duration = std::chrono::duration_cast( std::chrono::steady_clock::now() - wallTimeStart_); timing_.wallNanos += duration.count(); } diff --git a/velox/common/time/CpuWallTimer.h b/velox/common/time/CpuWallTimer.h index 430ec614b1e6..06feb989848c 100644 --- a/velox/common/time/CpuWallTimer.h +++ b/velox/common/time/CpuWallTimer.h @@ -18,19 +18,25 @@ #include #include +#include "velox/common/base/Macros.h" #include "velox/common/base/SuccinctPrinter.h" #include "velox/common/process/ProcessBase.h" namespace facebook::velox { -// Tracks call count and elapsed CPU and wall time for a repeating operation. +/// Tracks call count and elapsed CPU and wall time for a repeating operation. struct CpuWallTiming { uint64_t count = 0; uint64_t wallNanos = 0; uint64_t cpuNanos = 0; + auto operator<=>(const CpuWallTiming&) const = default; + void add(const CpuWallTiming& other) { + // Suppress spurious warnings in GCC 13. + VELOX_SUPPRESS_STRINGOP_OVERFLOW_WARNING count += other.count; + VELOX_UNSUPPRESS_STRINGOP_OVERFLOW_WARNING cpuNanos += other.cpuNanos; wallNanos += other.wallNanos; } @@ -50,15 +56,17 @@ struct CpuWallTiming { } }; -// Adds elapsed CPU and wall time to a CpuWallTiming. +/// Adds elapsed CPU and wall time to a CpuWallTiming. class CpuWallTimer { public: explicit CpuWallTimer(CpuWallTiming& timing); ~CpuWallTimer(); private: - uint64_t cpuTimeStart_; - std::chrono::steady_clock::time_point wallTimeStart_; + // NOTE: Put `wallTimeStart_` before `cpuTimeStart_`, so that wall-time starts + // counting earlier than cpu-time. + const std::chrono::steady_clock::time_point wallTimeStart_; + const uint64_t cpuTimeStart_; CpuWallTiming& timing_; }; @@ -73,8 +81,8 @@ class DeltaCpuWallTimeStopWatch { // NOTE: End the cpu-time timing first, and then end the wall-time timing, // so as to avoid the counter-intuitive phenomenon that the final calculated // cpu-time is slightly larger than the wall-time. - uint64_t cpuTimeDuration = process::threadCpuNanos() - cpuTimeStart_; - uint64_t wallTimeDuration = + const uint64_t cpuTimeDuration = process::threadCpuNanos() - cpuTimeStart_; + const uint64_t wallTimeDuration = std::chrono::duration_cast( std::chrono::steady_clock::now() - wallTimeStart_) .count(); diff --git a/velox/common/time/tests/CMakeLists.txt b/velox/common/time/tests/CMakeLists.txt index 2b2bc5caa9dc..cb602e651a35 100644 --- a/velox/common/time/tests/CMakeLists.txt +++ b/velox/common/time/tests/CMakeLists.txt @@ -15,8 +15,6 @@ include(GoogleTest) add_executable(velox_time_test CpuWallTimerTest.cpp) -target_link_libraries( - velox_time_test - PRIVATE velox_time glog::glog GTest::gtest GTest::gtest_main) +target_link_libraries(velox_time_test PRIVATE velox_time glog::glog GTest::gtest GTest::gtest_main) gtest_add_tests(velox_time_test "" AUTO) diff --git a/velox/connectors/CMakeLists.txt b/velox/connectors/CMakeLists.txt index 743ecbcaac00..4568f83650d9 100644 --- a/velox/connectors/CMakeLists.txt +++ b/velox/connectors/CMakeLists.txt @@ -29,6 +29,10 @@ if(${VELOX_ENABLE_TPCH_CONNECTOR}) add_subdirectory(tpch) endif() +if(${VELOX_ENABLE_TPCDS_CONNECTOR}) + add_subdirectory(tpcds) +endif() + if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() diff --git a/velox/connectors/Connector.cpp b/velox/connectors/Connector.cpp index cddbc31d8de1..ba9d1a9c4154 100644 --- a/velox/connectors/Connector.cpp +++ b/velox/connectors/Connector.cpp @@ -18,12 +18,6 @@ namespace facebook::velox::connector { namespace { -std::unordered_map>& -connectorFactories() { - static std::unordered_map> - factories; - return factories; -} std::unordered_map>& connectors() { static std::unordered_map> connectors; @@ -43,35 +37,6 @@ std::string DataSink::Stats::toString() const { spillStats.toString()); } -bool registerConnectorFactory(std::shared_ptr factory) { - bool ok = - connectorFactories().insert({factory->connectorName(), factory}).second; - VELOX_CHECK( - ok, - "ConnectorFactory with name '{}' is already registered", - factory->connectorName()); - return true; -} - -bool hasConnectorFactory(const std::string& connectorName) { - return connectorFactories().count(connectorName) == 1; -} - -bool unregisterConnectorFactory(const std::string& connectorName) { - auto count = connectorFactories().erase(connectorName); - return count == 1; -} - -std::shared_ptr getConnectorFactory( - const std::string& connectorName) { - auto it = connectorFactories().find(connectorName); - VELOX_CHECK( - it != connectorFactories().end(), - "ConnectorFactory with name '{}' not registered", - connectorName); - return it->second; -} - bool registerConnector(std::shared_ptr connector) { bool ok = connectors().insert({connector->connectorId(), connector}).second; VELOX_CHECK( @@ -95,6 +60,10 @@ std::shared_ptr getConnector(const std::string& connectorId) { return it->second; } +bool hasConnector(const std::string& connectorId) { + return connectors().find(connectorId) != connectors().end(); +} + const std::unordered_map>& getAllConnectors() { return connectors(); @@ -130,27 +99,18 @@ std::shared_ptr Connector::getTracker( }); } -std::string commitStrategyToString(CommitStrategy commitStrategy) { - switch (commitStrategy) { - case CommitStrategy::kNoCommit: - return "NO_COMMIT"; - case CommitStrategy::kTaskCommit: - return "TASK_COMMIT"; - default: - VELOX_UNREACHABLE( - "UNKOWN COMMIT STRATEGY: {}", static_cast(commitStrategy)); - } +namespace { +const folly::F14FastMap& +commitStrategyNames() { + static const folly::F14FastMap kNames = { + {CommitStrategy::kNoCommit, "NO_COMMIT"}, + {CommitStrategy::kTaskCommit, "TASK_COMMIT"}, + }; + return kNames; } +} // namespace -CommitStrategy stringToCommitStrategy(const std::string& strategy) { - if (strategy == "NO_COMMIT") { - return CommitStrategy::kNoCommit; - } else if (strategy == "TASK_COMMIT") { - return CommitStrategy::kTaskCommit; - } else { - VELOX_UNREACHABLE("UNKOWN COMMIT STRATEGY: {}", strategy); - } -} +VELOX_DEFINE_ENUM_NAME(CommitStrategy, commitStrategyNames); folly::dynamic ColumnHandle::serializeBase(std::string_view name) { folly::dynamic obj = folly::dynamic::object; @@ -173,4 +133,5 @@ folly::dynamic ConnectorTableHandle::serializeBase( folly::dynamic ConnectorTableHandle::serialize() const { return serializeBase("ConnectorTableHandle"); } + } // namespace facebook::velox::connector diff --git a/velox/connectors/Connector.h b/velox/connectors/Connector.h index 96b0e045830d..2eb85594c09c 100644 --- a/velox/connectors/Connector.h +++ b/velox/connectors/Connector.h @@ -16,6 +16,7 @@ #pragma once #include "folly/CancellationToken.h" +#include "velox/common/Enums.h" #include "velox/common/base/AsyncSource.h" #include "velox/common/base/PrefixSortConfig.h" #include "velox/common/base/RuntimeMetrics.h" @@ -23,9 +24,11 @@ #include "velox/common/base/SpillStats.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/caching/ScanTracker.h" +#include "velox/common/file/TokenProvider.h" #include "velox/common/future/VeloxPromise.h" #include "velox/core/ExpressionEvaluator.h" -#include "velox/type/Subfield.h" +#include "velox/core/QueryConfig.h" +#include "velox/type/Filter.h" #include "velox/vector/ComplexVector.h" #include @@ -36,9 +39,6 @@ class Config; namespace facebook::velox::wave { class WaveDataSource; } -namespace facebook::velox::common { -class Filter; -} namespace facebook::velox::config { class ConfigBase; } @@ -77,7 +77,15 @@ struct ConnectorSplit : public ISerializable { return nullptr; } - virtual ~ConnectorSplit() {} + virtual uint64_t size() const { + return 0; + } + + virtual ~ConnectorSplit() { + if (dataSource) { + dataSource->close(); + } + } virtual std::string toString() const { return fmt::format( @@ -92,8 +100,10 @@ class ColumnHandle : public ISerializable { public: virtual ~ColumnHandle() = default; - virtual const std::string& name() const { - VELOX_UNSUPPORTED(); + virtual const std::string& name() const = 0; + + virtual std::string toString() const { + return name(); } folly::dynamic serialize() const override; @@ -104,6 +114,12 @@ class ColumnHandle : public ISerializable { using ColumnHandlePtr = std::shared_ptr; +using ColumnHandleMap = + std::unordered_map; + +class ConnectorTableHandle; +using ConnectorTableHandlePtr = std::shared_ptr; + class ConnectorTableHandle : public ISerializable { public: explicit ConnectorTableHandle(std::string connectorId) @@ -111,26 +127,22 @@ class ConnectorTableHandle : public ISerializable { virtual ~ConnectorTableHandle() = default; - virtual std::string toString() const { - VELOX_NYI(); - } - const std::string& connectorId() const { return connectorId_; } - /// Returns the connector-dependent table name. Used with - /// ConnectorMetadata. Implementations need to supply a definition - /// to work with metadata. - virtual const std::string& name() const { - VELOX_UNSUPPORTED(); - } + /// Returns the table name. + virtual const std::string& name() const = 0; /// Returns true if the connector table handle supports index lookup. virtual bool supportsIndexLookup() const { return false; } + virtual std::string toString() const { + return name(); + } + virtual folly::dynamic serialize() const override; protected: @@ -140,8 +152,6 @@ class ConnectorTableHandle : public ISerializable { const std::string connectorId_; }; -using ConnectorTableHandlePtr = std::shared_ptr; - /// Represents a request for writing to connector class ConnectorInsertTableHandle : public ISerializable { public: @@ -160,26 +170,18 @@ class ConnectorInsertTableHandle : public ISerializable { } }; +using ConnectorInsertTableHandlePtr = + std::shared_ptr; + /// Represents the commit strategy for writing to connector. enum class CommitStrategy { /// No more commit actions are needed. kNoCommit, /// Task level commit is needed. - kTaskCommit + kTaskCommit, }; -/// Return a string encoding of the given commit strategy. -std::string commitStrategyToString(CommitStrategy commitStrategy); - -FOLLY_ALWAYS_INLINE std::ostream& operator<<( - std::ostream& os, - CommitStrategy strategy) { - os << commitStrategyToString(strategy); - return os; -} - -/// Return a commit strategy of the given string encoding. -CommitStrategy stringToCommitStrategy(const std::string& strategy); +VELOX_DECLARE_ENUM_NAME(CommitStrategy); /// Writes data received from table writer operator into different partitions /// based on the specific table layout. The actual implementation doesn't need @@ -225,6 +227,10 @@ class DataSink { /// Returns the stats of this data sink. virtual Stats stats() const = 0; + + virtual std::unordered_map runtimeStats() const { + return {}; + } }; class DataSource { @@ -247,6 +253,10 @@ class DataSource { uint64_t size, velox::ContinueFuture& future) = 0; + virtual const common::SubfieldFilters* getFilters() const { + return nullptr; + } + /// Add dynamically generated filter. /// @param outputChannel index into outputType specified in /// Connector::createDataSource() that identifies the column this filter @@ -261,7 +271,9 @@ class DataSource { /// Returns the number of input rows processed so far. virtual uint64_t getCompletedRows() = 0; - virtual std::unordered_map runtimeStats() = 0; + virtual std::unordered_map getRuntimeStats() { + return {}; + } /// Returns true if 'this' has initiated all the prefetch this will initiate. /// This means that the caller should schedule next splits to prefetch in the @@ -361,6 +373,11 @@ class IndexSource { public: virtual ~LookupResultIterator() = default; + /// Invoked to check if there are more lookup results available to fetch. + /// Returns true if there are more results, false otherwise. This allows + /// the caller to determine whether to continue calling 'next()'. + virtual bool hasNext() = 0; + /// Invoked to fetch up to 'size' number of output rows. Returns nullptr if /// all the lookup results have been fetched. Returns std::nullopt and sets /// the 'future' if started asynchronous work and needs to wait for it to @@ -399,7 +416,8 @@ class ConnectorQueryCtx { int driverId, const std::string& sessionTimezone, bool adjustTimestampToTimezone = false, - folly::CancellationToken cancellationToken = {}) + folly::CancellationToken cancellationToken = {}, + std::shared_ptr tokenProvider = {}) : operatorPool_(operatorPool), connectorPool_(connectorPool), sessionProperties_(sessionProperties), @@ -414,7 +432,8 @@ class ConnectorQueryCtx { planNodeId_(planNodeId), sessionTimezone_(sessionTimezone), adjustTimestampToTimezone_(adjustTimestampToTimezone), - cancellationToken_(std::move(cancellationToken)) { + cancellationToken_(std::move(cancellationToken)), + fsTokenProvider_(std::move(tokenProvider)) { VELOX_CHECK_NOT_NULL(sessionProperties); } @@ -502,6 +521,18 @@ class ConnectorQueryCtx { selectiveNimbleReaderEnabled_ = value; } + core::QueryConfig::RowSizeTrackingMode rowSizeTrackingMode() const { + return rowSizeTrackingEnabled_; + } + + void setRowSizeTrackingMode(core::QueryConfig::RowSizeTrackingMode value) { + rowSizeTrackingEnabled_ = value; + } + + std::shared_ptr fsTokenProvider() const { + return fsTokenProvider_; + } + private: memory::MemoryPool* const operatorPool_; memory::MemoryPool* const connectorPool_; @@ -518,14 +549,40 @@ class ConnectorQueryCtx { const std::string sessionTimezone_; const bool adjustTimestampToTimezone_; const folly::CancellationToken cancellationToken_; + const std::shared_ptr fsTokenProvider_; bool selectiveNimbleReaderEnabled_{false}; + core::QueryConfig::RowSizeTrackingMode rowSizeTrackingEnabled_{ + core::QueryConfig::RowSizeTrackingMode::ENABLED_FOR_ALL}; }; -class ConnectorMetadata; +class Connector; + +class ConnectorFactory { + public: + explicit ConnectorFactory(const char* name) : name_(name) {} + + virtual ~ConnectorFactory() = default; + + const std::string& connectorName() const { + return name_; + } + + virtual std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + folly::Executor* cpuExecutor = nullptr) = 0; + + private: + const std::string name_; +}; class Connector { public: - explicit Connector(const std::string& id) : id_(id) {} + explicit Connector( + const std::string& id, + std::shared_ptr config = nullptr) + : id_(id), config_(std::move(config)) {} virtual ~Connector() = default; @@ -533,9 +590,8 @@ class Connector { return id_; } - virtual const std::shared_ptr& connectorConfig() - const { - VELOX_NYI("connectorConfig is not supported yet"); + const std::shared_ptr& connectorConfig() const { + return config_; } /// Returns true if this connector would accept a filter dynamically @@ -544,25 +600,17 @@ class Connector { return false; } - /// Returns a ConnectorMetadata for accessing table - /// information. - virtual ConnectorMetadata* metadata() const { - VELOX_UNSUPPORTED(); - } - virtual std::unique_ptr createDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) = 0; /// Returns true if addSplit of DataSource can use 'dataSource' from /// ConnectorSplit in addSplit(). If so, TableScan can preload splits /// so that file opening and metadata operations are off the Driver' /// thread. - virtual bool supportsSplitPreload() { + virtual bool supportsSplitPreload() const { return false; } @@ -613,10 +661,8 @@ class Connector { const std::vector>& joinConditions, const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) { VELOX_UNSUPPORTED( "Connector {} does not support index source", connectorId()); @@ -624,7 +670,7 @@ class Connector { virtual std::unique_ptr createDataSink( RowTypePtr inputType, - std::shared_ptr connectorInsertTableHandle, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy) = 0; @@ -636,65 +682,43 @@ class Connector { const std::string& scanId, int32_t loadQuantum); + /// Returns the IOExecutor used by the connector. It is used to run async IO + /// operations by the connector. + virtual folly::Executor* ioExecutor() const { + return nullptr; + } + + // This is for backward compatibility, todo: remove after verax repo is + // updated virtual folly::Executor* executor() const { return nullptr; } + /// The name of the common runtime stats collected and reported by connector + /// data/index sources. + static inline const std::string kTotalRemainingFilterTime{ + "totalRemainingFilterWallNanos"}; + private: static void unregisterTracker(cache::ScanTracker* tracker); const std::string id_; + const std::shared_ptr config_; static folly::Synchronized< std::unordered_map>> trackers_; }; -class ConnectorFactory { - public: - explicit ConnectorFactory(const char* name) : name_(name) {} - - virtual ~ConnectorFactory() = default; - - const std::string& connectorName() const { - return name_; - } - - virtual std::shared_ptr newConnector( - const std::string& id, - std::shared_ptr config, - folly::Executor* ioExecutor = nullptr, - folly::Executor* cpuExecutor = nullptr) = 0; - - private: - const std::string name_; -}; - -/// Adds a factory for creating connectors to the registry using connector -/// name as the key. Throws if factor with the same name is already present. -/// Always returns true. The return value makes it easy to use with -/// FB_ANONYMOUS_VARIABLE. -bool registerConnectorFactory(std::shared_ptr factory); - -/// Returns true if a connector with the specified name has been registered, -/// false otherwise. -bool hasConnectorFactory(const std::string& connectorName); - -/// Unregister a connector factory by name. -/// Returns true if a connector with the specified name has been -/// unregistered, false otherwise. -bool unregisterConnectorFactory(const std::string& connectorName); - -/// Returns a factory for creating connectors with the specified name. -/// Throws if factory doesn't exist. -std::shared_ptr getConnectorFactory( - const std::string& connectorName); - /// Adds connector instance to the registry using connector ID as the key. /// Throws if connector with the same ID is already present. Always returns /// true. The return value makes it easy to use with FB_ANONYMOUS_VARIABLE. bool registerConnector(std::shared_ptr connector); +/// Returns true if a connector with the specified ID has been registered, false +/// otherwise. +bool hasConnector(const std::string& connectorId); + /// Removes the connector with specified ID from the registry. Returns true /// if connector was removed and false if connector didn't exist. bool unregisterConnector(const std::string& connectorId); diff --git a/velox/connectors/clp/CMakeLists.txt b/velox/connectors/clp/CMakeLists.txt index 6f5d8633a7fd..ecfdc19fe06d 100644 --- a/velox/connectors/clp/CMakeLists.txt +++ b/velox/connectors/clp/CMakeLists.txt @@ -19,11 +19,10 @@ velox_add_library( ClpConfig.cpp ClpConnector.cpp ClpDataSource.cpp - ClpTableHandle.cpp) + ClpTableHandle.cpp +) -velox_link_libraries(velox_clp_connector - PRIVATE clp-s-search simdjson::simdjson velox_connector) -target_compile_features(velox_clp_connector PRIVATE cxx_std_20) +velox_link_libraries(velox_clp_connector PRIVATE clp-s-search simdjson::simdjson velox_connector) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/clp/ClpColumnHandle.h b/velox/connectors/clp/ClpColumnHandle.h index d69efdd1749f..6b0493f789bd 100644 --- a/velox/connectors/clp/ClpColumnHandle.h +++ b/velox/connectors/clp/ClpColumnHandle.h @@ -30,6 +30,10 @@ class ClpColumnHandle : public ColumnHandle { originalColumnName_(originalColumnName), columnType_(columnType) {} + const std::string& name() const override { + return columnName_; + } + const std::string& columnName() const { return columnName_; } diff --git a/velox/connectors/clp/ClpConnector.cpp b/velox/connectors/clp/ClpConnector.cpp index 25f90171c271..1c4e3647ce7e 100644 --- a/velox/connectors/clp/ClpConnector.cpp +++ b/velox/connectors/clp/ClpConnector.cpp @@ -14,8 +14,6 @@ * limitations under the License. */ -#include "clp_s/TimestampPattern.hpp" - #include "velox/connectors/clp/ClpConnector.h" #include "velox/connectors/clp/ClpDataSource.h" @@ -28,10 +26,8 @@ ClpConnector::ClpConnector( std::unique_ptr ClpConnector::createDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) { return std::make_unique( outputType, @@ -43,20 +39,16 @@ std::unique_ptr ClpConnector::createDataSource( std::unique_ptr ClpConnector::createDataSink( RowTypePtr inputType, - std::shared_ptr connectorInsertTableHandle, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy) { VELOX_NYI("createDataSink for ClpConnector is not implemented!"); } ClpConnectorFactory::ClpConnectorFactory() - : ConnectorFactory(kClpConnectorName) { - clp_s::TimestampPattern::init(); -} + : ConnectorFactory(kClpConnectorName) {} ClpConnectorFactory::ClpConnectorFactory(const char* connectorName) - : ConnectorFactory(connectorName) { - clp_s::TimestampPattern::init(); -} + : ConnectorFactory(connectorName) {} } // namespace facebook::velox::connector::clp diff --git a/velox/connectors/clp/ClpConnector.h b/velox/connectors/clp/ClpConnector.h index 59efe8d8bb7b..ddccfdbbd3ec 100644 --- a/velox/connectors/clp/ClpConnector.h +++ b/velox/connectors/clp/ClpConnector.h @@ -28,7 +28,7 @@ class ClpConnector : public Connector { std::shared_ptr config); [[nodiscard]] const std::shared_ptr& - connectorConfig() const override { + connectorConfig() const { return config_->config(); } @@ -38,19 +38,17 @@ class ClpConnector : public Connector { std::unique_ptr createDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) override; - bool supportsSplitPreload() override { + bool supportsSplitPreload() const override { return false; } std::unique_ptr createDataSink( RowTypePtr inputType, - std::shared_ptr connectorInsertTableHandle, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy) override; diff --git a/velox/connectors/clp/ClpDataSource.cpp b/velox/connectors/clp/ClpDataSource.cpp index ac038869bf93..40e09296f956 100644 --- a/velox/connectors/clp/ClpDataSource.cpp +++ b/velox/connectors/clp/ClpDataSource.cpp @@ -29,14 +29,13 @@ namespace facebook::velox::connector::clp { ClpDataSource::ClpDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, velox::memory::MemoryPool* pool, std::shared_ptr& clpConfig) : pool_(pool), outputType_(outputType) { - auto clpTableHandle = std::dynamic_pointer_cast(tableHandle); + auto clpTableHandle = + std::dynamic_pointer_cast(tableHandle); storageType_ = clpConfig->storageType(); s3AuthProvider_ = clpConfig->s3AuthProvider(); @@ -47,7 +46,7 @@ ClpDataSource::ClpDataSource( "ColumnHandle not found for output name: {}", outputName); auto clpColumnHandle = - std::dynamic_pointer_cast(columnHandle->second); + std::dynamic_pointer_cast(columnHandle->second); VELOX_CHECK_NOT_NULL( clpColumnHandle, "ColumnHandle must be an instance of ClpColumnHandle for output name: {}", diff --git a/velox/connectors/clp/ClpDataSource.h b/velox/connectors/clp/ClpDataSource.h index a500219f3f82..66efbd47081c 100644 --- a/velox/connectors/clp/ClpDataSource.h +++ b/velox/connectors/clp/ClpDataSource.h @@ -34,10 +34,8 @@ class ClpDataSource : public DataSource { public: ClpDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, velox::memory::MemoryPool* pool, std::shared_ptr& clpConfig); @@ -60,7 +58,7 @@ class ClpDataSource : public DataSource { return completedRows_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { return {}; } diff --git a/velox/connectors/clp/ClpTableHandle.h b/velox/connectors/clp/ClpTableHandle.h index 357fcd102c9a..7ef0b0526cbd 100644 --- a/velox/connectors/clp/ClpTableHandle.h +++ b/velox/connectors/clp/ClpTableHandle.h @@ -25,6 +25,10 @@ class ClpTableHandle : public ConnectorTableHandle { ClpTableHandle(const std::string& connectorId, const std::string& tableName) : ConnectorTableHandle(connectorId), tableName_(tableName) {} + const std::string& name() const override { + return tableName_; + } + [[nodiscard]] const std::string& tableName() const { return tableName_; } diff --git a/velox/connectors/clp/search_lib/CMakeLists.txt b/velox/connectors/clp/search_lib/CMakeLists.txt index 69451d2fc79f..f675e0b0a423 100644 --- a/velox/connectors/clp/search_lib/CMakeLists.txt +++ b/velox/connectors/clp/search_lib/CMakeLists.txt @@ -20,14 +20,10 @@ velox_add_library( ClpPackageS3AuthProvider.h ClpS3AuthProviderBase.cpp ClpS3AuthProviderBase.h - ClpTimestampsUtils.h) + ClpTimestampsUtils.h +) add_subdirectory(archive) add_subdirectory(ir) -velox_link_libraries( - clp-s-search - PUBLIC clp-s-archive-search - PUBLIC clp-s-ir-search) - -target_compile_features(clp-s-search PRIVATE cxx_std_20) +velox_link_libraries(clp-s-search PUBLIC clp-s-archive-search PUBLIC clp-s-ir-search) diff --git a/velox/connectors/clp/search_lib/archive/CMakeLists.txt b/velox/connectors/clp/search_lib/archive/CMakeLists.txt index d1739d9e2efb..806350505c1f 100644 --- a/velox/connectors/clp/search_lib/archive/CMakeLists.txt +++ b/velox/connectors/clp/search_lib/archive/CMakeLists.txt @@ -17,11 +17,11 @@ velox_add_library( ClpArchiveCursor.cpp ClpArchiveJsonStringVectorLoader.cpp ClpArchiveVectorLoader.cpp - ClpQueryRunner.cpp) + ClpQueryRunner.cpp +) velox_link_libraries( clp-s-archive-search PUBLIC clp_s::archive_reader velox_vector - PRIVATE clp_s::clp_dependencies clp_s::io clp_s::search clp_s::search::kql) - -target_compile_features(clp-s-archive-search PRIVATE cxx_std_20) + PRIVATE clp_s::clp_dependencies clp_s::io clp_s::search clp_s::search::kql +) diff --git a/velox/connectors/clp/search_lib/archive/ClpArchiveVectorLoader.cpp b/velox/connectors/clp/search_lib/archive/ClpArchiveVectorLoader.cpp index 14a51d6674be..eb2434726ae0 100644 --- a/velox/connectors/clp/search_lib/archive/ClpArchiveVectorLoader.cpp +++ b/velox/connectors/clp/search_lib/archive/ClpArchiveVectorLoader.cpp @@ -72,7 +72,7 @@ void ClpArchiveVectorLoader::populateTimestampData( case clp_s::NodeType::FormattedFloat: case clp_s::NodeType::DictionaryFloat: case clp_s::NodeType::Integer: - case clp_s::NodeType::DateString: + case clp_s::NodeType::DeprecatedDateString: supportedNodeType = true; break; default: @@ -115,7 +115,8 @@ void ClpArchiveVectorLoader::populateTimestampData( convertToVeloxTimestamp( std::get(reader->extract_value(messageIndex)))); } else { - auto reader = static_cast(columnReader_); + auto reader = + static_cast(columnReader_); vector->set( vectorIndex, convertToVeloxTimestamp(reader->get_encoded_time(messageIndex))); @@ -214,8 +215,9 @@ void ClpArchiveVectorLoader::loadInternal( populateTimestampData(rows, timestampVector); } else if ( nullptr != - dynamic_cast(columnReader_)) { - populateTimestampData( + dynamic_cast( + columnReader_)) { + populateTimestampData( rows, timestampVector); } else if ( nullptr != dynamic_cast(columnReader_)) { @@ -261,8 +263,8 @@ template void ClpArchiveVectorLoader::populateTimestampData( RowSet rows, FlatVector* vector); -template void -ClpArchiveVectorLoader::populateTimestampData( +template void ClpArchiveVectorLoader::populateTimestampData< + clp_s::NodeType::DeprecatedDateString>( RowSet rows, FlatVector* vector); template void diff --git a/velox/connectors/clp/search_lib/ir/CMakeLists.txt b/velox/connectors/clp/search_lib/ir/CMakeLists.txt index 568492f385a4..964c24a66faf 100644 --- a/velox/connectors/clp/search_lib/ir/CMakeLists.txt +++ b/velox/connectors/clp/search_lib/ir/CMakeLists.txt @@ -16,7 +16,8 @@ velox_add_library( STATIC ClpIrCursor.cpp ClpIrJsonStringVectorLoader.cpp - ClpIrVectorLoader.cpp) + ClpIrVectorLoader.cpp +) velox_link_libraries( clp-s-ir-search @@ -29,4 +30,5 @@ velox_link_libraries( clp_s::search clp_s::search::ast clp_s::search::kql - velox_vector) + velox_vector +) diff --git a/velox/connectors/clp/tests/CMakeLists.txt b/velox/connectors/clp/tests/CMakeLists.txt index 5f7fd1f57c88..2b26c3a4dcf6 100644 --- a/velox/connectors/clp/tests/CMakeLists.txt +++ b/velox/connectors/clp/tests/CMakeLists.txt @@ -21,7 +21,7 @@ target_link_libraries( velox_exec_test_lib velox_vector_test_lib GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/velox/connectors/clp/tests/ClpConnectorTest.cpp b/velox/connectors/clp/tests/ClpConnectorTest.cpp index 4f2963fe97bc..114d5e50cbdd 100644 --- a/velox/connectors/clp/tests/ClpConnectorTest.cpp +++ b/velox/connectors/clp/tests/ClpConnectorTest.cpp @@ -45,22 +45,18 @@ class ClpConnectorTest : public exec::test::OperatorTestBase { void SetUp() override { OperatorTestBase::SetUp(); - connector::registerConnectorFactory( - std::make_shared()); - auto clpConnector = - connector::getConnectorFactory( - connector::clp::ClpConnectorFactory::kClpConnectorName) - ->newConnector( - kClpConnectorId, - std::make_shared( - std::unordered_map{})); + connector::clp::ClpConnectorFactory factory; + auto clpConnector = factory.newConnector( + kClpConnectorId, + std::make_shared( + std::unordered_map{}), + nullptr, + nullptr); connector::registerConnector(clpConnector); } void TearDown() override { connector::unregisterConnector(kClpConnectorId); - connector::unregisterConnectorFactory( - connector::clp::ClpConnectorFactory::kClpConnectorName); OperatorTestBase::TearDown(); } @@ -68,8 +64,9 @@ class ClpConnectorTest : public exec::test::OperatorTestBase { const std::string& splitPath, ClpConnectorSplit::SplitType type, std::shared_ptr kqlQuery) { - return exec::Split(std::make_shared( - kClpConnectorId, splitPath, static_cast(type), kqlQuery)); + return exec::Split( + std::make_shared( + kClpConnectorId, splitPath, static_cast(type), kqlQuery)); } RowVectorPtr getResults( @@ -88,27 +85,28 @@ class ClpConnectorTest : public exec::test::OperatorTestBase { TEST_F(ClpConnectorTest, test1NoPushdown) { const std::shared_ptr kqlQuery = nullptr; - auto plan = PlanBuilder() - .startTableScan() - .outputType( - ROW({"requestId", "userId", "method"}, - {VARCHAR(), VARCHAR(), VARCHAR()})) - .tableHandle(std::make_shared( - kClpConnectorId, "test_1")) - .assignments({ - {"requestId", - std::make_shared( - "requestId", "requestId", VARCHAR())}, - {"userId", - std::make_shared( - "userId", "userId", VARCHAR())}, - {"method", - std::make_shared( - "method", "method", VARCHAR())}, - }) - .endTableScan() - .filter("method = 'GET'") - .planNode(); + auto plan = + PlanBuilder() + .startTableScan() + .outputType( + ROW({"requestId", "userId", "method"}, + {VARCHAR(), VARCHAR(), VARCHAR()})) + .tableHandle( + std::make_shared(kClpConnectorId, "test_1")) + .assignments({ + {"requestId", + std::make_shared( + "requestId", "requestId", VARCHAR())}, + {"userId", + std::make_shared( + "userId", "userId", VARCHAR())}, + {"method", + std::make_shared( + "method", "method", VARCHAR())}, + }) + .endTableScan() + .filter("method = 'GET'") + .planNode(); auto output = getResults( plan, @@ -189,13 +187,13 @@ TEST_F(ClpConnectorTest, test1Pushdown) { getExampleFilePath("test_1.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = - makeRowVector({// requestId - makeFlatVector({"req-106"}), - // userId - makeNullableFlatVector({std::nullopt}), - // path - makeFlatVector({"/auth/login"})}); + auto expected = makeRowVector( + {// requestId + makeFlatVector({"req-106"}), + // userId + makeNullableFlatVector({std::nullopt}), + // path + makeFlatVector({"/auth/login"})}); test::assertEqualVectors(expected, output); auto irOutput = getResults( @@ -209,27 +207,28 @@ TEST_F(ClpConnectorTest, test1Pushdown) { TEST_F(ClpConnectorTest, test1JsonString) { const std::shared_ptr kqlQuery = nullptr; - auto plan = PlanBuilder() - .startTableScan() - .outputType( - ROW({"requestId", "__json_string", "method"}, - {VARCHAR(), VARCHAR(), VARCHAR()})) - .tableHandle(std::make_shared( - kClpConnectorId, "test_1")) - .assignments({ - {"requestId", - std::make_shared( - "requestId", "requestId", VARCHAR())}, - {"__json_string", - std::make_shared( - "__json_string", "__json_string", VARCHAR())}, - {"method", - std::make_shared( - "method", "method", VARCHAR())}, - }) - .endTableScan() - .filter("method = 'GET'") - .planNode(); + auto plan = + PlanBuilder() + .startTableScan() + .outputType( + ROW({"requestId", "__json_string", "method"}, + {VARCHAR(), VARCHAR(), VARCHAR()})) + .tableHandle( + std::make_shared(kClpConnectorId, "test_1")) + .assignments({ + {"requestId", + std::make_shared( + "requestId", "requestId", VARCHAR())}, + {"__json_string", + std::make_shared( + "__json_string", "__json_string", VARCHAR())}, + {"method", + std::make_shared( + "method", "method", VARCHAR())}, + }) + .endTableScan() + .filter("method = 'GET'") + .planNode(); const auto methodVector = makeFlatVector({ "GET", @@ -318,19 +317,19 @@ TEST_F(ClpConnectorTest, test2NoPushdown) { getExampleFilePath("test_2.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = - makeRowVector({// timestamp - makeFlatVector({Timestamp( - kTestTimestampSeconds, kTestTimestampNanoseconds)}), - // event - makeRowVector({ - // event.type - makeFlatVector({"storage"}), - // event.subtype - makeFlatVector({"disk_usage"}), - // event.severity - makeFlatVector({"WARNING"}), - })}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(kTestTimestampSeconds, kTestTimestampNanoseconds)}), + // event + makeRowVector({ + // event.type + makeFlatVector({"storage"}), + // event.subtype + makeFlatVector({"disk_usage"}), + // event.severity + makeFlatVector({"WARNING"}), + })}); test::assertEqualVectors(expected, output); auto irOutput = getResults( @@ -347,27 +346,28 @@ TEST_F(ClpConnectorTest, test2Pushdown) { "(event.severity: \"WARNING\" OR event.severity: \"ERROR\") AND " "((event.type: \"network\" AND event.subtype: \"connection\") OR " "(event.type: \"storage\" AND event.subtype: \"disk*\"))"); - auto plan = PlanBuilder() - .startTableScan() - .outputType( - ROW({"timestamp", "event"}, - {TIMESTAMP(), - ROW({"type", "subtype", "severity"}, - {VARCHAR(), VARCHAR(), VARCHAR()})})) - .tableHandle(std::make_shared( - kClpConnectorId, "test_2")) - .assignments( - {{"timestamp", - std::make_shared( - "timestamp", "timestamp", TIMESTAMP())}, - {"event", - std::make_shared( - "event", - "event", - ROW({"type", "subtype", "severity"}, - {VARCHAR(), VARCHAR(), VARCHAR()}))}}) - .endTableScan() - .planNode(); + auto plan = + PlanBuilder() + .startTableScan() + .outputType( + ROW({"timestamp", "event"}, + {TIMESTAMP(), + ROW({"type", "subtype", "severity"}, + {VARCHAR(), VARCHAR(), VARCHAR()})})) + .tableHandle( + std::make_shared(kClpConnectorId, "test_2")) + .assignments( + {{"timestamp", + std::make_shared( + "timestamp", "timestamp", TIMESTAMP())}, + {"event", + std::make_shared( + "event", + "event", + ROW({"type", "subtype", "severity"}, + {VARCHAR(), VARCHAR(), VARCHAR()}))}}) + .endTableScan() + .planNode(); auto output = getResults( plan, @@ -375,19 +375,19 @@ TEST_F(ClpConnectorTest, test2Pushdown) { getExampleFilePath("test_2.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = - makeRowVector({// timestamp - makeFlatVector({Timestamp( - kTestTimestampSeconds, kTestTimestampNanoseconds)}), - // event - makeRowVector({ - // event.type - makeFlatVector({"storage"}), - // event.subtype - makeFlatVector({"disk_usage"}), - // event.severity - makeFlatVector({"WARNING"}), - })}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(kTestTimestampSeconds, kTestTimestampNanoseconds)}), + // event + makeRowVector({ + // event.type + makeFlatVector({"storage"}), + // event.subtype + makeFlatVector({"disk_usage"}), + // event.severity + makeFlatVector({"WARNING"}), + })}); test::assertEqualVectors(expected, output); auto irOutput = getResults( @@ -438,15 +438,16 @@ TEST_F(ClpConnectorTest, test2Hybrid) { makeFlatVector( {Timestamp(kTestTimestampSeconds, kTestTimestampNanoseconds)}), // event - makeRowVector({// event.type - makeFlatVector({"storage"}), - // event.subtype - makeFlatVector({"disk_usage"}), - // event.severity - makeFlatVector({"WARNING"}), - // event.tags - makeArrayVector( - {{"\"filesystem\"", "\"monitoring\""}})}) + makeRowVector( + {// event.type + makeFlatVector({"storage"}), + // event.subtype + makeFlatVector({"disk_usage"}), + // event.severity + makeFlatVector({"WARNING"}), + // event.tags + makeArrayVector( + {{"\"filesystem\"", "\"monitoring\""}})}) }); test::assertEqualVectors(expected, output); @@ -465,21 +466,22 @@ TEST_F(ClpConnectorTest, test2JsonString) { "(event.severity: \"WARNING\" OR event.severity: \"ERROR\") AND " "((event.type: \"network\" AND event.subtype: \"connection\") OR " "(event.type: \"storage\" AND event.subtype: \"disk*\"))"); - auto plan = PlanBuilder() - .startTableScan() - .outputType(ROW( - {"timestamp", "__json_string"}, {TIMESTAMP(), VARCHAR()})) - .tableHandle(std::make_shared( - kClpConnectorId, "test_2")) - .assignments( - {{"timestamp", - std::make_shared( - "timestamp", "timestamp", TIMESTAMP())}, - {"__json_string", - std::make_shared( - "__json_string", "__json_string", VARCHAR())}}) - .endTableScan() - .planNode(); + auto plan = + PlanBuilder() + .startTableScan() + .outputType( + ROW({"timestamp", "__json_string"}, {TIMESTAMP(), VARCHAR()})) + .tableHandle( + std::make_shared(kClpConnectorId, "test_2")) + .assignments( + {{"timestamp", + std::make_shared( + "timestamp", "timestamp", TIMESTAMP())}, + {"__json_string", + std::make_shared( + "__json_string", "__json_string", VARCHAR())}}) + .endTableScan() + .planNode(); auto output = getResults( plan, @@ -515,17 +517,18 @@ TEST_F(ClpConnectorTest, test2JsonString) { TEST_F(ClpConnectorTest, test3TimestampMarshalling) { const std::shared_ptr kqlQuery = nullptr; - auto plan = PlanBuilder(pool_.get()) - .startTableScan() - .outputType(ROW({"timestamp"}, {TIMESTAMP()})) - .tableHandle(std::make_shared( - kClpConnectorId, "test_3")) - .assignments( - {{"timestamp", - std::make_shared( - "timestamp", "timestamp", TIMESTAMP())}}) - .endTableScan() - .planNode(); + auto plan = + PlanBuilder(pool_.get()) + .startTableScan() + .outputType(ROW({"timestamp"}, {TIMESTAMP()})) + .tableHandle( + std::make_shared(kClpConnectorId, "test_3")) + .assignments( + {{"timestamp", + std::make_shared( + "timestamp", "timestamp", TIMESTAMP())}}) + .endTableScan() + .planNode(); auto output = getResults( plan, @@ -546,18 +549,19 @@ TEST_F(ClpConnectorTest, test3TimestampMarshalling) { TEST_F(ClpConnectorTest, test4IrTimestampNoPushdown) { const std::shared_ptr kqlQuery = nullptr; - auto plan = PlanBuilder(pool_.get()) - .startTableScan() - .outputType(ROW({"timestamp"}, {TIMESTAMP()})) - .tableHandle(std::make_shared( - kClpConnectorId, "test_4")) - .assignments( - {{"timestamp", - std::make_shared( - "timestamp", "timestamp", TIMESTAMP())}}) - .endTableScan() - .filter("\"timestamp\" < timestamp '2025-08-24 02:36:45'") - .planNode(); + auto plan = + PlanBuilder(pool_.get()) + .startTableScan() + .outputType(ROW({"timestamp"}, {TIMESTAMP()})) + .tableHandle( + std::make_shared(kClpConnectorId, "test_4")) + .assignments( + {{"timestamp", + std::make_shared( + "timestamp", "timestamp", TIMESTAMP())}}) + .endTableScan() + .filter("\"timestamp\" < timestamp '2025-08-24 02:36:45'") + .planNode(); auto output = getResults( plan, @@ -578,17 +582,18 @@ TEST_F(ClpConnectorTest, test4IrTimestampPushdown) { // which is not supported yet so the value will be NULL. const std::shared_ptr kqlQuery = std::make_shared("(timestamp < 1756003005000000)"); - auto plan = PlanBuilder(pool_.get()) - .startTableScan() - .outputType(ROW({"timestamp"}, {TIMESTAMP()})) - .tableHandle(std::make_shared( - kClpConnectorId, "test_4")) - .assignments( - {{"timestamp", - std::make_shared( - "timestamp", "timestamp", TIMESTAMP())}}) - .endTableScan() - .planNode(); + auto plan = + PlanBuilder(pool_.get()) + .startTableScan() + .outputType(ROW({"timestamp"}, {TIMESTAMP()})) + .tableHandle( + std::make_shared(kClpConnectorId, "test_4")) + .assignments( + {{"timestamp", + std::make_shared( + "timestamp", "timestamp", TIMESTAMP())}}) + .endTableScan() + .planNode(); auto output = getResults( plan, @@ -633,23 +638,24 @@ TEST_F(ClpConnectorTest, test5FloatTimestampNoPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = makeRowVector({// timestamp - makeFlatVector( - {Timestamp(1746003005, 124000000), - Timestamp(1746003005, 124100000), - Timestamp(1746003005, 125000000), - Timestamp(1746003005, 126000000), - Timestamp(1746003005, 127000000), - Timestamp(1746003060, 0), - Timestamp(1746003065, 0)}), - makeFlatVector( - {1.2345678912345E9, - 1E16, - 1.234567891234567E9, - 1.234567891234567E9, - -1.234567891234567E-9, - 1234567891.234567, - -1234567891.234567})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(1746003005, 124000000), + Timestamp(1746003005, 124100000), + Timestamp(1746003005, 125000000), + Timestamp(1746003005, 126000000), + Timestamp(1746003005, 127000000), + Timestamp(1746003060, 0), + Timestamp(1746003065, 0)}), + makeFlatVector( + {1.2345678912345E9, + 1E16, + 1.234567891234567E9, + 1.234567891234567E9, + -1.234567891234567E-9, + 1234567891.234567, + -1234567891.234567})}); test::assertEqualVectors(expected, output); } @@ -683,17 +689,18 @@ TEST_F(ClpConnectorTest, test5FloatTimestampPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = makeRowVector({// timestamp - makeFlatVector( - {Timestamp(1746003005, 124000000), - Timestamp(1746003005, 124100000), - Timestamp(1746003005, 125000000), - Timestamp(1746003005, 126000000)}), - makeFlatVector( - {1.234567891234500E9, - 1E16, - 1.234567891234567E9, - 1.234567891234567E9})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(1746003005, 124000000), + Timestamp(1746003005, 124100000), + Timestamp(1746003005, 125000000), + Timestamp(1746003005, 126000000)}), + makeFlatVector( + {1.234567891234500E9, + 1E16, + 1.234567891234567E9, + 1.234567891234567E9})}); test::assertEqualVectors(expected, output); } @@ -725,29 +732,30 @@ TEST_F(ClpConnectorTest, test5FormattedFloatNoPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = makeRowVector({// timestamp - makeFlatVector( - {Timestamp(1746003005, 123457000), - Timestamp(1746003115, 0), - Timestamp(1746003120, 0), - Timestamp(1746003125, 0), - Timestamp(1746003130, 0), - Timestamp(1746003135, 0), - Timestamp(1746003140, 0), - Timestamp(1746003145, 0), - Timestamp(1746003185, 0), - Timestamp(1746003190, 0)}), - makeFlatVector( - {1.2345678912345E-29, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(1746003005, 123457000), + Timestamp(1746003115, 0), + Timestamp(1746003120, 0), + Timestamp(1746003125, 0), + Timestamp(1746003130, 0), + Timestamp(1746003135, 0), + Timestamp(1746003140, 0), + Timestamp(1746003145, 0), + Timestamp(1746003185, 0), + Timestamp(1746003190, 0)}), + makeFlatVector( + {1.2345678912345E-29, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0})}); test::assertEqualVectors(expected, output); } @@ -778,29 +786,30 @@ TEST_F(ClpConnectorTest, test5FormattedFloatPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = makeRowVector({// timestamp - makeFlatVector( - {Timestamp(1746003005, 123457000), - Timestamp(1746003115, 0), - Timestamp(1746003120, 0), - Timestamp(1746003125, 0), - Timestamp(1746003130, 0), - Timestamp(1746003135, 0), - Timestamp(1746003140, 0), - Timestamp(1746003145, 0), - Timestamp(1746003185, 0), - Timestamp(1746003190, 0)}), - makeFlatVector( - {1.2345678912345E-29, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(1746003005, 123457000), + Timestamp(1746003115, 0), + Timestamp(1746003120, 0), + Timestamp(1746003125, 0), + Timestamp(1746003130, 0), + Timestamp(1746003135, 0), + Timestamp(1746003140, 0), + Timestamp(1746003145, 0), + Timestamp(1746003185, 0), + Timestamp(1746003190, 0)}), + makeFlatVector( + {1.2345678912345E-29, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0})}); test::assertEqualVectors(expected, output); } @@ -831,11 +840,11 @@ TEST_F(ClpConnectorTest, test5DictionaryFloatNoPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = - makeRowVector({// timestamp - makeFlatVector( - {Timestamp(1746003195, 0), Timestamp(1746003200, 0)}), - makeFlatVector({2, 2})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(1746003195, 0), Timestamp(1746003200, 0)}), + makeFlatVector({2, 2})}); test::assertEqualVectors(expected, output); } @@ -866,11 +875,11 @@ TEST_F(ClpConnectorTest, test5DictionaryFloatPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = - makeRowVector({// timestamp - makeFlatVector( - {Timestamp(1746003195, 0), Timestamp(1746003200, 0)}), - makeFlatVector({2, 2})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector( + {Timestamp(1746003195, 0), Timestamp(1746003200, 0)}), + makeFlatVector({2, 2})}); test::assertEqualVectors(expected, output); } @@ -901,16 +910,17 @@ TEST_F(ClpConnectorTest, test5HybridNoPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = makeRowVector({// timestamp - makeFlatVector({ - Timestamp(1746003105, 0), - Timestamp(1746003110, 0), - Timestamp(1746003150, 0), - Timestamp(1746003155, 0), - Timestamp(1746003160, 0), - Timestamp(1746003205, 0), - }), - makeFlatVector({1, 1, 1, 1, 1, 1})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector({ + Timestamp(1746003105, 0), + Timestamp(1746003110, 0), + Timestamp(1746003150, 0), + Timestamp(1746003155, 0), + Timestamp(1746003160, 0), + Timestamp(1746003205, 0), + }), + makeFlatVector({1, 1, 1, 1, 1, 1})}); test::assertEqualVectors(expected, output); } @@ -941,16 +951,17 @@ TEST_F(ClpConnectorTest, test5HybridPushdown) { getExampleFilePath("test_5.clps"), ClpConnectorSplit::SplitType::kArchive, kqlQuery)}); - auto expected = makeRowVector({// timestamp - makeFlatVector({ - Timestamp(1746003105, 0), - Timestamp(1746003110, 0), - Timestamp(1746003150, 0), - Timestamp(1746003155, 0), - Timestamp(1746003160, 0), - Timestamp(1746003205, 0), - }), - makeFlatVector({1, 1, 1, 1, 1, 1})}); + auto expected = makeRowVector( + {// timestamp + makeFlatVector({ + Timestamp(1746003105, 0), + Timestamp(1746003110, 0), + Timestamp(1746003150, 0), + Timestamp(1746003155, 0), + Timestamp(1746003160, 0), + Timestamp(1746003205, 0), + }), + makeFlatVector({1, 1, 1, 1, 1, 1})}); test::assertEqualVectors(expected, output); } diff --git a/velox/connectors/fuzzer/CMakeLists.txt b/velox/connectors/fuzzer/CMakeLists.txt index a777f8a29a4f..8e21030f99db 100644 --- a/velox/connectors/fuzzer/CMakeLists.txt +++ b/velox/connectors/fuzzer/CMakeLists.txt @@ -14,8 +14,7 @@ add_library(velox_fuzzer_connector OBJECT FuzzerConnector.cpp) -target_link_libraries( - velox_fuzzer_connector velox_connector velox_vector_fuzzer) +target_link_libraries(velox_fuzzer_connector velox_connector velox_vector_fuzzer) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/fuzzer/FuzzerConnector.cpp b/velox/connectors/fuzzer/FuzzerConnector.cpp index 53f0c1843664..f17b8faefd9c 100644 --- a/velox/connectors/fuzzer/FuzzerConnector.cpp +++ b/velox/connectors/fuzzer/FuzzerConnector.cpp @@ -20,12 +20,12 @@ namespace facebook::velox::connector::fuzzer { FuzzerDataSource::FuzzerDataSource( - const std::shared_ptr& outputType, - const std::shared_ptr& tableHandle, + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, velox::memory::MemoryPool* pool) : outputType_(outputType), pool_(pool) { auto fuzzerTableHandle = - std::dynamic_pointer_cast(tableHandle); + std::dynamic_pointer_cast(tableHandle); VELOX_CHECK_NOT_NULL( fuzzerTableHandle, "TableHandle must be an instance of FuzzerTableHandle"); diff --git a/velox/connectors/fuzzer/FuzzerConnector.h b/velox/connectors/fuzzer/FuzzerConnector.h index 64477b73ea36..5b3f9bf74e22 100644 --- a/velox/connectors/fuzzer/FuzzerConnector.h +++ b/velox/connectors/fuzzer/FuzzerConnector.h @@ -34,7 +34,7 @@ namespace facebook::velox::connector::fuzzer { class FuzzerTableHandle : public ConnectorTableHandle { public: - explicit FuzzerTableHandle( + FuzzerTableHandle( std::string connectorId, VectorFuzzer::Options options, size_t fuzzerSeed = 0) @@ -42,21 +42,20 @@ class FuzzerTableHandle : public ConnectorTableHandle { fuzzerOptions(options), fuzzerSeed(fuzzerSeed) {} - ~FuzzerTableHandle() override {} - - std::string toString() const override { - return "fuzzer-mock-table"; + const std::string& name() const override { + static const std::string kName = "fuzzer-mock-table"; + return kName; } const VectorFuzzer::Options fuzzerOptions; - size_t fuzzerSeed; + const size_t fuzzerSeed; }; class FuzzerDataSource : public DataSource { public: FuzzerDataSource( - const std::shared_ptr& outputType, - const std::shared_ptr& tableHandle, + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, velox::memory::MemoryPool* pool); void addSplit(std::shared_ptr split) override; @@ -78,7 +77,7 @@ class FuzzerDataSource : public DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { // TODO: Which stats do we want to expose here? return {}; } @@ -109,11 +108,9 @@ class FuzzerConnector final : public Connector { : Connector(id) {} std::unique_ptr createDataSource( - const std::shared_ptr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& /*columnHandles*/, + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& /*columnHandles*/, ConnectorQueryCtx* connectorQueryCtx) override final { return std::make_unique( outputType, tableHandle, connectorQueryCtx->memoryPool()); @@ -121,8 +118,7 @@ class FuzzerConnector final : public Connector { std::unique_ptr createDataSink( RowTypePtr /*inputType*/, - std::shared_ptr< - ConnectorInsertTableHandle> /*connectorInsertTableHandle*/, + ConnectorInsertTableHandlePtr /*connectorInsertTableHandle*/, ConnectorQueryCtx* /*connectorQueryCtx*/, CommitStrategy /*commitStrategy*/) override final { VELOX_NYI("FuzzerConnector does not support data sink."); diff --git a/velox/connectors/fuzzer/tests/CMakeLists.txt b/velox/connectors/fuzzer/tests/CMakeLists.txt index 803619339097..6f0dbf3b6472 100644 --- a/velox/connectors/fuzzer/tests/CMakeLists.txt +++ b/velox/connectors/fuzzer/tests/CMakeLists.txt @@ -22,4 +22,5 @@ target_link_libraries( velox_exec_test_lib velox_aggregates GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) diff --git a/velox/connectors/fuzzer/tests/FuzzerConnectorTest.cpp b/velox/connectors/fuzzer/tests/FuzzerConnectorTest.cpp index 5b5fe277ac59..a9900512cd4c 100644 --- a/velox/connectors/fuzzer/tests/FuzzerConnectorTest.cpp +++ b/velox/connectors/fuzzer/tests/FuzzerConnectorTest.cpp @@ -124,13 +124,13 @@ TEST_F(FuzzerConnectorTest, reproducible) { auto plan1 = PlanBuilder() .startTableScan() .outputType(type) - .tableHandle(makeFuzzerTableHandle(/*fuzerSeed=*/1234)) + .tableHandle(makeFuzzerTableHandle(/*fuzzerSeed=*/1234)) .endTableScan() .planNode(); auto plan2 = PlanBuilder() .startTableScan() .outputType(type) - .tableHandle(makeFuzzerTableHandle(/*fuzerSeed=*/1234)) + .tableHandle(makeFuzzerTableHandle(/*fuzzerSeed=*/1234)) .endTableScan() .planNode(); diff --git a/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h b/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h index 971b54dbe9b8..b0c703f54bc2 100644 --- a/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h +++ b/velox/connectors/fuzzer/tests/FuzzerConnectorTestBase.h @@ -26,20 +26,13 @@ class FuzzerConnectorTestBase : public exec::test::OperatorTestBase { void SetUp() override { OperatorTestBase::SetUp(); - connector::registerConnectorFactory( - std::make_shared()); - std::shared_ptr config; - auto fuzzerConnector = - connector::getConnectorFactory( - connector::fuzzer::FuzzerConnectorFactory::kFuzzerConnectorName) - ->newConnector(kFuzzerConnectorId, config); + connector::fuzzer::FuzzerConnectorFactory factory; + auto fuzzerConnector = factory.newConnector(kFuzzerConnectorId, nullptr); connector::registerConnector(fuzzerConnector); } void TearDown() override { connector::unregisterConnector(kFuzzerConnectorId); - connector::unregisterConnectorFactory( - connector::fuzzer::FuzzerConnectorFactory::kFuzzerConnectorName); OperatorTestBase::TearDown(); } diff --git a/velox/connectors/hive/BufferedInputBuilder.cpp b/velox/connectors/hive/BufferedInputBuilder.cpp new file mode 100644 index 000000000000..f4fb6c1a828c --- /dev/null +++ b/velox/connectors/hive/BufferedInputBuilder.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/BufferedInputBuilder.h" +#include "velox/connectors/hive/HiveConnectorUtil.h" + +namespace facebook::velox::connector::hive { + +class DefaultBufferInputBuilder : public BufferedInputBuilder { + public: + std::unique_ptr create( + const FileHandle& fileHandle, + const dwio::common::ReaderOptions& readerOpts, + const ConnectorQueryCtx* connectorQueryCtx, + std::shared_ptr ioStats, + std::shared_ptr fsStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps) override { + return createBufferedInput( + fileHandle, + readerOpts, + connectorQueryCtx, + ioStats, + fsStats, + executor, + fileReadOps); + } +}; + +// static +std::shared_ptr BufferedInputBuilder::builder_ = + std::make_shared(); + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/BufferedInputBuilder.h b/velox/connectors/hive/BufferedInputBuilder.h new file mode 100644 index 000000000000..d23049fed6bb --- /dev/null +++ b/velox/connectors/hive/BufferedInputBuilder.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/FileHandle.h" +#include "velox/dwio/common/BufferedInput.h" +#include "velox/dwio/common/Reader.h" + +namespace facebook::velox::connector::hive { + +/// Registering a different implementation of BufferedInput is allowed using +/// 'registerBuilder' API. +class BufferedInputBuilder { + public: + virtual ~BufferedInputBuilder() = default; + + static const std::shared_ptr& getInstance() { + VELOX_CHECK_NOT_NULL(builder_, "Builder is not registered"); + return builder_; + } + + static void registerBuilder(std::shared_ptr builder) { + VELOX_CHECK_NOT_NULL(builder); + builder_ = std::move(builder); + } + + virtual std::unique_ptr create( + const FileHandle& fileHandle, + const dwio::common::ReaderOptions& readerOpts, + const ConnectorQueryCtx* connectorQueryCtx, + std::shared_ptr ioStats, + std::shared_ptr fsStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps = {}) = 0; + + private: + static std::shared_ptr builder_; +}; + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/CMakeLists.txt b/velox/connectors/hive/CMakeLists.txt index 6c8aa05487b4..e6f980d8e3fc 100644 --- a/velox/connectors/hive/CMakeLists.txt +++ b/velox/connectors/hive/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(iceberg) velox_add_library( velox_hive_connector OBJECT + BufferedInputBuilder.cpp FileHandle.cpp HiveConfig.cpp HiveConnector.cpp @@ -27,16 +28,17 @@ velox_add_library( HiveConnectorSplit.cpp HiveDataSink.cpp HiveDataSource.cpp - HivePartitionUtil.cpp + HivePartitionName.cpp PartitionIdGenerator.cpp SplitReader.cpp - TableHandle.cpp) + TableHandle.cpp +) velox_link_libraries( velox_hive_connector PUBLIC velox_hive_iceberg_splitreader - PRIVATE velox_common_io velox_connector velox_dwio_catalog_fbhive - velox_hive_partition_function) + PRIVATE velox_common_io velox_connector velox_dwio_catalog_fbhive velox_hive_partition_function +) velox_add_library(velox_hive_partition_function HivePartitionFunction.cpp) diff --git a/velox/connectors/hive/FileHandle.cpp b/velox/connectors/hive/FileHandle.cpp index a75225cb4ee6..413ca86bef46 100644 --- a/velox/connectors/hive/FileHandle.cpp +++ b/velox/connectors/hive/FileHandle.cpp @@ -40,7 +40,7 @@ std::string groupName(const std::string& filename) { } // namespace std::unique_ptr FileHandleGenerator::operator()( - const std::string& filename, + const FileHandleKey& key, const FileProperties* properties, filesystems::File::IoStats* stats) { // We have seen cases where drivers are stuck when creating file handles. @@ -53,11 +53,14 @@ std::unique_ptr FileHandleGenerator::operator()( fileHandle = std::make_unique(); filesystems::FileOptions options; options.stats = stats; + options.tokenProvider = key.tokenProvider; if (properties) { options.fileSize = properties->fileSize; options.readRangeHint = properties->readRangeHint; options.extraFileInfo = properties->extraFileInfo; + options.fileReadOps = properties->fileReadOps; } + const auto& filename = key.filename; fileHandle->file = filesystems::getFileSystem(filename, properties_) ->openFileForRead(filename, options); fileHandle->uuid = StringIdLease(fileIds(), filename); diff --git a/velox/connectors/hive/FileHandle.h b/velox/connectors/hive/FileHandle.h index 0a053742ca56..6f9b4050c31e 100644 --- a/velox/connectors/hive/FileHandle.h +++ b/velox/connectors/hive/FileHandle.h @@ -25,10 +25,12 @@ #pragma once +#include "velox/common/base/BitUtil.h" #include "velox/common/caching/CachedFactory.h" #include "velox/common/caching/FileIds.h" #include "velox/common/config/Config.h" #include "velox/common/file/File.h" +#include "velox/common/file/TokenProvider.h" #include "velox/connectors/hive/FileProperties.h" namespace facebook::velox { @@ -59,7 +61,44 @@ struct FileHandleSizer { uint64_t operator()(const FileHandle& a); }; -using FileHandleCache = SimpleLRUCache; +struct FileHandleKey { + std::string filename; + std::shared_ptr tokenProvider{nullptr}; + + bool operator==(const FileHandleKey& other) const { + if (filename != other.filename) { + return false; + } + + if (tokenProvider == other.tokenProvider) { + return true; + } + + if (!tokenProvider || !other.tokenProvider) { + return false; + } + + return tokenProvider->equals(*other.tokenProvider); + } +}; + +} // namespace facebook::velox + +namespace std { +template <> +struct hash { + size_t operator()(const facebook::velox::FileHandleKey& key) const noexcept { + size_t filenameHash = std::hash()(key.filename); + return key.tokenProvider ? facebook::velox::bits::hashMix( + filenameHash, key.tokenProvider->hash()) + : filenameHash; + } +}; +} // namespace std + +namespace facebook::velox { +using FileHandleCache = + SimpleLRUCache; // Creates FileHandles via the Generator interface the CachedFactory requires. class FileHandleGenerator { @@ -68,7 +107,7 @@ class FileHandleGenerator { FileHandleGenerator(std::shared_ptr properties) : properties_(std::move(properties)) {} std::unique_ptr operator()( - const std::string& filename, + const FileHandleKey& filename, const FileProperties* properties, filesystems::File::IoStats* stats); @@ -77,14 +116,14 @@ class FileHandleGenerator { }; using FileHandleFactory = CachedFactory< - std::string, + FileHandleKey, FileHandle, FileHandleGenerator, FileProperties, filesystems::File::IoStats, FileHandleSizer>; -using FileHandleCachedPtr = CachedPtr; +using FileHandleCachedPtr = CachedPtr; using FileHandleCacheStats = SimpleLRUCacheStats; diff --git a/velox/connectors/hive/FileProperties.h b/velox/connectors/hive/FileProperties.h index d3ed9e3cbd6b..a6158e1fec65 100644 --- a/velox/connectors/hive/FileProperties.h +++ b/velox/connectors/hive/FileProperties.h @@ -25,6 +25,7 @@ #pragma once +#include #include namespace facebook::velox { @@ -34,6 +35,7 @@ struct FileProperties { std::optional modificationTime; std::optional readRangeHint{std::nullopt}; std::shared_ptr extraFileInfo{nullptr}; + folly::F14FastMap fileReadOps{}; }; } // namespace facebook::velox diff --git a/velox/connectors/hive/HiveConfig.cpp b/velox/connectors/hive/HiveConfig.cpp index 5eaab02e648b..8b354e2ed7d4 100644 --- a/velox/connectors/hive/HiveConfig.cpp +++ b/velox/connectors/hive/HiveConfig.cpp @@ -67,6 +67,11 @@ uint32_t HiveConfig::maxPartitionsPerWriters( config_->get(kMaxPartitionsPerWriters, 128)); } +uint32_t HiveConfig::maxBucketCount(const config::ConfigBase* session) const { + return session->get( + kMaxBucketCountSession, config_->get(kMaxBucketCount, 100'000)); +} + bool HiveConfig::immutablePartitions() const { return config_->get(kImmutablePartitions, false); } @@ -88,6 +93,11 @@ std::optional HiveConfig::gcsMaxRetryTime() const { config_->get(kGcsMaxRetryTime)); } +std::optional HiveConfig::gcsAuthAccessTokenProvider() const { + return static_cast>( + config_->get(kGcsAuthAccessTokenProvider)); +} + bool HiveConfig::isOrcUseColumnNames(const config::ConfigBase* session) const { return session->get( kOrcUseColumnNamesSession, config_->get(kOrcUseColumnNames, false)); @@ -149,6 +159,15 @@ int32_t HiveConfig::prefetchRowGroups() const { return config_->get(kPrefetchRowGroups, 1); } +size_t HiveConfig::parallelUnitLoadCount( + const config::ConfigBase* session) const { + auto count = session->get( + kParallelUnitLoadCountSession, + config_->get(kParallelUnitLoadCount, 0)); + VELOX_CHECK_LE(count, 100, "parallelUnitLoadCount too large: {}", count); + return count; +} + int32_t HiveConfig::loadQuantum(const config::ConfigBase* session) const { return session->get( kLoadQuantumSession, config_->get(kLoadQuantum, 8 << 20)); @@ -233,4 +252,25 @@ std::string HiveConfig::hiveLocalFileFormat() const { return config_->get(kLocalFileFormat, ""); } +bool HiveConfig::preserveFlatMapsInMemory( + const config::ConfigBase* session) const { + return session->get( + kPreserveFlatMapsInMemorySession, + config_->get(kPreserveFlatMapsInMemory, false)); +} + +std::string HiveConfig::user(const config::ConfigBase* session) const { + return session->get(kUser, config_->get(kUser, "")); +} + +std::string HiveConfig::source(const config::ConfigBase* session) const { + return session->get( + kSource, config_->get(kSource, "")); +} + +std::string HiveConfig::schema(const config::ConfigBase* session) const { + return session->get( + kSchema, config_->get(kSchema, "")); +} + } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConfig.h b/velox/connectors/hive/HiveConfig.h index 04dbd754b81a..f69da84ff0c1 100644 --- a/velox/connectors/hive/HiveConfig.h +++ b/velox/connectors/hive/HiveConfig.h @@ -49,6 +49,10 @@ class HiveConfig { static constexpr const char* kMaxPartitionsPerWritersSession = "max_partitions_per_writers"; + /// Maximum number of buckets allowed to output by the table writers. + static constexpr const char* kMaxBucketCount = "hive.max-bucket-count"; + static constexpr const char* kMaxBucketCountSession = "hive.max_bucket_count"; + /// Whether new data can be inserted into an unpartition table. /// Velox currently does not support appending data to existing partitions. static constexpr const char* kImmutablePartitions = @@ -67,12 +71,13 @@ class HiveConfig { /// The GCS maximum time allowed to retry transient errors. static constexpr const char* kGcsMaxRetryTime = "hive.gcs.max-retry-time"; + static constexpr const char* kGcsAuthAccessTokenProvider = + "hive.gcs.auth.access-token-provider"; + /// Maps table field names to file field names using names, not indices. - // TODO: remove hive_orc_use_column_names since it doesn't exist in presto, - // right now this is only used for testing. static constexpr const char* kOrcUseColumnNames = "hive.orc.use-column-names"; static constexpr const char* kOrcUseColumnNamesSession = - "hive_orc_use_column_names"; + "orc_use_column_names"; /// Maps table field names to file field names using names, not indices. static constexpr const char* kParquetUseColumnNames = @@ -139,6 +144,15 @@ class HiveConfig { /// meta data together. Optimization to decrease the small IO requests static constexpr const char* kFilePreloadThreshold = "file-preload-threshold"; + /// When set to be larger than 0, parallel unit loader feature is enabled and + /// it configures how many units (e.g., stripes) we load in parallel. + /// When set to 0, parallel unit loader feature is disabled and on demand unit + /// loader would be used. + static constexpr const char* kParallelUnitLoadCount = + "parallel-unit-load-count"; + static constexpr const char* kParallelUnitLoadCountSession = + "parallel_unit_load_count"; + /// Config used to create write files. This config is provided to underlying /// file system through hive connector and data sink. The config is free form. /// The form should be defined by the underlying file system. @@ -183,11 +197,24 @@ class HiveConfig { static constexpr const char* kLocalDataPath = "hive_local_data_path"; static constexpr const char* kLocalFileFormat = "hive_local_file_format"; + /// Whether to preserve flat maps in memory as FlatMapVectors instead of + /// converting them to MapVectors. + static constexpr const char* kPreserveFlatMapsInMemory = + "hive.preserve-flat-maps-in-memory"; + static constexpr const char* kPreserveFlatMapsInMemorySession = + "hive.preserve_flat_maps_in_memory"; + + static constexpr const char* kUser = "user"; + static constexpr const char* kSource = "source"; + static constexpr const char* kSchema = "schema"; + InsertExistingPartitionsBehavior insertExistingPartitionsBehavior( const config::ConfigBase* session) const; uint32_t maxPartitionsPerWriters(const config::ConfigBase* session) const; + uint32_t maxBucketCount(const config::ConfigBase* session) const; + bool immutablePartitions() const; std::string gcsEndpoint() const; @@ -198,6 +225,8 @@ class HiveConfig { std::optional gcsMaxRetryTime() const; + std::optional gcsAuthAccessTokenProvider() const; + bool isOrcUseColumnNames(const config::ConfigBase* session) const; bool isParquetUseColumnNames(const config::ConfigBase* session) const; @@ -217,6 +246,8 @@ class HiveConfig { int32_t prefetchRowGroups() const; + size_t parallelUnitLoadCount(const config::ConfigBase* session) const; + int32_t loadQuantum(const config::ConfigBase* session) const; int32_t numCacheFileHandles() const; @@ -261,6 +292,19 @@ class HiveConfig { /// hiveLocalDataPath(). std::string hiveLocalFileFormat() const; + /// Whether to preserve flat maps in memory as FlatMapVectors instead of + /// converting them to MapVectors. + bool preserveFlatMapsInMemory(const config::ConfigBase* session) const; + + /// User of the query. Used for storage logging. + std::string user(const config::ConfigBase* session) const; + + /// Source of the query. Used for storage access and logging. + std::string source(const config::ConfigBase* session) const; + + /// Schema of the query. Used for storage logging. + std::string schema(const config::ConfigBase* session) const; + HiveConfig(std::shared_ptr config) { VELOX_CHECK_NOT_NULL( config, "Config is null for HiveConfig initialization"); diff --git a/velox/connectors/hive/HiveConnector.cpp b/velox/connectors/hive/HiveConnector.cpp index 4ef9e8f06179..e04828e83aaa 100644 --- a/velox/connectors/hive/HiveConnector.cpp +++ b/velox/connectors/hive/HiveConnector.cpp @@ -16,13 +16,10 @@ #include "velox/connectors/hive/HiveConnector.h" -#include "velox/common/base/Fs.h" #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/HiveDataSource.h" #include "velox/connectors/hive/HivePartitionFunction.h" -#include "velox/expression/ExprToSubfieldFilter.h" -#include "velox/expression/FieldReference.h" #include #include @@ -31,27 +28,19 @@ using namespace facebook::velox::exec; namespace facebook::velox::connector::hive { -namespace { -std::vector>& -hiveConnectorMetadataFactories() { - static std::vector> factories; - return factories; -} -} // namespace - HiveConnector::HiveConnector( const std::string& id, std::shared_ptr config, - folly::Executor* executor) - : Connector(id), - hiveConfig_(std::make_shared(config)), + folly::Executor* ioExecutor) + : Connector(id, std::move(config)), + hiveConfig_(std::make_shared(connectorConfig())), fileHandleFactory_( hiveConfig_->isFileHandleCacheEnabled() - ? std::make_unique>( + ? std::make_unique>( hiveConfig_->numCacheFileHandles()) : nullptr, - std::make_unique(config)), - executor_(executor) { + std::make_unique(hiveConfig_->config())), + ioExecutor_(ioExecutor) { if (hiveConfig_->isFileHandleCacheEnabled()) { LOG(INFO) << "Hive connector " << connectorId() << " created with maximum of " @@ -62,38 +51,31 @@ HiveConnector::HiveConnector( LOG(INFO) << "Hive connector " << connectorId() << " created with file handle cache disabled"; } - for (auto& factory : hiveConnectorMetadataFactories()) { - metadata_ = factory->create(this); - if (metadata_ != nullptr) { - break; - } - } } std::unique_ptr HiveConnector::createDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const std::unordered_map& columnHandles, ConnectorQueryCtx* connectorQueryCtx) { return std::make_unique( outputType, tableHandle, columnHandles, &fileHandleFactory_, - executor_, + ioExecutor_, connectorQueryCtx, hiveConfig_); } std::unique_ptr HiveConnector::createDataSink( RowTypePtr inputType, - std::shared_ptr connectorInsertTableHandle, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy) { - auto hiveInsertHandle = std::dynamic_pointer_cast( - connectorInsertTableHandle); + auto hiveInsertHandle = + std::dynamic_pointer_cast( + connectorInsertTableHandle); VELOX_CHECK_NOT_NULL( hiveInsertHandle, "Hive connector expecting hive write handle!"); return std::make_unique( @@ -104,6 +86,19 @@ std::unique_ptr HiveConnector::createDataSink( hiveConfig_); } +// static +void HiveConnector::registerSerDe() { + HiveTableHandle::registerSerDe(); + HiveColumnHandle::registerSerDe(); + HiveConnectorSplit::registerSerDe(); + HiveInsertTableHandle::registerSerDe(); + HiveInsertFileNameGenerator::registerSerDe(); + LocationHandle::registerSerDe(); + HiveBucketProperty::registerSerDe(); + HiveSortingColumn::registerSerDe(); + HivePartitionFunctionSpec::registerSerDe(); +} + std::unique_ptr HivePartitionFunctionSpec::create( int numPartitions, bool localExchange) const { @@ -125,7 +120,7 @@ std::unique_ptr HivePartitionFunctionSpec::create( std::mt19937{0}); } } - return std::make_unique( + return std::make_unique( numBuckets_, bucketToPartition_.empty() ? std::move(bucketToPartitions) : bucketToPartition_, @@ -190,16 +185,11 @@ core::PartitionFunctionSpecPtr HivePartitionFunctionSpec::deserialize( std::move(constValues)); } -void registerHivePartitionFunctionSerDe() { +// static +void HivePartitionFunctionSpec::registerSerDe() { auto& registry = DeserializationWithContextRegistryForSharedPtr(); registry.Register( "HivePartitionFunctionSpec", HivePartitionFunctionSpec::deserialize); } -bool registerHiveConnectorMetadataFactory( - std::unique_ptr factory) { - hiveConnectorMetadataFactories().push_back(std::move(factory)); - return true; -} - } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConnector.h b/velox/connectors/hive/HiveConnector.h index 937fdd850edf..845461ea412e 100644 --- a/velox/connectors/hive/HiveConnector.h +++ b/velox/connectors/hive/HiveConnector.h @@ -34,40 +34,28 @@ class HiveConnector : public Connector { std::shared_ptr config, folly::Executor* executor); - const std::shared_ptr& connectorConfig() - const override { - return hiveConfig_->config(); - } - bool canAddDynamicFilter() const override { return true; } - ConnectorMetadata* metadata() const override { - VELOX_CHECK_NOT_NULL(metadata_); - return metadata_.get(); - } - std::unique_ptr createDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) override; - bool supportsSplitPreload() override { + bool supportsSplitPreload() const override { return true; } std::unique_ptr createDataSink( RowTypePtr inputType, - std::shared_ptr connectorInsertTableHandle, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, ConnectorQueryCtx* connectorQueryCtx, CommitStrategy commitStrategy) override; - folly::Executor* executor() const override { - return executor_; + folly::Executor* ioExecutor() const override { + return ioExecutor_; } FileHandleCacheStats fileHandleCacheStats() { @@ -80,11 +68,12 @@ class HiveConnector : public Connector { return fileHandleFactory_.clearCache(); } + static void registerSerDe(); + protected: const std::shared_ptr hiveConfig_; FileHandleFactory fileHandleFactory_; - folly::Executor* executor_; - std::shared_ptr metadata_; + folly::Executor* ioExecutor_; }; class HiveConnectorFactory : public ConnectorFactory { @@ -100,7 +89,7 @@ class HiveConnectorFactory : public ConnectorFactory { const std::string& id, std::shared_ptr config, folly::Executor* ioExecutor = nullptr, - folly::Executor* cpuExecutor = nullptr) override { + [[maybe_unused]] folly::Executor* cpuExecutor = nullptr) override { return std::make_shared(id, config, ioExecutor); } }; @@ -147,6 +136,8 @@ class HivePartitionFunctionSpec : public core::PartitionFunctionSpec { const folly::dynamic& obj, void* context); + static void registerSerDe(); + private: const int numBuckets_; const std::vector bucketToPartition_; @@ -154,23 +145,4 @@ class HivePartitionFunctionSpec : public core::PartitionFunctionSpec { const std::vector constValues_; }; -void registerHivePartitionFunctionSerDe(); - -/// Hook for connecting metadata functions to a HiveConnector. Each registered -/// factory is called after initializing a HiveConnector until one of these -/// returns a ConnectorMetadata instance. -class HiveConnectorMetadataFactory { - public: - virtual ~HiveConnectorMetadataFactory() = default; - - /// Returns a ConnectorMetadata to complete'hiveConnector' if 'this' - /// recognizes a data source, e.g. local file system or remote metadata - /// service associated to configs in 'hiveConnector'. - virtual std::shared_ptr create( - HiveConnector* connector) = 0; -}; - -bool registerHiveConnectorMetadataFactory( - std::unique_ptr); - } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConnectorSplit.cpp b/velox/connectors/hive/HiveConnectorSplit.cpp index fba4902b180c..e7e32686692e 100644 --- a/velox/connectors/hive/HiveConnectorSplit.cpp +++ b/velox/connectors/hive/HiveConnectorSplit.cpp @@ -30,6 +30,10 @@ std::string HiveConnectorSplit::toString() const { return fmt::format("Hive: {} {} - {}", filePath, start, length); } +uint64_t HiveConnectorSplit::size() const { + return length; +} + std::string HiveConnectorSplit::getFileName() const { const auto i = filePath.rfind('/'); return i == std::string::npos ? filePath : filePath.substr(i + 1); @@ -144,8 +148,10 @@ std::shared_ptr HiveConnectorSplit::create( std::vector> bucketColumnHandles; for (const auto& bucketColumnHandleObj : bucketConversionObj["bucketColumnHandles"]) { - bucketColumnHandles.push_back(std::const_pointer_cast( - ISerializable::deserialize(bucketColumnHandleObj))); + bucketColumnHandles.push_back( + std::const_pointer_cast( + ISerializable::deserialize( + bucketColumnHandleObj))); } bucketConversion = HiveBucketConversion{ .tableBucketCount = static_cast( diff --git a/velox/connectors/hive/HiveConnectorSplit.h b/velox/connectors/hive/HiveConnectorSplit.h index 16a50c42abd3..3485c2330fa5 100644 --- a/velox/connectors/hive/HiveConnectorSplit.h +++ b/velox/connectors/hive/HiveConnectorSplit.h @@ -106,42 +106,9 @@ struct HiveConnectorSplit : public connector::ConnectorSplit { rowIdProperties(_rowIdProperties), bucketConversion(_bucketConversion) {} -#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY - HiveConnectorSplit( - const std::string& connectorId, - const std::string& _filePath, - dwio::common::FileFormat _fileFormat, - uint64_t _start, - uint64_t _length, - const std::unordered_map>& - _partitionKeys, - std::optional _tableBucketNumber, - const std::unordered_map& _customSplitInfo, - const std::shared_ptr& _extraFileInfo, - const std::unordered_map& _serdeParameters, - const std::unordered_map& _storageParameters, - int64_t splitWeight = 0, - bool cacheable = true, - const std::unordered_map& _infoColumns = {}, - std::optional _properties = std::nullopt, - std::optional _rowIdProperties = std::nullopt, - const std::optional& _bucketConversion = - std::nullopt) - : ConnectorSplit(connectorId, splitWeight, cacheable), - filePath(_filePath), - fileFormat(_fileFormat), - start(_start), - length(_length), - partitionKeys(_partitionKeys), - tableBucketNumber(_tableBucketNumber), - customSplitInfo(_customSplitInfo), - extraFileInfo(_extraFileInfo), - serdeParameters(_serdeParameters), - infoColumns(_infoColumns), - properties(_properties), - rowIdProperties(_rowIdProperties), - bucketConversion(_bucketConversion) {} -#endif + ~HiveConnectorSplit() = default; + + uint64_t size() const override; std::string toString() const override; diff --git a/velox/connectors/hive/HiveConnectorUtil.cpp b/velox/connectors/hive/HiveConnectorUtil.cpp index 1fd83c483537..0110f3be394a 100644 --- a/velox/connectors/hive/HiveConnectorUtil.cpp +++ b/velox/connectors/hive/HiveConnectorUtil.cpp @@ -21,6 +21,7 @@ #include "velox/dwio/common/CachedBufferedInput.h" #include "velox/dwio/common/DirectBufferedInput.h" #include "velox/expression/Expr.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::connector::hive { @@ -73,14 +74,15 @@ std::unique_ptr makeFloatingPointMapKeyFilter( if (lowerUnbounded && upperUnbounded) { continue; } - filters.push_back(std::make_unique>( - lower, - lowerUnbounded, - lowerExclusive, - upper, - upperUnbounded, - upperExclusive, - false)); + filters.push_back( + std::make_unique>( + lower, + lowerUnbounded, + lowerExclusive, + upper, + upperUnbounded, + upperExclusive, + false)); } if (filters.size() == 1) { return std::move(filters[0]); @@ -110,8 +112,7 @@ void addSubfields( folly::F14FastMap> required; for (auto& subfield : subfields) { auto* element = subfield.subfield->path()[level].get(); - auto* nestedField = - dynamic_cast(element); + auto* nestedField = element->as(); VELOX_CHECK( nestedField, "Unsupported for row subfields pruning: {}", @@ -150,20 +151,18 @@ void addSubfields( std::vector longSubscripts; for (auto& subfield : subfields) { auto* element = subfield.subfield->path()[level].get(); - if (dynamic_cast(element)) { + if (element->is(common::SubfieldKind::kAllSubscripts)) { return; } if (stringKey) { - auto* subscript = - dynamic_cast(element); + auto* subscript = element->as(); VELOX_CHECK( subscript, "Unsupported for string map pruning: {}", element->toString()); stringSubscripts.push_back(subscript->index()); } else { - auto* subscript = - dynamic_cast(element); + auto* subscript = element->as(); VELOX_CHECK( subscript, "Unsupported for long map pruning: {}", @@ -242,8 +241,7 @@ inline uint8_t parseDelimiter(const std::string& delim) { inline bool isSynthesizedColumn( const std::string& name, - const std::unordered_map>& - infoColumns) { + const std::unordered_map& infoColumns) { return infoColumns.count(name) != 0; } @@ -263,7 +261,7 @@ const std::string& getColumnName(const common::Subfield& subfield) { return field->name(); } -void checkColumnNameLowerCase(const std::shared_ptr& type) { +void checkColumnNameLowerCase(const TypePtr& type) { switch (type->kind()) { case TypeKind::ARRAY: checkColumnNameLowerCase(type->asArray().elementType()); @@ -289,8 +287,7 @@ void checkColumnNameLowerCase(const std::shared_ptr& type) { void checkColumnNameLowerCase( const common::SubfieldFilters& filters, - const std::unordered_map>& - infoColumns) { + const std::unordered_map& infoColumns) { for (const auto& filterIt : filters) { const auto name = filterIt.first.toString(); if (isSynthesizedColumn(name, infoColumns)) { @@ -330,7 +327,11 @@ void processFieldSpec( if (type.isMap() && !spec.isConstant()) { auto* keys = spec.childByName(common::ScanSpec::kMapKeysFieldName); VELOX_CHECK_NOT_NULL(keys); - keys->addFilter(common::IsNotNull()); + if (keys->filter()) { + VELOX_CHECK(!keys->filter()->testNull()); + } else { + keys->setFilter(std::make_shared()); + } } }); if (dataColumns) { @@ -351,10 +352,8 @@ std::shared_ptr makeScanSpec( outputSubfields, const common::SubfieldFilters& filters, const RowTypePtr& dataColumns, - const std::unordered_map>& - partitionKeys, - const std::unordered_map>& - infoColumns, + const std::unordered_map& partitionKeys, + const std::unordered_map& infoColumns, const SpecialColumnNames& specialColumns, bool disableStatsBasedFilterReorder, memory::MemoryPool* pool) { @@ -443,7 +442,8 @@ std::shared_ptr makeScanSpec( continue; } auto fieldSpec = spec->getOrCreateChild(pair.first); - fieldSpec->addFilter(*pair.second); + VELOX_CHECK_NULL(spec->filter()); + fieldSpec->setFilter(pair.second); } if (disableStatsBasedFilterReorder) { @@ -614,6 +614,7 @@ void configureRowReaderOptions( const std::shared_ptr& hiveSplit, const std::shared_ptr& hiveConfig, const config::ConfigBase* sessionProperties, + folly::Executor* const ioExecutor, dwio::common::RowReaderOptions& rowReaderOptions) { auto skipRowsIt = tableParameters.find(dwio::common::TableParameter::kSkipHeaderLineCount); @@ -621,12 +622,18 @@ void configureRowReaderOptions( rowReaderOptions.setSkipRows(folly::to(skipRowsIt->second)); } rowReaderOptions.setScanSpec(scanSpec); + rowReaderOptions.setIOExecutor(ioExecutor); rowReaderOptions.setMetadataFilter(std::move(metadataFilter)); rowReaderOptions.setRequestedType(rowType); rowReaderOptions.range(hiveSplit->start, hiveSplit->length); if (hiveConfig && sessionProperties) { - rowReaderOptions.setTimestampPrecision(static_cast( - hiveConfig->readTimestampUnit(sessionProperties))); + rowReaderOptions.setTimestampPrecision( + static_cast( + hiveConfig->readTimestampUnit(sessionProperties))); + rowReaderOptions.setPreserveFlatMapsInMemory( + hiveConfig->preserveFlatMapsInMemory(sessionProperties)); + rowReaderOptions.setParallelUnitLoadCount( + hiveConfig->parallelUnitLoadCount(sessionProperties)); } rowReaderOptions.setSerdeParameters(hiveSplit->serdeParameters); } @@ -637,7 +644,7 @@ bool applyPartitionFilter( const TypePtr& type, const std::string& partitionValue, bool isPartitionDateDaysSinceEpoch, - common::Filter* filter, + const common::Filter* filter, bool asLocalTime) { if (type->isDate()) { int32_t result = 0; @@ -646,7 +653,7 @@ bool applyPartitionFilter( if (isPartitionDateDaysSinceEpoch) { result = folly::to(partitionValue); } else { - result = DATE()->toDays(static_cast(partitionValue)); + result = DATE()->toDays(partitionValue); } return applyFilter(*filter, result); } @@ -691,7 +698,7 @@ bool testFilters( const std::string& filePath, const std::unordered_map>& partitionKeys, - const std::unordered_map>& + const std::unordered_map& partitionKeysHandle, bool asLocalTime) { const auto totalRows = reader->numberOfRows(); @@ -753,32 +760,50 @@ std::unique_ptr createBufferedInput( const ConnectorQueryCtx* connectorQueryCtx, std::shared_ptr ioStats, std::shared_ptr fsStats, - folly::Executor* executor) { + folly::Executor* executor, + const folly::F14FastMap& fileReadOps) { if (connectorQueryCtx->cache()) { return std::make_unique( fileHandle.file, dwio::common::MetricsLog::voidLog(), - fileHandle.uuid.id(), + fileHandle.uuid, connectorQueryCtx->cache(), Connector::getTracker( connectorQueryCtx->scanId(), readerOpts.loadQuantum()), - fileHandle.groupId.id(), + fileHandle.groupId, ioStats, std::move(fsStats), executor, - readerOpts); + readerOpts, + fileReadOps); + } + if (readerOpts.fileFormat() == dwio::common::FileFormat::NIMBLE) { + // Nimble streams (in case of single chunk) are compressed as whole and need + // to be fully fetched in order to do decompression, so there is no point to + // fetch them by quanta. Just use BufferedInput to fetch streams as whole + // to reduce memory footprint. + return std::make_unique( + fileHandle.file, + readerOpts.memoryPool(), + dwio::common::MetricsLog::voidLog(), + ioStats.get(), + fsStats.get(), + dwio::common::BufferedInput::kMaxMergeDistance, + std::nullopt, + fileReadOps); } return std::make_unique( fileHandle.file, dwio::common::MetricsLog::voidLog(), - fileHandle.uuid.id(), + fileHandle.uuid, Connector::getTracker( connectorQueryCtx->scanId(), readerOpts.loadQuantum()), - fileHandle.groupId.id(), + fileHandle.groupId, std::move(ioStats), std::move(fsStats), executor, - readerOpts); + readerOpts, + fileReadOps); } namespace { @@ -848,8 +873,6 @@ double getPrestoSampleRate( return std::max(0.0, std::min(1.0, rate->value().value())); } -} // namespace - core::TypedExprPtr extractFiltersFromRemainingFilter( const core::TypedExprPtr& expr, core::ExpressionEvaluator* evaluator, @@ -862,10 +885,10 @@ core::TypedExprPtr extractFiltersFromRemainingFilter( } common::Filter* oldFilter = nullptr; try { - common::Subfield subfield; - if (auto filter = exec::ExprToSubfieldFilterParser::getInstance() - ->leafCallToSubfieldFilter( - *call, subfield, evaluator, negated)) { + if (auto subfieldAndFilter = + exec::ExprToSubfieldFilterParser::getInstance() + ->leafCallToSubfieldFilter(*call, evaluator, negated)) { + auto& [subfield, filter] = subfieldAndFilter.value(); if (auto it = filters.find(subfield); it != filters.end()) { oldFilter = it->second.get(); filter = filter->mergeWith(oldFilter); @@ -887,20 +910,73 @@ core::TypedExprPtr extractFiltersFromRemainingFilter( return inner ? replaceInputs(call, {inner}) : nullptr; } - if ((call->name() == "and" && !negated) || - (call->name() == "or" && negated)) { - auto lhs = extractFiltersFromRemainingFilter( - call->inputs()[0], evaluator, negated, filters, sampleRate); - auto rhs = extractFiltersFromRemainingFilter( - call->inputs()[1], evaluator, negated, filters, sampleRate); - if (!lhs) { - return rhs; + if ((call->name() == expression::kAnd && !negated) || + (call->name() == expression::kOr && negated)) { + std::vector args; + args.reserve(call->inputs().size()); + for (const auto& input : call->inputs()) { + if (auto arg = extractFiltersFromRemainingFilter( + input, evaluator, negated, filters, sampleRate)) { + args.push_back(std::move(arg)); + } + // If extractFiltersFromRemainingFilter returns nullptr, it means + // everything in input is converted to filters. + } + if (args.empty()) { + return nullptr; } - if (!rhs) { - return lhs; + if (args.size() == 1) { + return std::move(args[0]); } - return replaceInputs(call, {lhs, rhs}); + return replaceInputs(call, std::move(args)); } + + if ((call->name() == expression::kAnd && negated) || + (call->name() == expression::kOr && !negated)) { + std::vector> disjuncts; + common::Subfield subfield; + + for (const auto& input : call->inputs()) { + common::SubfieldFilters tmpFilters; + double tmpSampleRate = 1; + auto tmpRemaining = extractFiltersFromRemainingFilter( + input, evaluator, negated, tmpFilters, tmpSampleRate); + + if (tmpRemaining != nullptr || tmpSampleRate != 1 || + tmpFilters.size() != 1) { + disjuncts.clear(); + break; + } + + if (disjuncts.empty()) { + subfield = tmpFilters.begin()->first.clone(); + } else if (!(subfield == tmpFilters.begin()->first)) { + disjuncts.clear(); + break; + } + + disjuncts.push_back(tmpFilters.begin()->second->clone()); + } + + if (!disjuncts.empty()) { + auto filter = + exec::ExprToSubfieldFilterParser::makeOrFilter(std::move(disjuncts)); + + if (filter == nullptr) { + return expr; + } + + auto it = filters.find(subfield); + if (it != filters.end()) { + filter = filter->mergeWith(it->second.get()); + } + + filters.insert_or_assign(std::move(subfield), std::move(filter)); + + return nullptr; + } + } + if (!negated) { double rate = getPrestoSampleRate(expr, call, evaluator); if (rate != -1) { @@ -908,6 +984,18 @@ core::TypedExprPtr extractFiltersFromRemainingFilter( return nullptr; } } + return expr; } +} // namespace + +core::TypedExprPtr extractFiltersFromRemainingFilter( + const core::TypedExprPtr& expr, + core::ExpressionEvaluator* evaluator, + common::SubfieldFilters& filters, + double& sampleRate) { + return extractFiltersFromRemainingFilter( + expr, evaluator, /*negated=*/false, filters, sampleRate); +} + } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HiveConnectorUtil.h b/velox/connectors/hive/HiveConnectorUtil.h index 0209a50848e2..a1e35cdcd27a 100644 --- a/velox/connectors/hive/HiveConnectorUtil.h +++ b/velox/connectors/hive/HiveConnectorUtil.h @@ -32,12 +32,13 @@ struct HiveConnectorSplit; const std::string& getColumnName(const common::Subfield& subfield); -void checkColumnNameLowerCase(const std::shared_ptr& type); +void checkColumnNameLowerCase(const TypePtr& type); void checkColumnNameLowerCase( const common::SubfieldFilters& filters, - const std::unordered_map>& - infoColumns); + const std::unordered_map< + std::string, + std::shared_ptr>& infoColumns); void checkColumnNameLowerCase(const core::TypedExprPtr& typeExpr); @@ -52,10 +53,12 @@ std::shared_ptr makeScanSpec( outputSubfields, const common::SubfieldFilters& filters, const RowTypePtr& dataColumns, - const std::unordered_map>& - partitionKeys, - const std::unordered_map>& - infoColumns, + const std::unordered_map< + std::string, + std::shared_ptr>& partitionKeys, + const std::unordered_map< + std::string, + std::shared_ptr>& infoColumns, const SpecialColumnNames& specialColumns, bool disableStatsBasedFilterReorder, memory::MemoryPool* pool); @@ -83,6 +86,7 @@ void configureRowReaderOptions( const std::shared_ptr& hiveSplit, const std::shared_ptr& hiveConfig, const config::ConfigBase* sessionProperties, + folly::Executor* ioExecutor, dwio::common::RowReaderOptions& rowReaderOptions); bool testFilters( @@ -91,8 +95,9 @@ bool testFilters( const std::string& filePath, const std::unordered_map>& partitionKey, - const std::unordered_map>& - partitionKeysHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& partitionKeysHandle, bool asLocalTime); std::unique_ptr createBufferedInput( @@ -101,12 +106,49 @@ std::unique_ptr createBufferedInput( const ConnectorQueryCtx* connectorQueryCtx, std::shared_ptr ioStats, std::shared_ptr fsStats, - folly::Executor* executor); - + folly::Executor* executor, + const folly::F14FastMap& fileReadOps = {}); + +/// Given a boolean expression, breaks it up into conjuncts and sorts these into +/// single-column comparisons with constants (filters), rand() < sampleRate, and +/// the rest (return value). +/// +/// Multiple rand() < K conjuncts are combined into a single sampleRate by +/// multiplying individual sample rates. rand() < 0.1 and rand() < 0.2 produces +/// sampleRate = 0.02. +/// +/// Multiple single-column comparisons with constants that reference the same +/// column or subfield are combined into a single filter. Pre-existing entries +/// in 'filters' are preserved and combined with the ones extracted from the +/// 'expr'. +/// +/// NOT(x OR y) is converted to (NOT x) AND (NOT y). +/// +/// @param expr Boolean expression to break up. +/// @param evaluator Expression evaluator to use. +/// @param filters Mapping from a column or a subfield to comparison with +/// constant. +/// @param sampleRate Sample rate extracted from rand() < sampleRate conjuncts. +/// @return Expression with filters and rand() < sampleRate conjuncts removed. +/// +/// Examples: +/// expr := a = 1 AND b > 0 +/// filters := {a: eq(1), b: gt(0)} +// sampleRate left unmodified +// return value is nullptr +/// +/// expr: not (a > 0 or b > 10) +/// filters := {a: le(0), b: le(10)} +/// sampleRate left unmodified +/// return value is nullptr +/// +/// expr := a > 0 AND a < b AND rand() < 0.1 +/// filters := {a: gt(0)} +/// sampleRate := 0.1 +/// return value is a < b core::TypedExprPtr extractFiltersFromRemainingFilter( const core::TypedExprPtr& expr, core::ExpressionEvaluator* evaluator, - bool negated, common::SubfieldFilters& filters, double& sampleRate); diff --git a/velox/connectors/hive/HiveDataSink.cpp b/velox/connectors/hive/HiveDataSink.cpp index f6f3c8d8c656..16eb602724e8 100644 --- a/velox/connectors/hive/HiveDataSink.cpp +++ b/velox/connectors/hive/HiveDataSink.cpp @@ -24,7 +24,6 @@ #include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/connectors/hive/HivePartitionFunction.h" #include "velox/connectors/hive/TableHandle.h" -#include "velox/core/ITypedExpr.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/SortingWriter.h" #include "velox/exec/OperatorUtils.h" @@ -78,36 +77,23 @@ RowVectorPtr makeDataInput( input->getNullCount()); } -// Returns a subset of column indices corresponding to partition keys. -std::vector getPartitionChannels( - const std::shared_ptr& insertTableHandle) { - std::vector channels; - - for (column_index_t i = 0; i < insertTableHandle->inputColumns().size(); - i++) { - if (insertTableHandle->inputColumns()[i]->isPartitionKey()) { - channels.push_back(i); - } - } - - return channels; -} - -// Returns the column indices of non-partition data columns. -std::vector getNonPartitionChannels( - const std::vector& partitionChannels, - const column_index_t childrenSize) { - std::vector dataChannels; - dataChannels.reserve(childrenSize - partitionChannels.size()); - - for (column_index_t i = 0; i < childrenSize; i++) { - if (std::find(partitionChannels.cbegin(), partitionChannels.cend(), i) == - partitionChannels.cend()) { - dataChannels.push_back(i); - } +// Creates a PartitionIdGenerator if the table is partitioned, otherwise returns +// nullptr. +std::unique_ptr createPartitionIdGenerator( + const RowTypePtr& inputType, + const std::shared_ptr& insertTableHandle, + const std::shared_ptr& hiveConfig, + const ConnectorQueryCtx* connectorQueryCtx) { + auto partitionChannels = insertTableHandle->partitionChannels(); + if (partitionChannels.empty()) { + return nullptr; } - - return dataChannels; + return std::make_unique( + inputType, + partitionChannels, + hiveConfig->maxPartitionsPerWriters( + connectorQueryCtx->sessionProperties()), + connectorQueryCtx->memoryPool()); } std::string makePartitionDirectory( @@ -166,9 +152,10 @@ std::unique_ptr createBucketFunction( std::string computeBucketedFileName( const std::string& queryId, + uint32_t maxBucketCount, uint32_t bucket) { - static const uint32_t kMaxBucketCountPadding = - std::to_string(HiveDataSink::maxBucketCount() - 1).size(); + const uint32_t kMaxBucketCountPadding = + std::to_string(maxBucketCount - 1).size(); const std::string bucketValueStr = std::to_string(bucket); return fmt::format( "0{:0>{}}_0_{}", bucketValueStr, kMaxBucketCountPadding, queryId); @@ -200,6 +187,29 @@ FOLLY_ALWAYS_INLINE int32_t getBucketCount(const HiveBucketProperty* bucketProperty) { return bucketProperty == nullptr ? 0 : bucketProperty->bucketCount(); } + +std::vector computePartitionChannels( + const std::vector>& inputColumns) { + std::vector channels; + for (auto i = 0; i < inputColumns.size(); i++) { + if (inputColumns[i]->isPartitionKey()) { + channels.push_back(i); + } + } + return channels; +} + +std::vector computeNonPartitionChannels( + const std::vector>& inputColumns) { + std::vector channels; + for (auto i = 0; i < inputColumns.size(); i++) { + if (!inputColumns[i]->isPartitionKey()) { + channels.push_back(i); + } + } + return channels; +} + } // namespace const HiveWriterId& HiveWriterId::unpartitionedId() { @@ -365,6 +375,52 @@ std::string HiveBucketProperty::toString() const { return out.str(); } +HiveInsertTableHandle::HiveInsertTableHandle( + std::vector> inputColumns, + std::shared_ptr locationHandle, + dwio::common::FileFormat storageFormat, + std::shared_ptr bucketProperty, + std::optional compressionKind, + const std::unordered_map& serdeParameters, + const std::shared_ptr& writerOptions, + // When this option is set the HiveDataSink will always write a file even + // if there's no data. This is useful when the table is bucketed, but the + // engine handles ensuring a 1 to 1 mapping from task to bucket. + const bool ensureFiles, + std::shared_ptr fileNameGenerator) + : inputColumns_(std::move(inputColumns)), + locationHandle_(std::move(locationHandle)), + storageFormat_(storageFormat), + bucketProperty_(std::move(bucketProperty)), + compressionKind_(compressionKind), + serdeParameters_(serdeParameters), + writerOptions_(writerOptions), + ensureFiles_(ensureFiles), + fileNameGenerator_(std::move(fileNameGenerator)), + partitionChannels_(computePartitionChannels(inputColumns_)), + nonPartitionChannels_(computeNonPartitionChannels(inputColumns_)) { + if (compressionKind.has_value()) { + VELOX_CHECK( + compressionKind.value() != common::CompressionKind_MAX, + "Unsupported compression type: CompressionKind_MAX"); + } + + if (ensureFiles_) { + // If ensureFiles is set and either the bucketProperty is set or some + // partition keys are in the data, there is not a 1:1 mapping from Task to + // files so we can't proactively create writers. + VELOX_CHECK( + bucketProperty_ == nullptr || bucketProperty_->bucketCount() == 0, + "ensureFiles is not supported with bucketing"); + + for (const auto& inputColumn : inputColumns_) { + VELOX_CHECK( + !inputColumn->isPartitionKey(), + "ensureFiles is not supported with partition keys in the data"); + } + } +} + HiveDataSink::HiveDataSink( RowTypePtr inputType, std::shared_ptr insertTableHandle, @@ -382,7 +438,14 @@ HiveDataSink::HiveDataSink( ? createBucketFunction( *insertTableHandle->bucketProperty(), inputType) - : nullptr) {} + : nullptr, + insertTableHandle->partitionChannels(), + insertTableHandle->nonPartitionChannels(), + createPartitionIdGenerator( + inputType, + insertTableHandle, + hiveConfig, + connectorQueryCtx)) {} HiveDataSink::HiveDataSink( RowTypePtr inputType, @@ -391,7 +454,10 @@ HiveDataSink::HiveDataSink( CommitStrategy commitStrategy, const std::shared_ptr& hiveConfig, uint32_t bucketCount, - std::unique_ptr bucketFunction) + std::unique_ptr bucketFunction, + const std::vector& partitionChannels, + const std::vector& dataChannels, + std::unique_ptr partitionIdGenerator) : inputType_(std::move(inputType)), insertTableHandle_(std::move(insertTableHandle)), connectorQueryCtx_(connectorQueryCtx), @@ -400,19 +466,9 @@ HiveDataSink::HiveDataSink( updateMode_(getUpdateMode()), maxOpenWriters_(hiveConfig_->maxPartitionsPerWriters( connectorQueryCtx->sessionProperties())), - partitionChannels_(getPartitionChannels(insertTableHandle_)), - partitionIdGenerator_( - !partitionChannels_.empty() - ? std::make_unique( - inputType_, - partitionChannels_, - maxOpenWriters_, - connectorQueryCtx_->memoryPool(), - hiveConfig_->isPartitionPathAsLowerCase( - connectorQueryCtx->sessionProperties())) - : nullptr), - dataChannels_( - getNonPartitionChannels(partitionChannels_, inputType_->size())), + partitionChannels_(partitionChannels), + partitionIdGenerator_(std::move(partitionIdGenerator)), + dataChannels_(dataChannels), bucketCount_(static_cast(bucketCount)), bucketFunction_(std::move(bucketFunction)), writerFactory_( @@ -421,16 +477,21 @@ HiveDataSink::HiveDataSink( sortWriterFinishTimeSliceLimitMs_(getFinishTimeSliceLimitMsFromHiveConfig( hiveConfig_, connectorQueryCtx->sessionProperties())), + partitionKeyAsLowerCase_(hiveConfig_->isPartitionPathAsLowerCase( + connectorQueryCtx_->sessionProperties())), fileNameGenerator_(insertTableHandle_->fileNameGenerator()) { + fileSystemStats_ = std::make_unique(); if (isBucketed()) { VELOX_USER_CHECK_LT( - bucketCount_, maxBucketCount(), "bucketCount exceeds the limit"); + bucketCount_, + hiveConfig_->maxBucketCount(connectorQueryCtx->sessionProperties()), + "bucketCount exceeds the limit"); } VELOX_USER_CHECK( (commitStrategy_ == CommitStrategy::kNoCommit) || (commitStrategy_ == CommitStrategy::kTaskCommit), "Unsupported commit strategy: {}", - commitStrategyToString(commitStrategy_)); + CommitStrategyName::toName(commitStrategy_)); if (insertTableHandle_->ensureFiles()) { VELOX_CHECK( @@ -465,12 +526,16 @@ HiveDataSink::HiveDataSink( bool HiveDataSink::canReclaim() const { // Currently, we only support memory reclaim on dwrf file writer. return (spillConfig_ != nullptr) && - (insertTableHandle_->storageFormat() == dwio::common::FileFormat::DWRF); + (insertTableHandle_->storageFormat() == dwio::common::FileFormat::DWRF || + insertTableHandle_->storageFormat() == dwio::common::FileFormat::NIMBLE); } void HiveDataSink::appendData(RowVectorPtr input) { checkRunning(); + // Lazy load all the input columns. + input->loadedVector(); + // Write to unpartitioned (and unbucketed) table. if (!isPartitioned() && !isBucketed()) { const auto index = ensureWriter(HiveWriterId::unpartitionedId()); @@ -481,11 +546,6 @@ void HiveDataSink::appendData(RowVectorPtr input) { // Compute partition and bucket numbers. computePartitionAndBucketIds(input); - // Lazy load all the input columns. - for (column_index_t i = 0; i < input->childrenSize(); ++i) { - input->childAt(i)->loadedVector(); - } - // All inputs belong to a single non-bucketed partition. The partition id // must be zero. if (!isBucketed() && partitionIdGenerator_->numPartitions() == 1) { @@ -590,6 +650,19 @@ DataSink::Stats HiveDataSink::stats() const { return stats; } +std::unordered_map HiveDataSink::runtimeStats() + const { + std::unordered_map runtimeStats; + + const auto fsStatsMap = fileSystemStats_->stats(); + for (const auto& [statName, statValue] : fsStatsMap) { + runtimeStats.emplace( + statName, RuntimeCounter(statValue.sum, statValue.unit)); + } + + return runtimeStats; +} + std::shared_ptr HiveDataSink::createWriterPool( const HiveWriterId& writerId) { auto* connectorPool = connectorQueryCtx_->connectorMemoryPool(); @@ -667,7 +740,10 @@ bool HiveDataSink::finish() { std::vector HiveDataSink::close() { setState(State::kClosed); closeInternal(); + return commitMessage(); +} +std::vector HiveDataSink::commitMessage() const { std::vector partitionUpdates; partitionUpdates.reserve(writerInfo_.size()); for (int i = 0; i < writerInfo_.size(); ++i) { @@ -730,38 +806,8 @@ uint32_t HiveDataSink::ensureWriter(const HiveWriterId& id) { return appendWriter(id); } -uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { - // Check max open writers. - VELOX_USER_CHECK_LE( - writers_.size(), maxOpenWriters_, "Exceeded open writer limit"); - VELOX_CHECK_EQ(writers_.size(), writerInfo_.size()); - VELOX_CHECK_EQ(writerIndexMap_.size(), writerInfo_.size()); - - std::optional partitionName; - if (isPartitioned()) { - partitionName = - partitionIdGenerator_->partitionName(id.partitionId.value()); - } - - // Without explicitly setting flush policy, the default memory based flush - // policy is used. - auto writerParameters = getWriterParameters(partitionName, id.bucketId); - const auto writePath = fs::path(writerParameters.writeDirectory()) / - writerParameters.writeFileName(); - auto writerPool = createWriterPool(id); - auto sinkPool = createSinkPool(writerPool); - std::shared_ptr sortPool{nullptr}; - if (sortWrite()) { - sortPool = createSortPool(writerPool); - } - writerInfo_.emplace_back(std::make_shared( - std::move(writerParameters), - std::move(writerPool), - std::move(sinkPool), - std::move(sortPool))); - ioStats_.emplace_back(std::make_shared()); - setMemoryReclaimers(writerInfo_.back().get(), ioStats_.back().get()); - +std::shared_ptr HiveDataSink::createWriterOptions() + const { // Take the writer options provided by the user as a starting point, or // allocate a new one. auto options = insertTableHandle_->writerOptions(); @@ -789,10 +835,11 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { options->spillConfig = spillConfig_; } - if (options->nonReclaimableSection == nullptr) { - options->nonReclaimableSection = - writerInfo_.back()->nonReclaimableSectionHolder.get(); - } + // Always set nonReclaimableSection to the current writer's holder. + // Since insertTableHandle_->writerOptions() returns a shared_ptr, we need + // to ensure each writer has its own nonReclaimableSection pointer. + options->nonReclaimableSection = + writerInfo_.back()->nonReclaimableSectionHolder.get(); if (options->memoryReclaimerFactory == nullptr || options->memoryReclaimerFactory() == nullptr) { @@ -811,6 +858,43 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { options->adjustTimestampToTimezone = connectorQueryCtx_->adjustTimestampToTimezone(); options->processConfigs(*hiveConfig_->config(), *connectorSessionProperties); + return options; +} + +uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { + // Check max open writers. + VELOX_USER_CHECK_LE( + writers_.size(), maxOpenWriters_, "Exceeded open writer limit"); + VELOX_CHECK_EQ(writers_.size(), writerInfo_.size()); + VELOX_CHECK_EQ(writerIndexMap_.size(), writerInfo_.size()); + + std::optional partitionName; + if (isPartitioned()) { + partitionName = getPartitionName(id.partitionId.value()); + } + + // Without explicitly setting flush policy, the default memory based flush + // policy is used. + auto writerParameters = getWriterParameters(partitionName, id.bucketId); + const auto writePath = fs::path(writerParameters.writeDirectory()) / + writerParameters.writeFileName(); + auto writerPool = createWriterPool(id); + auto sinkPool = createSinkPool(writerPool); + std::shared_ptr sortPool{nullptr}; + if (sortWrite()) { + sortPool = createSortPool(writerPool); + } + writerInfo_.emplace_back( + std::make_shared( + std::move(writerParameters), + std::move(writerPool), + std::move(sinkPool), + std::move(sortPool))); + ioStats_.emplace_back(std::make_unique()); + + setMemoryReclaimers(writerInfo_.back().get(), ioStats_.back().get()); + + auto options = createWriterOptions(); // Prevents the memory allocation during the writer creation. WRITER_NON_RECLAIMABLE_SECTION_GUARD(writerInfo_.size() - 1); @@ -824,10 +908,16 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { .pool = writerInfo_.back()->sinkPool.get(), .metricLogger = dwio::common::MetricsLog::voidLog(), .stats = ioStats_.back().get(), + .fileSystemStats = fileSystemStats_.get(), }), options); writer = maybeCreateBucketSortWriter(std::move(writer)); writers_.emplace_back(std::move(writer)); + addThreadLocalRuntimeStat( + fmt::format( + "{}WriterCount", + dwio::common::toString(insertTableHandle_->storageFormat())), + RuntimeCounter(1)); // Extends the buffer used for partition rows calculations. partitionSizes_.emplace_back(0); partitionRows_.emplace_back(nullptr); @@ -837,6 +927,15 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) { return writerIndexMap_[id]; } +std::string HiveDataSink::getPartitionName(uint32_t partitionId) const { + VELOX_CHECK_NOT_NULL(partitionIdGenerator_); + + return HivePartitionName::partitionName( + partitionId, + partitionIdGenerator_->partitionValues(), + partitionKeyAsLowerCase_); +} + std::unique_ptr HiveDataSink::maybeCreateBucketSortWriter( std::unique_ptr writer) { @@ -878,6 +977,24 @@ HiveWriterId HiveDataSink::getWriterId(size_t row) const { return HiveWriterId{partitionId, bucketId}; } +void HiveDataSink::updatePartitionRows( + uint32_t index, + vector_size_t numRows, + vector_size_t row) { + VELOX_DCHECK_LT(index, partitionSizes_.size()); + VELOX_DCHECK_EQ(partitionSizes_.size(), partitionRows_.size()); + VELOX_DCHECK_EQ(partitionRows_.size(), rawPartitionRows_.size()); + if (FOLLY_UNLIKELY(partitionRows_[index] == nullptr) || + (partitionRows_[index]->capacity() < numRows * sizeof(vector_size_t))) { + partitionRows_[index] = + allocateIndices(numRows, connectorQueryCtx_->memoryPool()); + rawPartitionRows_[index] = + partitionRows_[index]->asMutable(); + } + rawPartitionRows_[index][partitionSizes_[index]] = row; + ++partitionSizes_[index]; +} + void HiveDataSink::splitInputRowsAndEnsureWriters() { VELOX_CHECK(isPartitioned() || isBucketed()); if (isBucketed() && isPartitioned()) { @@ -891,19 +1008,7 @@ void HiveDataSink::splitInputRowsAndEnsureWriters() { for (auto row = 0; row < numRows; ++row) { const auto id = getWriterId(row); const uint32_t index = ensureWriter(id); - - VELOX_DCHECK_LT(index, partitionSizes_.size()); - VELOX_DCHECK_EQ(partitionSizes_.size(), partitionRows_.size()); - VELOX_DCHECK_EQ(partitionRows_.size(), rawPartitionRows_.size()); - if (FOLLY_UNLIKELY(partitionRows_[index] == nullptr) || - (partitionRows_[index]->capacity() < numRows * sizeof(vector_size_t))) { - partitionRows_[index] = - allocateIndices(numRows, connectorQueryCtx_->memoryPool()); - rawPartitionRows_[index] = - partitionRows_[index]->asMutable(); - } - rawPartitionRows_[index][partitionSizes_[index]] = row; - ++partitionSizes_[index]; + updatePartitionRows(index, numRows, row); } for (uint32_t i = 0; i < partitionSizes_.size(); ++i) { @@ -932,6 +1037,17 @@ HiveWriterParameters HiveDataSink::getWriterParameters( std::pair HiveDataSink::getWriterFileNames( std::optional bucketId) const { + if (auto hiveInsertFileNameGenerator = + std::dynamic_pointer_cast( + fileNameGenerator_)) { + return hiveInsertFileNameGenerator->gen( + bucketId, + insertTableHandle_, + *connectorQueryCtx_, + hiveConfig_, + isCommitRequired()); + } + return fileNameGenerator_->gen( bucketId, insertTableHandle_, *connectorQueryCtx_, isCommitRequired()); } @@ -941,13 +1057,33 @@ std::pair HiveInsertFileNameGenerator::gen( const std::shared_ptr insertTableHandle, const ConnectorQueryCtx& connectorQueryCtx, bool commitRequired) const { + auto defaultHiveConfig = + std::make_shared(std::make_shared( + std::unordered_map())); + + return this->gen( + bucketId, + insertTableHandle, + connectorQueryCtx, + defaultHiveConfig, + commitRequired); +} + +std::pair HiveInsertFileNameGenerator::gen( + std::optional bucketId, + const std::shared_ptr insertTableHandle, + const ConnectorQueryCtx& connectorQueryCtx, + const std::shared_ptr& hiveConfig, + bool commitRequired) const { auto targetFileName = insertTableHandle->locationHandle()->targetFileName(); const bool generateFileName = targetFileName.empty(); if (bucketId.has_value()) { VELOX_CHECK(generateFileName); // TODO: add hive.file_renaming_enabled support. - targetFileName = - computeBucketedFileName(connectorQueryCtx.queryId(), bucketId.value()); + targetFileName = computeBucketedFileName( + connectorQueryCtx.queryId(), + hiveConfig->maxBucketCount(connectorQueryCtx.sessionProperties()), + bucketId.value()); } else if (generateFileName) { // targetFileName includes planNodeId and Uuid. As a result, different // table writers run by the same task driver or the same table writer diff --git a/velox/connectors/hive/HiveDataSink.h b/velox/connectors/hive/HiveDataSink.h index c1354b1ca6b3..bf1ad3b6c39f 100644 --- a/velox/connectors/hive/HiveDataSink.h +++ b/velox/connectors/hive/HiveDataSink.h @@ -18,6 +18,7 @@ #include "velox/common/compression/Compression.h" #include "velox/connectors/Connector.h" #include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/HivePartitionName.h" #include "velox/connectors/hive/PartitionIdGenerator.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/dwio/common/Options.h" @@ -25,10 +26,6 @@ #include "velox/dwio/common/WriterFactory.h" #include "velox/exec/MemoryReclaimer.h" -namespace facebook::velox::dwrf { -class Writer; -} - namespace facebook::velox::connector::hive { class LocationHandle; @@ -217,6 +214,15 @@ class HiveInsertFileNameGenerator : public FileNameGenerator { const ConnectorQueryCtx& connectorQueryCtx, bool commitRequired) const override; + /// Version of file generation that takes hiveConfig into account when + /// generating file names + std::pair gen( + std::optional bucketId, + const std::shared_ptr insertTableHandle, + const ConnectorQueryCtx& connectorQueryCtx, + const std::shared_ptr& hiveConfig, + bool commitRequired) const; + static void registerSerDe(); folly::dynamic serialize() const override; @@ -245,37 +251,7 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { // engine handles ensuring a 1 to 1 mapping from task to bucket. const bool ensureFiles = false, std::shared_ptr fileNameGenerator = - std::make_shared()) - : inputColumns_(std::move(inputColumns)), - locationHandle_(std::move(locationHandle)), - storageFormat_(storageFormat), - bucketProperty_(std::move(bucketProperty)), - compressionKind_(compressionKind), - serdeParameters_(serdeParameters), - writerOptions_(writerOptions), - ensureFiles_(ensureFiles), - fileNameGenerator_(std::move(fileNameGenerator)) { - if (compressionKind.has_value()) { - VELOX_CHECK( - compressionKind.value() != common::CompressionKind_MAX, - "Unsupported compression type: CompressionKind_MAX"); - } - - if (ensureFiles_) { - // If ensureFiles is set and either the bucketProperty is set or some - // partition keys are in the data, there is not a 1:1 mapping from Task to - // files so we can't proactively create writers. - VELOX_CHECK( - bucketProperty_ == nullptr || bucketProperty_->bucketCount() == 0, - "ensureFiles is not supported with bucketing"); - - for (const auto& inputColumn : inputColumns_) { - VELOX_CHECK( - !inputColumn->isPartitionKey(), - "ensureFiles is not supported with partition keys in the data"); - } - } - } + std::make_shared()); virtual ~HiveInsertTableHandle() = default; @@ -324,6 +300,16 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { bool isExistingTable() const; + /// Returns a subset of column indices corresponding to partition keys. + const std::vector& partitionChannels() const { + return partitionChannels_; + } + + /// Returns the column indices of non-partition data columns. + const std::vector& nonPartitionChannels() const { + return nonPartitionChannels_; + } + folly::dynamic serialize() const override; static HiveInsertTableHandlePtr create(const folly::dynamic& obj); @@ -332,9 +318,11 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { std::string toString() const override; - private: + protected: const std::vector> inputColumns_; const std::shared_ptr locationHandle_; + + private: const dwio::common::FileFormat storageFormat_; const std::shared_ptr bucketProperty_; const std::optional compressionKind_; @@ -342,6 +330,8 @@ class HiveInsertTableHandle : public ConnectorInsertTableHandle { const std::shared_ptr writerOptions_; const bool ensureFiles_; const std::shared_ptr fileNameGenerator_; + const std::vector partitionChannels_; + const std::vector nonPartitionChannels_; }; /// Parameters for Hive writers. @@ -509,6 +499,17 @@ class HiveDataSink : public DataSink { }; static std::string stateString(State state); + /// Creates a HiveDataSink for writing data to Hive table files. + /// + /// @param inputType The schema of input data rows to be written. + /// @param insertTableHandle Metadata about the table write operation, + /// including storage format, compression, bucketing, and partitioning + /// configuration. + /// @param connectorQueryCtx Query context with session properties, memory + /// pools, and spill configuration. + /// @param commitStrategy Strategy for committing written data (kNoCommit or + /// kTaskCommit). + /// @param hiveConfig Hive connector configuration. HiveDataSink( RowTypePtr inputType, std::shared_ptr insertTableHandle, @@ -516,6 +517,31 @@ class HiveDataSink : public DataSink { CommitStrategy commitStrategy, const std::shared_ptr& hiveConfig); + /// Constructor with explicit bucketing and partitioning parameters. + /// + /// @param inputType The schema of input data rows to be written. + /// @param insertTableHandle Metadata about the table write operation, + /// including storage format, compression, location, and serialization + /// parameters. + /// @param connectorQueryCtx Query context with session properties, memory + /// pools, and spill configuration. + /// @param commitStrategy Strategy for committing written data (kNoCommit or + /// kTaskCommit). Determines whether temporary files need to be renamed on + /// commit. + /// @param hiveConfig Hive connector configuration with settings for max + /// partitions, bucketing limits etc. + /// @param bucketCount Number of buckets for bucketed tables (0 if not + /// bucketed). Must be less than the configured max bucket count. + /// @param bucketFunction Function to compute bucket IDs from row data + /// (nullptr if not bucketed). Used to distribute rows across buckets. + /// @param partitionChannels Column indices used for partitioning (empty if + /// not partitioned). These columns are extracted to determine partition + /// directories. + /// @param dataChannels Column indices for the actual data columns to be + /// written. + /// @param partitionIdGenerator Generates partition IDs from partition column + /// values (nullptr if not partitioned). Compute partition key combinations to + /// unique IDs. HiveDataSink( RowTypePtr inputType, std::shared_ptr insertTableHandle, @@ -523,12 +549,10 @@ class HiveDataSink : public DataSink { CommitStrategy commitStrategy, const std::shared_ptr& hiveConfig, uint32_t bucketCount, - std::unique_ptr bucketFunction); - - static uint32_t maxBucketCount() { - static const uint32_t kMaxBucketCount = 100'000; - return kMaxBucketCount; - } + std::unique_ptr bucketFunction, + const std::vector& partitionChannels, + const std::vector& dataChannels, + std::unique_ptr partitionIdGenerator); void appendData(RowVectorPtr input) override; @@ -536,17 +560,27 @@ class HiveDataSink : public DataSink { Stats stats() const override; + std::unordered_map runtimeStats() const override; + std::vector close() override; void abort() override; bool canReclaim() const; - private: + protected: // Validates the state transition from 'oldState' to 'newState'. void checkStateTransition(State oldState, State newState); void setState(State newState); + // Generates commit messages for all writers containing metadata about written + // files. Creates a JSON object for each writer with partition name, + // file paths, file names, data sizes, and row counts. This metadata is used + // by the coordinator to commit the transaction and update the metastore. + // + // @return Vector of JSON strings, one per writer. + virtual std::vector commitMessage() const; + class WriterReclaimer : public exec::MemoryReclaimer { public: static std::unique_ptr create( @@ -609,11 +643,11 @@ class HiveDataSink : public DataSink { io::IoStatistics* ioStats); // Compute the partition id and bucket id for each row in 'input'. - void computePartitionAndBucketIds(const RowVectorPtr& input); + virtual void computePartitionAndBucketIds(const RowVectorPtr& input); // Get the HiveWriter corresponding to the row // from partitionIds and bucketIds. - FOLLY_ALWAYS_INLINE HiveWriterId getWriterId(size_t row) const; + HiveWriterId getWriterId(size_t row) const; // Computes the number of input rows as well as the actual input row indices // to each corresponding (bucketed) partition based on the partition and @@ -623,16 +657,36 @@ class HiveDataSink : public DataSink { // Makes sure to create one writer for the given writer id. The function // returns the corresponding index in 'writers_'. - uint32_t ensureWriter(const HiveWriterId& id); + virtual uint32_t ensureWriter(const HiveWriterId& id); // Appends a new writer for the given 'id'. The function returns the index of // the newly created writer in 'writers_'. uint32_t appendWriter(const HiveWriterId& id); + // Creates and configures WriterOptions based on file format. + // Sets up compression, schema, and other writer configuration based on the + // insert table handle and connector settings. + virtual std::shared_ptr createWriterOptions() + const; + + // Returns the Hive partition directory name for the given partition ID. + // Converts the partition values associated with the partition ID into a + // Hive-formatted directory path. Returns std::nullopt if the table is + // unpartitioned. Should be called only when writing to a partitioned table. + virtual std::string getPartitionName(uint32_t partitionId) const; + std::unique_ptr maybeCreateBucketSortWriter( std::unique_ptr writer); + // Records a row index for a specific partition. This method maintains the + // mapping of which input rows belong to which partition by storing row + // indices in partition-specific buffers. If the buffer for the partition + // doesn't exist or is too small, it allocates/reallocates the buffer to + // accommodate all rows. + void + updatePartitionRows(uint32_t index, vector_size_t numRows, vector_size_t row); + HiveWriterParameters getWriterParameters( const std::optional& partition, std::optional bucketId) const; @@ -658,6 +712,15 @@ class HiveDataSink : public DataSink { void closeInternal(); + // IMPORTANT NOTE: these are passed to writers as raw pointers. HiveDataSink + // owns the lifetime of these objects, and therefore must destroy them last. + // Additionally, we must assume that no objects which hold a reference to + // these stats will outlive the HiveDataSink instance. This is a reasonable + // assumption given the semantics of these stats objects. + std::vector> ioStats_; + // Generic filesystem stats, exposed as RuntimeStats + std::unique_ptr fileSystemStats_; + const RowTypePtr inputType_; const std::shared_ptr insertTableHandle_; const ConnectorQueryCtx* const connectorQueryCtx_; @@ -674,6 +737,7 @@ class HiveDataSink : public DataSink { const std::shared_ptr writerFactory_; const common::SpillConfig* const spillConfig_; const uint64_t sortWriterFinishTimeSliceLimitMs_{0}; + const bool partitionKeyAsLowerCase_; std::vector sortColumnIndices_; std::vector sortCompareFlags_; @@ -690,8 +754,6 @@ class HiveDataSink : public DataSink { // writers_ are both indexed by partitionId. std::vector> writerInfo_; std::vector> writers_; - // IO statistics collected for each writer. - std::vector> ioStats_; // Below are structures updated when processing current input. partitionIds_ // are indexed by the row of input_. partitionRows_, rawPartitionRows_ and diff --git a/velox/connectors/hive/HiveDataSource.cpp b/velox/connectors/hive/HiveDataSource.cpp index 5ee0a1fc9b94..8ae61574eb5a 100644 --- a/velox/connectors/hive/HiveDataSource.cpp +++ b/velox/connectors/hive/HiveDataSource.cpp @@ -22,16 +22,12 @@ #include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConfig.h" -#include "velox/dwio/common/ReaderFactory.h" #include "velox/expression/FieldReference.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::connector::hive { -class HiveTableHandle; -class HiveColumnHandle; - namespace { bool isMember( @@ -54,56 +50,99 @@ bool shouldEagerlyMaterialize( return false; } +void checkColumnHandleConsistent( + const HiveColumnHandle& x, + const HiveColumnHandle& y) { + VELOX_CHECK( + x.columnType() == y.columnType(), + "Inconsistent column handle: {}", + x.name()); + VELOX_CHECK_EQ( + *x.dataType(), *y.dataType(), "Inconsistent column handle: {}", x.name()); + VELOX_CHECK_EQ( + *x.hiveType(), *y.hiveType(), "Inconsistent column handle: {}", x.name()); +} + } // namespace +void HiveDataSource::processColumnHandle(const HiveColumnHandlePtr& handle) { + switch (handle->columnType()) { + case HiveColumnHandle::ColumnType::kRegular: + break; + case HiveColumnHandle::ColumnType::kPartitionKey: + partitionKeys_.emplace(handle->name(), handle); + break; + case HiveColumnHandle::ColumnType::kSynthesized: + infoColumns_.emplace(handle->name(), handle); + break; + case HiveColumnHandle::ColumnType::kRowIndex: + specialColumns_.rowIndex = handle->name(); + break; + case HiveColumnHandle::ColumnType::kRowId: + specialColumns_.rowId = handle->name(); + break; + } +} + HiveDataSource::HiveDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& assignments, FileHandleFactory* fileHandleFactory, - folly::Executor* executor, + folly::Executor* ioExecutor, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig) : fileHandleFactory_(fileHandleFactory), - executor_(executor), + ioExecutor_(ioExecutor), connectorQueryCtx_(connectorQueryCtx), hiveConfig_(hiveConfig), pool_(connectorQueryCtx->memoryPool()), outputType_(outputType), expressionEvaluator_(connectorQueryCtx->expressionEvaluator()) { + hiveTableHandle_ = + std::dynamic_pointer_cast(tableHandle); + VELOX_CHECK_NOT_NULL( + hiveTableHandle_, "TableHandle must be an instance of HiveTableHandle"); + + folly::F14FastMap columnHandles; // Column handled keyed on the column alias, the name used in the query. - for (const auto& [canonicalizedName, columnHandle] : columnHandles) { - auto handle = std::dynamic_pointer_cast(columnHandle); + for (const auto& [canonicalizedName, columnHandle] : assignments) { + auto handle = + std::dynamic_pointer_cast(columnHandle); VELOX_CHECK_NOT_NULL( handle, "ColumnHandle must be an instance of HiveColumnHandle for {}", canonicalizedName); - switch (handle->columnType()) { - case HiveColumnHandle::ColumnType::kRegular: - break; - case HiveColumnHandle::ColumnType::kPartitionKey: - partitionKeys_.emplace(handle->name(), handle); - break; - case HiveColumnHandle::ColumnType::kSynthesized: - infoColumns_.emplace(handle->name(), handle); - break; - case HiveColumnHandle::ColumnType::kRowIndex: - specialColumns_.rowIndex = handle->name(); - break; - case HiveColumnHandle::ColumnType::kRowId: - specialColumns_.rowId = handle->name(); - break; + const auto [it, unique] = + columnHandles.emplace(handle->name(), handle.get()); + if (!unique) { + // This should not happen normally, but there is some bug in Presto DELETE + // queries that sometimes we do get duplicate assignments for partitioning + // columns. + checkColumnHandleConsistent(*handle, *it->second); + VELOX_CHECK( + handle->columnType() == HiveColumnHandle::ColumnType::kPartitionKey, + "Cannot map from same table column to different outputs in table scan; a project node should be used instead: {}", + handle->name()); + continue; + } + processColumnHandle(handle); + } + for (auto& handle : hiveTableHandle_->filterColumnHandles()) { + auto it = columnHandles.find(handle->name()); + if (it != columnHandles.end()) { + checkColumnHandleConsistent(*handle, *it->second); + continue; } + processColumnHandle(handle); } std::vector readColumnNames; auto readColumnTypes = outputType_->children(); for (const auto& outputName : outputType_->names()) { - auto it = columnHandles.find(outputName); + auto it = assignments.find(outputName); VELOX_CHECK( - it != columnHandles.end(), + it != assignments.end(), "ColumnHandle is missing for output column: {}", outputName); @@ -116,11 +155,9 @@ HiveDataSource::HiveDataSource( "Required subfield does not match column name"); subfields_[handle->name()].push_back(&subfield); } + columnPostProcessors_.push_back(handle->postProcessor()); } - hiveTableHandle_ = std::dynamic_pointer_cast(tableHandle); - VELOX_CHECK_NOT_NULL( - hiveTableHandle_, "TableHandle must be an instance of HiveTableHandle"); if (hiveConfig_->isFileColumnNamesReadAsLowerCase( connectorQueryCtx->sessionProperties())) { checkColumnNameLowerCase(outputType_); @@ -129,13 +166,12 @@ HiveDataSource::HiveDataSource( } for (const auto& [k, v] : hiveTableHandle_->subfieldFilters()) { - filters_.emplace(k.clone(), v->clone()); + filters_.emplace(k.clone(), v); } - double sampleRate = 1; + double sampleRate = hiveTableHandle_->sampleRate(); auto remainingFilter = extractFiltersFromRemainingFilter( hiveTableHandle_->remainingFilter(), expressionEvaluator_, - false, filters_, sampleRate); if (sampleRate != 1) { @@ -217,7 +253,7 @@ std::unique_ptr HiveDataSource::createSplitReader() { ioStats_, fsStats_, fileHandleFactory_, - executor_, + ioExecutor_, scanSpec_); } @@ -403,8 +439,11 @@ std::optional HiveDataSource::next( // don't need to reallocate the result for every batch. child->disableMemo(); } - outputColumns.emplace_back( - exec::wrapChild(rowsRemaining, remainingIndices, child)); + auto column = exec::wrapChild(rowsRemaining, remainingIndices, child); + if (columnPostProcessors_[i]) { + columnPostProcessors_[i](column); + } + outputColumns.push_back(std::move(column)); } return std::make_shared( @@ -415,67 +454,80 @@ void HiveDataSource::addDynamicFilter( column_index_t outputChannel, const std::shared_ptr& filter) { auto& fieldSpec = scanSpec_->getChildByChannel(outputChannel); - fieldSpec.addFilter(*filter); + fieldSpec.setFilter(filter); scanSpec_->resetCachedValues(true); if (splitReader_) { splitReader_->resetFilterCaches(); } } -std::unordered_map HiveDataSource::runtimeStats() { - auto res = runtimeStats_.toMap(); +std::unordered_map +HiveDataSource::getRuntimeStats() { + auto res = runtimeStats_.toRuntimeMetricMap(); res.insert( - {{"numPrefetch", RuntimeCounter(ioStats_->prefetch().count())}, + {{"numPrefetch", RuntimeMetric(ioStats_->prefetch().count())}, {"prefetchBytes", - RuntimeCounter( - ioStats_->prefetch().sum(), RuntimeCounter::Unit::kBytes)}, + RuntimeMetric( + ioStats_->prefetch().sum(), + ioStats_->prefetch().count(), + ioStats_->prefetch().min(), + ioStats_->prefetch().max(), + RuntimeCounter::Unit::kBytes)}, {"totalScanTime", - RuntimeCounter( - ioStats_->totalScanTime(), RuntimeCounter::Unit::kNanos)}, - {"totalRemainingFilterTime", - RuntimeCounter( + RuntimeMetric(ioStats_->totalScanTime(), RuntimeCounter::Unit::kNanos)}, + {Connector::kTotalRemainingFilterTime, + RuntimeMetric( totalRemainingFilterTime_.load(std::memory_order_relaxed), RuntimeCounter::Unit::kNanos)}, {"ioWaitWallNanos", - RuntimeCounter( + RuntimeMetric( ioStats_->queryThreadIoLatency().sum() * 1000, - RuntimeCounter::Unit::kNanos)}, - {"maxSingleIoWaitWallNanos", - RuntimeCounter( + ioStats_->queryThreadIoLatency().count(), + ioStats_->queryThreadIoLatency().min() * 1000, ioStats_->queryThreadIoLatency().max() * 1000, RuntimeCounter::Unit::kNanos)}, {"overreadBytes", - RuntimeCounter( + RuntimeMetric( ioStats_->rawOverreadBytes(), RuntimeCounter::Unit::kBytes)}}); if (ioStats_->read().count() > 0) { - res.insert({"numStorageRead", RuntimeCounter(ioStats_->read().count())}); res.insert( {"storageReadBytes", - RuntimeCounter(ioStats_->read().sum(), RuntimeCounter::Unit::kBytes)}); + RuntimeMetric( + ioStats_->read().sum(), + ioStats_->read().count(), + ioStats_->read().min(), + ioStats_->read().max(), + RuntimeCounter::Unit::kBytes)}); } if (ioStats_->ssdRead().count() > 0) { - res.insert({"numLocalRead", RuntimeCounter(ioStats_->ssdRead().count())}); + res.insert({"numLocalRead", RuntimeMetric(ioStats_->ssdRead().count())}); res.insert( {"localReadBytes", - RuntimeCounter( - ioStats_->ssdRead().sum(), RuntimeCounter::Unit::kBytes)}); + RuntimeMetric( + ioStats_->ssdRead().sum(), + ioStats_->ssdRead().count(), + ioStats_->ssdRead().min(), + ioStats_->ssdRead().max(), + RuntimeCounter::Unit::kBytes)}); } if (ioStats_->ramHit().count() > 0) { - res.insert({"numRamRead", RuntimeCounter(ioStats_->ramHit().count())}); + res.insert({"numRamRead", RuntimeMetric(ioStats_->ramHit().count())}); res.insert( {"ramReadBytes", - RuntimeCounter( - ioStats_->ramHit().sum(), RuntimeCounter::Unit::kBytes)}); + RuntimeMetric( + ioStats_->ramHit().sum(), + ioStats_->ramHit().count(), + ioStats_->ramHit().min(), + ioStats_->ramHit().max(), + RuntimeCounter::Unit::kBytes)}); } if (numBucketConversion_ > 0) { - res.insert({"numBucketConversion", RuntimeCounter(numBucketConversion_)}); + res.insert({"numBucketConversion", RuntimeMetric(numBucketConversion_)}); } const auto fsStats = fsStats_->stats(); for (const auto& storageStats : fsStats) { - res.emplace( - storageStats.first, - RuntimeCounter(storageStats.second.sum, storageStats.second.unit)); + res.emplace(storageStats.first, storageStats.second); } return res; } @@ -492,6 +544,7 @@ void HiveDataSource::setFromDataSource( readerOutputType_ = std::move(source->readerOutputType_); source->scanSpec_->moveAdaptationFrom(*scanSpec_); scanSpec_ = std::move(source->scanSpec_); + metadataFilter_ = std::move(source->metadataFilter_); splitReader_ = std::move(source->splitReader_); splitReader_->setConnectorQueryCtx(connectorQueryCtx_); // New io will be accounted on the stats of 'source'. Add the existing @@ -550,7 +603,7 @@ std::shared_ptr HiveDataSource::toWaveDataSource() { readerOutputType_, &partitionKeys_, fileHandleFactory_, - executor_, + ioExecutor_, connectorQueryCtx_, hiveConfig_, ioStats_, diff --git a/velox/connectors/hive/HiveDataSource.h b/velox/connectors/hive/HiveDataSource.h index 29b31dcee1e6..b019421eccf7 100644 --- a/velox/connectors/hive/HiveDataSource.h +++ b/velox/connectors/hive/HiveDataSource.h @@ -36,12 +36,10 @@ class HiveDataSource : public DataSource { public: HiveDataSource( const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& assignments, FileHandleFactory* fileHandleFactory, - folly::Executor* executor, + folly::Executor* ioExecutor, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig); @@ -62,7 +60,7 @@ class HiveDataSource : public DataSource { return completedRows_; } - std::unordered_map runtimeStats() override; + std::unordered_map getRuntimeStats() override; bool allPrefetchIssued() const override { return splitReader_ && splitReader_->allPrefetchIssued(); @@ -72,15 +70,18 @@ class HiveDataSource : public DataSource { int64_t estimatedRowSize() override; + const common::SubfieldFilters* getFilters() const override { + return &filters_; + } + std::shared_ptr toWaveDataSource() override; using WaveDelegateHookFunction = std::function( - const std::shared_ptr& hiveTableHandle, + const HiveTableHandlePtr& hiveTableHandle, const std::shared_ptr& scanSpec, const RowTypePtr& readerOutputType, - std::unordered_map>* - partitionKeys, + std::unordered_map* partitionKeys, FileHandleFactory* fileHandleFactory, folly::Executor* executor, const ConnectorQueryCtx* connectorQueryCtx, @@ -101,13 +102,13 @@ class HiveDataSource : public DataSource { virtual std::unique_ptr createSplitReader(); FileHandleFactory* const fileHandleFactory_; - folly::Executor* const executor_; + folly::Executor* const ioExecutor_; const ConnectorQueryCtx* const connectorQueryCtx_; const std::shared_ptr hiveConfig_; memory::MemoryPool* const pool_; std::shared_ptr split_; - std::shared_ptr hiveTableHandle_; + HiveTableHandlePtr hiveTableHandle_; std::shared_ptr scanSpec_; VectorPtr output_; std::unique_ptr splitReader_; @@ -119,8 +120,7 @@ class HiveDataSource : public DataSource { // Column handles for the partition key columns keyed on partition key column // name. - std::unordered_map> - partitionKeys_; + std::unordered_map partitionKeys_; std::shared_ptr ioStats_; std::shared_ptr fsStats_; @@ -147,17 +147,21 @@ class HiveDataSource : public DataSource { return emptyOutput_; } + // Add the information from column handle to the corresponding fields in this + // object. + void processColumnHandle(const HiveColumnHandlePtr& handle); + // The row type for the data source output, not including filter-only columns const RowTypePtr outputType_; core::ExpressionEvaluator* const expressionEvaluator_; // Column handles for the Split info columns keyed on their column names. - std::unordered_map> - infoColumns_; + std::unordered_map infoColumns_; SpecialColumnNames specialColumns_{}; std::vector remainingFilterSubfields_; folly::F14FastMap> subfields_; + std::vector> columnPostProcessors_; common::SubfieldFilters filters_; std::shared_ptr metadataFilter_; std::unique_ptr remainingFilterExprSet_; diff --git a/velox/connectors/hive/HivePartitionFunction.cpp b/velox/connectors/hive/HivePartitionFunction.cpp index d273cc8163e3..548aff99e15d 100644 --- a/velox/connectors/hive/HivePartitionFunction.cpp +++ b/velox/connectors/hive/HivePartitionFunction.cpp @@ -15,6 +15,8 @@ */ #include "velox/connectors/hive/HivePartitionFunction.h" +#include + namespace facebook::velox::connector::hive { namespace { @@ -31,8 +33,7 @@ int32_t hashInt64(int64_t value) { __attribute__((no_sanitize("integer"))) #endif #endif -uint32_t -hashBytes(StringView bytes, int32_t initialValue) { +uint32_t hashBytes(StringView bytes, int32_t initialValue) { uint32_t hash = initialValue; auto* data = bytes.data(); for (auto i = 0; i < bytes.size(); ++i) { @@ -461,7 +462,7 @@ HivePartitionFunction::HivePartitionFunction( std::vector keyChannels, const std::vector& constValues) : numBuckets_{numBuckets}, - bucketToPartition_{bucketToPartition}, + bucketToPartition_{std::move(bucketToPartition)}, keyChannels_{std::move(keyChannels)} { precomputedHashes_.resize(keyChannels_.size()); size_t constChannel{0}; @@ -495,7 +496,7 @@ std::optional HivePartitionFunction::partition( } } - static const int32_t kInt32Max = std::numeric_limits::max(); + static constexpr int32_t kInt32Max = std::numeric_limits::max(); if (bucketToPartition_.empty()) { // NOTE: if bucket to partition mapping is empty, then we do diff --git a/velox/connectors/hive/HivePartitionName.cpp b/velox/connectors/hive/HivePartitionName.cpp new file mode 100644 index 000000000000..2d7e7b9b665b --- /dev/null +++ b/velox/connectors/hive/HivePartitionName.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/HivePartitionName.h" +#include "velox/common/encode/Base64.h" +#include "velox/dwio/catalog/fbhive/FileUtils.h" +#include "velox/type/DecimalUtil.h" + +namespace facebook::velox::connector::hive { + +using namespace facebook::velox::dwio::catalog::fbhive; + +namespace { + +template +std::string formatDecimal(T value, const TypePtr& type) { + const auto& [p, s] = getDecimalPrecisionScale(*type); + const auto& maxSize = DecimalUtil::maxStringViewSize(p, s); + std::string buffer(maxSize, '\0'); + const auto& actualSize = + DecimalUtil::castToString(value, s, maxSize, buffer.data()); + buffer.resize(actualSize); + return buffer; +} + +} // namespace + +std::string HivePartitionName::toName(int32_t value, const TypePtr& type) { + if (type->isDate()) { + return DateType::toIso8601(value); + } + return fmt::to_string(value); +} + +std::string HivePartitionName::toName(int64_t value, const TypePtr& type) { + if (type->isShortDecimal()) { + return formatDecimal(value, type); + } + return fmt::to_string(value); +} + +std::string HivePartitionName::toName(int128_t value, const TypePtr& type) { + if (type->isLongDecimal()) { + return formatDecimal(value, type); + } + return fmt::to_string(value); +} + +std::string HivePartitionName::toName(Timestamp value, const TypePtr& type) { + value.toTimezone(Timestamp::defaultTimezone()); + TimestampToStringOptions options; + options.dateTimeSeparator = ' '; + // Set the precision to milliseconds, and enable the skipTrailingZeros match + // the timestamp precision and truncation behavior of Presto. + options.precision = TimestampPrecision::kMilliseconds; + options.skipTrailingZeros = true; + + auto result = value.toString(options); + + // Presto's java.sql.Timestamp.toString() always keeps at least one decimal + // place even when all fractional seconds are zero. + // If skipTrailingZeros removed all fractional digits, add back ".0" to match + // Presto's behavior. + if (auto dotPos = result.find_last_of('.'); dotPos == std::string::npos) { + // No decimal point found, add ".0" + result += ".0"; + } + + return result; +} + +std::string HivePartitionName::partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase) { + auto toPartitionName = + [](auto value, const TypePtr& type, int /*columnIndex*/) { + return HivePartitionName::toName(value, type); + }; + return FileUtils::makePartName( + partitionKeyValues( + partitionId, + partitionValues, + /*nullValueString=*/"", + toPartitionName), + partitionKeyAsLowerCase); +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HivePartitionName.h b/velox/connectors/hive/HivePartitionName.h new file mode 100644 index 000000000000..3e519528866c --- /dev/null +++ b/velox/connectors/hive/HivePartitionName.h @@ -0,0 +1,172 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/vector/ComplexVector.h" +#include "velox/vector/SimpleVector.h" + +namespace facebook::velox::connector::hive { + +/// Converting partition values to their string representations. +/// Provides template methods for formatting different data types according to +/// Hive partitioning conventions. +class HivePartitionName { + public: + /// Generic template for formatting partition values to strings using + /// fmt::to_string. Specialized for types that need special handling + /// (int32_t, int64_t, int128_t, Timestamp). + template + FOLLY_ALWAYS_INLINE static std::string toName(T value, const TypePtr& type) { + return fmt::to_string(value); + } + + /// Format int32_t partition values. Specialized to handle DATE type which + /// requires ISO-8601 formatting (YYYY-MM-DD) instead of raw integer value. + static std::string toName(int32_t value, const TypePtr& type); + + /// Format int64_t partition values. Specialized to handle short DECIMAL type + /// which requires decimal string formatting with proper precision and scale + /// instead of raw integer value. + static std::string toName(int64_t value, const TypePtr& type); + + /// Format int128_t partition values. Specialized to handle long DECIMAL type + /// which requires decimal string formatting with proper precision and scale + /// instead of raw integer value. + static std::string toName(int128_t value, const TypePtr& type); + + /// Format Timestamp partition values. Specialized to: + /// 1. Convert to default timezone + /// 2. Use space as date-time separator (not 'T') + /// 3. Use millisecond precision with trailing zeros skipped + /// 4. Always keep at least ".0" for fractional seconds (Presto compatibility) + static std::string toName(Timestamp value, const TypePtr& type); + + /// Build partition key-value pairs from partition values. + /// Returns a vector of (key, value) pairs for all partition columns. + /// @tparam F A callable that converts a value to a partition string. + /// Takes (value, type, columnIndex) and returns string. + /// @param partitionId The partition ID (row index) to extract values from. + /// @param partitionValues RowVector containing partition values. + /// @param nullValueString The string to use for null values. + /// @param toPartitionName Callable to convert a value to a string. + template + static std::vector> partitionKeyValues( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + const std::string& nullValueString, + const F& toPartitionName); + + /// Generate a Hive partition directory name from partition values for + /// partitionId. + /// + /// @param partitionId The row index in partitionValues to extract values + /// from. + /// @param partitionValues RowVector containing partition values. Each + /// child vector represents a partition column, and the row at + /// partitionId contains the values for this partition. + /// @param partitionKeyAsLowerCase Controls whether partition column names + /// should be converted to lowercase in the output. When true, column + /// names are lowercased (e.g., "year=2025"); when false, original + /// casing is preserved (e.g., "Year=2025"). + /// @return A formatted partition directory name string. Null values are + /// represented as __HIVE_DEFAULT_PARTITION__. + static std::string partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase); +}; + +namespace detail { + +// Unified template function to extract partition key-value string from a +// vector. Used by both Hive and Iceberg partition name generators. +// +// @tparam Kind The TypeKind of the partition column. +// @tparam F A callable that converts a value to a partition string. +// @param partitionVector The vector containing partition values. +// @param row The row index to extract the value from. +// @param type The type of the partition column. +// @param columnIndex The column index in the partition values. +// @param toPartitionName Callable to convert a value to a partition string. +// @return A pair of (column_name, formatted_value). +template +std::string makePartitionKeyValueString( + const BaseVector& partitionVector, + vector_size_t row, + const TypePtr& type, + int columnIndex, + const F& toPartitionName) { + using T = typename TypeTraits::NativeType; + + return toPartitionName( + partitionVector.as>()->valueAt(row), type, columnIndex); +} + +#define PARTITION_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ + [&]() { \ + switch (typeKind) { \ + case TypeKind::BOOLEAN: \ + case TypeKind::TINYINT: \ + case TypeKind::SMALLINT: \ + case TypeKind::INTEGER: \ + case TypeKind::BIGINT: \ + case TypeKind::VARCHAR: \ + case TypeKind::VARBINARY: \ + case TypeKind::TIMESTAMP: \ + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( \ + TEMPLATE_FUNC, typeKind, __VA_ARGS__); \ + default: \ + VELOX_UNSUPPORTED( \ + "Unsupported partition type: {}", TypeKindName::toName(typeKind)); \ + } \ + }() + +} // namespace detail + +template +std::vector> +HivePartitionName::partitionKeyValues( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + const std::string& nullValueString, + const F& toPartitionName) { + std::vector> partitionKeyValuePairs; + for (auto i = 0; i < partitionValues->childrenSize(); i++) { + const auto& child = partitionValues->childAt(i); + const auto& name = partitionValues->rowType()->nameOf(i); + if (child->isNullAt(partitionId)) { + partitionKeyValuePairs.emplace_back( + std::make_pair(name, nullValueString)); + continue; + } + + partitionKeyValuePairs.emplace_back( + std::make_pair( + name, + PARTITION_TYPE_DISPATCH( + detail::makePartitionKeyValueString, + child->typeKind(), + *child->loadedVector(), + partitionId, + child->type(), + i, + toPartitionName))); + } + return partitionKeyValuePairs; +} + +} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/HivePartitionUtil.cpp b/velox/connectors/hive/HivePartitionUtil.cpp deleted file mode 100644 index cbc53c79b5ea..000000000000 --- a/velox/connectors/hive/HivePartitionUtil.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/connectors/hive/HivePartitionUtil.h" - -namespace facebook::velox::connector::hive { - -#define PARTITION_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ - [&]() { \ - switch (typeKind) { \ - case TypeKind::BOOLEAN: \ - case TypeKind::TINYINT: \ - case TypeKind::SMALLINT: \ - case TypeKind::INTEGER: \ - case TypeKind::BIGINT: \ - case TypeKind::VARCHAR: \ - case TypeKind::VARBINARY: \ - return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( \ - TEMPLATE_FUNC, typeKind, __VA_ARGS__); \ - default: \ - VELOX_UNSUPPORTED( \ - "Unsupported partition type: {}", mapTypeKindToName(typeKind)); \ - } \ - }() - -namespace { -template -inline std::string makePartitionValueString(T value) { - return folly::to(value); -} - -template <> -inline std::string makePartitionValueString(bool value) { - return value ? "true" : "false"; -} - -template -std::pair makePartitionKeyValueString( - const BaseVector* partitionVector, - vector_size_t row, - const std::string& name, - bool isDate) { - using T = typename TypeTraits::NativeType; - if (partitionVector->as>()->isNullAt(row)) { - return std::make_pair(name, ""); - } - if (isDate) { - return std::make_pair( - name, - DATE()->toString( - partitionVector->as>()->valueAt(row))); - } - return std::make_pair( - name, - makePartitionValueString( - partitionVector->as>()->valueAt(row))); -} - -} // namespace - -std::vector> extractPartitionKeyValues( - const RowVectorPtr& partitionsVector, - vector_size_t row) { - std::vector> partitionKeyValues; - for (auto i = 0; i < partitionsVector->childrenSize(); i++) { - partitionKeyValues.push_back(PARTITION_TYPE_DISPATCH( - makePartitionKeyValueString, - partitionsVector->childAt(i)->typeKind(), - partitionsVector->childAt(i)->loadedVector(), - row, - asRowType(partitionsVector->type())->nameOf(i), - partitionsVector->childAt(i)->type()->isDate())); - } - return partitionKeyValues; -} - -} // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/PartitionIdGenerator.cpp b/velox/connectors/hive/PartitionIdGenerator.cpp index deec8dc5b005..a4773da58a10 100644 --- a/velox/connectors/hive/PartitionIdGenerator.cpp +++ b/velox/connectors/hive/PartitionIdGenerator.cpp @@ -16,37 +16,30 @@ #include "velox/connectors/hive/PartitionIdGenerator.h" -#include "velox/connectors/hive/HivePartitionUtil.h" -#include "velox/dwio/catalog/fbhive/FileUtils.h" - -using namespace facebook::velox::dwio::catalog::fbhive; - namespace facebook::velox::connector::hive { PartitionIdGenerator::PartitionIdGenerator( const RowTypePtr& inputType, std::vector partitionChannels, uint32_t maxPartitions, - memory::MemoryPool* pool, - bool partitionPathAsLowerCase) - : partitionChannels_(std::move(partitionChannels)), - maxPartitions_(maxPartitions), - partitionPathAsLowerCase_(partitionPathAsLowerCase) { + memory::MemoryPool* pool) + : pool_(pool), + partitionChannels_(std::move(partitionChannels)), + maxPartitions_(maxPartitions) { VELOX_USER_CHECK( !partitionChannels_.empty(), "There must be at least one partition key."); for (auto channel : partitionChannels_) { hashers_.emplace_back( exec::VectorHasher::create(inputType->childAt(channel), channel)); + VELOX_USER_CHECK( + hashers_.back()->typeSupportsValueIds(), + "Unsupported partition type: {}.", + inputType->childAt(channel)->toString()); } std::vector partitionKeyTypes; std::vector partitionKeyNames; for (auto channel : partitionChannels_) { - VELOX_USER_CHECK( - exec::VectorHasher::typeKindSupportsValueIds( - inputType->childAt(channel)->kind()), - "Unsupported partition type: {}.", - inputType->childAt(channel)->toString()); partitionKeyTypes.push_back(inputType->childAt(channel)); partitionKeyNames.push_back(inputType->nameOf(channel)); } @@ -96,12 +89,6 @@ void PartitionIdGenerator::run( } } -std::string PartitionIdGenerator::partitionName(uint64_t partitionId) const { - return FileUtils::makePartName( - extractPartitionKeyValues(partitionValues_, partitionId), - partitionPathAsLowerCase_); -} - void PartitionIdGenerator::computeValueIds( const RowVectorPtr& input, raw_vector& valueIds) { @@ -154,7 +141,7 @@ void PartitionIdGenerator::updateValueToPartitionIdMapping() { partitionIds_.clear(); - raw_vector newValueIds(numPartitions); + raw_vector newValueIds(numPartitions, pool_); SelectivityVector rows(numPartitions); for (auto i = 0; i < hashers_.size(); ++i) { auto& hasher = hashers_[i]; diff --git a/velox/connectors/hive/PartitionIdGenerator.h b/velox/connectors/hive/PartitionIdGenerator.h index 01b638c0f3ad..0a53252829c8 100644 --- a/velox/connectors/hive/PartitionIdGenerator.h +++ b/velox/connectors/hive/PartitionIdGenerator.h @@ -29,14 +29,11 @@ class PartitionIdGenerator { /// @param maxPartitions The max number of distinct partitions. /// @param pool Memory pool. Used to allocate memory for storing unique /// partition key values. - /// @param partitionPathAsLowerCase Used to control whether the partition path - /// need to convert to lower case. PartitionIdGenerator( const RowTypePtr& inputType, std::vector partitionChannels, uint32_t maxPartitions, - memory::MemoryPool* pool, - bool partitionPathAsLowerCase); + memory::MemoryPool* pool); /// Generate sequential partition IDs for input vector. /// @param input Input RowVector. @@ -48,11 +45,16 @@ class PartitionIdGenerator { return partitionIds_.size(); } - /// Return partition name for the given partition id in the typical Hive - /// style. It is derived from the partitionValues_ at index partitionId. - /// Partition keys appear in the order of partition columns in the table - /// schema. - std::string partitionName(uint64_t partitionId) const; + /// Returns the RowVector containing transformed partition keys. + /// Each row in this vector corresponds to a partition ID (row index = + /// partition ID). + /// Should be called after calling run() method. + /// + /// @return RowVector with one column per partition column, columns in same + /// order as partitionChannels_. + const RowVectorPtr& partitionValues() const { + return partitionValues_; + } private: static constexpr const int32_t kHasherReservePct = 20; @@ -75,12 +77,12 @@ class PartitionIdGenerator { const RowVectorPtr& input, vector_size_t row); + memory::MemoryPool* const pool_; + const std::vector partitionChannels_; const uint32_t maxPartitions_; - const bool partitionPathAsLowerCase_; - std::vector> hashers_; bool hasMultiplierSet_ = false; diff --git a/velox/connectors/hive/SplitReader.cpp b/velox/connectors/hive/SplitReader.cpp index 0bd741efda8c..8913044903c8 100644 --- a/velox/connectors/hive/SplitReader.cpp +++ b/velox/connectors/hive/SplitReader.cpp @@ -17,76 +17,88 @@ #include "velox/connectors/hive/SplitReader.h" #include "velox/common/caching/CacheTTLController.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/connectors/hive/iceberg/IcebergSplitReader.h" -#include "velox/dwio/common/CachedBufferedInput.h" #include "velox/dwio/common/ReaderFactory.h" -#include "velox/type/TimestampConversion.h" namespace facebook::velox::connector::hive { namespace { template -VectorPtr newConstantFromString( +VectorPtr newConstantFromStringImpl( const TypePtr& type, const std::optional& value, - vector_size_t size, velox::memory::MemoryPool* pool, - const std::string& sessionTimezone, - bool asLocalTime, - bool isPartitionDateDaysSinceEpoch = false) { + bool isLocalTimestamp, + bool isDaysSinceEpoch) { using T = typename TypeTraits::NativeType; if (!value.has_value()) { - return std::make_shared>(pool, size, true, type, T()); + return std::make_shared>(pool, 1, true, type, T()); } if (type->isDate()) { int32_t days = 0; // For Iceberg, the date partition values are already in daysSinceEpoch // form. - if (isPartitionDateDaysSinceEpoch) { + if (isDaysSinceEpoch) { days = folly::to(value.value()); } else { - days = DATE()->toDays(static_cast(value.value())); + days = DATE()->toDays(value.value()); } return std::make_shared>( - pool, size, false, type, std::move(days)); + pool, 1, false, type, std::move(days)); } if constexpr (std::is_same_v) { return std::make_shared>( - pool, size, false, type, StringView(value.value())); + pool, 1, false, type, StringView(value.value())); } else { auto copy = velox::util::Converter::tryCast(value.value()) .thenOrThrow(folly::identity, [&](const Status& status) { VELOX_USER_FAIL("{}", status.message()); }); if constexpr (kind == TypeKind::TIMESTAMP) { - if (asLocalTime) { + if (isLocalTimestamp) { copy.toGMT(Timestamp::defaultTimezone()); } } return std::make_shared>( - pool, size, false, type, std::move(copy)); + pool, 1, false, type, std::move(copy)); } } } // namespace +VectorPtr newConstantFromString( + const TypePtr& type, + const std::optional& value, + velox::memory::MemoryPool* pool, + bool isLocalTimestamp, + bool isDaysSinceEpoch) { + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( + newConstantFromStringImpl, + type->kind(), + type, + value, + pool, + isLocalTimestamp, + isDaysSinceEpoch); +} + std::unique_ptr SplitReader::create( const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveTableHandle, - const std::unordered_map>* - partitionKeys, + const HiveTableHandlePtr& hiveTableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig, const RowTypePtr& readerOutputType, const std::shared_ptr& ioStats, const std::shared_ptr& fsStats, FileHandleFactory* fileHandleFactory, - folly::Executor* executor, + folly::Executor* ioExecutor, const std::shared_ptr& scanSpec) { // Create the SplitReader based on hiveSplit->customSplitInfo["table_format"] if (hiveSplit->customSplitInfo.count("table_format") > 0 && @@ -101,7 +113,7 @@ std::unique_ptr SplitReader::create( ioStats, fsStats, fileHandleFactory, - executor, + ioExecutor, scanSpec); } else { return std::unique_ptr(new SplitReader( @@ -114,23 +126,22 @@ std::unique_ptr SplitReader::create( ioStats, fsStats, fileHandleFactory, - executor, + ioExecutor, scanSpec)); } } SplitReader::SplitReader( const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveTableHandle, - const std::unordered_map>* - partitionKeys, + const HiveTableHandlePtr& hiveTableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig, const RowTypePtr& readerOutputType, const std::shared_ptr& ioStats, const std::shared_ptr& fsStats, FileHandleFactory* fileHandleFactory, - folly::Executor* executor, + folly::Executor* ioExecutor, const std::shared_ptr& scanSpec) : hiveSplit_(hiveSplit), hiveTableHandle_(hiveTableHandle), @@ -141,7 +152,7 @@ SplitReader::SplitReader( ioStats_(ioStats), fsStats_(fsStats), fileHandleFactory_(fileHandleFactory), - executor_(executor), + ioExecutor_(ioExecutor), pool_(connectorQueryCtx->memoryPool()), scanSpec_(scanSpec), baseReaderOpts_(connectorQueryCtx->memoryPool()), @@ -162,8 +173,9 @@ void SplitReader::configureReaderOptions( void SplitReader::prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) { - createReader(); + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps) { + createReader(fileReadOps); if (emptySplit_) { return; } @@ -174,7 +186,7 @@ void SplitReader::prepareSplit( return; } - createRowReader(std::move(metadataFilter), std::move(rowType)); + createRowReader(std::move(metadataFilter), std::move(rowType), std::nullopt); } void SplitReader::setBucketConversion( @@ -223,7 +235,7 @@ uint64_t SplitReader::next(uint64_t size, VectorPtr& output) { mutation.randomSkip = baseReaderOpts_.randomSkip().get(); numScanned = baseRowReader_->next(size, output, &mutation); } - if (partitionFunction_) { + if (numScanned > 0 && output->size() > 0 && partitionFunction_) { applyBucketConversion( output, bucketConversionRows(*output->asChecked())); } @@ -285,16 +297,22 @@ std::string SplitReader::toString() const { static_cast(baseRowReader_.get())); } -void SplitReader::createReader() { +void SplitReader::createReader( + const folly::F14FastMap& fileReadOps) { VELOX_CHECK_NE( baseReaderOpts_.fileFormat(), dwio::common::FileFormat::UNKNOWN); FileHandleCachedPtr fileHandleCachePtr; + FileHandleKey fileHandleKey{ + .filename = hiveSplit_->filePath, + .tokenProvider = connectorQueryCtx_->fsTokenProvider()}; + + auto fileProperties = hiveSplit_->properties.value_or(FileProperties{}); + fileProperties.fileReadOps = fileReadOps; + try { fileHandleCachePtr = fileHandleFactory_->generate( - hiveSplit_->filePath, - hiveSplit_->properties.has_value() ? &*hiveSplit_->properties : nullptr, - fsStats_ ? fsStats_.get() : nullptr); + fileHandleKey, &fileProperties, fsStats_ ? fsStats_.get() : nullptr); VELOX_CHECK_NOT_NULL(fileHandleCachePtr.get()); } catch (const VeloxRuntimeError& e) { if (e.errorCode() == error_code::kFileNotFound && @@ -313,13 +331,14 @@ void SplitReader::createReader() { if (auto* cacheTTLController = cache::CacheTTLController::getInstance()) { cacheTTLController->addOpenFileInfo(fileHandleCachePtr->uuid.id()); } - auto baseFileInput = createBufferedInput( + auto baseFileInput = BufferedInputBuilder::getInstance()->create( *fileHandleCachePtr, baseReaderOpts_, connectorQueryCtx_, ioStats_, fsStats_, - executor_); + ioExecutor_, + fileReadOps); baseReader_ = dwio::common::getReaderFactory(baseReaderOpts_.fileFormat()) ->createReader(std::move(baseFileInput), baseReaderOpts_); @@ -369,7 +388,8 @@ bool SplitReader::checkIfSplitIsEmpty( void SplitReader::createRowReader( std::shared_ptr metadataFilter, - RowTypePtr rowType) { + RowTypePtr rowType, + std::optional rowSizeTrackingEnabled) { VELOX_CHECK_NULL(baseRowReader_); configureRowReaderOptions( hiveTableHandle_->tableParameters(), @@ -379,7 +399,13 @@ void SplitReader::createRowReader( hiveSplit_, hiveConfig_, connectorQueryCtx_->sessionProperties(), + ioExecutor_, baseRowReaderOpts_); + baseRowReaderOpts_.setTrackRowSize( + rowSizeTrackingEnabled.has_value() + ? *rowSizeTrackingEnabled + : connectorQueryCtx_->rowSizeTrackingMode() != + core::QueryConfig::RowSizeTrackingMode::DISABLED); baseRowReader_ = baseReader_->createRowReader(baseRowReaderOpts_); } @@ -401,16 +427,13 @@ std::vector SplitReader::adaptColumns( iter != hiveSplit_->infoColumns.end()) { auto infoColumnType = readerOutputType_->childAt(readerOutputType_->getChildIdx(fieldName)); - auto constant = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( - newConstantFromString, - infoColumnType->kind(), + auto constant = newConstantFromString( infoColumnType, iter->second, - 1, connectorQueryCtx_->memoryPool(), - connectorQueryCtx_->sessionTimezone(), hiveConfig_->readTimestampPartitionValueAsLocalTime( - connectorQueryCtx_->sessionProperties())); + connectorQueryCtx_->sessionProperties()), + false); childSpec->setConstantValue(constant); } else if ( childSpec->columnType() == common::ScanSpec::ColumnType::kRegular) { @@ -418,10 +441,11 @@ std::vector SplitReader::adaptColumns( if (!fileTypeIdx.has_value()) { // Column is missing. Most likely due to schema evolution. VELOX_CHECK(tableSchema, "Unable to resolve column '{}'", fieldName); - childSpec->setConstantValue(BaseVector::createNullConstant( - tableSchema->findChild(fieldName), - 1, - connectorQueryCtx_->memoryPool())); + childSpec->setConstantValue( + BaseVector::createNullConstant( + tableSchema->findChild(fieldName), + 1, + connectorQueryCtx_->memoryPool())); } else { // Column no longer missing, reset constant value set on the spec. childSpec->setConstantValue(nullptr); @@ -457,14 +481,10 @@ void SplitReader::setPartitionValue( "ColumnHandle is missing for partition key {}", partitionKey); auto type = it->second->dataType(); - auto constant = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( - newConstantFromString, - type->kind(), + auto constant = newConstantFromString( type, value, - 1, connectorQueryCtx_->memoryPool(), - connectorQueryCtx_->sessionTimezone(), hiveConfig_->readTimestampPartitionValueAsLocalTime( connectorQueryCtx_->sessionProperties()), it->second->isPartitionDateValueDaysSinceEpoch()); diff --git a/velox/connectors/hive/SplitReader.h b/velox/connectors/hive/SplitReader.h index 21c77ea2014a..46ecccb4eca7 100644 --- a/velox/connectors/hive/SplitReader.h +++ b/velox/connectors/hive/SplitReader.h @@ -25,7 +25,6 @@ namespace facebook::velox { class BaseVector; -class variant; using VectorPtr = std::shared_ptr; } // namespace facebook::velox @@ -48,6 +47,36 @@ class MemoryPool; namespace facebook::velox::connector::hive { +/// Creates a constant vector of size 1 from a string representation of a value. +/// +/// Used to materialize partition column values and info columns (e.g., $path, +/// $file_size) when reading Hive and Iceberg tables. Partition values are +/// stored as strings in HiveConnectorSplit::partitionKeys and need to be +/// converted to their appropriate types. +/// +/// @param type The target Velox type for the constant vector. Supports all +/// scalar types including primitives, dates, timestamps. +/// @param value The string representation of the value to convert, formatted +/// the same way as CAST(x as VARCHAR). Date values must be formatted using ISO +/// 8601 as YYYY-MM-DD. If nullopt, creates a null constant vector. +/// @param pool Memory pool for allocating the constant vector. +/// @param isLocalTimestamp If true and type is TIMESTAMP, interprets the string +/// value as local time and converts it to GMT. If false, treats the value +/// as already in GMT. +/// @param isDaysSinceEpoch If true and type is DATE, treats the string value as +/// an integer representing days since epoch (used by Iceberg). If false, parses +/// the string as a date string in ISO 8601 format (used by Hive). +/// +/// @return A constant vector of size 1 containing the converted value, or a +/// null constant if value is nullopt. +/// @throws VeloxUserError if the string cannot be converted to the target type. +VectorPtr newConstantFromString( + const TypePtr& type, + const std::optional& value, + velox::memory::MemoryPool* pool, + bool isLocalTimestamp, + bool isDaysSinceEpoch); + struct HiveConnectorSplit; class HiveTableHandle; class HiveColumnHandle; @@ -58,15 +87,16 @@ class SplitReader { static std::unique_ptr create( const std::shared_ptr& hiveSplit, const std::shared_ptr& hiveTableHandle, - const std::unordered_map>* - partitionKeys, + const std::unordered_map< + std::string, + std::shared_ptr>* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig, const RowTypePtr& readerOutputType, const std::shared_ptr& ioStats, const std::shared_ptr& fsStats, FileHandleFactory* fileHandleFactory, - folly::Executor* executor, + folly::Executor* ioExecutor, const std::shared_ptr& scanSpec); virtual ~SplitReader() = default; @@ -80,7 +110,8 @@ class SplitReader { /// would be called only once per incoming split virtual void prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats); + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}); virtual uint64_t next(uint64_t size, VectorPtr& output); @@ -110,8 +141,9 @@ class SplitReader { SplitReader( const std::shared_ptr& hiveSplit, const std::shared_ptr& hiveTableHandle, - const std::unordered_map>* - partitionKeys, + const std::unordered_map< + std::string, + std::shared_ptr>* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig, const RowTypePtr& readerOutputType, @@ -123,7 +155,8 @@ class SplitReader { /// Create the dwio::common::Reader object baseReader_, which will be used to /// read the data file's metadata and schema - void createReader(); + void createReader( + const folly::F14FastMap& fileReadOps = {}); // Adjust the scan spec according to the current split, then return the // adapted row type. @@ -146,7 +179,8 @@ class SplitReader { /// ColumnReaders that will be used to read the data void createRowReader( std::shared_ptr metadataFilter, - RowTypePtr rowType); + RowTypePtr rowType, + std::optional rowSizeTrackingEnabled); const folly::F14FastSet& bucketChannels() const { return bucketChannels_; @@ -159,24 +193,31 @@ class SplitReader { VectorPtr& output, const std::vector& ranges); - private: - /// Different table formats may have different meatadata columns. - /// This function will be used to update the scanSpec for these columns. - std::vector adaptColumns( - const RowTypePtr& fileType, - const std::shared_ptr& tableSchema) const; - + /// Sets a constant partition value on the scanSpec for a partition column. + /// Converts the partition key string value to the appropriate type and sets + /// it as a constant value in the scanSpec, so the column will be filled + /// with this constant value. + /// + /// @param spec The scan spec to set the constant value on. + /// @param partitionKey The name of the partition column. void setPartitionValue( common::ScanSpec* spec, const std::string& partitionKey, const std::optional& value) const; + private: + /// Different table formats may have different meatadata columns. + /// This function will be used to update the scanSpec for these columns. + virtual std::vector adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const; + protected: std::shared_ptr hiveSplit_; const std::shared_ptr hiveTableHandle_; const std::unordered_map< std::string, - std::shared_ptr>* const partitionKeys_; + std::shared_ptr>* const partitionKeys_; const ConnectorQueryCtx* connectorQueryCtx_; const std::shared_ptr hiveConfig_; @@ -184,7 +225,7 @@ class SplitReader { const std::shared_ptr ioStats_; const std::shared_ptr fsStats_; FileHandleFactory* const fileHandleFactory_; - folly::Executor* const executor_; + folly::Executor* const ioExecutor_; memory::MemoryPool* const pool_; std::shared_ptr scanSpec_; diff --git a/velox/connectors/hive/TableHandle.cpp b/velox/connectors/hive/TableHandle.cpp index 3f7c8b6f93d5..85f4a2f3ada8 100644 --- a/velox/connectors/hive/TableHandle.cpp +++ b/velox/connectors/hive/TableHandle.cpp @@ -26,6 +26,7 @@ columnTypeNames() { {HiveColumnHandle::ColumnType::kRegular, "Regular"}, {HiveColumnHandle::ColumnType::kSynthesized, "Synthesized"}, {HiveColumnHandle::ColumnType::kRowIndex, "RowIndex"}, + {HiveColumnHandle::ColumnType::kRowId, "RowId"}, }; } @@ -110,14 +111,21 @@ HiveTableHandle::HiveTableHandle( common::SubfieldFilters subfieldFilters, const core::TypedExprPtr& remainingFilter, const RowTypePtr& dataColumns, - const std::unordered_map& tableParameters) + const std::unordered_map& tableParameters, + std::vector filterColumnHandles, + double sampleRate) : ConnectorTableHandle(std::move(connectorId)), tableName_(tableName), filterPushdownEnabled_(filterPushdownEnabled), subfieldFilters_(std::move(subfieldFilters)), remainingFilter_(remainingFilter), + sampleRate_{sampleRate}, dataColumns_(dataColumns), - tableParameters_(tableParameters) {} + tableParameters_(tableParameters), + filterColumnHandles_(std::move(filterColumnHandles)) { + VELOX_CHECK_GT(sampleRate_, 0.0, "Sample rate must be positive"); + VELOX_CHECK_LE(sampleRate_, 1.0, "Sample rate must not exceed 1.0"); +} std::string HiveTableHandle::toString() const { std::stringstream out; @@ -139,6 +147,9 @@ std::string HiveTableHandle::toString() const { } out << "]"; } + if (sampleRate_ < 1.0) { + out << ", sample rate: " << sampleRate_; + } if (remainingFilter_) { out << ", remaining filter: (" << remainingFilter_->toString() << ")"; } @@ -159,6 +170,19 @@ std::string HiveTableHandle::toString() const { } out << "]"; } + if (!filterColumnHandles_.empty()) { + out << ", filter column handles: ["; + bool first = true; + for (auto& handle : filterColumnHandles_) { + if (first) { + first = false; + } else { + out << ", "; + } + out << handle->toString(); + } + out << "]"; + } return out.str(); } @@ -179,6 +203,11 @@ folly::dynamic HiveTableHandle::serialize() const { if (remainingFilter_) { obj["remainingFilter"] = remainingFilter_->serialize(); } + + if (sampleRate_ < 1.0) { + obj["sampleRate"] = sampleRate_; + } + if (dataColumns_) { obj["dataColumns"] = dataColumns_->serialize(); } @@ -187,6 +216,13 @@ folly::dynamic HiveTableHandle::serialize() const { tableParameters[param.first] = param.second; } obj["tableParameters"] = tableParameters; + if (!filterColumnHandles_.empty()) { + folly::dynamic filterColumnHandles = folly::dynamic::array; + for (const auto& handle : filterColumnHandles_) { + filterColumnHandles.push_back(handle->serialize()); + } + obj["filterColumnHandles"] = filterColumnHandles; + } return obj; } @@ -214,6 +250,11 @@ ConnectorTableHandlePtr HiveTableHandle::create( filter->clone(); } + double sampleRate = 1.0; + if (obj.count("sampleRate")) { + sampleRate = obj["sampleRate"].asDouble(); + } + RowTypePtr dataColumns; if (auto it = obj.find("dataColumns"); it != obj.items().end()) { dataColumns = ISerializable::deserialize(it->second, context); @@ -226,6 +267,14 @@ ConnectorTableHandlePtr HiveTableHandle::create( tableParameters.emplace(key.asString(), value.asString()); } + std::vector filterColumnHandles; + if (auto it = obj.find("filterColumnHandles"); it != obj.items().end()) { + for (const auto& handle : it->second) { + filterColumnHandles.push_back( + ISerializable::deserialize(handle, context)); + } + } + return std::make_shared( connectorId, tableName, @@ -233,7 +282,9 @@ ConnectorTableHandlePtr HiveTableHandle::create( std::move(subfieldFilters), remainingFilter, dataColumns, - tableParameters); + tableParameters, + std::move(filterColumnHandles), + sampleRate); } void HiveTableHandle::registerSerDe() { diff --git a/velox/connectors/hive/TableHandle.h b/velox/connectors/hive/TableHandle.h index 2711da55a37a..a59e371a84fd 100644 --- a/velox/connectors/hive/TableHandle.h +++ b/velox/connectors/hive/TableHandle.h @@ -25,6 +25,8 @@ namespace facebook::velox::connector::hive { class HiveColumnHandle : public ColumnHandle { public: + /// NOTE: Make sure to update the mapping in columnTypeNames() when modifying + /// this. enum class ColumnType { kPartitionKey, kRegular, @@ -53,13 +55,15 @@ class HiveColumnHandle : public ColumnHandle { TypePtr dataType, TypePtr hiveType, std::vector requiredSubfields = {}, - ColumnParseParameters columnParseParameters = {}) + ColumnParseParameters columnParseParameters = {}, + std::function postProcessor = {}) : name_(name), columnType_(columnType), dataType_(std::move(dataType)), hiveType_(std::move(hiveType)), requiredSubfields_(std::move(requiredSubfields)), - columnParseParameters_(columnParseParameters) { + columnParseParameters_(columnParseParameters), + postProcessor_(std::move(postProcessor)) { VELOX_USER_CHECK( dataType_->equivalent(*hiveType_), "data type {} and hive type {} do not match", @@ -110,7 +114,24 @@ class HiveColumnHandle : public ColumnHandle { ColumnParseParameters::kDaysSinceEpoch; } - std::string toString() const; + /// Apply some row-wise post processing to this column when it is present in + /// output. + /// + /// It's not allowed to change the size of the vector in the processor. The + /// top level vector is guaranteed to be safe to change. Any inner vectors + /// and buffers need to check the reference count before doing any change in + /// place, otherwise you need to allocate new vectors and buffers. + /// + /// For lazy vector, this will be applied after the lazy vector is loaded. + /// This is only applied after all the filtering is done; the filters (both + /// subfield filters and remaining filter) still apply to values before post + /// processing. ValueHook usage will be disabled if a post processor is + /// present. + const std::function& postProcessor() const { + return postProcessor_; + } + + std::string toString() const override; folly::dynamic serialize() const override; @@ -130,10 +151,17 @@ class HiveColumnHandle : public ColumnHandle { const TypePtr hiveType_; const std::vector requiredSubfields_; const ColumnParseParameters columnParseParameters_; + const std::function postProcessor_; }; +using HiveColumnHandlePtr = std::shared_ptr; +using HiveColumnHandleMap = + std::unordered_map; + class HiveTableHandle : public ConnectorTableHandle { public: + /// @param sampleRate Sampling rate in (0, 1] range. 0.1 means 10% sampling. + /// 1.0 means no sampling. Default is no sampling. HiveTableHandle( std::string connectorId, const std::string& tableName, @@ -141,7 +169,9 @@ class HiveTableHandle : public ConnectorTableHandle { common::SubfieldFilters subfieldFilters, const core::TypedExprPtr& remainingFilter, const RowTypePtr& dataColumns = nullptr, - const std::unordered_map& tableParameters = {}); + const std::unordered_map& tableParameters = {}, + std::vector filterColumnHandles = {}, + double sampleRate = 1.0); const std::string& tableName() const { return tableName_; @@ -151,27 +181,53 @@ class HiveTableHandle : public ConnectorTableHandle { return tableName(); } - bool isFilterPushdownEnabled() const { + [[deprecated]] bool isFilterPushdownEnabled() const { return filterPushdownEnabled_; } + /// Single field filters that can be applied efficiently during file reading. const common::SubfieldFilters& subfieldFilters() const { return subfieldFilters_; } + /// Everything else that cannot be converted into subfield filters, but still + /// require the data source to filter out. This is usually less efficient + /// than subfield filters but supports arbitrary boolean expression. const core::TypedExprPtr& remainingFilter() const { return remainingFilter_; } - // Schema of the table. Need this for reading TEXTFILE. + /// Sampling rate between 0 and 1 (excluding 0). 0.1 means 10% + /// sampling. 1.0 means no sampling. + double sampleRate() const { + return sampleRate_; + } + + /// Subset of schema of the table that we store in file (i.e., + /// non-partitioning columns). This must be in the exact order as columns in + /// file (except trailing columns), but with the table schema during read + /// time. + /// + /// This is needed for multiple purposes, including reading TEXTFILE and + /// handling schema evolution. const RowTypePtr& dataColumns() const { return dataColumns_; } + /// Extra parameters to pass down to file format reader layer. Keys should be + /// in dwio::common::TableParameter. const std::unordered_map& tableParameters() const { return tableParameters_; } + /// Extra columns that are used in filters and remaining filters, but not in + /// the output. If there is overlap with data source assignments parameter, + /// the name and types should be the same (the required subfields are taken + /// from assignments). + const std::vector filterColumnHandles() const { + return filterColumnHandles_; + } + std::string toString() const override; folly::dynamic serialize() const override; @@ -187,8 +243,12 @@ class HiveTableHandle : public ConnectorTableHandle { const bool filterPushdownEnabled_; const common::SubfieldFilters subfieldFilters_; const core::TypedExprPtr remainingFilter_; + const double sampleRate_; const RowTypePtr dataColumns_; const std::unordered_map tableParameters_; + const std::vector filterColumnHandles_; }; +using HiveTableHandlePtr = std::shared_ptr; + } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/benchmarks/CMakeLists.txt b/velox/connectors/hive/benchmarks/CMakeLists.txt index 38c15bb9c057..aa82fc2edd0b 100644 --- a/velox/connectors/hive/benchmarks/CMakeLists.txt +++ b/velox/connectors/hive/benchmarks/CMakeLists.txt @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_hive_partition_function_benchmark - HivePartitionFunctionBenchmark.cpp) +add_executable(velox_hive_partition_function_benchmark HivePartitionFunctionBenchmark.cpp) target_link_libraries( velox_hive_partition_function_benchmark @@ -31,4 +30,5 @@ target_link_libraries( velox_memory Folly::folly Folly::follybenchmark - fmt::fmt) + fmt::fmt +) diff --git a/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp b/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp index f503cb6e2e8f..b55c02d7c073 100644 --- a/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp +++ b/velox/connectors/hive/benchmarks/HivePartitionFunctionBenchmark.cpp @@ -97,7 +97,7 @@ class HivePartitionFunctionBenchmark void run(HivePartitionFunction* function) { if (rowVectors_.find(KIND) == rowVectors_.end()) { throw std::runtime_error( - fmt::format("Unsupported type {}.", mapTypeKindToName(KIND))); + fmt::format("Unsupported type {}.", TypeKindName::toName(KIND))); } function->partition(*rowVectors_[KIND], partitions_); } diff --git a/velox/connectors/hive/iceberg/CMakeLists.txt b/velox/connectors/hive/iceberg/CMakeLists.txt index bc78005c91bb..fe20d16bdc44 100644 --- a/velox/connectors/hive/iceberg/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/CMakeLists.txt @@ -12,11 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_hive_iceberg_splitreader IcebergSplitReader.cpp - IcebergSplit.cpp PositionalDeleteFileReader.cpp) +velox_add_library( + velox_hive_iceberg_splitreader + IcebergConnector.cpp + IcebergDataSink.cpp + IcebergPartitionName.cpp + IcebergSplit.cpp + IcebergSplitReader.cpp + PartitionSpec.cpp + PositionalDeleteFileReader.cpp + TransformEvaluator.cpp + TransformExprBuilder.cpp +) -velox_link_libraries(velox_hive_iceberg_splitreader velox_connector - Folly::folly) +velox_link_libraries( + velox_hive_iceberg_splitreader + velox_connector + velox_functions_iceberg + Folly::folly +) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/hive/iceberg/IcebergConnector.cpp b/velox/connectors/hive/iceberg/IcebergConnector.cpp new file mode 100644 index 000000000000..cec2422d36d0 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergConnector.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConnector.h" + +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" + +namespace facebook::velox::connector::hive::iceberg { + +const std::string_view kIcebergFunctionPrefixConfig{"presto.iceberg-namespace"}; +const std::string_view kDefaultIcebergFunctionPrefix{"$internal$.iceberg."}; + +namespace { + +// Registers Iceberg partition transform functions with prefix. +// NOTE: These functions are registered for internal transform usage only. +// Upstream engines such as Prestissimo and Gluten should register the same +// functions with different prefixes to avoid conflicts. +void registerIcebergInternalFunctions(const std::string& prefix) { + static std::once_flag registerFlag; + + std::call_once(registerFlag, [prefix]() { + functions::iceberg::registerFunctions(prefix); + }); +} + +} // namespace + +IcebergConnector::IcebergConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor) + : HiveConnector(id, config, ioExecutor), + functionPrefix_(config->get( + std::string(kIcebergFunctionPrefixConfig), + std::string(kDefaultIcebergFunctionPrefix))) { + registerIcebergInternalFunctions(functionPrefix_); +} + +std::unique_ptr IcebergConnector::createDataSink( + RowTypePtr inputType, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, + ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy) { + auto icebergInsertHandle = checkedPointerCast( + connectorInsertTableHandle); + + return std::make_unique( + inputType, + icebergInsertHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig_, + functionPrefix_); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergConnector.h b/velox/connectors/hive/iceberg/IcebergConnector.h new file mode 100644 index 000000000000..ffe0397fa640 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergConnector.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/hive/HiveConnector.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// TODO Add IcebergConfig class and Move these configuration properties to +/// IcebergConfig.h +extern const std::string_view kIcebergFunctionPrefixConfig; +extern const std::string_view kDefaultIcebergFunctionPrefix; + +/// Provides Iceberg table format support. +/// - Creates HiveDataSource instances that use IcebergSplitReader for reading +/// Iceberg tables with support for delete files and schema evolution. +/// - Creates IcebergDataSink instances for writing data with Iceberg-specific +/// partition transforms and commit metadata. +class IcebergConnector final : public HiveConnector { + public: + IcebergConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor); + + /// Creates IcebergDataSink for writing to Iceberg tables. + /// + /// @param inputType The schema of the input data to write. + /// @param connectorInsertTableHandle Must be an IcebergInsertTableHandle + /// containing Iceberg-specific write configuration. + /// @param connectorQueryCtx Query context for the write operation. + /// @param commitStrategy Strategy for committing the write operation. Only + /// CommitStrategy::kNoCommit is supported for Iceberg tables. Files + /// are written directly with their final names and commit metadata is + /// returned for the coordinator to update the Iceberg metadata tables. + /// @return IcebergDataSink instance configured for the write operation. + std::unique_ptr createDataSink( + RowTypePtr inputType, + ConnectorInsertTableHandlePtr connectorInsertTableHandle, + ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy) override; + + private: + const std::string functionPrefix_; +}; + +class IcebergConnectorFactory final : public ConnectorFactory { + public: + static constexpr const char* kIcebergConnectorName = "iceberg"; + + IcebergConnectorFactory() : ConnectorFactory(kIcebergConnectorName) {} + + /// Creates a new IcebergConnector instance. + /// + /// @param id Unique identifier for this connector instance (typically the + /// catalog name). + /// @param config Connector configuration properties + /// @param ioExecutor Optional executor for asynchronous I/O operations such + /// as split preloading and file prefetching. When provided, enables + /// background file operations off the main driver thread. If nullptr, I/O + /// operations run synchronously. + /// @param cpuExecutor ConnectorFactory interface to support other connector + /// types that may need CPU-bound async work. Currently unused by + /// IcebergConnector. + /// @return Shared pointer to the newly created IcebergConnector instance + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + [[maybe_unused]] folly::Executor* cpuExecutor = nullptr) override { + return std::make_shared(id, config, ioExecutor); + } +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.cpp b/velox/connectors/hive/iceberg/IcebergDataSink.cpp new file mode 100644 index 000000000000..7297174bc650 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSink.cpp @@ -0,0 +1,408 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" + +#include +#include +#include + +#include "velox/common/base/Fs.h" +#include "velox/connectors/hive/PartitionIdGenerator.h" +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/exec/OperatorUtils.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +template +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + using T = typename TypeTraits::NativeType; + return child->as>()->valueAt(row); +} + +template <> +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + return child->as>()->valueAt(row).str(); +} + +template <> +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + return child->as>()->valueAt(row).str(); +} + +template <> +folly::dynamic extractPartitionValue( + const VectorPtr& child, + vector_size_t row) { + return child->as>()->valueAt(row).toMicros(); +} + +class IcebergFileNameGenerator : public FileNameGenerator { + public: + std::pair gen( + std::optional bucketId, + const std::shared_ptr insertTableHandle, + const ConnectorQueryCtx& connectorQueryCtx, + bool commitRequired) const override; + + folly::dynamic serialize() const override; + + std::string toString() const override; +}; + +std::string makeUuid() { + return boost::lexical_cast(boost::uuids::random_generator()()); +} + +std::pair IcebergFileNameGenerator::gen( + std::optional bucketId, + const std::shared_ptr insertTableHandle, + const ConnectorQueryCtx& connectorQueryCtx, + bool commitRequired) const { + auto targetFileName = insertTableHandle->locationHandle()->targetFileName(); + if (targetFileName.empty()) { + targetFileName = fmt::format("{}", makeUuid()); + } + auto fileFormat = dwio::common::toString(insertTableHandle->storageFormat()); + auto fileName = fmt::format("{}.{}", targetFileName, fileFormat); + return {fileName, fileName}; +} + +folly::dynamic IcebergFileNameGenerator::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "IcebergFileNameGenerator"; + return obj; +} + +std::string IcebergFileNameGenerator::toString() const { + return "IcebergFileNameGenerator"; +} + +} // namespace + +IcebergInsertTableHandle::IcebergInsertTableHandle( + std::vector inputColumns, + LocationHandlePtr locationHandle, + dwio::common::FileFormat tableStorageFormat, + IcebergPartitionSpecPtr partitionSpec, + std::optional compressionKind, + const std::unordered_map& serdeParameters) + : HiveInsertTableHandle( + std::move(inputColumns), + std::move(locationHandle), + tableStorageFormat, + nullptr, + compressionKind, + serdeParameters, + nullptr, + false, + std::make_shared()), + partitionSpec_(partitionSpec) { + VELOX_USER_CHECK( + !inputColumns_.empty(), + "Input columns cannot be empty for Iceberg tables."); + VELOX_USER_CHECK_NOT_NULL( + locationHandle_, "Location handle is required for Iceberg tables."); + VELOX_USER_CHECK_EQ( + tableStorageFormat, + dwio::common::FileFormat::PARQUET, + "Only Parquet file format is supported when writing Iceberg tables."); +} + +namespace { + +// Creates partition channels by mapping partition spec fields to input column +// indices. For each field in the partition spec, finds the corresponding +// partition key column in the input columns and records its index. +// +// @param inputColumns The input columns from the insert table handle. +// @param partitionSpec The Iceberg partition specification, or nullptr if +// unpartitioned. +// @return A vector of column indices representing the partition channels. Each +// index corresponds to a partition field in the spec and points to the +// matching partition key column in the input. Returns an empty vector if +// partitionSpec is nullptr. +std::vector createPartitionChannels( + const std::vector& inputColumns, + const IcebergPartitionSpecPtr& partitionSpec) { + std::vector channels; + if (!partitionSpec) { + return channels; + } + + // Build a map from partition key column names to their indices in the input. + std::unordered_map partitionKeyMap; + for (auto i = 0; i < inputColumns.size(); ++i) { + if (inputColumns[i]->isPartitionKey()) { + partitionKeyMap[inputColumns[i]->name()] = i; + } + } + + // For each field in the partition spec, find its corresponding input column + // index. + channels.reserve(partitionSpec->fields.size()); + for (const auto& field : partitionSpec->fields) { + if (auto it = partitionKeyMap.find(field.name); + it != partitionKeyMap.end()) { + channels.push_back(it->second); + } + } + + return channels; +} + +std::vector createDataChannels( + const IcebergInsertTableHandlePtr& insertTableHandle) { + std::vector dataChannels( + insertTableHandle->inputColumns().size()); + std::iota(dataChannels.begin(), dataChannels.end(), 0); + return dataChannels; +} + +// Creates a RowType schema for transformed partition values based on the +// partition specification. This RowType is used to wrap the transformed +// partition columns before passing them to the partition ID generator. +// +// For each partition field in the spec: +// - The column type is the result type of the partition transform (e.g., +// INTEGER for year transform, DATE for day transform). +// - The column name is the source column name for identity transforms, or +// "columnName_transformName" for non-identity transforms (e.g., "birth_year" +// for a year transform on a birth column). +// +// @param partitionSpec The Iceberg partition specification, or nullptr if +// unpartitioned. +// @return A RowType containing one column per partition field with appropriate +// names and types. Returns nullptr if partitionSpec is nullptr. +RowTypePtr createPartitionRowType( + const IcebergPartitionSpecPtr& partitionSpec) { + if (!partitionSpec) { + return nullptr; + } + + std::vector partitionKeyTypes; + std::vector partitionKeyNames; + + // Build column names and types for each partition field. + // Identity transforms use the source column name directly. + // Non-identity transforms use "columnName_transformName" format. + for (const auto& field : partitionSpec->fields) { + partitionKeyTypes.emplace_back(field.resultType()); + std::string key = field.transformType == TransformType::kIdentity + ? field.name + : fmt::format( + "{}_{}", + field.name, + TransformTypeName::toName(field.transformType)); + partitionKeyNames.emplace_back(std::move(key)); + } + + return ROW(std::move(partitionKeyNames), std::move(partitionKeyTypes)); +} + +} // namespace + +IcebergDataSink::IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const std::string& functionPrefix) + : IcebergDataSink( + std::move(inputType), + insertTableHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig, + createPartitionChannels( + insertTableHandle->inputColumns(), + insertTableHandle->partitionSpec()), + createDataChannels(insertTableHandle), + createPartitionRowType(insertTableHandle->partitionSpec()), + functionPrefix) {} + +IcebergDataSink::IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const std::vector& partitionChannels, + const std::vector& dataChannels, + RowTypePtr partitionRowType, + const std::string& functionPrefix) + : HiveDataSink( + inputType, + insertTableHandle, + connectorQueryCtx, + commitStrategy, + hiveConfig, + 0, + nullptr, + partitionChannels, + dataChannels, + !partitionChannels.empty() + ? std::make_unique( + partitionRowType, + [&partitionChannels]() { + std::vector transformedChannels( + partitionChannels.size()); + std::iota( + transformedChannels.begin(), + transformedChannels.end(), + 0); + return transformedChannels; + }(), + hiveConfig->maxPartitionsPerWriters( + connectorQueryCtx->sessionProperties()), + connectorQueryCtx->memoryPool()) + : nullptr), + partitionSpec_(insertTableHandle->partitionSpec()), + transformEvaluator_( + !partitionChannels.empty() ? std::make_unique( + TransformExprBuilder::toExpressions( + partitionSpec_, + partitionChannels_, + inputType_, + functionPrefix), + connectorQueryCtx_) + : nullptr), + icebergPartitionName_( + partitionSpec_ != nullptr + ? std::make_unique(partitionSpec_) + : nullptr), + partitionRowType_(std::move(partitionRowType)) { + commitPartitionValue_.resize(maxOpenWriters_); +} + +std::vector IcebergDataSink::commitMessage() const { + std::vector commitTasks; + commitTasks.reserve(writerInfo_.size()); + + auto icebergInsertTableHandle = + std::dynamic_pointer_cast( + insertTableHandle_); + + for (auto i = 0; i < writerInfo_.size(); ++i) { + const auto& info = writerInfo_.at(i); + VELOX_CHECK_NOT_NULL(info); + // Following metadata (json format) is consumed by Presto CommitTaskData. + // It contains the minimal subset of metadata. + // TODO: Complete metrics is missing now and this could lead to suboptimal + // query plan, will collect full iceberg metrics in following PR. + // clang-format off + folly::dynamic commitData = folly::dynamic::object( + "path", (fs::path(info->writerParameters.writeDirectory()) / + info->writerParameters.writeFileName()).string()) + ("fileSizeInBytes", ioStats_.at(i)->rawBytesWritten()) + ("metrics", + folly::dynamic::object("recordCount", info->numWrittenRows)) + ("partitionSpecJson", + icebergInsertTableHandle->partitionSpec() ? icebergInsertTableHandle->partitionSpec()->specId : 0) + ("fileFormat", "PARQUET") + ("content", "DATA"); + // clang-format on + if (!commitPartitionValue_.empty() && !commitPartitionValue_[i].isNull()) { + commitData["partitionDataJson"] = folly::toJson( + folly::dynamic::object("partitionValues", commitPartitionValue_[i])); + } + auto commitDataJson = folly::toJson(commitData); + commitTasks.push_back(commitDataJson); + } + return commitTasks; +} + +void IcebergDataSink::computePartitionAndBucketIds(const RowVectorPtr& input) { + VELOX_CHECK(isPartitioned()); + VELOX_CHECK_NOT_NULL(transformEvaluator_); + VELOX_CHECK_NOT_NULL(partitionIdGenerator_); + // Step 1: Apply transforms to input partition columns. + auto transformedColumns = transformEvaluator_->evaluate(input); + + // Step 2: Create RowVector based on transformed columns. + const auto& transformedRowVector = std::make_shared( + connectorQueryCtx_->memoryPool(), + partitionRowType_, + nullptr, + input->size(), + std::move(transformedColumns)); + partitionIdGenerator_->run(transformedRowVector, partitionIds_); +} + +std::string IcebergDataSink::getPartitionName(uint32_t partitionId) const { + VELOX_CHECK_NOT_NULL(icebergPartitionName_); + + return icebergPartitionName_->partitionName( + partitionId, + partitionIdGenerator_->partitionValues(), + partitionKeyAsLowerCase_); +} + +uint32_t IcebergDataSink::ensureWriter(const HiveWriterId& id) { + auto writerId = HiveDataSink::ensureWriter(id); + if (commitPartitionValue_[writerId].isNull()) { + commitPartitionValue_[writerId] = makeCommitPartitionValue(writerId); + } + return writerId; +} + +std::shared_ptr +IcebergDataSink::createWriterOptions() const { + auto options = HiveDataSink::createWriterOptions(); + // Per Iceberg specification (https://iceberg.apache.org/spec/#parquet): + // - Timestamps must be stored with microsecond precision. + // - Timestamps must NOT be adjusted to UTC timezone; they should be written + // as-is without timezone conversion (empty string disables conversion). + // + // These settings are passed via serdeParameters to avoid including + // parquet-specific headers. The keys must match kParquetSerdeTimestampUnit + // and kParquetSerdeTimestampTimezone defined in + // velox/dwio/parquet/writer/Writer.h. The value "6" represents microseconds + // (TimestampPrecision::kMicroseconds). + options->serdeParameters["parquet.writer.timestamp.unit"] = "6"; + options->serdeParameters["parquet.writer.timestamp.timezone"] = ""; + // Re-process configs to apply the serde parameters we just set. + options->processConfigs( + *hiveConfig_->config(), *connectorQueryCtx_->sessionProperties()); + return options; +} + +folly::dynamic IcebergDataSink::makeCommitPartitionValue( + uint32_t writerIndex) const { + folly::dynamic partitionValues = folly::dynamic::array(); + const auto& transformedValues = partitionIdGenerator_->partitionValues(); + for (auto i = 0; i < partitionChannels_.size(); ++i) { + const auto& child = transformedValues->childAt(i); + if (child->isNullAt(writerIndex)) { + partitionValues.push_back(nullptr); + } else { + partitionValues.push_back(VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + extractPartitionValue, child->typeKind(), child, writerIndex)); + } + } + return partitionValues; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergDataSink.h b/velox/connectors/hive/iceberg/IcebergDataSink.h new file mode 100644 index 000000000000..022bd0d6b23a --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergDataSink.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/HiveDataSink.h" +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" +#include "velox/functions/iceberg/Register.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Represents a request for Iceberg write. +class IcebergInsertTableHandle final : public HiveInsertTableHandle { + public: + /// @param inputColumns Columns from the table schema to write. + /// The input RowVector must have the same number of columns and matching + /// types in the same order. + /// Column names in the RowVector may differ from those in inputColumns, + /// only position and type must align. All columns present in the input + /// data must be included, mismatches can lead to write failure. + /// @param locationHandle Contains the target location information including: + /// - Base directory path where data files will be written. + /// - File naming scheme and temporary directory paths. + /// @param tableStorageFormat File format to use for writing data files. + /// @param partitionSpec Optional partition specification defining how to + /// partition the data. If nullptr, the table is unpartitioned and all data + /// is written to a single directory. + /// @param compressionKind Optional compression to apply to data files. + /// @param serdeParameters Additional serialization/deserialization parameters + /// for the file format. + IcebergInsertTableHandle( + std::vector inputColumns, + LocationHandlePtr locationHandle, + dwio::common::FileFormat tableStorageFormat, + IcebergPartitionSpecPtr partitionSpec, + std::optional compressionKind = {}, + const std::unordered_map& serdeParameters = {}); + +#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY + IcebergInsertTableHandle( + std::vector inputColumns, + LocationHandlePtr locationHandle, + dwio::common::FileFormat tableStorageFormat, + std::optional compressionKind = {}, + const std::unordered_map& serdeParameters = {}) + : IcebergInsertTableHandle( + inputColumns, + locationHandle, + tableStorageFormat, + nullptr, + compressionKind, + serdeParameters) {} +#endif + + /// Returns the Iceberg partition specification that defines how the table + /// is partitioned. + const IcebergPartitionSpecPtr& partitionSpec() const { + return partitionSpec_; + } + + private: + const IcebergPartitionSpecPtr partitionSpec_; +}; + +using IcebergInsertTableHandlePtr = + std::shared_ptr; + +class IcebergDataSink : public HiveDataSink { + public: + IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const std::string& functionPrefix); + + /// Generates Iceberg-specific commit messages for all writers containing + /// metadata about written files. Creates a JSON object for each writer + /// in the format expected by Presto and Spark for Iceberg tables. + /// + /// Each commit message contains: + /// - path: full file path where data was written. + /// - fileSizeInBytes: raw bytes written to disk. + /// - metrics: object with recordCount (number of rows written). + /// - partitionSpecJson: partition specification. + /// - fileFormat: storage format (e.g., "PARQUET"). + /// - content: file content type ("DATA" for data files). + /// + /// See + /// https://github.com/prestodb/presto/blob/master/presto-iceberg/src/main/java/com/facebook/presto/iceberg/CommitTaskData.java + /// + /// Note: Complete Iceberg metrics are not yet implemented, which results in + /// incomplete manifest files that may lead to suboptimal query planning. + /// + /// @return Vector of JSON strings, one per writer, formatted according to + /// Presto and Spark Iceberg commit protocol. + std::vector commitMessage() const override; + + private: + IcebergDataSink( + RowTypePtr inputType, + IcebergInsertTableHandlePtr insertTableHandle, + const ConnectorQueryCtx* connectorQueryCtx, + CommitStrategy commitStrategy, + const std::shared_ptr& hiveConfig, + const std::vector& partitionChannels, + const std::vector& dataChannels, + RowTypePtr partitionRowType, + const std::string& functionPrefix); + + // Computes partition IDs for each row in the input batch by applying Iceberg + // partition transforms and generating unique partition identifiers. + // + // Performs a two-step process: + // 1. Applies Iceberg partition transforms (e.g., year, month, day, hour, + // bucket, truncate) to the input partition columns using + // transformEvaluator_ to produce transformed partition values. + // 2. Wraps the transformed columns in a RowVector with partitionRowType_ + // schema and passes it to partitionIdGenerator_ to compute partition IDs. + // + // The resulting partition IDs are stored in partitionIds_ buffer, where each + // element corresponds to a row in the input. These IDs are used to: + // - Route rows to the appropriate writer (one writer per unique partition). + // - Generate partition directory names via getPartitionName(). + // + // Note: Iceberg does not support bucketing, so this method only computes + // partition IDs, not bucket IDs. + // + // @param input The input RowVector containing rows to be partitioned. + void computePartitionAndBucketIds(const RowVectorPtr& input) override; + + // Returns the Iceberg partition directory name for the given partition ID. + // Converts the transformed partition values associated with the partition ID + // into an Iceberg compliant directory path + // (e.g., "date_year=2023/id_bucket=5"). + std::string getPartitionName(uint32_t partitionId) const override; + + // Ensures a writer exists for the given writer ID and returns its index. + // If the writer doesn't exist, creates it by calling appendWriter(). + // Additionally, extracts and stores the transformed partition values for + // the writer in commitPartitionValue_ if not already set, which will be + // included in the commit message as "partitionDataJson". + uint32_t ensureWriter(const HiveWriterId& id) override; + + // Creates writer options configured for Iceberg table writes. Extends the + // base HiveDataSink writer options with Iceberg-specific settings: + // - Sets timestamp timezone to nullopt (UTC) for Iceberg compliance. + // - Sets timestamp precision to microseconds. + std::shared_ptr createWriterOptions() + const override; + + // Extracts partition values for a specific writer to be included in the + // commit message. Converts the transformed partition values from columnar + // storage (partitionIdGenerator_->partitionValues() where each partition + // field is a separate column) to row storage (a folly::dynamic array of + // values for the given writer index) for JSON serialization. + // Returns nullptr for null partition values. + folly::dynamic makeCommitPartitionValue(uint32_t writerIndex) const; + + // Iceberg partition specification defining how the table is partitioned. + // Contains partition fields with source column names, transform types + // (e.g., identity, year, month, day, hour, bucket, truncate), transform + // parameters, and result types. Null if the table is unpartitioned. + const IcebergPartitionSpecPtr partitionSpec_; + + // Evaluates Iceberg partition transforms on input rows to produce transformed + // partition keys. Applies transforms defined in partitionSpec_ (e.g., + // year(date_col), bucket(id, 16)) to the corresponding input columns and + // returns a vector of transformed columns. The transformed keys are then + // wrapped in a RowVector and passed to IcebergPartitionIdGenerator. + // Null if the table is unpartitioned. + const std::unique_ptr transformEvaluator_; + + // Generates Iceberg compliant partition directory names from partition IDs. + // Converts transformed partition values to human-readable strings based on + // their transform types (e.g., year -> "2025", month -> "2025-11", hour -> + // "2025-11-12-13") and constructs URL-encoded partition paths. + // Null if the table is unpartitioned. + const std::unique_ptr icebergPartitionName_; + + // RowType schema for the transformed partition values RowVector. + // Contains one column per partition field in partitionSpec, where each + // column has: + // - Type: The result type of the partition transform (e.g., INTEGER for year + // transform, DATE for day transform). + // - Name: Source column name for identity transforms, or + // "columnName_transformName" for non-identity transforms (e.g., + // "date_year"). + // Used to construct the RowVector that wraps the transformed partition + // columns before passing them to IcebergPartitionIdGenerator for partition ID + // generation and to IcebergPartitionNameGenerator for partition path name + // generation. + RowTypePtr partitionRowType_; + + // Stores the transformed partition values for each writer to be included in + // the commit message sent to Presto. Indexed by writer index. Each entry + // contains the transformed partition values (as a folly::dynamic array) for + // that writer's partition, which are serialized to JSON as + // "partitionDataJson" in the commit protocol. These values represent the same + // transformed partition data as partitionIdGenerator_->partitionValues(), but + // converted from columnar storage (where each partition field is a separate + // column in the RowVector) to row storage (where each writer has a + // folly::dynamic array of values across all partition fields), ready for JSON + // serialization. + std::vector commitPartitionValue_; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergPartitionName.cpp b/velox/connectors/hive/iceberg/IcebergPartitionName.cpp new file mode 100644 index 000000000000..97a0f565b8b7 --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergPartitionName.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/common/encode/Base64.h" +#include "velox/dwio/catalog/fbhive/FileUtils.h" +#include "velox/functions/prestosql/URLFunctions.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +std::string escapePathName(const std::string& name) { + std::string encoded; + // Pre-allocate for worst case: every byte is invalid UTF-8. + // urlEscape() writes directly into the pre-allocated buffer and + // calls resize() at the end to shrink to the actual size used. + encoded.resize(name.size() * 9); + functions::detail::urlEscape(encoded, name); + return encoded; +} + +} // namespace + +IcebergPartitionName::IcebergPartitionName( + const IcebergPartitionSpecPtr& partitionSpec) { + VELOX_CHECK_NOT_NULL(partitionSpec); + transformTypes_.reserve(partitionSpec->fields.size()); + for (const auto& field : partitionSpec->fields) { + transformTypes_.emplace_back(field.transformType); + } +} + +std::string IcebergPartitionName::partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase) const { + auto toPartitionName = [this]( + auto value, const TypePtr& type, int columnIndex) { + return IcebergPartitionName::toName( + value, type, transformTypes_[columnIndex]); + }; + + return dwio::catalog::fbhive::FileUtils::makePartName( + HivePartitionName::partitionKeyValues( + partitionId, + partitionValues, + /*nullValueString=*/"null", + toPartitionName), + partitionKeyAsLowerCase, + /*useDefaultPartitionValue=*/false, + escapePathName); +} + +std::string IcebergPartitionName::toName( + int32_t value, + const TypePtr& type, + TransformType transformType) { + constexpr int32_t kEpochYear = 1970; + switch (transformType) { + case TransformType::kIdentity: { + if (type->isDate()) { + return DateType::toIso8601(value); + } + return fmt::to_string(value); + } + case TransformType::kDay: + return DATE()->toString(value); + case TransformType::kYear: + return fmt::format("{:04d}", kEpochYear + value); + case TransformType::kMonth: { + int32_t year = kEpochYear + value / 12; + int32_t month = 1 + value % 12; + if (month <= 0) { + month += 12; + year -= 1; + } + return fmt::format("{:04d}-{:02d}", year, month); + } + case TransformType::kHour: { + int64_t seconds = static_cast(value) * 3600; + std::tm tmValue; + VELOX_USER_CHECK( + Timestamp::epochToCalendarUtc(seconds, tmValue), + "Failed to convert seconds to time: {}", + seconds); + return fmt::format( + "{:04d}-{:02d}-{:02d}-{:02d}", + tmValue.tm_year + 1900, + tmValue.tm_mon + 1, + tmValue.tm_mday, + tmValue.tm_hour); + } + default: + return fmt::to_string(value); + } +} + +std::string IcebergPartitionName::toName( + Timestamp value, + const TypePtr& type, + TransformType transformType) { + VELOX_CHECK(transformType == TransformType::kIdentity); + TimestampToStringOptions options; + options.precision = TimestampPrecision::kMilliseconds; + options.zeroPaddingYear = true; + options.skipTrailingZeros = true; + options.leadingPositiveSign = true; + return value.toString(options); +} + +std::string IcebergPartitionName::toName( + StringView value, + const TypePtr& type, + TransformType transformType) { + VELOX_CHECK( + transformType == TransformType::kIdentity || + transformType == TransformType::kTruncate); + if (type->isVarbinary()) { + return encoding::Base64::encode(value.data(), value.size()); + } + return std::string(value); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergPartitionName.h b/velox/connectors/hive/iceberg/IcebergPartitionName.h new file mode 100644 index 000000000000..751f3620c80d --- /dev/null +++ b/velox/connectors/hive/iceberg/IcebergPartitionName.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/HivePartitionName.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Generates Iceberg-compliant partition path names. +/// Converts partition keys to human-readable strings based on their transform +/// types (e.g., year, month, day, hour, identity, truncate) and constructs +/// URL-encoded partition paths in the format "key1=value1/key2=value2/...". +class IcebergPartitionName { + public: + /// @param partitionSpec Iceberg partition specification containing transform + /// definitions for each partition field. Used to get transform type and call + /// different format functions to convert transformed partition values to + /// human-readable strings. + IcebergPartitionName(const IcebergPartitionSpecPtr& partitionSpec); + + /// Generates an Iceberg compliant partition path string for the given + /// partition ID. + /// + /// Constructs a partition path in the format "key1=value1/key2=value2/..." + /// where: + /// - Keys are partition column names for identity transforms, or + /// "columnName_transformName" for non-identity transforms (e.g., + /// "date_year") + /// - Values are human-readable string representations of the transformed + /// partition keys, formatted according to their transform types + /// - Both keys and values are URL-encoded per java.net.URLEncoder.encode() + /// + /// Example: "store_id=123/date_year=2025/address_bucket=1" + /// + /// Typically called once per partition ID when creating a new writer for that + /// partition. + /// + /// @param partitionId Sequential partition ID (0-based) used as the row index + /// into partitionValues. Must be less than partitionValues->size(). + /// @param partitionValues RowVector containing transformed partition keys + /// for all partitions. Each row represents one unique partition, with + /// columns corresponding to partition fields in partitionSpec. Row at + /// partitionId contains the keys for this specific partition. + /// @param partitionKeyAsLowerCase Whether to convert partition keys to + /// lowercase in the generated partition path. When true, partition keys like + /// "Year" become "year" in the path "year=2025/...". + /// @return URL-encoded partition path string suitable for use in file paths. + std::string partitionName( + uint32_t partitionId, + const RowVectorPtr& partitionValues, + bool partitionKeyAsLowerCase) const; + + /// Generic template for formatting simple types that just need string + /// conversion. Specialized for types that need special handling. + template + FOLLY_ALWAYS_INLINE static std::string + toName(T value, const TypePtr& type, TransformType transformType) { + return HivePartitionName::toName(value, type); + } + + /// Converts an int32_t partition key to its string representation based on + /// the transform type: + /// - kIdentity: For DATE type return "YYYY-MM-DD" format (e.g., + /// "2025-11-07"). + /// For other types return the value as-is (e.g., "-123"). + /// - kDay: Returns date in "YYYY-MM-DD" format (e.g., "2025-11-07"). + /// - kYear: Returns 4-digit year "YYYY" (e.g., "2025"). + /// - kMonth: Returns "YYYY-MM" format (e.g., "2025-01"). + /// - kHour: Returns "YYYY-MM-DD-HH" format (e.g., "2025-11-07-21"). + static std::string + toName(int32_t value, const TypePtr& type, TransformType transformType); + + /// Returns timestamp formatted with milliseconds precision, zero-padded + /// year, trailing zeros skipped, and leading positive sign for years >= + /// 10000. Examples: + /// - Timestamp(0, 0) -> "1970-01-01T00:00:00". + /// - Timestamp(1609459200, 999000000) -> "2021-01-01T00:00:00.999". + /// - Timestamp(1640995200, 500000000) -> "2022-01-01T00:00:00.5". + /// - Timestamp(-1, 999000000) -> "1969-12-31T23:59:59.999". + /// - Timestamp(253402300800, 100000000) -> "+10000-01-01T00:00:00.1". + static std::string + toName(Timestamp value, const TypePtr& type, TransformType transformType); + + /// Converts a StringView partition key to its string representation. + /// - For VARBINARY type returns Base64-encoded string. + /// - For VARCHAR type returns the string value as-is. + static std::string + toName(StringView value, const TypePtr& type, TransformType transformType); + + private: + // Cached transform types, one per partition column. Created once in + // constructor and reused for all formatting operations. Index corresponds to + // column index in partitionSpec_->fields. + std::vector transformTypes_; +}; + +using IcebergPartitionNamePtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergSplitReader.cpp b/velox/connectors/hive/iceberg/IcebergSplitReader.cpp index 84c85d3dce0b..76b4c1fbacf1 100644 --- a/velox/connectors/hive/iceberg/IcebergSplitReader.cpp +++ b/velox/connectors/hive/iceberg/IcebergSplitReader.cpp @@ -26,9 +26,8 @@ namespace facebook::velox::connector::hive::iceberg { IcebergSplitReader::IcebergSplitReader( const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveTableHandle, - const std::unordered_map>* - partitionKeys, + const HiveTableHandlePtr& hiveTableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig, const RowTypePtr& readerOutputType, @@ -55,7 +54,8 @@ IcebergSplitReader::IcebergSplitReader( void IcebergSplitReader::prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) { + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps) { createReader(); if (emptySplit_) { return; @@ -67,7 +67,7 @@ void IcebergSplitReader::prepareSplit( return; } - createRowReader(std::move(metadataFilter), std::move(rowType)); + createRowReader(std::move(metadataFilter), std::move(rowType), std::nullopt); std::shared_ptr icebergSplit = std::dynamic_pointer_cast(hiveSplit_); @@ -85,7 +85,7 @@ void IcebergSplitReader::prepareSplit( hiveSplit_->filePath, fileHandleFactory_, connectorQueryCtx_, - executor_, + ioExecutor_, hiveConfig_, ioStats_, fsStats_, @@ -110,7 +110,7 @@ uint64_t IcebergSplitReader::next(uint64_t size, VectorPtr& output) { } const auto actualSize = baseRowReader_->nextReadSize(size); - + baseReadOffset_ = baseRowReader_->nextRowNumber() - splitOffset_; if (actualSize == dwio::common::RowReader::kAtEnd) { return 0; } @@ -137,9 +137,61 @@ uint64_t IcebergSplitReader::next(uint64_t size, VectorPtr& output) { : nullptr; auto rowsScanned = baseRowReader_->next(actualSize, output, &mutation); - baseReadOffset_ += rowsScanned; return rowsScanned; } +std::vector IcebergSplitReader::adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const { + std::vector columnTypes = fileType->children(); + auto& childrenSpecs = scanSpec_->children(); + // Iceberg table stores all column's data in data file. + for (const auto& childSpec : childrenSpecs) { + const std::string& fieldName = childSpec->fieldName(); + if (auto iter = hiveSplit_->infoColumns.find(fieldName); + iter != hiveSplit_->infoColumns.end()) { + auto infoColumnType = readerOutputType_->findChild(fieldName); + auto constant = newConstantFromString( + infoColumnType, + iter->second, + connectorQueryCtx_->memoryPool(), + hiveConfig_->readTimestampPartitionValueAsLocalTime( + connectorQueryCtx_->sessionProperties()), + false); + childSpec->setConstantValue(constant); + } else { + auto fileTypeIdx = fileType->getChildIdxIfExists(fieldName); + auto outputTypeIdx = readerOutputType_->getChildIdxIfExists(fieldName); + if (outputTypeIdx.has_value() && fileTypeIdx.has_value()) { + childSpec->setConstantValue(nullptr); + auto& outputType = readerOutputType_->childAt(*outputTypeIdx); + columnTypes[*fileTypeIdx] = outputType; + } else if (!fileTypeIdx.has_value()) { + // Handle columns missing from the data file in two scenarios: + // 1. Schema evolution: Column was added after the data file was + // written and doesn't exist in older data files. + // 2. Partition columns: Hive migrated table. In Hive-written data + // files, partition column values are stored in partition metadata + // rather than in the data file itself, following Hive's partitioning + // convention. + if (auto it = hiveSplit_->partitionKeys.find(fieldName); + it != hiveSplit_->partitionKeys.end()) { + setPartitionValue(childSpec.get(), fieldName, it->second); + } else { + childSpec->setConstantValue( + BaseVector::createNullConstant( + tableSchema->findChild(fieldName), + 1, + connectorQueryCtx_->memoryPool())); + } + } + } + } + + scanSpec_->resetCachedValues(false); + + return columnTypes; +} + } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/IcebergSplitReader.h b/velox/connectors/hive/iceberg/IcebergSplitReader.h index 795912159b96..1b56aa077433 100644 --- a/velox/connectors/hive/iceberg/IcebergSplitReader.h +++ b/velox/connectors/hive/iceberg/IcebergSplitReader.h @@ -28,9 +28,8 @@ class IcebergSplitReader : public SplitReader { public: IcebergSplitReader( const std::shared_ptr& hiveSplit, - const std::shared_ptr& hiveTableHandle, - const std::unordered_map>* - partitionKeys, + const HiveTableHandlePtr& hiveTableHandle, + const std::unordered_map* partitionKeys, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& hiveConfig, const RowTypePtr& readerOutputType, @@ -44,11 +43,56 @@ class IcebergSplitReader : public SplitReader { void prepareSplit( std::shared_ptr metadataFilter, - dwio::common::RuntimeStatistics& runtimeStats) override; + dwio::common::RuntimeStatistics& runtimeStats, + const folly::F14FastMap& fileReadOps = {}) + override; uint64_t next(uint64_t size, VectorPtr& output) override; private: + /// Adapts the data file schema to match the table schema expected by the + /// query. + /// + /// This method reconciles differences between the physical data file schema + /// and the logical table schema, handling various scenarios where columns may + /// be missing, added, or need special treatment. + /// + /// @param fileType The schema read from the data file's metadata. This + /// represents the actual columns physically present in the Parquet/ORC file. + /// @param tableSchema The logical schema defined in the catalog (e.g., from + /// DDL). This represents the current table schema that queries expect. + /// + /// @return A vector of column types adapted to match the query's + /// expectations, with appropriate type conversions and constant values set + /// for missing or special columns. + /// + /// The method handles the following scenarios for each column in the scan + /// spec: + /// + /// 1. Info columns (e.g., $path, $data_sequence_number, $deleted) + /// These are virtual columns that provide metadata about the file itself. + /// Values are read from hiveSplit_->infoColumns map and set as constant + /// values in the scanSpec so they're materialized for all rows. + /// + /// 2. Regular columns present in File: + /// Column exists in both fileType and readerOutputType. Type is adapted + /// from fileType to match the expected output type, handling schema + /// evolution where column types may have changed. + /// + /// 3. Columns missing from File: + /// a) Partition columns (hive-migrated tables): + /// Column is marked as partition key in hiveSplit_->partitionKeys. + /// In Hive-written Iceberg tables, partition column values are stored + /// in partition metadata, not in the data file itself. Value is read + /// from partition metadata and set as a constant. + /// b) Schema evolution (newly added columns): + /// Column was added to the table schema after this data file was + /// written. Set as NULL constant since the old file doesn't contain + /// this column. + std::vector adaptColumns( + const RowTypePtr& fileType, + const RowTypePtr& tableSchema) const override; + // The read offset to the beginning of the split in number of rows for the // current batch for the base data file uint64_t baseReadOffset_; diff --git a/velox/connectors/hive/iceberg/PartitionSpec.cpp b/velox/connectors/hive/iceberg/PartitionSpec.cpp new file mode 100644 index 000000000000..4c9fae472ed0 --- /dev/null +++ b/velox/connectors/hive/iceberg/PartitionSpec.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/PartitionSpec.h" + +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +TransformCategory getTransformCategory(TransformType transformType) { + switch (transformType) { + case TransformType::kIdentity: + return TransformCategory::kIdentity; + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kDay: + case TransformType::kHour: + return TransformCategory::kTemporal; + case TransformType::kBucket: + return TransformCategory::kBucket; + case TransformType::kTruncate: + return TransformCategory::kTruncate; + default: + VELOX_UNREACHABLE("Unknown transform type"); + } +} + +bool isValidPartitionType(const TypePtr& type) { + return !( + type->isRow() || type->isArray() || type->isMap() || type->isDouble() || + type->isReal() || isTimestampWithTimeZoneType(type)); +} + +bool canTransform(TransformType transformType, const TypePtr& type) { + switch (transformType) { + case TransformType::kIdentity: + return type->isTinyint() || type->isSmallint() || type->isInteger() || + type->isBigint() || type->isBoolean() || type->isDecimal() || + type->isDate() || type->isTimestamp() || type->isVarchar() || + type->isVarbinary(); + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kDay: + return type->isDate() || type->isTimestamp(); + case TransformType::kHour: + return type->isTimestamp(); + case TransformType::kBucket: + return type->isInteger() || type->isBigint() || type->isDecimal() || + type->isVarchar() || type->isVarbinary() || type->isDate() || + type->isTimestamp(); + case TransformType::kTruncate: + return type->isInteger() || type->isBigint() || type->isDecimal() || + type->isVarchar() || type->isVarbinary(); + default: + VELOX_UNREACHABLE("Unsupported partition transform type."); + } +} + +const auto& transformTypeNames() { + static const folly::F14FastMap + kTransformNames = { + {TransformType::kIdentity, "identity"}, + {TransformType::kHour, "hour"}, + {TransformType::kDay, "day"}, + {TransformType::kMonth, "month"}, + {TransformType::kYear, "year"}, + {TransformType::kBucket, "bucket"}, + {TransformType::kTruncate, "trunc"}, + }; + return kTransformNames; +} + +const auto& transformCategoryNames() { + static const folly::F14FastMap + kTransformCategoryNames = { + {TransformCategory::kIdentity, "Identity"}, + {TransformCategory::kBucket, "Bucket"}, + {TransformCategory::kTruncate, "Truncate"}, + {TransformCategory::kTemporal, "Temporal"}, + }; + return kTransformCategoryNames; +} + +} // namespace + +VELOX_DEFINE_ENUM_NAME(TransformType, transformTypeNames); + +VELOX_DEFINE_ENUM_NAME(TransformCategory, transformCategoryNames); + +void IcebergPartitionSpec::checkCompatibility() const { + folly::F14FastMap> + columnTransforms; + + for (const auto& field : fields) { + const auto& type = field.type; + const auto& name = field.name; + VELOX_USER_CHECK( + isValidPartitionType(type), + "Type is not supported as a partition column: {}", + type->name()); + + VELOX_USER_CHECK( + canTransform(field.transformType, type), + "Transform is not supported for partition column. Column: '{}', Type: '{}', Transform: '{}'.", + name, + type->name(), + TransformTypeName::toName(field.transformType)); + + columnTransforms[name].emplace_back(field.transformType); + } + + // Check for duplicate transform categories per column. + std::vector errors; + for (const auto& [columnName, transforms] : columnTransforms) { + folly::F14FastSet seenCategories; + for (const auto& transform : transforms) { + auto category = getTransformCategory(transform); + if (!seenCategories.insert(category).second) { + std::vector transformNames; + for (const auto& t : transforms) { + transformNames.emplace_back( + std::string(TransformTypeName::toName(t))); + } + errors.emplace_back( + fmt::format( + "Column: '{}', Category: {}, Transforms: [{}]", + columnName, + TransformCategoryName::toName(category), + folly::join(", ", transformNames))); + break; + } + } + } + + VELOX_USER_CHECK( + errors.empty(), + "Multiple transforms of the same category on a column are not allowed. " + "Each transform category can appear at most once per column. {}", + folly::join("; ", errors)); +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/PartitionSpec.h b/velox/connectors/hive/iceberg/PartitionSpec.h new file mode 100644 index 000000000000..99297da0edc3 --- /dev/null +++ b/velox/connectors/hive/iceberg/PartitionSpec.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Partition transform types. +/// Defines how source column values are converted into partition keys. +/// See https://iceberg.apache.org/spec/#partition-transforms. +enum class TransformType { + /// Use the source value as-is (no transformation). + kIdentity, + /// Extract a timestamp hour, as hours from 1970-01-01 00:00:00. + kHour, + /// Extract a date or timestamp day, as days from 1970-01-01. + kDay, + /// Extract a date or timestamp month, as months from 1970-01. + kMonth, + /// Extract a date or timestamp year, as years from 1970. + kYear, + /// Hash the value into N buckets for even distribution. Requires an integer + /// parameter specifying the bucket count. + kBucket, + /// Truncate strings or numbers to a specified width. Requires an integer + /// parameter specifying the truncate width. + kTruncate +}; + +VELOX_DECLARE_ENUM_NAME(TransformType); + +/// A single column can be used to produce multiple partition keys, but with +/// following restrictions: +/// - Transforms are organized into 4 categories: Identity, Temporal, +/// Bucket, and Truncate. +/// - Each category can appear at most once per column. +/// - Sample valid specs on same column: ['truncate(a,2)', 'bucket(a,16)', 'a'] +/// or ['year(b)', 'bucket(b, 16)', 'b'] +enum class TransformCategory { + kIdentity, + /// Year/Month/Day/Hour + kTemporal, + kBucket, + kTruncate, +}; + +VELOX_DECLARE_ENUM_NAME(TransformCategory); + +/// Represents how to produce partition data for an Iceberg table. +/// +/// This structure corresponds to the Iceberg Java PartitionSpec class but +/// contains only the necessary fields for Velox. Partition keys are computed +/// by transforming columns in a table. +/// +/// The upstream engine processes this specification through the Iceberg Java +/// library to validate column types, detect duplicates, and generate the +/// partition spec that is passed to Velox. +/// +/// IMPORTANT: Iceberg spec uses field IDs to identify source columns, but +/// Velox RowType only supports matching fields by name. Therefore, Velox uses +/// the partition field name to match against the table schema column names. +/// Callers must ensure that partition field names exactly match the column +/// names in the table schema. +/// +/// The partition spec contains: +/// - Unique ID for versioning and evolution. +/// - Which source columns in current table schema to use for partitioning +/// (identified by field name, not field ID as in the Iceberg spec). +/// - What transforms to apply (identity, bucket, truncate etc.). +/// - Transform parameters (e.g., bucket count, truncate width). +struct IcebergPartitionSpec { + struct Field { + /// Column name as defined in table schema. This column's value is used to + /// compute partition key by applying 'transformType' transformation. + const std::string name; + + /// Column type. + const TypePtr type; + + /// Transform to apply. Callers must ensure the transform is compatible with + /// the column type. + const TransformType transformType; + + /// Optional parameter for transforms that require configuration. + const std::optional parameter; + + /// Returns the result type after applying this transform. + TypePtr resultType() const { + switch (transformType) { + case TransformType::kBucket: + case TransformType::kYear: + case TransformType::kMonth: + case TransformType::kHour: + return INTEGER(); + case TransformType::kDay: + return DATE(); + case TransformType::kIdentity: + case TransformType::kTruncate: + return type; + } + VELOX_UNREACHABLE("Unknown transform type"); + } + }; + + const int32_t specId; + const std::vector fields; + + /// Constructor with validation that: + /// - Each field's type is supported for partitioning. + /// - Each field's transform type is compatible with its data type. + /// - No transform category appears more than once per column (Identity, + /// Temporal, Bucket, and Truncate are separate categories). + /// + /// @param _specId Partition specification ID. + /// @param _fields Vector of partition fields. When empty indicates no + /// partition. + /// @throws VeloxUserError if validation fails. + IcebergPartitionSpec(int32_t _specId, std::vector _fields) + : specId(_specId), fields(std::move(_fields)) { + checkCompatibility(); + } + + private: + // Validates partition fields for correctness. + // Checks type/transform compatibility and transform combination rules. + void checkCompatibility() const; +}; + +using IcebergPartitionSpecPtr = std::shared_ptr; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp index 8f0d7ad13750..346ac8fed72a 100644 --- a/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp +++ b/velox/connectors/hive/iceberg/PositionalDeleteFileReader.cpp @@ -16,6 +16,7 @@ #include "velox/connectors/hive/iceberg/PositionalDeleteFileReader.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" #include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" @@ -63,8 +64,9 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( auto scanSpec = std::make_shared(""); scanSpec->addField(posColumn_->name, 0); auto* pathSpec = scanSpec->getOrCreateChild(filePathColumn_->name); - pathSpec->setFilter(std::make_unique( - std::vector({baseFilePath_}), false)); + pathSpec->setFilter( + std::make_unique( + std::vector({baseFilePath_}), false)); // Create the file schema (in RowType) and split that will be used by readers std::vector deleteColumnNames( @@ -92,9 +94,11 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( /*tableParameters=*/{}, deleteReaderOpts); - auto deleteFileHandleCachePtr = - fileHandleFactory_->generate(deleteFile_.filePath); - auto deleteFileInput = createBufferedInput( + const FileHandleKey fileHandleKey{ + .filename = deleteFile_.filePath, + .tokenProvider = connectorQueryCtx_->fsTokenProvider()}; + auto deleteFileHandleCachePtr = fileHandleFactory_->generate(fileHandleKey); + auto deleteFileInput = BufferedInputBuilder::getInstance()->create( *deleteFileHandleCachePtr, deleteReaderOpts, connectorQueryCtx, @@ -135,6 +139,7 @@ PositionalDeleteFileReader::PositionalDeleteFileReader( deleteSplit_, nullptr, nullptr, + nullptr, deleteRowReaderOpts); deleteRowReader_.reset(); @@ -249,15 +254,17 @@ void PositionalDeleteFileReader::updateDeleteBitmap( deletePositionsOffset_++; } - deleteBitmapBuffer->setSize(std::max( - static_cast(deleteBitmapBuffer->size()), - deletePositionsOffset_ == 0 || - (deletePositionsOffset_ < deletePositionsVector->size() && - deletePositions[deletePositionsOffset_] >= rowNumberUpperBound) - ? 0 - : bits::nbytes( - deletePositions[deletePositionsOffset_ - 1] + 1 - - rowNumberLowerBound))); + deleteBitmapBuffer->setSize( + std::max( + static_cast(deleteBitmapBuffer->size()), + deletePositionsOffset_ == 0 || + (deletePositionsOffset_ < deletePositionsVector->size() && + deletePositions[deletePositionsOffset_] >= + rowNumberUpperBound) + ? 0 + : bits::nbytes( + deletePositions[deletePositionsOffset_ - 1] + 1 - + rowNumberLowerBound))); } bool PositionalDeleteFileReader::readFinishedForBatch( diff --git a/velox/connectors/hive/iceberg/TransformEvaluator.cpp b/velox/connectors/hive/iceberg/TransformEvaluator.cpp new file mode 100644 index 000000000000..2744bddafa75 --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformEvaluator.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" + +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg { + +TransformEvaluator::TransformEvaluator( + const std::vector& expressions, + const ConnectorQueryCtx* connectorQueryCtx) + : connectorQueryCtx_(connectorQueryCtx) { + VELOX_CHECK_NOT_NULL(connectorQueryCtx_); + exprSet_ = connectorQueryCtx_->expressionEvaluator()->compile(expressions); + VELOX_CHECK_NOT_NULL(exprSet_); +} + +std::vector TransformEvaluator::evaluate( + const RowVectorPtr& input) const { + const auto numRows = input->size(); + const auto numExpressions = exprSet_->exprs().size(); + + std::vector results(numExpressions); + SelectivityVector rows(numRows); + + // Evaluate all expressions in one pass. + connectorQueryCtx_->expressionEvaluator()->evaluate( + exprSet_.get(), rows, *input, results); + + return results; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/TransformEvaluator.h b/velox/connectors/hive/iceberg/TransformEvaluator.h new file mode 100644 index 000000000000..ee7b26f7db8c --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformEvaluator.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/connectors/Connector.h" +#include "velox/core/QueryCtx.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Evaluates multiple expressions efficiently using batch evaluation. +/// Expressions are compiled once in the constructor and reused across multiple +/// input batches. +class TransformEvaluator { + public: + /// Creates an evaluator with the given expressions and connector query + /// context. Compiles the expressions once for reuse across multiple + /// evaluations. + /// + /// @param expressions Vector of typed expressions to evaluate. These are + /// typically built using TransformExprBuilder::toExpressions() for Iceberg + /// partition transforms, but can be any valid Velox expressions. The + /// expressions are compiled once during construction. + /// @param connectorQueryCtx Connector query context providing access to the + /// expression evaluator (for compilation and evaluation) and memory pool. + /// Must remain valid for the lifetime of this TransformEvaluator. + TransformEvaluator( + const std::vector& expressions, + const ConnectorQueryCtx* connectorQueryCtx); + + /// Evaluates all expressions on the input data in a single pass. + /// Uses the pre-compiled ExprSet from the constructor for efficiency. + /// + /// The input RowType must match the RowType used when building the + /// expressions (passed to TransformExprBuilder::toExpressions). The column + /// positions, names and types must align. Create new TransformEvaluator for + /// input that has different RowType with the one when building the + /// expressions. + /// + /// @param input Input row vector containing the source data. Must have the + /// same RowType (column positions, names and types) as used when building the + /// expressions in the constructor. + /// @return Vector of result columns, one for each expression, in the same + /// order as the expressions provided to the constructor. + std::vector evaluate(const RowVectorPtr& input) const; + + private: + const ConnectorQueryCtx* connectorQueryCtx_; + std::unique_ptr exprSet_; +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/TransformExprBuilder.cpp b/velox/connectors/hive/iceberg/TransformExprBuilder.cpp new file mode 100644 index 000000000000..4befbfb50b08 --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformExprBuilder.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/core/Expressions.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +/// Converts a single partition field to a typed expression. +/// +/// Builds an expression tree for one partition transform. Identity transforms +/// become FieldAccessTypedExpr, while other transforms (bucket, truncate, +/// year, month, day, hour) become CallTypedExpr with appropriate function +/// names and parameters. +/// +/// @param field Partition field containing transform type, source column +/// type, and optional parameter (e.g., bucket count, truncate width). +/// @param inputFieldName Name of the source column in the input RowVector. +/// @param icebergFuncPrefix Prefix of iceberg transform function names. +/// @return Typed expression representing the transform. +core::TypedExprPtr toExpression( + const IcebergPartitionSpec::Field& field, + const std::string& inputFieldName, + const std::string& icebergFuncPrefix) { + // For identity transform, just return a field access expression. + if (field.transformType == TransformType::kIdentity) { + return std::make_shared( + field.type, inputFieldName); + } + + // For other transforms, build a CallTypedExpr with the appropriate function. + std::string functionName; + switch (field.transformType) { + case TransformType::kBucket: + functionName = icebergFuncPrefix + "bucket"; + break; + case TransformType::kTruncate: + functionName = icebergFuncPrefix + "truncate"; + break; + case TransformType::kYear: + functionName = icebergFuncPrefix + "years"; + break; + case TransformType::kMonth: + functionName = icebergFuncPrefix + "months"; + break; + case TransformType::kDay: + functionName = icebergFuncPrefix + "days"; + break; + case TransformType::kHour: + functionName = icebergFuncPrefix + "hours"; + break; + case TransformType::kIdentity: + break; + } + + // Build the expression arguments. + std::vector exprArgs; + if (field.parameter.has_value()) { + exprArgs.emplace_back( + std::make_shared( + INTEGER(), Variant(field.parameter.value()))); + } + exprArgs.emplace_back( + std::make_shared(field.type, inputFieldName)); + + return std::make_shared( + field.resultType(), std::move(exprArgs), functionName); +} + +} // namespace + +std::vector TransformExprBuilder::toExpressions( + const IcebergPartitionSpecPtr& partitionSpec, + const std::vector& partitionChannels, + const RowTypePtr& inputType, + const std::string& icebergFuncPrefix) { + VELOX_CHECK_EQ( + partitionSpec->fields.size(), + partitionChannels.size(), + "Number of partition fields must match number of partition channels"); + + const auto numTransforms = partitionChannels.size(); + std::vector transformExprs; + transformExprs.reserve(numTransforms); + + for (auto i = 0; i < numTransforms; i++) { + const auto channel = partitionChannels[i]; + transformExprs.emplace_back(toExpression( + partitionSpec->fields.at(i), + inputType->nameOf(channel), + icebergFuncPrefix)); + } + + return transformExprs; +} + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/TransformExprBuilder.h b/velox/connectors/hive/iceberg/TransformExprBuilder.h new file mode 100644 index 000000000000..b583adcf97df --- /dev/null +++ b/velox/connectors/hive/iceberg/TransformExprBuilder.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg { + +/// Converts Iceberg partition specification to Velox expressions. +class TransformExprBuilder { + public: + /// Converts partition specification to a list of typed expressions. + /// + /// @param partitionSpec Iceberg partition specification containing transform + /// definitions for each partition field. + /// @param partitionChannels Column indices (0-based) in the input RowVector + /// that correspond to each partition field. Must have the same size as + /// partitionSpec->fields. Provides the positional mapping from partition spec + /// fields to input RowVector columns. + /// @param inputType The row type of the input data. This is necessary for + /// building expressions because the column names in partitionSpec reference + /// table schema names, which might not match the column names in inputType + /// (e.g., inputType may use generated names like c0, c1, c2). The + /// FieldAccessTypedExpr must be built using the actual column names from + /// inputType that will be present at runtime. The partitionChannels provide + /// the positional mapping to locate the correct columns. + /// @param icebergFuncPrefix Prefix for Iceberg transform function names. + /// @return Vector of typed expressions, one for each partition field. + static std::vector toExpressions( + const IcebergPartitionSpecPtr& partitionSpec, + const std::vector& partitionChannels, + const RowTypePtr& inputType, + const std::string& icebergFuncPrefix); +}; + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/CMakeLists.txt b/velox/connectors/hive/iceberg/tests/CMakeLists.txt index 51d1116d1845..883503bbef70 100644 --- a/velox/connectors/hive/iceberg/tests/CMakeLists.txt +++ b/velox/connectors/hive/iceberg/tests/CMakeLists.txt @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_dwio_iceberg_reader_benchmark_lib - IcebergSplitReaderBenchmark.cpp) +add_library(velox_dwio_iceberg_reader_benchmark_lib IcebergSplitReaderBenchmark.cpp) target_link_libraries( velox_dwio_iceberg_reader_benchmark_lib velox_exec_test_lib @@ -20,11 +19,11 @@ target_link_libraries( velox_hive_connector Folly::folly Folly::follybenchmark - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) if(VELOX_ENABLE_BENCHMARKS) - add_executable(velox_dwio_iceberg_reader_benchmark - IcebergSplitReaderBenchmarkMain.cpp) + add_executable(velox_dwio_iceberg_reader_benchmark IcebergSplitReaderBenchmarkMain.cpp) target_link_libraries( velox_dwio_iceberg_reader_benchmark velox_dwio_iceberg_reader_benchmark_lib @@ -33,13 +32,12 @@ if(VELOX_ENABLE_BENCHMARKS) velox_hive_connector Folly::folly Folly::follybenchmark - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} + ) endif() if(NOT VELOX_DISABLE_GOOGLETEST) - - add_executable(velox_hive_iceberg_test IcebergReadTest.cpp - IcebergSplitReaderBenchmarkTest.cpp) + add_executable(velox_hive_iceberg_test IcebergReadTest.cpp IcebergSplitReaderBenchmarkTest.cpp) add_test(velox_hive_iceberg_test velox_hive_iceberg_test) target_link_libraries( @@ -56,6 +54,43 @@ if(NOT VELOX_DISABLE_GOOGLETEST) Folly::folly Folly::follybenchmark GTest::gtest - GTest::gtest_main) + GTest::gtest_main + ) + + add_executable( + velox_hive_iceberg_insert_test + IcebergConnectorTest.cpp + IcebergInsertTest.cpp + IcebergTestBase.cpp + Main.cpp + PartitionNameTest.cpp + PartitionSpecTest.cpp + PartitionValueFormatterTest.cpp + TransformE2ETest.cpp + TransformTest.cpp + ) + + add_test(velox_hive_iceberg_insert_test velox_hive_iceberg_insert_test) + + target_link_libraries( + velox_hive_iceberg_insert_test + velox_exec_test_lib + velox_hive_connector + velox_hive_iceberg_splitreader + velox_vector_fuzzer + GTest::gtest + GTest::gtest_main + ) + + if(VELOX_ENABLE_PARQUET) + target_link_libraries(velox_hive_iceberg_test velox_dwio_parquet_reader) + + file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + target_link_libraries( + velox_hive_iceberg_insert_test + velox_dwio_parquet_reader + velox_dwio_parquet_writer + ) + endif() endif() diff --git a/velox/connectors/hive/iceberg/tests/IcebergConnectorTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergConnectorTest.cpp new file mode 100644 index 000000000000..3966ccd8307b --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergConnectorTest.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +class IcebergConnectorTest : public test::IcebergTestBase { + protected: + static void resetIcebergConnector( + const std::shared_ptr& config) { + unregisterConnector(test::kIcebergConnectorId); + + IcebergConnectorFactory factory; + auto icebergConnector = + factory.newConnector(test::kIcebergConnectorId, config); + registerConnector(icebergConnector); + } +}; + +TEST_F(IcebergConnectorTest, connectorConfiguration) { + auto customConfig = std::make_shared( + std::unordered_map{ + {hive::HiveConfig::kEnableFileHandleCache, "true"}, + {hive::HiveConfig::kNumCacheFileHandles, "1000"}}); + + resetIcebergConnector(customConfig); + + // Verify connector was registered successfully with custom config. + auto icebergConnector = getConnector(test::kIcebergConnectorId); + ASSERT_NE(icebergConnector, nullptr); + + auto config = icebergConnector->connectorConfig(); + ASSERT_NE(config, nullptr); + + hive::HiveConfig hiveConfig(config); + ASSERT_TRUE(hiveConfig.isFileHandleCacheEnabled()); + ASSERT_EQ(hiveConfig.numCacheFileHandles(), 1000); +} + +TEST_F(IcebergConnectorTest, connectorProperties) { + auto icebergConnector = getConnector(test::kIcebergConnectorId); + ASSERT_NE(icebergConnector, nullptr); + + ASSERT_TRUE(icebergConnector->canAddDynamicFilter()); + ASSERT_TRUE(icebergConnector->supportsSplitPreload()); + ASSERT_NE(icebergConnector->ioExecutor(), nullptr); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp new file mode 100644 index 000000000000..459a1bf83f6d --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergInsertTest.cpp @@ -0,0 +1,247 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::connector::hive::iceberg { +namespace { + +#ifdef VELOX_ENABLE_PARQUET + +class IcebergInsertTest : public test::IcebergTestBase { + protected: + void test(const RowTypePtr& rowType, double nullRatio = 0.0) { + const auto outputDirectory = exec::test::TempDirectoryPath::create(); + const auto dataPath = outputDirectory->getPath(); + constexpr int32_t numBatches = 10; + constexpr int32_t vectorSize = 5'000; + const auto vectors = + createTestData(rowType, numBatches, vectorSize, nullRatio); + const auto dataSink = createDataSinkAndAppendData(vectors, dataPath); + const auto commitTasks = dataSink->close(); + + auto splits = createSplitsForDirectory(dataPath); + ASSERT_EQ(splits.size(), commitTasks.size()); + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); + } +}; + +TEST_F(IcebergInsertTest, basic) { + auto rowType = + ROW({"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11"}, + {BIGINT(), + INTEGER(), + SMALLINT(), + BOOLEAN(), + REAL(), + DECIMAL(18, 5), + VARCHAR(), + VARBINARY(), + DATE(), + TIMESTAMP(), + ROW({"id", "name"}, {INTEGER(), VARCHAR()})}); + test(rowType, 0.2); +} + +TEST_F(IcebergInsertTest, mapAndArray) { + auto rowType = + ROW({"c1", "c2"}, {MAP(INTEGER(), VARCHAR()), ARRAY(VARCHAR())}); + test(rowType); +} + +TEST_F(IcebergInsertTest, bigDecimal) { + auto rowType = ROW({"c1"}, {DECIMAL(38, 5)}); + fileFormat_ = dwio::common::FileFormat::PARQUET; + test(rowType); +} + +TEST_F(IcebergInsertTest, singleColumnPartition) { + struct TestCase { + std::string name; + TypePtr type; + }; + + std::vector testCases = { + {"c1", BIGINT()}, + {"c2", INTEGER()}, + {"c3", SMALLINT()}, + {"c4", DECIMAL(18, 5)}, + {"c5", BOOLEAN()}, + {"c6", VARCHAR()}, + {"c7", DATE()}, + {"c8", TIMESTAMP()}}; + + for (const auto& testCase : testCases) { + const auto outputDirectory = exec::test::TempDirectoryPath::create(); + constexpr int32_t numBatches = 2; + constexpr int32_t vectorSize = 50; + auto rowType = ROW({testCase.name}, {testCase.type}); + + const auto vectors = createTestData(rowType, numBatches, vectorSize, 0.5); + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}}; + const auto dataSink = createDataSinkAndAppendData( + vectors, outputDirectory->getPath(), partitionTransforms); + const auto commitTasks = dataSink->close(); + auto splits = createSplitsForDirectory(outputDirectory->getPath()); + + ASSERT_GT(commitTasks.size(), 0); + ASSERT_EQ(splits.size(), commitTasks.size()); + + for (const auto& task : commitTasks) { + auto taskJson = folly::parseJson(task); + ASSERT_TRUE(taskJson.count("partitionDataJson") > 0); + } + + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); + } +} + +TEST_F(IcebergInsertTest, partitionNullColumn) { + struct TestCase { + std::string name; + TypePtr type; + }; + + std::vector testCases = { + {"c1", BIGINT()}, + {"c2", INTEGER()}, + {"c3", SMALLINT()}, + {"c4", DECIMAL(18, 5)}, + {"c5", BOOLEAN()}, + {"c6", VARCHAR()}, + {"c7", DATE()}, + {"c8", TIMESTAMP()}}; + + for (const auto& testCase : testCases) { + const auto outputDirectory = exec::test::TempDirectoryPath::create(); + constexpr int32_t numBatches = 2; + constexpr int32_t vectorSize = 100; + auto rowType = ROW({testCase.name}, {testCase.type}); + // nullRatio = 1.0 + const auto vectors = createTestData(rowType, numBatches, vectorSize, 1.0); + + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}}; + const auto dataSink = createDataSinkAndAppendData( + vectors, outputDirectory->getPath(), partitionTransforms); + + const auto commitTasks = dataSink->close(); + ASSERT_EQ(1, commitTasks.size()); + auto taskJson = folly::parseJson(commitTasks.at(0)); + ASSERT_EQ(1, taskJson.count("partitionDataJson")); + auto partitionData = + folly::parseJson(taskJson["partitionDataJson"].asString()); + ASSERT_EQ(1, partitionData.count("partitionValues")); + auto partitionValues = partitionData["partitionValues"]; + ASSERT_TRUE(partitionValues.isArray()); + ASSERT_TRUE(partitionValues[0].isNull()); + + auto files = listFiles(outputDirectory->getPath()); + ASSERT_EQ(files.size(), 1); + + for (const auto& file : files) { + auto partitionKeys = extractPartitionKeys(file); + ASSERT_EQ(partitionKeys.size(), 1); + ASSERT_TRUE(partitionKeys.contains(testCase.name)); + ASSERT_FALSE(partitionKeys.at(testCase.name).has_value()); + } + } +} + +TEST_F(IcebergInsertTest, partitionMultiColumns) { + auto rowType = + ROW({"c1", "c2", "c3", "c4"}, + { + BIGINT(), + INTEGER(), + SMALLINT(), + DECIMAL(18, 5), + }); + std::vector> columnCombinations = { + {0, 1}, // BIGINT, INTEGER. + {2, 1}, // SMALLINT, INTEGER. + {2, 3}, // SMALLINT, DECIMAL. + {0, 2, 1} // BIGINT, SMALLINT, INTEGER. + }; + + for (const auto& combination : columnCombinations) { + const auto outputDirectory = exec::test::TempDirectoryPath::create(); + constexpr int32_t numBatches = 2; + constexpr int32_t vectorSize = 50; + + std::vector vectors; + vectors.reserve(numBatches); + for (int32_t batch = 0; batch < numBatches; ++batch) { + vectors.push_back(makeRowVector( + rowType->names(), + { + makeFlatVector( + vectorSize, [](auto row) { return row * 100; }), + makeFlatVector( + vectorSize, [](auto row) { return row * 10; }), + makeFlatVector(vectorSize, [](auto row) { return row; }), + makeFlatVector( + vectorSize, + [](auto row) { return (row * 1000); }, + nullptr, + DECIMAL(18, 5)), + })); + } + + std::vector partitionTransforms; + for (auto colIndex : combination) { + partitionTransforms.push_back( + {colIndex, TransformType::kIdentity, std::nullopt}); + } + + const auto dataSink = createDataSinkAndAppendData( + vectors, outputDirectory->getPath(), partitionTransforms); + + const auto commitTasks = dataSink->close(); + auto splits = createSplitsForDirectory(outputDirectory->getPath()); + + ASSERT_EQ(commitTasks.size(), vectorSize); + ASSERT_EQ(splits.size(), commitTasks.size()); + + auto plan = exec::test::PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + exec::test::AssertQueryBuilder(plan).splits(splits).assertResults(vectors); + } +} +#endif + +} // namespace +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp b/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp index 5ed64da321f2..34e87fc25257 100644 --- a/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp +++ b/velox/connectors/hive/iceberg/tests/IcebergReadTest.cpp @@ -14,17 +14,22 @@ * limitations under the License. */ +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" #include "velox/connectors/hive/iceberg/IcebergMetadataColumns.h" #include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/dwio/common/tests/utils/DataFiles.h" #include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include +#ifdef VELOX_ENABLE_PARQUET +#include "velox/dwio/parquet/RegisterParquetReader.h" +#endif using namespace facebook::velox::exec::test; using namespace facebook::velox::exec; @@ -35,6 +40,13 @@ namespace facebook::velox::connector::hive::iceberg { class HiveIcebergTest : public HiveConnectorTestBase { public: + void SetUp() override { + HiveConnectorTestBase::SetUp(); +#ifdef VELOX_ENABLE_PARQUET + parquet::registerParquetReaderFactory(); +#endif + } + HiveIcebergTest() : config_{std::make_shared()} { // Make the writers flush per batch so that we can create non-aligned @@ -144,7 +156,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { return values; } - std::vector makeContinuousIncreasingValues( + static std::vector makeContinuousIncreasingValues( int64_t begin, int64_t end) { std::vector values; @@ -223,13 +235,13 @@ class HiveIcebergTest : public HiveConnectorTestBase { std::string duckdbSql = getDuckDBQuery(rowGroupSizesForFiles, deleteFilesForBaseDatafiles); - auto plan = tableScanNode(); + auto plan = PlanBuilder().tableScan(ROW({"c0"}, {BIGINT()})).planNode(); auto task = HiveConnectorTestBase::assertQuery( plan, splits, duckdbSql, numPrefetchSplits); auto planStats = toPlanStats(task->taskStats()); - auto scanNodeId = plan->id(); - auto it = planStats.find(scanNodeId); + + auto it = planStats.find(plan->id()); ASSERT_TRUE(it != planStats.end()); ASSERT_TRUE(it->second.peakMemoryBytes > 0); } @@ -252,27 +264,100 @@ class HiveIcebergTest : public HiveConnectorTestBase { auto file = filesystems::getFileSystem(dataFilePath, nullptr) ->openFileForRead(dataFilePath); const int64_t fileSize = file->size(); - std::vector> splits; const uint64_t splitSize = std::floor((fileSize) / splitCount); + std::vector> splits; + splits.reserve(splitCount); + for (int i = 0; i < splitCount; ++i) { - splits.emplace_back(std::make_shared( - kHiveConnectorId, - dataFilePath, - fileFomat_, - i * splitSize, - splitSize, - partitionKeys, - std::nullopt, - customSplitInfo, - nullptr, - /*cacheable=*/true, - deleteFiles)); + splits.emplace_back( + std::make_shared( + kHiveConnectorId, + dataFilePath, + fileFomat_, + i * splitSize, + splitSize, + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + /*cacheable=*/true, + deleteFiles)); } return splits; } + ColumnHandleMap makeColumnHandles( + const RowTypePtr& rowType, + const std::unordered_set& partitionIndices = {}) { + ColumnHandleMap assignments; + for (auto i = 0; i < rowType->size(); ++i) { + const auto& columnName = rowType->nameOf(i); + const auto& columnType = rowType->childAt(i); + auto columnHandleType = partitionIndices.contains(i) + ? HiveColumnHandle::ColumnType::kPartitionKey + : HiveColumnHandle::ColumnType::kRegular; + + assignments.insert( + {columnName, + std::make_shared( + columnName, + columnHandleType, + columnType, + columnType, + std::vector{})}); + } + + return assignments; + } + +#ifdef VELOX_ENABLE_PARQUET + std::vector> createParquetDeleteFileAndSplits( + const std::string& path, + const std::vector& deletePositionsVec, + int32_t deletedPositionSize, + const std::shared_ptr& deleteFilePath) { + writeToFile( + deleteFilePath->getPath(), + {makeRowVector( + {pathColumn_->name, posColumn_->name}, + { + makeFlatVector( + static_cast(deletedPositionSize), + [&](vector_size_t) { return path; }), + makeFlatVector(deletePositionsVec), + })}); + + IcebergDeleteFile icebergDeleteFile( + FileContent::kPositionalDeletes, + deleteFilePath->getPath(), + fileFomat_, + deletedPositionSize, + testing::internal::GetFileSize( + std::fopen(deleteFilePath->getPath().c_str(), "r"))); + auto fileSize = filesystems::getFileSystem(path, nullptr) + ->openFileForRead(path) + ->size(); + + std::unordered_map customSplitInfo{ + {"table_format", "hive-iceberg"}}; + std::unordered_map> partitionKeys; + return {std::make_shared( + kHiveConnectorId, + path, + dwio::common::FileFormat::PARQUET, + 0, + fileSize, + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + /*cacheable=*/true, + std::vector{icebergDeleteFile})}; + } +#endif + private: std::map> writeDataFiles( std::map> rowGroupSizesForFiles) { @@ -289,11 +374,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { // files. This is to make constructing DuckDB queries easier std::vector dataVectors = makeVectors(dataFile.second, startingValue); - writeToFile( - dataFilePaths[dataFile.first]->getPath(), - dataVectors, - config_, - flushPolicyFactory_); + writeToFile(dataFilePaths[dataFile.first]->getPath(), dataVectors); for (int i = 0; i < dataVectors.size(); i++) { dataVectorsJoined.push_back(dataVectors[i]); @@ -350,11 +431,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { totalPositionsInDeleteFile += positionsInRowGroup.size(); } - writeToFile( - deleteFilePath->getPath(), - deleteFileVectors, - config_, - flushPolicyFactory_); + writeToFile(deleteFilePath->getPath(), deleteFileVectors); deleteFilePaths[deleteFileName] = std::make_pair(totalPositionsInDeleteFile, deleteFilePath); @@ -373,8 +450,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { for (int j = 0; j < vectorSizes.size(); j++) { auto data = makeContinuousIncreasingValues( startingValue, startingValue + vectorSizes[j]); - VectorPtr c0 = makeFlatVector(data); - vectors.push_back(makeRowVector({"c0"}, {c0})); + vectors.push_back(makeRowVector({makeFlatVector(data)})); startingValue += vectorSizes[j]; } @@ -401,9 +477,9 @@ class HiveIcebergTest : public HiveConnectorTestBase { // Group the delete vectors by baseFileName std::map>> deletePosVectorsForAllBaseFiles; - for (auto deleteFile : deleteFilesForBaseDatafiles) { + for (auto& deleteFile : deleteFilesForBaseDatafiles) { auto deleteFileContent = deleteFile.second; - for (auto rowGroup : deleteFileContent) { + for (auto& rowGroup : deleteFileContent) { auto baseFileName = rowGroup.first; deletePosVectorsForAllBaseFiles[baseFileName].push_back( rowGroup.second); @@ -416,7 +492,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { std::map> flattenedDeletePosVectorsForAllBaseFiles; int64_t totalNumDeletePositions = 0; - for (auto deleteVectorsForBaseFile : deletePosVectorsForAllBaseFiles) { + for (auto& deleteVectorsForBaseFile : deletePosVectorsForAllBaseFiles) { auto baseFileName = deleteVectorsForBaseFile.first; auto deletePositionVectors = deleteVectorsForBaseFile.second; std::vector deletePositionVector = @@ -429,14 +505,18 @@ class HiveIcebergTest : public HiveConnectorTestBase { // Now build the DuckDB queries if (totalNumDeletePositions == 0) { return "SELECT * FROM tmp"; - } else if (totalNumDeletePositions >= totalNumRowsInAllBaseFiles) { + } + + if (totalNumDeletePositions >= totalNumRowsInAllBaseFiles) { return "SELECT * FROM tmp WHERE 1 = 0"; - } else { + } + + { // Convert the delete positions in all base files into column values std::vector allDeleteValues; int64_t numRowsInPreviousBaseFiles = 0; - for (auto baseFileSize : baseFileSizes) { + for (auto& baseFileSize : baseFileSizes) { auto deletePositions = flattenedDeletePosVectorsForAllBaseFiles[baseFileSize.first]; @@ -456,7 +536,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { return fmt::format( "SELECT * FROM tmp WHERE c0 NOT IN ({})", - makeNotInList(allDeleteValues)); + folly::join(", ", allDeleteValues)); } } @@ -464,7 +544,7 @@ class HiveIcebergTest : public HiveConnectorTestBase { const std::vector>& deletePositionVectors, int64_t baseFileSize) { std::vector deletePositionVector; - for (auto vec : deletePositionVectors) { + for (auto& vec : deletePositionVectors) { for (auto pos : vec) { if (pos >= 0 && pos < baseFileSize) { deletePositionVector.push_back(pos); @@ -480,29 +560,11 @@ class HiveIcebergTest : public HiveConnectorTestBase { return deletePositionVector; } - std::string makeNotInList(const std::vector& deletePositionVector) { - if (deletePositionVector.empty()) { - return ""; - } - - return std::accumulate( - deletePositionVector.begin() + 1, - deletePositionVector.end(), - std::to_string(deletePositionVector[0]), - [](const std::string& a, int64_t b) { - return a + ", " + std::to_string(b); - }); - } - - core::PlanNodePtr tableScanNode() { - return PlanBuilder(pool_.get()).tableScan(rowType_).planNode(); - } - dwio::common::FileFormat fileFomat_{dwio::common::FileFormat::DWRF}; - RowTypePtr rowType_{ROW({"c0"}, {BIGINT()})}; std::shared_ptr pathColumn_ = IcebergMetadataColumn::icebergDeleteFilePathColumn(); + std::shared_ptr posColumn_ = IcebergMetadataColumn::icebergDeletePosColumn(); }; @@ -532,7 +594,7 @@ TEST_F(HiveIcebergTest, singleBaseFileSinglePositionalDeleteFile) { /// delete positions. The parameter passed to /// assertSingleBaseFileSingleDeleteFile is the delete positions.for the middle /// base file. -TEST_F(HiveIcebergTest, MultipleBaseFilesSinglePositionalDeleteFile) { +TEST_F(HiveIcebergTest, multipleBaseFilesSinglePositionalDeleteFile) { folly::SingletonVault::singleton()->registrationComplete(); assertMultipleBaseFileSingleDeleteFile({0, 1, 2, 3}); @@ -695,7 +757,7 @@ TEST_F(HiveIcebergTest, positionalDeletesMultipleSplits) { assertMultipleSplits({1000, 9000, 20000}, 1, 0, 20000, 3); } -TEST_F(HiveIcebergTest, testPartitionedRead) { +TEST_F(HiveIcebergTest, partitionedRead) { RowTypePtr rowType{ROW({"c0", "ds"}, {BIGINT(), DateType::get()})}; std::unordered_map> partitionKeys; // Iceberg API sets partition values for dates to daysSinceEpoch, so @@ -713,43 +775,19 @@ TEST_F(HiveIcebergTest, testPartitionedRead) { auto dataFilePath = TempFilePath::create(); dataFilePaths.push_back(dataFilePath); - writeToFile( - dataFilePath->getPath(), dataVectors, config_, flushPolicyFactory_); + writeToFile(dataFilePath->getPath(), dataVectors); partitionKeys["ds"] = std::to_string(daysSinceEpoch); auto icebergSplits = makeIcebergSplits(dataFilePath->getPath(), {}, partitionKeys); splits.insert(splits.end(), icebergSplits.begin(), icebergSplits.end()); } - std::unordered_map> - assignments; - assignments.insert( - {"c0", - std::make_shared( - "c0", - HiveColumnHandle::ColumnType::kRegular, - rowType->childAt(0), - rowType->childAt(0))}); - - std::vector requiredSubFields; - HiveColumnHandle::ColumnParseParameters columnParseParameters; - columnParseParameters.partitionDateValueFormat = - HiveColumnHandle::ColumnParseParameters::kDaysSinceEpoch; - assignments.insert( - {"ds", - std::make_shared( - "ds", - HiveColumnHandle::ColumnType::kPartitionKey, - rowType->childAt(1), - rowType->childAt(1), - std::move(requiredSubFields), - columnParseParameters)}); - - auto plan = PlanBuilder(pool_.get()) - .tableScan(rowType, {}, "", nullptr, assignments) - .planNode(); + auto assignments = makeColumnHandles(rowType, {1}); - HiveConnectorTestBase::assertQuery( + auto plan = + PlanBuilder().tableScan(rowType, {}, "", nullptr, assignments).planNode(); + + assertQuery( plan, splits, "SELECT * FROM (VALUES (0, '2018-04-06'), (1, '2018-04-07'))", @@ -757,11 +795,11 @@ TEST_F(HiveIcebergTest, testPartitionedRead) { // Test filter on non-partitioned non-date column std::vector nonPartitionFilters = {"c0 = 1"}; - plan = PlanBuilder(pool_.get()) + plan = PlanBuilder() .tableScan(rowType, nonPartitionFilters, "", nullptr, assignments) .planNode(); - HiveConnectorTestBase::assertQuery(plan, splits, "SELECT 1, '2018-04-07'"); + assertQuery(plan, splits, "SELECT 1, '2018-04-07'"); // Test filter on non-partitioned date column std::vector filters = {"ds = date'2018-04-06'"}; @@ -773,6 +811,140 @@ TEST_F(HiveIcebergTest, testPartitionedRead) { splits.insert(splits.end(), icebergSplits.begin(), icebergSplits.end()); } - HiveConnectorTestBase::assertQuery(plan, splits, "SELECT 0, '2018-04-06'"); + assertQuery(plan, splits, "SELECT 0, '2018-04-06'"); +} + +TEST_F(HiveIcebergTest, schemaEvolutionRemoveColumn) { + auto oldRowType = ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), VARCHAR()}); + auto newRowType = ROW({"c0", "c2"}, {BIGINT(), VARCHAR()}); + + // Write data file with old schema (c0, c1, c2). + std::vector dataVectors; + dataVectors.push_back(makeRowVector( + oldRowType->names(), + { + makeFlatVector({1, 2, 3, 4, 5}), + makeFlatVector({10, 20, 30, 40, 50}), + makeFlatVector({"a", "b", "c", "d", "e"}), + })); + + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + + auto icebergSplits = makeIcebergSplits(dataFilePath->getPath()); + + // Expected result: c0 and c2 have values, c1 is not present. + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + newRowType->names(), + { + dataVectors[0]->childAt(0), + dataVectors[0]->childAt(2), + })); + + // Read with new schema (c0 and c2 only, c1 removed). + auto plan = PlanBuilder().tableScan(newRowType).planNode(); + AssertQueryBuilder(plan).splits(icebergSplits).assertResults(expectedVectors); +} + +TEST_F(HiveIcebergTest, schemaEvolutionAddColumns) { + auto oldRowType = ROW({"c0"}, {BIGINT()}); + auto newRowType = ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), VARCHAR()}); + + // Write data file with old schema (only c0). + std::vector dataVectors; + dataVectors.push_back(makeRowVector({ + makeFlatVector({100, 200, 300}), + })); + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + auto icebergSplits = makeIcebergSplits(dataFilePath->getPath()); + + // Expected result: c0 has values, c1 and c2 are NULL. + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector({ + dataVectors[0]->childAt(0), + makeNullConstant(TypeKind::INTEGER, 3), + makeNullConstant(TypeKind::VARCHAR, 3), + })); + + // Read with new schema (c0, c1, and c2). + auto plan = + PlanBuilder().tableScan(newRowType, {}, "", newRowType).planNode(); + AssertQueryBuilder(plan).splits(icebergSplits).assertResults(expectedVectors); +} + +// Test reading partition columns from Hive-migrated tables. +// This tests the adaptColumns method handling partition columns that are not +// stored in the data file but provided via partitionKeys map. +// This scenario occurs when reading Hive-written data files where partition +// column values are stored in partition metadata rather than in the data file. +TEST_F(HiveIcebergTest, partitionColumnsFromHive) { + auto fileRowType = ROW({"c0", "c1"}, {BIGINT(), INTEGER()}); + auto tableRowType = + ROW({"c0", "c1", "region", "year"}, + {BIGINT(), INTEGER(), VARCHAR(), INTEGER()}); + + // Write data file with only non-partition columns (c0, c1). + std::vector dataVectors; + dataVectors.push_back(makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({10, 20, 30}), + })); + auto dataFilePath = TempFilePath::create(); + writeToFile(dataFilePath->getPath(), dataVectors); + + // Set partition keys for region and year. + std::unordered_map> partitionKeys; + partitionKeys["region"] = "US"; + partitionKeys["year"] = "2025"; + + auto icebergSplits = + makeIcebergSplits(dataFilePath->getPath(), {}, partitionKeys); + auto assignments = makeColumnHandles(tableRowType, {2, 3}); + + // Expected result: c0 and c1 from file, region and year from partition keys. + std::vector expectedVectors; + expectedVectors.push_back(makeRowVector( + tableRowType->names(), + { + dataVectors[0]->childAt(0), + dataVectors[0]->childAt(1), + makeFlatVector({"US", "US", "US"}), + makeFlatVector({2025, 2025, 2025}), + })); + + // Read with table schema including partition columns. + auto plan = PlanBuilder() + .tableScan(tableRowType, {}, "", tableRowType, assignments) + .planNode(); + AssertQueryBuilder(plan).splits(icebergSplits).assertResults(expectedVectors); +} + +#ifdef VELOX_ENABLE_PARQUET +TEST_F(HiveIcebergTest, positionalDeleteFileWithRowGroupFilter) { + // This file contains three row groups, each with about 100 rows. + // Each row group has min/max values: [200, 299], [0, 99], [100, 199]. + // The filter here is id >= 100, which will cause the parquet reader to filter + // out the middle row group ([0, 99]). This can lead to a mismatch between the + // baseReadOffset tracked by Iceberg's split reader and the actual offset, + // resulting in records in the position delete file being mapped to incorrect + // rows. + auto path = test::getDataFilePath( + "velox/connectors/hive/iceberg/test", "examples/three_groups.parquet"); + const auto deletedPositionSize = 100; + std::vector deletePositionsVec( + deletedPositionSize); // allocate 100 elements, [100, 199]. + std::iota(deletePositionsVec.begin(), deletePositionsVec.end(), 100); + auto deleteFilePath = TempFilePath::create(); + HiveConnectorTestBase::assertQuery( + PlanBuilder() + .tableScan(ROW({"id"}, {BIGINT()}), {"id >= 100"}) + .planNode(), + createParquetDeleteFileAndSplits( + path, deletePositionsVec, deletedPositionSize, deleteFilePath), + "SELECT i AS id FROM range(100, 300) AS t(i)", + 0); } +#endif } // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp index bc9d2b4ad266..2c2c26297fd2 100644 --- a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp +++ b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.cpp @@ -215,7 +215,7 @@ std::shared_ptr IcebergSplitReaderBenchmark::createScanSpec( RowTypePtr& rowType, const std::vector& filterSpecs, std::vector& hitRows, - std::unordered_map>& filters) { + SubfieldFilters& filters) { std::unique_ptr filterGenerator = std::make_unique(rowType, 0); filters = filterGenerator->makeSubfieldFilters( @@ -271,7 +271,7 @@ void IcebergSplitReaderBenchmark::readSingleColumn( createFilterSpec(columnName, startPct, selectPct, rowType, false, false)); std::vector hitRows; - std::unordered_map> filters; + SubfieldFilters filters; auto scanSpec = createScanSpec(*batches, rowType, filterSpecs, hitRows, filters); @@ -324,7 +324,7 @@ void IcebergSplitReaderBenchmark::readSingleColumn( ""); FileHandleFactory fileHandleFactory( - std::make_unique>( + std::make_unique>( hiveConfig->numCacheFileHandles()), std::make_unique(connectorSessionProperties_)); diff --git a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h index 9b6d5df9bf61..3408fa4ce837 100644 --- a/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h +++ b/velox/connectors/hive/iceberg/tests/IcebergSplitReaderBenchmark.h @@ -70,9 +70,7 @@ class IcebergSplitReaderBenchmark { RowTypePtr& rowType, const std::vector& filterSpecs, std::vector& hitRows, - std::unordered_map< - facebook::velox::common::Subfield, - std::unique_ptr>& filters); + common::SubfieldFilters& filters); int read( const RowTypePtr& rowType, diff --git a/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp b/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp new file mode 100644 index 000000000000..bc8c34e761a4 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergTestBase.cpp @@ -0,0 +1,287 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +#include +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::connector::hive::iceberg::test { + +const std::string kIcebergConnectorId{"test-iceberg"}; + +void IcebergTestBase::SetUp() { + HiveConnectorTestBase::SetUp(); +#ifdef VELOX_ENABLE_PARQUET + parquet::registerParquetReaderFactory(); + parquet::registerParquetWriterFactory(); +#endif + Type::registerSerDe(); + + // Register IcebergConnector. + IcebergConnectorFactory icebergFactory; + auto icebergConnector = icebergFactory.newConnector( + kIcebergConnectorId, + std::make_shared( + std::unordered_map()), + ioExecutor_.get()); + registerConnector(icebergConnector); + + connectorSessionProperties_ = std::make_shared( + std::unordered_map(), true); + + connectorConfig_ = + std::make_shared(std::make_shared( + std::unordered_map())); + + setupMemoryPools(); + + fuzzerOptions_.vectorSize = 100; + fuzzerOptions_.nullRatio = 0.1; + fuzzer_ = std::make_unique(fuzzerOptions_, opPool_.get(), 1); +} + +void IcebergTestBase::TearDown() { + fuzzer_.reset(); + connectorQueryCtx_.reset(); + connectorPool_.reset(); + opPool_.reset(); + root_.reset(); + queryCtx_.reset(); + unregisterConnector(kIcebergConnectorId); + HiveConnectorTestBase::TearDown(); +} + +void IcebergTestBase::setupMemoryPools() { + root_.reset(); + opPool_.reset(); + connectorPool_.reset(); + connectorQueryCtx_.reset(); + queryCtx_.reset(); + + root_ = memory::memoryManager()->addRootPool( + "IcebergTest", 1L << 30, exec::MemoryReclaimer::create()); + opPool_ = root_->addLeafChild("operator"); + connectorPool_ = + root_->addAggregateChild("connector", exec::MemoryReclaimer::create()); + + queryCtx_ = core::QueryCtx::create(nullptr, core::QueryConfig({})); + auto expressionEvaluator = std::make_unique( + queryCtx_.get(), opPool_.get()); + + connectorQueryCtx_ = std::make_unique( + opPool_.get(), + connectorPool_.get(), + connectorSessionProperties_.get(), + nullptr, + common::PrefixSortConfig(), + std::move(expressionEvaluator), + nullptr, + "query.IcebergTest", + "task.IcebergTest", + "planNodeId.IcebergTest", + 0, + ""); +} + +std::vector IcebergTestBase::createTestData( + RowTypePtr rowType, + int32_t numBatches, + vector_size_t rowsPerBatch, + double nullRatio) { + std::vector vectors; + vectors.reserve(numBatches); + + fuzzerOptions_.nullRatio = nullRatio; + fuzzerOptions_.allowDictionaryVector = false; + fuzzerOptions_.timestampPrecision = + fuzzer::FuzzerTimestampPrecision::kMilliSeconds; + fuzzer_->setOptions(fuzzerOptions_); + + for (auto i = 0; i < numBatches; ++i) { + vectors.push_back(fuzzer_->fuzzRow(rowType, rowsPerBatch, false)); + } + + return vectors; +} + +std::shared_ptr IcebergTestBase::createPartitionSpec( + const RowTypePtr& rowType, + const std::vector& partitionFields) { + std::vector fields; + for (const auto& partitionField : partitionFields) { + fields.push_back( + IcebergPartitionSpec::Field{ + rowType->nameOf(partitionField.id), + rowType->childAt(partitionField.id), + partitionField.type, + partitionField.parameter}); + } + + return fields.empty() ? nullptr + : std::make_shared(1, fields); +} + +void addColumnHandles( + const RowTypePtr& rowType, + const std::vector& partitionFields, + std::vector& columnHandles) { + std::unordered_set partitionColumnIds; + for (const auto& field : partitionFields) { + partitionColumnIds.insert(field.id); + } + + columnHandles.reserve(rowType->size()); + for (auto i = 0; i < rowType->size(); ++i) { + const auto columnType = partitionColumnIds.contains(i) + ? HiveColumnHandle::ColumnType::kPartitionKey + : HiveColumnHandle::ColumnType::kRegular; + + columnHandles.push_back( + std::make_shared( + rowType->nameOf(i), + columnType, + rowType->childAt(i), + rowType->childAt(i))); + } +} + +IcebergInsertTableHandlePtr IcebergTestBase::createInsertTableHandle( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields) { + std::vector columnHandles; + addColumnHandles(rowType, partitionFields, columnHandles); + + auto locationHandle = std::make_shared( + outputDirectoryPath, + outputDirectoryPath, + LocationHandle::TableType::kNew); + + auto partitionSpec = createPartitionSpec(rowType, partitionFields); + + return std::make_shared( + /*inputColumns=*/columnHandles, + locationHandle, + /*tableStorageFormat=*/fileFormat_, + partitionSpec, + /*compressionKind=*/common::CompressionKind::CompressionKind_ZSTD); +} + +std::shared_ptr IcebergTestBase::createDataSink( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields) { + auto tableHandle = + createInsertTableHandle(rowType, outputDirectoryPath, partitionFields); + return std::make_shared( + rowType, + tableHandle, + connectorQueryCtx_.get(), + CommitStrategy::kNoCommit, + connectorConfig_, + std::string(kDefaultIcebergFunctionPrefix)); +} + +std::shared_ptr IcebergTestBase::createDataSinkAndAppendData( + const std::vector& vectors, + const std::string& dataPath, + const std::vector& partitionFields) { + VELOX_CHECK(!vectors.empty(), "vectors cannot be empty"); + + auto rowType = vectors.front()->rowType(); + auto dataSink = createDataSink(rowType, dataPath, partitionFields); + + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + EXPECT_TRUE(dataSink->finish()); + return dataSink; +} + +std::vector IcebergTestBase::listFiles( + const std::string& dirPath) { + std::vector files; + if (!std::filesystem::exists(dirPath)) { + return files; + } + + for (auto& dirEntry : + std::filesystem::recursive_directory_iterator(dirPath)) { + if (dirEntry.is_regular_file()) { + files.push_back(dirEntry.path().string()); + } + } + return files; +} + +std::unordered_map> +IcebergTestBase::extractPartitionKeys(const std::string& filePath) { + std::unordered_map> partitionKeys; + + std::vector pathComponents; + folly::split("/", filePath, pathComponents); + for (const auto& component : pathComponents) { + if (component.find('=') != std::string::npos) { + std::vector keys; + folly::split('=', component, keys); + if (keys.size() == 2) { + if (keys[1] == "null") { + partitionKeys[keys[0]] = std::nullopt; + } else { + partitionKeys[keys[0]] = keys[1]; + } + } + } + } + + return partitionKeys; +} + +std::vector> +IcebergTestBase::createSplitsForDirectory(const std::string& directory) { + std::vector> splits; + std::unordered_map customSplitInfo; + customSplitInfo["table_format"] = "hive-iceberg"; + + auto files = listFiles(directory); + for (const auto& filePath : files) { + auto partitionKeys = extractPartitionKeys(filePath); + + const auto file = filesystems::getFileSystem(filePath, nullptr) + ->openFileForRead(filePath); + splits.push_back( + std::make_shared( + kIcebergConnectorId, + filePath, + fileFormat_, + 0, + file->size(), + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + /*cacheable=*/true, + std::vector())); + } + + return splits; +} + +} // namespace facebook::velox::connector::hive::iceberg::test diff --git a/velox/connectors/hive/iceberg/tests/IcebergTestBase.h b/velox/connectors/hive/iceberg/tests/IcebergTestBase.h new file mode 100644 index 000000000000..c9dd8cacdc34 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/IcebergTestBase.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" +#ifdef VELOX_ENABLE_PARQUET +#include "velox/dwio/parquet/RegisterParquetWriter.h" +#include "velox/dwio/parquet/reader/ParquetReader.h" +#endif + +namespace facebook::velox::connector::hive::iceberg::test { + +extern const std::string kIcebergConnectorId; + +struct PartitionField { + // 0-based column index. + int32_t id; + TransformType type; + std::optional parameter; +}; + +class IcebergTestBase : public exec::test::HiveConnectorTestBase { + protected: + void SetUp() override; + + void TearDown() override; + + std::vector createTestData( + RowTypePtr rowType, + int32_t numBatches, + vector_size_t rowsPerBatch, + double nullRatio = 0.0); + + std::shared_ptr createDataSink( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields = {}); + + std::shared_ptr createDataSinkAndAppendData( + const std::vector& vectors, + const std::string& dataPath, + const std::vector& partitionFields = {}); + + std::vector> createSplitsForDirectory( + const std::string& directory); + + std::vector listFiles(const std::string& dirPath); + + std::shared_ptr createPartitionSpec( + const RowTypePtr& rowType, + const std::vector& partitionFields); + + /// Extracts partition key-value pairs from a file path. + /// Returns a map where keys are partition column names and values are + /// partition values (std::nullopt for null values). + /// Example: "/path/to/c1=10/c2=null/file.parquet" returns + /// {{"c1", "10"}, {"c2", std::nullopt}}. + static std::unordered_map> + extractPartitionKeys(const std::string& filePath); + + dwio::common::FileFormat fileFormat_{dwio::common::FileFormat::PARQUET}; + std::shared_ptr opPool_; + std::unique_ptr connectorQueryCtx_; + + private: + IcebergInsertTableHandlePtr createInsertTableHandle( + const RowTypePtr& rowType, + const std::string& outputDirectoryPath, + const std::vector& partitionFields = {}); + + std::vector listPartitionDirectories( + const std::string& dataPath); + + void setupMemoryPools(); + + std::shared_ptr root_; + std::shared_ptr connectorPool_; + std::shared_ptr connectorSessionProperties_; + std::shared_ptr connectorConfig_; + VectorFuzzer::Options fuzzerOptions_; + std::unique_ptr fuzzer_; + std::shared_ptr queryCtx_; +}; + +} // namespace facebook::velox::connector::hive::iceberg::test diff --git a/velox/connectors/hive/iceberg/tests/Main.cpp b/velox/connectors/hive/iceberg/tests/Main.cpp new file mode 100644 index 000000000000..3c9dd6615055 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/Main.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/process/ThreadDebugInfo.h" + +#include +#include + +// This main is needed for some tests on linux. +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Signal handler required for ThreadDebugInfoTest + facebook::velox::process::addDefaultFatalSignalHandler(); + folly::Init init(&argc, &argv, false); + return RUN_ALL_TESTS(); +} diff --git a/velox/connectors/hive/iceberg/tests/PartitionNameTest.cpp b/velox/connectors/hive/iceberg/tests/PartitionNameTest.cpp new file mode 100644 index 000000000000..30b74fb5ad6f --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/PartitionNameTest.cpp @@ -0,0 +1,482 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConnector.h" + +#include + +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +namespace facebook::velox::connector::hive::iceberg { + +using namespace facebook::velox; + +namespace { + +class PartitionNameTest : public test::IcebergTestBase { + protected: + // Generates partition IDs for the input rows and verifies that the resulting + // partition paths match the expected paths. Each row is processed + // independently, and its generated partition path is compared against the + // corresponding entry in expectedPaths. + // + // @param input Input data to generate partition IDs from. Must have the + // same size as expectedPaths. + // @param partitionSpec The IcebergPartitionSpec defining the partition + // transforms. The partition channels are determined by matching field names + // from the spec to column names in the input's type. + // @param expectedPaths Expected partition path strings, one per row. Each + // path should be the complete partition directory name (e.g., "col1=val1"). + void verifyPartitionPaths( + const RowVectorPtr& input, + const std::shared_ptr& partitionSpec, + const std::vector& expectedPaths) const { + ASSERT_EQ(expectedPaths.size(), input->size()); + std::vector partitionChannels(partitionSpec->fields.size()); + auto rowType = input->rowType(); + for (auto i = 0; i < partitionSpec->fields.size(); ++i) { + partitionChannels[i] = + rowType->getChildIdx(partitionSpec->fields[i].name); + } + + // Step 1: Build transform expressions and create evaluator. + auto transformExpressions = TransformExprBuilder::toExpressions( + partitionSpec, + partitionChannels, + rowType, + std::string(kDefaultIcebergFunctionPrefix)); + auto transformEvaluator = std::make_unique( + transformExpressions, connectorQueryCtx_.get()); + + // Step 2: Apply transforms to input partition columns. + auto transformedColumns = transformEvaluator->evaluate(input); + + std::vector partitionKeyTypes; + std::vector partitionKeyNames; + for (const auto& field : partitionSpec->fields) { + partitionKeyTypes.emplace_back(field.resultType()); + std::string key = field.transformType == TransformType::kIdentity + ? field.name + : fmt::format( + "{}_{}", + field.name, + TransformTypeName::toName(field.transformType)); + partitionKeyNames.emplace_back(std::move(key)); + } + + auto partitionRowType = + ROW(std::move(partitionKeyNames), std::move(partitionKeyTypes)); + // Step 3: Create RowVector based on transformed columns. + auto transformedRowVector = std::make_shared( + connectorQueryCtx_->memoryPool(), + partitionRowType, + nullptr, + input->size(), + std::move(transformedColumns)); + + // Step 4: Generate partition IDs from transformed data. + // The transformed row vector has columns in the same order as partition + // spec fields, so channels are sequential: 0, 1, 2, ... + std::vector transformedChannels( + partitionSpec->fields.size()); + std::iota(transformedChannels.begin(), transformedChannels.end(), 0); + + auto idGenerator = std::make_unique( + partitionRowType, + transformedChannels, + /*maxPartitions=*/128, + connectorQueryCtx_->memoryPool()); + + auto nameGenerator = std::make_unique(partitionSpec); + + raw_vector partitionIds(input->size()); + idGenerator->run(transformedRowVector, partitionIds); + + for (auto i = 0; i < input->size(); ++i) { + std::string partitionName = nameGenerator->partitionName( + partitionIds[i], idGenerator->partitionValues(), false); + ASSERT_EQ(partitionName, expectedPaths[i]); + } + } +}; + +TEST_F(PartitionNameTest, identity) { + std::vector> input = { + {INTEGER(), makeConstant(42, 1), "42"}, + {BIGINT(), makeConstant(9'876'543'210, 1), "9876543210"}, + {VARCHAR(), + makeConstant("test string partition column name", 1), + "test+string+partition+column+name"}, + {VARBINARY(), + makeConstant("\x48\x65\x6c\x6c\x6f", 1, VARBINARY()), + "SGVsbG8%3D"}, + {DECIMAL(18, 4), + makeConstant(12'345'678'901'234, 1, DECIMAL(18, 4)), + "1234567890.1234"}, + {BOOLEAN(), makeConstant(true, 1), "true"}, + {DATE(), makeConstant(18'262, 1, DATE()), "2020-01-01"}, + }; + + for (const auto& [type, value, expectedValue] : input) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {type}), {{0, TransformType::kIdentity, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({value}), + partitionSpec, + {fmt::format("c0={}", expectedValue)}); + } +} + +TEST_F(PartitionNameTest, timestamp) { + std::vector timestamps = { + Timestamp(253402300800, 100000000), // +10000-01-01T00:00:00.1. + Timestamp(-62170000000, 0), // -0001-11-29T19:33:20. + Timestamp(-62135577748, 999000000), // 0001-01-01T05:17:32.999. + Timestamp(0, 0), // 1970-01-01T00:00. + Timestamp(1609459200, 999000000), // 2021-01-01T00:00. + Timestamp(1640995200, 500000000), // 2022-01-01T00:00:00.5. + Timestamp(1672531200, 123000000), // 2023-01-01T00:00:00.123. + Timestamp(-1, 999000000), // 1969-12-31T23:59:59.999. + Timestamp(1, 1000000), // 1970-01-01T00:00:01.001. + Timestamp(-62167219199, 0), // 0000-01-01T00:00:01. + Timestamp(-377716279140, 321000000), // -10000-01-01T01:01:00.321. + Timestamp(253402304660, 321000000), // +10000-01-01T01:01:00.321. + Timestamp(951782400, 0), // 2000-02-29T00:00:00 (leap year). + Timestamp(4107456000, 0), // 2100-02-28T00:00:00. + Timestamp(-86400, 0), // 1969-12-31T00:00:00. + }; + + std::vector expectedPartitionNames = { + "c0=%2B10000-01-01T00%3A00%3A00.1", + "c0=-0001-11-29T19%3A33%3A20", + "c0=0001-01-01T05%3A17%3A32.999", + "c0=1970-01-01T00%3A00%3A00", + "c0=2021-01-01T00%3A00%3A00.999", + "c0=2022-01-01T00%3A00%3A00.5", + "c0=2023-01-01T00%3A00%3A00.123", + "c0=1969-12-31T23%3A59%3A59.999", + "c0=1970-01-01T00%3A00%3A01.001", + "c0=0000-01-01T00%3A00%3A01", + "c0=-10000-08-24T19%3A21%3A00.321", + "c0=%2B10000-01-01T01%3A04%3A20.321", + "c0=2000-02-29T00%3A00%3A00", + "c0=2100-02-28T00%3A00%3A00", + "c0=1969-12-31T00%3A00%3A00", + }; + + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), + {{0, TransformType::kIdentity, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + expectedPartitionNames); +} + +TEST_F(PartitionNameTest, null) { + std::vector< + std::tuple, VectorPtr>> + input = { + {INTEGER(), + TransformType::kBucket, + 32, + makeConstant(std::nullopt, 1)}, + {VARCHAR(), + TransformType::kTruncate, + 100, + makeConstant(std::nullopt, 1)}, + {DECIMAL(18, 3), + TransformType::kIdentity, + std::nullopt, + makeConstant(std::nullopt, 1, DECIMAL(18, 3))}, + {TIMESTAMP(), + TransformType::kYear, + std::nullopt, + makeConstant(std::nullopt, 1)}, + {TIMESTAMP(), + TransformType::kMonth, + std::nullopt, + makeConstant(std::nullopt, 1)}, + {DATE(), + TransformType::kDay, + std::nullopt, + makeConstant(std::nullopt, 1, DATE())}, + {TIMESTAMP(), + TransformType::kHour, + std::nullopt, + makeConstant(std::nullopt, 1)}, + }; + + for (const auto& [type, transformType, parameter, value] : input) { + auto rowType = ROW({"c0"}, {type}); + const auto& partitionSpec = + createPartitionSpec(rowType, {{0, transformType, parameter}}); + if (transformType == TransformType::kIdentity) { + verifyPartitionPaths(makeRowVector({value}), partitionSpec, {"c0=null"}); + } else { + verifyPartitionPaths( + makeRowVector({value}), + partitionSpec, + {fmt::format( + "c0_{}=null", TransformTypeName::toName(transformType))}); + } + } +} + +// test both partition column name and partition key encoding. +TEST_F(PartitionNameTest, specialChars) { + std::vector> inputs = { + {"abc123", "abc123"}, + {"ABC123", "ABC123"}, + {"a.b-c_d*e", "a.b-c_d*e"}, + {"space test", "space+test"}, + {"slash/test", "slash%2Ftest"}, + {"question?test", "question%3Ftest"}, + {"percent%test", "percent%25test"}, + {"hash#test", "hash%23test"}, + {"ampersand&test", "ampersand%26test"}, + {"equals=test", "equals%3Dtest"}, + {"plus+test", "plus%2Btest"}, + {"comma,test", "comma%2Ctest"}, + {"semicolon;test", "semicolon%3Btest"}, + {"at@test", "at%40test"}, + {"exclamation!test", "exclamation%21test"}, + {"dollar$test", "dollar%24test"}, + {"backslash\\test", "backslash%5Ctest"}, + {"quote\"test", "quote%22test"}, + {"apostrophe'test", "apostrophe%27test"}, + {"paren(test", "paren%28test"}, + {"paren)test", "paren%29test"}, + {"lessthan", "greater%3Ethan"}, + {"colon:test", "colon%3Atest"}, + {"pipe|test", "pipe%7Ctest"}, + {"bracket[test", "bracket%5Btest"}, + {"bracket]test", "bracket%5Dtest"}, + {"brace{test", "brace%7Btest"}, + {"brace}test", "brace%7Dtest"}, + {"caret^test", "caret%5Etest"}, + {"tilde~test", "tilde%7Etest"}, + {"backtick`test", "backtick%60test"}, + {"newline\ntest", "newline%0Atest"}, + {"carriage\rreturn", "carriage%0Dreturn"}, + {"tab\ttest", "tab%09test"}, + {"unicode\u00A9test", "unicode%C2%A9test"}, + {"email@example.com", "email%40example.com"}, + {"user:password@host:port/path", "user%3Apassword%40host%3Aport%2Fpath"}, + {"https://github.com/facebookincubator/velox", + "https%3A%2F%2Fgithub.com%2Ffacebookincubator%2Fvelox"}, + {"a+b=c&d=e+f", "a%2Bb%3Dc%26d%3De%2Bf"}, + {"a#b=c/d e", "a%23b%3Dc%2Fd+e"}, + {"special!@#$%^&*()_+", "special%21%40%23%24%25%5E%26*%28%29_%2B"}, + }; + + for (const auto& [input, encodedValue] : inputs) { + const auto& partitionSpec = createPartitionSpec( + ROW({input}, {VARCHAR()}), + {{0, TransformType::kIdentity, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector( + {input}, {makeConstant(StringView(input), 1)}), + partitionSpec, + {fmt::format("{}={}", encodedValue, encodedValue)}); + } +} + +TEST_F(PartitionNameTest, multipleRows) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0", "c1"}, {INTEGER(), VARCHAR()}), + { + {0, TransformType::kBucket, 8}, + {1, TransformType::kIdentity, std::nullopt}, + }); + + verifyPartitionPaths( + makeRowVector({ + makeFlatVector({10, 20, 30, -100}), + makeFlatVector({"value1", "VALue2", "VALUE3", ""}), + }), + partitionSpec, + { + "c0_bucket=4/c1=value1", + "c0_bucket=3/c1=VALue2", + "c0_bucket=3/c1=VALUE3", + "c0_bucket=6/c1=", + }); +} + +TEST_F(PartitionNameTest, year) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kYear, std::nullopt}}); + + std::vector timestamps = { + Timestamp(0, 0), + Timestamp(1609459200, 0), + Timestamp(1640995200, 0), + Timestamp(-31536000, 0), + Timestamp(253402300800, 0), + }; + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + { + "c0_year=1970", + "c0_year=2021", + "c0_year=2022", + "c0_year=1969", + "c0_year=10000", + }); +} + +TEST_F(PartitionNameTest, yearWithDate) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {DATE()}), {{0, TransformType::kYear, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({0, 365, 18262, -365}, DATE())}), + partitionSpec, + { + "c0_year=1970", + "c0_year=1971", + "c0_year=2020", + "c0_year=1969", + }); +} + +TEST_F(PartitionNameTest, month) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kMonth, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({ + Timestamp(0, 0), + Timestamp(2678400, 0), + Timestamp(1609459200, 0), + Timestamp(1640995200, 0), + Timestamp(-2678400, 0), + })}), + partitionSpec, + { + "c0_month=1970-01", + "c0_month=1970-02", + "c0_month=2021-01", + "c0_month=2022-01", + "c0_month=1969-12", + }); +} + +TEST_F(PartitionNameTest, monthWithDate) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {DATE()}), {{0, TransformType::kMonth, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({0, 31, 365, -31}, DATE())}), + partitionSpec, + { + "c0_month=1970-01", + "c0_month=1970-02", + "c0_month=1971-01", + "c0_month=1969-12", + }); +} + +TEST_F(PartitionNameTest, day) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kDay, std::nullopt}}); + + std::vector timestamps = { + Timestamp(0, 0), + Timestamp(86400, 0), + Timestamp(1577836800, 0), + Timestamp(-86400, 0), + }; + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + { + "c0_day=1970-01-01", + "c0_day=1970-01-02", + "c0_day=2020-01-01", + "c0_day=1969-12-31", + }); +} + +TEST_F(PartitionNameTest, dayWithDate) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {DATE()}), {{0, TransformType::kDay, std::nullopt}}); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({0, 1, 18262, -1}, DATE())}), + partitionSpec, + { + "c0_day=1970-01-01", + "c0_day=1970-01-02", + "c0_day=2020-01-01", + "c0_day=1969-12-31", + }); +} + +TEST_F(PartitionNameTest, hour) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kHour, std::nullopt}}); + + std::vector timestamps = { + Timestamp(0, 0), + Timestamp(3600, 0), + Timestamp(86400, 0), + Timestamp(1577836800, 0), + Timestamp(-3600, 0), + }; + + verifyPartitionPaths( + makeRowVector({makeFlatVector(timestamps)}), + partitionSpec, + { + "c0_hour=1970-01-01-00", + "c0_hour=1970-01-01-01", + "c0_hour=1970-01-02-00", + "c0_hour=2020-01-01-00", + "c0_hour=1969-12-31-23", + }); +} + +TEST_F(PartitionNameTest, multipleTransformsSameColumn) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), + { + {0, TransformType::kIdentity, std::nullopt}, + {0, TransformType::kYear, std::nullopt}, + {0, TransformType::kBucket, 10}, + }); + + verifyPartitionPaths( + makeRowVector({makeFlatVector({Timestamp(1609459200, 0)})}), + partitionSpec, + {"c0=2021-01-01T00%3A00%3A00/c0_year=2021/c0_bucket=0"}); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/PartitionSpecTest.cpp b/velox/connectors/hive/iceberg/tests/PartitionSpecTest.cpp new file mode 100644 index 000000000000..23a06340aad3 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/PartitionSpecTest.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/PartitionSpec.h" + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +TEST(PartitionSpecTest, invalidColumnType) { + auto makeSpec = [](const TypePtr& type) { + std::vector fields = { + {"c0", type, TransformType::kIdentity, std::nullopt}, + }; + return std::make_shared(1, fields); + }; + + VELOX_ASSERT_USER_THROW( + makeSpec(ROW({{"a", INTEGER()}})), + "Type is not supported as a partition column: ROW"); + VELOX_ASSERT_USER_THROW( + makeSpec(ARRAY(INTEGER())), + "Type is not supported as a partition column: ARRAY"); + VELOX_ASSERT_USER_THROW( + makeSpec(MAP(VARCHAR(), INTEGER())), + "Type is not supported as a partition column: MAP"); + VELOX_ASSERT_USER_THROW( + makeSpec(TIMESTAMP_WITH_TIME_ZONE()), + "Type is not supported as a partition column: TIMESTAMP WITH TIME ZONE"); +} + +TEST(PartitionSpecTest, invalidMultipleTransforms) { + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kIdentity, std::nullopt}, + {"c0", VARCHAR(), TransformType::kIdentity, std::nullopt}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Identity, Transforms: [identity, identity]"); + } + + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kBucket, 16}, + {"c0", VARCHAR(), TransformType::kBucket, 32}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Bucket, Transforms: [bucket, bucket]"); + } + + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kTruncate, 2}, + {"c0", VARCHAR(), TransformType::kTruncate, 5}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Truncate, Transforms: [trunc, trunc]"); + } + + { + std::vector fields4 = { + {"c0", TIMESTAMP(), TransformType::kYear, std::nullopt}, + {"c0", TIMESTAMP(), TransformType::kMonth, std::nullopt}, + {"c0", TIMESTAMP(), TransformType::kDay, std::nullopt}, + {"c0", TIMESTAMP(), TransformType::kHour, std::nullopt}, + }; + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields4), + "Column: 'c0', Category: Temporal, Transforms: [year, month, day, hour]"); + } +} + +TEST(PartitionSpecTest, invalidMultipleTransformsMultipleColumns) { + std::vector fields = { + {"c0", DATE(), TransformType::kYear, std::nullopt}, + {"c0", DATE(), TransformType::kMonth, std::nullopt}, + {"c1", VARCHAR(), TransformType::kBucket, 16}, + {"c1", VARCHAR(), TransformType::kBucket, 32}, + }; + // order may vary due to map iteration. + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c0', Category: Temporal, Transforms: [year, month]"); + VELOX_ASSERT_USER_THROW( + std::make_shared(1, fields), + "Column: 'c1', Category: Bucket, Transforms: [bucket, bucket]"); +} + +TEST(PartitionSpecTest, validMultipleTransforms) { + { + std::vector fields = { + {"c0", VARCHAR(), TransformType::kIdentity, std::nullopt}, + {"c0", VARCHAR(), TransformType::kBucket, 16}, + {"c0", VARCHAR(), TransformType::kTruncate, 10}, + }; + auto spec = std::make_shared(1, fields); + EXPECT_EQ(spec->fields.size(), 3); + } + + { + std::vector fields = { + {"c0", DATE(), TransformType::kYear, std::nullopt}, + {"c0", DATE(), TransformType::kBucket, 16}, + {"c0", DATE(), TransformType::kIdentity, std::nullopt}, + }; + auto spec = std::make_shared(1, fields); + EXPECT_EQ(spec->fields.size(), 3); + } +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/PartitionValueFormatterTest.cpp b/velox/connectors/hive/iceberg/tests/PartitionValueFormatterTest.cpp new file mode 100644 index 000000000000..13cda035cbfb --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/PartitionValueFormatterTest.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/connectors/hive/iceberg/IcebergPartitionName.h" +#include "velox/type/Type.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +template +std::string toPath(TransformType transform, T value, const TypePtr& type) { + return IcebergPartitionName::toName(value, type, transform); +} + +std::string timestampToPath(const Timestamp& timestamp) { + return toPath(TransformType::kIdentity, timestamp, TIMESTAMP()); +} + +std::string testString( + const std::string& value, + const TypePtr& typePtr = VARCHAR()) { + auto identityResult = + toPath(TransformType::kIdentity, StringView(value), typePtr); + auto truncateResult = + toPath(TransformType::kTruncate, StringView(value), typePtr); + EXPECT_EQ(identityResult, truncateResult); + return identityResult; +} + +std::string testVarbinary(const std::string& value) { + return testString(value, VARBINARY()); +} + +std::string testInteger(int32_t value) { + auto identityResult = toPath(TransformType::kIdentity, value, INTEGER()); + auto bucketResult = toPath(TransformType::kBucket, value, INTEGER()); + auto truncResult = toPath(TransformType::kTruncate, value, INTEGER()); + EXPECT_EQ(identityResult, truncResult); + EXPECT_EQ(bucketResult, truncResult); + return truncResult; +} + +TEST(IcebergPartitionPathTest, integer) { + EXPECT_EQ(testInteger(0), "0"); + EXPECT_EQ(testInteger(1), "1"); + EXPECT_EQ(testInteger(100), "100"); + EXPECT_EQ(testInteger(-100), "-100"); + EXPECT_EQ(testInteger(128), "128"); + EXPECT_EQ(testInteger(1024), "1024"); +} + +TEST(IcebergPartitionPathTest, date) { + EXPECT_EQ(toPath(TransformType::kIdentity, 18'262, DATE()), "2020-01-01"); + EXPECT_EQ(toPath(TransformType::kIdentity, 0, DATE()), "1970-01-01"); + EXPECT_EQ(toPath(TransformType::kIdentity, -1, DATE()), "1969-12-31"); + EXPECT_EQ(toPath(TransformType::kIdentity, 2'932'897, DATE()), "10000-01-01"); +} + +TEST(IcebergPartitionPathTest, boolean) { + EXPECT_EQ(toPath(TransformType::kIdentity, true, BOOLEAN()), "true"); + EXPECT_EQ(toPath(TransformType::kIdentity, false, BOOLEAN()), "false"); +} + +TEST(IcebergPartitionPathTest, string) { + EXPECT_EQ(testString("a/b/c=d"), "a/b/c=d"); + EXPECT_EQ(testString(""), ""); + EXPECT_EQ(testString("abc"), "abc"); +} + +TEST(IcebergPartitionPathTest, varbinary) { + EXPECT_EQ(testVarbinary("\x48\x65\x6c\x6c\x6f"), "SGVsbG8="); + EXPECT_EQ(testVarbinary("\x1\x2\x3"), "AQID"); + EXPECT_EQ(testVarbinary(""), ""); +} + +TEST(IcebergPartitionPathTest, timestamp) { + EXPECT_EQ(timestampToPath(Timestamp(0, 0)), "1970-01-01T00:00:00"); + EXPECT_EQ( + timestampToPath(Timestamp(1'609'459'200, 999'000'000)), + "2021-01-01T00:00:00.999"); + EXPECT_EQ( + timestampToPath(Timestamp(1'640'995'200, 500'000'000)), + "2022-01-01T00:00:00.5"); + EXPECT_EQ( + timestampToPath(Timestamp(-1, 999'000'000)), "1969-12-31T23:59:59.999"); + EXPECT_EQ( + timestampToPath(Timestamp(253'402'300'800, 100'000'000)), + "+10000-01-01T00:00:00.1"); + EXPECT_EQ( + timestampToPath(Timestamp(-62'170'000'000, 0)), "-0001-11-29T19:33:20"); + EXPECT_EQ( + timestampToPath(Timestamp(-62'167'219'199, 0)), "0000-01-01T00:00:01"); +} + +TEST(IcebergPartitionPathTest, year) { + EXPECT_EQ(toPath(TransformType::kYear, 0, INTEGER()), "1970"); + EXPECT_EQ(toPath(TransformType::kYear, 1, INTEGER()), "1971"); + EXPECT_EQ(toPath(TransformType::kYear, 8'030, INTEGER()), "10000"); + EXPECT_EQ(toPath(TransformType::kYear, -1, INTEGER()), "1969"); + EXPECT_EQ(toPath(TransformType::kYear, -50, INTEGER()), "1920"); +} + +TEST(IcebergPartitionPathTest, month) { + EXPECT_EQ(toPath(TransformType::kMonth, 0, INTEGER()), "1970-01"); + EXPECT_EQ(toPath(TransformType::kMonth, 1, INTEGER()), "1970-02"); + EXPECT_EQ(toPath(TransformType::kMonth, 11, INTEGER()), "1970-12"); + EXPECT_EQ(toPath(TransformType::kMonth, 612, INTEGER()), "2021-01"); + EXPECT_EQ(toPath(TransformType::kMonth, -1, INTEGER()), "1969-12"); + EXPECT_EQ(toPath(TransformType::kMonth, -13, INTEGER()), "1968-12"); +} + +TEST(IcebergPartitionPathTest, day) { + EXPECT_EQ(toPath(TransformType::kDay, 0, DATE()), "1970-01-01"); + EXPECT_EQ(toPath(TransformType::kDay, 1, DATE()), "1970-01-02"); + EXPECT_EQ(toPath(TransformType::kDay, 18'262, DATE()), "2020-01-01"); + EXPECT_EQ(toPath(TransformType::kDay, -1, DATE()), "1969-12-31"); +} + +TEST(IcebergPartitionPathTest, hour) { + EXPECT_EQ(toPath(TransformType::kHour, 0, INTEGER()), "1970-01-01-00"); + EXPECT_EQ(toPath(TransformType::kHour, 1, INTEGER()), "1970-01-01-01"); + EXPECT_EQ(toPath(TransformType::kHour, 24, INTEGER()), "1970-01-02-00"); + EXPECT_EQ(toPath(TransformType::kHour, 438'288, INTEGER()), "2020-01-01-00"); + EXPECT_EQ(toPath(TransformType::kHour, -1, INTEGER()), "1969-12-31-23"); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/TransformE2ETest.cpp b/velox/connectors/hive/iceberg/tests/TransformE2ETest.cpp new file mode 100644 index 000000000000..a8fba90bef0b --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/TransformE2ETest.cpp @@ -0,0 +1,650 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/functions/iceberg/Murmur3Hash32.h" + +namespace facebook::velox::connector::hive::iceberg { + +using namespace facebook::velox::exec::test; + +namespace { + +#ifdef VELOX_ENABLE_PARQUET + +class TransformE2ETest : public test::IcebergTestBase { + protected: + static constexpr int32_t kDefaultNumBatches = 2; + static constexpr int32_t kDefaultRowsPerBatch = 100; + + std::vector createTimestampTestData() { + std::vector batches; + // 1. Aligned to hour boundaries (no minutes/seconds) since hour is the + // finest granularity transform tested. + // 2. Include negative epoch values (pre-1970) to test edge cases. + // 3. Span multiple years and months. + static const std::vector timestamps = { + // 1969-01-01 00:00:00 (negative epoch) + Timestamp(-31536000, 0), + // 1969-12-31 00:00:00 (day before epoch) + Timestamp(-86400, 0), + // 1970-01-01 00:00:00 + Timestamp(0, 0), + // 1970-01-01 01:00:00 + Timestamp(3600, 0), + // 1970-01-02 00:00:00 + Timestamp(86400, 0), + // 1970-01-31 00:00:00 + Timestamp(2592000, 0), + // 1971-01-01 00:00:00 + Timestamp(31536000, 0), + // 2021-01-01 00:00:00 + Timestamp(1609459200, 0), + // 2021-01-02 00:00:00 + Timestamp(1609545600, 0), + // 2021-02-01 00:00:00 + Timestamp(1612224000, 0), + // 2022-01-01 00:00:00 + Timestamp(1640995200, 0), + // 2023-01-01 00:00:00 + Timestamp(1672531200, 0), + // 2100-01-01 00:00:00 + Timestamp(4102444800, 0), + }; + + for (auto i = 0; i < kDefaultNumBatches; i++) { + auto timestampVector = makeFlatVector( + kDefaultRowsPerBatch, + [](auto row) { return timestamps[row % timestamps.size()]; }); + batches.push_back(makeRowVector({timestampVector})); + } + return batches; + } + + std::shared_ptr writeBatchesWithTransforms( + const std::vector& batches, + const std::vector& partitionFields) { + VELOX_CHECK(!batches.empty(), "input cannot be empty"); + + int64_t expectedRowCount = 0; + for (const auto& batch : batches) { + expectedRowCount += batch->size(); + } + + auto rowType = batches.front()->rowType(); + auto outputDirectory = TempDirectoryPath::create(); + const auto dataSink = createDataSinkAndAppendData( + batches, outputDirectory->getPath(), partitionFields); + dataSink->close(); + verifyTotalRowCount(rowType, outputDirectory->getPath(), expectedRowCount); + return outputDirectory; + } + + // Generate a key from a timestamp and transform type. + // The key format depends on the transform type: + // - kMonth: "YYYY-MM" + // - kDay: "YYYY-MM-DD" + // - kHour: "YYYY-MM-DD-HH" + static std::string timestampToKey( + const Timestamp& ts, + TransformType transformType) { + std::tm tm; + if (!Timestamp::epochToCalendarUtc(ts.getSeconds(), tm)) { + return ""; + } + + int32_t year = tm.tm_year + 1900; + int32_t month = tm.tm_mon + 1; + int32_t day = tm.tm_mday; + int32_t hour = tm.tm_hour; + + switch (transformType) { + case TransformType::kMonth: + return fmt::format("{:04d}-{:02d}", year, month); + case TransformType::kDay: + return fmt::format("{:04d}-{:02d}-{:02d}", year, month, day); + case TransformType::kHour: + return fmt::format( + "{:04d}-{:02d}-{:02d}-{:02d}", year, month, day, hour); + default: + VELOX_UNREACHABLE(); + } + } + + // Helper function to build expected counts map from timestamp batches. + // The key format depends on the transform type: + // - kMonth: "YYYY-MM" + // - kDay: "YYYY-MM-DD" + // - kHour: "YYYY-MM-DD-HH" + static std::unordered_map + buildExpectedCountsFromTimestamps( + const std::vector& batches, + TransformType transformType) { + std::unordered_map expectedCounts; + + for (const auto& batch : batches) { + auto timestampVector = batch->childAt(0)->as>(); + for (auto i = 0; i < batch->size(); i++) { + Timestamp ts = timestampVector->valueAt(i); + std::string key = timestampToKey(ts, transformType); + expectedCounts[key]++; + } + } + return expectedCounts; + } + + static std::string dirName(const std::string& path) { + return std::filesystem::path(path).filename().string(); + } + + static std::vector firstLevelDirectories( + const std::string& basePath) { + std::vector directories; + for (const auto& entry : std::filesystem::directory_iterator(basePath)) { + if (entry.is_directory()) { + directories.push_back(entry.path().string()); + } + } + return directories; + } + + static std::vector listDirectoriesRecursively( + const std::string& path) { + std::vector directories; + auto firstLevelDirs = firstLevelDirectories(path); + + for (const auto& dir : firstLevelDirs) { + directories.push_back(dirName(dir)); + auto subDirs = listDirectoriesRecursively(dir); + directories.insert(directories.end(), subDirs.begin(), subDirs.end()); + } + + return directories; + } + + static std::vector verifyPartitionCount( + const std::string& outputPath, + int32_t expectedCount) { + const auto partitionDirs = firstLevelDirectories(outputPath); + EXPECT_EQ(partitionDirs.size(), expectedCount); + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + EXPECT_TRUE(name.find('=') != std::string::npos) + << "Partition directory " << name + << " does not follow Iceberg naming convention"; + } + return partitionDirs; + } + + // Verify the total row count across all partitions. + void verifyTotalRowCount( + const RowTypePtr& rowType, + const std::string& outputPath, + int32_t expectedRowCount) { + auto splits = createSplitsForDirectory(outputPath); + + const auto plan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + const auto actualRowCount = + AssertQueryBuilder(plan).splits(splits).countResults(); + + ASSERT_EQ(actualRowCount, expectedRowCount); + } + + // Verify data in a specific partition. + void verifyPartitionData( + const RowTypePtr& rowType, + const std::string& partitionPath, + const std::string& partitionFilter, + const int32_t expectedRowCount) { + auto splits = createSplitsForDirectory(partitionPath); + + auto scanPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + + const auto actualRowCount = + AssertQueryBuilder(scanPlan).splits(splits).countResults(); + + ASSERT_EQ(actualRowCount, expectedRowCount); + + const auto filterPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .filter(partitionFilter) + .planNode(); + const auto filteredRowCount = + AssertQueryBuilder(filterPlan).splits(splits).countResults(); + ASSERT_EQ(expectedRowCount, filteredRowCount); + } + + static std::pair parsePartitionDirName( + const std::string& name) { + auto eq = name.find('='); + VELOX_CHECK(eq != std::string::npos); + auto us = name.rfind('_', eq - 1); + auto columnName = name.substr(0, us); + auto value = name.substr(eq + 1); + return {columnName, value}; + } + + static int32_t computeBucketHash( + const StringView& value, + int32_t numBuckets) { + int32_t hash = functions::iceberg::Murmur3Hash32::hashBytes( + value.data(), value.size()); + return ((hash & 0x7FFFFFFF) % numBuckets); + } +}; + +TEST_F(TransformE2ETest, identity) { + constexpr auto rowsPerBatch = 10; + constexpr auto duplicates = 5; + auto rowType = ROW({"c0"}, {INTEGER()}); + auto baseVectors = createTestData(rowType, kDefaultNumBatches, rowsPerBatch); + + // Duplicate each row to create multiple rows with the same partition key. + std::vector vectors; + for (const auto& baseVector : baseVectors) { + auto duplicatedColumn = wrapInDictionary( + makeIndices( + baseVector->size() * duplicates, + [duplicates](auto row) { return row / duplicates; }), + baseVector->size() * duplicates, + baseVector->childAt(0)); + vectors.push_back(makeRowVector({duplicatedColumn})); + } + + auto outputDirectory = writeBatchesWithTransforms( + vectors, {{0, TransformType::kIdentity, std::nullopt}}); + + auto partitionDirs = verifyPartitionCount( + outputDirectory->getPath(), kDefaultNumBatches * rowsPerBatch); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + // Each partition should have duplicates rows. + verifyPartitionData(rowType, dir, name, duplicates); + } +} + +TEST_F(TransformE2ETest, partitionNamingConventions) { + auto rowVector = makeRowVector( + { + "c_int", + "c_bigint", + "c_varchar", + "c_varchar2", + "c_decimal", + "c_varbinary", + }, + { + makeConstant(42, 1, INTEGER()), + makeConstant(static_cast(9'876'543'210), 1, BIGINT()), + makeConstant("test string", 1, VARCHAR()), + makeNullConstant(TypeKind::VARCHAR, 1), + makeConstant(static_cast(1'234'567'890), 1, DECIMAL(18, 3)), + makeConstant("binarydata\1\2\3", 1, VARBINARY()), + }); + + std::vector partitionTransforms = { + {0, TransformType::kIdentity, std::nullopt}, // c_int. + {1, TransformType::kIdentity, std::nullopt}, // c_bigint. + {2, TransformType::kIdentity, std::nullopt}, // c_varchar. + {4, TransformType::kIdentity, std::nullopt}, // c_decimal. + {5, TransformType::kIdentity, std::nullopt}, // c_varbinary. + {3, TransformType::kIdentity, std::nullopt}, // c_varchar2. + }; + + auto outputDirectory = + writeBatchesWithTransforms({rowVector}, partitionTransforms); + + const auto actualPartitionNames = + listDirectoriesRecursively(outputDirectory->getPath()); + + // Build expected partition folder names. + std::vector expectedPartitionNames = { + "c_int=42", + "c_bigint=9876543210", + "c_varchar=test+string", + "c_decimal=1234567.890", + "c_varbinary=YmluYXJ5ZGF0YQECAw%3D%3D", + "c_varchar2=null", + }; + + ASSERT_EQ(actualPartitionNames, expectedPartitionNames) + << "Partition folder names do not match expected values"; +} + +TEST_F(TransformE2ETest, bucket) { + constexpr int32_t numBuckets = 4; + auto rowType = ROW({"c_varchar"}, {VARCHAR()}); + auto vectors = + createTestData(rowType, kDefaultNumBatches, kDefaultRowsPerBatch); + + auto outputDirectory = writeBatchesWithTransforms( + vectors, {{0, TransformType::kBucket, numBuckets}}); + + const auto partitionDirs = + verifyPartitionCount(outputDirectory->getPath(), numBuckets); + + int32_t totalRows = 0; + for (const auto& dir : partitionDirs) { + auto splits = createSplitsForDirectory(dir); + auto countPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + auto partitionRowCount = + AssertQueryBuilder(countPlan).splits(splits).countResults(); + + totalRows += partitionRowCount; + ASSERT_GT(partitionRowCount, 0); + } + + ASSERT_EQ(totalRows, kDefaultNumBatches * kDefaultRowsPerBatch); + + std::unordered_map valueToExpectedBucket; + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + const auto [k, v] = parsePartitionDirName(name); + const int32_t expectedBucket = std::stoi(v); + + auto dataPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .project({"c_varchar"}) + .planNode(); + const auto& dataResult = AssertQueryBuilder(dataPlan) + .splits(createSplitsForDirectory(dir)) + .copyResults(opPool_.get()); + + auto varcharColumn = dataResult->childAt(0)->asFlatVector(); + for (auto i = 0; i < dataResult->size(); i++) { + auto value = varcharColumn->valueAt(i); + + auto computedBucket = computeBucketHash(value, numBuckets); + ASSERT_EQ(computedBucket, expectedBucket); + } + } +} + +TEST_F(TransformE2ETest, truncate) { + auto rowType = ROW({"c_int"}, {INTEGER()}); + + std::vector batches; + for (auto i = 0; i < kDefaultNumBatches; i++) { + std::vector columns; + columns.push_back( + makeFlatVector(50, [](auto row) { return row % 100; })); + auto vectors = makeRowVector(rowType->names(), columns); + batches.push_back(vectors); + } + + auto outputDirectory = + writeBatchesWithTransforms(batches, {{0, TransformType::kTruncate, 10}}); + + auto partitionDirs = verifyPartitionCount(outputDirectory->getPath(), 5); + + for (const auto& dir : partitionDirs) { + const std::string name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + const std::string filter = fmt::format( + "{}>={} AND {}<{}", c, v, c, std::to_string(std::stoi(v) + 10)); + + verifyPartitionData( + rowType, dir, filter, 20); // 10 values per batch * 2 batches. + } +} + +TEST_F(TransformE2ETest, year) { + auto rowType = ROW({"c_date"}, {DATE()}); + static const std::vector dates = { + 18'262, 18'628, 18'993, 19'358, 19'723, 20'181}; + std::vector batches; + for (auto i = 0; i < kDefaultNumBatches; i++) { + auto dateVector = makeFlatVector( + kDefaultRowsPerBatch, + [](auto row) { return dates[row % dates.size()]; }, + nullptr, + DATE()); + batches.emplace_back(makeRowVector(rowType->names(), {dateVector})); + } + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kYear, std::nullopt}}); + + auto partitionDirs = verifyPartitionCount(outputDirectory->getPath(), 6); + + for (int32_t year = 2020; year <= 2025; year++) { + const auto expectedDirName = fmt::format("c_date_year={}", year); + bool foundPartition = false; + auto yearFilter = [](int32_t year) -> std::string { + return fmt::format("YEAR(DATE '{}-01-01')={}", year, year); + }; + + for (const auto& dir : partitionDirs) { + SCOPED_TRACE(year); + const auto name = dirName(dir); + if (name == expectedDirName) { + foundPartition = true; + auto datePlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .filter(yearFilter(year)) + .planNode(); + + auto splits = createSplitsForDirectory(dir); + auto partitionRowCount = + AssertQueryBuilder(datePlan).splits(splits).countResults(); + + auto countPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + auto totalPartitionCount = + AssertQueryBuilder(countPlan).splits(splits).countResults(); + ASSERT_EQ(partitionRowCount, totalPartitionCount); + break; + } + } + ASSERT_TRUE(foundPartition); + } +} + +TEST_F(TransformE2ETest, month) { + auto batches = createTimestampTestData(); + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kMonth, std::nullopt}}); + + auto expectedCounts = + buildExpectedCountsFromTimestamps(batches, TransformType::kMonth); + const auto partitionDirs = firstLevelDirectories(outputDirectory->getPath()); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + size_t dashPos = v.find('-'); + ASSERT_NE(dashPos, std::string::npos) << "Invalid month format: " << v; + + int32_t year = std::stoi(v.substr(0, dashPos)); + int32_t month = std::stoi(v.substr(dashPos + 1)); + std::string filter = + fmt::format("YEAR(c0) = {} AND MONTH(c0) = {}", year, month); + std::string monthKey = fmt::format("{:04d}-{:02d}", year, month); + verifyPartitionData( + ROW({"c0"}, {TIMESTAMP()}), dir, filter, expectedCounts[monthKey]); + } +} + +TEST_F(TransformE2ETest, day) { + auto batches = createTimestampTestData(); + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kDay, std::nullopt}}); + + auto expectedCounts = + buildExpectedCountsFromTimestamps(batches, TransformType::kDay); + const auto partitionDirs = firstLevelDirectories(outputDirectory->getPath()); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + std::vector dateParts; + folly::split('-', v, dateParts); + ASSERT_EQ(dateParts.size(), 3) << "Invalid day format: " << v; + + int32_t year = std::stoi(dateParts[0]); + int32_t month = std::stoi(dateParts[1]); + int32_t day = std::stoi(dateParts[2]); + + std::string filter = fmt::format( + "YEAR(c0) = {} AND MONTH(c0) = {} AND DAY(c0) = {}", year, month, day); + std::string dayKey = fmt::format("{:04d}-{:02d}-{:02d}", year, month, day); + verifyPartitionData( + ROW({"c0"}, {TIMESTAMP()}), dir, filter, expectedCounts[dayKey]); + } +} + +TEST_F(TransformE2ETest, hour) { + auto batches = createTimestampTestData(); + + auto outputDirectory = writeBatchesWithTransforms( + batches, {{0, TransformType::kHour, std::nullopt}}); + + auto expectedCounts = + buildExpectedCountsFromTimestamps(batches, TransformType::kHour); + const auto partitionDirs = firstLevelDirectories(outputDirectory->getPath()); + + for (const auto& dir : partitionDirs) { + const auto name = dirName(dir); + auto [c, v] = parsePartitionDirName(name); + std::vector dateParts; + folly::split('-', v, dateParts); + ASSERT_EQ(dateParts.size(), 4) << "Invalid hour format: " << v; + + int32_t year = std::stoi(dateParts[0]); + int32_t month = std::stoi(dateParts[1]); + int32_t day = std::stoi(dateParts[2]); + int32_t hour = std::stoi(dateParts[3]); + + std::string filter = fmt::format( + "YEAR(c0) = {} AND MONTH(c0) = {} AND " + "DAY(c0) = {} AND HOUR(c0) = {}", + year, + month, + day, + hour); + std::string hourKey = + fmt::format("{:04d}-{:02d}-{:02d}-{:02d}", year, month, day, hour); + verifyPartitionData( + ROW({"c0"}, {TIMESTAMP()}), dir, filter, expectedCounts[hourKey]); + } +} + +TEST_F(TransformE2ETest, multipleTransformsOnSameColumn) { + auto rowType = ROW( + { + "c_int", + "c_bigint", + }, + { + INTEGER(), + BIGINT(), + }); + + auto vectors = createTestData(rowType, 2, 20); + auto outputDirectory = writeBatchesWithTransforms( + vectors, + { + {0, TransformType::kIdentity, std::nullopt}, // c_int. + {0, TransformType::kTruncate, 10}, // truncate(c_int, 10). + {0, TransformType::kBucket, 4}, // bucket(c_int, 4). + }); + + auto firstLevelDirs = firstLevelDirectories(outputDirectory->getPath()); + ASSERT_GT(firstLevelDirs.size(), 0); + for (const auto& dir : firstLevelDirs) { + const auto name = dirName(dir); + ASSERT_TRUE(name.find("c_int=") != std::string::npos) + << "First level directory " << name << " should use identity transform"; + + auto secondLevelDirs = firstLevelDirectories(dir); + ASSERT_GT(secondLevelDirs.size(), 0) + << "No second level directories found in " << dir; + + for (const auto& secondDir : secondLevelDirs) { + const auto secondName = dirName(secondDir); + ASSERT_TRUE(secondName.find("c_int_trunc=") != std::string::npos) + << "Second level directory " << secondName + << " should use truncate transform"; + + auto thirdLevelDirs = firstLevelDirectories(secondDir); + ASSERT_GT(thirdLevelDirs.size(), 0) + << "No third level directories found in " << secondDir; + + for (const auto& thirdDir : thirdLevelDirs) { + const auto thirdName = dirName(thirdDir); + ASSERT_TRUE(thirdName.find("c_int_bucket=") != std::string::npos) + << "Third level directory " << thirdName + << " should use bucket transform"; + + // Verify the partition has data. + auto splits = createSplitsForDirectory(thirdDir); + auto countPlan = PlanBuilder() + .startTableScan() + .connectorId(test::kIcebergConnectorId) + .outputType(rowType) + .endTableScan() + .planNode(); + auto rowCount = + AssertQueryBuilder(countPlan).splits(splits).countResults(); + ASSERT_GT(rowCount, 0) + << "Leaf partition directory " << thirdDir << " has no data"; + } + } + } +} +#endif + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/TransformTest.cpp b/velox/connectors/hive/iceberg/tests/TransformTest.cpp new file mode 100644 index 000000000000..a0e19a916e62 --- /dev/null +++ b/velox/connectors/hive/iceberg/tests/TransformTest.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/iceberg/IcebergConnector.h" + +#include "velox/common/encode/Base64.h" +#include "velox/connectors/hive/iceberg/PartitionSpec.h" +#include "velox/connectors/hive/iceberg/TransformEvaluator.h" +#include "velox/connectors/hive/iceberg/TransformExprBuilder.h" +#include "velox/connectors/hive/iceberg/tests/IcebergTestBase.h" + +namespace facebook::velox::connector::hive::iceberg { + +namespace { + +class TransformTest : public test::IcebergTestBase { + protected: + void testTransform( + const IcebergPartitionSpecPtr& spec, + const RowVectorPtr& input, + const RowVectorPtr& expected) const { + std::vector partitionChannels; + for (const auto& field : spec->fields) { + partitionChannels.push_back(input->rowType()->getChildIdx(field.name)); + } + // Build and evaluate transform expressions. + auto transformExprs = TransformExprBuilder::toExpressions( + spec, + partitionChannels, + input->rowType(), + std::string(kDefaultIcebergFunctionPrefix)); + auto transformEvaluator = std::make_unique( + transformExprs, connectorQueryCtx_.get()); + auto result = transformEvaluator->evaluate(input); + + ASSERT_EQ(result.size(), expected->childrenSize()); + for (auto i = 0; i < result.size(); ++i) { + velox::test::assertEqualVectors(expected->childAt(i), result[i]); + } + } +}; + +TEST_F(TransformTest, identity) { + const auto& rowType = + ROW({"c0", "c1", "c2", "c3", "c4"}, + {INTEGER(), BIGINT(), VARCHAR(), VARBINARY(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kIdentity, std::nullopt}, + {1, TransformType::kIdentity, std::nullopt}, + {2, TransformType::kIdentity, std::nullopt}, + {3, TransformType::kIdentity, std::nullopt}, + {4, TransformType::kIdentity, std::nullopt}, + }); + + const std::vector input = { + makeFlatVector({1, -1}), + makeFlatVector({1L, -1L}), + makeFlatVector({("test data"), ("")}), + makeFlatVector({("\x01\x02\x03"), ("")}, VARBINARY()), + makeFlatVector({Timestamp(0, 0), Timestamp(1609459200, 0)}), + }; + + testTransform(partitionSpec, makeRowVector(input), makeRowVector(input)); +} + +TEST_F(TransformTest, nulls) { + const auto& rowType = + ROW({"c0", "c1", "c2", "c3", "c4", "c5", "c6"}, + {INTEGER(), + VARCHAR(), + VARBINARY(), + DATE(), + TIMESTAMP(), + TIMESTAMP(), + TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kIdentity, std::nullopt}, + {1, TransformType::kBucket, 8}, + {2, TransformType::kTruncate, 16}, + {3, TransformType::kYear, std::nullopt}, + {4, TransformType::kMonth, std::nullopt}, + {5, TransformType::kDay, std::nullopt}, + {6, TransformType::kHour, std::nullopt}, + }); + testTransform( + partitionSpec, + makeRowVector({ + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}, VARBINARY()), + makeNullableFlatVector({std::nullopt}, DATE()), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + }), + makeRowVector({ + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}, VARBINARY()), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({std::nullopt}, DATE()), + makeNullableFlatVector({std::nullopt}), + })); +} + +TEST_F(TransformTest, bucket) { + const auto& rowType = + ROW({"c0", "c1", "c2", "c3", "c4", "c5"}, + {INTEGER(), BIGINT(), VARCHAR(), VARBINARY(), DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kBucket, 4}, + {1, TransformType::kBucket, 8}, + {2, TransformType::kBucket, 16}, + {3, TransformType::kBucket, 32}, + {4, TransformType::kBucket, 10}, + {5, TransformType::kBucket, 8}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({8, 34, 0}), + makeFlatVector({34L, 0L, -34L}), + makeFlatVector({"abcdefg", "测试", ""}), + makeFlatVector( + {"\x61\x62\x64\x00\x00", "\x01\x02\x03\x04", "\x00"}, + VARBINARY()), + makeFlatVector({0, 365, 18'262}), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(-31536000, 0), + Timestamp(1612224000, 0)}), + }), + makeRowVector({ + makeFlatVector({3, 3, 0}), + makeFlatVector({3, 4, 5}), + makeFlatVector({6, 8, 0}), + makeFlatVector({26, 5, 0}), + makeFlatVector({6, 1, 3}), + makeFlatVector({4, 3, 5}), + })); +} + +TEST_F(TransformTest, year) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kYear, std::nullopt}, + {1, TransformType::kYear, std::nullopt}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({0, 18'262, -365}), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(31536000, 0), + Timestamp(-31536000, 0)}), + }), + makeRowVector({ + makeFlatVector({0, 50, -1}), + makeFlatVector({0, 1, -1}), + })); +} + +TEST_F(TransformTest, month) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kMonth, std::nullopt}, + {1, TransformType::kMonth, std::nullopt}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({0, 18'262, -365}), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(31536000, 0), + Timestamp(-2678400, 0)}), + }), + makeRowVector({ + makeFlatVector({0, 600, -12}), + makeFlatVector({0, 12, -1}), + })); +} + +TEST_F(TransformTest, day) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), TIMESTAMP()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kDay, std::nullopt}, + {1, TransformType::kDay, std::nullopt}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({0, 17532, -1}, DATE()), + makeFlatVector( + {Timestamp(0, 0), + Timestamp(1514764800, 0), + Timestamp(-86400, 0)}), + }), + makeRowVector({ + makeFlatVector({0, 17532, -1}, DATE()), + makeFlatVector({0, 17532, -1}, DATE()), + })); +} + +TEST_F(TransformTest, hour) { + const auto& partitionSpec = createPartitionSpec( + ROW({"c0"}, {TIMESTAMP()}), {{0, TransformType::kHour, std::nullopt}}); + + testTransform( + partitionSpec, + makeRowVector({makeFlatVector({ + Timestamp(0, 0), + Timestamp(3600, 0), + Timestamp(-3600, 0), + })}), + makeRowVector({makeFlatVector({0, 1, -1})})); +} + +TEST_F(TransformTest, truncate) { + const auto& rowType = ROW( + {"c0", "c1", "c2", "c3"}, {INTEGER(), BIGINT(), VARCHAR(), VARBINARY()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kTruncate, 10}, + {1, TransformType::kTruncate, 100}, + {2, TransformType::kTruncate, 5}, + {3, TransformType::kTruncate, 3}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({11, -11, 5}), + makeFlatVector({123L, -123L, 50L}), + makeFlatVector({"abcdefg", "测试data", "x"}), + makeFlatVector( + {"abcdefg", "\x01\x02\x03\x04", "\x05"}, VARBINARY()), + }), + makeRowVector({ + makeFlatVector({10, -20, 0}), + makeFlatVector({100L, -200L, 0L}), + makeFlatVector({"abcde", "测试dat", "x"}), + makeFlatVector( + {"abc", "\x01\x02\x03", "\x05"}, VARBINARY()), + })); +} + +TEST_F(TransformTest, multipleTransforms) { + const auto& rowType = ROW({"c0", "c1", "c2"}, {INTEGER(), DATE(), VARCHAR()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kBucket, 4}, + {1, TransformType::kYear, std::nullopt}, + {2, TransformType::kTruncate, 3}, + }); + + testTransform( + partitionSpec, + makeRowVector({ + makeFlatVector({8, 34}), + makeFlatVector({0, 17532}), + makeFlatVector({"abcdefg", "ab c"}), + }), + makeRowVector({ + makeFlatVector({3, 3}), + makeFlatVector({0, 48}), + makeFlatVector({"abc", "ab "}), + })); +} + +TEST_F(TransformTest, multipleTransformsOnSameColumn) { + const auto& rowType = ROW({"c0", "c1"}, {DATE(), VARCHAR()}); + const auto& partitionSpec = createPartitionSpec( + rowType, + { + {0, TransformType::kYear, std::nullopt}, + {0, TransformType::kBucket, 10}, + {1, TransformType::kTruncate, 5}, + {1, TransformType::kBucket, 8}, + }); + + testTransform( + partitionSpec, + makeRowVector( + rowType->names(), + { + makeFlatVector({0, 17532}), + makeFlatVector({"abcdefg", "test"}), + }), + makeRowVector({ + makeFlatVector({0, 48}), + makeFlatVector({6, 7}), + makeFlatVector({"abcde", "test"}), + makeFlatVector({6, 3}), + })); +} + +} // namespace + +} // namespace facebook::velox::connector::hive::iceberg diff --git a/velox/connectors/hive/iceberg/tests/examples/three_groups.parquet b/velox/connectors/hive/iceberg/tests/examples/three_groups.parquet new file mode 100644 index 000000000000..cfedba43474b Binary files /dev/null and b/velox/connectors/hive/iceberg/tests/examples/three_groups.parquet differ diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsConfig.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsConfig.cpp deleted file mode 100644 index 7604e3ef3b69..000000000000 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsConfig.cpp +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h" - -#include "velox/common/config/Config.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" - -#include - -namespace facebook::velox::filesystems { - -std::function()> - AbfsConfig::testWriteClientFn_; - -class DataLakeFileClientWrapper final : public AzureDataLakeFileClient { - public: - DataLakeFileClientWrapper(std::unique_ptr client) - : client_(std::move(client)) {} - - void create() override { - client_->Create(); - } - - Azure::Storage::Files::DataLake::Models::PathProperties getProperties() - override { - return client_->GetProperties().Value; - } - - void append(const uint8_t* buffer, size_t size, uint64_t offset) override { - auto bodyStream = Azure::Core::IO::MemoryBodyStream(buffer, size); - client_->Append(bodyStream, offset); - } - - void flush(uint64_t position) override { - client_->Flush(position); - } - - void close() override { - // do nothing. - } - - std::string getUrl() const override { - return client_->GetUrl(); - } - - private: - const std::unique_ptr client_; -}; - -AbfsConfig::AbfsConfig( - std::string_view path, - const config::ConfigBase& config) { - std::string_view file; - isHttps_ = true; - if (path.find(kAbfssScheme) == 0) { - file = path.substr(kAbfssScheme.size()); - } else if (path.find(kAbfsScheme) == 0) { - file = path.substr(kAbfsScheme.size()); - isHttps_ = false; - } else { - VELOX_FAIL("Invalid ABFS Path {}", path); - } - - auto firstAt = file.find_first_of("@"); - fileSystem_ = file.substr(0, firstAt); - auto firstSep = file.find_first_of("/"); - filePath_ = file.substr(firstSep + 1); - accountNameWithSuffix_ = file.substr(firstAt + 1, firstSep - firstAt - 1); - - auto authTypeKey = - fmt::format("{}.{}", kAzureAccountAuthType, accountNameWithSuffix_); - authType_ = kAzureSharedKeyAuthType; - if (config.valueExists(authTypeKey)) { - authType_ = config.get(authTypeKey).value(); - } - if (authType_ == kAzureSharedKeyAuthType) { - auto credKey = - fmt::format("{}.{}", kAzureAccountKey, accountNameWithSuffix_); - VELOX_USER_CHECK( - config.valueExists(credKey), "Config {} not found", credKey); - auto firstDot = accountNameWithSuffix_.find_first_of("."); - auto accountName = accountNameWithSuffix_.substr(0, firstDot); - auto endpointSuffix = accountNameWithSuffix_.substr(firstDot + 5); - std::stringstream ss; - ss << "DefaultEndpointsProtocol=" << (isHttps_ ? "https" : "http"); - ss << ";AccountName=" << accountName; - ss << ";AccountKey=" << config.get(credKey).value(); - ss << ";EndpointSuffix=" << endpointSuffix; - - if (config.valueExists(kAzureBlobEndpoint)) { - ss << ";BlobEndpoint=" - << config.get(kAzureBlobEndpoint).value(); - } - ss << ";"; - connectionString_ = ss.str(); - } else if (authType_ == kAzureOAuthAuthType) { - auto clientIdKey = fmt::format( - "{}.{}", kAzureAccountOAuth2ClientId, accountNameWithSuffix_); - auto clientSecretKey = fmt::format( - "{}.{}", kAzureAccountOAuth2ClientSecret, accountNameWithSuffix_); - auto clientEndpointKey = fmt::format( - "{}.{}", kAzureAccountOAuth2ClientEndpoint, accountNameWithSuffix_); - VELOX_USER_CHECK( - config.valueExists(clientIdKey), "Config {} not found", clientIdKey); - VELOX_USER_CHECK( - config.valueExists(clientSecretKey), - "Config {} not found", - clientSecretKey); - VELOX_USER_CHECK( - config.valueExists(clientEndpointKey), - "Config {} not found", - clientEndpointKey); - auto clientEndpoint = config.get(clientEndpointKey).value(); - auto firstSep = clientEndpoint.find_first_of("/", /* https:// */ 8); - authorityHost_ = clientEndpoint.substr(0, firstSep + 1); - auto sedondSep = clientEndpoint.find_first_of("/", firstSep + 1); - tenentId_ = clientEndpoint.substr(firstSep + 1, sedondSep - firstSep - 1); - Azure::Identity::ClientSecretCredentialOptions options; - options.AuthorityHost = authorityHost_; - tokenCredential_ = - std::make_shared( - tenentId_, - config.get(clientIdKey).value(), - config.get(clientSecretKey).value(), - options); - } else if (authType_ == kAzureSASAuthType) { - auto sasKey = fmt::format("{}.{}", kAzureSASKey, accountNameWithSuffix_); - VELOX_USER_CHECK(config.valueExists(sasKey), "Config {} not found", sasKey); - sas_ = config.get(sasKey).value(); - } else { - VELOX_USER_FAIL( - "Unsupported auth type {}, supported auth types are SharedKey, OAuth and SAS.", - authType_); - } -} - -std::unique_ptr AbfsConfig::getReadFileClient() { - if (authType_ == kAzureSASAuthType) { - auto url = getUrl(true); - return std::make_unique(fmt::format("{}?{}", url, sas_)); - } else if (authType_ == kAzureOAuthAuthType) { - auto url = getUrl(true); - return std::make_unique(url, tokenCredential_); - } else { - return std::make_unique(BlobClient::CreateFromConnectionString( - connectionString_, fileSystem_, filePath_)); - } -} - -std::unique_ptr AbfsConfig::getWriteFileClient() { - if (testWriteClientFn_) { - return testWriteClientFn_(); - } - std::unique_ptr client; - if (authType_ == kAzureSASAuthType) { - auto url = getUrl(false); - client = - std::make_unique(fmt::format("{}?{}", url, sas_)); - } else if (authType_ == kAzureOAuthAuthType) { - auto url = getUrl(false); - client = std::make_unique(url, tokenCredential_); - } else { - client = std::make_unique( - DataLakeFileClient::CreateFromConnectionString( - connectionString_, fileSystem_, filePath_)); - } - return std::make_unique(std::move(client)); -} - -std::string AbfsConfig::getUrl(bool withblobSuffix) { - std::string accountNameWithSuffixForUrl(accountNameWithSuffix_); - if (withblobSuffix) { - // We should use correct suffix for blob client. - size_t start_pos = accountNameWithSuffixForUrl.find("dfs"); - if (start_pos != std::string::npos) { - accountNameWithSuffixForUrl.replace(start_pos, 3, "blob"); - } - } - return fmt::format( - "{}{}/{}/{}", - isHttps_ ? "https://" : "http://", - accountNameWithSuffixForUrl, - fileSystem_, - filePath_); -} - -} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp index 18fbd3f2284a..e5877f74638b 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp @@ -20,198 +20,14 @@ #include #include -#include "velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" namespace facebook::velox::filesystems { -class AbfsReadFile::Impl { - constexpr static uint64_t kNaturalReadSize = 4 << 20; // 4M - constexpr static uint64_t kReadConcurrency = 8; - - public: - explicit Impl(std::string_view path, const config::ConfigBase& config) { - auto abfsConfig = AbfsConfig(path, config); - filePath_ = abfsConfig.filePath(); - fileClient_ = abfsConfig.getReadFileClient(); - } - - void initialize(const FileOptions& options) { - if (options.fileSize.has_value()) { - VELOX_CHECK_GE( - options.fileSize.value(), 0, "File size must be non-negative"); - length_ = options.fileSize.value(); - } - - if (length_ != -1) { - return; - } - - try { - auto properties = fileClient_->GetProperties(); - length_ = properties.Value.BlobSize; - } catch (Azure::Storage::StorageException& e) { - throwStorageExceptionWithOperationDetails("GetProperties", filePath_, e); - } - VELOX_CHECK_GE(length_, 0); - } - - std::string_view pread( - uint64_t offset, - uint64_t length, - void* buffer, - File::IoStats* stats) const { - preadInternal(offset, length, static_cast(buffer)); - return {static_cast(buffer), length}; - } - - std::string pread(uint64_t offset, uint64_t length, File::IoStats* stats) - const { - std::string result(length, 0); - preadInternal(offset, length, result.data()); - return result; - } - - uint64_t preadv( - uint64_t offset, - const std::vector>& buffers, - File::IoStats* stats) const { - size_t length = 0; - auto size = buffers.size(); - for (auto& range : buffers) { - length += range.size(); - } - std::string result(length, 0); - preadInternal(offset, length, static_cast(result.data())); - size_t resultOffset = 0; - for (auto range : buffers) { - if (range.data()) { - memcpy(range.data(), &(result.data()[resultOffset]), range.size()); - } - resultOffset += range.size(); - } - - return length; - } - - uint64_t preadv( - folly::Range regions, - folly::Range iobufs, - File::IoStats* stats) const { - size_t length = 0; - VELOX_CHECK_EQ(regions.size(), iobufs.size()); - for (size_t i = 0; i < regions.size(); ++i) { - const auto& region = regions[i]; - auto& output = iobufs[i]; - output = folly::IOBuf(folly::IOBuf::CREATE, region.length); - pread(region.offset, region.length, output.writableData(), stats); - output.append(region.length); - length += region.length; - } - - return length; - } - - uint64_t size() const { - return length_; - } - - uint64_t memoryUsage() const { - return 3 * sizeof(std::string) + sizeof(int64_t); - } - - bool shouldCoalesce() const { - return false; - } - - std::string getName() const { - return filePath_; - } - - uint64_t getNaturalReadSize() const { - return kNaturalReadSize; - } - - private: - void preadInternal(uint64_t offset, uint64_t length, char* position) const { - // Read the desired range of bytes. - Azure::Core::Http::HttpRange range; - range.Offset = offset; - range.Length = length; - - Azure::Storage::Blobs::DownloadBlobOptions blob; - blob.Range = range; - auto response = fileClient_->Download(blob); - response.Value.BodyStream->ReadToCount( - reinterpret_cast(position), length); - } - - std::string filePath_; - std::unique_ptr fileClient_; - int64_t length_ = -1; -}; - -AbfsReadFile::AbfsReadFile( - std::string_view path, - const config::ConfigBase& config) { - impl_ = std::make_shared(path, config); -} - -void AbfsReadFile::initialize(const FileOptions& options) { - return impl_->initialize(options); -} - -std::string_view AbfsReadFile::pread( - uint64_t offset, - uint64_t length, - void* buffer, - File::IoStats* stats) const { - return impl_->pread(offset, length, buffer, stats); -} - -std::string AbfsReadFile::pread( - uint64_t offset, - uint64_t length, - File::IoStats* stats) const { - return impl_->pread(offset, length, stats); -} - -uint64_t AbfsReadFile::preadv( - uint64_t offset, - const std::vector>& buffers, - File::IoStats* stats) const { - return impl_->preadv(offset, buffers, stats); -} - -uint64_t AbfsReadFile::preadv( - folly::Range regions, - folly::Range iobufs, - File::IoStats* stats) const { - return impl_->preadv(regions, iobufs, stats); -} - -uint64_t AbfsReadFile::size() const { - return impl_->size(); -} - -uint64_t AbfsReadFile::memoryUsage() const { - return impl_->memoryUsage(); -} - -bool AbfsReadFile::shouldCoalesce() const { - return false; -} - -std::string AbfsReadFile::getName() const { - return impl_->getName(); -} - -uint64_t AbfsReadFile::getNaturalReadSize() const { - return impl_->getNaturalReadSize(); -} - AbfsFileSystem::AbfsFileSystem(std::shared_ptr config) : FileSystem(config) { VELOX_CHECK_NOT_NULL(config.get()); diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h index c0d3d60ccdee..8a838ee17719 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h @@ -19,7 +19,7 @@ namespace facebook::velox::filesystems { -/// Implementation of the ABS (Azure Blob Storage) filesystem and file +/// Implementation of the ABFS (Azure Blob File Storage) filesystem and file /// interface. We provide a registration method for reading and writing files so /// that the appropriate type of file can be constructed based on a filename. /// The supported schema is `abfs(s)://` to align with the valid scheme @@ -76,5 +76,4 @@ class AbfsFileSystem : public FileSystem { } }; -void registerAbfsFileSystem(); } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsPath.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.cpp new file mode 100644 index 000000000000..b6ba75cbcbfe --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" + +namespace facebook::velox::filesystems { + +AbfsPath::AbfsPath(std::string_view path) { + std::string_view file; + isHttps_ = true; + if (path.find(kAbfssScheme) == 0) { + file = path.substr(kAbfssScheme.size()); + } else if (path.find(kAbfsScheme) == 0) { + file = path.substr(kAbfsScheme.size()); + isHttps_ = false; + } else { + VELOX_FAIL("Invalid ABFS Path {}", path); + } + + auto firstAt = file.find_first_of("@"); + fileSystem_ = file.substr(0, firstAt); + auto firstSep = file.find_first_of("/"); + filePath_ = file.substr(firstSep + 1); + accountNameWithSuffix_ = file.substr(firstAt + 1, firstSep - firstAt - 1); + auto firstDot = accountNameWithSuffix_.find_first_of("."); + accountName_ = accountNameWithSuffix_.substr(0, firstDot); +} + +std::string AbfsPath::getUrl(bool withblobSuffix) const { + std::string accountNameWithSuffixForUrl(accountNameWithSuffix_); + if (withblobSuffix) { + // We should use correct suffix for blob client. + size_t startPos = accountNameWithSuffixForUrl.find("dfs"); + if (startPos != std::string::npos) { + accountNameWithSuffixForUrl.replace(startPos, 3, "blob"); + } + } + return fmt::format( + "{}{}/{}/{}", + isHttps_ ? "https://" : "http://", + accountNameWithSuffixForUrl, + fileSystem_, + filePath_); +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h similarity index 65% rename from velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h rename to velox/connectors/hive/storage_adapters/abfs/AbfsPath.h index af4b46708d86..3a4e0d99f097 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsPath.h @@ -16,20 +16,13 @@ #pragma once -#include -#include #include #include #include -#include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" using namespace Azure::Storage::Blobs; using namespace Azure::Storage::Files::DataLake; -namespace facebook::velox::config { -class ConfigBase; -} - namespace facebook::velox::filesystems { // This is used to specify the Azurite endpoint in testing. @@ -61,70 +54,48 @@ static constexpr const char* kAzureOAuthAuthType = "OAuth"; static constexpr const char* kAzureSASAuthType = "SAS"; -class AbfsConfig { - public: - explicit AbfsConfig(std::string_view path, const config::ConfigBase& config); - - std::unique_ptr getReadFileClient(); +// For performance, re - use SAS tokens until the expiry is within this number +// of seconds. +static constexpr const char* kAzureSasTokenRenewPeriod = + "fs.azure.sas.token.renew.period.for.streams"; - std::unique_ptr getWriteFileClient(); +// Helper class to parse and extract information from a given ABFS path. +class AbfsPath { + public: + AbfsPath(std::string_view path); - std::string filePath() const { - return filePath_; + bool isHttps() const { + return isHttps_; } - /// Test only. std::string fileSystem() const { return fileSystem_; } - /// Test only. - std::string connectionString() const { - return connectionString_; - } - - /// Test only. - std::string tenentId() const { - return tenentId_; + std::string filePath() const { + return filePath_; } - /// Test only. - std::string authorityHost() const { - return authorityHost_; + std::string accountName() const { + return accountName_; } - /// Test only. - static void setUpTestWriteClient( - std::function()> testClientFn) { - testWriteClientFn_ = testClientFn; + std::string accountNameWithSuffix() const { + return accountNameWithSuffix_; } - /// Test only. - static void tearDownTestWriteClient() { - testWriteClientFn_ = nullptr; - } + std::string getUrl(bool withblobSuffix) const; private: - std::string getUrl(bool withblobSuffix); - - std::string authType_; + bool isHttps_; // Container name is called FileSystem in some Azure API. std::string fileSystem_; std::string filePath_; std::string connectionString_; - bool isHttps_; + std::string accountName_; std::string accountNameWithSuffix_; - - std::string sas_; - - std::string tenentId_; - std::string authorityHost_; - std::shared_ptr tokenCredential_; - - static std::function()> - testWriteClientFn_; }; } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp new file mode 100644 index 000000000000..575f1f2de572 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.cpp @@ -0,0 +1,216 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" + +namespace facebook::velox::filesystems { + +class AbfsReadFile::Impl { + constexpr static uint64_t kNaturalReadSize = 4 << 20; // 4M + constexpr static uint64_t kReadConcurrency = 8; + + public: + explicit Impl(std::string_view path, const config::ConfigBase& config) { + auto abfsPath = std::make_shared(path); + filePath_ = abfsPath->filePath(); + fileClient_ = + AzureClientProviderFactories::getReadFileClient(abfsPath, config); + } + + void initialize(const FileOptions& options) { + if (options.fileSize.has_value()) { + VELOX_CHECK_GE( + options.fileSize.value(), 0, "File size must be non-negative"); + length_ = options.fileSize.value(); + } + + if (length_ != -1) { + return; + } + + try { + auto properties = fileClient_->getProperties(); + length_ = properties.Value.BlobSize; + } catch (Azure::Storage::StorageException& e) { + throwStorageExceptionWithOperationDetails("GetProperties", filePath_, e); + } + VELOX_CHECK_GE(length_, 0); + } + + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buffer, + const FileStorageContext& fileStorageContext) const { + preadInternal(offset, length, static_cast(buffer)); + return {static_cast(buffer), length}; + } + + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { + std::string result(length, 0); + preadInternal(offset, length, result.data()); + return result; + } + + uint64_t preadv( + uint64_t offset, + const std::vector>& buffers, + const FileStorageContext& fileStorageContext) const { + size_t length = 0; + auto size = buffers.size(); + for (auto& range : buffers) { + length += range.size(); + } + std::string result(length, 0); + preadInternal(offset, length, static_cast(result.data())); + size_t resultOffset = 0; + for (auto range : buffers) { + if (range.data()) { + memcpy(range.data(), &(result.data()[resultOffset]), range.size()); + } + resultOffset += range.size(); + } + + return length; + } + + uint64_t preadv( + folly::Range regions, + folly::Range iobufs, + const FileStorageContext& fileStorageContext) const { + size_t length = 0; + VELOX_CHECK_EQ(regions.size(), iobufs.size()); + for (size_t i = 0; i < regions.size(); ++i) { + const auto& region = regions[i]; + auto& output = iobufs[i]; + output = folly::IOBuf(folly::IOBuf::CREATE, region.length); + pread( + region.offset, + region.length, + output.writableData(), + fileStorageContext); + output.append(region.length); + length += region.length; + } + + return length; + } + + uint64_t size() const { + return length_; + } + + uint64_t memoryUsage() const { + return 3 * sizeof(std::string) + sizeof(int64_t); + } + + bool shouldCoalesce() const { + return false; + } + + std::string getName() const { + return filePath_; + } + + uint64_t getNaturalReadSize() const { + return kNaturalReadSize; + } + + private: + void preadInternal(uint64_t offset, uint64_t length, char* position) const { + // Read the desired range of bytes. + Azure::Core::Http::HttpRange range; + range.Offset = offset; + range.Length = length; + + Azure::Storage::Blobs::DownloadBlobOptions blob; + blob.Range = range; + auto response = fileClient_->download(blob); + response.Value.BodyStream->ReadToCount( + reinterpret_cast(position), length); + } + + std::string filePath_; + std::unique_ptr fileClient_; + int64_t length_ = -1; +}; + +AbfsReadFile::AbfsReadFile( + std::string_view path, + const config::ConfigBase& config) { + impl_ = std::make_shared(path, config); +} + +void AbfsReadFile::initialize(const FileOptions& options) { + return impl_->initialize(options); +} + +std::string_view AbfsReadFile::pread( + uint64_t offset, + uint64_t length, + void* buffer, + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, buffer, fileStorageContext); +} + +std::string AbfsReadFile::pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, fileStorageContext); +} + +uint64_t AbfsReadFile::preadv( + uint64_t offset, + const std::vector>& buffers, + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(offset, buffers, fileStorageContext); +} + +uint64_t AbfsReadFile::preadv( + folly::Range regions, + folly::Range iobufs, + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(regions, iobufs, fileStorageContext); +} + +uint64_t AbfsReadFile::size() const { + return impl_->size(); +} + +uint64_t AbfsReadFile::memoryUsage() const { + return impl_->memoryUsage(); +} + +bool AbfsReadFile::shouldCoalesce() const { + return false; +} + +std::string AbfsReadFile::getName() const { + return impl_->getName(); +} + +uint64_t AbfsReadFile::getNaturalReadSize() const { + return impl_->getNaturalReadSize(); +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h index 942439c06c1e..b682926ad1af 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h @@ -35,22 +35,22 @@ class AbfsReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; std::string pread( uint64_t offset, uint64_t length, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t preadv( uint64_t offset, const std::vector>& buffers, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t preadv( folly::Range regions, folly::Range iobufs, - File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp new file mode 100644 index 000000000000..5aefc0983867 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" + +namespace facebook::velox::filesystems { + +std::vector extractCacheKeyFromConfig( + const config::ConfigBase& config) { + std::vector cacheKeys; + constexpr std::string_view authTypePrefix{kAzureAccountAuthType}; + for (const auto& [key, value] : config.rawConfigs()) { + if (key.find(authTypePrefix) == 0) { + // Extract the accountName after "fs.azure.account.auth.type.". + auto remaining = std::string_view(key).substr(authTypePrefix.size() + 1); + auto dot = remaining.find("."); + VELOX_USER_CHECK_NE( + dot, + std::string_view::npos, + "Invalid Azure account auth type key: {}", + key); + cacheKeys.emplace_back(CacheKey{remaining.substr(0, dot), value}); + } + } + return cacheKeys; +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h index 925c6f91ece9..1a6cf6e0a0e7 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h @@ -26,6 +26,16 @@ constexpr std::string_view kAbfsScheme{"abfs://"}; constexpr std::string_view kAbfssScheme{"abfss://"}; } // namespace +class ConfigBase; + +struct CacheKey { + const std::string accountName; + const std::string authType; + + CacheKey(std::string_view accountName, std::string_view authType) + : accountName(accountName), authType(authType) {} +}; + inline bool isAbfsFile(const std::string_view filename) { return filename.find(kAbfsScheme) == 0 || filename.find(kAbfssScheme) == 0; } @@ -45,4 +55,7 @@ inline std::string throwStorageExceptionWithOperationDetails( VELOX_FAIL(errMsg); } +std::vector extractCacheKeyFromConfig( + const config::ConfigBase& config); + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.cpp index 2ba015743a0d..5db3b535b386 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.cpp @@ -15,8 +15,10 @@ */ #include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h" + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" namespace facebook::velox::filesystems { @@ -88,8 +90,9 @@ class AbfsWriteFile::Impl { AbfsWriteFile::AbfsWriteFile( std::string_view path, const config::ConfigBase& config) { - auto abfsConfig = AbfsConfig(path, config); - auto client = abfsConfig.getWriteFileClient(); + const auto abfsPath = std::make_shared(path); + auto client = + AzureClientProviderFactories::getWriteFileClient(abfsPath, config); impl_ = std::make_unique(path, client); } diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h b/velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h new file mode 100644 index 000000000000..d633ac7fdbae --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::filesystems { + +// Interface for Azure Blob Storage client operations. +class AzureBlobClient { + public: + virtual ~AzureBlobClient() {} + + virtual Azure::Response + getProperties() = 0; + + virtual Azure::Response + download(const Azure::Storage::Blobs::DownloadBlobOptions& options) = 0; + + virtual std::string getUrl() = 0; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h b/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h new file mode 100644 index 000000000000..1a1a68f6d87f --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" + +namespace facebook::velox::filesystems { + +// Provider interface for creating Azure Blob and Data Lake clients. +class AzureClientProvider { + public: + virtual ~AzureClientProvider() = default; + + // Creates AzureBlobClient for file read operations. + virtual std::unique_ptr getReadFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) = 0; + + // Creates AzureDataLakeFileClient for file write operations. + virtual std::unique_ptr getWriteFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) = 0; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.cpp b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.cpp new file mode 100644 index 000000000000..cbd0ae54859f --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.cpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" + +#include + +namespace facebook::velox::filesystems { + +namespace { + +folly::Synchronized< + std::unordered_map>& +azureClientFactoryRegistry() { + static folly::Synchronized< + std::unordered_map> + factories; + return factories; +} + +} // namespace + +void AzureClientProviderFactories::registerFactory( + const std::string& account, + const AzureClientProviderFactory& factory) { + azureClientFactoryRegistry().withWLock([&](auto& factories) { + auto [_, inserted] = factories.insert_or_assign(account, factory); + LOG_IF(INFO, !inserted) << "AzureClientProviderFactory for account '" + << account << "' has been overridden."; + }); +} + +AzureClientProviderFactory AzureClientProviderFactories::getClientFactory( + const std::string& account) { + return azureClientFactoryRegistry().withRLock( + [&](const auto& factories) -> AzureClientProviderFactory { + if (auto it = factories.find(account); it != factories.end()) { + return it->second; + } + VELOX_USER_FAIL( + "No AzureClientProviderFactory registered for account '{}'." + "Please use `registerAzureClientProvider` or " + "`registerAzureClientProviderFactory` to register a factory for " + "the account before using it.", + account); + }); +} + +std::unique_ptr +AzureClientProviderFactories::getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + auto factory = getClientFactory(abfsPath->accountName()); + return factory(abfsPath->accountName())->getReadFileClient(abfsPath, config); +} + +std::unique_ptr +AzureClientProviderFactories::getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + auto factory = getClientFactory(abfsPath->accountName()); + return factory(abfsPath->accountName())->getWriteFileClient(abfsPath, config); +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h new file mode 100644 index 000000000000..6c2af30aa985 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" + +namespace facebook::velox::filesystems { + +using AzureClientProviderFactory = + std::function( + const std::string& account)>; + +/// Handles the registration of Azure client providers and the creation of +/// AzureBlobClient and AzureDataLakeFileClient instances. +class AzureClientProviderFactories { + public: + /// Registers a factory for creating AzureClientProvider instances. + /// Any existing factory registered for the specified account will be + /// overwritten by recalling this method with the same account name. + static void registerFactory( + const std::string& account, + const AzureClientProviderFactory& factory); + + /// Get the registered AzureClientProviderFactory for the specified + /// account. Throws exception if no factory is registered for the account. + static AzureClientProviderFactory getClientFactory( + const std::string& account); + + /// Uses the registered AzureClientProviderFactory to create an + /// AzureBlobClient for file read operations. Throws exception if no factory + /// is registered for the account specified in `abfsPath`. + static std::unique_ptr getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); + + /// Uses the registered AzureClientProviderFactory to create an + /// AzureDataLakeFileClient for file write operations. Throws exception if no + /// factory is registered for the account specified in `abfsPath`. + static std::unique_ptr getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.cpp b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.cpp new file mode 100644 index 000000000000..eec80f08f6df --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.cpp @@ -0,0 +1,240 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" + +#include + +namespace facebook::velox::filesystems { + +namespace { + +class DataLakeFileClientWrapper final : public AzureDataLakeFileClient { + public: + DataLakeFileClientWrapper(std::unique_ptr client) + : client_(std::move(client)) {} + + void create() override { + client_->Create(); + } + + Azure::Storage::Files::DataLake::Models::PathProperties getProperties() + override { + return client_->GetProperties().Value; + } + + void append(const uint8_t* buffer, size_t size, uint64_t offset) override { + auto bodyStream = Azure::Core::IO::MemoryBodyStream(buffer, size); + client_->Append(bodyStream, offset); + } + + void flush(uint64_t position) override { + client_->Flush(position); + } + + void close() override { + // do nothing. + } + + std::string getUrl() override { + return client_->GetUrl(); + } + + private: + const std::unique_ptr client_; +}; + +class BlobClientWrapper : public AzureBlobClient { + public: + BlobClientWrapper(std::unique_ptr client) { + blobClient_ = std::move(client); + } + + Azure::Response getProperties() + override { + return blobClient_->GetProperties(); + } + + Azure::Response download( + const Azure::Storage::Blobs::DownloadBlobOptions& options) override { + return blobClient_->Download(options); + } + + std::string getUrl() override { + return blobClient_->GetUrl(); + } + + private: + std::unique_ptr blobClient_; +}; + +} // namespace + +std::unique_ptr +SharedKeyAzureClientProvider::getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + auto client = + std::make_unique(BlobClient::CreateFromConnectionString( + connectionString_, abfsPath->fileSystem(), abfsPath->filePath())); + return std::make_unique(std::move(client)); +} + +std::unique_ptr +SharedKeyAzureClientProvider::getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + auto client = std::make_unique( + DataLakeFileClient::CreateFromConnectionString( + connectionString_, abfsPath->fileSystem(), abfsPath->filePath())); + return std::make_unique(std::move(client)); +} + +std::string SharedKeyAzureClientProvider::connectionString( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + return connectionString_; +} + +void SharedKeyAzureClientProvider::init( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + auto credKey = + fmt::format("{}.{}", kAzureAccountKey, abfsPath->accountNameWithSuffix()); + VELOX_USER_CHECK(config.valueExists(credKey), "Config {} not found", credKey); + auto firstDot = abfsPath->accountNameWithSuffix().find_first_of("."); + auto endpointSuffix = + abfsPath->accountNameWithSuffix().substr(firstDot + 5 /* .dfs. */); + std::stringstream ss; + ss << "DefaultEndpointsProtocol=" << (abfsPath->isHttps() ? "https" : "http"); + ss << ";AccountName=" << abfsPath->accountName(); + ss << ";AccountKey=" << config.get(credKey).value(); + ss << ";EndpointSuffix=" << endpointSuffix; + + if (config.valueExists(kAzureBlobEndpoint)) { + ss << ";BlobEndpoint=" + << config.get(kAzureBlobEndpoint).value(); + } + ss << ";"; + connectionString_ = ss.str(); +} + +std::unique_ptr OAuthAzureClientProvider::getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + const auto url = abfsPath->getUrl(true); + auto client = std::make_unique(url, tokenCredential_); + return std::make_unique(std::move(client)); +} + +std::unique_ptr +OAuthAzureClientProvider::getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + const auto url = abfsPath->getUrl(false); + auto client = std::make_unique(url, tokenCredential_); + return std::make_unique(std::move(client)); +} + +std::pair +OAuthAzureClientProvider::tenantIdAndAuthorityHost( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + return {tenentId_, authorityHost_}; +} + +void OAuthAzureClientProvider::init( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + auto clientIdKey = fmt::format( + "{}.{}", kAzureAccountOAuth2ClientId, abfsPath->accountNameWithSuffix()); + auto clientSecretKey = fmt::format( + "{}.{}", + kAzureAccountOAuth2ClientSecret, + abfsPath->accountNameWithSuffix()); + auto clientEndpointKey = fmt::format( + "{}.{}", + kAzureAccountOAuth2ClientEndpoint, + abfsPath->accountNameWithSuffix()); + VELOX_USER_CHECK( + config.valueExists(clientIdKey), "Config {} not found", clientIdKey); + VELOX_USER_CHECK( + config.valueExists(clientSecretKey), + "Config {} not found", + clientSecretKey); + VELOX_USER_CHECK( + config.valueExists(clientEndpointKey), + "Config {} not found", + clientEndpointKey); + auto clientEndpoint = config.get(clientEndpointKey).value(); + // Length of "https://". + static const std::size_t kHttpsPrefixLen = 8; + auto firstSep = clientEndpoint.find_first_of("/", kHttpsPrefixLen); + authorityHost_ = clientEndpoint.substr(0, firstSep + 1); + auto sedondSep = clientEndpoint.find_first_of("/", firstSep + 1); + tenentId_ = clientEndpoint.substr(firstSep + 1, sedondSep - firstSep - 1); + Azure::Identity::ClientSecretCredentialOptions options; + options.AuthorityHost = authorityHost_; + tokenCredential_ = std::make_shared( + tenentId_, + config.get(clientIdKey).value(), + config.get(clientSecretKey).value(), + options); +} + +std::unique_ptr FixedSasAzureClientProvider::getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + const auto url = abfsPath->getUrl(true); + auto client = std::make_unique(fmt::format("{}?{}", url, sas_)); + return std::make_unique(std::move(client)); +} + +std::unique_ptr +FixedSasAzureClientProvider::getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + const auto url = abfsPath->getUrl(false); + auto client = + std::make_unique(fmt::format("{}?{}", url, sas_)); + return std::make_unique(std::move(client)); +} + +std::string FixedSasAzureClientProvider::sas( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(abfsPath, config); + return sas_; +} + +void FixedSasAzureClientProvider::init( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + auto sasKey = + fmt::format("{}.{}", kAzureSASKey, abfsPath->accountNameWithSuffix()); + VELOX_USER_CHECK(config.valueExists(sasKey), "Config {} not found", sasKey); + sas_ = config.get(sasKey).value(); +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h new file mode 100644 index 000000000000..85293891fcb3 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h @@ -0,0 +1,101 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/config/Config.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" + +namespace facebook::velox::filesystems { + +// AzureClientProvider for Shared Key authentication. +class SharedKeyAzureClientProvider final : public AzureClientProvider { + public: + std::unique_ptr getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + // Test only. + std::string connectionString( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); + + private: + void init( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); + + std::string connectionString_; +}; + +// AzureClientProvider for OAuth authentication. +class OAuthAzureClientProvider final : public AzureClientProvider { + public: + std::unique_ptr getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + // Test only. + std::pair tenantIdAndAuthorityHost( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); + + private: + void init( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); + + std::string tenentId_; + std::string authorityHost_; + std::shared_ptr tokenCredential_; +}; + +// AzureClientProvider for SAS authentication with a fixed SAS token. +class FixedSasAzureClientProvider final : public AzureClientProvider { + public: + std::unique_ptr getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + // Test only. + std::string sas( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); + + private: + void init( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config); + + std::string sas_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h b/velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h index abd607c0d1b2..3416bab7c5d8 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h +++ b/velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h @@ -44,6 +44,6 @@ class AzureDataLakeFileClient { virtual void append(const uint8_t* buffer, size_t size, uint64_t offset) = 0; virtual void flush(uint64_t position) = 0; virtual void close() = 0; - virtual std::string getUrl() const = 0; + virtual std::string getUrl() = 0; }; } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt index 799e93830373..136db68d1afd 100644 --- a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt @@ -20,22 +20,30 @@ if(VELOX_ENABLE_ABFS) velox_sources( velox_abfs PRIVATE - AbfsFileSystem.cpp - AbfsConfig.cpp - AbfsWriteFile.cpp) + AbfsFileSystem.cpp + AbfsPath.cpp + AbfsReadFile.cpp + AbfsUtil.cpp + AbfsWriteFile.cpp + AzureClientProviderFactories.cpp + AzureClientProviderImpl.cpp + DynamicSasTokenClientProvider.cpp + ) velox_link_libraries( velox_abfs - PUBLIC velox_file - velox_core - velox_hive_config - velox_dwio_common_exception - Azure::azure-identity - Azure::azure-storage-blobs - Azure::azure-storage-files-datalake - Folly::folly - glog::glog - fmt::fmt) + PUBLIC + velox_file + velox_core + velox_hive_config + velox_dwio_common_exception + Azure::azure-identity + Azure::azure-storage-blobs + Azure::azure-storage-files-datalake + Folly::folly + glog::glog + fmt::fmt + ) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp new file mode 100644 index 000000000000..b98ae99d48ed --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h" + +#include + +namespace facebook::velox::filesystems { + +namespace { + +constexpr int64_t kDefaultSasTokenRenewPeriod = 120; // in seconds + +Azure::DateTime getExpiry(const std::string_view& token) { + if (token.empty()) { + return Azure::DateTime::clock::time_point::min(); + } + + static constexpr std::string_view kSignedExpiry{"se="}; + static constexpr int32_t kSignedExpiryLen = 3; + + auto start = token.find(kSignedExpiry); + if (start == std::string::npos) { + return Azure::DateTime::clock::time_point::min(); + } + start += kSignedExpiryLen; + + auto end = token.find("&", start); + auto seValue = (end == std::string::npos) + ? std::string(token.substr(start)) + : std::string(token.substr(start, end - start)); + + seValue = Azure::Core::Url::Decode(seValue); + auto seDate = + Azure::DateTime::Parse(seValue, Azure::DateTime::DateFormat::Rfc3339); + + static constexpr std::string_view kSignedKeyExpiry = "ske="; + static constexpr int32_t kSignedKeyExpiryLen = 4; + + start = token.find(kSignedKeyExpiry); + if (start == std::string::npos) { + return seDate; + } + start += kSignedKeyExpiryLen; + + end = token.find("&", start); + auto skeValue = (end == std::string::npos) + ? std::string(token.substr(start)) + : std::string(token.substr(start, end - start)); + + skeValue = Azure::Core::Url::Decode(skeValue); + auto skeDate = + Azure::DateTime::Parse(skeValue, Azure::DateTime::DateFormat::Rfc3339); + + return std::min(skeDate, seDate); +} + +bool isNearExpiry(Azure::DateTime expiration, int64_t minExpirationInSeconds) { + if (expiration == Azure::DateTime::clock::time_point::min()) { + return true; + } + auto remaining = std::chrono::duration_cast( + expiration - Azure::DateTime::clock::now()) + .count(); + return remaining <= minExpirationInSeconds; +} + +class DynamicSasTokenDataLakeFileClient final : public AzureDataLakeFileClient { + public: + DynamicSasTokenDataLakeFileClient( + const std::shared_ptr& abfsPath, + const std::shared_ptr& sasKeyGenerator, + int64_t sasTokenRenewPeriod) + : abfsPath_(abfsPath), + sasKeyGenerator_(sasKeyGenerator), + sasTokenRenewPeriod_(sasTokenRenewPeriod) {} + + void create() override { + getWriteClient()->Create(); + } + + Azure::Storage::Files::DataLake::Models::PathProperties getProperties() + override { + return getReadClient()->GetProperties().Value; + } + + void append(const uint8_t* buffer, size_t size, uint64_t offset) override { + auto bodyStream = Azure::Core::IO::MemoryBodyStream(buffer, size); + getWriteClient()->Append(bodyStream, offset); + } + + void flush(uint64_t position) override { + getWriteClient()->Flush(position); + } + + void close() override {} + + std::string getUrl() override { + return getWriteClient()->GetUrl(); + } + + private: + std::shared_ptr abfsPath_; + std::shared_ptr sasKeyGenerator_; + int64_t sasTokenRenewPeriod_; + + std::unique_ptr writeClient_{nullptr}; + Azure::DateTime writeSasExpiration_{ + Azure::DateTime::clock::time_point::min()}; + + std::unique_ptr readClient_{nullptr}; + Azure::DateTime readSasExpiration_{Azure::DateTime::clock::time_point::min()}; + + DataLakeFileClient* getWriteClient() { + if (writeClient_ == nullptr || + isNearExpiry(writeSasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasKeyGenerator_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsWriteOperation); + writeSasExpiration_ = getExpiry(sas); + writeClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(false), sas)); + } + return writeClient_.get(); + } + + DataLakeFileClient* getReadClient() { + if (readClient_ == nullptr || + isNearExpiry(readSasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasKeyGenerator_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsReadOperation); + readSasExpiration_ = getExpiry(sas); + readClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(false), sas)); + } + return readClient_.get(); + } +}; + +class DynamicSasTokenBlobClient : public AzureBlobClient { + public: + DynamicSasTokenBlobClient( + const std::shared_ptr& abfsPath, + const std::shared_ptr& sasTokenProvider, + int64_t sasTokenRenewPeriod) + : abfsPath_(abfsPath), + sasTokenProvider_(sasTokenProvider), + sasTokenRenewPeriod_(sasTokenRenewPeriod) {} + + Azure::Response getProperties() + override { + return getBlobClient()->GetProperties(); + } + + Azure::Response download( + const Azure::Storage::Blobs::DownloadBlobOptions& options) override { + return getBlobClient()->Download(options); + } + + std::string getUrl() override { + return getBlobClient()->GetUrl(); + } + + private: + std::shared_ptr abfsPath_; + std::shared_ptr sasTokenProvider_; + int64_t sasTokenRenewPeriod_; + + std::unique_ptr blobClient_{nullptr}; + Azure::DateTime sasExpiration_{Azure::DateTime::clock::time_point::min()}; + + BlobClient* getBlobClient() { + if (blobClient_ == nullptr || + isNearExpiry(sasExpiration_, sasTokenRenewPeriod_)) { + const auto& sas = sasTokenProvider_->getSasToken( + abfsPath_->fileSystem(), abfsPath_->filePath(), kAbfsReadOperation); + sasExpiration_ = getExpiry(sas); + blobClient_ = std::make_unique( + fmt::format("{}?{}", abfsPath_->getUrl(true), sas)); + } + return blobClient_.get(); + } +}; + +} // namespace + +DynamicSasTokenClientProvider::DynamicSasTokenClientProvider( + const std::shared_ptr& sasTokenProvider) + : AzureClientProvider(), sasTokenProvider_(sasTokenProvider) {} + +void DynamicSasTokenClientProvider::init(const config::ConfigBase& config) { + sasTokenRenewPeriod_ = config.get( + kAzureSasTokenRenewPeriod, kDefaultSasTokenRenewPeriod); +} + +std::unique_ptr +DynamicSasTokenClientProvider::getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(config); + return std::make_unique( + abfsPath, sasTokenProvider_, sasTokenRenewPeriod_); +} + +std::unique_ptr +DynamicSasTokenClientProvider::getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) { + init(config); + return std::make_unique( + abfsPath, sasTokenProvider_, sasTokenRenewPeriod_); +} +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h new file mode 100644 index 000000000000..ab1d53f0045b --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/hive/storage_adapters/abfs/AzureBlobClient.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureDataLakeFileClient.h" + +namespace facebook::velox::filesystems { + +/// SAS permissions reference: +/// https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-directory-container-or-blob +/// +/// ReadClient uses "read" permission for Download and GetProperties. +/// WriteClient uses "read" permission for GetProperties, and "write" permission +/// for other operations. +static const std::string kAbfsReadOperation{"read"}; +static const std::string kAbfsWriteOperation{"write"}; + +/// Interface for providing SAS tokens for ABFS file system operations. +/// Adapted from the Hadoop Azure implementation: +/// org.apache.hadoop.fs.azurebfs.extensions.SASTokenProvider +class SasTokenProvider { + public: + virtual ~SasTokenProvider() = default; + + virtual std::string getSasToken( + const std::string& fileSystem, + const std::string& path, + const std::string& operation) = 0; +}; + +/// Client provider that dynamically refreshes SAS tokens based on the +/// expiration time of the token. A SasTokenProvider for retrieving SAS tokens +/// must be provided to this class. Example for generating the SAS token can be +/// found in: +/// https://github.com/Azure/azure-sdk-for-cpp/blob/3d917e7c178f0a49b189395a907180084857cc70/sdk/storage/azure-storage-blobs/samples/blob_sas.cpp +class DynamicSasTokenClientProvider : public AzureClientProvider { + public: + explicit DynamicSasTokenClientProvider( + const std::shared_ptr& sasTokenProvider); + + std::unique_ptr getReadFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& abfsPath, + const config::ConfigBase& config) override; + + private: + void init(const config::ConfigBase& config); + + std::shared_ptr sasTokenProvider_; + int64_t sasTokenRenewPeriod_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp index 112aaf7f87f6..a87153928633 100644 --- a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.cpp @@ -14,16 +14,22 @@ * limitations under the License. */ +#include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" // @manual + #ifdef VELOX_ENABLE_ABFS +#include "velox/common/base/Exceptions.h" #include "velox/common/config/Config.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h" // @manual #include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" // @manual +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" // @manual +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" // @manual #include "velox/dwio/common/FileSink.h" #endif namespace facebook::velox::filesystems { #ifdef VELOX_ENABLE_ABFS + folly::once_flag abfsInitiationFlag; std::shared_ptr abfsFileSystemGenerator( @@ -60,4 +66,41 @@ void registerAbfsFileSystem() { #endif } +void registerAzureClientProvider(const config::ConfigBase& config) { +#ifdef VELOX_ENABLE_ABFS + + for (const auto& [accountName, authType] : + extractCacheKeyFromConfig(config)) { + if (authType == kAzureSharedKeyAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else if (authType == kAzureOAuthAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else if (authType == kAzureSASAuthType) { + AzureClientProviderFactories::registerFactory( + accountName, [](const std::string&) { + return std::make_unique(); + }); + } else { + VELOX_USER_FAIL( + "Unsupported auth type {}, supported auth types are SharedKey, OAuth and SAS.", + authType); + } + } +#endif +} + +void registerAzureClientProviderFactory( + const std::string& account, + const AzureClientProviderFactory& factory) { +#ifdef VELOX_ENABLE_ABFS + AzureClientProviderFactories::registerFactory(account, factory); +#endif +} + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h index cbe6758808a0..3d67c6ce667f 100644 --- a/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h @@ -16,9 +16,37 @@ #pragma once +#include +#include +#include + +namespace facebook::velox::config { + +class ConfigBase; + +} // namespace facebook::velox::config + namespace facebook::velox::filesystems { +class AzureClientProvider; +class AbfsPath; + +using AzureClientProviderFactory = + std::function( + const std::string& account)>; + // Register the ABFS filesystem. void registerAbfsFileSystem(); +/// Register the AzureClientProvider implementation in `AzureClientProviders` +/// based on the configuration. +void registerAzureClientProvider(const config::ConfigBase& config); + +/// Registers a factory for creating AzureClientProvider instances. +/// Any existing factory registered for the specified account will be +/// overwritten by recalling this method with the same account name. +void registerAzureClientProviderFactory( + const std::string& account, + const AzureClientProviderFactory& factory); + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp index aef53c2dd68d..245c6931cb30 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp @@ -24,10 +24,15 @@ #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" #include "velox/connectors/hive/FileHandle.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" + +#include "connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" +#include "connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" +#include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h" #include "velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h" #include "velox/dwio/common/FileSink.h" @@ -38,8 +43,34 @@ using namespace facebook::velox; using namespace facebook::velox::filesystems; using ::facebook::velox::common::Region; +namespace { + constexpr int kOneMB = 1 << 20; +class TestAzureClientProvider final : public AzureClientProvider { + public: + explicit TestAzureClientProvider() { + delegated_ = std::make_unique(); + } + + std::unique_ptr getReadFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) override { + return delegated_->getReadFileClient(path, config); + } + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) override { + return std::make_unique(); + } + + private: + std::unique_ptr delegated_; +}; + +} // namespace + class AbfsFileSystemTest : public testing::Test { public: std::shared_ptr azuriteServer_; @@ -47,12 +78,9 @@ class AbfsFileSystemTest : public testing::Test { static void SetUpTestCase() { registerAbfsFileSystem(); - AbfsConfig::setUpTestWriteClient( - []() { return std::make_unique(); }); - } - - static void TearDownTestSuite() { - AbfsConfig::tearDownTestWriteClient(); + registerAzureClientProviderFactory("test", [](const std::string&) { + return std::make_unique(); + }); } void SetUp() override { @@ -163,15 +191,14 @@ TEST_F(AbfsFileSystemTest, openFileForReadWithInvalidOptions) { TEST_F(AbfsFileSystemTest, fileHandleWithProperties) { FileHandleFactory factory( - std::make_unique>(1), + std::make_unique>(1), std::make_unique(azuriteServer_->hiveConfig())); FileProperties properties = {15 + kOneMB, 1}; - auto fileHandleProperties = - factory.generate(azuriteServer_->fileURI(), &properties); + FileHandleKey key{azuriteServer_->fileURI()}; + auto fileHandleProperties = factory.generate(key, &properties); readData(fileHandleProperties->file.get()); - auto fileHandleWithoutProperties = - factory.generate(azuriteServer_->fileURI()); + auto fileHandleWithoutProperties = factory.generate(key); readData(fileHandleWithoutProperties->file.get()); } @@ -265,12 +292,12 @@ TEST_F(AbfsFileSystemTest, notImplemented) { VELOX_ASSERT_THROW(abfs_->rmdir("dir"), "rmdir for abfs not implemented"); } -TEST_F(AbfsFileSystemTest, credNotFOund) { +TEST_F(AbfsFileSystemTest, clientProviderFactoryNotRegistered) { const std::string abfsFile = std::string("abfs://test@test1.dfs.core.windows.net/test"); VELOX_ASSERT_THROW( abfs_->openFileForRead(abfsFile), - "Config fs.azure.account.key.test1.dfs.core.windows.net not found"); + "No AzureClientProviderFactory registered for account 'test1'."); } TEST_F(AbfsFileSystemTest, registerAbfsFileSink) { diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsUtilTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsUtilTest.cpp new file mode 100644 index 000000000000..a8f9bd602745 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsUtilTest.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" + +#include + +using namespace facebook::velox::filesystems; + +TEST(AbfsUtilsTest, isAbfsFile) { + EXPECT_FALSE(isAbfsFile("abfs:")); + EXPECT_FALSE(isAbfsFile("abfss:")); + EXPECT_FALSE(isAbfsFile("abfs:/")); + EXPECT_FALSE(isAbfsFile("abfss:/")); + EXPECT_TRUE(isAbfsFile("abfs://test@test.dfs.core.windows.net/test")); + EXPECT_TRUE(isAbfsFile("abfss://test@test.dfs.core.windows.net/test")); +} diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AzureClientProviderFactoriesTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/AzureClientProviderFactoriesTest.cpp new file mode 100644 index 000000000000..c85c15ee5fbd --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AzureClientProviderFactoriesTest.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" +#include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" + +using namespace facebook::velox; +using namespace facebook::velox::filesystems; + +namespace { + +class DummyAzureClientProvider final : public AzureClientProvider { + public: + std::unique_ptr getReadFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) override { + VELOX_FAIL("DummyAzureClientProvider: Not implemented."); + } + + std::unique_ptr getWriteFileClient( + const std::shared_ptr& path, + const config::ConfigBase& config) override { + VELOX_FAIL("DummyAzureClientProvider: Not implemented."); + } +}; + +} // namespace + +TEST(AzureClientProviderFactoriesTest, registerFromConfig) { + const auto abfsPath = std::make_shared( + "abfss://abc@efg.dfs.core.windows.net/file/test.txt"); + + { + // OAuth auth type. + const config::ConfigBase config( + {{"fs.azure.account.auth.type.efg.dfs.core.windows.net", "OAuth"}, + {"fs.azure.account.oauth2.client.id.efg.dfs.core.windows.net", "123"}, + {"fs.azure.account.oauth2.client.secret.efg.dfs.core.windows.net", + "456"}, + {"fs.azure.account.oauth2.client.endpoint.efg.dfs.core.windows.net", + "https://login.microsoftonline.com/{TENANTID}/oauth2/token"}}, + false); + registerAzureClientProvider(config); + + ASSERT_NE( + AzureClientProviderFactories::getReadFileClient(abfsPath, config), + nullptr); + ASSERT_NE( + AzureClientProviderFactories::getWriteFileClient(abfsPath, config), + nullptr); + } + + { + // SharedKey auth type. + const config::ConfigBase config( + {{"fs.azure.account.auth.type.efg.dfs.core.windows.net", "SharedKey"}, + {"fs.azure.account.key.efg.dfs.core.windows.net", "456"}}, + false); + registerAzureClientProvider(config); + + ASSERT_NE( + AzureClientProviderFactories::getReadFileClient(abfsPath, config), + nullptr); + ASSERT_NE( + AzureClientProviderFactories::getWriteFileClient(abfsPath, config), + nullptr); + } + + { + // SAS auth type. + const config::ConfigBase config( + {{"fs.azure.account.auth.type.efg.dfs.core.windows.net", "SAS"}, + {"fs.azure.sas.fixed.token.efg.dfs.core.windows.net", "456"}}, + false); + registerAzureClientProvider(config); + + ASSERT_NE( + AzureClientProviderFactories::getReadFileClient(abfsPath, config), + nullptr); + ASSERT_NE( + AzureClientProviderFactories::getWriteFileClient(abfsPath, config), + nullptr); + } + + { + // Invalid auth type. + const config::ConfigBase config( + {{"fs.azure.account.auth.type.efg.dfs.core.windows.net", "Custom"}, + {"fs.azure.account.key.efg.dfs.core.windows.net", "456"}}, + false); + VELOX_ASSERT_THROW( + registerAzureClientProvider(config), + "Unsupported auth type Custom, supported auth types are SharedKey, OAuth and SAS."); + } + + { + // Invalid config key. + const config::ConfigBase config( + {{"fs.azure.account.auth.type.efg", "SharedKey"}, + {"fs.azure.account.key.efg.dfs.core.windows.net", "456"}}, + false); + VELOX_ASSERT_THROW( + registerAzureClientProvider(config), + "Invalid Azure account auth type key: fs.azure.account.auth.type.efg"); + } +} + +TEST(AzureClientProviderFactoriesTest, registerCustomFactory) { + static const std::string path = "abfs://test@efg.dfs.core.windows.net/test"; + const auto abfsPath = std::make_shared(path); + + registerAzureClientProviderFactory( + "efg", + [](const std::string& account) -> std::unique_ptr { + return std::make_unique(); + }); + + ASSERT_NO_THROW(AzureClientProviderFactories::getClientFactory("efg")); + VELOX_ASSERT_THROW( + AzureClientProviderFactories::getReadFileClient( + abfsPath, config::ConfigBase({})), + "DummyAzureClientProvider: Not implemented."); + VELOX_ASSERT_THROW( + AzureClientProviderFactories::getWriteFileClient( + abfsPath, config::ConfigBase({})), + "DummyAzureClientProvider: Not implemented."); + + VELOX_ASSERT_THROW( + AzureClientProviderFactories::getClientFactory("efg2"), + "No AzureClientProviderFactory registered for account 'efg2'."); +} diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsCommonTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/AzureClientProvidersTest.cpp similarity index 50% rename from velox/connectors/hive/storage_adapters/abfs/tests/AbfsCommonTest.cpp rename to velox/connectors/hive/storage_adapters/abfs/tests/AzureClientProvidersTest.cpp index 53fd09323b41..b9fd4346b237 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsCommonTest.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AzureClientProvidersTest.cpp @@ -14,43 +14,21 @@ * limitations under the License. */ +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" + +#include "connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/config/Config.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" #include "gtest/gtest.h" using namespace facebook::velox::filesystems; using namespace facebook::velox; -TEST(AbfsUtilsTest, isAbfsFile) { - EXPECT_FALSE(isAbfsFile("abfs:")); - EXPECT_FALSE(isAbfsFile("abfss:")); - EXPECT_FALSE(isAbfsFile("abfs:/")); - EXPECT_FALSE(isAbfsFile("abfss:/")); - EXPECT_TRUE(isAbfsFile("abfs://test@test.dfs.core.windows.net/test")); - EXPECT_TRUE(isAbfsFile("abfss://test@test.dfs.core.windows.net/test")); -} - -TEST(AbfsConfigTest, authType) { - const config::ConfigBase config( - {{"fs.azure.account.auth.type.efg.dfs.core.windows.net", "Custom"}, - {"fs.azure.account.key.efg.dfs.core.windows.net", "456"}}, - false); - VELOX_ASSERT_USER_THROW( - std::make_unique( - "abfss://foo@efg.dfs.core.windows.net/test.txt", config), - "Unsupported auth type Custom, supported auth types are SharedKey, OAuth and SAS."); -} - -TEST(AbfsConfigTest, clientSecretOAuth) { +TEST(AzureClientProvidersTest, clientSecretOAuth) { const config::ConfigBase config( - {{"fs.azure.account.auth.type.efg.dfs.core.windows.net", "OAuth"}, - {"fs.azure.account.auth.type.bar1.dfs.core.windows.net", "OAuth"}, - {"fs.azure.account.auth.type.bar2.dfs.core.windows.net", "OAuth"}, - {"fs.azure.account.auth.type.bar3.dfs.core.windows.net", "OAuth"}, - {"fs.azure.account.oauth2.client.id.efg.dfs.core.windows.net", "test"}, + {{"fs.azure.account.oauth2.client.id.efg.dfs.core.windows.net", "test"}, {"fs.azure.account.oauth2.client.secret.efg.dfs.core.windows.net", "test"}, {"fs.azure.account.oauth2.client.endpoint.efg.dfs.core.windows.net", @@ -60,27 +38,45 @@ TEST(AbfsConfigTest, clientSecretOAuth) { {"fs.azure.account.oauth2.client.secret.bar3.dfs.core.windows.net", "test"}}, false); + + auto clientProvider = OAuthAzureClientProvider(); + VELOX_ASSERT_USER_THROW( - std::make_unique( - "abfss://foo@bar1.dfs.core.windows.net/test.txt", config), + clientProvider.tenantIdAndAuthorityHost( + std::make_shared( + "abfss://foo@bar1.dfs.core.windows.net/test.txt"), + config), "Config fs.azure.account.oauth2.client.id.bar1.dfs.core.windows.net not found"); VELOX_ASSERT_USER_THROW( - std::make_unique( - "abfss://foo@bar2.dfs.core.windows.net/test.txt", config), + clientProvider.tenantIdAndAuthorityHost( + std::make_shared( + "abfss://foo@bar2.dfs.core.windows.net/test.txt"), + config), "Config fs.azure.account.oauth2.client.secret.bar2.dfs.core.windows.net not found"); VELOX_ASSERT_USER_THROW( - std::make_unique( - "abfss://foo@bar3.dfs.core.windows.net/test.txt", config), + clientProvider.tenantIdAndAuthorityHost( + std::make_shared( + "abfss://foo@bar3.dfs.core.windows.net/test.txt"), + config), "Config fs.azure.account.oauth2.client.endpoint.bar3.dfs.core.windows.net not found"); - auto abfsConfig = - AbfsConfig("abfss://abc@efg.dfs.core.windows.net/file/test.txt", config); - EXPECT_EQ(abfsConfig.tenentId(), "{TENANTID}"); - EXPECT_EQ(abfsConfig.authorityHost(), "https://login.microsoftonline.com/"); - auto readClient = abfsConfig.getReadFileClient(); + + const auto expectedTenantIdAndAuthorityHost = + std::make_pair( + "{TENANTID}", "https://login.microsoftonline.com/"); + EXPECT_EQ( + clientProvider.tenantIdAndAuthorityHost( + std::make_shared( + "abfss://abc@efg.dfs.core.windows.net/file/test.txt"), + config), + expectedTenantIdAndAuthorityHost); + + const auto abfsPath = std::make_shared( + "abfss://abc@efg.dfs.core.windows.net/file/test.txt"); + auto readClient = clientProvider.getReadFileClient(abfsPath, config); EXPECT_EQ( - readClient->GetUrl(), + readClient->getUrl(), "https://efg.blob.core.windows.net/abc/file/test.txt"); - auto writeClient = abfsConfig.getWriteFileClient(); + auto writeClient = clientProvider.getWriteFileClient(abfsPath, config); // GetUrl retrieves the value from the internal blob client, which represents // the blob's path as well. EXPECT_EQ( @@ -88,23 +84,27 @@ TEST(AbfsConfigTest, clientSecretOAuth) { "https://efg.blob.core.windows.net/abc/file/test.txt"); } -TEST(AbfsConfigTest, sasToken) { +TEST(AzureClientProviderTest, fixedSasToken) { const config::ConfigBase config( - {{"fs.azure.account.auth.type.efg.dfs.core.windows.net", "SAS"}, - {"fs.azure.account.auth.type.bar.dfs.core.windows.net", "SAS"}, - {"fs.azure.sas.fixed.token.bar.dfs.core.windows.net", "sas=test"}}, + {{"fs.azure.sas.fixed.token.bar.dfs.core.windows.net", "sas=test"}}, false); + + auto clientProvider = FixedSasAzureClientProvider(); + VELOX_ASSERT_USER_THROW( - std::make_unique( - "abfss://foo@efg.dfs.core.windows.net/test.txt", config), + clientProvider.sas( + std::make_shared( + "abfss://foo@efg.dfs.core.windows.net/test.txt"), + config), "Config fs.azure.sas.fixed.token.efg.dfs.core.windows.net not found"); - auto abfsConfig = - AbfsConfig("abfs://abc@bar.dfs.core.windows.net/file", config); - auto readClient = abfsConfig.getReadFileClient(); + + const auto abfsPath = + std::make_shared("abfs://abc@bar.dfs.core.windows.net/file"); + auto readClient = clientProvider.getReadFileClient(abfsPath, config); EXPECT_EQ( - readClient->GetUrl(), + readClient->getUrl(), "http://bar.blob.core.windows.net/abc/file?sas=test"); - auto writeClient = abfsConfig.getWriteFileClient(); + auto writeClient = clientProvider.getWriteFileClient(abfsPath, config); // GetUrl retrieves the value from the internal blob client, which represents // the blob's path as well. EXPECT_EQ( @@ -112,43 +112,47 @@ TEST(AbfsConfigTest, sasToken) { "http://bar.blob.core.windows.net/abc/file?sas=test"); } -TEST(AbfsConfigTest, sharedKey) { +TEST(AzureClientProviderTest, sharedKey) { const config::ConfigBase config( {{"fs.azure.account.key.efg.dfs.core.windows.net", "123"}, - {"fs.azure.account.auth.type.efg.dfs.core.windows.net", "SharedKey"}, {"fs.azure.account.key.foobar.dfs.core.windows.net", "456"}, {"fs.azure.account.key.bar.dfs.core.windows.net", "789"}}, false); - auto abfsConfig = - AbfsConfig("abfs://abc@efg.dfs.core.windows.net/file", config); - EXPECT_EQ(abfsConfig.fileSystem(), "abc"); - EXPECT_EQ(abfsConfig.filePath(), "file"); + const auto abfsPath = + std::make_shared("abfs://abc@efg.dfs.core.windows.net/file"); + EXPECT_EQ(abfsPath->fileSystem(), "abc"); + EXPECT_EQ(abfsPath->filePath(), "file"); + + auto clientProvider = SharedKeyAzureClientProvider(); EXPECT_EQ( - abfsConfig.connectionString(), + clientProvider.connectionString(abfsPath, config), "DefaultEndpointsProtocol=http;AccountName=efg;AccountKey=123;EndpointSuffix=core.windows.net;"); - auto abfssConfig = AbfsConfig( - "abfss://abc@foobar.dfs.core.windows.net/sf_1/store_sales/ss_sold_date_sk=2450816/part-00002-a29c25f1-4638-494e-8428-a84f51dcea41.c000.snappy.parquet", - config); - EXPECT_EQ(abfssConfig.fileSystem(), "abc"); + const auto abfssPath = std::make_shared( + "abfss://abc@foobar.dfs.core.windows.net/sf_1/store_sales/ss_sold_date_sk=2450816/part-00002-a29c25f1-4638-494e-8428-a84f51dcea41.c000.snappy.parquet"); + EXPECT_EQ(abfssPath->fileSystem(), "abc"); EXPECT_EQ( - abfssConfig.filePath(), + abfssPath->filePath(), "sf_1/store_sales/ss_sold_date_sk=2450816/part-00002-a29c25f1-4638-494e-8428-a84f51dcea41.c000.snappy.parquet"); EXPECT_EQ( - abfssConfig.connectionString(), + clientProvider.connectionString(abfssPath, config), "DefaultEndpointsProtocol=https;AccountName=foobar;AccountKey=456;EndpointSuffix=core.windows.net;"); // Test with special character space. - auto abfssConfigWithSpecialCharacters = AbfsConfig( - "abfss://foo@bar.dfs.core.windows.net/main@dir/sub dir/test.txt", config); - - EXPECT_EQ(abfssConfigWithSpecialCharacters.fileSystem(), "foo"); + const auto abfssPathWithSpecialCharacters = std::make_shared( + "abfss://foo@bar.dfs.core.windows.net/main@dir/sub dir/test.txt"); + EXPECT_EQ(abfssPathWithSpecialCharacters->fileSystem(), "foo"); + EXPECT_EQ( + abfssPathWithSpecialCharacters->filePath(), "main@dir/sub dir/test.txt"); EXPECT_EQ( - abfssConfigWithSpecialCharacters.filePath(), "main@dir/sub dir/test.txt"); + clientProvider.connectionString(abfssPathWithSpecialCharacters, config), + "DefaultEndpointsProtocol=https;AccountName=bar;AccountKey=789;EndpointSuffix=core.windows.net;"); VELOX_ASSERT_USER_THROW( - std::make_unique( - "abfss://foo@otheraccount.dfs.core.windows.net/test.txt", config), + clientProvider.connectionString( + std::make_shared( + "abfss://foo@otheraccount.dfs.core.windows.net/test.txt"), + config), "Config fs.azure.account.key.otheraccount.dfs.core.windows.net not found"); } diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.cpp index 726c2c531be8..9213068d926a 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.cpp @@ -15,7 +15,8 @@ */ #include "velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsConfig.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsPath.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderImpl.h" namespace facebook::velox::filesystems { @@ -108,9 +109,10 @@ AzuriteServer::AzuriteServer(int64_t port) : port_(port) { } void AzuriteServer::addFile(std::string source) { - AbfsConfig conf(fileURI(), *hiveConfig()); + const auto abfsPath = std::make_shared(fileURI()); + auto clientProvider = SharedKeyAzureClientProvider(); auto containerClient = BlobContainerClient::CreateFromConnectionString( - conf.connectionString(), container_); + clientProvider.connectionString(abfsPath, *hiveConfig()), container_); containerClient.CreateIfNotExists(); auto blobClient = containerClient.GetBlockBlobClient(file_); blobClient.UploadFrom(source); diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt index 2246dadd385f..c81471db9f68 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt @@ -12,12 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_abfs_test AbfsFileSystemTest.cpp AbfsCommonTest.cpp - AzuriteServer.cpp MockDataLakeFileClient.cpp) +add_executable( + velox_abfs_test + AbfsFileSystemTest.cpp + AbfsUtilTest.cpp + AzureClientProvidersTest.cpp + AzureClientProviderFactoriesTest.cpp + DynamicSasTokenClientProviderTest.cpp + AzuriteServer.cpp + MockDataLakeFileClient.cpp +) add_test(velox_abfs_test velox_abfs_test) target_link_libraries( velox_abfs_test - PRIVATE velox_abfs velox_exec_test_lib GTest::gtest GTest::gtest_main) + PRIVATE velox_abfs velox_exec_test_lib GTest::gtest GTest::gtest_main +) target_compile_options(velox_abfs_test PRIVATE -Wno-deprecated-declarations) diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp new file mode 100644 index 000000000000..8793aea080b9 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/tests/DynamicSasTokenClientProviderTest.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/DynamicSasTokenClientProvider.h" +#include "velox/connectors/hive/storage_adapters/abfs/AzureClientProviderFactories.h" +#include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" + +#include "gtest/gtest.h" + +#include +#include +#include + +using namespace facebook::velox::filesystems; +using namespace facebook::velox; + +namespace { + +class MyDynamicAbfsSasTokenProvider : public SasTokenProvider { + public: + MyDynamicAbfsSasTokenProvider(int64_t expiration) + : expirationSeconds_(expiration) {} + + std::string getSasToken( + const std::string& fileSystem, + const std::string& path, + const std::string& operation) override { + const auto lastSlash = path.find_last_of("/"); + const auto containerName = path.substr(0, lastSlash); + const auto blobName = path.substr(lastSlash + 1); + + Azure::Storage::Sas::BlobSasBuilder sasBuilder; + sasBuilder.ExpiresOn = Azure::DateTime::clock::now() + + std::chrono::seconds(expirationSeconds_); + sasBuilder.BlobContainerName = containerName; + sasBuilder.BlobName = blobName; + sasBuilder.Resource = Azure::Storage::Sas::BlobSasResource::Blob; + sasBuilder.SetPermissions( + Azure::Storage::Sas::BlobSasPermissions::Read & + Azure::Storage::Sas::BlobSasPermissions::Write); + + std::string sasToken = sasBuilder.GenerateSasToken( + Azure::Storage::StorageSharedKeyCredential( + "test", + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==")); + + // Remove the leading '?' from the SAS token. + if (sasToken[0] == '?') { + sasToken = sasToken.substr(1); + } + + return sasToken; + } + + private: + int64_t expirationSeconds_; +}; + +} // namespace + +TEST(DynamicSasTokenClientProviderTest, dynamicSasToken) { + { + const std::string account = "account1"; + const config::ConfigBase config( + {{"fs.azure.account.auth.type.account1.dfs.core.windows.net", "SAS"}, + {"fs.azure.sas.token.renew.period.for.streams", "1"}}, + false); + registerAzureClientProviderFactory(account, [](const std::string&) { + auto sasTokenProvider = + std::make_shared(3); + return std::make_unique(sasTokenProvider); + }); + + auto abfsPath = std::make_shared( + fmt::format("abfs://abc@{}.dfs.core.windows.net/file", account)); + auto readClient = + AzureClientProviderFactories::getReadFileClient(abfsPath, config); + auto writeClient = + AzureClientProviderFactories::getWriteFileClient(abfsPath, config); + + auto readUrl = readClient->getUrl(); + auto writeUrl = writeClient->getUrl(); + + // Let the current time pass 3 seconds to ensure the SAS token is expired. + std::this_thread::sleep_for(std::chrono::seconds(3)); // NOLINT + + auto newReadUrl = readClient->getUrl(); + ASSERT_NE(readUrl, newReadUrl); + // The SAS token should be reused. + ASSERT_EQ(newReadUrl, readClient->getUrl()); + + auto newWriteUrl = writeClient->getUrl(); + ASSERT_NE(writeUrl, newWriteUrl); + // The SAS token should be reused. + ASSERT_EQ(newWriteUrl, writeClient->getUrl()); + } + + { + // SAS token expired by setting the renewal period to 120 seconds. + const std::string account = "account2"; + const config::ConfigBase config( + {{"fs.azure.account.auth.type.account2.dfs.core.windows.net", "SAS"}, + {"fs.azure.sas.token.renew.period.for.streams", "120"}}, + false); + registerAzureClientProviderFactory(account, [](const std::string&) { + auto sasTokenProvider = + std::make_shared(60); + return std::make_unique(sasTokenProvider); + }); + + auto abfsPath = std::make_shared( + fmt::format("abfs://abc@{}.dfs.core.windows.net/file", account)); + auto readClient = + AzureClientProviderFactories::getReadFileClient(abfsPath, config); + auto writeClient = + AzureClientProviderFactories::getWriteFileClient(abfsPath, config); + + auto readUrl = readClient->getUrl(); + auto writeUrl = writeClient->getUrl(); + + // Let the current time pass 3 seconds to ensure the timestamp in the SAS + // token is updated. + std::this_thread::sleep_for(std::chrono::seconds(3)); // NOLINT + + // Sas token should be renewed because the time left is less than the + // renewal period. + ASSERT_NE(readUrl, readClient->getUrl()); + ASSERT_NE(writeUrl, writeClient->getUrl()); + } +} diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h b/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h index 5b874fff56b5..560294414c8a 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h +++ b/velox/connectors/hive/storage_adapters/abfs/tests/MockDataLakeFileClient.h @@ -46,7 +46,7 @@ class MockDataLakeFileClient : public AzureDataLakeFileClient { void close() override; - std::string getUrl() const override { + std::string getUrl() override { return "testUrl"; } diff --git a/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt index e9f5f98520c0..7e110edac193 100644 --- a/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/gcs/CMakeLists.txt @@ -17,9 +17,8 @@ velox_add_library(velox_gcs RegisterGcsFileSystem.cpp) if(VELOX_ENABLE_GCS) - velox_sources(velox_gcs PRIVATE GcsFileSystem.cpp GcsUtil.cpp) - velox_link_libraries(velox_gcs velox_dwio_common Folly::folly - google-cloud-cpp::storage) + velox_sources(velox_gcs PRIVATE GcsFileSystem.cpp GcsUtil.cpp GcsWriteFile.cpp GcsReadFile.cpp) + velox_link_libraries(velox_gcs velox_dwio_common Folly::folly google-cloud-cpp::storage) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp b/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp index a2dd4e77d985..8ec4873e28de 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.cpp @@ -17,9 +17,10 @@ #include "velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h" #include "velox/common/base/Exceptions.h" #include "velox/common/config/Config.h" -#include "velox/common/file/File.h" #include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h" #include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" +#include "velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.h" #include "velox/core/QueryConfig.h" #include @@ -30,244 +31,50 @@ #include namespace facebook::velox { -namespace { +namespace filesystems { +using namespace connector::hive; namespace gcs = ::google::cloud::storage; namespace gc = ::google::cloud; -// Reference: https://github.com/apache/arrow/issues/29916 -// Change the default upload buffer size. In general, sending larger buffers is -// more efficient with GCS, as each buffer requires a roundtrip to the service. -// With formatted output (when using `operator<<`), keeping a larger buffer in -// memory before uploading makes sense. With unformatted output (the only -// choice given gcs::io::OutputStream's API) it is better to let the caller -// provide as large a buffer as they want. The GCS C++ client library will -// upload this buffer with zero copies if possible. -auto constexpr kUploadBufferSize = 256 * 1024; - -inline void checkGcsStatus( - const gc::Status outcome, - const std::string_view& errorMsgPrefix, - const std::string& bucket, - const std::string& key) { - if (!outcome.ok()) { - const auto errMsg = fmt::format( - "{} due to: Path:'{}', SDK Error Type:{}, GCS Status Code:{}, Message:'{}'", - errorMsgPrefix, - gcsURI(bucket, key), - outcome.error_info().domain(), - getErrorStringFromGcsError(outcome.code()), - outcome.message()); - if (outcome.code() == gc::StatusCode::kNotFound) { - VELOX_FILE_NOT_FOUND_ERROR(errMsg); - } - VELOX_FAIL(errMsg); - } -} - -class GcsReadFile final : public ReadFile { - public: - GcsReadFile(const std::string& path, std::shared_ptr client) - : client_(std::move(client)) { - // assumption it's a proper path - setBucketAndKeyFromGcsPath(path, bucket_, key_); - } - - // Gets the length of the file. - // Checks if there are any issues reading the file. - void initialize(const filesystems::FileOptions& options) { - if (options.fileSize.has_value()) { - VELOX_CHECK_GE( - options.fileSize.value(), 0, "File size must be non-negative"); - length_ = options.fileSize.value(); - } - - // Make it a no-op if invoked twice. - if (length_ != -1) { - return; - } - // get metadata and initialize length - auto metadata = client_->GetObjectMetadata(bucket_, key_); - if (!metadata.ok()) { - checkGcsStatus( - metadata.status(), - "Failed to get metadata for GCS object", - bucket_, - key_); - } - length_ = (*metadata).size(); - VELOX_CHECK_GE(length_, 0); - } - - std::string_view pread( - uint64_t offset, - uint64_t length, - void* buffer, - filesystems::File::IoStats* stats = nullptr) const override { - preadInternal(offset, length, static_cast(buffer)); - return {static_cast(buffer), length}; - } - - std::string pread( - uint64_t offset, - uint64_t length, - filesystems::File::IoStats* stats = nullptr) const override { - std::string result(length, 0); - char* position = result.data(); - preadInternal(offset, length, position); - return result; - } - - uint64_t preadv( - uint64_t offset, - const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override { - // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in - // between. This call must populate the ranges (except gap ranges) - // sequentially starting from 'offset'. If a range pointer is nullptr, the - // data from stream of size range.size() will be skipped. - size_t length = 0; - for (const auto range : buffers) { - length += range.size(); - } - std::string result(length, 0); - preadInternal(offset, length, static_cast(result.data())); - size_t resultOffset = 0; - for (auto range : buffers) { - if (range.data()) { - memcpy(range.data(), &(result.data()[resultOffset]), range.size()); - } - resultOffset += range.size(); - } - return length; - } - - uint64_t size() const override { - return length_; - } - uint64_t memoryUsage() const override { - return sizeof(GcsReadFile) // this class - + sizeof(gcs::Client) // pointee - + kUploadBufferSize; // buffer size - } - - bool shouldCoalesce() const final { - return false; - } - - std::string getName() const override { - return key_; - } - - uint64_t getNaturalReadSize() const override { - return kUploadBufferSize; - } - - private: - // The assumption here is that "position" has space for at least "length" - // bytes. - void preadInternal(uint64_t offset, uint64_t length, char* position) const { - gcs::ObjectReadStream stream = client_->ReadObject( - bucket_, key_, gcs::ReadRange(offset, offset + length)); - if (!stream) { - checkGcsStatus( - stream.status(), "Failed to get GCS object", bucket_, key_); - } - - stream.read(position, length); - if (!stream) { - checkGcsStatus( - stream.status(), "Failed to get read object", bucket_, key_); - } - bytesRead_ += length; - } - - std::shared_ptr client_; - std::string bucket_; - std::string key_; - std::atomic length_ = -1; -}; - -class GcsWriteFile final : public WriteFile { - public: - explicit GcsWriteFile( - const std::string& path, - std::shared_ptr client) - : client_(client) { - setBucketAndKeyFromGcsPath(path, bucket_, key_); - } - - ~GcsWriteFile() { - close(); - } - - void initialize() { - // Make it a no-op if invoked twice. - if (size_ != -1) { - return; - } - - // Check that it doesn't exist, if it does throw an error - auto object_metadata = client_->GetObjectMetadata(bucket_, key_); - VELOX_CHECK(!object_metadata.ok(), "File already exists"); - - auto stream = client_->WriteObject(bucket_, key_); - checkGcsStatus( - stream.last_status(), - "Failed to open GCS object for writing", - bucket_, - key_); - stream_ = std::move(stream); - size_ = 0; - } - - void append(const std::string_view data) override { - VELOX_CHECK(isFileOpen(), "File is not open"); - stream_ << data; - size_ += data.size(); - } - - void flush() override { - if (isFileOpen()) { - stream_.flush(); - } - } +namespace { - void close() override { - if (isFileOpen()) { - stream_.flush(); - stream_.Close(); - closed_ = true; - } - } +auto constexpr kGcsInvalidPath = "File {} is not a valid gcs file"; - uint64_t size() const override { - return size_; - } +folly::Synchronized< + std::unordered_map>& +credentialsProviderFactories() { + static folly::Synchronized< + std::unordered_map> + factories; + return factories; +} - private: - inline bool isFileOpen() { - return (!closed_ && stream_.IsOpen()); - } +std::shared_ptr getCredentialsProviderByName( + const std::string& providerName, + const std::shared_ptr& hiveConfig) { + VELOX_USER_CHECK( + !providerName.empty(), + "GcsOAuthCredentialsProviderFactory name cannot be empty"); + return credentialsProviderFactories().withRLock([&](const auto& factories) { + const auto it = factories.find(providerName); + VELOX_USER_CHECK( + it != factories.end(), + "GcsOAuthCredentialsProviderFactory for '{}' not registered", + providerName); + const auto& factory = it->second; + return factory(hiveConfig); + }); +} - gcs::ObjectWriteStream stream_; - std::shared_ptr client_; - std::string bucket_; - std::string key_; - std::atomic size_{-1}; - std::atomic closed_{false}; -}; } // namespace -namespace filesystems { -using namespace connector::hive; - -auto constexpr kGcsInvalidPath = "File {} is not a valid gcs file"; - class GcsFileSystem::Impl { public: - Impl(const config::ConfigBase* config) - : hiveConfig_(std::make_shared( - std::make_shared(config->rawConfigsCopy()))) {} + Impl(const std::string& bucket, const config::ConfigBase* config) + : bucket_(bucket), + hiveConfig_( + std::make_shared(std::make_shared( + config->rawConfigsCopy()))) {} ~Impl() = default; @@ -275,21 +82,28 @@ class GcsFileSystem::Impl { void initializeClient() { constexpr std::string_view kHttpsScheme{"https://"}; auto options = gc::Options{}; - auto endpointOverride = hiveConfig_->gcsEndpoint(); - // Use secure credentials by default. - if (!endpointOverride.empty()) { - options.set(endpointOverride); - // Use Google default credentials if endpoint has https scheme. - if (endpointOverride.find(kHttpsScheme) == 0) { - options.set( - gc::MakeGoogleDefaultCredentials()); + if (auto tokenProvider = hiveConfig_->gcsAuthAccessTokenProvider()) { + auto credentialsProvider = + getCredentialsProviderByName(tokenProvider.value(), hiveConfig_); + auto credentials = credentialsProvider->getCredentials(bucket_); + options.set(credentials); + } else { + auto endpointOverride = hiveConfig_->gcsEndpoint(); + // Use secure credentials by default. + if (!endpointOverride.empty()) { + options.set(endpointOverride); + // Use Google default credentials if endpoint has https scheme. + if (endpointOverride.find(kHttpsScheme) == 0) { + options.set( + gc::MakeGoogleDefaultCredentials()); + } else { + options.set( + gc::MakeInsecureCredentials()); + } } else { options.set( - gc::MakeInsecureCredentials()); + gc::MakeGoogleDefaultCredentials()); } - } else { - options.set( - gc::MakeGoogleDefaultCredentials()); } options.set(kUploadBufferSize); @@ -332,13 +146,16 @@ class GcsFileSystem::Impl { } private: + const std::string bucket_; const std::shared_ptr hiveConfig_; std::shared_ptr client_; }; -GcsFileSystem::GcsFileSystem(std::shared_ptr config) +GcsFileSystem::GcsFileSystem( + const std::string& bucket, + std::shared_ptr config) : FileSystem(config) { - impl_ = std::make_shared(config.get()); + impl_ = std::make_shared(bucket, config.get()); } void GcsFileSystem::initializeClient() { @@ -436,18 +253,133 @@ std::string GcsFileSystem::name() const { return "GCS"; } -void GcsFileSystem::rename(std::string_view, std::string_view, bool) { - VELOX_UNSUPPORTED("rename for GCS not implemented"); +void GcsFileSystem::rename( + std::string_view originPath, + std::string_view newPath, + bool overwrite) { + if (!isGcsFile(originPath)) { + VELOX_FAIL(kGcsInvalidPath, originPath); + } + + if (!isGcsFile(newPath)) { + VELOX_FAIL(kGcsInvalidPath, newPath); + } + + std::string originBucket; + std::string originObject; + const auto originFile = gcsPath(originPath); + setBucketAndKeyFromGcsPath(originFile, originBucket, originObject); + + std::string newBucket; + std::string newObject; + const auto newFile = gcsPath(newPath); + setBucketAndKeyFromGcsPath(newFile, newBucket, newObject); + + if (!overwrite) { + auto objects = list(newPath); + if (std::find(objects.begin(), objects.end(), newObject) != objects.end()) { + VELOX_USER_FAIL( + "Failed to rename object {} to {} with as {} exists.", + originObject, + newObject, + newObject); + return; + } + } + + // Copy the object to the new name. + auto copyStats = impl_->getClient()->CopyObject( + originBucket, originObject, newBucket, newObject); + if (!copyStats.ok()) { + checkGcsStatus( + copyStats.status(), + fmt::format( + "Failed to rename for GCS object {}/{}", + originBucket, + originObject), + originBucket, + originObject); + } + + // Delete the original object. + auto delStatus = impl_->getClient()->DeleteObject(originBucket, originObject); + if (!delStatus.ok()) { + checkGcsStatus( + delStatus, + fmt::format( + "Failed to delete for GCS object {}/{} after copy when renaming. And the copied object is at {}/{}", + originBucket, + originObject, + newBucket, + newObject), + originBucket, + originObject); + } } void GcsFileSystem::mkdir( std::string_view path, const DirectoryOptions& options) { - VELOX_UNSUPPORTED("mkdir for GCS not implemented"); + if (!isGcsFile(path)) { + VELOX_FAIL(kGcsInvalidPath, path); + } + + std::string bucket; + std::string object; + const auto file = gcsPath(path); + setBucketAndKeyFromGcsPath(file, bucket, object); + + // Create an empty object to represent the directory. + auto status = impl_->getClient()->InsertObject(bucket, object, ""); + + checkGcsStatus( + status.status(), + fmt::format("Failed to mkdir for GCS object {}/{}", bucket, object), + bucket, + object); } void GcsFileSystem::rmdir(std::string_view path) { - VELOX_UNSUPPORTED("rmdir for GCS not implemented"); + if (!isGcsFile(path)) { + VELOX_FAIL(kGcsInvalidPath, path); + } + + const auto file = gcsPath(path); + std::string bucket; + std::string object; + setBucketAndKeyFromGcsPath(file, bucket, object); + for (auto&& metadata : impl_->getClient()->ListObjects(bucket)) { + checkGcsStatus( + metadata.status(), + fmt::format("Failed to rmdir for GCS object {}/{}", bucket, object), + bucket, + object); + + auto status = impl_->getClient()->DeleteObject(bucket, metadata->name()); + checkGcsStatus( + metadata.status(), + fmt::format( + "Failed to delete for GCS object {}/{} when rmdir.", + bucket, + metadata->name()), + bucket, + metadata->name()); + } +} + +void registerOAuthCredentialsProvider( + const std::string& providerName, + const GcsOAuthCredentialsProviderFactory& factory) { + VELOX_CHECK( + !providerName.empty(), + "GcsOAuthCredentialsProviderFactory name cannot be empty"); + credentialsProviderFactories().withWLock([&](auto& factories) { + VELOX_CHECK( + factories.find(providerName) == factories.end(), + "GcsOAuthCredentialsProviderFactory '{}' already registered", + providerName); + factories.insert({providerName, factory}); + }); } } // namespace filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h b/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h index 34daff8d6c64..eeef856bed05 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h @@ -17,6 +17,8 @@ #pragma once #include "velox/common/file/FileSystems.h" +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/storage_adapters/gcs/GcsOAuthCredentialsProvider.h" namespace facebook::velox::filesystems { @@ -26,7 +28,9 @@ namespace facebook::velox::filesystems { /// (register|generate)ReadFile and (register|generate)WriteFile functions. class GcsFileSystem : public FileSystem { public: - explicit GcsFileSystem(std::shared_ptr config); + explicit GcsFileSystem( + const std::string& bucket, + std::shared_ptr config); /// Initialize the google::cloud::storage::Client from the input Config /// parameters. @@ -66,7 +70,8 @@ class GcsFileSystem : public FileSystem { /// Returns the name of the adapter (GCS) std::string name() const override; - /// Unsupported + /// Removes the objects associated to a path by using + /// google::cloud::storage::Client::DeleteObject. void remove(std::string_view path) override; /// Check that the path exists by using @@ -77,14 +82,25 @@ class GcsFileSystem : public FileSystem { /// google::cloud::storage::Client::ListObjects std::vector list(std::string_view path) override; - /// Unsupported - void rename(std::string_view, std::string_view, bool) override; - - /// Unsupported + /// Renames the original object to the new object using + /// google::cloud::storage::Client::CopyObject and + /// google::cloud::storage::Client::DeleteObject. + /// Note that this process involves separate copy and delete operations, which + /// may lead to temporary inconsistencies if either operation fails or if + /// there is a delay between them. + void rename( + std::string_view originPath, + std::string_view newPath, + bool overwrite) override; + + /// Supports mkdir operation by using + /// google::cloud::storage::Client::InsertObject void mkdir(std::string_view path, const DirectoryOptions& options = {}) override; - /// Unsupported + /// Deletes the objects associated to a path using + /// google::cloud::storage::Client::ListObjects and + /// google::cloud::storage::Client::DeleteObjects void rmdir(std::string_view path) override; protected: @@ -92,4 +108,12 @@ class GcsFileSystem : public FileSystem { std::shared_ptr impl_; }; +using GcsOAuthCredentialsProviderFactory = + std::function( + const std::shared_ptr& hiveConfig)>; + +void registerOAuthCredentialsProvider( + const std::string& providerName, + const GcsOAuthCredentialsProviderFactory& factory); + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsOAuthCredentialsProvider.h b/velox/connectors/hive/storage_adapters/gcs/GcsOAuthCredentialsProvider.h new file mode 100644 index 000000000000..92a89d3e7321 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/GcsOAuthCredentialsProvider.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::filesystems { + +namespace gcs = ::google::cloud::storage; + +/// Interface for providing OAuth2 credentials for Google Cloud Storage (GCS). +/// Implementations should return a GCS OAuth2 credential used for creating the +/// GCS client for a specific bucket via `getCredentials`. +class GcsOAuthCredentialsProvider { + public: + virtual ~GcsOAuthCredentialsProvider() = default; + + virtual std::shared_ptr getCredentials( + const std::string& bucket) = 0; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp new file mode 100644 index 000000000000..e8cf14dbb2f4 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h" +#include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" + +namespace facebook::velox::filesystems { + +namespace gcs = ::google::cloud::storage; + +class GcsReadFile::Impl { + public: + Impl(const std::string& path, std::shared_ptr client) + : client_(client) { + setBucketAndKeyFromGcsPath(path, bucket_, key_); + } + + // Gets the length of the file. + // Checks if there are any issues reading the file. + void initialize(const filesystems::FileOptions& options) { + if (options.fileSize.has_value()) { + VELOX_CHECK_GE( + options.fileSize.value(), 0, "File size must be non-negative"); + length_ = options.fileSize.value(); + } + + // Make it a no-op if invoked twice. + if (length_ != -1) { + return; + } + // get metadata and initialize length + auto metadata = client_->GetObjectMetadata(bucket_, key_); + if (!metadata.ok()) { + checkGcsStatus( + metadata.status(), + "Failed to get metadata for GCS object", + bucket_, + key_); + } + length_ = (*metadata).size(); + VELOX_CHECK_GE(length_, 0); + } + + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buffer, + std::atomic& bytesRead, + const FileStorageContext& fileStorageContext) const { + preadInternal(offset, length, static_cast(buffer), bytesRead); + return {static_cast(buffer), length}; + } + + std::string pread( + uint64_t offset, + uint64_t length, + std::atomic& bytesRead, + const FileStorageContext& fileStorageContext) const { + std::string result(length, 0); + char* position = result.data(); + preadInternal(offset, length, position, bytesRead); + return result; + } + + uint64_t preadv( + uint64_t offset, + const std::vector>& buffers, + std::atomic& bytesRead, + const FileStorageContext& fileStorageContext) const { + // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in + // between. This call must populate the ranges (except gap ranges) + // sequentially starting from 'offset'. If a range pointer is nullptr, the + // data from stream of size range.size() will be skipped. + size_t length = 0; + for (const auto range : buffers) { + length += range.size(); + } + std::string result(length, 0); + preadInternal(offset, length, static_cast(result.data()), bytesRead); + size_t resultOffset = 0; + for (auto range : buffers) { + if (range.data()) { + memcpy(range.data(), &(result.data()[resultOffset]), range.size()); + } + resultOffset += range.size(); + } + return length; + } + + uint64_t size() const { + return length_; + } + + uint64_t memoryUsage() const { + return sizeof(GcsReadFile) // this class + + sizeof(gcs::Client) // pointee + + kUploadBufferSize; // buffer size + } + + std::string getName() const { + return key_; + } + + private: + // The assumption here is that "position" has space for at least "length" + // bytes. + void preadInternal( + uint64_t offset, + uint64_t length, + char* position, + std::atomic& bytesRead_) const { + gcs::ObjectReadStream stream = client_->ReadObject( + bucket_, key_, gcs::ReadRange(offset, offset + length)); + if (!stream) { + checkGcsStatus( + stream.status(), "Failed to get GCS object", bucket_, key_); + } + + stream.read(position, length); + if (!stream) { + checkGcsStatus( + stream.status(), "Failed to get read object", bucket_, key_); + } + bytesRead_ += length; + } + + std::shared_ptr client_; + std::string bucket_; + std::string key_; + std::atomic length_ = -1; +}; + +GcsReadFile::GcsReadFile( + const std::string& path, + std::shared_ptr client) + : impl_(std::make_unique(path, client)) {} + +GcsReadFile::~GcsReadFile() = default; + +void GcsReadFile::initialize(const filesystems::FileOptions& options) { + impl_->initialize(options); +} + +std::string_view GcsReadFile::pread( + uint64_t offset, + uint64_t length, + void* buffer, + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, buffer, bytesRead_, fileStorageContext); +} + +std::string GcsReadFile::pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, bytesRead_, fileStorageContext); +} +uint64_t GcsReadFile::preadv( + uint64_t offset, + const std::vector>& buffers, + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(offset, buffers, bytesRead_, fileStorageContext); +} + +uint64_t GcsReadFile::size() const { + return impl_->size(); +} + +uint64_t GcsReadFile::memoryUsage() const { + return impl_->memoryUsage(); +} + +std::string GcsReadFile::getName() const { + return impl_->getName(); +} + +uint64_t GcsReadFile::getNaturalReadSize() const { + return kUploadBufferSize; +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h new file mode 100644 index 000000000000..6e79ee34afde --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/GcsReadFile.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/common/file/File.h" + +namespace facebook::velox::filesystems { + +/** + * Implementation of gcs read file. + */ +class GcsReadFile : public ReadFile { + public: + GcsReadFile( + const std::string& path, + std::shared_ptr<::google::cloud::storage::Client> client); + + ~GcsReadFile() override; + + void initialize(const filesystems::FileOptions& options); + + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buffer, + const FileStorageContext& fileStorageContext = {}) const override; + + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext = {}) const override; + + uint64_t preadv( + uint64_t offset, + const std::vector>& buffers, + const FileStorageContext& fileStorageContext = {}) const override; + + uint64_t size() const override; + + uint64_t memoryUsage() const override; + + bool shouldCoalesce() const final { + return false; + } + + std::string getName() const override; + + uint64_t getNaturalReadSize() const override; + + protected: + class Impl; + std::shared_ptr impl_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsUtil.cpp b/velox/connectors/hive/storage_adapters/gcs/GcsUtil.cpp index 8035e4a1e29f..b9fffb3fa1d4 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsUtil.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/GcsUtil.cpp @@ -34,4 +34,24 @@ std::string getErrorStringFromGcsError(const google::cloud::StatusCode& code) { } } +void checkGcsStatus( + const google::cloud::Status outcome, + const std::string_view& errorMsgPrefix, + const std::string& bucket, + const std::string& key) { + if (!outcome.ok()) { + const auto errMsg = fmt::format( + "{} due to: Path:'{}', SDK Error Type:{}, GCS Status Code:{}, Message:'{}'", + errorMsgPrefix, + gcsURI(bucket, key), + outcome.error_info().domain(), + getErrorStringFromGcsError(outcome.code()), + outcome.message()); + if (outcome.code() == google::cloud::StatusCode::kNotFound) { + VELOX_FILE_NOT_FOUND_ERROR(errMsg); + } + VELOX_FAIL(errMsg); + } +} + } // namespace facebook::velox diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsUtil.h b/velox/connectors/hive/storage_adapters/gcs/GcsUtil.h index e16736fb938f..ac02793ad111 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GcsUtil.h +++ b/velox/connectors/hive/storage_adapters/gcs/GcsUtil.h @@ -20,6 +20,18 @@ namespace facebook::velox { +// Reference: https://github.com/apache/arrow/issues/29916 +// Change the default upload buffer size. In general, sending larger buffers is +// more efficient with GCS, as each buffer requires a roundtrip to the service. +// With formatted output (when using `operator<<`), keeping a larger buffer in +// memory before uploading makes sense. With unformatted output (the only +// choice given gcs::io::OutputStream's API) it is better to let the caller +// provide as large a buffer as they want. The GCS C++ client library will +// upload this buffer with zero copies if possible. +auto constexpr kUploadBufferSize = 256 * 1024; + +constexpr const char* kGcsDefaultCacheKeyPrefix = "gcs-default-key"; + namespace { constexpr const char* kSep{"/"}; constexpr std::string_view kGcsScheme{"gs://"}; @@ -58,4 +70,10 @@ inline std::string gcsPath(const std::string_view& path) { return std::string(path.substr(kGcsScheme.length())); } +void checkGcsStatus( + const google::cloud::Status outcome, + const std::string_view& errorMsgPrefix, + const std::string& bucket, + const std::string& key); + } // namespace facebook::velox diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.cpp b/velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.cpp new file mode 100644 index 000000000000..d47bc64db69c --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.h" +#include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" + +namespace facebook::velox::filesystems { + +namespace gcs = ::google::cloud::storage; + +class GcsWriteFile::Impl { + public: + Impl(const std::string& path, std::shared_ptr client) + : client_(client) { + setBucketAndKeyFromGcsPath(path, bucket_, key_); + } + + ~Impl() { + close(); + } + + void initialize() { + // Make it a no-op if invoked twice. + if (size_ != -1) { + return; + } + + // Check that it doesn't exist, if it does throw an error + auto object_metadata = client_->GetObjectMetadata(bucket_, key_); + VELOX_CHECK(!object_metadata.ok(), "File already exists"); + + auto stream = client_->WriteObject(bucket_, key_); + checkGcsStatus( + stream.last_status(), + "Failed to open GCS object for writing", + bucket_, + key_); + stream_ = std::move(stream); + size_ = 0; + } + + void append(const std::string_view data) { + VELOX_CHECK(isFileOpen(), "File is not open"); + stream_ << data; + size_ += data.size(); + } + + void flush() { + if (isFileOpen()) { + stream_.flush(); + } + } + + void close() { + if (isFileOpen()) { + stream_.flush(); + stream_.Close(); + closed_ = true; + } + } + + uint64_t size() const { + return size_; + } + + private: + inline bool isFileOpen() { + return (!closed_ && stream_.IsOpen()); + } + + gcs::ObjectWriteStream stream_; + std::shared_ptr client_; + std::string bucket_; + std::string key_; + std::atomic size_{-1}; + std::atomic closed_{false}; +}; + +GcsWriteFile::GcsWriteFile( + const std::string& path, + std::shared_ptr client) + : impl_(std::make_unique(path, client)) {} + +GcsWriteFile::~GcsWriteFile() = default; + +void GcsWriteFile::initialize() { + impl_->initialize(); +} + +void GcsWriteFile::append(const std::string_view data) { + impl_->append(data); +} + +void GcsWriteFile::flush() { + impl_->flush(); +} + +void GcsWriteFile::close() { + impl_->close(); +} + +uint64_t GcsWriteFile::size() const { + return impl_->size(); +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.h b/velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.h new file mode 100644 index 000000000000..3e6527c3cf6b --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/GcsWriteFile.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/common/file/File.h" + +namespace facebook::velox::filesystems { + +/** + * Implementation of gcs write file. + */ +class GcsWriteFile : public WriteFile { + public: + GcsWriteFile( + const std::string& path, + std::shared_ptr<::google::cloud::storage::Client> client); + + ~GcsWriteFile() override; + + void initialize(); + + /// Writes the data by append mode. + void append(std::string_view data) override; + + /// Flushs the data. + void flush() override; + + /// Closes the file. + void close() override; + + /// Gets the file size. + uint64_t size() const override; + + protected: + class Impl; + std::shared_ptr impl_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp b/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp index 1622e43f6085..d1ae189c7f8b 100644 --- a/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.cpp @@ -14,8 +14,11 @@ * limitations under the License. */ +#include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" + #ifdef VELOX_ENABLE_GCS #include "velox/common/config/Config.h" +#include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h" // @manual #include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" // @manual #include "velox/dwio/common/FileSink.h" @@ -24,7 +27,15 @@ namespace facebook::velox::filesystems { #ifdef VELOX_ENABLE_GCS -folly::once_flag GcsInstantiationFlag; + +using FileSystemMap = folly::Synchronized< + std::unordered_map>>; + +/// Multiple GCS filesystems are supported. +FileSystemMap& gcsFileSystems() { + static FileSystemMap instances; + return instances; +} std::function(std::shared_ptr, std::string_view)> @@ -32,24 +43,52 @@ gcsFileSystemGenerator() { static auto filesystemGenerator = [](std::shared_ptr properties, std::string_view filePath) { - // Only one instance of GCSFileSystem is supported for now (follow S3 - // for now). - // TODO: Support multiple GCSFileSystem instances using a cache - // Initialize on first access and reuse after that. - static std::shared_ptr gcsfs; - folly::call_once(GcsInstantiationFlag, [&properties]() { - std::shared_ptr fs; - if (properties != nullptr) { - fs = std::make_shared(properties); - } else { - fs = std::make_shared( - std::make_shared( - std::unordered_map())); - } - fs->initializeClient(); - gcsfs = fs; - }); - return gcsfs; + const auto file = gcsPath(filePath); + std::string bucket; + std::string object; + setBucketAndKeyFromGcsPath(file, bucket, object); + auto cacheKey = fmt::format( + "{}-{}", + properties->get( + connector::hive::HiveConfig::kGcsEndpoint, + kGcsDefaultCacheKeyPrefix), + bucket); + + // Check if an instance exists with a read lock (shared). + auto fs = gcsFileSystems().withRLock( + [&](auto& instanceMap) -> std::shared_ptr { + auto iterator = instanceMap.find(cacheKey); + if (iterator != instanceMap.end()) { + return iterator->second; + } + return nullptr; + }); + if (fs != nullptr) { + return fs; + } + + return gcsFileSystems().withWLock( + [&](auto& instanceMap) -> std::shared_ptr { + // Repeat the checks with a write lock. + auto iterator = instanceMap.find(cacheKey); + if (iterator != instanceMap.end()) { + return iterator->second; + } + + std::shared_ptr fs; + if (properties != nullptr) { + fs = std::make_shared(bucket, properties); + } else { + fs = std::make_shared( + bucket, + std::make_shared( + std::unordered_map())); + } + fs->initializeClient(); + + instanceMap.insert({cacheKey, fs}); + return fs; + }); }; return filesystemGenerator; } @@ -78,4 +117,12 @@ void registerGcsFileSystem() { #endif } +void registerGcsOAuthCredentialsProvider( + const std::string& providerName, + const GcsOAuthCredentialsProviderFactory& factory) { +#ifdef VELOX_ENABLE_GCS + registerOAuthCredentialsProvider(providerName, factory); +#endif +} + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h b/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h index b0f668d6f413..0a6deb48dc5a 100644 --- a/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h @@ -16,9 +16,26 @@ #pragma once +#include +#include +#include + +namespace facebook::velox::connector::hive { +class HiveConfig; +} + namespace facebook::velox::filesystems { +class GcsOAuthCredentialsProvider; // Register the GCS filesystem. void registerGcsFileSystem(); +using GcsOAuthCredentialsProviderFactory = + std::function( + const std::shared_ptr& hiveConfig)>; + +void registerGcsOAuthCredentialsProvider( + const std::string& providerName, + const GcsOAuthCredentialsProviderFactory& factory); + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/examples/CMakeLists.txt b/velox/connectors/hive/storage_adapters/gcs/examples/CMakeLists.txt index 1363d688da44..9f6da5dee157 100644 --- a/velox/connectors/hive/storage_adapters/gcs/examples/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/gcs/examples/CMakeLists.txt @@ -21,4 +21,5 @@ target_link_libraries( velox_core velox_hive_connector velox_dwio_common_exception - velox_exec) + velox_exec +) diff --git a/velox/connectors/hive/storage_adapters/gcs/examples/GcsFileSystemExample.cpp b/velox/connectors/hive/storage_adapters/gcs/examples/GcsFileSystemExample.cpp index 08d9b83a0152..f5e19454ebd0 100644 --- a/velox/connectors/hive/storage_adapters/gcs/examples/GcsFileSystemExample.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/examples/GcsFileSystemExample.cpp @@ -16,6 +16,7 @@ #include "velox/common/config/Config.h" #include "velox/common/file/File.h" #include "velox/connectors/hive/storage_adapters/gcs/GcsFileSystem.h" +#include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" #include #include @@ -45,7 +46,10 @@ int main(int argc, char** argv) { gflags::ShowUsageWithFlags(argv[0]); return 1; } - filesystems::GcsFileSystem gcfs(newConfiguration()); + std::string bucket; + std::string object; + setBucketAndKeyFromGcsPath(FLAGS_gcs_path, bucket, object); + filesystems::GcsFileSystem gcfs(bucket, newConfiguration()); gcfs.initializeClient(); std::cout << "Opening file for read " << FLAGS_gcs_path << std::endl; std::unique_ptr file_read = gcfs.openFileForRead(FLAGS_gcs_path); diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt index 6a775a72f07d..1d848f671f60 100644 --- a/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/gcs/tests/CMakeLists.txt @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_gcsfile_test GcsUtilTest.cpp GcsFileSystemTest.cpp) -add_test(velox_gcsfile_test velox_gcsfile_test) +add_executable(velox_gcs_file_test GcsUtilTest.cpp GcsFileSystemTest.cpp) +add_test(velox_gcs_file_test velox_gcs_file_test) target_link_libraries( - velox_gcsfile_test + velox_gcs_file_test velox_core velox_dwio_common_exception velox_exec @@ -26,7 +26,8 @@ target_link_libraries( velox_temp_path GTest::gmock GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_gcs_insert_test GcsInsertTest.cpp) add_test(velox_gcs_insert_test velox_gcs_insert_test) @@ -36,9 +37,30 @@ target_link_libraries( velox_gcs velox_hive_config velox_core + velox_dwio_parquet_reader + velox_dwio_parquet_writer velox_exec_test_lib velox_dwio_common_exception velox_exec GTest::gmock GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) + +add_executable(velox_gcs_multiendpoints_test GcsMultipleEndpointsTest.cpp) +add_test(velox_gcs_multiendpoints_test velox_gcs_multiendpoints_test) +target_link_libraries( + velox_gcs_multiendpoints_test + velox_file + velox_gcs + velox_hive_config + velox_core + velox_exec_test_lib + velox_dwio_parquet_reader + velox_dwio_parquet_writer + velox_dwio_common_exception + velox_exec + GTest::gmock + GTest::gtest + GTest::gtest_main +) diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/GcsEmulator.h b/velox/connectors/hive/storage_adapters/gcs/tests/GcsEmulator.h index 12db0a12808f..bea464f6fe02 100644 --- a/velox/connectors/hive/storage_adapters/gcs/tests/GcsEmulator.h +++ b/velox/connectors/hive/storage_adapters/gcs/tests/GcsEmulator.h @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "velox/common/config/Config.h" +#include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" #include "velox/exec/tests/utils/PortUtil.h" @@ -92,7 +93,7 @@ class GcsEmulator : public testing::Environment { const std::unordered_map configOverride = {}) const { std::unordered_map config( - {{"hive.gcs.endpoint", endpoint_}}); + {{connector::hive::HiveConfig::kGcsEndpoint, endpoint_}}); // Update the default config map with the supplied configOverride map for (const auto& [configName, configValue] : configOverride) { @@ -102,7 +103,7 @@ class GcsEmulator : public testing::Environment { return std::make_shared(std::move(config)); } - std::string_view preexistingBucketName() { + const std::string& preexistingBucketName() { return bucketName_; } diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp index a80842208f89..4a465cacaefb 100644 --- a/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/tests/GcsFileSystemTest.cpp @@ -18,6 +18,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/File.h" #include "velox/connectors/hive/storage_adapters/gcs/GcsUtil.h" +#include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" #include "velox/connectors/hive/storage_adapters/gcs/tests/GcsEmulator.h" #include "velox/exec/tests/utils/TempFilePath.h" @@ -26,6 +27,16 @@ namespace facebook::velox::filesystems { namespace { +class DummyGcsOAuthCredentialsProvider : public GcsOAuthCredentialsProvider { + public: + explicit DummyGcsOAuthCredentialsProvider() : GcsOAuthCredentialsProvider() {} + + std::shared_ptr getCredentials( + const std::string&) override { + VELOX_FAIL("DummyGcsOAuthCredentialsProvider: Not implemented"); + } +}; + class GcsFileSystemTest : public testing::Test { public: void SetUp() { @@ -40,7 +51,8 @@ TEST_F(GcsFileSystemTest, readFile) { const auto gcsFile = gcsURI( emulator_->preexistingBucketName(), emulator_->preexistingObjectName()); - filesystems::GcsFileSystem gcfs(emulator_->hiveConfig()); + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); gcfs.initializeClient(); auto readFile = gcfs.openFileForRead(gcsFile); std::int64_t size = readFile->size(); @@ -48,8 +60,8 @@ TEST_F(GcsFileSystemTest, readFile) { EXPECT_EQ(size, ref_size); EXPECT_EQ(readFile->pread(0, size), kLoremIpsum); - char buffer1[size]; - ASSERT_EQ(readFile->pread(0, size, &buffer1), kLoremIpsum); + std::vector buffer1(size); + ASSERT_EQ(readFile->pread(0, size, buffer1.data()), kLoremIpsum); ASSERT_EQ(readFile->size(), ref_size); char buffer2[50]; @@ -77,7 +89,8 @@ TEST_F(GcsFileSystemTest, writeAndReadFile) { const std::string_view newFile = "readWriteFile.txt"; const auto gcsFile = gcsURI(emulator_->preexistingBucketName(), newFile); - filesystems::GcsFileSystem gcfs(emulator_->hiveConfig()); + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); gcfs.initializeClient(); auto writeFile = gcfs.openFileForWrite(gcsFile); std::string_view kDataContent = @@ -103,45 +116,84 @@ TEST_F(GcsFileSystemTest, writeAndReadFile) { EXPECT_EQ(readFile->pread(0, size), kDataContent); // Opening an existing file for write must be an error. - filesystems::GcsFileSystem newGcfs(emulator_->hiveConfig()); + filesystems::GcsFileSystem newGcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); newGcfs.initializeClient(); VELOX_ASSERT_THROW(newGcfs.openFileForWrite(gcsFile), "File already exists"); } -TEST_F(GcsFileSystemTest, renameNotImplemented) { - const std::string_view file = "newTest.txt"; - const auto gcsExistingFile = gcsURI( - emulator_->preexistingBucketName(), emulator_->preexistingObjectName()); - const auto gcsNewFile = gcsURI(emulator_->preexistingBucketName(), file); - filesystems::GcsFileSystem gcfs(emulator_->hiveConfig()); +TEST_F(GcsFileSystemTest, rename) { + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); gcfs.initializeClient(); - gcfs.openFileForRead(gcsExistingFile); + + const std::string_view oldFile = "oldTest.txt"; + const std::string_view newFile = "newTest.txt"; + + const auto gcsExistingFile = + gcsURI(emulator_->preexistingBucketName(), oldFile); + auto writeFile = gcfs.openFileForWrite(gcsExistingFile); + std::string_view kDataContent = "GcsFileSystemTest rename operation test"; + writeFile->append(kDataContent.substr(0, 10)); + writeFile->flush(); + writeFile->close(); + + const auto gcsNewFile = gcsURI(emulator_->preexistingBucketName(), newFile); + VELOX_ASSERT_THROW( - gcfs.rename(gcsExistingFile, gcsNewFile, true), - "rename for GCS not implemented"); + gcfs.rename(gcsExistingFile, gcsExistingFile, false), + fmt::format( + "Failed to rename object {} to {} with as {} exists.", + oldFile, + oldFile, + oldFile)); + + gcfs.rename(gcsExistingFile, gcsNewFile, true); + + auto results = gcfs.list(gcsNewFile); + ASSERT_TRUE( + std::find(results.begin(), results.end(), oldFile) == results.end()); + ASSERT_TRUE( + std::find(results.begin(), results.end(), newFile) != results.end()); } -TEST_F(GcsFileSystemTest, mkdirNotImplemented) { +TEST_F(GcsFileSystemTest, mkdir) { const std::string_view dir = "newDirectory"; const auto gcsNewDirectory = gcsURI(emulator_->preexistingBucketName(), dir); - filesystems::GcsFileSystem gcfs(emulator_->hiveConfig()); + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); gcfs.initializeClient(); - VELOX_ASSERT_THROW( - gcfs.mkdir(gcsNewDirectory), "mkdir for GCS not implemented"); + gcfs.mkdir(gcsNewDirectory); + const auto& results = gcfs.list(gcsNewDirectory); + ASSERT_TRUE(std::find(results.begin(), results.end(), dir) != results.end()); } -TEST_F(GcsFileSystemTest, rmdirNotImplemented) { +TEST_F(GcsFileSystemTest, rmdir) { const std::string_view dir = "Directory"; const auto gcsDirectory = gcsURI(emulator_->preexistingBucketName(), dir); - filesystems::GcsFileSystem gcfs(emulator_->hiveConfig()); + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); gcfs.initializeClient(); - VELOX_ASSERT_THROW(gcfs.rmdir(gcsDirectory), "rmdir for GCS not implemented"); + + auto writeFile = gcfs.openFileForWrite(gcsDirectory); + std::string_view kDataContent = "GcsFileSystemTest rename operation test"; + writeFile->append(kDataContent.substr(0, 10)); + writeFile->flush(); + writeFile->close(); + + auto results = gcfs.list(gcsDirectory); + ASSERT_TRUE(std::find(results.begin(), results.end(), dir) != results.end()); + gcfs.rmdir(gcsDirectory); + + results = gcfs.list(gcsDirectory); + ASSERT_TRUE(std::find(results.begin(), results.end(), dir) == results.end()); } TEST_F(GcsFileSystemTest, missingFile) { const std::string_view file = "newTest.txt"; const auto gcsFile = gcsURI(emulator_->preexistingBucketName(), file); - filesystems::GcsFileSystem gcfs(emulator_->hiveConfig()); + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); gcfs.initializeClient(); VELOX_ASSERT_RUNTIME_THROW_CODE( gcfs.openFileForRead(gcsFile), @@ -150,7 +202,8 @@ TEST_F(GcsFileSystemTest, missingFile) { } TEST_F(GcsFileSystemTest, missingBucket) { - filesystems::GcsFileSystem gcfs(emulator_->hiveConfig()); + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), emulator_->hiveConfig()); gcfs.initializeClient(); const std::string_view gcsFile = "gs://dummy/foo.txt"; VELOX_ASSERT_RUNTIME_THROW_CODE( @@ -208,12 +261,77 @@ TEST_F(GcsFileSystemTest, credentialsConfig) { {"hive.gcs.json-key-file-path", jsonFile->getPath()}}; auto hiveConfig = emulator_->hiveConfig(configOverride); - filesystems::GcsFileSystem gcfs(hiveConfig); + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), hiveConfig); gcfs.initializeClient(); const auto gcsFile = gcsURI( emulator_->preexistingBucketName(), emulator_->preexistingObjectName()); VELOX_ASSERT_THROW( gcfs.openFileForRead(gcsFile), "Invalid ServiceAccountCredentials"); } + +TEST_F(GcsFileSystemTest, credentialsProvider) { + const auto providerFactory = + [](const std::shared_ptr&) { + return std::make_shared(); + }; + registerGcsOAuthCredentialsProvider("dummy_provider", providerFactory); + + { + std::unordered_map configOverride = { + {"hive.gcs.auth.access-token-provider", "dummy_provider"}}; + auto hiveConfig = emulator_->hiveConfig(configOverride); + + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), hiveConfig); + VELOX_ASSERT_THROW( + gcfs.initializeClient(), + "DummyGcsOAuthCredentialsProvider: Not implemented"); + + VELOX_ASSERT_THROW( + registerGcsOAuthCredentialsProvider("", providerFactory), + "GcsOAuthCredentialsProviderFactory name cannot be empty"); + + VELOX_ASSERT_THROW( + registerGcsOAuthCredentialsProvider("dummy_provider", providerFactory), + "GcsOAuthCredentialsProviderFactory 'dummy_provider' already registered"); + } + + // Invalid provider name. + { + std::unordered_map configOverride = { + {"hive.gcs.auth.access-token-provider", ""}}; + auto hiveConfig = emulator_->hiveConfig(configOverride); + + filesystems::GcsFileSystem gcfs( + emulator_->preexistingBucketName(), hiveConfig); + + VELOX_ASSERT_THROW( + gcfs.initializeClient(), + "GcsOAuthCredentialsProviderFactory name cannot be empty"); + } +} + +TEST_F(GcsFileSystemTest, defaultCacheKey) { + registerGcsFileSystem(); + std::unordered_map configWithoutEndpoint = {}; + auto hiveConfigDefault = std::make_shared( + std::move(configWithoutEndpoint)); + const auto gcsFile1 = gcsURI( + emulator_->preexistingBucketName(), emulator_->preexistingObjectName()); + // FileSystem should be cached by the default key. + auto defaultGcs = filesystems::getFileSystem(gcsFile1, hiveConfigDefault); + + std::unordered_map configWithEndpoint = { + {connector::hive::HiveConfig::kGcsEndpoint, kGcsDefaultCacheKeyPrefix}}; + auto hiveConfigCustom = + std::make_shared(std::move(configWithEndpoint)); + const auto gcsFile2 = gcsURI(emulator_->preexistingBucketName(), "dummy.txt"); + auto customGcs = filesystems::getFileSystem(gcsFile2, hiveConfigCustom); + // The same FileSystem should be cached by the value of key + // kGcsDefaultCacheKeyPrefix. + ASSERT_EQ(customGcs, defaultGcs); +} + } // namespace } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/GcsMultipleEndpointsTest.cpp b/velox/connectors/hive/storage_adapters/gcs/tests/GcsMultipleEndpointsTest.cpp new file mode 100644 index 000000000000..00bb3310a067 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/gcs/tests/GcsMultipleEndpointsTest.cpp @@ -0,0 +1,204 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "gtest/gtest.h" +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" +#include "velox/connectors/hive/storage_adapters/gcs/tests/GcsEmulator.h" +#include "velox/dwio/parquet/RegisterParquetReader.h" +#include "velox/dwio/parquet/RegisterParquetWriter.h" +#include "velox/exec/TableWriter.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +static const std::string_view kConnectorId1 = "test-hive1"; +static const std::string_view kConnectorId2 = "test-hive2"; +static const std::string_view kBucketName = "writedata"; + +using namespace facebook::velox::exec::test; + +namespace facebook::velox::filesystems { +namespace { + +class GcsMultipleEndpointsTest : public testing::Test, + public velox::test::VectorTestBase { + public: + static void SetUpTestCase() { + registerGcsFileSystem(); + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + gcsEmulatorOne_ = std::make_unique(); + gcsEmulatorOne_->bootstrap(); + gcsEmulatorTwo_ = std::make_unique(); + gcsEmulatorTwo_->bootstrap(); + + parquet::registerParquetReaderFactory(); + parquet::registerParquetWriterFactory(); + } + + void registerConnectors( + std::string_view connectorId1, + std::string_view connectorId2, + const std::unordered_map config1Override = {}, + const std::unordered_map config2Override = {}) { + connector::hive::HiveConnectorFactory factory; + auto hiveConnector1 = factory.newConnector( + std::string(connectorId1), + gcsEmulatorOne_->hiveConfig(config1Override), + ioExecutor_.get()); + auto hiveConnector2 = factory.newConnector( + std::string(connectorId2), + gcsEmulatorTwo_->hiveConfig(config2Override), + ioExecutor_.get()); + connector::registerConnector(hiveConnector1); + connector::registerConnector(hiveConnector2); + } + + void TearDown() override { + parquet::unregisterParquetReaderFactory(); + parquet::unregisterParquetWriterFactory(); + } + + folly::dynamic writeData( + const RowVectorPtr input, + const std::string& outputDirectory, + const std::string& connectorId) { + auto plan = PlanBuilder() + .values({input}) + .tableWrite( + outputDirectory.data(), + {}, + 0, + {}, + {}, + dwio::common::FileFormat::PARQUET, + {}, + connectorId) + .planNode(); + // Execute the write plan. + auto results = AssertQueryBuilder(plan).copyResults(pool()); + // Second column contains details about written files. + auto details = results->childAt(exec::TableWriteTraits::kFragmentChannel) + ->as>(); + folly::dynamic obj = folly::parseJson(details->valueAt(1)); + return obj["fileWriteInfos"]; + } + + std::shared_ptr createSplit( + folly::dynamic tableWriteInfo, + std::string outputDirectory, + std::string connectorId) { + auto writeFileName = tableWriteInfo[0]["writeFileName"].asString(); + auto filePath = fmt::format("{}{}", outputDirectory, writeFileName); + const int64_t fileSize = tableWriteInfo[0]["fileSize"].asInt(); + + return connector::hive::HiveConnectorSplitBuilder(filePath) + .connectorId(connectorId) + .fileFormat(dwio::common::FileFormat::PARQUET) + .length(fileSize) + .build(); + } + + void testJoin( + int numRows, + std::string_view outputDirectory, + std::string_view connectorId1, + std::string_view connectorId2) { + auto rowType1 = ROW( + {"a0", "a1", "a2", "a3"}, {BIGINT(), INTEGER(), SMALLINT(), DOUBLE()}); + auto rowType2 = ROW( + {"b0", "b1", "b2", "b3"}, {BIGINT(), INTEGER(), SMALLINT(), DOUBLE()}); + + auto input1 = makeRowVector( + rowType1->names(), + {makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row; }), + makeFlatVector(numRows, [](auto row) { return row; })}); + auto input2 = makeRowVector(rowType2->names(), input1->children()); + + // Insert input data into both tables. + auto table1WriteInfo = + writeData(input1, outputDirectory.data(), std::string(connectorId1)); + auto table2WriteInfo = + writeData(input2, outputDirectory.data(), std::string(connectorId2)); + + // Inner Join both the tables. + core::PlanNodeId scan1, scan2; + auto planNodeIdGenerator = std::make_shared(); + auto table1Scan = PlanBuilder(planNodeIdGenerator, pool()) + .startTableScan() + .tableName("hive_table1") + .outputType(rowType1) + .connectorId(std::string(connectorId1)) + .endTableScan() + .capturePlanNodeId(scan1) + .planNode(); + auto join = + PlanBuilder(planNodeIdGenerator, pool()) + .startTableScan() + .tableName("hive_table1") + .outputType(rowType2) + .connectorId(std::string(connectorId2)) + .endTableScan() + .capturePlanNodeId(scan2) + .hashJoin({"b0"}, {"a0"}, table1Scan, "", {"a0", "a1", "a2", "a3"}) + .planNode(); + + auto split1 = createSplit( + table1WriteInfo, outputDirectory.data(), std::string(connectorId1)); + auto split2 = createSplit( + table2WriteInfo, outputDirectory.data(), std::string(connectorId2)); + auto results = AssertQueryBuilder(join) + .split(scan1, split1) + .split(scan2, split2) + .copyResults(pool()); + assertEqualResults({input1}, {results}); + } + + std::unique_ptr gcsEmulatorOne_; + std::unique_ptr gcsEmulatorTwo_; + std::unique_ptr ioExecutor_; +}; +} // namespace + +TEST_F(GcsMultipleEndpointsTest, baseEndpoints) { + const int64_t kExpectedRows = 1'000; + + const auto gcsBucket = gcsURI(gcsEmulatorOne_->preexistingBucketName(), ""); + + registerConnectors(kConnectorId1, kConnectorId2); + + testJoin(kExpectedRows, gcsBucket, kConnectorId1, kConnectorId2); + + connector::unregisterConnector(std::string(kConnectorId1)); + connector::unregisterConnector(std::string(kConnectorId2)); +} + +} // namespace facebook::velox::filesystems + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt index 44aa7be3489c..8c49f3e11ed9 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/hdfs/CMakeLists.txt @@ -17,18 +17,8 @@ velox_add_library(velox_hdfs RegisterHdfsFileSystem.cpp) if(VELOX_ENABLE_HDFS) - velox_sources( - velox_hdfs - PRIVATE - HdfsFileSystem.cpp - HdfsReadFile.cpp - HdfsWriteFile.cpp) - velox_link_libraries( - velox_hdfs - velox_external_hdfs - velox_dwio_common - Folly::folly - xsimd) + velox_sources(velox_hdfs PRIVATE HdfsFileSystem.cpp HdfsReadFile.cpp HdfsWriteFile.cpp) + velox_link_libraries(velox_hdfs velox_external_hdfs velox_dwio_common Folly::folly xsimd) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp index 856f2b2526de..46522ab2a6f5 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.cpp @@ -56,11 +56,29 @@ class HdfsFileSystem::Impl { } ~Impl() { - LOG(INFO) << "Disconnecting HDFS file system"; - int disconnectResult = driver_->Disconnect(hdfsClient_); - if (disconnectResult != 0) { - LOG(WARNING) << "hdfs disconnect failure in HdfsReadFile close: " - << errno; + if (!closed_) { + LOG(WARNING) + << "The HdfsFileSystem instance is not closed upon destruction. You must explicitly call the close() API before JVM termination to ensure proper disconnection."; + } + } + + // The HdfsFileSystem::Disconnect operation requires the JVM method + // definitions to be loaded within an active JVM process. + // Therefore, it must be invoked before the JVM shuts down. + + // To address this, we’ve introduced a new close() API that performs the + // disconnect operation. Third-party applications can call this close() method + // prior to JVM termination to ensure proper cleanup. + void close() { + if (!closed_) { + LOG(WARNING) << "Disconnecting HDFS file system"; + int disconnectResult = driver_->Disconnect(hdfsClient_); + if (disconnectResult != 0) { + LOG(WARNING) << "hdfs disconnect failure in HdfsReadFile close: " + << errno; + } + + closed_ = true; } } @@ -75,6 +93,7 @@ class HdfsFileSystem::Impl { private: hdfsFS hdfsClient_; filesystems::arrow::io::internal::LibHdfsShim* driver_; + bool closed_ = false; }; HdfsFileSystem::HdfsFileSystem( @@ -91,7 +110,7 @@ std::string HdfsFileSystem::name() const { std::unique_ptr HdfsFileSystem::openFileForRead( std::string_view path, const FileOptions& /*unused*/) { - // Only remove the schema for hdfs path. + // Only remove the scheme for hdfs path. if (path.find(kScheme) == 0) { path.remove_prefix(kScheme.length()); if (auto index = path.find('/')) { @@ -109,6 +128,10 @@ std::unique_ptr HdfsFileSystem::openFileForWrite( impl_->hdfsShim(), impl_->hdfsClient(), path); } +void HdfsFileSystem::close() { + impl_->close(); +} + bool HdfsFileSystem::isHdfsFile(const std::string_view filePath) { return (filePath.find(kScheme) == 0) || (filePath.find(kViewfsScheme) == 0); } @@ -129,11 +152,11 @@ HdfsServiceEndpoint HdfsFileSystem::getServiceEndpoint( // Fall back to get a fixed endpoint from config. auto hdfsHost = config->get("hive.hdfs.host"); VELOX_CHECK( - hdfsHost.hasValue(), + hdfsHost.has_value(), "hdfsHost is empty, configuration missing for hdfs host"); auto hdfsPort = config->get("hive.hdfs.port"); VELOX_CHECK( - hdfsPort.hasValue(), + hdfsPort.has_value(), "hdfsPort is empty, configuration missing for hdfs port"); return HdfsServiceEndpoint{*hdfsHost, *hdfsPort}; } @@ -152,7 +175,131 @@ HdfsServiceEndpoint HdfsFileSystem::getServiceEndpoint( } void HdfsFileSystem::remove(std::string_view path) { - VELOX_UNSUPPORTED("Does not support removing files from hdfs"); + // Only remove the scheme for hdfs path. + if (path.find(kScheme) == 0) { + path.remove_prefix(kScheme.length()); + if (auto index = path.find('/')) { + path.remove_prefix(index); + } + } + + VELOX_CHECK_EQ( + impl_->hdfsShim()->Delete(impl_->hdfsClient(), path.data(), 0), + 0, + "Cannot delete file : {} in HDFS, error is : {}", + path, + impl_->hdfsShim()->GetLastExceptionRootCause()); +} + +std::vector HdfsFileSystem::list(std::string_view path) { + // Only remove the scheme for hdfs path. + if (path.find(kScheme) == 0) { + path.remove_prefix(kScheme.length()); + if (auto index = path.find('/')) { + path.remove_prefix(index); + } + } + + std::vector result; + int numEntries; + + auto fileInfo = impl_->hdfsShim()->ListDirectory( + impl_->hdfsClient(), path.data(), &numEntries); + + VELOX_CHECK_NOT_NULL( + fileInfo, + "Unable to list the files in path {}. got error: {}", + path, + impl_->hdfsShim()->GetLastExceptionRootCause()); + + for (auto i = 0; i < numEntries; i++) { + result.emplace_back(fileInfo[i].mName); + } + + impl_->hdfsShim()->FreeFileInfo(fileInfo, numEntries); + + return result; +} + +bool HdfsFileSystem::exists(std::string_view path) { + // Only remove the scheme for hdfs path. + if (path.find(kScheme) == 0) { + path.remove_prefix(kScheme.length()); + if (auto index = path.find('/')) { + path.remove_prefix(index); + } + } + + return impl_->hdfsShim()->Exists(impl_->hdfsClient(), path.data()) == 0; +} + +void HdfsFileSystem::mkdir( + std::string_view path, + const DirectoryOptions& options) { + // Only remove the scheme for hdfs path. + if (path.find(kScheme) == 0) { + path.remove_prefix(kScheme.length()); + if (auto index = path.find('/')) { + path.remove_prefix(index); + } + } + + VELOX_CHECK_EQ( + impl_->hdfsShim()->MakeDirectory(impl_->hdfsClient(), path.data()), + 0, + "Cannot mkdir {} in HDFS, error is : {}", + path, + impl_->hdfsShim()->GetLastExceptionRootCause()); +} + +void HdfsFileSystem::rename( + std::string_view path, + std::string_view newPath, + bool overWrite) { + VELOX_CHECK_EQ( + overWrite, false, "HdfsFileSystem::rename doesn't support overwrite"); + // Only remove the scheme for hdfs path. + if (path.find(kScheme) == 0) { + path.remove_prefix(kScheme.length()); + if (auto index = path.find('/')) { + path.remove_prefix(index); + } + } + + // Only remove the scheme for hdfs path. + if (newPath.find(kScheme) == 0) { + newPath.remove_prefix(kScheme.length()); + if (auto index = newPath.find('/')) { + newPath.remove_prefix(index); + } + } + + VELOX_CHECK_EQ( + impl_->hdfsShim()->Rename( + impl_->hdfsClient(), path.data(), newPath.data()), + 0, + "Cannot rename file from {} to {} in HDFS, error is : {}", + path, + newPath, + impl_->hdfsShim()->GetLastExceptionRootCause()); +} + +void HdfsFileSystem::rmdir(std::string_view path) { + // Only remove the scheme for hdfs path. + if (path.find(kScheme) == 0) { + path.remove_prefix(kScheme.length()); + if (auto index = path.find('/')) { + path.remove_prefix(index); + } + } + + VELOX_CHECK_EQ( + impl_->hdfsShim()->Delete( + impl_->hdfsClient(), path.data(), /*recursive=*/true), + 0, + "Cannot remove directory {} recursively in HDFS, error is : {}", + path, + impl_->hdfsShim()->GetLastExceptionRootCause()); } } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h index b541ec629baf..9720bb13034b 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h @@ -61,31 +61,25 @@ class HdfsFileSystem : public FileSystem { std::string_view path, const FileOptions& options = {}) override; + void close(); + + // Deletes the hdfs files. void remove(std::string_view path) override; - virtual void rename( + void rename( std::string_view path, std::string_view newPath, - bool overWrite = false) override { - VELOX_UNSUPPORTED("rename for HDFs not implemented"); - } + bool overWrite = false) override; - bool exists(std::string_view path) override { - VELOX_UNSUPPORTED("exists for HDFS not implemented"); - } + bool exists(std::string_view path) override; - virtual std::vector list(std::string_view path) override { - VELOX_UNSUPPORTED("list for HDFS not implemented"); - } + /// List the objects associated to a path. + std::vector list(std::string_view path) override; void mkdir(std::string_view path, const DirectoryOptions& options = {}) - override { - VELOX_UNSUPPORTED("mkdir for HDFS not implemented"); - } + override; - void rmdir(std::string_view path) override { - VELOX_UNSUPPORTED("rmdir for HDFS not implemented"); - } + void rmdir(std::string_view path) override; static bool isHdfsFile(std::string_view filename); diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp index affc1dfd2ede..1d320cda44b6 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp @@ -103,12 +103,19 @@ class HdfsReadFile::Impl { } } - std::string_view pread(uint64_t offset, uint64_t length, void* buf) const { + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buf, + const FileStorageContext& fileStorageContext) const { preadInternal(offset, length, static_cast(buf)); return {static_cast(buf), length}; } - std::string pread(uint64_t offset, uint64_t length) const { + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { std::string result(length, 0); char* pos = result.data(); preadInternal(offset, length, pos); @@ -163,15 +170,15 @@ std::string_view HdfsReadFile::pread( uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats) const { - return pImpl->pread(offset, length, buf); + const FileStorageContext& fileStorageContext) const { + return pImpl->pread(offset, length, buf, fileStorageContext); } std::string HdfsReadFile::pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats) const { - return pImpl->pread(offset, length); + const FileStorageContext& fileStorageContext) const { + return pImpl->pread(offset, length, fileStorageContext); } uint64_t HdfsReadFile::size() const { diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h index ddd35e511a71..a59b178909c6 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h @@ -38,12 +38,12 @@ class HdfsReadFile final : public ReadFile { uint64_t offset, uint64_t length, void* buf, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; std::string pread( uint64_t offset, uint64_t length, - filesystems::File::IoStats* stats = nullptr) const final; + const FileStorageContext& fileStorageContext = {}) const final; uint64_t size() const final; diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp index be668a3133e1..26d43ccb9100 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsWriteFile.cpp @@ -54,12 +54,14 @@ HdfsWriteFile::~HdfsWriteFile() { void HdfsWriteFile::close() { int success = driver_->CloseFile(hdfsClient_, hdfsFile_); + common::testutil::TestValue::adjust( + "facebook::velox::connectors::hive::HdfsWriteFile::close", &success); + hdfsFile_ = nullptr; VELOX_CHECK_EQ( success, 0, "Failed to close hdfs file: {}", driver_->GetLastExceptionRootCause()); - hdfsFile_ = nullptr; } void HdfsWriteFile::flush() { diff --git a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp index 1f23179f0a72..bb4f208c4731 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.cpp @@ -28,44 +28,48 @@ namespace facebook::velox::filesystems { #ifdef VELOX_ENABLE_HDFS std::mutex mtx; +folly::ConcurrentHashMap> + registeredFilesystems; + std::function(std::shared_ptr, std::string_view)> hdfsFileSystemGenerator() { - static auto filesystemGenerator = [](std::shared_ptr - properties, - std::string_view filePath) { - static folly::ConcurrentHashMap> - filesystems; - static folly:: - ConcurrentHashMap> - hdfsInitiationFlags; - HdfsServiceEndpoint endpoint = - HdfsFileSystem::getServiceEndpoint(filePath, properties.get()); - std::string hdfsIdentity = endpoint.identity(); - if (filesystems.find(hdfsIdentity) != filesystems.end()) { - return filesystems[hdfsIdentity]; - } - std::unique_lock lk(mtx, std::defer_lock); - /// If the init flag for a given hdfs identity is not found, - /// create one for init use. It's a singleton. - if (hdfsInitiationFlags.find(hdfsIdentity) == hdfsInitiationFlags.end()) { - lk.lock(); - if (hdfsInitiationFlags.find(hdfsIdentity) == hdfsInitiationFlags.end()) { - std::shared_ptr initiationFlagPtr = - std::make_shared(); - hdfsInitiationFlags.insert(hdfsIdentity, initiationFlagPtr); - } - lk.unlock(); - } - folly::call_once( - *hdfsInitiationFlags[hdfsIdentity].get(), - [&properties, endpoint, hdfsIdentity]() { - auto filesystem = - std::make_shared(properties, endpoint); - filesystems.insert(hdfsIdentity, filesystem); - }); - return filesystems[hdfsIdentity]; - }; + static auto filesystemGenerator = + [](std::shared_ptr properties, + std::string_view filePath) { + static folly:: + ConcurrentHashMap> + hdfsInitiationFlags; + HdfsServiceEndpoint endpoint = + HdfsFileSystem::getServiceEndpoint(filePath, properties.get()); + std::string hdfsIdentity = endpoint.identity(); + if (registeredFilesystems.find(hdfsIdentity) != + registeredFilesystems.end()) { + return registeredFilesystems[hdfsIdentity]; + } + std::unique_lock lk(mtx, std::defer_lock); + /// If the init flag for a given hdfs identity is not found, + /// create one for init use. It's a singleton. + if (hdfsInitiationFlags.find(hdfsIdentity) == + hdfsInitiationFlags.end()) { + lk.lock(); + if (hdfsInitiationFlags.find(hdfsIdentity) == + hdfsInitiationFlags.end()) { + std::shared_ptr initiationFlagPtr = + std::make_shared(); + hdfsInitiationFlags.insert(hdfsIdentity, initiationFlagPtr); + } + lk.unlock(); + } + folly::call_once( + *hdfsInitiationFlags[hdfsIdentity].get(), + [&properties, endpoint, hdfsIdentity]() { + auto filesystem = + std::make_shared(properties, endpoint); + registeredFilesystems.insert(hdfsIdentity, filesystem); + }); + return registeredFilesystems[hdfsIdentity]; + }; return filesystemGenerator; } diff --git a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h index 6f6f0c032bd7..18eef4aca176 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h @@ -16,8 +16,15 @@ #pragma once +#include "folly/concurrency/ConcurrentHashMap.h" + namespace facebook::velox::filesystems { +class HdfsFileSystem; + +extern folly::ConcurrentHashMap> + registeredFilesystems; + // Register the HDFS. void registerHdfsFileSystem(); diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt index 3e3b8b2d5a24..bdefd5b93c78 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/CMakeLists.txt @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_hdfs_file_test HdfsFileSystemTest.cpp HdfsMiniCluster.cpp - HdfsUtilTest.cpp) +add_executable(velox_hdfs_file_test HdfsFileSystemTest.cpp HdfsMiniCluster.cpp HdfsUtilTest.cpp) add_test(velox_hdfs_file_test velox_hdfs_file_test) target_link_libraries( @@ -27,32 +26,33 @@ target_link_libraries( velox_exec GTest::gtest GTest::gtest_main - GTest::gmock) + GTest::gmock +) -target_compile_options(velox_hdfs_file_test - PRIVATE -Wno-deprecated-declarations) +target_compile_options(velox_hdfs_file_test PRIVATE -Wno-deprecated-declarations) -add_executable(velox_hdfs_insert_test HdfsInsertTest.cpp HdfsMiniCluster.cpp - HdfsUtilTest.cpp) +add_executable(velox_hdfs_insert_test HdfsInsertTest.cpp HdfsMiniCluster.cpp HdfsUtilTest.cpp) add_test(velox_hdfs_insert_test velox_hdfs_insert_test) target_link_libraries( velox_hdfs_insert_test + velox_dwio_parquet_reader + velox_dwio_parquet_writer + velox_hdfs velox_exec_test_lib velox_exec GTest::gtest GTest::gtest_main - GTest::gmock) + GTest::gmock +) -target_compile_options(velox_hdfs_insert_test - PRIVATE -Wno-deprecated-declarations) +target_compile_options(velox_hdfs_insert_test PRIVATE -Wno-deprecated-declarations) # velox_hdfs_insert_test and velox_hdfs_file_test two tests can't run in # parallel due to the port conflict in Hadoop NameNode and DataNode. The # namenode port conflict can be resolved using the -nnport configuration in -# hadoop-mapreduce-client-jobclient-3.3.0-tests.jar. However the data node port +# hadoop-mapreduce-client-jobclient-3.3.6-tests.jar. However the data node port # cannot be configured. Therefore, we need to make sure that # velox_hdfs_file_test runs only after velox_hdfs_insert_test has finished. -set_tests_properties(velox_hdfs_insert_test PROPERTIES DEPENDS - velox_hdfs_file_test) +set_tests_properties(velox_hdfs_insert_test PROPERTIES DEPENDS velox_hdfs_file_test) diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp index be330b99901f..2fce65c91c55 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp @@ -21,6 +21,7 @@ #include "gtest/gtest.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h" @@ -36,6 +37,9 @@ using filesystems::arrow::io::internal::LibHdfsShim; constexpr int kOneMB = 1 << 20; static const std::string kDestinationPath = "/test_file.txt"; +static const std::string kRenamePath = "/rename_file.txt"; +static const std::string kRenameNewPath = "/rename_new_file.txt"; +static const std::string kDeletedPath = "/delete_file.txt"; static const std::string kSimpleDestinationPath = "hdfs://" + kDestinationPath; static const std::string kViewfsDestinationPath = "viewfs://" + kDestinationPath; @@ -50,6 +54,8 @@ class HdfsFileSystemTest : public testing::Test { miniCluster->start(); auto tempFile = createFile(); miniCluster->addFile(tempFile->getPath(), kDestinationPath); + miniCluster->addFile(tempFile->getPath(), kRenamePath); + miniCluster->addFile(tempFile->getPath(), kDeletedPath); } configurationValues.insert( {"hive.hdfs.host", std::string(miniCluster->host())}); @@ -67,6 +73,11 @@ class HdfsFileSystemTest : public testing::Test { } static void TearDownTestSuite() { + for (const auto& [_, filesystem] : + facebook::velox::filesystems::registeredFilesystems) { + filesystem->close(); + } + miniCluster->stop(); } @@ -212,6 +223,29 @@ TEST_F(HdfsFileSystemTest, read) { readData(&readFile); } +TEST_F(HdfsFileSystemTest, rename) { + auto config = std::make_shared( + std::unordered_map(configurationValues)); + auto hdfsFileSystem = + filesystems::getFileSystem(fullDestinationPath_, config); + + ASSERT_TRUE(hdfsFileSystem->exists(kRenamePath)); + hdfsFileSystem->rename(kRenamePath, kRenameNewPath); + ASSERT_FALSE(hdfsFileSystem->exists(kRenamePath)); + ASSERT_TRUE(hdfsFileSystem->exists(kRenameNewPath)); +} + +TEST_F(HdfsFileSystemTest, delete) { + auto config = std::make_shared( + std::unordered_map(configurationValues)); + auto hdfsFileSystem = + filesystems::getFileSystem(fullDestinationPath_, config); + + ASSERT_TRUE(hdfsFileSystem->exists(kDeletedPath)); + hdfsFileSystem->remove(kDeletedPath); + ASSERT_FALSE(hdfsFileSystem->exists(kDeletedPath)); +} + TEST_F(HdfsFileSystemTest, viaFileSystem) { auto config = std::make_shared( std::unordered_map(configurationValues)); @@ -221,6 +255,31 @@ TEST_F(HdfsFileSystemTest, viaFileSystem) { readData(readFile.get()); } +TEST_F(HdfsFileSystemTest, exists) { + auto config = std::make_shared( + std::unordered_map(configurationValues)); + auto hdfsFileSystem = + filesystems::getFileSystem(fullDestinationPath_, config); + ASSERT_TRUE(hdfsFileSystem->exists(fullDestinationPath_)); + + const std::string_view notExistFilePath = + "hdfs://localhost:7777//path/that/does/not/exist"; + ASSERT_FALSE(hdfsFileSystem->exists(notExistFilePath)); +} + +TEST_F(HdfsFileSystemTest, mkdirAndRmdir) { + auto config = std::make_shared( + std::unordered_map(configurationValues)); + auto hdfsFileSystem = + filesystems::getFileSystem(fullDestinationPath_, config); + const std::string newDir = "/new_directory"; + ASSERT_FALSE(hdfsFileSystem->exists(newDir)); + hdfsFileSystem->mkdir(newDir); + ASSERT_TRUE(hdfsFileSystem->exists(newDir)); + hdfsFileSystem->rmdir(newDir); + ASSERT_FALSE(hdfsFileSystem->exists(newDir)); +} + TEST_F(HdfsFileSystemTest, initializeFsWithEndpointInfoInFilePath) { // Without host/port configured. auto config = std::make_shared( @@ -328,16 +387,6 @@ TEST_F(HdfsFileSystemTest, writeSupported) { hdfsFileSystem->openFileForWrite("/path"); } -TEST_F(HdfsFileSystemTest, removeNotSupported) { - auto config = std::make_shared( - std::unordered_map(configurationValues)); - auto hdfsFileSystem = - filesystems::getFileSystem(fullDestinationPath_, config); - VELOX_ASSERT_THROW( - hdfsFileSystem->remove("/path"), - "Does not support removing files from hdfs"); -} - TEST_F(HdfsFileSystemTest, multipleThreadsWithReadFile) { startThreads = false; @@ -457,6 +506,18 @@ TEST_F(HdfsFileSystemTest, writeWithParentDirNotExist) { ASSERT_EQ(writeFile->size(), data.size() * 3); } +TEST_F(HdfsFileSystemTest, list) { + auto config = std::make_shared( + std::unordered_map(configurationValues)); + auto hdfsFileSystem = + filesystems::getFileSystem(fullDestinationPath_, config); + + auto result = hdfsFileSystem->list(fullDestinationPath_); + + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(result[0].find(kDestinationPath) != std::string::npos); +} + TEST_F(HdfsFileSystemTest, readFailures) { filesystems::arrow::io::internal::LibHdfsShim* driver; auto hdfs = connectHdfsDriver( @@ -465,3 +526,34 @@ TEST_F(HdfsFileSystemTest, readFailures) { std::string(miniCluster->nameNodePort())); verifyFailures(driver, hdfs); } + +DEBUG_ONLY_TEST_F(HdfsFileSystemTest, writeFilePreventsDoubleClose) { + common::testutil::TestValue::enable(); + + int closeCallCount = 0; + + SCOPED_TESTVALUE_SET( + "facebook::velox::connectors::hive::HdfsWriteFile::close", + std::function([&closeCallCount](int* success) { + ++closeCallCount; + if (closeCallCount == 1) { + *success = -1; + } + })); + + auto writeFile = openFileForWrite("/test_double_close.txt"); + + writeFile->append("test data"); + writeFile->flush(); + + VELOX_ASSERT_THROW(writeFile->close(), "Failed to close hdfs file:"); + + EXPECT_EQ(closeCallCount, 1); + + // Destructor should not call close() again because hdfsFile_ is nullptr + // The closeCallCount should remain 1. + writeFile.reset(); + EXPECT_EQ(closeCallCount, 1); + + common::testutil::TestValue::disable(); +} diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp index 9ec9a1254154..ed2287a7c42d 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsInsertTest.cpp @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "velox/common/memory/Memory.h" +#include "velox/connectors/hive/storage_adapters/hdfs/HdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h" #include "velox/connectors/hive/storage_adapters/test_common/InsertTest.h" @@ -47,6 +48,10 @@ class HdfsInsertTest : public testing::Test, public InsertTest { } void TearDown() override { + for (const auto& [_, filesystem] : + facebook::velox::filesystems::registeredFilesystems) { + filesystem->close(); + } InsertTest::TearDown(); miniCluster->stop(); } diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h index c54ae9589b3e..da07cb341a85 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsMiniCluster.h @@ -26,7 +26,7 @@ static const std::string kMiniClusterExecutableName{"hadoop"}; static const std::string kHadoopSearchPath{":/usr/local/hadoop/bin"}; static const std::string kJarCommand{"jar"}; static const std::string kMiniclusterJar{ - "/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-3.3.0-tests.jar"}; + "/share/hadoop/mapreduce/hadoop-mapreduce-client-jobclient-3.3.6-tests.jar"}; static const std::string kMiniclusterCommand{"minicluster"}; static const std::string kNoMapReduceOption{"-nomr"}; static const std::string kFormatNameNodeOption{"-format"}; diff --git a/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt index 2b8837970bcb..741f01a61b36 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/s3fs/CMakeLists.txt @@ -18,14 +18,11 @@ velox_add_library(velox_s3fs RegisterS3FileSystem.cpp) if(VELOX_ENABLE_S3) velox_sources( velox_s3fs - PRIVATE - S3FileSystem.cpp - S3Util.cpp - S3Config.cpp) + PRIVATE S3FileSystem.cpp S3Util.cpp S3Config.cpp S3WriteFile.cpp S3ReadFile.cpp + ) velox_include_directories(velox_s3fs PRIVATE ${AWSSDK_INCLUDE_DIRS}) - velox_link_libraries(velox_s3fs PRIVATE velox_dwio_common Folly::folly - ${AWSSDK_LIBRARIES}) + velox_link_libraries(velox_s3fs PRIVATE velox_dwio_common Folly::folly ${AWSSDK_LIBRARIES}) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp index 639e82b5e1cf..d0e24d984ceb 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.cpp @@ -20,6 +20,7 @@ #include "velox/common/file/File.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Config.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Counters.h" +#include "velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Util.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.h" #include "velox/dwio/common/DataBuffer.h" @@ -33,45 +34,16 @@ #include #include #include -#include -#include -#include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include #include -#include +#include +#include namespace facebook::velox::filesystems { namespace { -// Reference: https://issues.apache.org/jira/browse/ARROW-8692 -// https://github.com/apache/arrow/blob/master/cpp/src/arrow/filesystem/s3fs.cc#L843 -// A non-copying iostream. See -// https://stackoverflow.com/questions/35322033/aws-c-sdk-uploadpart-times-out -// https://stackoverflow.com/questions/13059091/creating-an-input-stream-from-constant-memory -class StringViewStream : Aws::Utils::Stream::PreallocatedStreamBuf, - public std::iostream { - public: - StringViewStream(const void* data, int64_t nbytes) - : Aws::Utils::Stream::PreallocatedStreamBuf( - reinterpret_cast(const_cast(data)), - static_cast(nbytes)), - std::iostream(this) {} -}; - -// By default, the AWS SDK reads object data into an auto-growing StringStream. -// To avoid copies, read directly into a pre-allocated buffer instead. -// See https://github.com/aws/aws-sdk-cpp/issues/64 for an alternative but -// functionally similar recipe. -Aws::IOStreamFactory AwsWriteableStreamFactory(void* data, int64_t nbytes) { - return [=]() { return Aws::New("", data, nbytes); }; -} folly::Synchronized< std::unordered_map>& @@ -96,141 +68,6 @@ std::shared_ptr getCredentialsProviderByName( }); } -class S3ReadFile final : public ReadFile { - public: - S3ReadFile(std::string_view path, Aws::S3::S3Client* client) - : client_(client) { - getBucketAndKeyFromPath(path, bucket_, key_); - } - - // Gets the length of the file. - // Checks if there are any issues reading the file. - void initialize(const filesystems::FileOptions& options) { - if (options.fileSize.has_value()) { - VELOX_CHECK_GE( - options.fileSize.value(), 0, "File size must be non-negative"); - length_ = options.fileSize.value(); - } - - // Make it a no-op if invoked twice. - if (length_ != -1) { - return; - } - - Aws::S3::Model::HeadObjectRequest request; - request.SetBucket(awsString(bucket_)); - request.SetKey(awsString(key_)); - - RECORD_METRIC_VALUE(kMetricS3MetadataCalls); - auto outcome = client_->HeadObject(request); - if (!outcome.IsSuccess()) { - RECORD_METRIC_VALUE(kMetricS3GetMetadataErrors); - } - RECORD_METRIC_VALUE(kMetricS3GetMetadataRetries, outcome.GetRetryCount()); - VELOX_CHECK_AWS_OUTCOME( - outcome, "Failed to get metadata for S3 object", bucket_, key_); - length_ = outcome.GetResult().GetContentLength(); - VELOX_CHECK_GE(length_, 0); - } - - std::string_view pread( - uint64_t offset, - uint64_t length, - void* buffer, - File::IoStats* stats) const override { - preadInternal(offset, length, static_cast(buffer)); - return {static_cast(buffer), length}; - } - - std::string pread(uint64_t offset, uint64_t length, File::IoStats* stats) - const override { - std::string result(length, 0); - char* position = result.data(); - preadInternal(offset, length, position); - return result; - } - - uint64_t preadv( - uint64_t offset, - const std::vector>& buffers, - File::IoStats* stats) const override { - // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in - // between. This call must populate the ranges (except gap ranges) - // sequentially starting from 'offset'. AWS S3 GetObject does not support - // multi-range. AWS S3 also charges by number of read requests and not size. - // The idea here is to use a single read spanning all the ranges and then - // populate individual ranges. We pre-allocate a buffer to support this. - size_t length = 0; - for (const auto range : buffers) { - length += range.size(); - } - // TODO: allocate from a memory pool - std::string result(length, 0); - preadInternal(offset, length, static_cast(result.data())); - size_t resultOffset = 0; - for (auto range : buffers) { - if (range.data()) { - memcpy(range.data(), &(result.data()[resultOffset]), range.size()); - } - resultOffset += range.size(); - } - return length; - } - - uint64_t size() const override { - return length_; - } - - uint64_t memoryUsage() const override { - // TODO: Check if any buffers are being used by the S3 library - return sizeof(Aws::S3::S3Client) + kS3MaxKeySize + 2 * sizeof(std::string) + - sizeof(int64_t); - } - - bool shouldCoalesce() const final { - return false; - } - - std::string getName() const final { - return fmt::format("s3://{}/{}", bucket_, key_); - } - - uint64_t getNaturalReadSize() const final { - return 72 << 20; - } - - private: - // The assumption here is that "position" has space for at least "length" - // bytes. - void preadInternal(uint64_t offset, uint64_t length, char* position) const { - // Read the desired range of bytes. - Aws::S3::Model::GetObjectRequest request; - Aws::S3::Model::GetObjectResult result; - - request.SetBucket(awsString(bucket_)); - request.SetKey(awsString(key_)); - std::stringstream ss; - ss << "bytes=" << offset << "-" << offset + length - 1; - request.SetRange(awsString(ss.str())); - request.SetResponseStreamFactory( - AwsWriteableStreamFactory(position, length)); - RECORD_METRIC_VALUE(kMetricS3ActiveConnections); - RECORD_METRIC_VALUE(kMetricS3GetObjectCalls); - auto outcome = client_->GetObject(request); - if (!outcome.IsSuccess()) { - RECORD_METRIC_VALUE(kMetricS3GetObjectErrors); - } - RECORD_METRIC_VALUE(kMetricS3GetObjectRetries, outcome.GetRetryCount()); - RECORD_METRIC_VALUE(kMetricS3ActiveConnections, -1); - VELOX_CHECK_AWS_OUTCOME(outcome, "Failed to get S3 object", bucket_, key_); - } - - Aws::S3::S3Client* client_; - std::string bucket_; - std::string key_; - int64_t length_ = -1; -}; - Aws::Utils::Logging::LogLevel inferS3LogLevel(std::string_view logLevel) { std::string level = std::string(logLevel); // Convert to upper case. @@ -272,239 +109,6 @@ Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy inferPayloadSign( } } // namespace -class S3WriteFile::Impl { - public: - explicit Impl( - std::string_view path, - Aws::S3::S3Client* client, - memory::MemoryPool* pool) - : client_(client), pool_(pool) { - VELOX_CHECK_NOT_NULL(client); - VELOX_CHECK_NOT_NULL(pool); - getBucketAndKeyFromPath(path, bucket_, key_); - currentPart_ = std::make_unique>(*pool_); - currentPart_->reserve(kPartUploadSize); - // Check that the object doesn't exist, if it does throw an error. - { - Aws::S3::Model::HeadObjectRequest request; - request.SetBucket(awsString(bucket_)); - request.SetKey(awsString(key_)); - RECORD_METRIC_VALUE(kMetricS3MetadataCalls); - auto objectMetadata = client_->HeadObject(request); - if (!objectMetadata.IsSuccess()) { - RECORD_METRIC_VALUE(kMetricS3GetMetadataErrors); - } - RECORD_METRIC_VALUE( - kMetricS3GetObjectRetries, objectMetadata.GetRetryCount()); - VELOX_CHECK(!objectMetadata.IsSuccess(), "S3 object already exists"); - } - - // Create bucket if not present. - { - Aws::S3::Model::HeadBucketRequest request; - request.SetBucket(awsString(bucket_)); - auto bucketMetadata = client_->HeadBucket(request); - if (!bucketMetadata.IsSuccess()) { - Aws::S3::Model::CreateBucketRequest request; - request.SetBucket(bucket_); - auto outcome = client_->CreateBucket(request); - VELOX_CHECK_AWS_OUTCOME( - outcome, "Failed to create S3 bucket", bucket_, ""); - } - } - - // Initiate the multi-part upload. - { - Aws::S3::Model::CreateMultipartUploadRequest request; - request.SetBucket(awsString(bucket_)); - request.SetKey(awsString(key_)); - - /// If we do not set anything then the SDK will default to application/xml - /// which confuses some tools - /// (https://github.com/apache/arrow/issues/11934). So we instead default - /// to application/octet-stream which is less misleading. - request.SetContentType(kApplicationOctetStream); - // The default algorithm used is MD5. However, MD5 is not supported with - // fips and can cause a SIGSEGV. Set CRC32 instead which is a standard for - // checksum computation and is not restricted by fips. - request.SetChecksumAlgorithm(Aws::S3::Model::ChecksumAlgorithm::CRC32); - - auto outcome = client_->CreateMultipartUpload(request); - VELOX_CHECK_AWS_OUTCOME( - outcome, "Failed initiating multiple part upload", bucket_, key_); - uploadState_.id = outcome.GetResult().GetUploadId(); - } - - fileSize_ = 0; - } - - // Appends data to the end of the file. - void append(std::string_view data) { - VELOX_CHECK(!closed(), "File is closed"); - if (data.size() + currentPart_->size() >= kPartUploadSize) { - upload(data); - } else { - // Append to current part. - currentPart_->unsafeAppend(data.data(), data.size()); - } - fileSize_ += data.size(); - } - - // No-op. - void flush() { - VELOX_CHECK(!closed(), "File is closed"); - /// currentPartSize must be less than kPartUploadSize since - /// append() would have already flushed after reaching kUploadPartSize. - VELOX_CHECK_LT(currentPart_->size(), kPartUploadSize); - } - - // Complete the multipart upload and close the file. - void close() { - if (closed()) { - return; - } - RECORD_METRIC_VALUE(kMetricS3StartedUploads); - uploadPart({currentPart_->data(), currentPart_->size()}, true); - VELOX_CHECK_EQ(uploadState_.partNumber, uploadState_.completedParts.size()); - // Complete the multipart upload. - { - Aws::S3::Model::CompletedMultipartUpload completedUpload; - completedUpload.SetParts(uploadState_.completedParts); - Aws::S3::Model::CompleteMultipartUploadRequest request; - request.SetBucket(awsString(bucket_)); - request.SetKey(awsString(key_)); - request.SetUploadId(uploadState_.id); - request.SetMultipartUpload(std::move(completedUpload)); - - auto outcome = client_->CompleteMultipartUpload(request); - if (outcome.IsSuccess()) { - RECORD_METRIC_VALUE(kMetricS3SuccessfulUploads); - } else { - RECORD_METRIC_VALUE(kMetricS3FailedUploads); - } - VELOX_CHECK_AWS_OUTCOME( - outcome, "Failed to complete multiple part upload", bucket_, key_); - } - currentPart_->clear(); - } - - // Current file size, i.e. the sum of all previous appends. - uint64_t size() const { - return fileSize_; - } - - int numPartsUploaded() const { - return uploadState_.partNumber; - } - - private: - static constexpr int64_t kPartUploadSize = 10 * 1024 * 1024; - static constexpr const char* kApplicationOctetStream = - "application/octet-stream"; - - bool closed() const { - return (currentPart_->capacity() == 0); - } - - // Holds state for the multipart upload. - struct UploadState { - Aws::Vector completedParts; - int64_t partNumber = 0; - Aws::String id; - }; - UploadState uploadState_; - - // Data can be smaller or larger than the kPartUploadSize. - // Complete the currentPart_ and upload kPartUploadSize chunks of data. - // Save the remaining into currentPart_. - void upload(const std::string_view data) { - auto dataPtr = data.data(); - auto dataSize = data.size(); - // Fill-up the remaining currentPart_. - auto remainingBufferSize = currentPart_->capacity() - currentPart_->size(); - currentPart_->unsafeAppend(dataPtr, remainingBufferSize); - uploadPart({currentPart_->data(), currentPart_->size()}); - dataPtr += remainingBufferSize; - dataSize -= remainingBufferSize; - while (dataSize > kPartUploadSize) { - uploadPart({dataPtr, kPartUploadSize}); - dataPtr += kPartUploadSize; - dataSize -= kPartUploadSize; - } - // Stash the remaining at the beginning of currentPart. - currentPart_->unsafeAppend(0, dataPtr, dataSize); - } - - void uploadPart(const std::string_view part, bool isLast = false) { - // Only the last part can be less than kPartUploadSize. - VELOX_CHECK(isLast || (!isLast && (part.size() == kPartUploadSize))); - // Upload the part. - { - Aws::S3::Model::UploadPartRequest request; - request.SetBucket(bucket_); - request.SetKey(key_); - request.SetUploadId(uploadState_.id); - request.SetPartNumber(++uploadState_.partNumber); - request.SetContentLength(part.size()); - request.SetBody( - std::make_shared(part.data(), part.size())); - // The default algorithm used is MD5. However, MD5 is not supported with - // fips and can cause a SIGSEGV. Set CRC32 instead which is a standard for - // checksum computation and is not restricted by fips. - request.SetChecksumAlgorithm(Aws::S3::Model::ChecksumAlgorithm::CRC32); - auto outcome = client_->UploadPart(request); - VELOX_CHECK_AWS_OUTCOME(outcome, "Failed to upload", bucket_, key_); - // Append ETag and part number for this uploaded part. - // This will be needed for upload completion in Close(). - auto result = outcome.GetResult(); - Aws::S3::Model::CompletedPart part; - - part.SetPartNumber(uploadState_.partNumber); - part.SetETag(result.GetETag()); - // Don't add the checksum to the part if the checksum is empty. - // Some filesystems such as IBM COS require this to be not set. - if (!result.GetChecksumCRC32().empty()) { - part.SetChecksumCRC32(result.GetChecksumCRC32()); - } - uploadState_.completedParts.push_back(std::move(part)); - } - } - - Aws::S3::S3Client* client_; - memory::MemoryPool* pool_; - std::unique_ptr> currentPart_; - std::string bucket_; - std::string key_; - size_t fileSize_ = -1; -}; - -S3WriteFile::S3WriteFile( - std::string_view path, - Aws::S3::S3Client* client, - memory::MemoryPool* pool) { - impl_ = std::make_shared(path, client, pool); -} - -void S3WriteFile::append(std::string_view data) { - return impl_->append(data); -} - -void S3WriteFile::flush() { - impl_->flush(); -} - -void S3WriteFile::close() { - impl_->close(); -} - -uint64_t S3WriteFile::size() const { - return impl_->size(); -} - -int S3WriteFile::numPartsUploaded() const { - return impl_->numPartsUploaded(); -} - // Initialize and Finalize the AWS SDK C++ library. // Initialization must be done before creating a S3FileSystem. // Finalization must be done after all S3FileSystem instances have been deleted. @@ -628,6 +232,12 @@ class S3FileSystem::Impl { Impl(const S3Config& s3Config) { VELOX_CHECK(getAwsInstance()->isInitialized(), "S3 is not initialized"); Aws::S3::S3ClientConfiguration clientConfig; + // Required for AWS CLI object operations on OCI and MinIO due to checksum + // handling. + clientConfig.checksumConfig.requestChecksumCalculation = + Aws::Client::RequestChecksumCalculation::WHEN_REQUIRED; + clientConfig.checksumConfig.responseChecksumValidation = + Aws::Client::ResponseChecksumValidation::WHEN_REQUIRED; if (s3Config.endpoint().has_value()) { clientConfig.endpointOverride = s3Config.endpoint().value(); } @@ -879,4 +489,92 @@ std::string S3FileSystem::name() const { return "S3"; } +std::vector S3FileSystem::list(std::string_view path) { + std::string bucket; + std::string key; + getBucketAndKeyFromPath(getPath(path), bucket, key); + + Aws::S3::Model::ListObjectsRequest request; + request.SetBucket(awsString(bucket)); + request.SetPrefix(awsString(key)); + + auto outcome = impl_->s3Client()->ListObjects(request); + VELOX_CHECK_AWS_OUTCOME( + outcome, "Failed to list objects in S3 bucket", bucket, key); + + std::vector objectKeys; + const auto& result = outcome.GetResult(); + for (const auto& object : result.GetContents()) { + objectKeys.emplace_back(object.GetKey()); + } + + return objectKeys; +} + +bool S3FileSystem::exists(std::string_view path) { + std::string bucket; + std::string key; + getBucketAndKeyFromPath(getPath(path), bucket, key); + + Aws::S3::Model::HeadObjectRequest request; + request.SetBucket(awsString(bucket)); + request.SetKey(awsString(key)); + + return impl_->s3Client()->HeadObject(request).IsSuccess(); +} + +void S3FileSystem::mkdir( + std::string_view path, + const DirectoryOptions& options) { + std::string bucket; + std::string key; + getBucketAndKeyFromPath(getPath(path), bucket, key); + + Aws::S3::Model::PutObjectRequest request; + request.SetBucket(awsString(bucket)); + request.SetKey(awsString(key)); + + VELOX_CHECK_AWS_OUTCOME( + impl_->s3Client()->PutObject(request), + "Failed to mkdir objects in S3 bucket", + bucket, + key); +} + +void S3FileSystem::rename( + std::string_view path, + std::string_view newPath, + bool overWrite) { + std::string sourceBucket; + std::string sourceKey; + getBucketAndKeyFromPath(getPath(path), sourceBucket, sourceKey); + + std::string targetBucket; + std::string targetKey; + getBucketAndKeyFromPath(getPath(newPath), targetBucket, targetKey); + + // Copies the object to the new location. + Aws::S3::Model::CopyObjectRequest copyRequest; + copyRequest.SetCopySource(awsString(sourceBucket + "/" + sourceKey)); + copyRequest.SetBucket(awsString(targetBucket)); + copyRequest.SetKey(awsString(targetKey)); + + VELOX_CHECK_AWS_OUTCOME( + impl_->s3Client()->CopyObject(copyRequest), + "Failed to copy object in S3 during rename", + sourceBucket, + sourceKey); + + // Deletes the original object. + Aws::S3::Model::DeleteObjectRequest deleteRequest; + deleteRequest.SetBucket(awsString(sourceBucket)); + deleteRequest.SetKey(awsString(sourceKey)); + + VELOX_CHECK_AWS_OUTCOME( + impl_->s3Client()->DeleteObject(deleteRequest), + "Failed to delete original object in S3 during rename", + sourceBucket, + sourceKey); +} + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.h b/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.h index 8ca6634b5b3a..f121223f0f3b 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3FileSystem.h @@ -64,25 +64,20 @@ class S3FileSystem : public FileSystem { VELOX_UNSUPPORTED("remove for S3 not implemented"); } + // Renames the path. void rename( std::string_view path, std::string_view newPath, - bool overWrite = false) override { - VELOX_UNSUPPORTED("rename for S3 not implemented"); - } + bool overWrite = false) override; - bool exists(std::string_view path) override { - VELOX_UNSUPPORTED("exists for S3 not implemented"); - } + /// Checks that the path exists. + bool exists(std::string_view path) override; - std::vector list(std::string_view path) override { - VELOX_UNSUPPORTED("list for S3 not implemented"); - } + /// List the objects associated to a path. + std::vector list(std::string_view path) override; void mkdir(std::string_view path, const DirectoryOptions& options = {}) - override { - VELOX_UNSUPPORTED("mkdir for S3 not implemented"); - } + override; void rmdir(std::string_view path) override { VELOX_UNSUPPORTED("rmdir for S3 not implemented"); diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp new file mode 100644 index 000000000000..38d66318f3e2 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.cpp @@ -0,0 +1,222 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/connectors/hive/storage_adapters/s3fs/S3Counters.h" +#include "velox/connectors/hive/storage_adapters/s3fs/S3Util.h" + +#include +#include +#include +#include + +namespace facebook::velox::filesystems { + +namespace { + +// By default, the AWS SDK reads object data into an auto-growing StringStream. +// To avoid copies, read directly into a pre-allocated buffer instead. +// See https://github.com/aws/aws-sdk-cpp/issues/64 for an alternative but +// functionally similar recipe. +Aws::IOStreamFactory AwsWriteableStreamFactory(void* data, int64_t nbytes) { + return [=]() { return Aws::New("", data, nbytes); }; +} + +} // namespace + +class S3ReadFile ::Impl { + public: + explicit Impl(std::string_view path, Aws::S3::S3Client* client) + : client_(client) { + getBucketAndKeyFromPath(path, bucket_, key_); + } + + // Gets the length of the file. + // Checks if there are any issues reading the file. + void initialize(const filesystems::FileOptions& options) { + if (options.fileSize.has_value()) { + VELOX_CHECK_GE( + options.fileSize.value(), 0, "File size must be non-negative"); + length_ = options.fileSize.value(); + } + + // Make it a no-op if invoked twice. + if (length_ != -1) { + return; + } + + Aws::S3::Model::HeadObjectRequest request; + request.SetBucket(awsString(bucket_)); + request.SetKey(awsString(key_)); + + RECORD_METRIC_VALUE(kMetricS3MetadataCalls); + auto outcome = client_->HeadObject(request); + if (!outcome.IsSuccess()) { + RECORD_METRIC_VALUE(kMetricS3GetMetadataErrors); + } + RECORD_METRIC_VALUE(kMetricS3GetMetadataRetries, outcome.GetRetryCount()); + VELOX_CHECK_AWS_OUTCOME( + outcome, "Failed to get metadata for S3 object", bucket_, key_); + length_ = outcome.GetResult().GetContentLength(); + VELOX_CHECK_GE(length_, 0); + } + + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buffer, + const FileStorageContext& fileStorageContext) const { + preadInternal(offset, length, static_cast(buffer)); + return {static_cast(buffer), length}; + } + + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { + std::string result(length, 0); + char* position = result.data(); + preadInternal(offset, length, position); + return result; + } + + uint64_t preadv( + uint64_t offset, + const std::vector>& buffers, + const FileStorageContext& fileStorageContext) const { + // 'buffers' contains Ranges(data, size) with some gaps (data = nullptr) in + // between. This call must populate the ranges (except gap ranges) + // sequentially starting from 'offset'. AWS S3 GetObject does not support + // multi-range. AWS S3 also charges by number of read requests and not size. + // The idea here is to use a single read spanning all the ranges and then + // populate individual ranges. We pre-allocate a buffer to support this. + size_t length = 0; + for (const auto range : buffers) { + length += range.size(); + } + // TODO: allocate from a memory pool + std::string result(length, 0); + preadInternal(offset, length, static_cast(result.data())); + size_t resultOffset = 0; + for (auto range : buffers) { + if (range.data()) { + memcpy(range.data(), &(result.data()[resultOffset]), range.size()); + } + resultOffset += range.size(); + } + return length; + } + + uint64_t size() const { + return length_; + } + + uint64_t memoryUsage() const { + // TODO: Check if any buffers are being used by the S3 library + return sizeof(Aws::S3::S3Client) + kS3MaxKeySize + 2 * sizeof(std::string) + + sizeof(int64_t); + } + + bool shouldCoalesce() const { + return false; + } + + std::string getName() const { + return fmt::format("s3://{}/{}", bucket_, key_); + } + + private: + // The assumption here is that "position" has space for at least "length" + // bytes. + void preadInternal(uint64_t offset, uint64_t length, char* position) const { + // Read the desired range of bytes. + Aws::S3::Model::GetObjectRequest request; + Aws::S3::Model::GetObjectResult result; + + request.SetBucket(awsString(bucket_)); + request.SetKey(awsString(key_)); + std::stringstream ss; + ss << "bytes=" << offset << "-" << offset + length - 1; + request.SetRange(awsString(ss.str())); + request.SetResponseStreamFactory( + AwsWriteableStreamFactory(position, length)); + RECORD_METRIC_VALUE(kMetricS3ActiveConnections); + RECORD_METRIC_VALUE(kMetricS3GetObjectCalls); + auto outcome = client_->GetObject(request); + if (!outcome.IsSuccess()) { + RECORD_METRIC_VALUE(kMetricS3GetObjectErrors); + } + RECORD_METRIC_VALUE(kMetricS3GetObjectRetries, outcome.GetRetryCount()); + RECORD_METRIC_VALUE(kMetricS3ActiveConnections, -1); + VELOX_CHECK_AWS_OUTCOME(outcome, "Failed to get S3 object", bucket_, key_); + } + + Aws::S3::S3Client* client_; + std::string bucket_; + std::string key_; + int64_t length_ = -1; +}; + +S3ReadFile::S3ReadFile(std::string_view path, Aws::S3::S3Client* client) { + impl_ = std::make_shared(path, client); +} + +S3ReadFile::~S3ReadFile() = default; + +void S3ReadFile::initialize(const filesystems::FileOptions& options) { + return impl_->initialize(options); +} + +std::string_view S3ReadFile::pread( + uint64_t offset, + uint64_t length, + void* buf, + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, buf, fileStorageContext); +} + +std::string S3ReadFile::pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext) const { + return impl_->pread(offset, length, fileStorageContext); +} + +uint64_t S3ReadFile::preadv( + uint64_t offset, + const std::vector>& buffers, + const FileStorageContext& fileStorageContext) const { + return impl_->preadv(offset, buffers, fileStorageContext); +} + +uint64_t S3ReadFile::size() const { + return impl_->size(); +} + +uint64_t S3ReadFile::memoryUsage() const { + return impl_->memoryUsage(); +} + +bool S3ReadFile::shouldCoalesce() const { + return impl_->shouldCoalesce(); +} + +std::string S3ReadFile::getName() const { + return impl_->getName(); +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h new file mode 100644 index 000000000000..de7eb63f5ada --- /dev/null +++ b/velox/connectors/hive/storage_adapters/s3fs/S3ReadFile.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/file/File.h" + +namespace Aws::S3 { +class S3Client; +} + +namespace facebook::velox::filesystems { + +/// Implementation of s3 read file. +class S3ReadFile : public ReadFile { + public: + S3ReadFile(std::string_view path, Aws::S3::S3Client* client); + + ~S3ReadFile() override; + + std::string_view pread( + uint64_t offset, + uint64_t length, + void* buf, + const FileStorageContext& fileStorageContext = {}) const final; + + std::string pread( + uint64_t offset, + uint64_t length, + const FileStorageContext& fileStorageContext = {}) const final; + + uint64_t preadv( + uint64_t offset, + const std::vector>& buffers, + const FileStorageContext& fileStorageContext = {}) const final; + + uint64_t size() const final; + + uint64_t memoryUsage() const final; + + bool shouldCoalesce() const final; + + std::string getName() const final; + + uint64_t getNaturalReadSize() const final { + return 72 << 20; + } + + void initialize(const filesystems::FileOptions& options); + + private: + void preadInternal(uint64_t offset, uint64_t length, char* position) const; + + class Impl; + std::shared_ptr impl_; +}; + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h index 966c6bfe30a7..125d9cc805b5 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h @@ -28,6 +28,8 @@ #include "velox/common/base/Exceptions.h" +#include + namespace facebook::velox::filesystems { namespace { @@ -202,7 +204,7 @@ std::optional parseAWSStandardRegionName( class S3ProxyConfigurationBuilder { public: S3ProxyConfigurationBuilder(const std::string& s3Endpoint) - : s3Endpoint_(s3Endpoint){}; + : s3Endpoint_(s3Endpoint) {} S3ProxyConfigurationBuilder& useSsl(const bool& useSsl) { useSsl_ = useSsl; @@ -216,11 +218,26 @@ class S3ProxyConfigurationBuilder { bool useSsl_; }; +// Reference: https://issues.apache.org/jira/browse/ARROW-8692 +// https://github.com/apache/arrow/blob/master/cpp/src/arrow/filesystem/s3fs.cc#L843 +// A non-copying iostream. See +// https://stackoverflow.com/questions/35322033/aws-c-sdk-uploadpart-times-out +// https://stackoverflow.com/questions/13059091/creating-an-input-stream-from-constant-memory +class StringViewStream : Aws::Utils::Stream::PreallocatedStreamBuf, + public std::iostream { + public: + StringViewStream(const void* data, int64_t nbytes) + : Aws::Utils::Stream::PreallocatedStreamBuf( + reinterpret_cast(const_cast(data)), + static_cast(nbytes)), + std::iostream(this) {} +}; + } // namespace facebook::velox::filesystems template <> struct fmt::formatter : formatter { - auto format(Aws::Http::HttpResponseCode s, format_context& ctx) { + auto format(Aws::Http::HttpResponseCode s, format_context& ctx) const { return formatter::format(static_cast(s), ctx); } }; diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.cpp b/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.cpp new file mode 100644 index 000000000000..2284ad6b8ef7 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/s3fs/S3WriteFile.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/connectors/hive/storage_adapters/s3fs/S3Counters.h" +#include "velox/connectors/hive/storage_adapters/s3fs/S3Util.h" +#include "velox/dwio/common/DataBuffer.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace facebook::velox::filesystems { + +class S3WriteFile::Impl { + public: + explicit Impl( + std::string_view path, + Aws::S3::S3Client* client, + memory::MemoryPool* pool) + : client_(client), pool_(pool) { + VELOX_CHECK_NOT_NULL(client); + VELOX_CHECK_NOT_NULL(pool); + getBucketAndKeyFromPath(path, bucket_, key_); + currentPart_ = std::make_unique>(*pool_); + currentPart_->reserve(kPartUploadSize); + // Check that the object doesn't exist, if it does throw an error. + { + Aws::S3::Model::HeadObjectRequest request; + request.SetBucket(awsString(bucket_)); + request.SetKey(awsString(key_)); + RECORD_METRIC_VALUE(kMetricS3MetadataCalls); + auto objectMetadata = client_->HeadObject(request); + if (!objectMetadata.IsSuccess()) { + RECORD_METRIC_VALUE(kMetricS3GetMetadataErrors); + } + RECORD_METRIC_VALUE( + kMetricS3GetObjectRetries, objectMetadata.GetRetryCount()); + VELOX_CHECK( + !objectMetadata.IsSuccess(), + "S3 object already exists: bucket={}, key={}", + bucket_, + key_); + } + + // Create bucket if not present. + { + Aws::S3::Model::HeadBucketRequest request; + request.SetBucket(awsString(bucket_)); + auto bucketMetadata = client_->HeadBucket(request); + if (!bucketMetadata.IsSuccess()) { + Aws::S3::Model::CreateBucketRequest request; + request.SetBucket(bucket_); + auto outcome = client_->CreateBucket(request); + VELOX_CHECK_AWS_OUTCOME( + outcome, "Failed to create S3 bucket", bucket_, ""); + } + } + + // Initiate the multi-part upload. + { + Aws::S3::Model::CreateMultipartUploadRequest request; + request.SetBucket(awsString(bucket_)); + request.SetKey(awsString(key_)); + + /// If we do not set anything then the SDK will default to application/xml + /// which confuses some tools + /// (https://github.com/apache/arrow/issues/11934). So we instead default + /// to application/octet-stream which is less misleading. + request.SetContentType(kApplicationOctetStream); + auto outcome = client_->CreateMultipartUpload(request); + VELOX_CHECK_AWS_OUTCOME( + outcome, "Failed initiating multiple part upload", bucket_, key_); + uploadState_.id = outcome.GetResult().GetUploadId(); + } + + fileSize_ = 0; + } + + // Appends data to the end of the file. + void append(std::string_view data) { + VELOX_CHECK(!closed(), "File is closed"); + if (data.size() + currentPart_->size() >= kPartUploadSize) { + upload(data); + } else { + // Append to current part. + currentPart_->unsafeAppend(data.data(), data.size()); + } + fileSize_ += data.size(); + } + + // No-op. + void flush() { + VELOX_CHECK(!closed(), "File is closed"); + /// currentPartSize must be less than kPartUploadSize since + /// append() would have already flushed after reaching kUploadPartSize. + VELOX_CHECK_LT(currentPart_->size(), kPartUploadSize); + } + + // Complete the multipart upload and close the file. + void close() { + if (closed()) { + return; + } + RECORD_METRIC_VALUE(kMetricS3StartedUploads); + uploadPart({currentPart_->data(), currentPart_->size()}, true); + VELOX_CHECK_EQ(uploadState_.partNumber, uploadState_.completedParts.size()); + // Complete the multipart upload. + { + Aws::S3::Model::CompletedMultipartUpload completedUpload; + completedUpload.SetParts(uploadState_.completedParts); + Aws::S3::Model::CompleteMultipartUploadRequest request; + request.SetBucket(awsString(bucket_)); + request.SetKey(awsString(key_)); + request.SetUploadId(uploadState_.id); + request.SetMultipartUpload(std::move(completedUpload)); + + auto outcome = client_->CompleteMultipartUpload(request); + if (outcome.IsSuccess()) { + RECORD_METRIC_VALUE(kMetricS3SuccessfulUploads); + } else { + RECORD_METRIC_VALUE(kMetricS3FailedUploads); + } + VELOX_CHECK_AWS_OUTCOME( + outcome, "Failed to complete multiple part upload", bucket_, key_); + } + currentPart_->clear(); + } + + // Current file size, i.e. the sum of all previous appends. + uint64_t size() const { + return fileSize_; + } + + int numPartsUploaded() const { + return uploadState_.partNumber; + } + + private: + static constexpr int64_t kPartUploadSize = 10 * 1024 * 1024; + static constexpr const char* kApplicationOctetStream = + "application/octet-stream"; + + bool closed() const { + return (currentPart_->capacity() == 0); + } + + // Holds state for the multipart upload. + struct UploadState { + Aws::Vector completedParts; + int64_t partNumber = 0; + Aws::String id; + }; + UploadState uploadState_; + + // Data can be smaller or larger than the kPartUploadSize. + // Complete the currentPart_ and upload kPartUploadSize chunks of data. + // Save the remaining into currentPart_. + void upload(const std::string_view data) { + auto dataPtr = data.data(); + auto dataSize = data.size(); + // Fill-up the remaining currentPart_. + auto remainingBufferSize = currentPart_->capacity() - currentPart_->size(); + currentPart_->unsafeAppend(dataPtr, remainingBufferSize); + uploadPart({currentPart_->data(), currentPart_->size()}); + dataPtr += remainingBufferSize; + dataSize -= remainingBufferSize; + while (dataSize > kPartUploadSize) { + uploadPart({dataPtr, kPartUploadSize}); + dataPtr += kPartUploadSize; + dataSize -= kPartUploadSize; + } + // Stash the remaining at the beginning of currentPart. + currentPart_->unsafeAppend(0, dataPtr, dataSize); + } + + void uploadPart(const std::string_view part, bool isLast = false) { + // Only the last part can be less than kPartUploadSize. + VELOX_CHECK(isLast || (!isLast && (part.size() == kPartUploadSize))); + // Upload the part. + { + Aws::S3::Model::UploadPartRequest request; + request.SetBucket(bucket_); + request.SetKey(key_); + request.SetUploadId(uploadState_.id); + request.SetPartNumber(++uploadState_.partNumber); + request.SetContentLength(part.size()); + request.SetBody( + std::make_shared(part.data(), part.size())); + auto outcome = client_->UploadPart(request); + VELOX_CHECK_AWS_OUTCOME(outcome, "Failed to upload", bucket_, key_); + // Append ETag and part number for this uploaded part. + // This will be needed for upload completion in Close(). + auto result = outcome.GetResult(); + Aws::S3::Model::CompletedPart part; + + part.SetPartNumber(uploadState_.partNumber); + part.SetETag(result.GetETag()); + // Don't add the checksum to the part if the checksum is empty. + // Some filesystems such as IBM COS require this to be not set. + if (!result.GetChecksumCRC32().empty()) { + part.SetChecksumCRC32(result.GetChecksumCRC32()); + } + uploadState_.completedParts.push_back(std::move(part)); + } + } + + Aws::S3::S3Client* client_; + memory::MemoryPool* pool_; + std::unique_ptr> currentPart_; + std::string bucket_; + std::string key_; + size_t fileSize_ = -1; +}; + +S3WriteFile::S3WriteFile( + std::string_view path, + Aws::S3::S3Client* client, + memory::MemoryPool* pool) { + impl_ = std::make_shared(path, client, pool); +} + +void S3WriteFile::append(std::string_view data) { + return impl_->append(data); +} + +void S3WriteFile::flush() { + impl_->flush(); +} + +void S3WriteFile::close() { + impl_->close(); +} + +uint64_t S3WriteFile::size() const { + return impl_->size(); +} + +int S3WriteFile::numPartsUploaded() const { + return impl_->numPartsUploaded(); +} + +} // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt index f2bc59aa20f9..9d92727e7673 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/CMakeLists.txt @@ -19,7 +19,8 @@ target_link_libraries( velox_common_config velox_s3fs GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_s3file_test S3FileSystemTest.cpp S3UtilTest.cpp) add_test(velox_s3file_test velox_s3file_test) @@ -32,7 +33,8 @@ target_link_libraries( velox_dwio_common_exception velox_exec GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_s3registration_test S3FileSystemRegistrationTest.cpp) add_test(velox_s3registration_test velox_s3registration_test) @@ -46,7 +48,8 @@ target_link_libraries( velox_dwio_common_exception velox_exec GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_s3finalize_test S3FileSystemFinalizeTest.cpp) add_test(velox_s3finalize_test velox_s3finalize_test) @@ -56,7 +59,8 @@ target_link_libraries( velox_file velox_core GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_s3insert_test S3InsertTest.cpp) add_test(velox_s3insert_test velox_s3insert_test) @@ -71,13 +75,15 @@ target_link_libraries( velox_dwio_common_exception velox_exec GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_s3read_test S3ReadTest.cpp) add_test( NAME velox_s3read_test COMMAND velox_s3read_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_s3read_test velox_file @@ -88,7 +94,8 @@ target_link_libraries( velox_dwio_common_exception velox_exec GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_s3metrics_test S3FileSystemMetricsTest.cpp) add_test(velox_s3metrics_test velox_s3metrics_test) @@ -97,7 +104,8 @@ target_link_libraries( velox_s3fs velox_exec_test_lib GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_s3multiendpoints_test S3MultipleEndpointsTest.cpp) add_test(velox_s3multiendpoints_test velox_s3multiendpoints_test) @@ -112,4 +120,5 @@ target_link_libraries( velox_dwio_common_exception velox_exec GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp index ca941a5feb3b..907acf927838 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemMetricsTest.cpp @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include +#include #include "velox/common/memory/Memory.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" @@ -23,8 +25,9 @@ #include -namespace facebook::velox::filesystems { +namespace facebook::velox::filesystems::test { namespace { + class S3TestReporter : public BaseStatsReporter { public: mutable std::mutex m; @@ -39,6 +42,7 @@ class S3TestReporter : public BaseStatsReporter { statTypeMap.clear(); histogramPercentilesMap.clear(); } + void registerMetricExportType(const char* key, StatType statType) const override { statTypeMap[key] = statType; @@ -168,7 +172,7 @@ TEST_F(S3FileSystemMetricsTest, metrics) { EXPECT_EQ(1, s3Reporter->counterMap[std::string{kMetricS3GetObjectCalls}]); } -} // namespace facebook::velox::filesystems +} // namespace facebook::velox::filesystems::test int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp index c0d106c44187..256fe6e481e5 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemRegistrationTest.cpp @@ -69,9 +69,10 @@ TEST_F(S3FileSystemRegistrationTest, fileHandle) { } auto hiveConfig = minioServer_->hiveConfig(); FileHandleFactory factory( - std::make_unique>(1000), + std::make_unique>(1000), std::make_unique(hiveConfig)); - auto fileHandleCachePtr = factory.generate(s3File); + FileHandleKey key{s3File}; + auto fileHandleCachePtr = factory.generate(key); readData(fileHandleCachePtr->file.get()); } diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp index 2fdc115609e1..a94e9b4c3523 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp @@ -175,7 +175,7 @@ TEST_F(S3FileSystemTest, noBackendServer) { minioServer_->stop(); VELOX_ASSERT_THROW( s3fs.openFileForRead(kDummyPath), - "Failed to get metadata for S3 object due to: 'Network connection'. Path:'s3://dummy/foo.txt', SDK Error Type:99, HTTP Status Code:-1, S3 Service:'Unknown', Message:'curlCode: 7, Couldn't connect to server'"); + "Failed to get metadata for S3 object due to: 'Network connection'. Path:'s3://dummy/foo.txt', SDK Error Type:99, HTTP Status Code:-1, S3 Service:'Unknown', Message:'curlCode: 7, Couldn't connect to server"); // Start Minio again. minioServer_->start(); } @@ -219,6 +219,27 @@ TEST_F(S3FileSystemTest, logLocation) { checkLogPrefix(expected); } +TEST_F(S3FileSystemTest, mkdirAndRename) { + const auto bucketName = "mkdir"; + const auto file = "mkdir-test.txt"; + const auto s3File = s3URI(bucketName, file); + addBucket(bucketName); + + auto hiveConfig = minioServer_->hiveConfig(); + filesystems::S3FileSystem s3fs(bucketName, hiveConfig); + + ASSERT_FALSE(s3fs.exists(s3File)); + s3fs.mkdir(s3File); + ASSERT_TRUE(s3fs.exists(s3File)); + + // Rename test + const auto renameFile = "rename-test.txt"; + const auto s3RenameFile = s3URI(bucketName, renameFile); + s3fs.rename(s3File, s3RenameFile); + ASSERT_TRUE(s3fs.exists(s3RenameFile)); + ASSERT_FALSE(s3fs.exists(s3File)); +} + TEST_F(S3FileSystemTest, writeFileAndRead) { const auto bucketName = "writedata"; const auto file = "test.txt"; @@ -290,6 +311,14 @@ TEST_F(S3FileSystemTest, writeFileAndRead) { } // Verify the last chunk. ASSERT_EQ(readFile->pread(contentSize * 250'000, contentSize), dataContent); + + // Verify the S3 list function. + auto result = s3fs.list(s3File); + + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(result[0] == file); + + ASSERT_TRUE(s3fs.exists(s3File)); } TEST_F(S3FileSystemTest, invalidConnectionSettings) { @@ -340,4 +369,5 @@ TEST_F(S3FileSystemTest, registerCredentialProviderFactories) { }), "CredentialsProviderFactory 'my-credentials-provider' already registered"); } + } // namespace facebook::velox::filesystems diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp index cb6e8e783473..cf446b12d2e2 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3MultipleEndpointsTest.cpp @@ -17,6 +17,7 @@ #include #include "gtest/gtest.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/S3Util.h" #include "velox/connectors/hive/storage_adapters/s3fs/tests/S3Test.h" @@ -52,8 +53,6 @@ class S3MultipleEndpoints : public S3Test, public ::test::VectorTestBase { minioSecondServer_->addBucket(kBucketName.data()); filesystems::registerS3FileSystem(); - connector::registerConnectorFactory( - std::make_shared()); parquet::registerParquetReaderFactory(); parquet::registerParquetWriterFactory(); } @@ -63,20 +62,15 @@ class S3MultipleEndpoints : public S3Test, public ::test::VectorTestBase { std::string_view connectorId2, const std::unordered_map config1Override = {}, const std::unordered_map config2Override = {}) { - auto hiveConnector1 = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - std::string(connectorId1), - minioServer_->hiveConfig(config1Override), - ioExecutor_.get()); - auto hiveConnector2 = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - std::string(connectorId2), - minioSecondServer_->hiveConfig(config2Override), - ioExecutor_.get()); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector1 = factory.newConnector( + std::string(connectorId1), + minioServer_->hiveConfig(config1Override), + ioExecutor_.get()); + auto hiveConnector2 = factory.newConnector( + std::string(connectorId2), + minioSecondServer_->hiveConfig(config2Override), + ioExecutor_.get()); connector::registerConnector(hiveConnector1); connector::registerConnector(hiveConnector2); } @@ -84,8 +78,6 @@ class S3MultipleEndpoints : public S3Test, public ::test::VectorTestBase { void TearDown() override { parquet::unregisterParquetReaderFactory(); parquet::unregisterParquetWriterFactory(); - connector::unregisterConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName); S3Test::TearDown(); } diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp index 7abcbfb7e56b..bbd084837514 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3ReadTest.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/memory/Memory.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/storage_adapters/s3fs/RegisterS3FileSystem.h" #include "velox/connectors/hive/storage_adapters/s3fs/tests/S3Test.h" #include "velox/dwio/common/tests/utils/DataFiles.h" @@ -39,12 +40,9 @@ class S3ReadTest : public S3Test, public ::test::VectorTestBase { void SetUp() override { S3Test::SetUp(); filesystems::registerS3FileSystem(); - connector::registerConnectorFactory( - std::make_shared()); + connector::hive::HiveConnectorFactory factory; auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector(kHiveConnectorId, minioServer_->hiveConfig()); + factory.newConnector(kHiveConnectorId, minioServer_->hiveConfig()); connector::registerConnector(hiveConnector); parquet::registerParquetReaderFactory(); } @@ -52,8 +50,6 @@ class S3ReadTest : public S3Test, public ::test::VectorTestBase { void TearDown() override { parquet::unregisterParquetReaderFactory(); filesystems::finalizeS3FileSystem(); - connector::unregisterConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName); connector::unregisterConnector(kHiveConnectorId); S3Test::TearDown(); } diff --git a/velox/connectors/hive/storage_adapters/test_common/InsertTest.h b/velox/connectors/hive/storage_adapters/test_common/InsertTest.h index 0bef0f09b6ae..0700ca7a334e 100644 --- a/velox/connectors/hive/storage_adapters/test_common/InsertTest.h +++ b/velox/connectors/hive/storage_adapters/test_common/InsertTest.h @@ -18,6 +18,7 @@ #include #include "velox/common/memory/Memory.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/parquet/RegisterParquetReader.h" #include "velox/dwio/parquet/RegisterParquetWriter.h" #include "velox/exec/TableWriter.h" @@ -33,13 +34,9 @@ class InsertTest : public velox::test::VectorTestBase { void SetUp( std::shared_ptr hiveConfig, folly::Executor* ioExecutor) { - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - exec::test::kHiveConnectorId, hiveConfig, ioExecutor); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + exec::test::kHiveConnectorId, hiveConfig, ioExecutor); connector::registerConnector(hiveConnector); parquet::registerParquetReaderFactory(); @@ -49,8 +46,7 @@ class InsertTest : public velox::test::VectorTestBase { void TearDown() { parquet::unregisterParquetReaderFactory(); parquet::unregisterParquetWriterFactory(); - connector::unregisterConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName); + connector::unregisterConnector(exec::test::kHiveConnectorId); } diff --git a/velox/connectors/hive/tests/CMakeLists.txt b/velox/connectors/hive/tests/CMakeLists.txt index 0ff796e9077e..2662fe6824fc 100644 --- a/velox/connectors/hive/tests/CMakeLists.txt +++ b/velox/connectors/hive/tests/CMakeLists.txt @@ -20,10 +20,11 @@ add_executable( HiveConnectorUtilTest.cpp HiveConnectorSerDeTest.cpp HivePartitionFunctionTest.cpp - HivePartitionUtilTest.cpp + HivePartitionNameTest.cpp HiveSplitTest.cpp PartitionIdGeneratorTest.cpp - TableHandleTest.cpp) + TableHandleTest.cpp +) add_test(velox_hive_connector_test velox_hive_connector_test) target_link_libraries( @@ -36,13 +37,14 @@ target_link_libraries( velox_exec velox_exec_test_lib GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) if(VELOX_ENABLE_PARQUET) - - target_include_directories(velox_hive_connector_test - PUBLIC ${ARROW_PREFIX}/install/include) - target_link_libraries(velox_hive_connector_test velox_dwio_parquet_writer - velox_dwio_parquet_reader) - + target_include_directories(velox_hive_connector_test PUBLIC ${ARROW_PREFIX}/install/include) + target_link_libraries( + velox_hive_connector_test + velox_dwio_parquet_writer + velox_dwio_parquet_reader + ) endif() diff --git a/velox/connectors/hive/tests/FileHandleTest.cpp b/velox/connectors/hive/tests/FileHandleTest.cpp index fcbea2c0fe6e..6c3e71c42fe2 100644 --- a/velox/connectors/hive/tests/FileHandleTest.cpp +++ b/velox/connectors/hive/tests/FileHandleTest.cpp @@ -37,9 +37,10 @@ TEST(FileHandleTest, localFile) { } FileHandleFactory factory( - std::make_unique>(1000), + std::make_unique>(1000), std::make_unique()); - auto fileHandle = factory.generate(filename); + FileHandleKey key{filename}; + auto fileHandle = factory.generate(key); ASSERT_EQ(fileHandle->file->size(), 3); char buffer[3]; ASSERT_EQ(fileHandle->file->pread(0, 3, &buffer), "foo"); @@ -61,12 +62,13 @@ TEST(FileHandleTest, localFileWithProperties) { } FileHandleFactory factory( - std::make_unique>(1000), + std::make_unique>(1000), std::make_unique()); FileProperties properties = { .fileSize = tempFile->fileSize(), .modificationTime = tempFile->fileModifiedTime()}; - auto fileHandle = factory.generate(filename, &properties); + FileHandleKey key{filename}; + auto fileHandle = factory.generate(key, &properties); ASSERT_EQ(fileHandle->file->size(), 3); char buffer[3]; ASSERT_EQ(fileHandle->file->pread(0, 3, &buffer), "foo"); diff --git a/velox/connectors/hive/tests/HiveConfigTest.cpp b/velox/connectors/hive/tests/HiveConfigTest.cpp index ae76870f492f..bdba80eefb0b 100644 --- a/velox/connectors/hive/tests/HiveConfigTest.cpp +++ b/velox/connectors/hive/tests/HiveConfigTest.cpp @@ -23,8 +23,9 @@ using namespace facebook::velox::connector::hive; using facebook::velox::connector::hive::HiveConfig; TEST(HiveConfigTest, defaultConfig) { - HiveConfig hiveConfig(std::make_shared( - std::unordered_map())); + HiveConfig hiveConfig( + std::make_shared( + std::unordered_map())); const auto emptySession = std::make_unique( std::unordered_map()); ASSERT_EQ( @@ -53,6 +54,7 @@ TEST(HiveConfigTest, defaultConfig) { ASSERT_TRUE(hiveConfig.isPartitionPathAsLowerCase(emptySession.get())); ASSERT_TRUE(hiveConfig.allowNullPartitionKeys(emptySession.get())); ASSERT_EQ(hiveConfig.loadQuantum(emptySession.get()), 8 << 20); + ASSERT_FALSE(hiveConfig.preserveFlatMapsInMemory(emptySession.get())); } TEST(HiveConfigTest, overrideConfig) { @@ -74,7 +76,9 @@ TEST(HiveConfigTest, overrideConfig) { {HiveConfig::kSortWriterMaxOutputBytes, "100MB"}, {HiveConfig::kSortWriterFinishTimeSliceLimitMs, "400"}, {HiveConfig::kReadStatsBasedFilterReorderDisabled, "true"}, - {HiveConfig::kLoadQuantum, std::to_string(4 << 20)}}; + {HiveConfig::kLoadQuantum, std::to_string(4 << 20)}, + {HiveConfig::kMaxBucketCount, std::to_string(100'000)}, + {HiveConfig::kPreserveFlatMapsInMemory, "true"}}; HiveConfig hiveConfig( std::make_shared(std::move(configFromFile))); auto emptySession = std::make_shared( @@ -104,11 +108,14 @@ TEST(HiveConfigTest, overrideConfig) { ASSERT_TRUE( hiveConfig.readStatsBasedFilterReorderDisabled(emptySession.get())); ASSERT_EQ(hiveConfig.loadQuantum(emptySession.get()), 4 << 20); + ASSERT_EQ(hiveConfig.maxBucketCount(emptySession.get()), 100'000); + ASSERT_TRUE(hiveConfig.preserveFlatMapsInMemory(emptySession.get())); } TEST(HiveConfigTest, overrideSession) { - HiveConfig hiveConfig(std::make_shared( - std::unordered_map())); + HiveConfig hiveConfig( + std::make_shared( + std::unordered_map())); std::unordered_map sessionOverride = { {HiveConfig::kInsertExistingPartitionsBehaviorSession, "OVERWRITE"}, {HiveConfig::kOrcUseColumnNamesSession, "true"}, @@ -121,7 +128,9 @@ TEST(HiveConfigTest, overrideSession) { {HiveConfig::kAllowNullPartitionKeysSession, "false"}, {HiveConfig::kIgnoreMissingFilesSession, "true"}, {HiveConfig::kReadStatsBasedFilterReorderDisabledSession, "true"}, - {HiveConfig::kLoadQuantumSession, std::to_string(4 << 20)}}; + {HiveConfig::kLoadQuantumSession, std::to_string(4 << 20)}, + {HiveConfig::kPreserveFlatMapsInMemorySession, "true"}, + }; const auto session = std::make_unique(std::move(sessionOverride)); ASSERT_EQ( @@ -147,4 +156,5 @@ TEST(HiveConfigTest, overrideSession) { ASSERT_TRUE(hiveConfig.ignoreMissingFiles(session.get())); ASSERT_TRUE(hiveConfig.readStatsBasedFilterReorderDisabled(session.get())); ASSERT_EQ(hiveConfig.loadQuantum(session.get()), 4 << 20); + ASSERT_TRUE(hiveConfig.preserveFlatMapsInMemory(session.get())); } diff --git a/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp b/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp index 4c5ce7132946..ebe71e7b4b04 100644 --- a/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorSerDeTest.cpp @@ -15,10 +15,10 @@ */ #include -#include "velox/connectors/Connector.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/expression/ExprToSubfieldFilter.h" +#include "velox/type/tests/SubfieldFiltersBuilder.h" namespace facebook::velox::connector::hive::test { namespace { @@ -31,14 +31,7 @@ class HiveConnectorSerDeTest : public exec::test::HiveConnectorTestBase { Type::registerSerDe(); common::Filter::registerSerDe(); core::ITypedExpr::registerSerDe(); - HiveTableHandle::registerSerDe(); - HiveColumnHandle::registerSerDe(); - LocationHandle::registerSerDe(); - HiveInsertTableHandle::registerSerDe(); - HiveBucketProperty::registerSerDe(); - HiveSortingColumn::registerSerDe(); - HiveConnectorSplit::registerSerDe(); - HiveInsertFileNameGenerator::registerSerDe(); + HiveConnector::registerSerDe(); } template @@ -161,6 +154,7 @@ TEST_F(HiveConnectorSerDeTest, hiveColumnHandle) { HiveColumnHandle::ColumnType::kRegular, HiveColumnHandle::ColumnType::kSynthesized, HiveColumnHandle::ColumnType::kRowIndex, + HiveColumnHandle::ColumnType::kRowId, }; for (auto columnHandleType : columnHandleTypes) { diff --git a/velox/connectors/hive/tests/HiveConnectorTest.cpp b/velox/connectors/hive/tests/HiveConnectorTest.cpp index 4834cf3b8335..1a6934077907 100644 --- a/velox/connectors/hive/tests/HiveConnectorTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorTest.cpp @@ -21,6 +21,7 @@ #include "velox/connectors/hive/HiveConfig.h" #include "velox/connectors/hive/HiveConnectorUtil.h" #include "velox/connectors/hive/HiveDataSource.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::connector::hive { @@ -64,7 +65,7 @@ groupSubfields(const std::vector& subfields) { } bool mapKeyIsNotNull(const ScanSpec& mapSpec) { - return dynamic_cast( + return dynamic_cast( mapSpec.childByName(ScanSpec::kMapKeysFieldName)->filter()); } @@ -92,8 +93,10 @@ TEST_F(HiveConnectorTest, makeScanSpecRequiredSubfieldsMultilevel) { auto rowType = ROW({{"c0", columnType}}); auto subfields = makeSubfields({"c0.c0c1[3][\"foo\"].c0c1c0"}); for (bool statsBasedFilterReorderDisabled : {false, true}) { - SCOPED_TRACE(fmt::format( - "statsBasedFilterReorderDisabled {}", statsBasedFilterReorderDisabled)); + SCOPED_TRACE( + fmt::format( + "statsBasedFilterReorderDisabled {}", + statsBasedFilterReorderDisabled)); auto scanSpec = makeScanSpec( rowType, @@ -581,8 +584,8 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) { auto expr = parseExpr("not (c0 > 0 or c1 > 0)", rowType); SubfieldFilters filters; double sampleRate = 1; - auto remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + auto remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 1); ASSERT_EQ(filters.size(), 2); @@ -591,8 +594,8 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) { expr = parseExpr("not (c0 > 0 or c1 > c0)", rowType); filters.clear(); - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_EQ(sampleRate, 1); ASSERT_EQ(filters.size(), 1); ASSERT_GT(filters.count(Subfield("c0")), 0); @@ -602,14 +605,49 @@ TEST_F(HiveConnectorTest, extractFiltersFromRemainingFilter) { expr = parseExpr( "not (c2 > 1::decimal(20, 0) or c2 < 0::decimal(20, 0))", rowType); filters.clear(); - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_EQ(sampleRate, 1); ASSERT_GT(filters.count(Subfield("c2")), 0); // Change these once HUGEINT filter merge is fixed. ASSERT_TRUE(remaining); ASSERT_EQ( - remaining->toString(), "not(lt(ROW[\"c2\"],cast 0 as DECIMAL(20, 0)))"); + remaining->toString(), "not(lt(ROW[\"c2\"],cast(0 as DECIMAL(20, 0))))"); + + // parseExpr gives AND/OR with 2 arguments. We need to construct the node + // manually to have more than 2. + expr = std::make_shared( + BOOLEAN(), + expression::kAnd, + parseExpr("c0 > 0", rowType), + parseExpr("c1 > 0", rowType), + parseExpr("c2 > 0::decimal(20, 0)", rowType)); + filters.clear(); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 3); + ASSERT_TRUE(filters.contains(Subfield("c0"))); + ASSERT_TRUE(filters.contains(Subfield("c1"))); + ASSERT_TRUE(filters.contains(Subfield("c2"))); + ASSERT_FALSE(remaining); + + expr = std::make_shared( + BOOLEAN(), + expression::kAnd, + parseExpr("c0 % 2 = 0", rowType), + parseExpr("c1 % 3 = 0", rowType), + parseExpr("c2 > 0::decimal(20, 0)", rowType)); + filters.clear(); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_TRUE(filters.contains(Subfield("c2"))); + ASSERT_TRUE(remaining); + ASSERT_EQ( + remaining->toString(), + "and(eq(mod(ROW[\"c0\"],2),0),eq(mod(ROW[\"c1\"],3),0))"); } TEST_F(HiveConnectorTest, prestoTableSampling) { @@ -620,8 +658,8 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { auto expr = parseExpr("rand() < 0.5", rowType); SubfieldFilters filters; double sampleRate = 1; - auto remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + auto remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 0.5); ASSERT_TRUE(filters.empty()); @@ -629,8 +667,8 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { expr = parseExpr("c0 > 0 and rand() < 0.5", rowType); filters.clear(); sampleRate = 1; - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 0.5); ASSERT_EQ(filters.size(), 1); @@ -639,8 +677,8 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { expr = parseExpr("rand() < 0.5 and rand() < 0.5", rowType); filters.clear(); sampleRate = 1; - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_FALSE(remaining); ASSERT_EQ(sampleRate, 0.25); ASSERT_TRUE(filters.empty()); @@ -648,13 +686,73 @@ TEST_F(HiveConnectorTest, prestoTableSampling) { expr = parseExpr("c0 > 0 or rand() < 0.5", rowType); filters.clear(); sampleRate = 1; - remaining = extractFiltersFromRemainingFilter( - expr, &evaluator, false, filters, sampleRate); + remaining = + extractFiltersFromRemainingFilter(expr, &evaluator, filters, sampleRate); ASSERT_TRUE(remaining); ASSERT_EQ(*remaining, *expr); ASSERT_EQ(sampleRate, 1); ASSERT_TRUE(filters.empty()); } +#define VELOX_ASSERT_FILTER(expected, actual) \ + ASSERT_TRUE(expected->testingEquals(*actual)) \ + << expected->toString() << " vs " << actual->toString(); + +TEST_F(HiveConnectorTest, disjuncts) { + auto queryCtx = core::QueryCtx::create(); + exec::SimpleExpressionEvaluator evaluator(queryCtx.get(), pool_.get()); + auto rowType = ROW({"c0", "c1", "c2"}, {BIGINT(), BIGINT(), DECIMAL(20, 0)}); + + { + auto expr = + parseExpr("(c0 > 0 and c0 < 10) or (c0 > 5 and c0 < 15)", rowType); + + SubfieldFilters filters; + double sampleRate = 1; + auto remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_TRUE(remaining == nullptr); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + VELOX_ASSERT_FILTER(exec::between(1, 14), filters.begin()->second); + } + + { + auto expr = parseExpr("(c0 between -10 and 12)", rowType); + + SubfieldFilters filters; + double sampleRate = 1; + auto remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_TRUE(remaining == nullptr); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_EQ(filters.begin()->first, Subfield("c0")); + VELOX_ASSERT_FILTER(exec::between(-10, 12), filters.begin()->second); + + expr = parseExpr("(c0 > 0 and c0 < 10) or (c0 > 5 and c0 < 15)", rowType); + remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_TRUE(remaining == nullptr); + ASSERT_EQ(sampleRate, 1); + ASSERT_EQ(filters.size(), 1); + ASSERT_EQ(filters.begin()->first, Subfield("c0")); + VELOX_ASSERT_FILTER(exec::between(1, 12), filters.begin()->second); + } + + { + auto expr = parseExpr("c0 not in (1, 3) or c0 in (1, 2)", rowType); + SubfieldFilters filters; + double sampleRate = 1; + auto remaining = extractFiltersFromRemainingFilter( + expr, &evaluator, filters, sampleRate); + ASSERT_EQ(remaining, expr); + ASSERT_EQ(sampleRate, 1); + ASSERT_TRUE(filters.empty()); + } +} + +#undef VELOX_ASSERT_FILTER + } // namespace } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp index fd4971bdff12..2c6dd76f4549 100644 --- a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp @@ -336,18 +336,6 @@ TEST_F(HiveConnectorUtilTest, cacheRetention) { } } -TEST_F(HiveConnectorUtilTest, configureRowReaderOptions) { - auto split = - std::make_shared("", "", FileFormat::UNKNOWN); - auto rowType = ROW({{"float_features", MAP(INTEGER(), REAL())}}); - auto spec = std::make_shared(""); - spec->addAllChildFields(*rowType); - auto* float_features = spec->childByName("float_features"); - float_features->childByName(common::ScanSpec::kMapKeysFieldName) - ->setFilter(common::createBigintValues({1, 3}, false)); - float_features->setFlatMapFeatureSelection({"1", "3"}); -} - TEST_F(HiveConnectorUtilTest, configureSstRowReaderOptions) { dwio::common::RowReaderOptions rowReaderOpts; auto hiveSplit = @@ -364,9 +352,118 @@ TEST_F(HiveConnectorUtilTest, configureSstRowReaderOptions) { /*hiveSplit=*/hiveSplit, /*hiveConfig=*/nullptr, /*sessionProperties=*/nullptr, + /*ioExecutor=*/nullptr, /*rowReaderOptions=*/rowReaderOpts); EXPECT_EQ(rowReaderOpts.serdeParameters(), hiveSplit->serdeParameters); } +TEST_F(HiveConnectorUtilTest, configureRowReaderOptionsFromConfig) { + // Test default behavior (preserveFlatMapsInMemory = false) + { + auto hiveConfig = + std::make_shared(std::make_shared( + std::unordered_map())); + config::ConfigBase sessionProperties({}); + + dwio::common::RowReaderOptions rowReaderOpts; + auto hiveSplit = + std::make_shared("", "", FileFormat::DWRF); + + configureRowReaderOptions( + /*tableParameters=*/{}, + /*scanSpec=*/nullptr, + /*metadataFilter=*/nullptr, + /*rowType=*/nullptr, + /*hiveSplit=*/hiveSplit, + /*hiveConfig=*/hiveConfig, + /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, + /*rowReaderOptions=*/rowReaderOpts); + + EXPECT_FALSE(rowReaderOpts.preserveFlatMapsInMemory()); + } + + // Test with config override (preserveFlatMapsInMemory = true) + { + std::unordered_map configProps = { + {hive::HiveConfig::kPreserveFlatMapsInMemory, "true"}}; + auto hiveConfig = std::make_shared( + std::make_shared(std::move(configProps))); + config::ConfigBase sessionProperties({}); + + dwio::common::RowReaderOptions rowReaderOpts; + auto hiveSplit = + std::make_shared("", "", FileFormat::DWRF); + + configureRowReaderOptions( + /*tableParameters=*/{}, + /*scanSpec=*/nullptr, + /*metadataFilter=*/nullptr, + /*rowType=*/nullptr, + /*hiveSplit=*/hiveSplit, + /*hiveConfig=*/hiveConfig, + /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, + /*rowReaderOptions=*/rowReaderOpts); + + EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); + } + + // Test with session override (preserveFlatMapsInMemory = true) + { + auto hiveConfig = + std::make_shared(std::make_shared( + std::unordered_map())); + std::unordered_map sessionProps = { + {hive::HiveConfig::kPreserveFlatMapsInMemorySession, "true"}}; + config::ConfigBase sessionProperties(std::move(sessionProps)); + + dwio::common::RowReaderOptions rowReaderOpts; + auto hiveSplit = + std::make_shared("", "", FileFormat::DWRF); + + configureRowReaderOptions( + /*tableParameters=*/{}, + /*scanSpec=*/nullptr, + /*metadataFilter=*/nullptr, + /*rowType=*/nullptr, + /*hiveSplit=*/hiveSplit, + /*hiveConfig=*/hiveConfig, + /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, + /*rowReaderOptions=*/rowReaderOpts); + + EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); + } + + // Test session override takes precedence over config + { + std::unordered_map configProps = { + {hive::HiveConfig::kPreserveFlatMapsInMemory, "false"}}; + auto hiveConfig = std::make_shared( + std::make_shared(std::move(configProps))); + std::unordered_map sessionProps = { + {hive::HiveConfig::kPreserveFlatMapsInMemorySession, "true"}}; + config::ConfigBase sessionProperties(std::move(sessionProps)); + + dwio::common::RowReaderOptions rowReaderOpts; + auto hiveSplit = + std::make_shared("", "", FileFormat::DWRF); + + configureRowReaderOptions( + /*tableParameters=*/{}, + /*scanSpec=*/nullptr, + /*metadataFilter=*/nullptr, + /*rowType=*/nullptr, + /*hiveSplit=*/hiveSplit, + /*hiveConfig=*/hiveConfig, + /*sessionProperties=*/&sessionProperties, + /*ioExecutor=*/nullptr, + /*rowReaderOptions=*/rowReaderOpts); + + EXPECT_TRUE(rowReaderOpts.preserveFlatMapsInMemory()); + } +} + } // namespace facebook::velox::connector diff --git a/velox/connectors/hive/tests/HiveDataSinkTest.cpp b/velox/connectors/hive/tests/HiveDataSinkTest.cpp index 537db9e064cb..893cec45c7fa 100644 --- a/velox/connectors/hive/tests/HiveDataSinkTest.cpp +++ b/velox/connectors/hive/tests/HiveDataSinkTest.cpp @@ -15,13 +15,16 @@ */ #include +#include "velox/common/caching/AsyncDataCache.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include +#include #include #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/common/BufferedInput.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" @@ -71,7 +74,7 @@ class HiveDataSinkTest : public exec::test::HiveConnectorTestBase { setupMemoryPools(); spillExecutor_ = std::make_unique( - std::thread::hardware_concurrency()); + folly::hardware_concurrency()); } void TearDown() override { @@ -111,7 +114,8 @@ class HiveDataSinkTest : public exec::test::HiveConnectorTestBase { 0, 0, writerFlushThreshold, - "none"); + "none", + 0); } void setupMemoryPools() { @@ -1314,6 +1318,108 @@ TEST_F(HiveDataSinkTest, ensureFilesUnsupported) { ), "ensureFiles is not supported with bucketing"); } + +TEST_F(HiveDataSinkTest, raceWithCacheEviction) { + /// This test ensures that LRU cache staleness and StringIdMap cache + /// eviction do not cause issues with file reads. + std::atomic stop{false}; + auto cacheCleaner = std::async(std::launch::async, [&] { + auto cache = cache::AsyncDataCache::getInstance(); + auto hiveConnector = std::dynamic_pointer_cast( + getConnector(exec::test::kHiveConnectorId)); + while (!stop) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + cache->clear(); + hiveConnector->clearFileHandleCache(); + } + }); + + const auto outputDirectory = TempDirectoryPath::create(); + auto dataSink = createDataSink(rowType_, outputDirectory->getPath()); + const auto vectors = createVectors(500 /*vectorSize*/, 10 /*numVectors*/); + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + ASSERT_TRUE(dataSink->finish()); + ASSERT_FALSE(dataSink->close().empty()); + + createDuckDbTable(vectors); + verifyWrittenData(outputDirectory->getPath()); + + stop = true; + cacheCleaner.get(); +} + +#ifdef VELOX_ENABLE_PARQUET +TEST_F(HiveDataSinkTest, lazyVectorForParquet) { + // This test ensures that lazy vector is handled correctly in HiveDataSink. + VectorFuzzer::Options options{.vectorSize = 100}; + VectorFuzzer fuzzer(options, pool()); + + auto lazyVector = fuzzer.wrapInLazyVector(fuzzer.fuzzFlat(BIGINT(), 100)); + auto lazyMapVector = fuzzer.wrapInLazyVector(fuzzer.fuzzMap( + fuzzer.fuzzFlat(BIGINT(), 100), fuzzer.fuzzFlat(VARCHAR(), 100), 100)); + + auto rowType = ROW({"c0", "c1"}, {BIGINT(), MAP(BIGINT(), VARCHAR())}); + std::vector children; + children.emplace_back(lazyVector); + children.emplace_back(lazyMapVector); + auto row = std::make_shared( + pool(), rowType, nullptr, 100, std::move(children)); + + const auto outputDirectory = TempDirectoryPath::create(); + auto dataSink = createDataSink( + rowType, outputDirectory->getPath(), dwio::common::FileFormat::PARQUET); + + dataSink->appendData(row); + ASSERT_TRUE(dataSink->finish()); + dataSink->close(); +} +#endif + +// Test to verify that each writer has its own nonReclaimableSection +// pointer when writerOptions is shared. +TEST_F(HiveDataSinkTest, sharedWriterOptionsWithMultipleWriters) { + const auto outputDirectory = TempDirectoryPath::create(); + + const int32_t numBuckets = 3; + auto bucketProperty = std::make_shared( + HiveBucketProperty::Kind::kHiveCompatible, + numBuckets, + std::vector{"c0"}, + std::vector{BIGINT()}, + std::vector>{}); + + // Create shared writer options (this simulates the scenario where + // insertTableHandle_->writerOptions() returns a shared object) + auto sharedWriterOptions = std::make_shared(); + + // Create a data sink with multiple writers (one for each bucket) + auto dataSink = createDataSink( + rowType_, + outputDirectory->getPath(), + dwio::common::FileFormat::DWRF, + {}, + bucketProperty, + sharedWriterOptions); + + const auto vectors = createVectors(200, 3); + + // Write data - this should work without throwing exceptions + for (const auto& vector : vectors) { + dataSink->appendData(vector); + } + + while (!dataSink->finish()) { + } + const auto partitions = dataSink->close(); + + ASSERT_GT(partitions.size(), 1); + createDuckDbTable(vectors); + verifyWrittenData( + outputDirectory->getPath(), static_cast(partitions.size())); +} + } // namespace } // namespace facebook::velox::connector::hive diff --git a/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp b/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp index 4fdf68684b04..1adfb3874458 100644 --- a/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp +++ b/velox/connectors/hive/tests/HivePartitionFunctionTest.cpp @@ -832,10 +832,11 @@ TEST_F(HivePartitionFunctionTest, skew) { } std::vector partitionedInputs; for (int partition = 0; partition < numRemotePartitions; ++partition) { - partitionedInputs.push_back(exec::wrap( - partitionSizeVectors[partition], - partitionIndicesVector[partition], - input)); + partitionedInputs.push_back( + exec::wrap( + partitionSizeVectors[partition], + partitionIndicesVector[partition], + input)); } // Checks that the bad hive partition function (using round-robin map from diff --git a/velox/connectors/hive/tests/HivePartitionUtilTest.cpp b/velox/connectors/hive/tests/HivePartitionNameTest.cpp similarity index 58% rename from velox/connectors/hive/tests/HivePartitionUtilTest.cpp rename to velox/connectors/hive/tests/HivePartitionNameTest.cpp index 7942233b615a..b236951946f4 100644 --- a/velox/connectors/hive/tests/HivePartitionUtilTest.cpp +++ b/velox/connectors/hive/tests/HivePartitionNameTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/connectors/hive/HivePartitionUtil.h" +#include "velox/connectors/hive/HivePartitionName.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/catalog/fbhive/FileUtils.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -26,9 +26,13 @@ using namespace facebook::velox; using namespace facebook::velox::connector::hive; using namespace facebook::velox::dwio::catalog::fbhive; -class HivePartitionUtilTest : public ::testing::Test, +class HivePartitionNameTest : public ::testing::Test, public velox::test::VectorTestBase { protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + template VectorPtr makeDictionary(const std::vector& data) { auto base = makeFlatVector(data); @@ -58,9 +62,26 @@ class HivePartitionUtilTest : public ::testing::Test, input->size(), partitions); } + + static auto toPartitionName() { + return [](auto value, const TypePtr& type, int /*columnIndex*/) { + return HivePartitionName::toName(value, type); + }; + } + + std::vector> extractPartitionKeyValues( + RowVectorPtr input, + const std::vector& partitionChannels, + vector_size_t rowIndex = 0) { + return HivePartitionName::partitionKeyValues( + rowIndex, + makePartitionsVector(input, partitionChannels), + /*nullValueString=*/"", + toPartitionName()); + } }; -TEST_F(HivePartitionUtilTest, partitionName) { +TEST_F(HivePartitionNameTest, partitionName) { { RowVectorPtr input = makeRowVector( {"flat_bool_col", @@ -69,14 +90,17 @@ TEST_F(HivePartitionUtilTest, partitionName) { "flat_int_col", "flat_bigint_col", "dict_string_col", - "const_date_col"}, + "const_date_col", + "flat_timestamp_col"}, {makeFlatVector(std::vector{false}), makeFlatVector(std::vector{10}), makeFlatVector(std::vector{100}), makeFlatVector(std::vector{1000}), makeFlatVector(std::vector{10000}), makeDictionary(std::vector{"str1000"}), - makeConstant(10000, 1, DATE())}); + makeConstant(10000, 1, DATE()), + makeFlatVector( + std::vector{Timestamp::fromMillis(1577836800000)})}); std::vector expectedPartitionKeyValues{ "flat_bool_col=false", @@ -85,7 +109,8 @@ TEST_F(HivePartitionUtilTest, partitionName) { "flat_int_col=1000", "flat_bigint_col=10000", "dict_string_col=str1000", - "const_date_col=1997-05-19"}; + "const_date_col=1997-05-19", + "flat_timestamp_col=2019-12-31 16%3A00%3A00.0"}; std::vector partitionChannels; for (auto i = 1; i <= expectedPartitionKeyValues.size(); i++) { @@ -94,9 +119,7 @@ TEST_F(HivePartitionUtilTest, partitionName) { EXPECT_EQ( FileUtils::makePartName( - extractPartitionKeyValues( - makePartitionsVector(input, partitionChannels), 0), - true), + extractPartitionKeyValues(input, partitionChannels), true), folly::join( "/", std::vector( @@ -116,14 +139,12 @@ TEST_F(HivePartitionUtilTest, partitionName) { VELOX_ASSERT_THROW( FileUtils::makePartName( - extractPartitionKeyValues( - makePartitionsVector(input, partitionChannels), 0), - true), + extractPartitionKeyValues(input, partitionChannels), true), "Unsupported partition type: MAP"); } } -TEST_F(HivePartitionUtilTest, partitionNameForNull) { +TEST_F(HivePartitionNameTest, partitionNameForNull) { std::vector partitionColumnNames{ "flat_bool_col", "flat_tinyint_col", @@ -131,7 +152,8 @@ TEST_F(HivePartitionUtilTest, partitionNameForNull) { "flat_int_col", "flat_bigint_col", "flat_string_col", - "const_date_col"}; + "const_date_col", + "flat_timestamp_col"}; RowVectorPtr input = makeRowVector( partitionColumnNames, @@ -141,14 +163,56 @@ TEST_F(HivePartitionUtilTest, partitionNameForNull) { makeNullableFlatVector({std::nullopt}), makeNullableFlatVector({std::nullopt}), makeNullableFlatVector({std::nullopt}), - makeConstant(std::nullopt, 1, DATE())}); + makeConstant(std::nullopt, 1, DATE()), + makeNullableFlatVector({std::nullopt})}); for (auto i = 0; i < partitionColumnNames.size(); i++) { std::vector partitionChannels = {(column_index_t)i}; - auto partitionEntries = extractPartitionKeyValues( - makePartitionsVector(input, partitionChannels), 0); + auto partitionEntries = extractPartitionKeyValues(input, partitionChannels); EXPECT_EQ(1, partitionEntries.size()); EXPECT_EQ(partitionColumnNames[i], partitionEntries[0].first); EXPECT_EQ("", partitionEntries[0].second); } } + +TEST_F(HivePartitionNameTest, timestampPartitionValueFormatting) { + // Test timestamp partition value formatting to match Presto's + // java.sql.Timestamp.toString() behavior: removes trailing zeros but keeps at + // least one decimal place + std::vector timestamps = { + // Test case 1: All zeros in fractional seconds -> should become ".0" + Timestamp( + 0, 0), // 1970-01-01 00:00:00.000 UTC -> 1969-12-31 16:00:00.0 PST + + // Test case 2: Trailing zeros should be removed + Timestamp(0, 980000000), // 1970-01-01 00:00:00.980 UTC -> 1969-12-31 + // 16:00:00.98 PST + + // Test case 3: No trailing zeros, should remain unchanged + Timestamp(0, 123000000), // 1970-01-01 00:00:00.123 UTC -> 1969-12-31 + // 16:00:00.123 PST + }; + + // Expected values account for timezone conversion to PST (UTC-8) + std::vector expectedValues = { + "1969-12-31 16:00:00.0", // .000 -> .0 (converted to PST) + "1969-12-31 16:00:00.98", // .980 -> .98 (converted to PST) + "1969-12-31 16:00:00.123", // .123 -> .123 (converted to PST) + }; + + RowVectorPtr input = + makeRowVector({"timestamp_col"}, {makeFlatVector(timestamps)}); + + std::vector partitionChannels{0}; + + for (size_t i = 0; i < timestamps.size(); i++) { + auto partitionEntries = + extractPartitionKeyValues(input, partitionChannels, i); + + EXPECT_EQ(1, partitionEntries.size()); + EXPECT_EQ("timestamp_col", partitionEntries[0].first); + EXPECT_EQ(expectedValues[i], partitionEntries[0].second) + << "Failed for timestamp index " << i << " with value " + << timestamps[i].toString(); + } +} diff --git a/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp b/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp index e91782666878..7dcb0d5e1959 100644 --- a/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp +++ b/velox/connectors/hive/tests/PartitionIdGeneratorTest.cpp @@ -16,6 +16,8 @@ #include "velox/connectors/hive/PartitionIdGenerator.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/hive/HivePartitionName.h" +#include "velox/type/TimestampConversion.h" #include "velox/vector/tests/utils/VectorTestBase.h" #include "gtest/gtest.h" @@ -33,7 +35,7 @@ class PartitionIdGeneratorTest : public ::testing::Test, TEST_F(PartitionIdGeneratorTest, consecutiveIdsSingleKey) { auto numPartitions = 100; - PartitionIdGenerator idGenerator(ROW({VARCHAR()}), {0}, 100, pool(), true); + PartitionIdGenerator idGenerator(ROW({VARCHAR()}), {0}, 100, pool()); auto input = makeRowVector( {makeFlatVector(numPartitions * 3, [&](auto row) { @@ -55,7 +57,7 @@ TEST_F(PartitionIdGeneratorTest, consecutiveIdsSingleKey) { TEST_F(PartitionIdGeneratorTest, consecutiveIdsMultipleKeys) { PartitionIdGenerator idGenerator( - ROW({VARCHAR(), INTEGER()}), {0, 1}, 100, pool(), true); + ROW({VARCHAR(), INTEGER()}), {0, 1}, 100, pool()); auto input = makeRowVector({ makeFlatVector( @@ -82,7 +84,7 @@ TEST_F(PartitionIdGeneratorTest, consecutiveIdsMultipleKeys) { TEST_F(PartitionIdGeneratorTest, multipleBoolKeys) { PartitionIdGenerator idGenerator( - ROW({BOOLEAN(), BOOLEAN()}), {0, 1}, 100, pool(), true); + ROW({BOOLEAN(), BOOLEAN()}), {0, 1}, 100, pool()); auto input = makeRowVector({ makeFlatVector( @@ -108,7 +110,7 @@ TEST_F(PartitionIdGeneratorTest, multipleBoolKeys) { } TEST_F(PartitionIdGeneratorTest, stableIdsSingleKey) { - PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool(), true); + PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool()); auto numPartitions = 40; auto input = makeRowVector({ @@ -135,7 +137,7 @@ TEST_F(PartitionIdGeneratorTest, stableIdsSingleKey) { TEST_F(PartitionIdGeneratorTest, stableIdsMultipleKeys) { PartitionIdGenerator idGenerator( - ROW({BIGINT(), VARCHAR(), INTEGER()}), {1, 2}, 100, pool(), true); + ROW({BIGINT(), VARCHAR(), INTEGER()}), {1, 2}, 100, pool()); const vector_size_t size = 1'000; auto input = makeRowVector({ @@ -174,7 +176,7 @@ TEST_F(PartitionIdGeneratorTest, stableIdsMultipleKeys) { TEST_F(PartitionIdGeneratorTest, partitionKeysCaseSensitive) { PartitionIdGenerator idGenerator( - ROW({"cc0", "Cc1"}, {BIGINT(), VARCHAR()}), {1}, 100, pool(), false); + ROW({"cc0", "Cc+1"}, {BIGINT(), VARCHAR()}), {1}, 100, pool()); auto input = makeRowVector({ makeFlatVector({1, 2, 3}), @@ -183,12 +185,19 @@ TEST_F(PartitionIdGeneratorTest, partitionKeysCaseSensitive) { raw_vector firstTimeIds; idGenerator.run(input, firstTimeIds); - EXPECT_EQ("Cc1=apple", idGenerator.partitionName(0)); - EXPECT_EQ("Cc1=orange", idGenerator.partitionName(1)); + + EXPECT_EQ( + "Cc+1=apple", + HivePartitionName::partitionName( + 0, idGenerator.partitionValues(), /*partitionKeyAsLowerCase=*/false)); + EXPECT_EQ( + "Cc+1=orange", + HivePartitionName::partitionName( + 1, idGenerator.partitionValues(), /*partitionKeyAsLowerCase=*/false)); } TEST_F(PartitionIdGeneratorTest, numPartitions) { - PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool(), true); + PartitionIdGenerator idGenerator(ROW({BIGINT()}), {0}, 100, pool()); // First run to process partition 0,..,9. Total num of partitions processed by // far is 10. @@ -223,7 +232,7 @@ TEST_F(PartitionIdGeneratorTest, limitOfPartitionNumber) { auto maxPartitions = 100; PartitionIdGenerator idGenerator( - ROW({INTEGER()}), {0}, maxPartitions, pool(), true); + ROW({INTEGER()}), {0}, maxPartitions, pool()); auto input = makeRowVector({ makeFlatVector(maxPartitions + 1, [](auto row) { return row; }), @@ -236,6 +245,82 @@ TEST_F(PartitionIdGeneratorTest, limitOfPartitionNumber) { fmt::format("Exceeded limit of {} distinct partitions.", maxPartitions)); } +TEST_F(PartitionIdGeneratorTest, timestampPartitionKeyComparasion) { + PartitionIdGenerator idGenerator( + ROW({"timestamp_col"}, {TIMESTAMP()}), {0}, 100, pool()); + auto timestampResult = util::fromTimestampString( + "2025-01-02 00:00:00.0", util::TimestampParseMode::kPrestoCast); + auto input = makeRowVector({ + makeFlatVector({timestampResult.value()}), + }); + raw_vector testTimeIds; + idGenerator.run(input, testTimeIds); + + EXPECT_EQ( + HivePartitionName::partitionName( + testTimeIds[0], + idGenerator.partitionValues(), + /*partitionKeyAsLowerCase=*/true), + "timestamp_col=2025-01-01 16%3A00%3A00.0"); +} + +TEST_F(PartitionIdGeneratorTest, timestampPartitionKey) { + PartitionIdGenerator idGenerator(ROW({TIMESTAMP()}), {0}, 100, pool()); + + auto numPartitions = 50; + auto input = makeRowVector({ + makeFlatVector( + numPartitions, + [](auto row) { + return Timestamp::fromMillis( + 1639426440000 + static_cast(row) * 100); + }), + }); + + raw_vector firstTimeIds; + idGenerator.run(input, firstTimeIds); + + std::unordered_set distinctIds( + firstTimeIds.begin(), firstTimeIds.end()); + EXPECT_EQ(distinctIds.size(), numPartitions); + EXPECT_EQ(*std::min_element(distinctIds.begin(), distinctIds.end()), 0); + EXPECT_EQ( + *std::max_element(distinctIds.begin(), distinctIds.end()), + numPartitions - 1); + + raw_vector secondTimeIds; + idGenerator.run(input, secondTimeIds); + + for (auto i = 0; i < input->size(); ++i) { + EXPECT_EQ(firstTimeIds[i], secondTimeIds[i]) << "at " << i; + } + + auto otherNumPartitions = 30; + auto otherInput = makeRowVector({ + makeFlatVector( + otherNumPartitions, + [](auto row) { + return Timestamp::fromMillis( + 1639426440000 + static_cast(row) * 500); + }), + }); + + raw_vector otherIds; + idGenerator.run(otherInput, otherIds); + + std::unordered_set otherDistinctIds( + otherIds.begin(), otherIds.end()); + EXPECT_EQ(otherDistinctIds.size(), otherNumPartitions); + + // Run the original input again and verify stable IDs + raw_vector thirdTimeIds; + idGenerator.run(input, thirdTimeIds); + + for (auto i = 0; i < input->size(); ++i) { + EXPECT_EQ(firstTimeIds[i], thirdTimeIds[i]) << "at " << i; + } +} + TEST_F(PartitionIdGeneratorTest, supportedPartitionKeyTypes) { // Test on supported key types. { @@ -248,23 +333,26 @@ TEST_F(PartitionIdGeneratorTest, supportedPartitionKeyTypes) { SMALLINT(), INTEGER(), BIGINT(), + TIMESTAMP(), }), - {0, 1, 2, 3, 4, 5, 6}, + {0, 1, 2, 3, 4, 5, 6, 7}, 100, - pool(), - true); - - auto input = makeRowVector({ - makeNullableFlatVector( - {"Left", std::nullopt, "Right"}, VARCHAR()), - makeNullableFlatVector({true, false, std::nullopt}), - makeFlatVector( - {"proton", "neutron", "electron"}, VARBINARY()), - makeNullableFlatVector({1, 2, std::nullopt}), - makeNullableFlatVector({1, 2, std::nullopt}), - makeNullableFlatVector({1, std::nullopt, 2}), - makeNullableFlatVector({std::nullopt, 1, 2}), - }); + pool()); + + auto input = makeRowVector( + {makeNullableFlatVector( + {"Left", std::nullopt, "Right"}, VARCHAR()), + makeNullableFlatVector({true, false, std::nullopt}), + makeFlatVector( + {"proton", "neutron", "electron"}, VARBINARY()), + makeNullableFlatVector({1, 2, std::nullopt}), + makeNullableFlatVector({1, 2, std::nullopt}), + makeNullableFlatVector({1, std::nullopt, 2}), + makeNullableFlatVector({std::nullopt, 1, 2}), + makeNullableFlatVector( + {std::nullopt, + Timestamp::fromMillis(1639426440001), + Timestamp::fromMillis(1639426440002)})}); raw_vector ids; idGenerator.run(input, ids); @@ -279,15 +367,13 @@ TEST_F(PartitionIdGeneratorTest, supportedPartitionKeyTypes) { auto input = makeRowVector({ makeConstant(1.0, 1), makeConstant(1.0, 1), - makeConstant(Timestamp::fromMillis(1639426440000), 1), makeArrayVector({{1, 2, 3}}), makeMapVector({{{1, 2}}}), }); for (column_index_t i = 1; i < input->childrenSize(); i++) { VELOX_ASSERT_THROW( - PartitionIdGenerator( - asRowType(input->type()), {i}, 100, pool(), true), + PartitionIdGenerator(asRowType(input->type()), {i}, 100, pool()), fmt::format( "Unsupported partition type: {}.", input->childAt(i)->type()->toString())); diff --git a/velox/connectors/tests/CMakeLists.txt b/velox/connectors/tests/CMakeLists.txt index 0021b6a56900..b70f227b41eb 100644 --- a/velox/connectors/tests/CMakeLists.txt +++ b/velox/connectors/tests/CMakeLists.txt @@ -20,4 +20,5 @@ target_link_libraries( GTest::gtest GTest::gtest_main glog::glog - Folly::folly) + Folly::folly +) diff --git a/velox/connectors/tests/ConnectorTest.cpp b/velox/connectors/tests/ConnectorTest.cpp index 88690fdb2c75..90de53198ad1 100644 --- a/velox/connectors/tests/ConnectorTest.cpp +++ b/velox/connectors/tests/ConnectorTest.cpp @@ -15,15 +15,11 @@ */ #include "velox/connectors/Connector.h" -#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/config/Config.h" #include namespace facebook::velox::connector { - -class ConnectorTest : public testing::Test {}; - namespace { class TestConnector : public connector::Connector { @@ -32,18 +28,15 @@ class TestConnector : public connector::Connector { std::unique_ptr createDataSource( const RowTypePtr& /* outputType */, - const std::shared_ptr& /* tableHandle */, - const std::unordered_map< - std::string, - std::shared_ptr>& /* columnHandles */, + const ConnectorTableHandlePtr& /* tableHandle */, + const connector::ColumnHandleMap& /* columnHandles */, connector::ConnectorQueryCtx* connectorQueryCtx) override { VELOX_NYI(); } std::unique_ptr createDataSink( RowTypePtr /*inputType*/, - std::shared_ptr< - ConnectorInsertTableHandle> /*connectorInsertTableHandle*/, + ConnectorInsertTableHandlePtr /*connectorInsertTableHandle*/, ConnectorQueryCtx* /*connectorQueryCtx*/, CommitStrategy /*commitStrategy*/) override final { VELOX_NYI(); @@ -52,9 +45,7 @@ class TestConnector : public connector::Connector { class TestConnectorFactory : public connector::ConnectorFactory { public: - static constexpr const char* kConnectorFactoryName = "test-factory"; - - TestConnectorFactory() : ConnectorFactory(kConnectorFactoryName) {} + TestConnectorFactory() : ConnectorFactory("test-factory") {} std::shared_ptr newConnector( const std::string& id, @@ -65,39 +56,30 @@ class TestConnectorFactory : public connector::ConnectorFactory { } }; -} // namespace +TEST(ConnectorTest, getAllConnectors) { + TestConnectorFactory factory; -TEST_F(ConnectorTest, getAllConnectors) { - registerConnectorFactory(std::make_shared()); - VELOX_ASSERT_THROW( - registerConnectorFactory(std::make_shared()), - "ConnectorFactory with name 'test-factory' is already registered"); - EXPECT_TRUE(hasConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); const int32_t numConnectors = 10; for (int32_t i = 0; i < numConnectors; i++) { - registerConnector( - getConnectorFactory(TestConnectorFactory::kConnectorFactoryName) - ->newConnector( - fmt::format("connector-{}", i), - std::make_shared( - std::unordered_map()))); + registerConnector(factory.newConnector( + fmt::format("connector-{}", i), + std::make_shared( + std::unordered_map()))); } + const auto& connectors = getAllConnectors(); EXPECT_EQ(connectors.size(), numConnectors); for (int32_t i = 0; i < numConnectors; i++) { EXPECT_EQ(connectors.count(fmt::format("connector-{}", i)), 1); } + for (int32_t i = 0; i < numConnectors; i++) { unregisterConnector(fmt::format("connector-{}", i)); } EXPECT_EQ(getAllConnectors().size(), 0); - EXPECT_TRUE( - unregisterConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); - EXPECT_FALSE( - unregisterConnectorFactory(TestConnectorFactory::kConnectorFactoryName)); } -TEST_F(ConnectorTest, connectorSplit) { +TEST(ConnectorTest, connectorSplit) { { const ConnectorSplit split("test", 100, true); ASSERT_EQ(split.connectorId, "test"); @@ -117,4 +99,5 @@ TEST_F(ConnectorTest, connectorSplit) { "[split: connector id test, weight 50, cacheable false]"); } } +} // namespace } // namespace facebook::velox::connector diff --git a/velox/connectors/tpcds/CMakeLists.txt b/velox/connectors/tpcds/CMakeLists.txt new file mode 100644 index 000000000000..4c92a290f13c --- /dev/null +++ b/velox/connectors/tpcds/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +velox_add_library(velox_tpcds_connector TpcdsConnector.cpp) +velox_link_libraries(velox_tpcds_connector PUBLIC velox_connector velox_tpcds_gen PRIVATE fmt::fmt) + +if(${VELOX_BUILD_TESTING}) + add_subdirectory(tests) +endif() diff --git a/velox/connectors/tpcds/TpcdsConnector.cpp b/velox/connectors/tpcds/TpcdsConnector.cpp new file mode 100644 index 000000000000..b981bda903ad --- /dev/null +++ b/velox/connectors/tpcds/TpcdsConnector.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/connectors/tpcds/TpcdsConnector.h" +#include "velox/tpcds/gen/DSDGenIterator.h" + +using namespace facebook::velox; + +namespace facebook::velox::connector::tpcds { + +std::string TpcdsTableHandle::toString() const { + return fmt::format( + "table: {}, scale factor: {}", toTableName(table_), scaleFactor_); +} + +TpcdsDataSource::TpcdsDataSource( + const std::shared_ptr& outputType, + const std::shared_ptr& + tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + velox::memory::MemoryPool* FOLLY_NONNULL pool) + : pool_(pool) { + const auto tpcdsTableHandle = + std::dynamic_pointer_cast(tableHandle); + VELOX_CHECK_NOT_NULL( + tpcdsTableHandle, "TableHandle must be an instance of TpcdsTableHandle"); + table_ = tpcdsTableHandle->getTpcdsTable(); + scaleFactor_ = tpcdsTableHandle->getScaleFactor(); + velox::tpcds::DSDGenIterator dsdGenIterator(scaleFactor_, 1, 1); + rowCount_ = dsdGenIterator.getRowCount(static_cast(table_)); + + auto tpcdsTableSchema = getTableSchema(tpcdsTableHandle->getTpcdsTable()); + VELOX_CHECK_NOT_NULL(tpcdsTableSchema, "TpcdsSchema can't be null."); + + outputColumnMappings_.reserve(outputType->size()); + + for (const auto& outputName : outputType->names()) { + auto it = columnHandles.find(outputName); + VELOX_CHECK( + it != columnHandles.end(), + "ColumnHandle is missing for output column '{}' on table '{}'", + outputName, + toTableName(table_)); + + const auto handle = + std::dynamic_pointer_cast(it->second); + VELOX_CHECK_NOT_NULL( + handle, + "ColumnHandle must be an instance of TpcdsColumnHandle " + "for '{}' on table '{}'", + it->second->name(), + toTableName(table_)); + + auto idx = tpcdsTableSchema->getChildIdxIfExists(handle->name()); + VELOX_CHECK( + idx != std::nullopt, + "Column '{}' not found on TPC-DS table '{}'.", + handle->name(), + toTableName(table_)); + outputColumnMappings_.emplace_back(*idx); + } + outputType_ = outputType; +} + +RowVectorPtr TpcdsDataSource::projectOutputColumns(RowVectorPtr inputVector) { + std::vector children; + children.reserve(outputColumnMappings_.size()); + + for (const auto channel : outputColumnMappings_) { + children.emplace_back(inputVector->childAt(channel)); + } + + return std::make_shared( + pool_, + outputType_, + BufferPtr(), + inputVector->size(), + std::move(children)); +} + +void TpcdsDataSource::addSplit(std::shared_ptr split) { + VELOX_CHECK_EQ( + currentSplit_, + nullptr, + "Previous split has not been processed yet. Call next() to process the split."); + currentSplit_ = std::dynamic_pointer_cast(split); + VELOX_CHECK(currentSplit_, "Wrong type of split for TpcdsDataSource."); + + size_t partSize = std::ceil( + static_cast(rowCount_) / + static_cast(currentSplit_->totalParts_)); + + splitOffset_ = partSize * currentSplit_->partNumber_; + splitEnd_ = splitOffset_ + partSize; +} + +std::optional TpcdsDataSource::next( + uint64_t size, + velox::ContinueFuture& /*future*/) { + VELOX_CHECK_NOT_NULL( + currentSplit_, "No split to process. Call addSplit() first."); + + size_t maxRows = std::min(size, (splitEnd_ - splitOffset_)); + vector_size_t parallel = currentSplit_->totalParts_; + vector_size_t child = currentSplit_->partNumber_; + auto outputVector = genTpcdsData( + table_, maxRows, splitOffset_, pool_, scaleFactor_, parallel, child); + + // If the split is exhausted. + if (!outputVector || outputVector->size() == 0) { + currentSplit_ = nullptr; + return nullptr; + } + + // splitOffset needs to advance based on maxRows passed to getTpcdsData(), and + // not the actual number of returned rows in the output vector, as they are + // not the same for lineitem. + splitOffset_ += maxRows; + completedRows_ += outputVector->size(); + completedBytes_ += outputVector->retainedSize(); + + return projectOutputColumns(outputVector); +} +} // namespace facebook::velox::connector::tpcds diff --git a/velox/connectors/tpcds/TpcdsConnector.h b/velox/connectors/tpcds/TpcdsConnector.h new file mode 100644 index 000000000000..329114d90057 --- /dev/null +++ b/velox/connectors/tpcds/TpcdsConnector.h @@ -0,0 +1,184 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/connectors/Connector.h" +#include "velox/connectors/tpcds/TpcdsConnectorSplit.h" +#include "velox/tpcds/gen/TpcdsGen.h" + +namespace facebook::velox::connector::tpcds { + +class TpcdsConnector; + +// TPC-DS column handle only needs the column name (all columns are generated in +// the same way). +class TpcdsColumnHandle : public velox::connector::ColumnHandle { + public: + explicit TpcdsColumnHandle(const std::string& name) : name_(name) {} + + const std::string& name() const { + return name_; + } + + private: + const std::string name_; +}; + +// TPC-DS table handle uses the underlying enum to describe the target table. +class TpcdsTableHandle : public ConnectorTableHandle { + public: + explicit TpcdsTableHandle( + std::string connectorId, + velox::tpcds::Table table, + double scaleFactor = 0.01) + : ConnectorTableHandle(std::move(connectorId)), + table_(table), + name_(toTableName(table)), + scaleFactor_(scaleFactor) { + VELOX_CHECK_GT(scaleFactor, 0.0, "Tpcds scale factor must be non-negative"); + } + + const std::string& name() const override { + return name_; + } + + std::string toString() const override; + + velox::tpcds::Table getTpcdsTable() const { + return table_; + } + + double getScaleFactor() const { + return scaleFactor_; + } + + private: + const velox::tpcds::Table table_; + const std::string name_; + const double scaleFactor_; +}; + +class TpcdsDataSource : public velox::connector::DataSource { + public: + TpcdsDataSource( + const std::shared_ptr& outputType, + const std::shared_ptr& + tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + velox::memory::MemoryPool* FOLLY_NONNULL pool); + + void addSplit(std::shared_ptr split) override; + + void addDynamicFilter( + column_index_t /*outputChannel*/, + const std::shared_ptr& /*filter*/) override { + VELOX_NYI("Dynamic filters not supported by TpcdsConnector."); + } + + std::optional next(uint64_t size, velox::ContinueFuture& future) + override; + + uint64_t getCompletedRows() override { + return completedRows_; + } + + uint64_t getCompletedBytes() override { + return completedBytes_; + } + + std::unordered_map getRuntimeStats() override { + return {}; + } + + private: + RowVectorPtr projectOutputColumns(RowVectorPtr vector); + + velox::tpcds::Table table_; + double scaleFactor_{0.01}; + size_t rowCount_{0}; + RowTypePtr outputType_; + + // Mapping between output columns and their indices (column_index_t) in the + // dsdgen generated datasets. + std::vector outputColumnMappings_; + + std::shared_ptr currentSplit_; + + // Offset of the first row in current split. + uint64_t splitOffset_{0}; + // Offset of the last row in current split. + uint64_t splitEnd_{0}; + + size_t completedRows_{0}; + size_t completedBytes_{0}; + + memory::MemoryPool* FOLLY_NONNULL pool_; +}; + +class TpcdsConnector final : public velox::connector::Connector { + public: + TpcdsConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* FOLLY_NULLABLE /*executor*/) + : Connector(id) {} + + std::unique_ptr createDataSource( + const std::shared_ptr& outputType, + const std::shared_ptr& tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + ConnectorQueryCtx* FOLLY_NONNULL connectorQueryCtx) override final { + return std::make_unique( + outputType, + tableHandle, + columnHandles, + connectorQueryCtx->memoryPool()); + } + + std::unique_ptr createDataSink( + RowTypePtr /*inputType*/, + std::shared_ptr< + const ConnectorInsertTableHandle> /*connectorInsertTableHandle*/, + ConnectorQueryCtx* /*connectorQueryCtx*/, + CommitStrategy /*commitStrategy*/) override final { + VELOX_NYI("TpcdsConnector does not support data sink."); + } +}; + +class TpcdsConnectorFactory : public ConnectorFactory { + public: + static constexpr const char* kTpcdsConnectorName{"tpcds"}; + + TpcdsConnectorFactory() : ConnectorFactory(kTpcdsConnectorName) {} + + explicit TpcdsConnectorFactory(const char* connectorName) + : ConnectorFactory(connectorName) {} + + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + folly::Executor* cpuExecutor = nullptr) override { + return std::make_shared(id, config, ioExecutor); + } +}; + +} // namespace facebook::velox::connector::tpcds diff --git a/velox/connectors/tpcds/TpcdsConnectorSplit.h b/velox/connectors/tpcds/TpcdsConnectorSplit.h new file mode 100644 index 000000000000..80fb8863e76a --- /dev/null +++ b/velox/connectors/tpcds/TpcdsConnectorSplit.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/connectors/Connector.h" + +namespace facebook::velox::connector::tpcds { + +struct TpcdsConnectorSplit : public velox::connector::ConnectorSplit { + explicit TpcdsConnectorSplit( + const std::string& connectorId, + size_t totalParts, + size_t partNumber) + : TpcdsConnectorSplit(connectorId, true, totalParts, partNumber) {} + + TpcdsConnectorSplit( + const std::string& connectorId, + bool cacheable, + size_t totalParts, + size_t partNumber) + : ConnectorSplit(connectorId, /*splitWeight=*/0, cacheable), + totalParts_(totalParts), + partNumber_(partNumber) { + VELOX_CHECK_GE(totalParts, 1, "totalParts must be >= 1"); + VELOX_CHECK_GT(totalParts, partNumber, "totalParts must be > partNumber"); + } + + // In how many parts the generated TPC-DS table will be segmented, roughly + // `rowCount / totalParts` + const vector_size_t totalParts_{1}; + + // Which of these parts will be read by this split. + const vector_size_t partNumber_{0}; +}; + +} // namespace facebook::velox::connector::tpcds + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::connector::tpcds::TpcdsConnectorSplit const& s, + format_context& ctx) const { + return formatter::format(s.toString(), ctx); + } +}; + +template <> +struct fmt::formatter< + std::shared_ptr> + : formatter { + auto format( + std::shared_ptr< + facebook::velox::connector::tpcds::TpcdsConnectorSplit> const& s, + format_context& ctx) const { + return formatter::format(s->toString(), ctx); + } +}; diff --git a/velox/runner/tests/CMakeLists.txt b/velox/connectors/tpcds/tests/CMakeLists.txt similarity index 71% rename from velox/runner/tests/CMakeLists.txt rename to velox/connectors/tpcds/tests/CMakeLists.txt index 863f572fa828..e214822539c5 100644 --- a/velox/runner/tests/CMakeLists.txt +++ b/velox/connectors/tpcds/tests/CMakeLists.txt @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +add_executable(velox_tpcds_connector_test TpcdsConnectorTest.cpp) -add_executable(velox_local_runner_test LocalRunnerTest.cpp Main.cpp) - -add_test(velox_local_runner_test velox_local_runner_test) +add_test(velox_tpcds_connector_test velox_tpcds_connector_test) target_link_libraries( - velox_local_runner_test - velox_exec_runner_test_util + velox_tpcds_connector_test + velox_tpcds_connector + velox_vector_test_lib velox_exec_test_lib - velox_parse_parser - velox_parse_expression - GTest::gtest) + velox_aggregates + GTest::gtest + GTest::gtest_main +) diff --git a/velox/connectors/tpcds/tests/TpcdsConnectorTest.cpp b/velox/connectors/tpcds/tests/TpcdsConnectorTest.cpp new file mode 100644 index 000000000000..67462ec7ca9d --- /dev/null +++ b/velox/connectors/tpcds/tests/TpcdsConnectorTest.cpp @@ -0,0 +1,221 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/tpcds/TpcdsConnector.h" +#include +#include "gtest/gtest.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::connector::tpcds::test { + +namespace { +class TpcdsConnectorTest : public exec::test::OperatorTestBase { + public: + const std::string kTpcdsConnectorId = "test-tpcds"; + + void SetUp() override { + OperatorTestBase::SetUp(); + connector::tpcds::TpcdsConnectorFactory factory; + auto connector = factory.newConnector( + kTpcdsConnectorId, + std::make_shared( + std::unordered_map())); + connector::registerConnector(connector); + } + + void TearDown() override { + connector::unregisterConnector(kTpcdsConnectorId); + OperatorTestBase::TearDown(); + } + + exec::Split makeTpcdsSplit(size_t totalParts = 1, size_t partNumber = 0) + const { + return exec::Split( + std::make_shared( + kTpcdsConnectorId, /*cacheable=*/true, totalParts, partNumber)); + } + + RowVectorPtr getResults( + const core::PlanNodePtr& planNode, + std::vector&& splits) { + return exec::test::AssertQueryBuilder(planNode) + .splits(std::move(splits)) + .copyResults(pool()); + } +}; + +// Simple scan of first 5 rows of table 'warehouse'. +TEST_F(TpcdsConnectorTest, simple) { + auto plan = + exec::test::PlanBuilder() + .tpcdsTableScan( + velox::tpcds::Table::TBL_WAREHOUSE, + {"w_warehouse_sk", "w_warehouse_name", "w_city", "w_state"}, + 1) + .limit(0, 5, false) + .planNode(); + + auto output = getResults(plan, {makeTpcdsSplit()}); + auto expected = makeRowVector({ + makeFlatVector({1, 2, 3, 4, 5}), + makeNullableFlatVector( + {"Conventional childr", + "Important issues liv", + "Doors canno", + "Bad cards must make.", + std::nullopt}), + makeConstant("Fairview", 5), + makeConstant("TN", 5), + }); + velox::test::assertEqualVectors(expected, output); +} + +// Extract single column from table 'store'. +TEST_F(TpcdsConnectorTest, singleColumn) { + auto plan = + exec::test::PlanBuilder() + .tpcdsTableScan(velox::tpcds::Table::TBL_STORE, {"s_store_name"}, 1) + .planNode(); + + auto output = getResults(plan, {makeTpcdsSplit()}); + auto expected = makeRowVector({makeFlatVector({ + "ought", + "able", + "able", + "ese", + "anti", + "cally", + "ation", + "eing", + "eing", + "bar", + "ought", + "ought", + })}); + velox::test::assertEqualVectors(expected, output); + EXPECT_EQ("s_store_name", output->type()->asRow().nameOf(0)); +} + +// Check that aliases are correctly resolved. +TEST_F(TpcdsConnectorTest, singleColumnWithAlias) { + const std::string aliasedName = "my_aliased_column_name"; + + auto outputType = ROW({aliasedName}, {VARCHAR()}); + auto plan = exec::test::PlanBuilder() + .startTableScan() + .outputType(outputType) + .tableHandle( + std::make_shared( + kTpcdsConnectorId, velox::tpcds::Table::TBL_ITEM)) + .assignments({ + {aliasedName, + std::make_shared("i_product_name")}, + {"other_name", + std::make_shared("i_product_name")}, + {"third_column", + std::make_shared("i_manager_id")}, + }) + .endTableScan() + .limit(0, 1, false) + .planNode(); + + auto output = getResults(plan, {makeTpcdsSplit()}); + auto expected = makeRowVector({makeFlatVector({ + "ought", + })}); + velox::test::assertEqualVectors(expected, output); + + EXPECT_EQ(aliasedName, output->type()->asRow().nameOf(0)); + EXPECT_EQ(1, output->childrenSize()); +} + +TEST_F(TpcdsConnectorTest, unknownColumn) { + EXPECT_THROW( + { + exec::test::PlanBuilder() + .tpcdsTableScan(velox::tpcds::Table::TBL_ITEM, {"invalid_column"}) + .planNode(); + }, + VeloxUserError); +} + +// Join warehouse and store on state. +TEST_F(TpcdsConnectorTest, join) { + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId itemScanId; + core::PlanNodeId inventoryScanId; + auto plan = exec::test::PlanBuilder(planNodeIdGenerator) + .tpcdsTableScan( + velox::tpcds::Table::TBL_WAREHOUSE, + {"w_warehouse_id", "w_state"}, + 1) + .capturePlanNodeId(itemScanId) + .hashJoin( + {"w_state"}, + {"s_state"}, + exec::test::PlanBuilder(planNodeIdGenerator) + .tpcdsTableScan( + velox::tpcds::Table::TBL_STORE, {"s_state"}, 1) + .capturePlanNodeId(inventoryScanId) + .planNode(), + "", + {"w_warehouse_id"}) + // Get distinct values of warehouse_id + .partialAggregation({"w_warehouse_id"}, {}) + .finalAggregation() + .orderBy({"w_warehouse_id"}, false) + .planNode(); + + auto output = exec::test::AssertQueryBuilder(plan) + .split(itemScanId, makeTpcdsSplit()) + .split(inventoryScanId, makeTpcdsSplit()) + .copyResults(pool()); + + auto expected = makeRowVector({ + makeFlatVector( + {"AAAAAAAABAAAAAAA", + "AAAAAAAACAAAAAAA", + "AAAAAAAADAAAAAAA", + "AAAAAAAAEAAAAAAA", + "AAAAAAAAFAAAAAAA"}), + }); + velox::test::assertEqualVectors(expected, output); +} + +TEST_F(TpcdsConnectorTest, inventoryDateCount) { + auto plan = + exec::test::PlanBuilder() + .tpcdsTableScan( + velox::tpcds::Table::TBL_WEB_SITE, {"web_rec_start_date"}, 1) + .filter("web_rec_start_date = '1997-08-16'::DATE") + .limit(0, 10, false) + .planNode(); + + auto output = getResults(plan, {makeTpcdsSplit()}); + auto inventoryDate = output->childAt(0)->asFlatVector(); + EXPECT_EQ("1997-08-16", DATE()->toString(inventoryDate->valueAt(0))); + EXPECT_EQ(10, inventoryDate->size()); +} +} // namespace +} // namespace facebook::velox::connector::tpcds::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/velox/connectors/tpch/CMakeLists.txt b/velox/connectors/tpch/CMakeLists.txt index b8373349e7d0..43ebe514b283 100644 --- a/velox/connectors/tpch/CMakeLists.txt +++ b/velox/connectors/tpch/CMakeLists.txt @@ -14,8 +14,16 @@ velox_add_library(velox_tpch_connector OBJECT TpchConnector.cpp) -velox_link_libraries(velox_tpch_connector velox_connector velox_tpch_gen - fmt::fmt) +velox_link_libraries( + velox_tpch_connector + velox_connector + velox_tpch_gen + velox_core + velox_exec + velox_expression + velox_vector + fmt::fmt +) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/connectors/tpch/TpchConnector.cpp b/velox/connectors/tpch/TpchConnector.cpp index b2317774fe74..cf70061aada3 100644 --- a/velox/connectors/tpch/TpchConnector.cpp +++ b/velox/connectors/tpch/TpchConnector.cpp @@ -15,12 +15,20 @@ */ #include "velox/connectors/tpch/TpchConnector.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/expression/Expr.h" #include "velox/tpch/gen/TpchGen.h" namespace facebook::velox::connector::tpch { using facebook::velox::tpch::Table; +TpchConnector::TpchConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* /*executor*/) + : Connector(id, std::move(config)) {} + namespace { RowVectorPtr getTpchData( @@ -53,20 +61,23 @@ RowVectorPtr getTpchData( } // namespace std::string TpchTableHandle::toString() const { - return fmt::format( - "table: {}, scale factor: {}", toTableName(table_), scaleFactor_); + std::stringstream out; + out << "table: " << toTableName(table_) << ", scale factor: " << scaleFactor_; + if (filterExpression_ != nullptr) { + out << ", filter: " << filterExpression_->toString(); + } + return out.str(); } TpchDataSource::TpchDataSource( - const std::shared_ptr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, - velox::memory::MemoryPool* pool) - : pool_(pool) { + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx) + : connectorQueryCtx_(connectorQueryCtx), + pool_(connectorQueryCtx->memoryPool()) { auto tpchTableHandle = - std::dynamic_pointer_cast(tableHandle); + std::dynamic_pointer_cast(tableHandle); VELOX_CHECK_NOT_NULL( tpchTableHandle, "TableHandle must be an instance of TpchTableHandle"); tpchTable_ = tpchTableHandle->getTable(); @@ -77,7 +88,6 @@ TpchDataSource::TpchDataSource( VELOX_CHECK_NOT_NULL(tpchTableSchema, "TpchSchema can't be null."); outputColumnMappings_.reserve(outputType->size()); - for (const auto& outputName : outputType->names()) { auto it = columnHandles.find(outputName); VELOX_CHECK( @@ -86,7 +96,7 @@ TpchDataSource::TpchDataSource( outputName, toTableName(tpchTable_)); - auto handle = std::dynamic_pointer_cast(it->second); + auto handle = std::dynamic_pointer_cast(it->second); VELOX_CHECK_NOT_NULL( handle, "ColumnHandle must be an instance of TpchColumnHandle " @@ -103,6 +113,11 @@ TpchDataSource::TpchDataSource( outputColumnMappings_.emplace_back(*idx); } outputType_ = outputType; + + if (tpchTableHandle->filterExpression()) { + filterExpression_ = connectorQueryCtx_->expressionEvaluator()->compile( + tpchTableHandle->filterExpression()); + } } RowVectorPtr TpchDataSource::projectOutputColumns(RowVectorPtr inputVector) { @@ -143,6 +158,55 @@ void TpchDataSource::addSplit(std::shared_ptr split) { splitEnd_ = splitOffset_ + partSize; } +RowVectorPtr TpchDataSource::applyFilter( + RowVectorPtr& vector, + exec::ExprSet* filter) { + if (!filter) { + return projectOutputColumns(vector); + } + + filterSelectivityVector_.resize(vector->size()); + filterSelectivityVector_.setAll(); + filterEvalCtx_.selectedIndices = + allocateIndices(vector->size(), vector->pool()); + + if (!filterMask_ || filterMask_->size() < vector->size()) { + filterMask_ = BaseVector::create(BOOLEAN(), vector->size(), pool_); + } + connectorQueryCtx_->expressionEvaluator()->evaluate( + filter, filterSelectivityVector_, *vector, filterMask_); + + auto filterResults = filterMask_->as>(); + filterSelectivityVector_.applyToSelected([&](vector_size_t row) { + if (filterResults->isNullAt(row) || !filterResults->valueAt(row)) { + filterSelectivityVector_.setValid(row, false); + } + }); + filterSelectivityVector_.updateBounds(); + + if (filterSelectivityVector_.isAllSelected()) { + return projectOutputColumns(vector); + } + + auto* selected = filterEvalCtx_.getRawSelectedIndices( + filterSelectivityVector_.size(), pool_); + vector_size_t remaining = 0; + filterSelectivityVector_.applyToSelected( + [&selected, &remaining](int32_t row) { selected[remaining++] = row; }); + + std::vector children; + children.reserve(outputType_->size()); + for (int i = 0; i < outputType_->size(); ++i) { + auto& child = vector->childAt(outputColumnMappings_[i]); + children.emplace_back( + exec::wrapChild(remaining, filterEvalCtx_.selectedIndices, child)); + } + + filterEvalCtx_.selectedIndices.reset(); + return std::make_shared( + vector->pool(), outputType_, BufferPtr(), remaining, std::move(children)); +} + std::optional TpchDataSource::next( uint64_t size, velox::ContinueFuture& /*future*/) { @@ -175,7 +239,8 @@ std::optional TpchDataSource::next( completedRows_ += outputVector->size(); completedBytes_ += outputVector->retainedSize(); - return projectOutputColumns(outputVector); + // Apply any filters pushed down into the DataSource + return applyFilter(outputVector, filterExpression_.get()); } bool TpchDataSource::isLineItem() const { diff --git a/velox/connectors/tpch/TpchConnector.h b/velox/connectors/tpch/TpchConnector.h index 64fc4abf1be0..1ec4aad0a551 100644 --- a/velox/connectors/tpch/TpchConnector.h +++ b/velox/connectors/tpch/TpchConnector.h @@ -18,7 +18,11 @@ #include "velox/common/config/Config.h" #include "velox/connectors/Connector.h" #include "velox/connectors/tpch/TpchConnectorSplit.h" +#include "velox/core/ITypedExpr.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/expression/Expr.h" #include "velox/tpch/gen/TpchGen.h" +#include "velox/vector/BaseVector.h" namespace facebook::velox::connector::tpch { @@ -44,21 +48,37 @@ class TpchTableHandle : public ConnectorTableHandle { explicit TpchTableHandle( std::string connectorId, velox::tpch::Table table, - double scaleFactor = 1.0) + double scaleFactor = 1.0, + velox::core::TypedExprPtr filterExpression = nullptr) : ConnectorTableHandle(std::move(connectorId)), table_(table), - scaleFactor_(scaleFactor) { + scaleFactor_(scaleFactor), + filterExpression_(std::move(filterExpression)) { VELOX_CHECK_GE(scaleFactor, 0, "Tpch scale factor must be non-negative"); + auto sf = static_cast(scaleFactor_); + if (sf > 0) { + name_ = fmt::format("sf{}.{}", sf, velox::tpch::toTableName(table)); + } else { + name_ = fmt::format("tiny.{}", velox::tpch::toTableName(table)); + } } ~TpchTableHandle() override {} std::string toString() const override; + const std::string& name() const override { + return name_; + } + velox::tpch::Table getTable() const { return table_; } + const velox::core::TypedExprPtr& filterExpression() const { + return filterExpression_; + } + double getScaleFactor() const { return scaleFactor_; } @@ -66,17 +86,17 @@ class TpchTableHandle : public ConnectorTableHandle { private: const velox::tpch::Table table_; double scaleFactor_; + std::string name_; + const velox::core::TypedExprPtr filterExpression_; }; class TpchDataSource : public DataSource { public: TpchDataSource( - const std::shared_ptr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, - velox::memory::MemoryPool* pool); + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, + ConnectorQueryCtx* connectorQueryCtx); void addSplit(std::shared_ptr split) override; @@ -97,7 +117,7 @@ class TpchDataSource : public DataSource { return completedBytes_; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { // TODO: Which stats do we want to expose here? return {}; } @@ -106,6 +126,7 @@ class TpchDataSource : public DataSource { bool isLineItem() const; RowVectorPtr projectOutputColumns(RowVectorPtr vector); + RowVectorPtr applyFilter(RowVectorPtr& vector, exec::ExprSet* filter); velox::tpch::Table tpchTable_; double scaleFactor_{1.0}; @@ -126,6 +147,12 @@ class TpchDataSource : public DataSource { size_t completedRows_{0}; size_t completedBytes_{0}; + SelectivityVector filterSelectivityVector_; + exec::FilterEvalCtx filterEvalCtx_; + std::shared_ptr filterMask_; + std::unique_ptr filterExpression_; + + ConnectorQueryCtx* connectorQueryCtx_; memory::MemoryPool* pool_; }; @@ -134,27 +161,20 @@ class TpchConnector final : public Connector { TpchConnector( const std::string& id, std::shared_ptr config, - folly::Executor* /*executor*/) - : Connector(id) {} + folly::Executor* /*executor*/); std::unique_ptr createDataSource( - const std::shared_ptr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const RowTypePtr& outputType, + const ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, ConnectorQueryCtx* connectorQueryCtx) override final { return std::make_unique( - outputType, - tableHandle, - columnHandles, - connectorQueryCtx->memoryPool()); + outputType, tableHandle, columnHandles, connectorQueryCtx); } std::unique_ptr createDataSink( RowTypePtr /*inputType*/, - std::shared_ptr< - ConnectorInsertTableHandle> /*connectorInsertTableHandle*/, + ConnectorInsertTableHandlePtr /*connectorInsertTableHandle*/, ConnectorQueryCtx* /*connectorQueryCtx*/, CommitStrategy /*commitStrategy*/) override final { VELOX_NYI("TpchConnector does not support data sink."); diff --git a/velox/connectors/tpch/TpchConnectorSplit.h b/velox/connectors/tpch/TpchConnectorSplit.h index bfb112db5619..b4b5420d1f92 100644 --- a/velox/connectors/tpch/TpchConnectorSplit.h +++ b/velox/connectors/tpch/TpchConnectorSplit.h @@ -32,7 +32,7 @@ struct TpchConnectorSplit : public connector::ConnectorSplit { bool cacheable, size_t totalParts, size_t partNumber) - : ConnectorSplit(connectorId, /*splitWeight=*/0, cacheable), + : ConnectorSplit(connectorId, /*_splitWeight=*/0, cacheable), totalParts(totalParts), partNumber(partNumber) { VELOX_CHECK_GE(totalParts, 1, "totalParts must be >= 1"); diff --git a/velox/connectors/tpch/tests/CMakeLists.txt b/velox/connectors/tpch/tests/CMakeLists.txt index 5474de3cfa90..eb4f27185160 100644 --- a/velox/connectors/tpch/tests/CMakeLists.txt +++ b/velox/connectors/tpch/tests/CMakeLists.txt @@ -22,7 +22,8 @@ target_link_libraries( velox_exec_test_lib velox_aggregates GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_tpch_speed_test SpeedTest.cpp) @@ -32,4 +33,5 @@ target_link_libraries( velox_exec velox_exec_test_lib velox_memory - fmt::fmt) + fmt::fmt +) diff --git a/velox/connectors/tpch/tests/SpeedTest.cpp b/velox/connectors/tpch/tests/SpeedTest.cpp index 0e8216260a90..e3a55e53aaad 100644 --- a/velox/connectors/tpch/tests/SpeedTest.cpp +++ b/velox/connectors/tpch/tests/SpeedTest.cpp @@ -56,22 +56,16 @@ using std::chrono::system_clock; class TpchSpeedTest { public: TpchSpeedTest() { - connector::registerConnectorFactory( - std::make_shared()); - auto tpchConnector = - connector::getConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName) - ->newConnector( - kTpchConnectorId_, - std::make_shared( - std::unordered_map())); + connector::tpch::TpchConnectorFactory factory; + auto tpchConnector = factory.newConnector( + kTpchConnectorId_, + std::make_shared( + std::unordered_map())); connector::registerConnector(tpchConnector); } ~TpchSpeedTest() { connector::unregisterConnector(kTpchConnectorId_); - connector::unregisterConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName); } void run(tpch::Table table, size_t scaleFactor, size_t numSplits) { @@ -123,8 +117,9 @@ class TpchSpeedTest { for (size_t i = 0; i < numSplits; ++i) { task.addSplit( scanId, - exec::Split(std::make_shared( - kTpchConnectorId_, /*cacheable=*/true, numSplits, i))); + exec::Split( + std::make_shared( + kTpchConnectorId_, /*cacheable=*/true, numSplits, i))); } task.noMoreSplits(scanId); diff --git a/velox/connectors/tpch/tests/TpchConnectorTest.cpp b/velox/connectors/tpch/tests/TpchConnectorTest.cpp index 65d0a1e09bce..1ccd9017b318 100644 --- a/velox/connectors/tpch/tests/TpchConnectorTest.cpp +++ b/velox/connectors/tpch/tests/TpchConnectorTest.cpp @@ -39,29 +39,24 @@ class TpchConnectorTest : public exec::test::OperatorTestBase { void SetUp() override { FLAGS_velox_tpch_text_pool_size_mb = 10; OperatorTestBase::SetUp(); - connector::registerConnectorFactory( - std::make_shared()); - auto tpchConnector = - connector::getConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName) - ->newConnector( - kTpchConnectorId, - std::make_shared( - std::unordered_map())); + connector::tpch::TpchConnectorFactory factory; + auto tpchConnector = factory.newConnector( + kTpchConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(tpchConnector); } void TearDown() override { connector::unregisterConnector(kTpchConnectorId); - connector::unregisterConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName); OperatorTestBase::TearDown(); } exec::Split makeTpchSplit(size_t totalParts = 1, size_t partNumber = 0) const { - return exec::Split(std::make_shared( - kTpchConnectorId, /*cacheable=*/true, totalParts, partNumber)); + return exec::Split( + std::make_shared( + kTpchConnectorId, /*cacheable=*/true, totalParts, partNumber)); } RowVectorPtr getResults( @@ -138,8 +133,9 @@ TEST_F(TpchConnectorTest, singleColumnWithAlias) { PlanBuilder() .startTableScan() .outputType(outputType) - .tableHandle(std::make_shared( - kTpchConnectorId, Table::TBL_NATION)) + .tableHandle( + std::make_shared( + kTpchConnectorId, Table::TBL_NATION)) .assignments({ {aliasedName, std::make_shared("n_name")}, {"other_name", std::make_shared("n_name")}, @@ -164,8 +160,9 @@ void TpchConnectorTest::runScaleFactorTest(double scaleFactor) { auto plan = PlanBuilder() .startTableScan() .outputType(ROW({}, {})) - .tableHandle(std::make_shared( - kTpchConnectorId, Table::TBL_SUPPLIER, scaleFactor)) + .tableHandle( + std::make_shared( + kTpchConnectorId, Table::TBL_SUPPLIER, scaleFactor)) .endTableScan() .singleAggregation({}, {"count(1)"}) .planNode(); @@ -193,8 +190,9 @@ TEST_F(TpchConnectorTest, lineitemTinyRowCount) { auto plan = PlanBuilder() .startTableScan() .outputType(ROW({}, {})) - .tableHandle(std::make_shared( - kTpchConnectorId, Table::TBL_LINEITEM, 0.01)) + .tableHandle( + std::make_shared( + kTpchConnectorId, Table::TBL_LINEITEM, 0.01)) .endTableScan() .singleAggregation({}, {"count(1)"}) .planNode(); @@ -249,6 +247,147 @@ TEST_F(TpchConnectorTest, multipleSplits) { } } +// Test filtering in the TpchConnector. +TEST_F(TpchConnectorTest, filterPushdown) { + // Test equality filter + auto plan = PlanBuilder(pool()) + .tpchTableScan( + Table::TBL_NATION, + {"n_nationkey", "n_name", "n_regionkey"}, + 1.0, + kTpchConnectorId, + "n_regionkey = 1") + .planNode(); + + auto output = getResults(plan, {makeTpchSplit()}); + + // Should only return nations with regionkey = 1 + auto expected = makeRowVector({ + // n_nationkey + makeFlatVector({1, 2, 3, 17, 24}), + // n_name + makeFlatVector({ + "ARGENTINA", + "BRAZIL", + "CANADA", + "PERU", + "UNITED STATES", + }), + // n_regionkey + makeFlatVector({1, 1, 1, 1, 1}), + }); + test::assertEqualVectors(expected, output); +} + +// Test more complex filters in the TpchConnector. +TEST_F(TpchConnectorTest, complexFilterPushdown) { + // Test range filter + auto plan = PlanBuilder(pool()) + .tpchTableScan( + Table::TBL_NATION, + {"n_nationkey", "n_name", "n_regionkey"}, + 1.0, + kTpchConnectorId, + "n_nationkey < 5 AND n_regionkey > 0") + .planNode(); + + auto output = getResults(plan, {makeTpchSplit()}); + + // Should only return nations with nationkey < 5 AND regionkey > 0 + auto expected = makeRowVector({ + // n_nationkey + makeFlatVector({1, 2, 3, 4}), + // n_name + makeFlatVector({ + "ARGENTINA", + "BRAZIL", + "CANADA", + "EGYPT", + }), + // n_regionkey + makeFlatVector({1, 1, 1, 4}), + }); + test::assertEqualVectors(expected, output); +} + +// Test filtering with LIKE operator +TEST_F(TpchConnectorTest, likeFilterPushdown) { + // Test LIKE filter + auto plan = PlanBuilder(pool()) + .tpchTableScan( + Table::TBL_NATION, + {"n_nationkey", "n_name", "n_regionkey"}, + 1.0, + kTpchConnectorId, + "n_name LIKE 'A%'") + .planNode(); + + auto output = getResults(plan, {makeTpchSplit()}); + + // Should only return nations with names starting with 'A' + auto expected = makeRowVector({ + // n_nationkey + makeFlatVector({0, 1}), + // n_name + makeFlatVector({ + "ALGERIA", + "ARGENTINA", + }), + // n_regionkey + makeFlatVector({0, 1}), + }); + test::assertEqualVectors(expected, output); +} + +// Test filtering with IN operator +TEST_F(TpchConnectorTest, inFilterPushdown) { + // Test IN filter + auto plan = PlanBuilder(pool()) + .tpchTableScan( + Table::TBL_NATION, + {"n_nationkey", "n_name", "n_regionkey"}, + 1.0, + kTpchConnectorId, + "n_nationkey IN (0, 5, 10, 15, 20)") + .planNode(); + + auto output = getResults(plan, {makeTpchSplit()}); + + // Should only return nations with nationkey in the specified list + auto expected = makeRowVector({ + // n_nationkey + makeFlatVector({0, 5, 10, 15, 20}), + // n_name + makeFlatVector({ + "ALGERIA", + "ETHIOPIA", + "IRAN", + "MOROCCO", + "SAUDI ARABIA", + }), + // n_regionkey + makeFlatVector({0, 0, 4, 0, 4}), + }); + test::assertEqualVectors(expected, output); +} + +TEST_F(TpchConnectorTest, namespaceInfer) { + std::vector> expect = { + {0.05, "tiny.customer"}, + {1.0, "sf1.customer"}, + {5.0, "sf5.customer"}, + {10.0, "sf10.customer"}, + {100.0, "sf100.customer"}, + {300.0, "sf300.customer"}, + {10000.0, "sf10000.customer"}, + }; + for (const auto& expectpair : expect) { + auto handle = std::make_shared( + kTpchConnectorId, Table::TBL_CUSTOMER, expectpair.first); + EXPECT_EQ(handle->name(), expectpair.second); + } +} + // Join nation and region. TEST_F(TpchConnectorTest, join) { auto planNodeIdGenerator = std::make_shared(); @@ -302,6 +441,20 @@ TEST_F(TpchConnectorTest, orderDateCount) { EXPECT_EQ(9, orderDate->size()); } +TEST_F(TpchConnectorTest, config) { + std::unordered_map properties = { + {"property", "value"}}; + connector::tpch::TpchConnectorFactory factory; + auto connector = factory.newConnector( + kTpchConnectorId, + std::make_shared(std::move(properties))); + + const auto& config = connector->connectorConfig(); + auto val = config->get("property"); + EXPECT_TRUE(val.has_value()); + EXPECT_EQ(val.value(), "value"); +} + } // namespace int main(int argc, char** argv) { diff --git a/velox/core/CMakeLists.txt b/velox/core/CMakeLists.txt index 33c2bb537043..5957c440b942 100644 --- a/velox/core/CMakeLists.txt +++ b/velox/core/CMakeLists.txt @@ -18,24 +18,30 @@ endif() velox_add_library( velox_core Expressions.cpp + ITypedExpr.cpp + PlanConsistencyChecker.cpp PlanFragment.cpp PlanNode.cpp QueryConfig.cpp QueryCtx.cpp - SimpleFunctionMetadata.cpp) + SimpleFunctionMetadata.cpp + TableWriteTraits.cpp +) velox_link_libraries( velox_core - PUBLIC velox_arrow_bridge - velox_caching - velox_common_config - velox_connector - velox_exception - velox_expression_functions - velox_memory - velox_type - velox_vector - Boost::headers - Folly::folly - fmt::fmt - PRIVATE velox_encode) + PUBLIC + velox_arrow_bridge + velox_caching + velox_common_config + velox_connector + velox_exception + velox_expression_functions + velox_memory + velox_type + velox_vector + Boost::headers + Folly::folly + fmt::fmt + PRIVATE velox_encode +) diff --git a/velox/core/ExpressionEvaluator.h b/velox/core/ExpressionEvaluator.h index b10b113c8b25..1309d8ed8b2a 100644 --- a/velox/core/ExpressionEvaluator.h +++ b/velox/core/ExpressionEvaluator.h @@ -41,6 +41,9 @@ class ExpressionEvaluator { virtual std::unique_ptr compile( const std::shared_ptr& expression) = 0; + virtual std::unique_ptr compile( + const std::vector>& expressions) = 0; + // Evaluates previously compiled expression on the specified rows. // Re-uses result vector if it is not null. virtual void evaluate( @@ -49,6 +52,12 @@ class ExpressionEvaluator { const RowVector& input, VectorPtr& result) = 0; + virtual void evaluate( + exec::ExprSet* exprSet, + const SelectivityVector& rows, + const RowVector& input, + std::vector& results) = 0; + // Memory pool used to construct input or output vectors. virtual memory::MemoryPool* pool() = 0; }; diff --git a/velox/core/Expressions.cpp b/velox/core/Expressions.cpp index 75c34c74feca..289d6f3e9b46 100644 --- a/velox/core/Expressions.cpp +++ b/velox/core/Expressions.cpp @@ -14,7 +14,11 @@ * limitations under the License. */ #include "velox/core/Expressions.h" +#include "velox/common/Casts.h" #include "velox/common/encode/Base64.h" +#include "velox/vector/ComplexVector.h" +#include "velox/vector/ConstantVector.h" +#include "velox/vector/SimpleVector.h" #include "velox/vector/VectorSaver.h" namespace facebook::velox::core { @@ -84,6 +88,19 @@ TypedExprPtr InputTypedExpr::create(const folly::dynamic& obj, void* context) { return std::make_shared(std::move(type)); } +std::optional ConstantTypedExpr::toBool() const { + VELOX_CHECK( + this->type()->isBoolean(), + "Expected boolean expression, but got {}", + this->type()->toString()); + + if (!isNull()) { + return valueVector_ ? valueVector_->as>()->valueAt(0) + : value_.value(); + } + return std::nullopt; +} + void ConstantTypedExpr::accept( const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const { @@ -112,7 +129,7 @@ TypedExprPtr ConstantTypedExpr::create( auto type = core::deserializeType(obj, context); if (obj.count("value")) { - auto value = variant::create(obj["value"]); + auto value = Variant::create(obj["value"]); return std::make_shared(std::move(type), value); } @@ -125,6 +142,316 @@ TypedExprPtr ConstantTypedExpr::create( return std::make_shared(restoreVector(dataStream, pool)); } +// static +TypedExprPtr ConstantTypedExpr::makeNull(const TypePtr& type) { + return std::make_shared( + type, Variant::null(type->kind())); +} + +std::string ConstantTypedExpr::toString() const { + if (hasValueVector()) { + return valueVector_->toString(0); + } + + return value_.toStringAsVector(type()); +} + +namespace { + +bool equalsImpl( + const VectorPtr& vector, + vector_size_t index, + const Variant& value); + +template +bool equalsNoNulls( + const VectorPtr& vector, + vector_size_t index, + const Variant& value) { + using T = typename TypeTraits::NativeType; + + const auto thisValue = vector->as>()->valueAt(index); + const auto otherValue = T(value.value()); + + const auto& type = vector->type(); + + auto result = type->providesCustomComparison() + ? SimpleVector::comparePrimitiveAscWithCustomComparison( + type.get(), thisValue, otherValue) + : SimpleVector::comparePrimitiveAsc(thisValue, otherValue); + return result == 0; +} + +template <> +bool equalsNoNulls( + const VectorPtr& vector, + vector_size_t index, + const Variant& value) { + using T = std::shared_ptr; + const auto thisValue = vector->as>()->valueAt(index); + const auto& otherValue = value.value().obj; + + const auto& type = vector->type(); + + auto result = type->providesCustomComparison() + ? SimpleVector::comparePrimitiveAscWithCustomComparison( + type.get(), thisValue, otherValue) + : SimpleVector::comparePrimitiveAsc(thisValue, otherValue); + return result == 0; +} + +template <> +bool equalsNoNulls( + const VectorPtr& vector, + vector_size_t index, + const Variant& value) { + auto* wrappedVector = vector->wrappedVector(); + VELOX_CHECK_EQ(VectorEncoding::Simple::ARRAY, wrappedVector->encoding()); + + auto* arrayVector = wrappedVector->asUnchecked(); + + index = vector->wrappedIndex(index); + + const auto offset = arrayVector->offsetAt(index); + const auto size = arrayVector->sizeAt(index); + + const auto& arrayValue = value.value(); + if (size != arrayValue.size()) { + return false; + } + + for (auto i = 0; i < size; ++i) { + if (!equalsImpl(arrayVector->elements(), offset + i, arrayValue.at(i))) { + return false; + } + } + + return true; +} + +template <> +bool equalsNoNulls( + const VectorPtr& vector, + vector_size_t index, + const Variant& value) { + auto* wrappedVector = vector->wrappedVector(); + VELOX_CHECK_EQ(VectorEncoding::Simple::MAP, wrappedVector->encoding()); + + auto* mapVector = wrappedVector->asUnchecked(); + + index = vector->wrappedIndex(index); + + const auto size = mapVector->sizeAt(index); + + const auto& mapValue = value.value(); + if (size != mapValue.size()) { + return false; + } + + const auto sortedIndices = mapVector->sortedKeyIndices(index); + + size_t i = 0; + for (const auto& [key, value] : mapValue) { + if (!equalsImpl(mapVector->mapKeys(), sortedIndices[i], key)) { + return false; + } + + if (!equalsImpl(mapVector->mapValues(), sortedIndices[i], value)) { + return false; + } + + ++i; + } + + return true; +} + +template <> +bool equalsNoNulls( + const VectorPtr& vector, + vector_size_t index, + const Variant& value) { + auto* wrappedVector = vector->wrappedVector(); + VELOX_CHECK_EQ(VectorEncoding::Simple::ROW, wrappedVector->encoding()); + + auto* rowVector = wrappedVector->asUnchecked(); + + index = vector->wrappedIndex(index); + + const auto size = rowVector->type()->size(); + + const auto& rowValue = value.value(); + if (size != rowValue.size()) { + return false; + } + + for (auto i = 0; i < size; ++i) { + if (rowVector->childAt(i) == nullptr) { + return false; + } + + if (!equalsImpl(rowVector->childAt(i), index, rowValue.at(i))) { + return false; + } + } + + return true; +} + +bool equalsImpl( + const VectorPtr& vector, + vector_size_t index, + const Variant& value) { + static constexpr CompareFlags kEqualValueAtFlags = + CompareFlags::equality(CompareFlags::NullHandlingMode::kNullAsValue); + + bool thisNull = vector->isNullAt(index); + bool otherNull = value.isNull(); + + if (otherNull || thisNull) { + return BaseVector::compareNulls(thisNull, otherNull, kEqualValueAtFlags) + .value() == 0; + } + + return VELOX_DYNAMIC_TYPE_DISPATCH_ALL( + equalsNoNulls, vector->typeKind(), vector, index, value); +} +} // namespace + +bool ConstantTypedExpr::equals(const ITypedExpr& other) const { + const auto* casted = dynamic_cast(&other); + if (!casted) { + return false; + } + + if (*this->type() != *casted->type()) { + return false; + } + + if (this->hasValueVector() != casted->hasValueVector()) { + return this->hasValueVector() + ? equalsImpl(this->valueVector_, 0, casted->value_) + : equalsImpl(casted->valueVector_, 0, this->value_); + } + + if (this->hasValueVector()) { + return this->valueVector_->equalValueAt(casted->valueVector_.get(), 0, 0); + } + + return this->value_ == casted->value_; +} + +namespace { + +uint64_t hashImpl(const TypePtr& type, const Variant& value); + +template +uint64_t hashImpl(const TypePtr& type, const Variant& value) { + using T = typename TypeTraits::NativeType; + + const auto& v = value.value(); + + if (type->providesCustomComparison()) { + return SimpleVector::hashValueAtWithCustomType(type, T(v)); + } + + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareHash{}(T(v)); + } else { + return folly::hasher{}(T(v)); + } +} + +template <> +uint64_t hashImpl(const TypePtr& type, const Variant& value) { + return value.hash(); +} + +template <> +uint64_t hashImpl(const TypePtr& type, const Variant& value) { + const auto& arrayValue = value.value(); + + const auto& elementType = type->childAt(0); + + uint64_t hash = BaseVector::kNullHash; + for (auto i = 0; i < arrayValue.size(); ++i) { + hash = bits::hashMix(hash, hashImpl(elementType, arrayValue.at(i))); + } + return hash; +} + +template <> +uint64_t hashImpl(const TypePtr& type, const Variant& value) { + const auto& mapValue = value.value(); + + const auto& keyType = type->childAt(0); + const auto& valueType = type->childAt(1); + + uint64_t hash = BaseVector::kNullHash; + for (const auto& [key, value] : mapValue) { + const auto keyValueHash = + bits::hashMix(hashImpl(keyType, key), hashImpl(valueType, value)); + hash = bits::commutativeHashMix(hash, keyValueHash); + } + + return hash; +} + +template <> +uint64_t hashImpl(const TypePtr& type, const Variant& value) { + const auto& rowValue = value.value(); + + uint64_t hash = BaseVector::kNullHash; + for (auto i = 0; i < rowValue.size(); ++i) { + const auto value = hashImpl(type->childAt(i), rowValue.at(i)); + if (i == 0) { + hash = value; + } else { + hash = bits::hashMix(hash, value); + } + } + return hash; +} + +uint64_t hashImpl(const TypePtr& type, const Variant& value) { + if (value.isNull()) { + return BaseVector::kNullHash; + } + + return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(hashImpl, type->kind(), type, value); +} + +} // namespace + +size_t ConstantTypedExpr::localHash() const { + static const size_t kBaseHash = std::hash()("ConstantTypedExpr"); + + uint64_t h; + + if (hasValueVector()) { + h = valueVector_->hashValueAt(0); + } else { + h = hashImpl(type(), value_); + } + + return bits::hashMix(kBaseHash, h); +} + +std::string CallTypedExpr::toString() const { + std::string str{}; + str += name(); + str += "("; + for (size_t i = 0; i < inputs().size(); ++i) { + auto& input = inputs().at(i); + if (i != 0) { + str += ","; + } + str += input->toString(); + } + str += ")"; + return str; +} + void CallTypedExpr::accept( const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const { @@ -146,6 +473,43 @@ TypedExprPtr CallTypedExpr::create(const folly::dynamic& obj, void* context) { std::move(type), std::move(inputs), obj["functionName"].asString()); } +TypedExprPtr FieldAccessTypedExpr::rewriteInputNames( + const std::unordered_map& mapping) const { + if (inputs().empty()) { + auto it = mapping.find(name_); + return it != mapping.end() + ? it->second + : std::make_shared(type(), name_); + } + + auto newInputs = rewriteInputsRecursive(mapping); + VELOX_CHECK_EQ(1, newInputs.size()); + // Only rewrite name if input in InputTypedExpr. Rewrite in other + // cases(like dereference) is unsound. + if (!newInputs[0]->isInputKind()) { + return std::make_shared(type(), newInputs[0], name_); + } + auto it = mapping.find(name_); + auto newName = name_; + if (it != mapping.end()) { + if (auto name = + std::dynamic_pointer_cast(it->second)) { + newName = name->name(); + } + } + return std::make_shared(type(), newInputs[0], newName); +} + +std::string FieldAccessTypedExpr::toString() const { + std::stringstream ss; + ss << std::quoted(name(), '"', '"'); + if (inputs().empty()) { + return fmt::format("{}", ss.str()); + } + + return fmt::format("{}[{}]", inputs()[0]->toString(), ss.str()); +} + void FieldAccessTypedExpr::accept( const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const { @@ -202,6 +566,40 @@ TypedExprPtr DereferenceTypedExpr::create( std::move(type), std::move(inputs[0]), index); } +namespace { +TypePtr toRowType( + const std::vector& names, + const std::vector& expressions) { + std::vector types; + types.reserve(expressions.size()); + for (const auto& expr : expressions) { + types.push_back(expr->type()); + } + + auto namesCopy = names; + return ROW(std::move(namesCopy), std::move(types)); +} +} // namespace + +ConcatTypedExpr::ConcatTypedExpr( + const std::vector& names, + const std::vector& inputs) + : ITypedExpr{ExprKind::kConcat, toRowType(names, inputs), inputs} {} + +std::string ConcatTypedExpr::toString() const { + std::string str{}; + str += "CONCAT("; + for (size_t i = 0; i < inputs().size(); ++i) { + auto& input = inputs().at(i); + if (i != 0) { + str += ","; + } + str += input->toString(); + } + str += ")"; + return str; +} + void ConcatTypedExpr::accept( const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const { @@ -221,6 +619,17 @@ TypedExprPtr ConcatTypedExpr::create(const folly::dynamic& obj, void* context) { type->asRow().names(), std::move(inputs)); } +TypedExprPtr LambdaTypedExpr::rewriteInputNames( + const std::unordered_map& mapping) const { + for (const auto& name : signature_->names()) { + if (mapping.count(name)) { + VELOX_USER_FAIL("Ambiguous variable: {}", name); + } + } + return std::make_shared( + signature_, body_->rewriteInputNames(mapping)); +} + void LambdaTypedExpr::accept( const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const { @@ -243,6 +652,16 @@ TypedExprPtr LambdaTypedExpr::create(const folly::dynamic& obj, void* context) { asRowType(signature), std::move(body)); } +std::string CastTypedExpr::toString() const { + if (isTryCast_) { + return fmt::format( + "try_cast({} as {})", inputs()[0]->toString(), type()->toString()); + } else { + return fmt::format( + "cast({} as {})", inputs()[0]->toString(), type()->toString()); + } +} + void CastTypedExpr::accept( const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const { @@ -251,7 +670,7 @@ void CastTypedExpr::accept( folly::dynamic CastTypedExpr::serialize() const { auto obj = ITypedExpr::serializeBase("CastTypedExpr"); - obj["nullOnFailure"] = nullOnFailure_; + obj["isTryCast"] = isTryCast_; return obj; } @@ -261,7 +680,7 @@ TypedExprPtr CastTypedExpr::create(const folly::dynamic& obj, void* context) { auto inputs = deserializeInputs(obj, context); return std::make_shared( - std::move(type), std::move(inputs), obj["nullOnFailure"].asBool()); + std::move(type), std::move(inputs), obj["isTryCast"].asBool()); } } // namespace facebook::velox::core diff --git a/velox/core/Expressions.h b/velox/core/Expressions.h index 582dbce8b7ef..ba97846c53b6 100644 --- a/velox/core/Expressions.h +++ b/velox/core/Expressions.h @@ -15,7 +15,7 @@ */ #pragma once -#include +#include "velox/common/Casts.h" #include "velox/common/base/Exceptions.h" #include "velox/core/ITypedExpr.h" #include "velox/vector/BaseVector.h" @@ -24,11 +24,11 @@ namespace facebook::velox::core { class InputTypedExpr : public ITypedExpr { public: - explicit InputTypedExpr(TypePtr type) : ITypedExpr{std::move(type)} {} + explicit InputTypedExpr(TypePtr type) + : ITypedExpr{ExprKind::kInput, std::move(type)} {} bool operator==(const ITypedExpr& other) const final { - const auto* casted = dynamic_cast(&other); - return casted != nullptr; + return other.isInputKind(); } std::string toString() const override { @@ -58,41 +58,36 @@ class InputTypedExpr : public ITypedExpr { class ConstantTypedExpr : public ITypedExpr { public: // Creates constant expression. For complex types, only - // variant::null() value is supported. - ConstantTypedExpr(TypePtr type, variant value) - : ITypedExpr{std::move(type)}, value_{std::move(value)} {} + // Variant::null() value is supported. + ConstantTypedExpr(TypePtr type, Variant value) + : ITypedExpr{ExprKind::kConstant, std::move(type)}, + value_{std::move(value)} { + VELOX_CHECK( + value_.isTypeCompatible(ITypedExpr::type()), + "Expression type {} does not match variant type {}", + ITypedExpr::type()->toString(), + value_.inferType()->toString()); + } // Creates constant expression of scalar or complex type. The value comes from // index zero. explicit ConstantTypedExpr(const VectorPtr& value) - : ITypedExpr{value->type()}, + : ITypedExpr{ExprKind::kConstant, value->type()}, valueVector_{ value->isConstantEncoding() ? value : BaseVector::wrapInConstant(1, 0, value)} {} - std::string toString() const override { - if (hasValueVector()) { - return valueVector_->toString(0); - } - return value_.toJson(type()); - } - - size_t localHash() const override { - static const size_t kBaseHash = - std::hash()("ConstantTypedExpr"); + std::string toString() const override; - return bits::hashMix( - kBaseHash, - hasValueVector() ? valueVector_->hashValueAt(0) : value_.hash()); - } + size_t localHash() const override; bool hasValueVector() const { return valueVector_ != nullptr; } - /// Returns scalar value as variant if hasValueVector() is false. - const variant& value() const { + /// Returns scalar value as Variant if hasValueVector() is false. + const Variant& value() const { return value_; } @@ -119,6 +114,10 @@ class ConstantTypedExpr : public ITypedExpr { return BaseVector::createConstant(type(), value_, 1, pool); } + /// Returns value of boolean expression, std::nullopt for null booleans. + /// Throws an error if expression is not of boolean type. + std::optional toBool() const; + const std::vector& inputs() const { static const std::vector kEmpty{}; return kEmpty; @@ -138,26 +137,7 @@ class ConstantTypedExpr : public ITypedExpr { const ITypedExprVisitor& visitor, ITypedExprVisitorContext& context) const override; - bool equals(const ITypedExpr& other) const { - const auto* casted = dynamic_cast(&other); - if (!casted) { - return false; - } - - if (*this->type() != *casted->type()) { - return false; - } - - if (this->hasValueVector() != casted->hasValueVector()) { - return false; - } - - if (this->hasValueVector()) { - return this->valueVector_->equalValueAt(casted->valueVector_.get(), 0, 0); - } - - return this->value_ == casted->value_; - } + bool equals(const ITypedExpr& other) const; bool operator==(const ITypedExpr& other) const final { return this->equals(other); @@ -171,8 +151,11 @@ class ConstantTypedExpr : public ITypedExpr { static TypedExprPtr create(const folly::dynamic& obj, void* context); + /// Returns a NULL constant expression of given type. + static TypedExprPtr makeNull(const TypePtr& type); + private: - const variant value_; + const Variant value_; const VectorPtr valueVector_; }; @@ -180,15 +163,15 @@ using ConstantTypedExprPtr = std::shared_ptr; /// Evaluates a scalar function or a special form. /// -/// Supported special forms are: and, or, cast, try_cast, coalesce, if, switch, -/// try. See registerFunctionCallToSpecialForms in +/// Supported special forms are: and, or, cast, try_cast, coalesce, if, +/// switch, try. See registerFunctionCallToSpecialForms in /// expression/RegisterSpecialForm.h for the up-to-date list. /// /// Regular functions have the following properties: (1) return type is fully -/// defined by function name and input types; (2) during evaluation all function -/// arguments are evaluated first before the function itself is evaluated on the -/// results, a failure to evaluate function argument prevents the function from -/// being evaluated. +/// defined by function name and input types; (2) during evaluation all +/// function arguments are evaluated first before the function itself is +/// evaluated on the results, a failure to evaluate function argument prevents +/// the function from being evaluated. /// /// Special forms are different from regular scalar functions as they do not /// always have the above properties. @@ -198,11 +181,11 @@ using ConstantTypedExprPtr = std::shared_ptr; /// - Conjuncts AND, OR don't have (2): these have logic to stop evaluating /// arguments if the outcome is already decided. For example, a > 10 AND b < 3 /// applied to a = 0 and b = 0 is fully decided after evaluating a > 10. The -/// result is FALSE. This is important not only from efficiency standpoint, but -/// semantically as well. Not evaluating unnecessary arguments implicitly -/// suppresses the errors that might have happened if evaluation proceeded. For -/// example, a > 10 AND b / a > 1 would fail if both expressions were evaluated -/// on a = 0. +/// result is FALSE. This is important not only from efficiency standpoint, +/// but semantically as well. Not evaluating unnecessary arguments implicitly +/// suppresses the errors that might have happened if evaluation proceeded. +/// For example, a > 10 AND b / a > 1 would fail if both expressions were +/// evaluated on a = 0. /// - Coalesce, if, switch also don't have (2): these also have logic to stop /// evaluating arguments if the outcome is already decided. /// - TRY doesn't have (2) either: it needs to capture and suppress errors @@ -216,9 +199,19 @@ class CallTypedExpr : public ITypedExpr { TypePtr type, std::vector inputs, std::string name) - : ITypedExpr{std::move(type), std::move(inputs)}, + : ITypedExpr{ExprKind::kCall, std::move(type), std::move(inputs)}, name_(std::move(name)) {} + /// @param type Return type. + /// @param name Name of the function or special form. + /// @param inputs List of input expressions. + template + CallTypedExpr(TypePtr type, std::string name, TypedExprs... inputs) + : CallTypedExpr( + std::move(type), + std::vector{std::forward(inputs)...}, + std::move(name)) {} + virtual const std::string& name() const { return name_; } @@ -230,20 +223,7 @@ class CallTypedExpr : public ITypedExpr { type(), rewriteInputsRecursive(mapping), name_); } - std::string toString() const override { - std::string str{}; - str += name(); - str += "("; - for (size_t i = 0; i < inputs().size(); ++i) { - auto& input = inputs().at(i); - if (i != 0) { - str += ","; - } - str += input->toString(); - } - str += ")"; - return str; - } + std::string toString() const override; size_t localHash() const override { static const size_t kBaseHash = std::hash()("CallTypedExpr"); @@ -270,10 +250,10 @@ class CallTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } @@ -292,17 +272,16 @@ class FieldAccessTypedExpr : public ITypedExpr { public: /// Used as a leaf in an expression tree specifying input column by name. FieldAccessTypedExpr(TypePtr type, std::string name) - : ITypedExpr{std::move(type)}, + : ITypedExpr{ExprKind::kFieldAccess, std::move(type)}, name_(std::move(name)), isInputColumn_(true) {} /// Used as a dereference expression which selects a subfield in a struct by /// name. FieldAccessTypedExpr(TypePtr type, TypedExprPtr input, std::string name) - : ITypedExpr{std::move(type), {std::move(input)}}, + : ITypedExpr{ExprKind::kFieldAccess, std::move(type), {std::move(input)}}, name_(std::move(name)), - isInputColumn_(dynamic_cast(inputs()[0].get())) { - } + isInputColumn_(inputs()[0]->isInputKind()) {} const std::string& name() const { return name_; @@ -310,43 +289,9 @@ class FieldAccessTypedExpr : public ITypedExpr { TypedExprPtr rewriteInputNames( const std::unordered_map& mapping) - const override { - if (inputs().empty()) { - auto it = mapping.find(name_); - return it != mapping.end() - ? it->second - : std::make_shared(type(), name_); - } - - auto newInputs = rewriteInputsRecursive(mapping); - VELOX_CHECK_EQ(1, newInputs.size()); - // Only rewrite name if input in InputTypedExpr. Rewrite in other - // cases(like dereference) is unsound. - if (!std::dynamic_pointer_cast(newInputs[0])) { - return std::make_shared( - type(), newInputs[0], name_); - } - auto it = mapping.find(name_); - auto newName = name_; - if (it != mapping.end()) { - if (auto name = std::dynamic_pointer_cast( - it->second)) { - newName = name->name(); - } - } - return std::make_shared( - type(), newInputs[0], newName); - } - - std::string toString() const override { - std::stringstream ss; - ss << std::quoted(name(), '"', '"'); - if (inputs().empty()) { - return fmt::format("{}", ss.str()); - } + const override; - return fmt::format("{}[{}]", inputs()[0]->toString(), ss.str()); - } + std::string toString() const override; size_t localHash() const override { static const size_t kBaseHash = @@ -374,10 +319,10 @@ class FieldAccessTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } @@ -397,15 +342,17 @@ class FieldAccessTypedExpr : public ITypedExpr { using FieldAccessTypedExprPtr = std::shared_ptr; -/// Represents a dereference expression which selects a subfield in a struct by -/// name. +/// Represents a dereference expression which selects a subfield in a struct +/// by name. class DereferenceTypedExpr : public ITypedExpr { public: DereferenceTypedExpr(TypePtr type, TypedExprPtr input, uint32_t index) - : ITypedExpr{std::move(type), {std::move(input)}}, index_(index) { + : ITypedExpr{ExprKind::kDereference, std::move(type), {std::move(input)}}, + index_(index) { // Make sure this isn't being used to access a top level column. - VELOX_USER_CHECK_NULL( - std::dynamic_pointer_cast(inputs()[0])); + VELOX_USER_CHECK( + !inputs()[0]->isInputKind(), + "DereferenceTypedExpr select a subfeild cannot be used to access a top level column"); } uint32_t index() const { @@ -452,10 +399,10 @@ class DereferenceTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } @@ -474,8 +421,7 @@ class ConcatTypedExpr : public ITypedExpr { public: ConcatTypedExpr( const std::vector& names, - const std::vector& inputs) - : ITypedExpr{toType(names, inputs), inputs} {} + const std::vector& inputs); TypedExprPtr rewriteInputNames( const std::unordered_map& mapping) @@ -484,19 +430,7 @@ class ConcatTypedExpr : public ITypedExpr { type()->asRow().names(), rewriteInputsRecursive(mapping)); } - std::string toString() const override { - std::string str{}; - str += "CONCAT("; - for (size_t i = 0; i < inputs().size(); ++i) { - auto& input = inputs().at(i); - if (i != 0) { - str += ","; - } - str += input->toString(); - } - str += ")"; - return str; - } + std::string toString() const override; size_t localHash() const override { static const size_t kBaseHash = std::hash()("ConcatTypedExpr"); @@ -520,29 +454,16 @@ class ConcatTypedExpr : public ITypedExpr { return false; } return std::equal( - this->inputs().begin(), - this->inputs().end(), - other.inputs().begin(), - other.inputs().end(), + this->inputs().cbegin(), + this->inputs().cend(), + other.inputs().cbegin(), + other.inputs().cend(), [](const auto& p1, const auto& p2) { return *p1 == *p2; }); } folly::dynamic serialize() const override; static TypedExprPtr create(const folly::dynamic& obj, void* context); - - private: - static TypePtr toType( - const std::vector& names, - const std::vector& expressions) { - std::vector children{}; - std::vector namesCopy{}; - for (size_t i = 0; i < names.size(); ++i) { - namesCopy.push_back(names.at(i)); - children.push_back(expressions.at(i)->type()); - } - return ROW(std::move(namesCopy), std::move(children)); - } }; using ConcatTypedExprPtr = std::shared_ptr; @@ -550,9 +471,11 @@ using ConcatTypedExprPtr = std::shared_ptr; class LambdaTypedExpr : public ITypedExpr { public: LambdaTypedExpr(RowTypePtr signature, TypedExprPtr body) - : ITypedExpr(std::make_shared( - std::vector(signature->children()), - body->type())), + : ITypedExpr( + ExprKind::kLambda, + std::make_shared( + std::vector(signature->children()), + body->type())), signature_(std::move(signature)), body_(std::move(body)) {} @@ -566,15 +489,7 @@ class LambdaTypedExpr : public ITypedExpr { TypedExprPtr rewriteInputNames( const std::unordered_map& mapping) - const override { - for (const auto& name : signature_->names()) { - if (mapping.count(name)) { - VELOX_USER_FAIL("Ambiguous variable: {}", name); - } - } - return std::make_shared( - signature_, body_->rewriteInputNames(mapping)); - } + const override; std::string toString() const override { return fmt::format( @@ -623,18 +538,15 @@ class CastTypedExpr : public ITypedExpr { /// expresion. /// @param input Single input. The type of input is referred to as from-type /// and expected to be different from to-type. - /// @param nullOnFailure Whether to suppress cast errors and return null. - CastTypedExpr( - const TypePtr& type, - const TypedExprPtr& input, - bool nullOnFailure) - : ITypedExpr{type, {input}}, nullOnFailure_(nullOnFailure) {} + /// @param isTryCast Whether this expression is used for `try_cast`. + CastTypedExpr(const TypePtr& type, const TypedExprPtr& input, bool isTryCast) + : ITypedExpr{ExprKind::kCast, type, {input}}, isTryCast_(isTryCast) {} CastTypedExpr( const TypePtr& type, const std::vector& inputs, - bool nullOnFailure) - : ITypedExpr{type, inputs}, nullOnFailure_(nullOnFailure) { + bool isTryCast) + : ITypedExpr{ExprKind::kCast, type, inputs}, isTryCast_(isTryCast) { VELOX_USER_CHECK_EQ( 1, inputs.size(), "Cast expression requires exactly one input"); } @@ -643,22 +555,14 @@ class CastTypedExpr : public ITypedExpr { const std::unordered_map& mapping) const override { return std::make_shared( - type(), rewriteInputsRecursive(mapping), nullOnFailure_); + type(), rewriteInputsRecursive(mapping), isTryCast_); } - std::string toString() const override { - if (nullOnFailure_) { - return fmt::format( - "try_cast {} as {}", inputs()[0]->toString(), type()->toString()); - } else { - return fmt::format( - "cast {} as {}", inputs()[0]->toString(), type()->toString()); - } - } + std::string toString() const override; size_t localHash() const override { static const size_t kBaseHash = std::hash()("CastTypedExpr"); - return bits::hashMix(kBaseHash, std::hash()(nullOnFailure_)); + return bits::hashMix(kBaseHash, std::hash()(isTryCast_)); } void accept( @@ -672,15 +576,15 @@ class CastTypedExpr : public ITypedExpr { } if (inputs().empty()) { return type() == otherCast->type() && otherCast->inputs().empty() && - nullOnFailure_ == otherCast->nullOnFailure(); + isTryCast_ == otherCast->isTryCast(); } return *type() == *otherCast->type() && *inputs()[0] == *otherCast->inputs()[0] && - nullOnFailure_ == otherCast->nullOnFailure(); + isTryCast_ == otherCast->isTryCast(); } - bool nullOnFailure() const { - return nullOnFailure_; + bool isTryCast() const { + return isTryCast_; } folly::dynamic serialize() const override; @@ -688,8 +592,9 @@ class CastTypedExpr : public ITypedExpr { static TypedExprPtr create(const folly::dynamic& obj, void* context); private: - // Suppress exception and return null on failure to cast. - const bool nullOnFailure_; + // Whether this expression is used for `try_cast`. When true, Presto cast + // suppresses exception and return null on failure to case. + const bool isTryCast_; }; using CastTypedExprPtr = std::shared_ptr; @@ -699,7 +604,7 @@ class TypedExprs { public: /// Returns true if 'expr' is a field access expression. static bool isFieldAccess(const TypedExprPtr& expr) { - return dynamic_cast(expr.get()) != nullptr; + return expr->isFieldAccessKind(); } /// Returns 'expr' as FieldAccessTypedExprPtr or null if not field access @@ -710,7 +615,7 @@ class TypedExprs { /// Returns true if 'expr' is a constant expression. static bool isConstant(const TypedExprPtr& expr) { - return dynamic_cast(expr.get()) != nullptr; + return expr->isConstantKind(); } /// Returns 'expr' as ConstantTypedExprPtr or null if not a constant @@ -721,7 +626,7 @@ class TypedExprs { /// Returns true if 'expr' is a lambda expression. static bool isLambda(const TypedExprPtr& expr) { - return dynamic_cast(expr.get()) != nullptr; + return expr->isLambdaKind(); } /// Returns 'expr' as LambdaTypedExprPtr or null if not a lambda expression. diff --git a/velox/core/ITypedExpr.cpp b/velox/core/ITypedExpr.cpp new file mode 100644 index 000000000000..f0b11d8af809 --- /dev/null +++ b/velox/core/ITypedExpr.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/core/ITypedExpr.h" + +namespace facebook::velox::core { + +namespace { +const auto& exprKindNames() { + static const folly::F14FastMap kNames = { + {ExprKind::kInput, "INPUT"}, + {ExprKind::kFieldAccess, "FIELD"}, + {ExprKind::kDereference, "DEREFERENCE"}, + {ExprKind::kCall, "CALL"}, + {ExprKind::kCast, "CAST"}, + {ExprKind::kConstant, "CONSTANT"}, + {ExprKind::kConcat, "CONCAT"}, + {ExprKind::kLambda, "LAMBDA"}, + }; + return kNames; +} +} // namespace + +VELOX_DEFINE_ENUM_NAME(ExprKind, exprKindNames); + +size_t ITypedExprHasher::operator()(const ITypedExpr* expr) const { + return expr->hash(); +} + +bool ITypedExprComparer::operator()( + const ITypedExpr* lhs, + const ITypedExpr* rhs) const { + return *lhs == *rhs; +} +} // namespace facebook::velox::core diff --git a/velox/core/ITypedExpr.h b/velox/core/ITypedExpr.h index b0f96623aaf1..b1e705ff7f58 100644 --- a/velox/core/ITypedExpr.h +++ b/velox/core/ITypedExpr.h @@ -19,30 +19,93 @@ namespace facebook::velox::core { +enum class ExprKind : int32_t { + kInput = 0, + kFieldAccess = 1, + kDereference = 2, + kCall = 3, + kCast = 5, + kConstant = 6, + kConcat = 7, + kLambda = 8, +}; + +VELOX_DECLARE_ENUM_NAME(ExprKind); + class ITypedExpr; class ITypedExprVisitor; class ITypedExprVisitorContext; using TypedExprPtr = std::shared_ptr; +struct ITypedExprHasher { + size_t operator()(const ITypedExpr* expr) const; +}; + +struct ITypedExprComparer { + bool operator()(const ITypedExpr* lhs, const ITypedExpr* rhs) const; +}; + /// Strongly-typed expression, e.g. literal, function call, etc. class ITypedExpr : public ISerializable { public: - explicit ITypedExpr(TypePtr type) : type_{std::move(type)}, inputs_{} {} + ITypedExpr(ExprKind kind, TypePtr type) + : kind_{kind}, type_{std::move(type)}, inputs_{} {} + + ITypedExpr(ExprKind kind, TypePtr type, std::vector inputs) + : kind_{kind}, type_{std::move(type)}, inputs_{std::move(inputs)} {} - ITypedExpr(TypePtr type, std::vector inputs) - : type_{std::move(type)}, inputs_{std::move(inputs)} {} + virtual ~ITypedExpr() = default; + + ExprKind kind() const { + return kind_; + } const TypePtr& type() const { return type_; } - virtual ~ITypedExpr() = default; - const std::vector& inputs() const { return inputs_; } + bool isInputKind() const { + return kind_ == ExprKind::kInput; + } + + bool isFieldAccessKind() const { + return kind_ == ExprKind::kFieldAccess; + } + + bool isDereferenceKind() const { + return kind_ == ExprKind::kDereference; + } + + bool isCallKind() const { + return kind_ == ExprKind::kCall; + } + + bool isCastKind() const { + return kind_ == ExprKind::kCast; + } + + bool isConstantKind() const { + return kind_ == ExprKind::kConstant; + } + + bool isConcatKind() const { + return kind_ == ExprKind::kConcat; + } + + bool isLambdaKind() const { + return kind_ == ExprKind::kLambda; + } + + template + const T* asUnchecked() const { + return dynamic_cast(this); + } + /// Returns a copy of this expression with input fields replaced according /// to specified 'mapping'. Fields specified in the 'mapping' are replaced /// by the corresponding expression in 'mapping'. @@ -64,7 +127,7 @@ class ITypedExpr : public ISerializable { size_t hash() const { size_t hash = bits::hashMix(type_->hashKind(), localHash()); - for (int32_t i = 0; i < inputs_.size(); ++i) { + for (size_t i = 0; i < inputs_.size(); ++i) { hash = bits::hashMix(hash, inputs_[i]->hash()); } return hash; @@ -88,8 +151,9 @@ class ITypedExpr : public ISerializable { } private: - TypePtr type_; - std::vector inputs_; + const ExprKind kind_; + const TypePtr type_; + const std::vector inputs_; }; } // namespace facebook::velox::core diff --git a/velox/core/Metaprogramming.h b/velox/core/Metaprogramming.h index 955ea4122b46..95906ef21406 100644 --- a/velox/core/Metaprogramming.h +++ b/velox/core/Metaprogramming.h @@ -124,11 +124,10 @@ template struct has_method { private: template - static constexpr auto check(T*) -> - typename std::is_same< - decltype(std::declval().template resolve( - std::declval()...)), - TRet>::type { + static constexpr auto check(T*) -> typename std::is_same< + decltype(std::declval().template resolve( + std::declval()...)), + TRet>::type { return {}; } diff --git a/velox/core/PlanConsistencyChecker.cpp b/velox/core/PlanConsistencyChecker.cpp new file mode 100644 index 000000000000..fa20ce268aee --- /dev/null +++ b/velox/core/PlanConsistencyChecker.cpp @@ -0,0 +1,314 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/core/PlanConsistencyChecker.h" + +namespace facebook::velox::core { + +namespace { + +class Checker : public PlanNodeVisitor { + public: + void visit(const AggregationNode& node, PlanNodeVisitorContext& ctx) + const override { + const auto& rowType = node.sources().at(0)->outputType(); + for (const auto& expr : node.groupingKeys()) { + checkInputs(expr, rowType); + } + + for (const auto& expr : node.preGroupedKeys()) { + checkInputs(expr, rowType); + } + + for (const auto& aggregate : node.aggregates()) { + checkInputs(aggregate.call, rowType); + + for (const auto& expr : aggregate.sortingKeys) { + checkInputs(expr, rowType); + } + + if (aggregate.mask) { + checkInputs(aggregate.mask, rowType); + } + } + + verifyOutputNames(node); + + visitSources(&node, ctx); + } + + void visit(const ArrowStreamNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const AssignUniqueIdNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const EnforceSingleRowNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const ExchangeNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const ExpandNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const FilterNode& node, PlanNodeVisitorContext& ctx) + const override { + checkInputs(node.filter(), node.sources().at(0)->outputType()); + + visitSources(&node, ctx); + } + + void visit(const GroupIdNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const HashJoinNode& node, PlanNodeVisitorContext& ctx) + const override { + std::unordered_set> keyNames; + for (auto i = 0; i < node.leftKeys().size(); ++i) { + const auto& leftKey = node.leftKeys().at(i); + const auto& rightKey = node.rightKeys().at(i); + + bool unique = keyNames.emplace(leftKey->name(), rightKey->name()).second; + VELOX_CHECK( + unique, + "Duplicate join condition: {} = {}", + leftKey->toString(), + rightKey->toString()); + } + + if (node.filter() != nullptr) { + const auto& leftRowType = node.sources().at(0)->outputType(); + const auto& rightRowType = node.sources().at(1)->outputType(); + auto rowType = leftRowType->unionWith(rightRowType); + checkInputs(node.filter(), rowType); + } + + visitSources(&node, ctx); + } + + void visit(const IndexLookupJoinNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const LimitNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const LocalMergeNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const LocalPartitionNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const MarkDistinctNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const MergeExchangeNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const MergeJoinNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const NestedLoopJoinNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const SpatialJoinNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const OrderByNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const PartitionedOutputNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const ProjectNode& node, PlanNodeVisitorContext& ctx) + const override { + const auto& rowType = node.sources().at(0)->outputType(); + for (const auto& expr : node.projections()) { + checkInputs(expr, rowType); + } + + verifyOutputNames(node); + + visitSources(&node, ctx); + } + + void visit(const ParallelProjectNode& node, PlanNodeVisitorContext& ctx) + const override { + const auto& rowType = node.sources().at(0)->outputType(); + for (const auto& group : node.exprGroups()) { + for (const auto& expr : group) { + checkInputs(expr, rowType); + } + } + visitSources(&node, ctx); + } + + void visit(const RowNumberNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const TableScanNode& node, PlanNodeVisitorContext& ctx) + const override { + verifyOutputNames(node); + + // Verify assignments match outputType 1:1. + const auto& names = node.outputType()->names(); + VELOX_USER_CHECK_EQ( + names.size(), + node.assignments().size(), + "Column assignments must match output type 1:1."); + + for (const auto& name : names) { + VELOX_USER_CHECK( + node.assignments().contains(name), + "Column assignment is missing for {}", + name); + } + } + + void visit(const TableWriteNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const TableWriteMergeNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const TopNNode& node, PlanNodeVisitorContext& ctx) const override { + visitSources(&node, ctx); + } + + void visit(const TopNRowNumberNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const TraceScanNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const UnnestNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const ValuesNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const WindowNode& node, PlanNodeVisitorContext& ctx) + const override { + visitSources(&node, ctx); + } + + void visit(const PlanNode& node, PlanNodeVisitorContext& ctx) const override { + visitSources(&node, ctx); + } + + private: + void visitSources(const PlanNode* node, PlanNodeVisitorContext& ctx) const { + for (auto& source : node->sources()) { + source->accept(*this, ctx); + } + } + + // Verify that output column names are not empty and unique. + static void verifyOutputNames(const PlanNode& node) { + folly::F14FastSet names; + for (const auto& name : node.outputType()->names()) { + VELOX_USER_CHECK(!name.empty(), "Output column name cannot be empty"); + VELOX_USER_CHECK( + names.emplace(name).second, "Duplicate output column: {}", name); + } + } + + static void checkInputs( + const core::TypedExprPtr& expr, + const RowTypePtr& rowType) { + if (expr->isFieldAccessKind()) { + auto fieldAccess = expr->asUnchecked(); + if (fieldAccess->isInputColumn()) { + // Verify that field name points to an existing column in the input and + // the type matches. + const auto& name = fieldAccess->name(); + const auto& type = fieldAccess->type(); + const auto& expectedType = rowType->findChild(fieldAccess->name()); + VELOX_USER_CHECK( + *type == *expectedType, + "Wrong type of input column: {}, {} vs. {}", + name, + type->toString(), + expectedType->toString()); + } + } + + if (expr->isLambdaKind()) { + const auto& lambda = expr->asUnchecked(); + checkInputs(lambda->body(), lambda->signature()->unionWith(rowType)); + } + + for (const auto& input : expr->inputs()) { + checkInputs(input, rowType); + } + } +}; +} // namespace + +void PlanConsistencyChecker::check(const core::PlanNodePtr& plan) { + PlanNodeVisitorContext ctx; + Checker checker; + plan->accept(checker, ctx); +} +}; // namespace facebook::velox::core diff --git a/velox/core/PlanConsistencyChecker.h b/velox/core/PlanConsistencyChecker.h new file mode 100644 index 000000000000..1268a3cc65f5 --- /dev/null +++ b/velox/core/PlanConsistencyChecker.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/PlanNode.h" + +namespace facebook::velox::core { + +/// Verify integrity of the plan. Specifically, verify that expressions used in +/// Filter, Project and ParallelProject nodes reference input columns using +/// valid names and types. +class PlanConsistencyChecker { + public: + static void check(const core::PlanNodePtr& plan); +}; +} // namespace facebook::velox::core diff --git a/velox/core/PlanFragment.cpp b/velox/core/PlanFragment.cpp index cfed635dc1c0..1799ca8be1ec 100644 --- a/velox/core/PlanFragment.cpp +++ b/velox/core/PlanFragment.cpp @@ -28,11 +28,10 @@ bool PlanFragment::canSpill(const QueryConfig& queryConfig) const { }) != nullptr; } -bool PlanFragment::supportsBarrier() const { +const PlanNode* PlanFragment::firstNodeNotSupportingBarrier() const { return PlanNode::findFirstNode( - planNode.get(), [&](const core::PlanNode* node) { - return !node->supportsBarrier(); - }) == nullptr; + planNode.get(), + [&](const core::PlanNode* node) { return !node->supportsBarrier(); }); } std::string executionStrategyToString(ExecutionStrategy strategy) { diff --git a/velox/core/PlanFragment.h b/velox/core/PlanFragment.h index 98cd64f0405b..4724a0d62eec 100644 --- a/velox/core/PlanFragment.h +++ b/velox/core/PlanFragment.h @@ -62,8 +62,9 @@ struct PlanFragment { groupedExecutionLeafNodeIds.end(); } - /// Returns true if all plan nodes support barrier. - bool supportsBarrier() const; + /// Returns first node that does not support barrier. + /// Returns nullptr if all nodes support barrier. + const PlanNode* firstNodeNotSupportingBarrier() const; /// Returns true if the spilling is enabled and there is at least one node in /// the plan, whose operator can spill. Returns false otherwise. diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index d4e94654f614..2f0e0a7aff42 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -15,12 +15,12 @@ */ #include +#include "velox/common/Casts.h" #include "velox/common/encode/Base64.h" #include "velox/core/PlanNode.h" #include "velox/vector/VectorSaver.h" namespace facebook::velox::core { - namespace { void appendComma(int32_t i, std::stringstream& sql) { @@ -51,12 +51,19 @@ IndexLookupConditionPtr createIndexJoinCondition( if (obj["type"] == "between") { return BetweenIndexLookupCondition::create(obj, context); } + if (obj["type"] == "equal") { + return EqualIndexLookupCondition::create(obj, context); + } VELOX_USER_FAIL( "Unknown index join condition type {}", obj["type"].asString()); } } // namespace -std::vector deserializeJoinConditions( +/// Deserializes lookup conditions from dynamic object for index lookup joins. +/// These conditions are more complex than simple equality join conditions and +/// can include IN, BETWEEN, and EQUAL conditions that involve both left and +/// right side columns. +std::vector deserializejoinConditions( const folly::dynamic& obj, void* context) { if (obj.count("joinConditions") == 0) { @@ -81,6 +88,15 @@ PlanNodePtr deserializeSingleSource(const folly::dynamic& obj, void* context) { PlanNodeId deserializePlanNodeId(const folly::dynamic& obj) { return obj["id"].asString(); } + +template +std::unordered_map invertMap(const std::unordered_map& mapping) { + std::unordered_map inverted; + for (const auto& [key, value] : mapping) { + inverted.emplace(value, key); + } + return inverted; +} } // namespace const SortOrder kAscNullsFirst(true, true); @@ -150,6 +166,7 @@ AggregationNode::AggregationNode( const std::vector& globalGroupingSets, const std::optional& groupId, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source) : PlanNode(id), step_(step), @@ -160,6 +177,7 @@ AggregationNode::AggregationNode( ignoreNullKeys_(ignoreNullKeys), groupId_(groupId), globalGroupingSets_(globalGroupingSets), + noGroupsSpanBatches_(noGroupsSpanBatches), sources_{source}, outputType_(getAggregationOutputType( groupingKeys_, @@ -204,6 +222,10 @@ AggregationNode::AggregationNode( VELOX_USER_CHECK( groupId_.has_value(), "Global grouping sets require GroupId key"); } + + VELOX_USER_CHECK( + !noGroupsSpanBatches_ || isPreGrouped(), + "noGroupsSpanBatches can only be set for streaming aggregation (pre-grouped)"); } AggregationNode::AggregationNode( @@ -214,6 +236,7 @@ AggregationNode::AggregationNode( const std::vector& aggregateNames, const std::vector& aggregates, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source) : AggregationNode( id, @@ -225,6 +248,7 @@ AggregationNode::AggregationNode( kDefaultGlobalGroupingSets, kDefaultGroupId, ignoreNullKeys, + noGroupsSpanBatches, source) {} namespace { @@ -281,7 +305,7 @@ bool AggregationNode::canSpill(const QueryConfig& queryConfig) const { } void AggregationNode::addDetails(std::stringstream& stream) const { - stream << stepName(step_) << " "; + stream << toName(step_) << " "; if (isPreGrouped()) { stream << "STREAMING "; @@ -319,47 +343,31 @@ void AggregationNode::addDetails(std::stringstream& stream) const { if (groupId_.has_value()) { stream << " Group Id key: " << groupId_.value()->name(); } + + if (noGroupsSpanBatches_) { + stream << " noGroupsSpanBatches"; + } } namespace { -std::unordered_map stepNames() { - return { - {AggregationNode::Step::kPartial, "PARTIAL"}, - {AggregationNode::Step::kFinal, "FINAL"}, - {AggregationNode::Step::kIntermediate, "INTERMEDIATE"}, - {AggregationNode::Step::kSingle, "SINGLE"}, - }; +const auto& stepNames() { + static const folly::F14FastMap + kNames = { + {AggregationNode::Step::kPartial, "PARTIAL"}, + {AggregationNode::Step::kFinal, "FINAL"}, + {AggregationNode::Step::kIntermediate, "INTERMEDIATE"}, + {AggregationNode::Step::kSingle, "SINGLE"}, + }; + return kNames; } -template -std::unordered_map invertMap(const std::unordered_map& mapping) { - std::unordered_map inverted; - for (const auto& [key, value] : mapping) { - inverted.emplace(value, key); - } - return inverted; -} } // namespace -// static -const char* AggregationNode::stepName(AggregationNode::Step step) { - static const auto kSteps = stepNames(); - auto it = kSteps.find(step); - VELOX_CHECK(it != kSteps.end(), "Invalid step {}", static_cast(step)); - return it->second.c_str(); -} - -// static -AggregationNode::Step AggregationNode::stepFromName(const std::string& name) { - static const auto kSteps = invertMap(stepNames()); - auto it = kSteps.find(name); - VELOX_CHECK(it != kSteps.end(), "Invalid step " + name); - return it->second; -} +VELOX_DEFINE_EMBEDDED_ENUM_NAME(AggregationNode, Step, stepNames) folly::dynamic AggregationNode::serialize() const { auto obj = PlanNode::serialize(); - obj["step"] = stepName(step_); + obj["step"] = toName(step_); obj["groupingKeys"] = ISerializable::serialize(groupingKeys_); obj["preGroupedKeys"] = ISerializable::serialize(preGroupedKeys_); obj["aggregateNames"] = ISerializable::serialize(aggregateNames_); @@ -377,6 +385,7 @@ folly::dynamic AggregationNode::serialize() const { obj["groupId"] = ISerializable::serialize(groupId_.value()); } obj["ignoreNullKeys"] = ignoreNullKeys_; + obj["noGroupsSpanBatches"] = noGroupsSpanBatches_; return obj; } @@ -394,6 +403,10 @@ std::vector deserializeFields( array, context); } +FieldAccessTypedExprPtr deserializeField(const folly::dynamic& obj) { + return ISerializable::deserialize(obj); +} + std::vector deserializeStrings(const folly::dynamic& array) { return ISerializable::deserialize>(array); } @@ -439,12 +452,13 @@ folly::dynamic AggregationNode::Aggregate::serialize() const { AggregationNode::Aggregate AggregationNode::Aggregate::deserialize( const folly::dynamic& obj, void* context) { - auto call = ISerializable::deserialize(obj["call"]); + auto call = ISerializable::deserialize(obj["call"], context); auto rawInputTypes = ISerializable::deserialize>(obj["rawInputTypes"]); FieldAccessTypedExprPtr mask; if (obj.count("mask")) { - mask = ISerializable::deserialize(obj["mask"]); + mask = + ISerializable::deserialize(obj["mask"], context); } auto sortingKeys = deserializeFields(obj["sortingKeys"], context); auto sortingOrders = deserializeSortingOrders(obj["sortingOrders"]); @@ -484,7 +498,7 @@ PlanNodePtr AggregationNode::create(const folly::dynamic& obj, void* context) { return std::make_shared( deserializePlanNodeId(obj), - stepFromName(obj["step"].asString()), + toStep(obj["step"].asString()), groupingKeys, preGroupedKeys, aggregateNames, @@ -492,6 +506,8 @@ PlanNodePtr AggregationNode::create(const folly::dynamic& obj, void* context) { globalGroupingSets, groupId, obj["ignoreNullKeys"].asBool(), + obj.count("noGroupsSpanBatches") ? obj["noGroupsSpanBatches"].asBool() + : false, deserializeSingleSource(obj, context)); } @@ -701,7 +717,8 @@ PlanNodePtr GroupIdNode::create(const folly::dynamic& obj, void* context) { for (const auto& info : obj["groupingKeyInfos"]) { groupingKeyInfos.push_back( {info["output"].asString(), - ISerializable::deserialize(info["input"])}); + ISerializable::deserialize( + info["input"], context)}); } auto groupingSets = @@ -783,6 +800,22 @@ PlanNodePtr ValuesNode::create(const folly::dynamic& obj, void* context) { obj["repeatTimes"].asInt()); } +// static +RowTypePtr AbstractProjectNode::makeOutputType( + const std::vector& names, + const std::vector& projections) { + VELOX_USER_CHECK_EQ(names.size(), projections.size()); + + std::vector types; + types.reserve(projections.size()); + for (const auto& projection : projections) { + types.push_back(projection->type()); + } + + auto namesCopy = names; + return ROW(std::move(namesCopy), std::move(types)); +} + void AbstractProjectNode::addDetails(std::stringstream& stream) const { stream << "expressions: "; for (auto i = 0; i < projections_.size(); i++) { @@ -863,7 +896,11 @@ class SummarizeExprVisitor : public ITypedExprVisitor { void visit(const FieldAccessTypedExpr& expr, ITypedExprVisitorContext& ctx) const override { auto& myCtx = static_cast(ctx); - myCtx.expressionCounts()["field"]++; + if (expr.isInputColumn()) { + myCtx.expressionCounts()["field"]++; + } else { + myCtx.expressionCounts()["dereference"]++; + } visitInputs(expr, ctx); } @@ -992,6 +1029,12 @@ void AbstractProjectNode::addSummaryDetails( SummarizeExprVisitor::Context exprCtx; SummarizeExprVisitor visitor; for (const auto& projection : projections_) { + // Skip identity projections. + if (projection->isFieldAccessKind() && + projection->asUnchecked()->isInputColumn()) { + continue; + } + projection->accept(visitor, exprCtx); } @@ -1006,12 +1049,17 @@ void AbstractProjectNode::addSummaryDetails( std::vector dereferences; dereferences.reserve(numFields); + std::vector constants; + constants.reserve(numFields); + for (auto i = 0; i < numFields; ++i) { const auto& expr = projections_[i]; - if (dynamic_cast(expr.get())) { + if (expr->isDereferenceKind()) { dereferences.push_back(i); + } else if (expr->isConstantKind()) { + constants.push_back(i); } else { - auto fae = dynamic_cast(expr.get()); + auto fae = expr->asUnchecked(); if (fae == nullptr) { projections.push_back(i); } else if (!fae->isInputColumn()) { @@ -1021,33 +1069,53 @@ void AbstractProjectNode::addSummaryDetails( } // projections: 4 out of 10 - stream << indentation << "projections: " << projections.size() << " out of " - << numFields << std::endl; - { - const auto cnt = - std::min(options.project.maxProjections, projections.size()); - appendProjections( - indentation + " ", - *this, - projections, - cnt, - stream, - options.maxLength); + if (!projections.empty()) { + stream << indentation << "projections: " << projections.size() << " out of " + << numFields << std::endl; + { + const auto cnt = + std::min(options.project.maxProjections, projections.size()); + appendProjections( + indentation + " ", + *this, + projections, + cnt, + stream, + options.maxLength); + } } // dereferences: 2 out of 10 - stream << indentation << "dereferences: " << dereferences.size() << " out of " - << numFields << std::endl; - { - const auto cnt = - std::min(options.project.maxDereferences, dereferences.size()); - appendProjections( - indentation + " ", - *this, - dereferences, - cnt, - stream, - options.maxLength); + if (!dereferences.empty()) { + stream << indentation << "dereferences: " << dereferences.size() + << " out of " << numFields << std::endl; + { + const auto cnt = + std::min(options.project.maxDereferences, dereferences.size()); + appendProjections( + indentation + " ", + *this, + dereferences, + cnt, + stream, + options.maxLength); + } + } + + // constants: 1 out of 10 + if (!constants.empty()) { + stream << indentation << "constant projections: " << constants.size() + << " out of " << numFields << std::endl; + { + const auto cnt = std::min(options.project.maxConstants, constants.size()); + appendProjections( + indentation + " ", + *this, + constants, + cnt, + stream, + options.maxLength); + } } } @@ -1078,6 +1146,122 @@ PlanNodePtr ProjectNode::create(const folly::dynamic& obj, void* context) { std::move(source)); } +namespace { +// makes a list of all names for use in the ProjectNode. +std::vector allNames( + const std::vector& names, + const std::vector& moreNames) { + auto result = names; + result.insert(result.cend(), moreNames.cbegin(), moreNames.cend()); + return result; +} + +// Flattens out projection exprs and adds dummy exprs for noLoadIdentities. +// Used to fill in ProjectNode members for use in the summary functions. +std::vector flattenExprs( + const std::vector>& exprs, + const std::vector& moreNames, + const PlanNodePtr& input) { + std::vector result; + for (auto& group : exprs) { + result.insert(result.cend(), group.cbegin(), group.cend()); + } + + const auto& sourceType = input->outputType(); + for (auto& name : moreNames) { + result.push_back( + std::make_shared( + sourceType->findChild(name), name)); + } + return result; +} +} // namespace + +ParallelProjectNode::ParallelProjectNode( + const PlanNodeId& id, + std::vector names, + std::vector> exprGroups, + std::vector noLoadIdentities, + PlanNodePtr input) + : AbstractProjectNode( + id, + allNames(names, noLoadIdentities), + flattenExprs(exprGroups, noLoadIdentities, input), + input), + exprNames_(std::move(names)), + exprGroups_(std::move(exprGroups)), + noLoadIdentities_(std::move(noLoadIdentities)) { + VELOX_USER_CHECK(!exprNames_.empty()); + VELOX_USER_CHECK(!exprGroups_.empty()); + + for (const auto& group : exprGroups_) { + VELOX_USER_CHECK(!group.empty()); + } +} + +void ParallelProjectNode::addDetails(std::stringstream& stream) const { + AbstractProjectNode::addDetails(stream); + stream << " Parallel expr groups: "; + int32_t start = 0; + for (auto i = 0; i < exprGroups_.size(); ++i) { + if (i > 0) { + stream << ", "; + } + stream << fmt::format("[{}-{}]", start, start + exprGroups_[i].size() - 1); + start += exprGroups_[i].size(); + } + stream << std::endl; +} + +folly::dynamic ParallelProjectNode::serialize() const { + auto obj = PlanNode::serialize(); + obj["names"] = ISerializable::serialize(exprNames_); + obj["projections"] = ISerializable::serialize(exprGroups_); + obj["noLoadIdentities"] = ISerializable::serialize(noLoadIdentities_); + return obj; +} + +void ParallelProjectNode::accept( + const PlanNodeVisitor& visitor, + PlanNodeVisitorContext& context) const { + visitor.visit(*this, context); +} + +// static +PlanNodePtr ParallelProjectNode::create( + const folly::dynamic& obj, + void* context) { + auto source = deserializeSingleSource(obj, context); + + auto names = deserializeStrings(obj["names"]); + auto projections = + ISerializable::deserialize>>( + obj["projections"], context); + auto noLoadIdentities = deserializeStrings(obj["noLoadIdentities"]); + return std::make_shared( + deserializePlanNodeId(obj), + std::move(names), + std::move(projections), + std::move(noLoadIdentities), + std::move(source)); +} + +// static +PlanNodePtr LazyDereferenceNode::create( + const folly::dynamic& obj, + void* context) { + auto source = deserializeSingleSource(obj, context); + + auto names = deserializeStrings(obj["names"]); + auto projections = ISerializable::deserialize>( + obj["projections"], context); + return std::make_shared( + deserializePlanNodeId(obj), + std::move(names), + std::move(projections), + std::move(source)); +} + const std::vector& TableScanNode::sources() const { return kEmptySources; } @@ -1092,6 +1276,13 @@ void TableScanNode::addDetails(std::stringstream& stream) const { stream << tableHandle_->toString(); } +void TableScanNode::addSummaryDetails( + const std::string& indentation, + const PlanSummaryOptions& /* options */, + std::stringstream& stream) const { + stream << indentation << tableHandle_->toString() << std::endl; +} + folly::dynamic TableScanNode::serialize() const { auto obj = PlanNode::serialize(); obj["outputType"] = outputType_->serialize(); @@ -1111,18 +1302,16 @@ folly::dynamic TableScanNode::serialize() const { PlanNodePtr TableScanNode::create(const folly::dynamic& obj, void* context) { auto planNodeId = obj["id"].asString(); auto outputType = deserializeRowType(obj["outputType"]); - auto tableHandle = std::const_pointer_cast( + auto tableHandle = ISerializable::deserialize( - obj["tableHandle"], context)); + obj["tableHandle"], context); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; for (const auto& pair : obj["assignments"]) { auto assign = pair["assign"].asString(); auto columnHandle = ISerializable::deserialize( - pair["columnHandle"]); - assignments[assign] = - std::const_pointer_cast(columnHandle); + pair["columnHandle"], context); + assignments[assign] = std::move(columnHandle); } return std::make_shared( @@ -1178,12 +1367,14 @@ UnnestNode::UnnestNode( std::vector unnestVariables, std::vector unnestNames, std::optional ordinalityName, + std::optional markerName, const PlanNodePtr& source) : PlanNode(id), replicateVariables_{std::move(replicateVariables)}, unnestVariables_{std::move(unnestVariables)}, unnestNames_{std::move(unnestNames)}, ordinalityName_{std::move(ordinalityName)}, + markerName_(std::move(markerName)), sources_{source} { // Calculate output type. First come "replicate" columns, followed by // "unnest" columns, followed by an optional ordinality column. @@ -1219,6 +1410,12 @@ UnnestNode::UnnestNode( names.emplace_back(ordinalityName_.value()); types.emplace_back(BIGINT()); } + + if (markerName_.has_value()) { + names.emplace_back(markerName_.value()); + types.emplace_back(BOOLEAN()); + } + outputType_ = ROW(std::move(names), std::move(types)); } @@ -1235,6 +1432,9 @@ folly::dynamic UnnestNode::serialize() const { if (ordinalityName_.has_value()) { obj["ordinalityName"] = ordinalityName_.value(); } + if (markerName_.has_value()) { + obj["markerName"] = markerName_.value(); + } return obj; } @@ -1255,13 +1455,17 @@ PlanNodePtr UnnestNode::create(const folly::dynamic& obj, void* context) { if (obj.count("ordinalityName")) { ordinalityName = obj["ordinalityName"].asString(); } - + std::optional markerName = std::nullopt; + if (obj.count("markerName")) { + markerName = obj["markerName"].asString(); + } return std::make_shared( deserializePlanNodeId(obj), std::move(replicateVariables), std::move(unnestVariables), std::move(unnestNames), - ordinalityName, + std::move(ordinalityName), + std::move(markerName), std::move(source)); } @@ -1280,21 +1484,23 @@ AbstractJoinNode::AbstractJoinNode( rightKeys_(rightKeys), filter_(std::move(filter)), sources_({std::move(left), std::move(right)}), - outputType_(std::move(outputType)) { + outputType_(std::move(outputType)) {} + +void AbstractJoinNode::validate() const { VELOX_CHECK(!leftKeys_.empty(), "JoinNode requires at least one join key"); VELOX_CHECK_EQ( leftKeys_.size(), rightKeys_.size(), "JoinNode requires same number of join keys on left and right sides"); auto leftType = sources_[0]->outputType(); - for (auto key : leftKeys_) { + for (const auto& key : leftKeys_) { VELOX_CHECK( leftType->containsChild(key->name()), "Left side join key not found in left side output: {}", key->name()); } auto rightType = sources_[1]->outputType(); - for (auto key : rightKeys_) { + for (const auto& key : rightKeys_) { VELOX_CHECK( rightType->containsChild(key->name()), "Right side join key not found in right side output: {}", @@ -1307,32 +1513,29 @@ AbstractJoinNode::AbstractJoinNode( "Join key types on the left and right sides must match"); } - auto numOutputColumms = outputType_->size(); - if (core::isLeftSemiProjectJoin(joinType) || - core::isRightSemiProjectJoin(joinType)) { + auto numOutputColumns = outputType_->size(); + if (isLeftSemiProjectJoin() || isRightSemiProjectJoin()) { // Last output column must be a boolean 'match'. - --numOutputColumms; - VELOX_CHECK_EQ(outputType_->childAt(numOutputColumms), BOOLEAN()); + --numOutputColumns; + VELOX_CHECK_EQ(outputType_->childAt(numOutputColumns), BOOLEAN()); // Verify that 'match' column name doesn't match any column from left or // right source. - const auto& name = outputType_->nameOf(numOutputColumms); + const auto& name = outputType_->nameOf(numOutputColumns); VELOX_CHECK(!leftType->containsChild(name)); VELOX_CHECK(!rightType->containsChild(name)); } // Output of right semi join cannot include columns from the left side. bool outputMayIncludeLeftColumns = - !(core::isRightSemiFilterJoin(joinType) || - core::isRightSemiProjectJoin(joinType)); + !(isRightSemiFilterJoin() || isRightSemiProjectJoin()); // Output of left semi and anti joins cannot include columns from the right // side. bool outputMayIncludeRightColumns = - !(core::isLeftSemiFilterJoin(joinType) || - core::isLeftSemiProjectJoin(joinType) || core::isAntiJoin(joinType)); + !(isLeftSemiFilterJoin() || isLeftSemiProjectJoin() || isAntiJoin()); - for (auto i = 0; i < numOutputColumms; ++i) { + for (auto i = 0; i < numOutputColumns; ++i) { auto name = outputType_->nameOf(i); if (outputMayIncludeLeftColumns && leftType->containsChild(name)) { VELOX_CHECK( @@ -1353,7 +1556,7 @@ AbstractJoinNode::AbstractJoinNode( } void AbstractJoinNode::addDetails(std::stringstream& stream) const { - stream << joinTypeName(joinType_) << " "; + stream << JoinTypeName::toName(joinType_) << " "; for (auto i = 0; i < leftKeys_.size(); ++i) { if (i > 0) { @@ -1369,7 +1572,7 @@ void AbstractJoinNode::addDetails(std::stringstream& stream) const { folly::dynamic AbstractJoinNode::serializeBase() const { auto obj = PlanNode::serialize(); - obj["joinType"] = joinTypeName(joinType_); + obj["joinType"] = JoinTypeName::toName(joinType_); obj["leftKeys"] = ISerializable::serialize(leftKeys_); obj["rightKeys"] = ISerializable::serialize(rightKeys_); if (filter_) { @@ -1380,8 +1583,8 @@ folly::dynamic AbstractJoinNode::serializeBase() const { } namespace { -std::unordered_map joinTypeNames() { - return { +const auto& joinTypeNames() { + static const folly::F14FastMap kNames = { {JoinType::kInner, "INNER"}, {JoinType::kLeft, "LEFT"}, {JoinType::kRight, "RIGHT"}, @@ -1392,25 +1595,33 @@ std::unordered_map joinTypeNames() { {JoinType::kRightSemiProject, "RIGHT SEMI (PROJECT)"}, {JoinType::kAnti, "ANTI"}, }; + return kNames; } -} // namespace -const char* joinTypeName(JoinType joinType) { - static const auto kJoinTypes = joinTypeNames(); - auto it = kJoinTypes.find(joinType); - VELOX_CHECK( - it != kJoinTypes.end(), - "Invalid join type {}", - static_cast(joinType)); - return it->second.c_str(); +// Check that each output of the join is in exactly one of the inputs. +void checkJoinColumnNames( + const RowTypePtr& leftType, + const RowTypePtr& rightType, + const RowTypePtr& outputType, + uint32_t numColumnsToCheck) { + for (auto i = 0; i < numColumnsToCheck; ++i) { + const auto name = outputType->nameOf(i); + const bool leftContains = leftType->containsChild(name); + const bool rightContains = rightType->containsChild(name); + VELOX_USER_CHECK( + !(leftContains && rightContains), + "Duplicate column name found on join's left and right sides: {}", + name); + VELOX_USER_CHECK( + leftContains || rightContains, + "Join's output column not found in either left or right sides: {}", + name); + } } -JoinType joinTypeFromName(const std::string& name) { - static const auto kJoinTypes = invertMap(joinTypeNames()); - auto it = kJoinTypes.find(name); - VELOX_CHECK(it != kJoinTypes.end(), "Invalid join type " + name); - return it->second; -} +} // namespace + +VELOX_DEFINE_ENUM_NAME(JoinType, joinTypeNames) void HashJoinNode::addDetails(std::stringstream& stream) const { AbstractJoinNode::addDetails(stream); @@ -1422,6 +1633,7 @@ void HashJoinNode::addDetails(std::stringstream& stream) const { folly::dynamic HashJoinNode::serialize() const { auto obj = serializeBase(); obj["nullAware"] = nullAware_; + obj["useHashTableCache"] = useHashTableCache_; return obj; } @@ -1437,26 +1649,28 @@ PlanNodePtr HashJoinNode::create(const folly::dynamic& obj, void* context) { VELOX_CHECK_EQ(2, sources.size()); auto nullAware = obj["nullAware"].asBool(); + auto useHashTableCache = obj.getDefault("useHashTableCache", false).asBool(); auto leftKeys = deserializeFields(obj["leftKeys"], context); auto rightKeys = deserializeFields(obj["rightKeys"], context); TypedExprPtr filter; if (obj.count("filter")) { - filter = ISerializable::deserialize(obj["filter"]); + filter = ISerializable::deserialize(obj["filter"], context); } auto outputType = deserializeRowType(obj["outputType"]); return std::make_shared( deserializePlanNodeId(obj), - joinTypeFromName(obj["joinType"].asString()), + JoinTypeName::toJoinType(obj["joinType"].asString()), nullAware, std::move(leftKeys), std::move(rightKeys), filter, sources[0], sources[1], - outputType); + outputType, + useHashTableCache); } MergeJoinNode::MergeJoinNode( @@ -1477,10 +1691,11 @@ MergeJoinNode::MergeJoinNode( std::move(left), std::move(right), std::move(outputType)) { + validate(); VELOX_USER_CHECK( isSupported(joinType_), - "The join type is not supported by merge join: ", - joinTypeName(joinType_)); + "The join type is not supported by merge join: {}", + JoinTypeName::toName(joinType_)); } folly::dynamic MergeJoinNode::serialize() const { @@ -1488,15 +1703,15 @@ folly::dynamic MergeJoinNode::serialize() const { } // static -bool MergeJoinNode::isSupported(core::JoinType joinType) { +bool MergeJoinNode::isSupported(JoinType joinType) { switch (joinType) { - case core::JoinType::kInner: - case core::JoinType::kLeft: - case core::JoinType::kRight: - case core::JoinType::kLeftSemiFilter: - case core::JoinType::kRightSemiFilter: - case core::JoinType::kAnti: - case core::JoinType::kFull: + case JoinType::kInner: + case JoinType::kLeft: + case JoinType::kRight: + case JoinType::kLeftSemiFilter: + case JoinType::kRightSemiFilter: + case JoinType::kAnti: + case JoinType::kFull: return true; default: @@ -1520,14 +1735,14 @@ PlanNodePtr MergeJoinNode::create(const folly::dynamic& obj, void* context) { TypedExprPtr filter; if (obj.count("filter")) { - filter = ISerializable::deserialize(obj["filter"]); + filter = ISerializable::deserialize(obj["filter"], context); } auto outputType = deserializeRowType(obj["outputType"]); return std::make_shared( deserializePlanNodeId(obj), - joinTypeFromName(obj["joinType"].asString()), + JoinTypeName::toJoinType(obj["joinType"].asString()), std::move(leftKeys), std::move(rightKeys), filter, @@ -1536,30 +1751,125 @@ PlanNodePtr MergeJoinNode::create(const folly::dynamic& obj, void* context) { outputType); } +IndexLookupJoinNode::IndexLookupJoinNode( + const PlanNodeId& id, + JoinType joinType, + const std::vector& leftKeys, + const std::vector& rightKeys, + const std::vector& joinConditions, + TypedExprPtr filter, + bool hasMarker, + PlanNodePtr left, + TableScanNodePtr right, + RowTypePtr outputType) + : AbstractJoinNode( + id, + joinType, + leftKeys, + rightKeys, + std::move(filter), + std::move(left), + right, + outputType), + lookupSourceNode_(std::move(right)), + joinConditions_(joinConditions), + hasMarker_(hasMarker) { + VELOX_USER_CHECK( + !leftKeys.empty(), + "The index lookup join node requires at least one join key"); + VELOX_USER_CHECK_EQ( + leftKeys_.size(), + rightKeys_.size(), + "The index lookup join node requires same number of join keys on left and right sides"); + + // TODO: add check that (1) 'rightKeys_' form an index prefix. each of + // 'joinConditions_' uses columns from both sides and uses exactly one index + // column from the right side. + VELOX_USER_CHECK( + lookupSourceNode_->tableHandle()->supportsIndexLookup(), + "The lookup table handle {} from connector {} doesn't support index lookup", + lookupSourceNode_->tableHandle()->name(), + lookupSourceNode_->tableHandle()->connectorId()); + VELOX_USER_CHECK( + isSupported(joinType_), + "Unsupported index lookup join type {}", + JoinTypeName::toName(joinType_)); + + auto leftType = sources_[0]->outputType(); + for (const auto& key : leftKeys_) { + VELOX_USER_CHECK( + leftType->containsChild(key->name()), + "Left side join key not found in left side output: {}", + key->name()); + } + auto rightType = sources_[1]->outputType(); + for (const auto& key : rightKeys_) { + VELOX_USER_CHECK( + rightType->containsChild(key->name()), + "Right side join key not found in right side output: {}", + key->name()); + } + for (auto i = 0; i < leftKeys_.size(); ++i) { + VELOX_USER_CHECK_EQ( + leftKeys_[i]->type()->kind(), + rightKeys_[i]->type()->kind(), + "Index lookup koin key types on the left and right sides must match"); + } + + auto numOutputColumns = outputType_->size(); + if (hasMarker_) { + VELOX_USER_CHECK( + isLeftJoin(), + "Index join match column can only present for {} but not {}", + JoinTypeName::toName(JoinType::kLeft), + JoinTypeName::toName(joinType_)); + // Last output column must be a boolean 'match'. + --numOutputColumns; + VELOX_USER_CHECK_EQ( + outputType_->childAt(numOutputColumns), + BOOLEAN(), + "The last output column must be boolean type if match column is present"); + + // Verify that 'match' column name doesn't match any column from left or + // right source. + const auto& name = outputType_->nameOf(numOutputColumns); + VELOX_USER_CHECK(!leftType->containsChild(name)); + VELOX_USER_CHECK(!rightType->containsChild(name)); + } + + checkJoinColumnNames(leftType, rightType, outputType_, numOutputColumns); +} + PlanNodePtr IndexLookupJoinNode::create( const folly::dynamic& obj, void* context) { auto sources = deserializeSources(obj, context); VELOX_CHECK_EQ(2, sources.size()); TableScanNodePtr lookupSource = - std::dynamic_pointer_cast(sources[1]); - VELOX_CHECK_NOT_NULL(lookupSource); + checkedPointerCast(sources[1]); auto leftKeys = deserializeFields(obj["leftKeys"], context); auto rightKeys = deserializeFields(obj["rightKeys"], context); - VELOX_CHECK_EQ(obj.count("filter"), 0); + TypedExprPtr filter; + if (obj.count("filter")) { + filter = ISerializable::deserialize(obj["filter"], context); + } + + auto joinConditions = deserializejoinConditions(obj, context); - auto joinConditions = deserializeJoinConditions(obj, context); + const bool hasMarker = obj["hasMarker"].asBool(); auto outputType = deserializeRowType(obj["outputType"]); return std::make_shared( deserializePlanNodeId(obj), - joinTypeFromName(obj["joinType"].asString()), + JoinTypeName::toJoinType(obj["joinType"].asString()), std::move(leftKeys), std::move(rightKeys), std::move(joinConditions), + filter, + hasMarker, sources[0], std::move(lookupSource), std::move(outputType)); @@ -1574,6 +1884,10 @@ folly::dynamic IndexLookupJoinNode::serialize() const { } obj["joinConditions"] = std::move(serializedJoins); } + if (filter_) { + obj["filter"] = filter_->serialize(); + } + obj["hasMarker"] = hasMarker_; return obj; } @@ -1583,13 +1897,15 @@ void IndexLookupJoinNode::addDetails(std::stringstream& stream) const { return; } - std::vector joinConditionStrs; - joinConditionStrs.reserve(joinConditions_.size()); + std::vector joinConditionstrs; + joinConditionstrs.reserve(joinConditions_.size()); for (const auto& joinCondition : joinConditions_) { - joinConditionStrs.push_back(joinCondition->toString()); + joinConditionstrs.push_back(joinCondition->toString()); } - stream << ", joinConditions: [" << folly::join(", ", joinConditionStrs) - << " ]"; + stream << ", joinConditions: [" << folly::join(", ", joinConditionstrs) + << " ], filter: [" + << (filter_ == nullptr ? "null" : filter_->toString()) + << "], hasMarker: [" << (hasMarker_ ? "true" : "false") << "]"; } void IndexLookupJoinNode::accept( @@ -1599,21 +1915,19 @@ void IndexLookupJoinNode::accept( } // static -bool IndexLookupJoinNode::isSupported(core::JoinType joinType) { +bool IndexLookupJoinNode::isSupported(JoinType joinType) { switch (joinType) { - case core::JoinType::kInner: + case JoinType::kInner: [[fallthrough]]; - case core::JoinType::kLeft: + case JoinType::kLeft: return true; default: return false; } } -bool isIndexLookupJoin(const core::PlanNode* planNode) { - const auto* indexLookupJoin = - dynamic_cast(planNode); - return indexLookupJoin != nullptr; +bool isIndexLookupJoin(const PlanNode* planNode) { + return isInstanceOf(planNode); } // static @@ -1635,16 +1949,16 @@ NestedLoopJoinNode::NestedLoopJoinNode( outputType_(std::move(outputType)) { VELOX_USER_CHECK( isSupported(joinType_), - "The join type is not supported by nested loop join: ", - joinTypeName(joinType_)); + "The join type is not supported by nested loop join: {}", + JoinTypeName::toName(joinType_)); auto leftType = sources_[0]->outputType(); auto rightType = sources_[1]->outputType(); - auto numOutputColumms = outputType_->size(); - if (core::isLeftSemiProjectJoin(joinType)) { - --numOutputColumms; - VELOX_CHECK_EQ(outputType_->childAt(numOutputColumms), BOOLEAN()); - const auto& name = outputType_->nameOf(numOutputColumms); + auto numOutputColumns = outputType_->size(); + if (isLeftSemiProjectJoin(joinType)) { + --numOutputColumns; + VELOX_CHECK_EQ(outputType_->childAt(numOutputColumns), BOOLEAN()); + const auto& name = outputType_->nameOf(numOutputColumns); VELOX_CHECK( !leftType->containsChild(name), "Match column '{}' cannot be present in left source.", @@ -1655,19 +1969,7 @@ NestedLoopJoinNode::NestedLoopJoinNode( name); } - for (auto i = 0; i < numOutputColumms; ++i) { - const auto name = outputType_->nameOf(i); - const bool leftContains = leftType->containsChild(name); - const bool rightContains = rightType->containsChild(name); - VELOX_USER_CHECK( - !(leftContains && rightContains), - "Duplicate column name found on join's left and right sides: {}", - name); - VELOX_USER_CHECK( - leftContains || rightContains, - "Join's output column not found in either left or right sides: {}", - name); - } + checkJoinColumnNames(leftType, rightType, outputType_, numOutputColumns); } NestedLoopJoinNode::NestedLoopJoinNode( @@ -1679,18 +1981,18 @@ NestedLoopJoinNode::NestedLoopJoinNode( id, kDefaultJoinType, kDefaultJoinCondition, - left, - right, - outputType) {} + std::move(left), + std::move(right), + std::move(outputType)) {} // static -bool NestedLoopJoinNode::isSupported(core::JoinType joinType) { +bool NestedLoopJoinNode::isSupported(JoinType joinType) { switch (joinType) { - case core::JoinType::kInner: - case core::JoinType::kLeft: - case core::JoinType::kRight: - case core::JoinType::kFull: - case core::JoinType::kLeftSemiProject: + case JoinType::kInner: + case JoinType::kLeft: + case JoinType::kRight: + case JoinType::kFull: + case JoinType::kLeftSemiProject: return true; default: @@ -1699,7 +2001,7 @@ bool NestedLoopJoinNode::isSupported(core::JoinType joinType) { } void NestedLoopJoinNode::addDetails(std::stringstream& stream) const { - stream << joinTypeName(joinType_); + stream << JoinTypeName::toName(joinType_); if (joinCondition_) { stream << ", joinCondition: " << joinCondition_->toString(); } @@ -1707,7 +2009,7 @@ void NestedLoopJoinNode::addDetails(std::stringstream& stream) const { folly::dynamic NestedLoopJoinNode::serialize() const { auto obj = PlanNode::serialize(); - obj["joinType"] = joinTypeName(joinType_); + obj["joinType"] = JoinTypeName::toName(joinType_); if (joinCondition_) { obj["joinCondition"] = joinCondition_->serialize(); } @@ -1737,7 +2039,7 @@ PlanNodePtr NestedLoopJoinNode::create( return std::make_shared( deserializePlanNodeId(obj), - joinTypeFromName(obj["joinType"].asString()), + JoinTypeName::toJoinType(obj["joinType"].asString()), joinCondition, sources[0], sources[1], @@ -1824,17 +2126,17 @@ void addWindowFunction( VELOX_USER_FAIL("Window frame end cannot be UNBOUNDED PRECEDING"); } - stream << WindowNode::windowTypeName(frame.type) << " between "; + stream << WindowNode::toName(frame.type) << " between "; if (frame.startValue) { addKeys(stream, {frame.startValue}); stream << " "; } - stream << WindowNode::boundTypeName(frame.startType) << " and "; + stream << WindowNode::toName(frame.startType) << " and "; if (frame.endValue) { addKeys(stream, {frame.endValue}); stream << " "; } - stream << WindowNode::boundTypeName(frame.endType); + stream << WindowNode::toName(frame.endType); } } // namespace @@ -1926,70 +2228,42 @@ void WindowNode::addDetails(std::stringstream& stream) const { } namespace { -std::unordered_map boundTypeNames() { - return { - {WindowNode::BoundType::kCurrentRow, "CURRENT ROW"}, - {WindowNode::BoundType::kPreceding, "PRECEDING"}, - {WindowNode::BoundType::kFollowing, "FOLLOWING"}, - {WindowNode::BoundType::kUnboundedPreceding, "UNBOUNDED PRECEDING"}, - {WindowNode::BoundType::kUnboundedFollowing, "UNBOUNDED FOLLOWING"}, - }; +const auto& boundTypeNames() { + static const folly::F14FastMap + kNames = { + {WindowNode::BoundType::kCurrentRow, "CURRENT ROW"}, + {WindowNode::BoundType::kPreceding, "PRECEDING"}, + {WindowNode::BoundType::kFollowing, "FOLLOWING"}, + {WindowNode::BoundType::kUnboundedPreceding, "UNBOUNDED PRECEDING"}, + {WindowNode::BoundType::kUnboundedFollowing, "UNBOUNDED FOLLOWING"}, + }; + return kNames; } } // namespace -// static -const char* WindowNode::boundTypeName(WindowNode::BoundType type) { - static const auto kTypes = boundTypeNames(); - auto it = kTypes.find(type); - VELOX_CHECK( - it != kTypes.end(), - "Invalid window bound type {}", - static_cast(type)); - return it->second.c_str(); -} - -// static -WindowNode::BoundType WindowNode::boundTypeFromName(const std::string& name) { - static const auto kTypes = invertMap(boundTypeNames()); - auto it = kTypes.find(name); - VELOX_CHECK(it != kTypes.end(), "Invalid window bound type " + name); - return it->second; -} +VELOX_DEFINE_EMBEDDED_ENUM_NAME(WindowNode, BoundType, boundTypeNames) namespace { -std::unordered_map windowTypeNames() { - return { - {WindowNode::WindowType::kRows, "ROWS"}, - {WindowNode::WindowType::kRange, "RANGE"}, - }; +const auto& windowTypeNames() { + static const folly::F14FastMap + kNames = { + {WindowNode::WindowType::kRows, "ROWS"}, + {WindowNode::WindowType::kRange, "RANGE"}, + }; + return kNames; } } // namespace -// static -const char* WindowNode::windowTypeName(WindowNode::WindowType type) { - static const auto kTypes = windowTypeNames(); - auto it = kTypes.find(type); - VELOX_CHECK( - it != kTypes.end(), "Invalid window type {}", static_cast(type)); - return it->second.c_str(); -} - -// static -WindowNode::WindowType WindowNode::windowTypeFromName(const std::string& name) { - static const auto kTypes = invertMap(windowTypeNames()); - auto it = kTypes.find(name); - VELOX_CHECK(it != kTypes.end(), "Invalid window type " + name); - return it->second; -} +VELOX_DEFINE_EMBEDDED_ENUM_NAME(WindowNode, WindowType, windowTypeNames) folly::dynamic WindowNode::Frame::serialize() const { folly::dynamic obj = folly::dynamic::object(); - obj["type"] = windowTypeName(type); - obj["startType"] = boundTypeName(startType); + obj["type"] = toName(type); + obj["startType"] = toName(startType); if (startValue) { obj["startValue"] = startValue->serialize(); } - obj["endType"] = boundTypeName(endType); + obj["endType"] = toName(endType); if (endValue) { obj["endValue"] = endValue->serialize(); } @@ -2009,10 +2283,10 @@ WindowNode::Frame WindowNode::Frame::deserialize(const folly::dynamic& obj) { } return { - windowTypeFromName(obj["type"].asString()), - boundTypeFromName(obj["startType"].asString()), + toWindowType(obj["type"].asString()), + toBoundType(obj["startType"].asString()), startValue, - boundTypeFromName(obj["endType"].asString()), + toBoundType(obj["endType"].asString()), endValue}; } @@ -2228,8 +2502,41 @@ PlanNodePtr RowNumberNode::create(const folly::dynamic& obj, void* context) { source); } +namespace { +std::unordered_map +rankFunctionNames() { + return { + {TopNRowNumberNode::RankFunction::kRowNumber, "row_number"}, + {TopNRowNumberNode::RankFunction::kRank, "rank"}, + {TopNRowNumberNode::RankFunction::kDenseRank, "dense_rank"}, + }; +} +} // namespace + +// static +const char* TopNRowNumberNode::rankFunctionName( + TopNRowNumberNode::RankFunction function) { + static const auto kFunctionNames = rankFunctionNames(); + auto it = kFunctionNames.find(function); + VELOX_CHECK( + it != kFunctionNames.cend(), + "Invalid rank function {}", + static_cast(function)); + return it->second.c_str(); +} + +// static +TopNRowNumberNode::RankFunction TopNRowNumberNode::rankFunctionFromName( + std::string_view name) { + static const auto kFunctionNames = invertMap(rankFunctionNames()); + auto it = kFunctionNames.find(name.data()); + VELOX_CHECK(it != kFunctionNames.cend(), "Invalid rank function {}", name); + return it->second; +} + TopNRowNumberNode::TopNRowNumberNode( PlanNodeId id, + RankFunction function, std::vector partitionKeys, std::vector sortingKeys, std::vector sortingOrders, @@ -2237,6 +2544,7 @@ TopNRowNumberNode::TopNRowNumberNode( int32_t limit, PlanNodePtr source) : PlanNode(std::move(id)), + function_(function), partitionKeys_{std::move(partitionKeys)}, sortingKeys_{std::move(sortingKeys)}, sortingOrders_{std::move(sortingOrders)}, @@ -2274,6 +2582,8 @@ TopNRowNumberNode::TopNRowNumberNode( } void TopNRowNumberNode::addDetails(std::stringstream& stream) const { + stream << rankFunctionName(function_) << " "; + if (!partitionKeys_.empty()) { stream << "partition by ("; addFields(stream, partitionKeys_); @@ -2289,6 +2599,7 @@ void TopNRowNumberNode::addDetails(std::stringstream& stream) const { folly::dynamic TopNRowNumberNode::serialize() const { auto obj = PlanNode::serialize(); + obj["function"] = rankFunctionName(function_); obj["partitionKeys"] = ISerializable::serialize(partitionKeys_); obj["sortingKeys"] = ISerializable::serialize(sortingKeys_); obj["sortingOrders"] = serializeSortingOrders(sortingOrders_); @@ -2310,6 +2621,7 @@ PlanNodePtr TopNRowNumberNode::create( const folly::dynamic& obj, void* context) { auto source = deserializeSingleSource(obj, context); + auto function = rankFunctionFromName(obj["function"].asString()); auto partitionKeys = deserializeFields(obj["partitionKeys"], context); auto sortingKeys = deserializeFields(obj["sortingKeys"], context); @@ -2322,6 +2634,7 @@ PlanNodePtr TopNRowNumberNode::create( return std::make_shared( deserializePlanNodeId(obj), + function, partitionKeys, sortingKeys, sortingOrders, @@ -2364,19 +2677,79 @@ void TableWriteNode::addDetails(std::stringstream& stream) const { stream << insertTableHandle_->connectorInsertTableHandle()->toString(); } +RowTypePtr ColumnStatsSpec::outputType() const { + // Create output type based on the column stats collection specs. + std::vector names; + std::vector types; + + const auto numAggregates = aggregates.size(); + const auto outputTypeSize = groupingKeys.size() + numAggregates; + + names.reserve(outputTypeSize); + types.reserve(outputTypeSize); + + for (const auto& key : groupingKeys) { + names.push_back(key->name()); + types.push_back(key->type()); + } + + for (auto i = 0; i < numAggregates; ++i) { + names.push_back(aggregateNames[i]); + types.push_back(aggregates[i].call->type()); + } + return ROW(std::move(names), std::move(types)); +} + +folly::dynamic ColumnStatsSpec::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["groupingKeys"] = ISerializable::serialize(groupingKeys); + obj["aggregationStep"] = AggregationNode::toName(aggregationStep); + obj["aggregateNames"] = ISerializable::serialize(aggregateNames); + obj["aggregates"] = folly::dynamic::array; + for (const auto& aggregate : aggregates) { + obj["aggregates"].push_back(aggregate.serialize()); + } + return obj; +} + +// static +ColumnStatsSpec ColumnStatsSpec::create( + const folly::dynamic& obj, + void* context) { + auto groupingKeys = deserializeFields(obj["groupingKeys"], context); + const auto aggregationStep = + AggregationNode::toStep(obj["aggregationStep"].asString()); + auto aggregateNames = ISerializable::deserialize>( + obj["aggregateNames"]); + + std::vector aggregates; + aggregates.reserve(obj["aggregates"].size()); + for (const auto& aggregate : obj["aggregates"]) { + aggregates.push_back( + AggregationNode::Aggregate::deserialize(aggregate, context)); + } + + return ColumnStatsSpec{ + std::move(groupingKeys), + aggregationStep, + std::move(aggregateNames), + std::move(aggregates)}; +} + folly::dynamic TableWriteNode::serialize() const { auto obj = PlanNode::serialize(); obj["columns"] = columns_->serialize(); obj["columnNames"] = ISerializable::serialize(columnNames_); - if (aggregationNode_ != nullptr) { - obj["aggregationNode"] = aggregationNode_->serialize(); + if (columnStatsSpec_.has_value()) { + obj["columnStatsSpec"] = columnStatsSpec_->serialize(); } obj["connectorId"] = insertTableHandle_->connectorId(); obj["connectorInsertTableHandle"] = insertTableHandle_->connectorInsertTableHandle()->serialize(); obj["hasPartitioningScheme"] = hasPartitioningScheme_; obj["outputType"] = outputType_->serialize(); - obj["commitStrategy"] = connector::commitStrategyToString(commitStrategy_); + obj["commitStrategy"] = + std::string(connector::CommitStrategyName::toName(commitStrategy_)); return obj; } @@ -2392,26 +2765,23 @@ PlanNodePtr TableWriteNode::create(const folly::dynamic& obj, void* context) { auto columns = deserializeRowType(obj["columns"]); auto columnNames = ISerializable::deserialize>(obj["columnNames"]); - std::shared_ptr aggregationNode; - if (obj.count("aggregationNode") != 0) { - aggregationNode = std::const_pointer_cast( - ISerializable::deserialize( - obj["aggregationNode"], context)); - } auto connectorId = obj["connectorId"].asString(); auto connectorInsertTableHandle = - std::const_pointer_cast( - ISerializable::deserialize( - obj["connectorInsertTableHandle"])); + ISerializable::deserialize( + obj["connectorInsertTableHandle"]); const bool hasPartitioningScheme = obj["hasPartitioningScheme"].asBool(); auto outputType = deserializeRowType(obj["outputType"]); - auto commitStrategy = - connector::stringToCommitStrategy(obj["commitStrategy"].asString()); + auto commitStrategy = connector::CommitStrategyName::toCommitStrategy( + obj["commitStrategy"].asString()); + std::optional columnStatsSpec; + if (obj.count("columnStatsSpec") != 0) { + columnStatsSpec = ColumnStatsSpec::create(obj["columnStatsSpec"], context); + } return std::make_shared( id, columns, columnNames, - std::move(aggregationNode), + std::move(columnStatsSpec), std::make_shared( connectorId, connectorInsertTableHandle), hasPartitioningScheme, @@ -2426,8 +2796,8 @@ folly::dynamic TableWriteMergeNode::serialize() const { auto obj = PlanNode::serialize(); VELOX_CHECK_EQ( sources_.size(), 1, "TableWriteMergeNode can only have one source"); - if (aggregationNode_ != nullptr) { - obj["aggregationNode"] = aggregationNode_->serialize(); + if (columnStatsSpec_.has_value()) { + obj["columnStatsSpec"] = columnStatsSpec_->serialize(); } obj["outputType"] = outputType_->serialize(); return obj; @@ -2445,13 +2815,15 @@ PlanNodePtr TableWriteMergeNode::create( void* context) { auto id = obj["id"].asString(); auto outputType = deserializeRowType(obj["outputType"]); - std::shared_ptr aggregationNode; - if (obj.count("aggregationNode") != 0) { - aggregationNode = std::const_pointer_cast( - ISerializable::deserialize(obj["aggregationNode"])); + std::optional columnStatsSpec; + if (obj.count("columnStatsSpec") != 0) { + columnStatsSpec = ColumnStatsSpec::create(obj["columnStatsSpec"], context); } return std::make_shared( - id, outputType, aggregationNode, deserializeSingleSource(obj, context)); + id, + outputType, + std::move(columnStatsSpec), + deserializeSingleSource(obj, context)); } MergeExchangeNode::MergeExchangeNode( @@ -2502,7 +2874,7 @@ PlanNodePtr MergeExchangeNode::create( } void LocalPartitionNode::addDetails(std::stringstream& stream) const { - stream << typeName(type_); + stream << toName(type_); if (type_ != Type::kGather) { stream << " " << partitionFunctionSpec_->toString(); } @@ -2513,7 +2885,7 @@ void LocalPartitionNode::addDetails(std::stringstream& stream) const { folly::dynamic LocalPartitionNode::serialize() const { auto obj = PlanNode::serialize(); - obj["type"] = typeName(type_); + obj["type"] = toName(type_); obj["scaleWriter"] = scaleWriter_; obj["partitionFunctionSpec"] = partitionFunctionSpec_->serialize(); return obj; @@ -2531,42 +2903,28 @@ PlanNodePtr LocalPartitionNode::create( void* context) { return std::make_shared( deserializePlanNodeId(obj), - typeFromName(obj["type"].asString()), + toType(obj["type"].asString()), obj["scaleWriter"].asBool(), ISerializable::deserialize( - obj["partitionFunctionSpec"]), + obj["partitionFunctionSpec"], context), deserializeSources(obj, context)); } namespace { -std::unordered_map -localPartitionTypeNames() { - return { - {LocalPartitionNode::Type::kGather, "GATHER"}, - {LocalPartitionNode::Type::kRepartition, "REPARTITION"}, - }; +const auto& localPartitionTypeNames() { + static const folly::F14FastMap + kNames = { + {LocalPartitionNode::Type::kGather, "GATHER"}, + {LocalPartitionNode::Type::kRepartition, "REPARTITION"}, + }; + return kNames; } } // namespace -// static -const char* LocalPartitionNode::typeName(Type type) { - static const auto kTypes = localPartitionTypeNames(); - auto it = kTypes.find(type); - VELOX_CHECK( - it != kTypes.end(), - "Invalid LocalPartitionNode type {}", - static_cast(type)); - return it->second.c_str(); -} - -// static -LocalPartitionNode::Type LocalPartitionNode::typeFromName( - const std::string& name) { - static const auto kTypes = invertMap(localPartitionTypeNames()); - auto it = kTypes.find(name); - VELOX_CHECK(it != kTypes.end(), "Invalid LocalPartitionNode type " + name); - return it->second; -} +VELOX_DEFINE_EMBEDDED_ENUM_NAME( + LocalPartitionNode, + Type, + localPartitionTypeNames) PartitionedOutputNode::PartitionedOutputNode( const PlanNodeId& id, @@ -2592,12 +2950,16 @@ PartitionedOutputNode::PartitionedOutputNode( VELOX_USER_CHECK( keys_.empty(), "Non-empty partitioning keys require more than one partition"); + } else { + VELOX_USER_CHECK_NOT_NULL( + partitionFunctionSpec_, + "Partition function spec must be specified when the number of destinations is more than 1."); } if (!isPartitioned()) { VELOX_USER_CHECK( keys_.empty(), "{} partitioning doesn't allow for partitioning keys", - kindString(kind_)); + toName(kind_)); } } @@ -2682,36 +3044,18 @@ PlanNodePtr EnforceSingleRowNode::create( } namespace { -std::unordered_map -partitionKindNames() { - return { - {PartitionedOutputNode::Kind::kPartitioned, "PARTITIONED"}, - {PartitionedOutputNode::Kind::kBroadcast, "BROADCAST"}, - {PartitionedOutputNode::Kind::kArbitrary, "ARBITRARY"}, - }; +const auto& partitionKindNames() { + static const folly::F14FastMap + kNames = { + {PartitionedOutputNode::Kind::kPartitioned, "PARTITIONED"}, + {PartitionedOutputNode::Kind::kBroadcast, "BROADCAST"}, + {PartitionedOutputNode::Kind::kArbitrary, "ARBITRARY"}, + }; + return kNames; } - } // namespace -// static -std::string PartitionedOutputNode::kindString(Kind kind) { - static const auto kPartitionNames = partitionKindNames(); - auto it = kPartitionNames.find(kind); - VELOX_CHECK( - it != kPartitionNames.end(), - "Invalid Output Kind {}", - static_cast(kind)); - return it->second; -} - -// static -PartitionedOutputNode::Kind PartitionedOutputNode::stringToKind( - const std::string& name) { - static const auto kPartitionKinds = invertMap(partitionKindNames()); - auto it = kPartitionKinds.find(name); - VELOX_CHECK(it != kPartitionKinds.end(), "Invalid Output Kind " + name); - return it->second; -} +VELOX_DEFINE_EMBEDDED_ENUM_NAME(PartitionedOutputNode, Kind, partitionKindNames) void PartitionedOutputNode::addDetails(std::stringstream& stream) const { if (kind_ == Kind::kBroadcast) { @@ -2739,7 +3083,7 @@ void PartitionedOutputNode::addDetails(std::stringstream& stream) const { folly::dynamic PartitionedOutputNode::serialize() const { auto obj = PlanNode::serialize(); - obj["kind"] = kindString(kind_); + obj["kind"] = toName(kind_); obj["numPartitions"] = numPartitions_; obj["keys"] = ISerializable::serialize(keys_); obj["replicateNullsAndAny"] = replicateNullsAndAny_; @@ -2761,7 +3105,7 @@ PlanNodePtr PartitionedOutputNode::create( void* context) { return std::make_shared( deserializePlanNodeId(obj), - stringToKind(obj["kind"].asString()), + toKind(obj["kind"].asString()), ISerializable::deserialize>(obj["keys"], context), obj["numPartitions"].asInt(), obj["replicateNullsAndAny"].asBool(), @@ -2772,6 +3116,123 @@ PlanNodePtr PartitionedOutputNode::create( deserializeSingleSource(obj, context)); } +SpatialJoinNode::SpatialJoinNode( + const PlanNodeId& id, + JoinType joinType, + TypedExprPtr joinCondition, + FieldAccessTypedExprPtr probeGeometry, + FieldAccessTypedExprPtr buildGeometry, + std::optional radius, + PlanNodePtr left, + PlanNodePtr right, + RowTypePtr outputType) + : PlanNode(id), + joinType_(joinType), + joinCondition_(std::move(joinCondition)), + probeGeometry_(std::move(probeGeometry)), + buildGeometry_(std::move(buildGeometry)), + radius_(std::move(radius)), + sources_({std::move(left), std::move(right)}), + outputType_(std::move(outputType)) { + VELOX_USER_CHECK( + isSupported(joinType_), + "The join type is not supported by spatial join: {}", + JoinTypeName::toName(joinType_)); + VELOX_USER_CHECK_NOT_NULL( + joinCondition_, "The join condition must not be null for spatial join"); + VELOX_USER_CHECK_NOT_NULL( + probeGeometry_, "Probe geometery must not be null for spatial joins"); + VELOX_USER_CHECK_NOT_NULL( + buildGeometry_, "Build geometery must not be null for spatial joins"); + VELOX_USER_CHECK_EQ( + sources_.size(), 2, "Must have 2 sources for spatial joins"); + VELOX_USER_CHECK( + sources_[0] != nullptr, "Left source must not be null for spatial joins"); + VELOX_USER_CHECK( + sources_[1] != nullptr, + "Right source must not be null for spatial joins"); + + checkJoinColumnNames( + sources_[0]->outputType(), + sources_[1]->outputType(), + outputType_, + outputType_->size()); +} + +bool SpatialJoinNode::isSupported(JoinType joinType) { + switch (joinType) { + case JoinType::kInner: + [[fallthrough]]; + case JoinType::kLeft: + return true; + default: + return false; + } +} + +void SpatialJoinNode::addDetails(std::stringstream& stream) const { + stream << JoinTypeName::toName(joinType_); + if (joinCondition_) { + stream << ", joinCondition: " << joinCondition_->toString(); + } + stream << ", probeGeometry: " << probeGeometry_->name(); + stream << ", buildGeometry: " << buildGeometry_->name(); + if (radius_) { + stream << ", radius: " << radius_.value()->name(); + } +} + +folly::dynamic SpatialJoinNode::serialize() const { + auto obj = PlanNode::serialize(); + obj["joinType"] = JoinTypeName::toName(joinType_); + if (joinCondition_) { + obj["joinCondition"] = joinCondition_->serialize(); + } + obj["outputType"] = outputType_->serialize(); + obj["probeGeometry"] = probeGeometry_->serialize(); + obj["buildGeometry"] = buildGeometry_->serialize(); + if (radius_) { + obj["radius"] = radius_.value()->serialize(); + } + return obj; +} + +void SpatialJoinNode::accept( + const PlanNodeVisitor& visitor, + PlanNodeVisitorContext& context) const { + visitor.visit(*this, context); +} + +PlanNodePtr SpatialJoinNode::create(const folly::dynamic& obj, void* context) { + auto sources = deserializeSources(obj, context); + VELOX_CHECK_EQ(2, sources.size()); + + TypedExprPtr joinCondition; + if (obj.count("joinCondition")) { + joinCondition = + ISerializable::deserialize(obj["joinCondition"], context); + } + + auto outputType = deserializeRowType(obj["outputType"]); + auto probeGeometry = deserializeField(obj["probeGeometry"]); + auto buildGeometry = deserializeField(obj["buildGeometry"]); + std::optional radius; + if (obj.count("radius")) { + radius = deserializeField(obj["radius"]); + } + + return std::make_shared( + deserializePlanNodeId(obj), + JoinTypeName::toJoinType(obj["joinType"].asString()), + joinCondition, + probeGeometry, + buildGeometry, + radius, + sources[0], + sources[1], + outputType); +} + TopNNode::TopNNode( const PlanNodeId& id, const std::vector& sortingKeys, @@ -2921,10 +3382,7 @@ void PlanNode::toString( bool detailed, bool recursive, size_t indentationSize, - const std::function& addContext) const { + const AddContextFunc& addContext) const { const std::string indentation(indentationSize, ' '); stream << indentation << "-- " << name() << "[" << id() << "]"; @@ -2939,10 +3397,8 @@ void PlanNode::toString( stream << std::endl; if (addContext) { - auto contextIndentation = indentation + " "; - stream << contextIndentation; + const auto contextIndentation = indentation + " "; addContext(id(), contextIndentation, stream); - stream << std::endl; } if (recursive) { @@ -2994,16 +3450,22 @@ void PlanNode::accept( void PlanNode::toSummaryString( const PlanSummaryOptions& options, std::stringstream& stream, - size_t indentationSize) const { + size_t indentationSize, + const AddContextFunc& addContext) const { const std::string indentation(indentationSize, ' '); + const auto detailsIndentation = indentation + std::string(6, ' '); + stream << indentation << "-- " << name() << "[" << id() << "]: " << summarizeOutputType(outputType(), options) << std::endl; + addSummaryDetails(detailsIndentation, options, stream); - addSummaryDetails(indentation + " ", options, stream); + if (addContext != nullptr) { + addContext(id(), detailsIndentation, stream); + } for (auto& source : sources()) { - source->toSummaryString(options, stream, indentationSize + 2); + source->toSummaryString(options, stream, indentationSize + 2, addContext); } } @@ -3017,10 +3479,54 @@ void PlanNode::addSummaryDetails( stream << indentation << truncate(out.str(), options.maxLength) << std::endl; } +void PlanNode::toSkeletonString( + std::stringstream& stream, + size_t indentationSize) const { + // Skip Project nodes. + if (const auto* project = dynamic_cast(this)) { + project->sources().at(0)->toSkeletonString(stream, indentationSize); + return; + } + + const std::string indentation(indentationSize, ' '); + + stream << indentation << "-- " << name() << "[" << id() + << "]: " << outputType()->size() << " fields" << std::endl; + + // Include table scan details. + if (const auto* scan = dynamic_cast(this)) { + stream << indentation << std::string(6, ' ') + << scan->tableHandle()->toString() << std::endl; + } + + for (const auto& source : sources()) { + source->toSkeletonString(stream, indentationSize + 2); + } +} + +// static +const PlanNode* PlanNode::findFirstNode( + const PlanNode* root, + const std::function& predicate) { + VELOX_CHECK_NOT_NULL(root); + if (predicate(root)) { + return root; + } + + // Recursively go further through the sources. + for (const auto& source : root->sources()) { + const auto* ret = PlanNode::findFirstNode(source.get(), predicate); + if (ret != nullptr) { + return ret; + } + } + return nullptr; +} + namespace { void collectLeafPlanNodeIds( - const core::PlanNode& planNode, - std::unordered_set& leafIds) { + const PlanNode& planNode, + std::unordered_set& leafIds) { if (planNode.sources().empty()) { leafIds.insert(planNode.id()); return; @@ -3034,8 +3540,8 @@ void collectLeafPlanNodeIds( } } // namespace -std::unordered_set PlanNode::leafPlanNodeIds() const { - std::unordered_set leafIds; +std::unordered_set PlanNode::leafPlanNodeIds() const { + std::unordered_set leafIds; collectLeafPlanNodeIds(*this, leafIds); return leafIds; } @@ -3062,7 +3568,9 @@ void PlanNode::registerSerDe() { registry.Register("OrderByNode", OrderByNode::create); registry.Register("PartitionedOutputNode", PartitionedOutputNode::create); registry.Register("ProjectNode", ProjectNode::create); + registry.Register("ParallelProjectNode", ParallelProjectNode::create); registry.Register("RowNumberNode", RowNumberNode::create); + registry.Register("SpatialJoinNode", SpatialJoinNode::create); registry.Register("TableScanNode", TableScanNode::create); registry.Register("TableWriteNode", TableWriteNode::create); registry.Register("TableWriteMergeNode", TableWriteMergeNode::create); @@ -3165,7 +3673,7 @@ void AggregationNode::addSummaryDetails( PlanNodePtr FilterNode::create(const folly::dynamic& obj, void* context) { auto source = deserializeSingleSource(obj, context); - auto filter = ISerializable::deserialize(obj["filter"]); + auto filter = ISerializable::deserialize(obj["filter"], context); return std::make_shared( deserializePlanNodeId(obj), filter, std::move(source)); } @@ -3177,7 +3685,7 @@ folly::dynamic IndexLookupCondition::serialize() const { } bool InIndexLookupCondition::isFilter() const { - return std::dynamic_pointer_cast(list) != nullptr; + return list->isConstantKind(); } folly::dynamic InIndexLookupCondition::serialize() const { @@ -3195,16 +3703,13 @@ void InIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(key); VELOX_CHECK_NOT_NULL(list); VELOX_CHECK( - std::dynamic_pointer_cast(list) || - std::dynamic_pointer_cast(list), + list->isFieldAccessKind() || list->isConstantKind(), "Invalid condition list {}", list->toString()); - const auto listType = - std::dynamic_pointer_cast(list->type()); - VELOX_CHECK_NOT_NULL(listType); + const auto& listType = list->type()->asArray(); VELOX_CHECK_EQ( key->type()->kind(), - listType->elementType()->kind(), + listType.elementType()->kind(), "In condition key and list condition element must have the same type"); } @@ -3222,9 +3727,7 @@ IndexLookupConditionPtr InIndexLookupCondition::create( } bool BetweenIndexLookupCondition::isFilter() const { - return (std::dynamic_pointer_cast(lower) != - nullptr) && - (std::dynamic_pointer_cast(upper) != nullptr); + return lower->isConstantKind() && upper->isConstantKind(); } folly::dynamic BetweenIndexLookupCondition::serialize() const { @@ -3259,14 +3762,12 @@ void BetweenIndexLookupCondition::validate() const { VELOX_CHECK_NOT_NULL(lower); VELOX_CHECK_NOT_NULL(upper); VELOX_CHECK( - std::dynamic_pointer_cast(lower) || - std::dynamic_pointer_cast(lower), + lower->isFieldAccessKind() || lower->isConstantKind(), "Invalid lower between condition {}", lower->toString()); VELOX_CHECK( - std::dynamic_pointer_cast(upper) || - std::dynamic_pointer_cast(upper), + upper->isFieldAccessKind() || upper->isConstantKind(), "Invalid upper between condition {}", upper->toString()); @@ -3280,4 +3781,44 @@ void BetweenIndexLookupCondition::validate() const { upper->type()->kind(), "Index key and upper condition must have the same type"); } + +bool EqualIndexLookupCondition::isFilter() const { + return value->isConstantKind(); +} + +folly::dynamic EqualIndexLookupCondition::serialize() const { + folly::dynamic obj = IndexLookupCondition::serialize(); + obj["type"] = "equal"; + obj["value"] = value->serialize(); + return obj; +} + +std::string EqualIndexLookupCondition::toString() const { + return fmt::format("{} = {}", key->toString(), value->toString()); +} + +IndexLookupConditionPtr EqualIndexLookupCondition::create( + const folly::dynamic& obj, + void* context) { + auto key = + ISerializable::deserialize(obj["key"], context); + return std::make_shared( + key, ISerializable::deserialize(obj["value"], context)); +} + +void EqualIndexLookupCondition::validate() const { + VELOX_CHECK_NOT_NULL(key); + VELOX_CHECK_NOT_NULL(value); + VELOX_CHECK_NOT_NULL( + checkedPointerCast(value), + "Equal condition value must be a constant expression: {}", + value->toString()); + + VELOX_CHECK_EQ( + key->type()->kind(), + value->type()->kind(), + "Equal condition key and value must have compatible types: {} vs {}", + key->type()->toString(), + value->type()->toString()); +} } // namespace facebook::velox::core diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index 873aa032675b..24f954d8f0b3 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -19,6 +19,7 @@ #include +#include "velox/common/Enums.h" #include "velox/connectors/Connector.h" #include "velox/core/Expressions.h" #include "velox/core/QueryConfig.h" @@ -31,14 +32,14 @@ namespace facebook::velox::core { class PlanNodeVisitor; class PlanNodeVisitorContext; -typedef std::string PlanNodeId; +using PlanNodeId = std::string; /// Generic representation of InsertTable struct InsertTableHandle { public: InsertTableHandle( const std::string& connectorId, - const std::shared_ptr& + const connector::ConnectorInsertTableHandlePtr& connectorInsertTableHandle) : connectorId_(connectorId), connectorInsertTableHandle_(connectorInsertTableHandle) {} @@ -47,8 +48,8 @@ struct InsertTableHandle { return connectorId_; } - const std::shared_ptr& - connectorInsertTableHandle() const { + const connector::ConnectorInsertTableHandlePtr& connectorInsertTableHandle() + const { return connectorInsertTableHandle_; } @@ -57,8 +58,7 @@ struct InsertTableHandle { const std::string connectorId_; // Write request to a DataSink of that connector type - const std::shared_ptr - connectorInsertTableHandle_; + const connector::ConnectorInsertTableHandlePtr connectorInsertTableHandle_; }; class SortOrder { @@ -74,14 +74,7 @@ class SortOrder { return nullsFirst_; } - bool operator==(const SortOrder& other) const { - return std::tie(ascending_, nullsFirst_) == - std::tie(other.ascending_, other.nullsFirst_); - } - - bool operator!=(const SortOrder& other) const { - return !(*this == other); - } + bool operator==(const SortOrder& other) const = default; std::string toString() const { return fmt::format( @@ -114,15 +107,19 @@ extern const SortOrder kDescNullsLast; struct PlanSummaryOptions { /// Options that apply specifically to PROJECT nodes. struct ProjectOptions { - /// For a given PROJECT node, maximum number of non-identity projection - /// expressions to include in the summary. By default, no expression is - /// included. + /// For a given PROJECT node, maximum number of non-identity and + /// non-constant projection expressions to include in the summary. By + /// default, no expression is included. size_t maxProjections = 0; /// For a given PROJECT node, maximum number of dereference (access of a /// struct field) expressions to include in the summary. By default, no /// expression is included. size_t maxDereferences = 0; + + /// For a given PROJECT node, maximum number of constant expressions to + /// include in the summary. By default, no expression is included. + size_t maxConstants = 0; }; ProjectOptions project = {}; @@ -157,7 +154,7 @@ struct PlanSummaryOptions { class PlanNode : public ISerializable { public: - explicit PlanNode(const PlanNodeId& id) : id_{id} {} + explicit PlanNode(PlanNodeId id) : id_{std::move(id)} {} virtual ~PlanNode() {} @@ -220,6 +217,14 @@ class PlanNode : public ISerializable { /// Returns a set of leaf plan node IDs. std::unordered_set leafPlanNodeIds() const; + /// Lambda to add context for a given plan node. Receives plan node ID, + /// indentation and std::ostream where to append the context. Start each line + /// of context with 'indentation' and end with a new-line character. + using AddContextFunc = std::function; + /// Returns human-friendly representation of the plan. By default, returns the /// plan node name. Includes plan node details such as join keys and aggregate /// function names if 'detailed' is true. Returns the whole sub-tree if @@ -227,48 +232,55 @@ class PlanNode : public ISerializable { /// 'addContext' is not null. /// /// @param addContext Optional lambda to add context for a given plan node. - /// Receives plan node ID, indentation and std::stringstream where to append - /// the context. Use indentation for second and subsequent lines of a - /// multi-line context. Do not use indentation for single-line context. Do not - /// add trailing new-line character for the last or only line of context. std::string toString( bool detailed = false, bool recursive = false, - const std::function& addContext = nullptr) const { + const AddContextFunc& addContext = nullptr) const { std::stringstream stream; toString(stream, detailed, recursive, 0, addContext); return stream.str(); } - std::string toSummaryString(PlanSummaryOptions options = {}) const { + /// @param addContext Optional lambda to add context for a given plan node. + std::string toSummaryString( + PlanSummaryOptions options = {}, + const AddContextFunc& addContext = nullptr) const { std::stringstream stream; - toSummaryString(options, stream, 0); + toSummaryString(options, stream, 0, addContext); + return stream.str(); + } + + std::string toSkeletonString() const { + std::stringstream stream; + toSkeletonString(stream, 0); return stream.str(); } /// The name of the plan node, used in toString. virtual std::string_view name() const = 0; + template + bool is() const { + return dynamic_cast(this) != nullptr; + } + + template + const T* as() const { + return dynamic_cast(this); + } + /// Recursively checks the node tree for a first node that satisfy a given /// condition. Returns pointer to the node if found, nullptr if not. static const PlanNode* findFirstNode( - const PlanNode* node, - const std::function& predicate) { - if (predicate(node)) { - return node; - } + const PlanNode* root, + const std::function& predicate); - // Recursively go further through the sources. - for (const auto& source : node->sources()) { - const auto* ret = PlanNode::findFirstNode(source.get(), predicate); - if (ret != nullptr) { - return ret; - } - } - return nullptr; + /// @return PlanNode with matching ID or nullptr if not found. + static const PlanNode* findNodeById( + const PlanNode* root, + const PlanNodeId& id) { + return findFirstNode( + root, [&](const auto* node) { return node->id() == id; }); } private: @@ -286,11 +298,9 @@ class PlanNode : public ISerializable { bool detailed, bool recursive, size_t indentationSize, - const std::function& addContext) const; + const AddContextFunc& addContext) const; + // The default implementation calls 'addDetails' and truncates the result. virtual void addSummaryDetails( const std::string& indentation, const PlanSummaryOptions& options, @@ -299,9 +309,16 @@ class PlanNode : public ISerializable { void toSummaryString( const PlanSummaryOptions& options, std::stringstream& stream, - size_t indentationSize) const; + size_t indentationSize, + const AddContextFunc& addContext) const; - const std::string id_; + // Even shorter summary of the plan. Hides all Project nodes. Shows only + // number of output columns, but no names or types. Doesn't show any details + // of the nodes, except for table scan. + void toSkeletonString(std::stringstream& stream, size_t indentationSize) + const; + + const PlanNodeId id_; }; using PlanNodePtr = std::shared_ptr; @@ -315,10 +332,7 @@ class ValuesNode : public PlanNode { size_t repeatTimes = kDefaultRepeatTimes) : PlanNode(id), values_(std::move(values)), - outputType_( - values_.empty() - ? ROW({}) - : std::dynamic_pointer_cast(values_[0]->type())), + outputType_(values_.empty() ? ROW({}) : values_[0]->rowType()), parallelizable_(parallelizable), repeatTimes_(repeatTimes) {} @@ -498,6 +512,8 @@ class ArrowStreamNode : public PlanNode { std::shared_ptr arrowStream_; }; +using ArrowStreamNodePtr = std::shared_ptr; + class TraceScanNode final : public PlanNode { public: TraceScanNode( @@ -614,6 +630,8 @@ class TraceScanNode final : public PlanNode { const RowTypePtr outputType_; }; +using TraceScanNodePtr = std::shared_ptr; + class FilterNode : public PlanNode { public: FilterNode(const PlanNodeId& id, TypedExprPtr filter, PlanNodePtr source) @@ -706,6 +724,8 @@ class FilterNode : public PlanNode { const TypedExprPtr filter_; }; +using FilterNodePtr = std::shared_ptr; + class AbstractProjectNode : public PlanNode { public: AbstractProjectNode( @@ -809,15 +829,7 @@ class AbstractProjectNode : public PlanNode { static RowTypePtr makeOutputType( const std::vector& names, - const std::vector& projections) { - std::vector types; - for (auto& projection : projections) { - types.push_back(projection->type()); - } - - auto namesCopy = names; - return std::make_shared(std::move(namesCopy), std::move(types)); - } + const std::vector& projections); const std::vector sources_; const std::vector names_; @@ -876,15 +888,92 @@ class ProjectNode : public AbstractProjectNode { static PlanNodePtr create(const folly::dynamic& obj, void* context); }; +using ProjectNodePtr = std::shared_ptr; + +/// Variant of ProjectNode that computes projections in +/// parallel. The exprs are given in groups, so that all exprs in +/// one group run together and all groups run in parallel. If lazies +/// are loaded, each lazy must be loaded by exactly one group. If +/// there are identity projections in the groups, possible lazies +/// are loaded as part of processing the group. One can additionally +/// specify 'noLoadIdentities' which are identity projected through +/// without loading. This last set must be disjoint from all columns +/// accessed by the exprs. The output type has 'names' first and +/// then 'noLoadIdentities'. The ith name corresponds to the ith +/// expr when exprs is flattened. Inherits core::ProjectNode in order to reuse +/// the summary functions. +class ParallelProjectNode : public core::AbstractProjectNode { + public: + ParallelProjectNode( + const core::PlanNodeId& id, + std::vector names, + std::vector> exprGroups, + std::vector noLoadIdentities, + core::PlanNodePtr input); + + std::string_view name() const override { + return "ParallelProject"; + } + + const std::vector& exprNames() const { + return exprNames_; + } + + const std::vector>& exprGroups() const { + return exprGroups_; + } + + const std::vector noLoadIdentities() const { + return noLoadIdentities_; + } + + folly::dynamic serialize() const override; + + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) + const override; + + static PlanNodePtr create(const folly::dynamic& obj, void* context); + + private: + void addDetails(std::stringstream& stream) const override; + + const std::vector exprNames_; + const std::vector> exprGroups_; + const std::vector noLoadIdentities_; +}; + +/// Variant of project node that contains only field accesses and dereferences, +/// and does not materialize the input columns. Used to split subfields of +/// struct columns for later parallel processing. +class LazyDereferenceNode : public core::ProjectNode { + public: + LazyDereferenceNode( + const PlanNodeId& id, + std::vector names, + std::vector projections, + PlanNodePtr source) + : ProjectNode( + id, + std::move(names), + std::move(projections), + std::move(source)) {} + + std::string_view name() const override { + return "LazyDereference"; + } + + static PlanNodePtr create(const folly::dynamic& obj, void* context); +}; + +using ParallelProjectNodePtr = std::shared_ptr; + class TableScanNode : public PlanNode { public: TableScanNode( const PlanNodeId& id, RowTypePtr outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& assignments) + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& assignments) : PlanNode(id), outputType_(std::move(outputType)), tableHandle_(tableHandle), @@ -911,16 +1000,12 @@ class TableScanNode : public PlanNode { return *this; } - Builder& tableHandle( - std::shared_ptr tableHandle) { + Builder& tableHandle(connector::ConnectorTableHandlePtr tableHandle) { tableHandle_ = std::move(tableHandle); return *this; } - Builder& assignments( - std::unordered_map< - std::string, - std::shared_ptr> assignments) { + Builder& assignments(connector::ColumnHandleMap assignments) { assignments_ = std::move(assignments); return *this; } @@ -944,12 +1029,8 @@ class TableScanNode : public PlanNode { private: std::optional id_; std::optional outputType_; - std::optional> - tableHandle_; - std::optional>> - assignments_; + std::optional tableHandle_; + std::optional assignments_; }; bool supportsBarrier() const override { @@ -969,13 +1050,11 @@ class TableScanNode : public PlanNode { return true; } - const std::shared_ptr& tableHandle() const { + const connector::ConnectorTableHandlePtr& tableHandle() const { return tableHandle_; } - const std:: - unordered_map>& - assignments() const { + const connector::ColumnHandleMap& assignments() const { return assignments_; } @@ -990,11 +1069,14 @@ class TableScanNode : public PlanNode { private: void addDetails(std::stringstream& stream) const override; + void addSummaryDetails( + const std::string& indentation, + const PlanSummaryOptions& options, + std::stringstream& stream) const override; + const RowTypePtr outputType_; - const std::shared_ptr tableHandle_; - const std:: - unordered_map> - assignments_; + const connector::ConnectorTableHandlePtr tableHandle_; + const connector::ColumnHandleMap assignments_; }; using TableScanNodePtr = std::shared_ptr; @@ -1012,9 +1094,7 @@ class AggregationNode : public PlanNode { kSingle }; - static const char* stepName(Step step); - - static Step stepFromName(const std::string& name); + VELOX_DECLARE_EMBEDDED_ENUM_NAME(Step); /// Aggregate function call. struct Aggregate { @@ -1028,14 +1108,14 @@ class AggregationNode : public PlanNode { /// Optional name of input column to use as a mask. Column type must be /// BOOLEAN. - FieldAccessTypedExprPtr mask; + FieldAccessTypedExprPtr mask{}; /// Optional list of input columns to sort by before applying aggregate /// function. - std::vector sortingKeys; + std::vector sortingKeys{}; /// A list of sorting orders that goes together with 'sortingKeys'. - std::vector sortingOrders; + std::vector sortingOrders{}; /// Boolean indicating whether inputs must be de-duplicated before /// aggregating. @@ -1054,8 +1134,29 @@ class AggregationNode : public PlanNode { const std::vector& aggregateNames, const std::vector& aggregates, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source); + AggregationNode( + const PlanNodeId& id, + Step step, + const std::vector& groupingKeys, + const std::vector& preGroupedKeys, + const std::vector& aggregateNames, + const std::vector& aggregates, + bool ignoreNullKeys, + PlanNodePtr source) + : AggregationNode( + id, + step, + groupingKeys, + preGroupedKeys, + aggregateNames, + aggregates, + ignoreNullKeys, + /*noGroupsSpanBatches=*/false, + source) {} + /// @param globalGroupingSets Group IDs of the global grouping sets produced /// by the preceding GroupId node /// @param groupId Group ID key produced by the preceding GroupId node. Must @@ -1075,8 +1176,33 @@ class AggregationNode : public PlanNode { const std::vector& globalGroupingSets, const std::optional& groupId, bool ignoreNullKeys, + bool noGroupsSpanBatches, PlanNodePtr source); + AggregationNode( + const PlanNodeId& id, + Step step, + const std::vector& groupingKeys, + const std::vector& preGroupedKeys, + const std::vector& aggregateNames, + const std::vector& aggregates, + const std::vector& globalGroupingSets, + const std::optional& groupId, + bool ignoreNullKeys, + PlanNodePtr source) + : AggregationNode( + id, + step, + groupingKeys, + preGroupedKeys, + aggregateNames, + aggregates, + globalGroupingSets, + groupId, + ignoreNullKeys, + /*noGroupsSpanBatches=*/false, + source) {} + class Builder { public: Builder() = default; @@ -1091,6 +1217,7 @@ class AggregationNode : public PlanNode { globalGroupingSets_ = other.globalGroupingSets(); groupId_ = other.groupId(); ignoreNullKeys_ = other.ignoreNullKeys(); + noGroupsSpanBatches_ = other.noGroupsSpanBatches(); VELOX_CHECK_EQ(other.sources().size(), 1); source_ = other.sources()[0]; } @@ -1141,6 +1268,11 @@ class AggregationNode : public PlanNode { return *this; } + Builder& noGroupsSpanBatches(bool noGroupsSpanBatches) { + noGroupsSpanBatches_ = noGroupsSpanBatches; + return *this; + } + Builder& source(PlanNodePtr source) { source_ = std::move(source); return *this; @@ -1175,6 +1307,7 @@ class AggregationNode : public PlanNode { globalGroupingSets_, groupId_, ignoreNullKeys_.value(), + noGroupsSpanBatches_, source_.value()); } @@ -1188,6 +1321,7 @@ class AggregationNode : public PlanNode { std::vector globalGroupingSets_ = kDefaultGlobalGroupingSets; std::optional groupId_ = kDefaultGroupId; std::optional ignoreNullKeys_; + bool noGroupsSpanBatches_{false}; std::optional source_; }; @@ -1221,10 +1355,10 @@ class AggregationNode : public PlanNode { bool isPreGrouped() const { return !preGroupedKeys_.empty() && std::equal( - preGroupedKeys_.begin(), - preGroupedKeys_.end(), - groupingKeys_.begin(), - groupingKeys_.end(), + preGroupedKeys_.cbegin(), + preGroupedKeys_.cend(), + groupingKeys_.cbegin(), + groupingKeys_.cend(), [](const FieldAccessTypedExprPtr& x, const FieldAccessTypedExprPtr& y) -> bool { return (*x == *y); @@ -1247,10 +1381,19 @@ class AggregationNode : public PlanNode { return globalGroupingSets_; } - std::optional groupId() const { + const std::optional& groupId() const { return groupId_; } + /// When true, indicates that for streaming aggregation, no sort group spans + /// across input batches. Each input batch contains complete data for its + /// groups - no group will appear in any subsequent input batch. This allows + /// the streaming aggregation operator to immediately produce the aggregation + /// result for all the groups in each input batch. + bool noGroupsSpanBatches() const { + return noGroupsSpanBatches_; + } + std::string_view name() const override { return "Aggregation"; } @@ -1287,13 +1430,22 @@ class AggregationNode : public PlanNode { const std::vector aggregates_; const bool ignoreNullKeys_; - std::optional groupId_; - std::vector globalGroupingSets_; + const std::optional groupId_; + const std::vector globalGroupingSets_; + + // When true, indicates that for streaming aggregation, no sort group spans + // across input batches. Each input batch contains complete data for its + // groups - no group will appear in any subsequent input batch. This allows + // the streaming aggregation operator to immediately produce the aggregation + // result for all the groups in each input batch. + const bool noGroupsSpanBatches_; const std::vector sources_; const RowTypePtr outputType_; }; +using AggregationNodePtr = std::shared_ptr; + inline std::ostream& operator<<( std::ostream& out, const AggregationNode::Step& step) { @@ -1316,14 +1468,64 @@ inline std::string mapAggregationStepToName(const AggregationNode::Step& step) { return ss.str(); } +/// Specify the column stats collection by aggregation. This is used by table +/// writer plan nodes. +struct ColumnStatsSpec : public ISerializable { + /// Grouping keys of the aggregation. It is set to partitioning keys for the + /// partitioned table. It is empty for unpartitioned table. + std::vector groupingKeys; + + /// Step of the aggregation. Specifies the stage of multi-step aggregation + /// processing for column statistics collection: + /// - kSingle: used by TableWrite to complete aggregation in one step when no + /// multi-stage processing needed. + /// - kPartial: used by TableWrite for the first stage of multi-step + /// aggregation, produces intermediate results that need further processing + /// (used in distributed scenarios) + /// - kIntermediate: used by TableWriteMerge in middle stage that processes + /// partial results and produces more refined intermediate results for + /// further aggregation + /// - kFinal: used by TableWriteMerge Final stage that processes intermediate + /// results to produce the complete aggregated statistics. + AggregationNode::Step aggregationStep{AggregationNode::Step::kSingle}; + + /// Names of the aggregations. + std::vector aggregateNames; + + /// Aggregations. + std::vector aggregates; + + ColumnStatsSpec( + std::vector _groupingKeys, + AggregationNode::Step _aggregationStep, + std::vector _aggregateNames, + std::vector _aggregates) + : groupingKeys(std::move(_groupingKeys)), + aggregationStep(_aggregationStep), + aggregateNames(std::move(_aggregateNames)), + aggregates(std::move(_aggregates)) { + VELOX_CHECK(!aggregates.empty()); + VELOX_CHECK_EQ(aggregates.size(), aggregateNames.size()); + } + + /// Returns the output row type that will be produced by this column stats + /// spec. The output type is determined by the grouping keys and aggregate + /// functions specified in the object. + RowTypePtr outputType() const; + + folly::dynamic serialize() const override; + + static ColumnStatsSpec create(const folly::dynamic& obj, void* context); +}; + class TableWriteNode : public PlanNode { public: TableWriteNode( const PlanNodeId& id, const RowTypePtr& columns, const std::vector& columnNames, - std::shared_ptr aggregationNode, - std::shared_ptr insertTableHandle, + std::optional columnStatsSpec, + std::shared_ptr insertTableHandle, bool hasPartitioningScheme, RowTypePtr outputType, connector::CommitStrategy commitStrategy, @@ -1332,7 +1534,7 @@ class TableWriteNode : public PlanNode { sources_{source}, columns_{columns}, columnNames_{columnNames}, - aggregationNode_(std::move(aggregationNode)), + columnStatsSpec_(std::move(columnStatsSpec)), insertTableHandle_(std::move(insertTableHandle)), hasPartitioningScheme_(hasPartitioningScheme), outputType_(std::move(outputType)), @@ -1355,7 +1557,7 @@ class TableWriteNode : public PlanNode { id_ = other.id(); columns_ = other.columns(); columnNames_ = other.columnNames(); - aggregationNode_ = other.aggregationNode(); + columnStatsSpec_ = other.columnStatsSpec(); insertTableHandle_ = other.insertTableHandle(); hasPartitioningScheme_ = other.hasPartitioningScheme(); outputType_ = other.outputType(); @@ -1379,8 +1581,8 @@ class TableWriteNode : public PlanNode { return *this; } - Builder& aggregationNode(std::shared_ptr aggregationNode) { - aggregationNode_ = std::move(aggregationNode); + Builder& columnStatsSpec(std::optional columnStatsSpec) { + columnStatsSpec_ = std::move(columnStatsSpec); return *this; } @@ -1417,10 +1619,6 @@ class TableWriteNode : public PlanNode { VELOX_USER_CHECK( columnNames_.has_value(), "TableWriteNode columnNames is not set"); VELOX_USER_CHECK( - aggregationNode_.has_value(), - "TableWriteNode aggregationNode is not set"); - VELOX_USER_CHECK( - insertTableHandle_.has_value(), "TableWriteNode insertTableHandle is not set"); VELOX_USER_CHECK( @@ -1437,7 +1635,7 @@ class TableWriteNode : public PlanNode { id_.value(), columns_.value(), columnNames_.value(), - aggregationNode_.value(), + columnStatsSpec_, insertTableHandle_.value(), hasPartitioningScheme_.value(), outputType_.value(), @@ -1449,8 +1647,8 @@ class TableWriteNode : public PlanNode { std::optional id_; std::optional columns_; std::optional> columnNames_; - std::optional> aggregationNode_; - std::optional> insertTableHandle_; + std::optional columnStatsSpec_; + std::optional> insertTableHandle_; std::optional hasPartitioningScheme_; std::optional outputType_; std::optional commitStrategy_; @@ -1480,7 +1678,7 @@ class TableWriteNode : public PlanNode { return columnNames_; } - const std::shared_ptr& insertTableHandle() const { + const std::shared_ptr& insertTableHandle() const { return insertTableHandle_; } @@ -1497,9 +1695,15 @@ class TableWriteNode : public PlanNode { return commitStrategy_; } - /// Optional aggregation node for column statistics collection - std::shared_ptr aggregationNode() const { - return aggregationNode_; + /// Returns true of this table write plan node has configured column + /// statistics collection. + bool hasColumnStatsSpec() const { + return columnStatsSpec_.has_value(); + } + + /// Optional spec for column statistics collection. + const std::optional& columnStatsSpec() const { + return columnStatsSpec_; } bool canSpill(const QueryConfig& queryConfig) const override { @@ -1520,13 +1724,15 @@ class TableWriteNode : public PlanNode { const std::vector sources_; const RowTypePtr columns_; const std::vector columnNames_; - const std::shared_ptr aggregationNode_; - const std::shared_ptr insertTableHandle_; + const std::optional columnStatsSpec_; + const std::shared_ptr insertTableHandle_; const bool hasPartitioningScheme_; const RowTypePtr outputType_; const connector::CommitStrategy commitStrategy_; }; +using TableWriteNodePtr = std::shared_ptr; + class TableWriteMergeNode : public PlanNode { public: /// 'outputType' specifies the type to store the metadata of table write @@ -1535,12 +1741,21 @@ class TableWriteMergeNode : public PlanNode { TableWriteMergeNode( const PlanNodeId& id, RowTypePtr outputType, - std::shared_ptr aggregationNode, + std::optional columnStatsSpec, PlanNodePtr source) : PlanNode(id), - aggregationNode_(std::move(aggregationNode)), + columnStatsSpec_(std::move(columnStatsSpec)), sources_{std::move(source)}, - outputType_(std::move(outputType)) {} + outputType_(std::move(outputType)) { + if (hasColumnStatsSpec()) { + VELOX_USER_CHECK( + columnStatsSpec_->aggregationStep == + core::AggregationNode::Step::kFinal || + columnStatsSpec_->aggregationStep == + core::AggregationNode::Step::kIntermediate, + "TableWriteMergeNode requires aggregation step to be intermediate or final"); + } + } class Builder { public: @@ -1549,7 +1764,7 @@ class TableWriteMergeNode : public PlanNode { explicit Builder(const TableWriteMergeNode& other) { id_ = other.id(); outputType_ = other.outputType(); - aggregationNode_ = other.aggregationNode(); + columnStatsSpec_ = other.columnStatsSpec(); VELOX_CHECK_EQ(other.sources().size(), 1); source_ = other.sources()[0]; } @@ -1564,8 +1779,8 @@ class TableWriteMergeNode : public PlanNode { return *this; } - Builder& aggregationNode(std::shared_ptr aggregationNode) { - aggregationNode_ = std::move(aggregationNode); + Builder& columnStatsSpec(std::optional columnStatsSpec) { + columnStatsSpec_ = std::move(columnStatsSpec); return *this; } @@ -1578,29 +1793,29 @@ class TableWriteMergeNode : public PlanNode { VELOX_USER_CHECK(id_.has_value(), "TableWriteMergeNode id is not set"); VELOX_USER_CHECK( outputType_.has_value(), "TableWriteMergeNode outputType is not set"); - VELOX_USER_CHECK( - aggregationNode_.has_value(), - "TableWriteMergeNode aggregationNode is not set"); VELOX_USER_CHECK( source_.has_value(), "TableWriteMergeNode source is not set"); return std::make_shared( - id_.value(), - outputType_.value(), - aggregationNode_.value(), - source_.value()); + id_.value(), outputType_.value(), columnStatsSpec_, source_.value()); } private: std::optional id_; std::optional outputType_; - std::optional> aggregationNode_; + std::optional columnStatsSpec_; std::optional source_; }; - /// Optional aggregation node for column statistics collection - std::shared_ptr aggregationNode() const { - return aggregationNode_; + /// Returns true of this table write merge plan node has configured column + /// statistics collection. + bool hasColumnStatsSpec() const { + return columnStatsSpec_.has_value(); + } + + /// Optional spec for column statistics collection. + const std::optional& columnStatsSpec() const { + return columnStatsSpec_; } const std::vector& sources() const override { @@ -1625,11 +1840,13 @@ class TableWriteMergeNode : public PlanNode { private: void addDetails(std::stringstream& stream) const override; - const std::shared_ptr aggregationNode_; + const std::optional columnStatsSpec_; const std::vector sources_; const RowTypePtr outputType_; }; +using TableWriteMergeNodePtr = std::shared_ptr; + /// For each input row, generates N rows with M columns according to /// specified 'projections'. 'projections' is an N x M matrix of expressions: /// a vector of N rows each having M columns. Each expression is either a @@ -1733,6 +1950,8 @@ class ExpandNode : public PlanNode { const std::vector> projections_; }; +using ExpandNodePtr = std::shared_ptr; + /// Plan node used to implement aggregations over grouping sets. Duplicates /// the aggregation input for each set of grouping keys. The output contains /// one column for each grouping key, followed by aggregation inputs, followed @@ -1901,6 +2120,8 @@ class GroupIdNode : public PlanNode { const std::string groupIdName_; }; +using GroupIdNodePtr = std::shared_ptr; + class ExchangeNode : public PlanNode { public: ExchangeNode( @@ -1987,6 +2208,8 @@ class ExchangeNode : public PlanNode { const VectorSerde::Kind serdeKind_; }; +using ExchangeNodePtr = std::shared_ptr; + class MergeExchangeNode : public ExchangeNode { public: MergeExchangeNode( @@ -2087,6 +2310,8 @@ class MergeExchangeNode : public ExchangeNode { const std::vector sortingOrders_; }; +using MergeExchangeNodePtr = std::shared_ptr; + class LocalMergeNode : public PlanNode { public: LocalMergeNode( @@ -2169,6 +2394,10 @@ class LocalMergeNode : public PlanNode { return sortingKeys_; } + bool canSpill(const QueryConfig& queryConfig) const override { + return !sortingKeys_.empty() && queryConfig.localMergeSpillEnabled(); + } + const std::vector& sortingOrders() const { return sortingOrders_; } @@ -2189,6 +2418,8 @@ class LocalMergeNode : public PlanNode { const std::vector sortingOrders_; }; +using LocalMergeNodePtr = std::shared_ptr; + /// Calculates partition number for each row of the specified vector. class PartitionFunction { public: @@ -2258,9 +2489,7 @@ class LocalPartitionNode : public PlanNode { kRepartition, }; - static const char* typeName(Type type); - - static Type typeFromName(const std::string& name); + VELOX_DECLARE_EMBEDDED_ENUM_NAME(Type); /// If 'scaleWriter' is true, the local partition is used to scale the table /// writer prcessing. @@ -2282,7 +2511,7 @@ class LocalPartitionNode : public PlanNode { VELOX_USER_CHECK_NOT_NULL(partitionFunctionSpec_); - for (auto i = 1; i < sources_.size(); ++i) { + for (size_t i = 1; i < sources_.size(); ++i) { VELOX_USER_CHECK( *sources_[i]->outputType() == *sources_[0]->outputType(), "All sources of the LocalPartitionedNode must have the same output type: {} vs. {}.", @@ -2412,6 +2641,8 @@ class LocalPartitionNode : public PlanNode { const PartitionFunctionSpecPtr partitionFunctionSpec_; }; +using LocalPartitionNodePtr = std::shared_ptr; + class PartitionedOutputNode : public PlanNode { public: enum class Kind { @@ -2419,8 +2650,8 @@ class PartitionedOutputNode : public PlanNode { kBroadcast, kArbitrary, }; - static std::string kindString(Kind kind); - static Kind stringToKind(const std::string& str); + + VELOX_DECLARE_EMBEDDED_ENUM_NAME(Kind) PartitionedOutputNode( const PlanNodeId& id, @@ -2617,6 +2848,7 @@ class PartitionedOutputNode : public PlanNode { } const PartitionFunctionSpec& partitionFunctionSpec() const { + VELOX_CHECK_NOT_NULL(partitionFunctionSpec_); return *partitionFunctionSpec_; } @@ -2641,10 +2873,12 @@ class PartitionedOutputNode : public PlanNode { const RowTypePtr outputType_; }; +using PartitionedOutputNodePtr = std::shared_ptr; + FOLLY_ALWAYS_INLINE std::ostream& operator<<( std::ostream& out, const PartitionedOutputNode::Kind kind) { - out << PartitionedOutputNode::kindString(kind); + out << PartitionedOutputNode::toName(kind); return out; } @@ -2652,81 +2886,75 @@ enum class JoinType { // For each row on the left, find all matching rows on the right and return // all combinations. kInner = 0, + // For each row on the left, find all matching rows on the right and return // all combinations. In addition, return all rows from the left that have no // match on the right with right-side columns filled with nulls. kLeft = 1, + // Opposite of kLeft. For each row on the right, find all matching rows on - // the - // left and return all combinations. In addition, return all rows from the + // the left and return all combinations. In addition, return all rows from the // right that have no match on the left with left-side columns filled with // nulls. kRight = 2, + // A "union" of kLeft and kRight. For each row on the left, find all - // matching - // rows on the right and return all combinations. In addition, return all - // rows - // from the left that have no - // match on the right with right-side columns filled with nulls. Also, - // return - // all rows from the - // right that have no match on the left with left-side columns filled with - // nulls. + // matching rows on the right and return all combinations. In addition, return + // all rows from the left that have no match on the right with right-side + // columns filled with nulls. Also, return all rows from the right that have + // no match on the left with left-side columns filled with nulls. kFull = 3, + // Return a subset of rows from the left side which have a match on the - // right - // side. For this join type, cardinality of the output is less than or equal - // to the cardinality of the left side. + // right side. For this join type, cardinality of the output is less than or + // equal to the cardinality of the left side. kLeftSemiFilter = 4, + // Return each row from the left side with a boolean flag indicating whether // there exists a match on the right side. For this join type, cardinality - // of - // the output equals the cardinality of the left side. + // of the output equals the cardinality of the left side. // // The handling of the rows with nulls in the join key depends on the // 'nullAware' boolean specified separately. // - // Null-aware join follows IN semantic. Regular join follows EXISTS - // semantic. + // Null-aware join follows IN semantic. Regular join follows EXISTS semantic. kLeftSemiProject = 5, + // Opposite of kLeftSemiFilter. Return a subset of rows from the right side // which have a match on the left side. For this join type, cardinality of - // the - // output is less than or equal to the cardinality of the right side. + // the output is less than or equal to the cardinality of the right side. kRightSemiFilter = 6, + // Opposite of kLeftSemiProject. Return each row from the right side with a // boolean flag indicating whether there exists a match on the left side. - // For - // this join type, cardinality of the output equals the cardinality of the + // For this join type, cardinality of the output equals the cardinality of the // right side. // // The handling of the rows with nulls in the join key depends on the // 'nullAware' boolean specified separately. // - // Null-aware join follows IN semantic. Regular join follows EXISTS - // semantic. + // Null-aware join follows IN semantic. Regular join follows EXISTS semantic. kRightSemiProject = 7, + // Return each row from the left side which has no match on the right side. // The handling of the rows with nulls in the join key depends on the // 'nullAware' boolean specified separately. // // Null-aware join follows NOT IN semantic: - // (1) return empty result if the right side contains a record with a null - // in + // (1) return empty result if the right side contains a record with a null in // the join key; - // (2) return left-side row with null in the join key only when - // the right side is empty. + // (2) return left-side row with null in the join key only when the right side + // is empty. // // Regular anti join follows NOT EXISTS semantic: // (1) ignore right-side rows with nulls in the join keys; // (2) unconditionally return left side rows with nulls in the join keys. kAnti = 8, + kNumJoinTypes = 9, }; -const char* joinTypeName(JoinType joinType); - -JoinType joinTypeFromName(const std::string& name); +VELOX_DECLARE_ENUM_NAME(JoinType); inline bool isInnerJoin(JoinType joinType) { return joinType == JoinType::kInner; @@ -2906,6 +3134,15 @@ class AbstractJoinNode : public PlanNode { return isInnerJoin() || isLeftJoin() || isAntiJoin(); } + /// Indicates if this joinNode can drop duplicate rows with same join key. + /// For left semi and anti join, it is not necessary to store duplicate rows. + bool canDropDuplicates() const { + // Left semi and anti join with no extra filter only needs to know whether + // there is a match. Hence, no need to store entries with duplicate keys. + return !filter() && + (isLeftSemiFilterJoin() || isLeftSemiProjectJoin() || isAntiJoin()); + } + const std::vector& leftKeys() const { return leftKeys_; } @@ -2919,6 +3156,7 @@ class AbstractJoinNode : public PlanNode { } protected: + void validate() const; void addDetails(std::stringstream& stream) const override; folly::dynamic serializeBase() const; @@ -2953,7 +3191,8 @@ class HashJoinNode : public AbstractJoinNode { TypedExprPtr filter, PlanNodePtr left, PlanNodePtr right, - RowTypePtr outputType) + RowTypePtr outputType, + bool useHashTableCache = false) : AbstractJoinNode( id, joinType, @@ -2963,7 +3202,10 @@ class HashJoinNode : public AbstractJoinNode { std::move(left), std::move(right), std::move(outputType)), - nullAware_{nullAware} { + nullAware_{nullAware}, + useHashTableCache_{useHashTableCache} { + validate(); + if (nullAware) { VELOX_USER_CHECK( isNullAwareSupported(joinType), @@ -2986,6 +3228,7 @@ class HashJoinNode : public AbstractJoinNode { explicit Builder(const HashJoinNode& other) : AbstractJoinNode::Builder(other) { nullAware_ = other.isNullAware(); + useHashTableCache_ = other.useHashTableCache(); } Builder& nullAware(bool value) { @@ -2993,6 +3236,11 @@ class HashJoinNode : public AbstractJoinNode { return *this; } + Builder& useHashTableCache(bool value) { + useHashTableCache_ = value; + return *this; + } + std::shared_ptr build() const { VELOX_USER_CHECK(id_.has_value(), "HashJoinNode id is not set"); VELOX_USER_CHECK( @@ -3019,11 +3267,13 @@ class HashJoinNode : public AbstractJoinNode { filter_.value_or(nullptr), left_.value(), right_.value(), - outputType_.value()); + outputType_.value(), + useHashTableCache_.value_or(false)); } private: std::optional nullAware_; + std::optional useHashTableCache_; }; std::string_view name() const override { @@ -3046,6 +3296,12 @@ class HashJoinNode : public AbstractJoinNode { return nullAware_; } + /// Returns whether hash table caching is enabled for broadcast joins. + /// Only used by Presto-on-Spark. + bool useHashTableCache() const { + return useHashTableCache_; + } + folly::dynamic serialize() const override; static PlanNodePtr create(const folly::dynamic& obj, void* context); @@ -3054,8 +3310,11 @@ class HashJoinNode : public AbstractJoinNode { void addDetails(std::stringstream& stream) const override; const bool nullAware_; + const bool useHashTableCache_; }; +using HashJoinNodePtr = std::shared_ptr; + /// Represents inner/outer/semi/anti merge joins. Translates to an /// exec::MergeJoin operator. Assumes that both left and right input data is /// sorted on the join keys. A separate pipeline that puts its output into @@ -3126,6 +3385,8 @@ class MergeJoinNode : public AbstractJoinNode { static PlanNodePtr create(const folly::dynamic& obj, void* context); }; +using MergeJoinNodePtr = std::shared_ptr; + struct IndexLookupCondition : public ISerializable { /// References to an index table column. FieldAccessTypedExprPtr key; @@ -3145,6 +3406,7 @@ struct IndexLookupCondition : public ISerializable { virtual std::string toString() const = 0; }; + using IndexLookupConditionPtr = std::shared_ptr; /// Represents IN-LIST index lookup condition: contains('list', 'key'). 'list' @@ -3172,6 +3434,7 @@ struct InIndexLookupCondition : public IndexLookupCondition { private: void validate() const; }; + using InIndexLookupConditionPtr = std::shared_ptr; /// Represents BETWEEN index lookup condition: 'key' between 'lower' and @@ -3207,9 +3470,37 @@ struct BetweenIndexLookupCondition : public IndexLookupCondition { private: void validate() const; }; + using BetweenIndexLookupConditionPtr = std::shared_ptr; +/// Represents EQUAL index lookup condition: 'key' = 'value'. 'value' must be a +/// constant value with the same type as 'key'. +struct EqualIndexLookupCondition : public IndexLookupCondition { + /// The value to compare against. + TypedExprPtr value; + + EqualIndexLookupCondition(FieldAccessTypedExprPtr _key, TypedExprPtr _value) + : IndexLookupCondition(std::move(_key)), value(std::move(_value)) { + validate(); + } + + bool isFilter() const override; + + folly::dynamic serialize() const override; + + std::string toString() const override; + + static IndexLookupConditionPtr create( + const folly::dynamic& obj, + void* context); + + private: + void validate() const; +}; + +using EqualIndexLookupConditionPtr = std::shared_ptr; + /// Represents index lookup join. Translates to an exec::IndexLookupJoin /// operator. Assumes the right input is a table scan source node that provides /// indexed table lookup for the left input with the specified join keys and @@ -3246,46 +3537,28 @@ class IndexLookupJoinNode : public AbstractJoinNode { public: /// @param joinType Specifies the lookup join type. Only INNER and LEFT joins /// are supported. + /// @param leftKeys Left side join keys used for index lookup. + /// @param rightKeys Right side join keys that form the index prefix. + /// @param joinConditions Additional conditions for index lookup that can't + /// be converted into simple equality join conditions. These conditions use + /// columns from both left and right and exactly one index column from + /// the right side.sides + /// @param filter Additional filter to apply on join results. This supports + /// filters that can't be converted into join conditions. + /// @param hasMarker if true, the output type includes a boolean + /// column at the end to indicate if a join output row has a match or not. + /// This only applies for left join. IndexLookupJoinNode( const PlanNodeId& id, JoinType joinType, const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, + TypedExprPtr filter, + bool hasMarker, PlanNodePtr left, TableScanNodePtr right, - RowTypePtr outputType) - : AbstractJoinNode( - id, - joinType, - leftKeys, - rightKeys, - /*filter=*/nullptr, - std::move(left), - right, - outputType), - lookupSourceNode_(std::move(right)), - joinConditions_(joinConditions) { - VELOX_USER_CHECK( - !leftKeys.empty(), - "The lookup join node requires at least one join key"); - VELOX_USER_CHECK_EQ( - leftKeys_.size(), - rightKeys_.size(), - "The lookup join node requires same number of join keys on left and right sides"); - // TODO: add check that (1) 'rightKeys_' form an index prefix. each of - // 'joinConditions_' uses columns from both sides and uses exactly one index - // column from the right side. - VELOX_USER_CHECK( - lookupSourceNode_->tableHandle()->supportsIndexLookup(), - "The lookup table handle {} from connector {} doesn't support index lookup", - lookupSourceNode_->tableHandle()->name(), - lookupSourceNode_->tableHandle()->connectorId()); - VELOX_USER_CHECK( - isSupported(joinType_), - "Unsupported index lookup join type {}", - joinTypeName(joinType_)); - } + RowTypePtr outputType); class Builder : public AbstractJoinNode::Builder { @@ -3295,14 +3568,30 @@ class IndexLookupJoinNode : public AbstractJoinNode { explicit Builder(const IndexLookupJoinNode& other) : AbstractJoinNode::Builder(other) { joinConditions_ = other.joinConditions(); + filter_ = other.filter(); + hasMarker_ = other.hasMarker(); } + /// Set lookup conditions for index lookup that can't be converted into + /// simple equality join conditions. Builder& joinConditions( std::vector joinConditions) { joinConditions_ = std::move(joinConditions); return *this; } + /// Set additional filter to apply on join results. + Builder& filter(TypedExprPtr filter) { + filter_ = std::move(filter); + return *this; + } + + /// Set whether to include a marker column for left joins. + Builder& hasMarker(bool hasMarker) { + hasMarker_ = hasMarker; + return *this; + } + std::shared_ptr build() const { VELOX_USER_CHECK(id_.has_value(), "IndexLookupJoinNode id is not set"); VELOX_USER_CHECK( @@ -3317,23 +3606,23 @@ class IndexLookupJoinNode : public AbstractJoinNode { right_.has_value(), "IndexLookupJoinNode right source is not set"); VELOX_USER_CHECK( outputType_.has_value(), "IndexLookupJoinNode outputType is not set"); - VELOX_USER_CHECK( - joinConditions_.has_value(), - "IndexLookupJoinNode join conditions are not set"); return std::make_shared( id_.value(), joinType_.value(), leftKeys_.value(), rightKeys_.value(), - joinConditions_.value(), + joinConditions_, + filter_.value_or(nullptr), + hasMarker_, left_.value(), std::dynamic_pointer_cast(right_.value()), outputType_.value()); } private: - std::optional> joinConditions_; + std::vector joinConditions_; + bool hasMarker_{false}; }; bool supportsBarrier() const override { @@ -3344,6 +3633,8 @@ class IndexLookupJoinNode : public AbstractJoinNode { return lookupSourceNode_; } + /// Returns the join conditions for index lookup that can't be converted into + /// simple equality join conditions. const std::vector& joinConditions() const { return joinConditions_; } @@ -3352,6 +3643,11 @@ class IndexLookupJoinNode : public AbstractJoinNode { return "IndexLookupJoin"; } + /// Returns whether this node includes a marker column for left joins. + bool hasMarker() const { + return hasMarker_; + } + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) const override; @@ -3365,11 +3661,20 @@ class IndexLookupJoinNode : public AbstractJoinNode { private: void addDetails(std::stringstream& stream) const override; + /// The table scan node that provides the lookup source for index operations. const TableScanNodePtr lookupSourceNode_; + /// Join conditions that can't be converted into simple equality join + /// conditions. These conditions involve columns from both left and right + /// sides and exactly one index column from the right side. const std::vector joinConditions_; + + /// Whether to include a marker column for left joins to indicate matches. + const bool hasMarker_; }; +using IndexLookupJoinNodePtr = std::shared_ptr; + /// Returns true if 'planNode' is index lookup join node. bool isIndexLookupJoin(const core::PlanNode* planNode); @@ -3514,6 +3819,8 @@ class NestedLoopJoinNode : public PlanNode { const RowTypePtr outputType_; }; +using NestedLoopJoinNodePtr = std::shared_ptr; + // Represents the 'SortBy' node in the plan. class OrderByNode : public PlanNode { public: @@ -3655,6 +3962,228 @@ class OrderByNode : public PlanNode { const std::vector sources_; }; +using OrderByNodePtr = std::shared_ptr; + +/// Represents a spatial join between two geometries. Translates to an +/// exec::SpatialJoinProbe and exec::SpatialJoinBuild. A separate +/// pipeline is produced for the build side when generating exec::Operators. +/// +/// Spatial join supports "local" spatial predicates, i.e. predicates that +/// require defined proximity. Examples include ST_Intersects or any of +/// the DE-9IM predicates except for ST_Disjoint. It also supports +/// `ST_Distance(g1, g2) <= d`. +/// +/// The local join index is a collection of bounding boxes for a quick +/// check (either "no" or "maybe"), and the actual predicate must be +/// checked for each candidate. +/// +/// Currently only INNER joins are supported, but LEFT joins are planned. +class SpatialJoinNode : public PlanNode { + public: + SpatialJoinNode( + const PlanNodeId& id, + JoinType joinType, + TypedExprPtr joinCondition, + FieldAccessTypedExprPtr probeGeometry, + FieldAccessTypedExprPtr buildGeometry, + std::optional radius, + PlanNodePtr left, + PlanNodePtr right, + RowTypePtr outputType); + + SpatialJoinNode( + const PlanNodeId& id, + JoinType joinType, + TypedExprPtr joinCondition, + PlanNodePtr left, + PlanNodePtr right, + RowTypePtr outputType); + + PlanNodePtr leftNode() const { + return sources()[0]; + } + + PlanNodePtr rightNode() const { + return sources()[1]; + } + + class Builder { + public: + Builder() = default; + + explicit Builder(const SpatialJoinNode& other) { + id_ = other.id(); + joinType_ = other.joinType(); + joinCondition_ = other.joinCondition(); + probeGeometry_ = other.probeGeometry(); + buildGeometry_ = other.buildGeometry(); + radius_ = other.radius(); + VELOX_CHECK_EQ(other.sources().size(), 2); + left_ = other.sources()[0]; + right_ = other.sources()[1]; + outputType_ = other.outputType(); + } + + Builder& id(PlanNodeId id) { + id_ = std::move(id); + return *this; + } + + Builder& joinType(JoinType joinType) { + joinType_ = joinType; + return *this; + } + + Builder& joinCondition(TypedExprPtr joinCondition) { + joinCondition_ = std::move(joinCondition); + return *this; + } + + Builder& probeGeometry(FieldAccessTypedExprPtr probeGeometry) { + probeGeometry_ = std::move(probeGeometry); + return *this; + } + + Builder& buildGeometry(FieldAccessTypedExprPtr buildGeometry) { + buildGeometry_ = std::move(buildGeometry); + return *this; + } + + Builder& radius(FieldAccessTypedExprPtr radius) { + radius_ = std::move(radius); + return *this; + } + + Builder& left(PlanNodePtr left) { + left_ = std::move(left); + return *this; + } + + Builder& right(PlanNodePtr right) { + right_ = std::move(right); + return *this; + } + + Builder& outputType(RowTypePtr outputType) { + outputType_ = std::move(outputType); + return *this; + } + + std::shared_ptr build() const { + VELOX_USER_CHECK(id_.has_value(), "SpatialJoinNode id is not set"); + VELOX_USER_CHECK( + left_.has_value(), "SpatialJoinNode left source is not set"); + VELOX_USER_CHECK( + right_.has_value(), "SpatialJoinNode right source is not set"); + VELOX_USER_CHECK( + outputType_.has_value(), "SpatialJoinNode outputType is not set"); + VELOX_USER_CHECK( + probeGeometry_.has_value(), + "SpatialJoinNode probe geometry is not set"); + VELOX_USER_CHECK( + buildGeometry_.has_value(), + "SpatialJoinNode build geometry is not set"); + + VELOX_USER_CHECK( + (probeGeometry_.has_value() && buildGeometry_.has_value()) || + (!probeGeometry_.has_value() && !buildGeometry_.has_value()), + "Either probe and build geometry must both be set, or neither"); + + if (probeGeometry_.has_value() && buildGeometry_.has_value()) { + return std::make_shared( + id_.value(), + joinType_, + joinCondition_, + probeGeometry_.value(), + buildGeometry_.value(), + radius_, + left_.value(), + right_.value(), + outputType_.value()); + } + + return std::make_shared( + id_.value(), + joinType_, + joinCondition_, + probeGeometry_.value(), + buildGeometry_.value(), + radius_, + left_.value(), + right_.value(), + outputType_.value()); + } + + private: + std::optional id_; + JoinType joinType_ = kDefaultJoinType; + TypedExprPtr joinCondition_; + std::optional probeGeometry_; + std::optional buildGeometry_; + std::optional radius_; + std::optional left_; + std::optional right_; + std::optional outputType_; + }; + + const std::vector& sources() const override { + return sources_; + } + + void accept(const PlanNodeVisitor& visitor, PlanNodeVisitorContext& context) + const override; + + const RowTypePtr& outputType() const override { + return outputType_; + } + + std::string_view name() const override { + return "SpatialJoin"; + } + + const TypedExprPtr& joinCondition() const { + return joinCondition_; + } + + const FieldAccessTypedExprPtr& probeGeometry() const { + return probeGeometry_; + } + + const FieldAccessTypedExprPtr& buildGeometry() const { + return buildGeometry_; + } + + const std::optional& radius() const { + return radius_; + } + + JoinType joinType() const { + return joinType_; + } + + folly::dynamic serialize() const override; + + /// If spatial join supports this join type. + static bool isSupported(JoinType joinType); + + static PlanNodePtr create(const folly::dynamic& obj, void* context); + + private: + constexpr static JoinType kDefaultJoinType = JoinType::kInner; + + void addDetails(std::stringstream& stream) const override; + + const JoinType joinType_; + const TypedExprPtr joinCondition_; + const FieldAccessTypedExprPtr probeGeometry_; + const FieldAccessTypedExprPtr buildGeometry_; + const std::optional radius_; + const std::vector sources_; + const RowTypePtr outputType_; +}; + +using SpatialJoinNodePtr = std::shared_ptr; + class TopNNode : public PlanNode { public: TopNNode( @@ -3782,6 +4311,8 @@ class TopNNode : public PlanNode { const std::vector sources_; }; +using TopNNodePtr = std::shared_ptr; + class LimitNode : public PlanNode { public: // @param isPartial Boolean indicating whether Limit node generates partial @@ -3866,6 +4397,10 @@ class LimitNode : public PlanNode { std::optional source_; }; + bool supportsBarrier() const override { + return true; + } + const RowTypePtr& outputType() const override { return sources_[0]->outputType(); } @@ -3906,6 +4441,8 @@ class LimitNode : public PlanNode { const std::vector sources_; }; +using LimitNodePtr = std::shared_ptr; + /// Expands arrays and maps into separate columns. Arrays are expanded into a /// single column, and maps are expanded into two columns (key, value). Can be /// used to expand multiple columns. In this case will produce as many rows as @@ -3922,12 +4459,17 @@ class UnnestNode : public PlanNode { /// names must appear in the same order as unnestVariables. /// @param ordinalityName Optional name for the ordinality columns. If not /// present, ordinality column is not produced. + /// @param markerName Optional name for column which indicates whether an + /// output row has non-empty unnested value. If not present, marker column is + /// not provided and the unnest operator also skips producing output rows + /// with empty unnest value. UnnestNode( const PlanNodeId& id, std::vector replicateVariables, std::vector unnestVariables, std::vector unnestNames, std::optional ordinalityName, + std::optional markerName, const PlanNodePtr& source); class Builder { @@ -3976,6 +4518,11 @@ class UnnestNode : public PlanNode { return *this; } + Builder& markerName(std::optional markerName) { + markerName_ = std::move(markerName); + return *this; + } + std::shared_ptr build() const { VELOX_USER_CHECK(id_.has_value(), "UnnestNode id is not set"); VELOX_USER_CHECK( @@ -3994,6 +4541,7 @@ class UnnestNode : public PlanNode { unnestVariables_.value(), unnestNames_.value(), ordinalityName_, + markerName_, source_.value()); } @@ -4003,6 +4551,7 @@ class UnnestNode : public PlanNode { std::optional> unnestVariables_; std::optional> unnestNames_; std::optional ordinalityName_; + std::optional markerName_; std::optional source_; }; @@ -4032,10 +4581,26 @@ class UnnestNode : public PlanNode { return unnestVariables_; } - bool withOrdinality() const { + const std::vector& unnestNames() const { + return unnestNames_; + } + + const std::optional& ordinalityName() const { + return ordinalityName_; + } + + bool hasOrdinality() const { return ordinalityName_.has_value(); } + const std::optional& markerName() const { + return markerName_; + } + + bool hasMarker() const { + return markerName_.has_value(); + } + std::string_view name() const override { return "Unnest"; } @@ -4051,10 +4616,13 @@ class UnnestNode : public PlanNode { const std::vector unnestVariables_; const std::vector unnestNames_; const std::optional ordinalityName_; + const std::optional markerName_; const std::vector sources_; RowTypePtr outputType_; }; +using UnnestNodePtr = std::shared_ptr; + /// Checks that input contains at most one row. Return that row as is. If /// input is empty, returns a single row with all values set to null. If input /// contains more than one row raises an exception. @@ -4124,6 +4692,8 @@ class EnforceSingleRowNode : public PlanNode { const std::vector sources_; }; +using EnforceSingleRowNodePtr = std::shared_ptr; + /// Adds a new column named `idName` at the end of the input columns /// with unique int64_t value per input row. /// @@ -4142,6 +4712,10 @@ class AssignUniqueIdNode : public PlanNode { const int32_t taskUniqueId, PlanNodePtr source); + bool supportsBarrier() const override { + return true; + } + class Builder { public: Builder() = default; @@ -4230,6 +4804,8 @@ class AssignUniqueIdNode : public PlanNode { std::shared_ptr uniqueIdCounter_; }; +using AssignUniqueIdNodePtr = std::shared_ptr; + /// PlanNode used for evaluating Sql window functions. /// All window functions evaluated in the operator have the same /// window spec (partition keys + order columns). @@ -4249,9 +4825,7 @@ class WindowNode : public PlanNode { public: enum class WindowType { kRange, kRows }; - static const char* windowTypeName(WindowType type); - - static WindowType windowTypeFromName(const std::string& name); + VELOX_DECLARE_EMBEDDED_ENUM_NAME(WindowType) enum class BoundType { kUnboundedPreceding, @@ -4261,9 +4835,7 @@ class WindowNode : public PlanNode { kUnboundedFollowing }; - static const char* boundTypeName(BoundType type); - - static BoundType boundTypeFromName(const std::string& name); + VELOX_DECLARE_EMBEDDED_ENUM_NAME(BoundType) /// Window frames can be ROW or RANGE type. /// Frame bounds can be CURRENT ROW, UNBOUNDED PRECEDING(FOLLOWING) @@ -4451,6 +5023,10 @@ class WindowNode : public PlanNode { return "Window"; } + const std::vector& windowColumnNames() const { + return windowColumnNames_; + } + folly::dynamic serialize() const override; static PlanNodePtr create(const folly::dynamic& obj, void* context); @@ -4473,6 +5049,8 @@ class WindowNode : public PlanNode { const RowTypePtr outputType_; }; +using WindowNodePtr = std::shared_ptr; + /// Optimized version of a WindowNode for a single row_number function with an /// optional limit and no sorting. /// The output of this node contains all input columns followed by an optional @@ -4607,6 +5185,8 @@ class RowNumberNode : public PlanNode { const RowTypePtr outputType_; }; +using RowNumberNodePtr = std::shared_ptr; + /// The MarkDistinct operator marks unique rows based on distinctKeys. /// The result is put in a new markerName column alongside the original input. /// @param markerName Name of the output mask channel. @@ -4717,24 +5297,40 @@ class MarkDistinctNode : public PlanNode { const RowTypePtr outputType_; }; -/// Optimized version of a WindowNode for a single row_number function with a -/// limit over sorted partitions. -/// The output of this node contains all input columns followed by an optional +using MarkDistinctNodePtr = std::shared_ptr; + +/// Optimized version of a WindowNode for a single row_number, rank or +/// dense_rank function with a limit over sorted partitions. The output of this +/// node contains all input columns followed by an optional /// 'rowNumberColumnName' BIGINT column. +/// TODO: This node will be renamed to TopNRank or TopNRowNode once all the +/// support for handling rank and dense_rank is committed to Velox. class TopNRowNumberNode : public PlanNode { public: + enum class RankFunction { + kRowNumber, + kRank, + kDenseRank, + }; + + static const char* rankFunctionName(RankFunction function); + + static RankFunction rankFunctionFromName(std::string_view name); + + /// @param rankFunction RanksFunction (row_number, rank, dense_rank) for TopN. /// @param partitionKeys Partitioning keys. May be empty. /// @param sortingKeys Sorting keys. May not be empty and may not intersect /// with 'partitionKeys'. /// @param sortingOrders Sorting orders, one per sorting key. /// @param rowNumberColumnName Optional name of the column containing row - /// numbers. If not specified, the output doesn't include 'row number' - /// column. This is used when computing partial results. + /// numbers or rank or dense_rank. If not specified, the output doesn't + /// include 'row_number' column. This is used when computing partial results. /// @param limit Per-partition limit. The number of /// rows produced by this node will not exceed this value for any given /// partition. Extra rows will be dropped. TopNRowNumberNode( PlanNodeId id, + RankFunction function, std::vector partitionKeys, std::vector sortingKeys, std::vector sortingOrders, @@ -4757,6 +5353,7 @@ class TopNRowNumberNode : public PlanNode { limit_ = other.limit(); VELOX_CHECK_EQ(other.sources().size(), 1); source_ = other.sources()[0]; + function_ = other.rankFunction(); } Builder& id(PlanNodeId id) { @@ -4764,6 +5361,11 @@ class TopNRowNumberNode : public PlanNode { return *this; } + Builder& function(RankFunction function) { + function_ = function; + return *this; + } + Builder& partitionKeys(std::vector partitionKeys) { partitionKeys_ = std::move(partitionKeys); return *this; @@ -4815,6 +5417,7 @@ class TopNRowNumberNode : public PlanNode { return std::make_shared( id_.value(), + function_.has_value() ? function_.value() : RankFunction::kRowNumber, partitionKeys_.value(), sortingKeys_.value(), sortingOrders_.value(), @@ -4825,6 +5428,7 @@ class TopNRowNumberNode : public PlanNode { private: std::optional id_; + std::optional function_; std::optional> partitionKeys_; std::optional> sortingKeys_; std::optional> sortingOrders_; @@ -4868,6 +5472,10 @@ class TopNRowNumberNode : public PlanNode { return limit_; } + RankFunction rankFunction() const { + return function_; + } + bool generateRowNumber() const { return outputType_->size() > sources_[0]->outputType()->size(); } @@ -4883,6 +5491,8 @@ class TopNRowNumberNode : public PlanNode { private: void addDetails(std::stringstream& stream) const override; + const RankFunction function_; + const std::vector partitionKeys_; const std::vector sortingKeys_; @@ -4895,6 +5505,8 @@ class TopNRowNumberNode : public PlanNode { const RowTypePtr outputType_; }; +using TopNRowNumberNodePtr = std::shared_ptr; + class PlanNodeVisitorContext { public: virtual ~PlanNodeVisitorContext() = default; @@ -4960,6 +5572,9 @@ class PlanNodeVisitor { const NestedLoopJoinNode& node, PlanNodeVisitorContext& ctx) const = 0; + virtual void visit(const SpatialJoinNode& node, PlanNodeVisitorContext& ctx) + const = 0; + virtual void visit(const OrderByNode& node, PlanNodeVisitorContext& ctx) const = 0; @@ -4970,6 +5585,10 @@ class PlanNodeVisitor { virtual void visit(const ProjectNode& node, PlanNodeVisitorContext& ctx) const = 0; + virtual void visit( + const ParallelProjectNode& node, + PlanNodeVisitorContext& ctx) const = 0; + virtual void visit(const RowNumberNode& node, PlanNodeVisitorContext& ctx) const = 0; @@ -5017,12 +5636,12 @@ class PlanNodeVisitor { template <> struct fmt::formatter - : formatter { + : formatter { auto format( facebook::velox::core::PartitionedOutputNode::Kind s, format_context& ctx) const { - return formatter::format( - facebook::velox::core::PartitionedOutputNode::kindString(s), ctx); + return formatter::format( + facebook::velox::core::PartitionedOutputNode::toName(s), ctx); } }; @@ -5032,3 +5651,25 @@ struct fmt::formatter : formatter { return formatter::format(static_cast(s), ctx); } }; + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::core::TopNRowNumberNode::RankFunction f, + format_context& ctx) const { + return formatter::format( + facebook::velox::core::TopNRowNumberNode::rankFunctionName(f), ctx); + } +}; + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::core::AggregationNode::Step s, + format_context& ctx) const { + return formatter::format( + facebook::velox::core::AggregationNode::toName(s), ctx); + } +}; diff --git a/velox/core/QueryConfig.cpp b/velox/core/QueryConfig.cpp index 3d5b25ff9487..736332280bb6 100644 --- a/velox/core/QueryConfig.cpp +++ b/velox/core/QueryConfig.cpp @@ -22,29 +22,26 @@ namespace facebook::velox::core { -QueryConfig::QueryConfig( - const std::unordered_map& values) - : config_{std::make_unique( - std::unordered_map(values))} { - validateConfig(); -} +QueryConfig::QueryConfig(std::unordered_map values) + : QueryConfig{ + ConfigTag{}, + std::make_shared(std::move(values))} {} -QueryConfig::QueryConfig(std::unordered_map&& values) - : config_{std::make_unique(std::move(values))} { +QueryConfig::QueryConfig( + ConfigTag /*tag*/, + std::shared_ptr config) + : config_{std::move(config)} { validateConfig(); } void QueryConfig::validateConfig() { // Validate if timezone name can be recognized. - if (config_->valueExists(QueryConfig::kSessionTimezone)) { + if (auto tz = config_->get(QueryConfig::kSessionTimezone)) { VELOX_USER_CHECK( - tz::getTimeZoneID( - config_->get(QueryConfig::kSessionTimezone).value(), - false) != -1, - fmt::format( - "session '{}' set with invalid value '{}'", - QueryConfig::kSessionTimezone, - config_->get(QueryConfig::kSessionTimezone).value())); + tz::getTimeZoneID(*tz, false) != -1, + "session '{}' set with invalid value '{}'", + QueryConfig::kSessionTimezone, + *tz); } } diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index c86058f844e5..f5b9a7fe19ab 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -21,16 +21,21 @@ namespace facebook::velox::core { -/// A simple wrapper around velox::ConfigBase. Defines constants for query +/// A simple wrapper around velox::IConfig. Defines constants for query /// config properties and accessor methods. /// Create per query context. Does not have a singleton instance. /// Does not allow altering properties on the fly. Only at creation time. class QueryConfig { public: - explicit QueryConfig( - const std::unordered_map& values); + explicit QueryConfig(std::unordered_map values); + + // This is needed only to resolve correct ctor for cases like + // QueryConfig{{}} or QueryConfig({}). + struct ConfigTag {}; - explicit QueryConfig(std::unordered_map&& values); + explicit QueryConfig( + ConfigTag /*tag*/, + std::shared_ptr config); /// Maximum memory that a query can use on a single host. static constexpr const char* kQueryMaxMemoryPerNode = @@ -40,6 +45,11 @@ class QueryConfig { /// name, e.g: "America/Los_Angeles". static constexpr const char* kSessionTimezone = "session_timezone"; + /// Session start time in milliseconds since Unix epoch. This represents when + /// the query session began execution. Used for functions that need to know + /// the session start time (e.g., current_date, localtime). + static constexpr const char* kSessionStartTime = "start_time"; + /// If true, timezone-less timestamp conversions (e.g. string to timestamp, /// when the string does not specify a timezone) will be adjusted to the user /// provided session timezone (if any). @@ -64,6 +74,26 @@ class QueryConfig { static constexpr const char* kExprTrackCpuUsage = "expression.track_cpu_usage"; + /// Takes a comma separated list of function names to track CPU usage for. + /// Only applicable when kExprTrackCpuUsage is set to false. Is empty by + /// default. This allows fine-grained control over CPU tracking overhead when + /// only specific functions need to be monitored. + static constexpr const char* kExprTrackCpuUsageForFunctions = + "expression.track_cpu_usage_for_functions"; + + /// Controls whether non-deterministic expressions are deduplicated during + /// compilation. This is intended for testing and debugging purposes. By + /// default, this is set to true to preserve standard behavior. If set to + /// false, non-deterministic functions (such as rand()) will not be + /// deduplicated. Since non-deterministic functions may yield different + /// outputs on each call, disabling deduplication guarantees that the function + /// is executed only when the original expression is evaluated, rather than + /// being triggered for every deduplicated instance. This ensures each + /// invocation corresponds directly to the actual expression, maintaining + /// independent behavior for each call. + static constexpr const char* kExprDedupNonDeterministic = + "expression.dedup_non_deterministic"; + /// Whether to track CPU usage for stages of individual operators. True by /// default. Can be expensive when processing small batches, e.g. < 10K rows. static constexpr const char* kOperatorTrackCpuUsage = @@ -139,6 +169,11 @@ class QueryConfig { static constexpr const char* kLocalExchangePartitionBufferPreserveEncoding = "local_exchange_partition_buffer_preserve_encoding"; + /// Maximum number of vectors buffered in each local merge source before + /// blocking to wait for consumers. + static constexpr const char* kLocalMergeSourceQueueSize = + "local_merge_source_queue_size"; + /// Maximum size in bytes to accumulate in ExchangeQueue. Enforced /// approximately, not strictly. static constexpr const char* kMaxExchangeBufferSize = @@ -170,12 +205,42 @@ class QueryConfig { static constexpr const char* kAbandonPartialAggregationMinPct = "abandon_partial_aggregation_min_pct"; + /// Memory threshold in bytes for triggering string compaction during + /// global aggregation. When total string storage exceeds this limit with + /// high unused memory ratio, compaction is triggered to reclaim dead strings. + /// Disabled by default (0). + /// + /// NOTE: currently only applies to approx_most_frequent aggregate with + /// StringView type during global aggregation. May extend to other types. + static constexpr const char* kAggregationCompactionBytesThreshold = + "aggregation_compaction_bytes_threshold"; + + /// Ratio of unused (evicted) bytes to total bytes that triggers compaction. + /// Value is between 0.0 and 1.0. Default is 0.25. + /// + /// NOTE: currently only applies to approx_most_frequent aggregate with + /// StringView type during global aggregation. May extend to other types. + static constexpr const char* kAggregationCompactionUnusedMemoryRatio = + "aggregation_compaction_unused_memory_ratio"; + static constexpr const char* kAbandonPartialTopNRowNumberMinRows = "abandon_partial_topn_row_number_min_rows"; static constexpr const char* kAbandonPartialTopNRowNumberMinPct = "abandon_partial_topn_row_number_min_pct"; + /// Number of input rows to receive before starting to check whether to + /// abandon building a HashTable without duplicates in HashBuild for left + /// semi/anti join. + static constexpr const char* kAbandonDedupHashMapMinRows = + "abandon_dedup_hashmap_min_rows"; + + /// Abandons building a HashTable without duplicates in HashBuild for left + /// semi/anti join if the percentage of distinct keys in the HashTable exceeds + /// this threshold. Zero means 'disable this optimization'. + static constexpr const char* kAbandonDedupHashMapMinPct = + "abandon_dedup_hashmap_min_pct"; + static constexpr const char* kMaxElementsSizeInRepeatAndSequence = "max_elements_size_in_repeat_and_sequence"; @@ -230,6 +295,10 @@ class QueryConfig { static constexpr const char* kAdaptiveFilterReorderingEnabled = "adaptive_filter_reordering_enabled"; + /// If true, allow hash probe drivers to generate build-side rows in parallel. + static constexpr const char* kParallelOutputJoinBuildRowsEnabled = + "parallel_output_join_build_rows_enabled"; + /// Global enable spilling flag. static constexpr const char* kSpillEnabled = "spill_enabled"; @@ -250,6 +319,13 @@ class QueryConfig { /// Window spilling flag, only applies if "spill_enabled" flag is set. static constexpr const char* kWindowSpillEnabled = "window_spill_enabled"; + /// When processing spilled window data, read batches of whole partitions + /// having at least that many rows. Set to 1 to read one whole partition at a + /// time. Each driver processing the Window operator will process that much + /// data at once. + static constexpr const char* kWindowSpillMinReadBatchRows = + "window_spill_min_read_batch_rows"; + /// If true, the memory arbitrator will reclaim memory from table writer by /// flushing its buffered data to disk. only applies if "spill_enabled" flag /// is set. @@ -263,6 +339,14 @@ class QueryConfig { static constexpr const char* kTopNRowNumberSpillEnabled = "topn_row_number_spill_enabled"; + /// LocalMerge spilling flag, only applies if "spill_enabled" flag is set. + static constexpr const char* kLocalMergeSpillEnabled = + "local_merge_spill_enabled"; + + /// Specify the max number of local sources to merge at a time. + static constexpr const char* kLocalMergeMaxNumMergeSources = + "local_merge_max_num_merge_sources"; + /// The max row numbers to fill and spill for each spill run. This is used to /// cap the memory used for spilling. If it is zero, then there is no limit /// and spilling might run out of memory. @@ -291,6 +375,14 @@ class QueryConfig { static constexpr const char* kSpillCompressionKind = "spill_compression_codec"; + /// The max number of files to merge at a time when merging sorted files into + /// a single ordered stream. 0 means unlimited. This is used to reduce memory + /// pressure by capping the number of open files when merging spilled sorted + /// files to avoid using too much memory and causing OOM. Note that this is + /// only applicable for ordered spill. + static constexpr const char* kSpillNumMaxMergeFiles = + "spill_num_max_merge_files"; + /// Enable the prefix sort or fallback to timsort in spill. The prefix sort is /// faster than std::sort but requires the memory to build normalized prefix /// keys, which might have potential risk of running out of server memory. @@ -316,7 +408,7 @@ class QueryConfig { "spill_file_create_config"; /// Default offset spill start partition bit. It is used with - /// 'kJoinSpillPartitionBits' or 'kAggregationSpillPartitionBits' together to + /// 'kSpillNumPartitionBits' together to /// calculate the spilling partition number for join spill or aggregation /// spill. static constexpr const char* kSpillStartPartitionBit = @@ -356,6 +448,14 @@ class QueryConfig { static constexpr const char* kPrestoArrayAggIgnoreNulls = "presto.array_agg.ignore_nulls"; + /// If true, Spark function's behavior is ANSI-compliant, e.g. throws runtime + /// exception instead of returning null on invalid inputs. It affects only + /// functions explicitly marked as "ANSI compliant". + /// Note: This feature is still under development to achieve full ANSI + /// compliance. Users can refer to the Spark function documentation to verify + /// the current support status of a specific function. + static constexpr const char* kSparkAnsiEnabled = "spark.ansi_enabled"; + /// The default number of expected items for the bloomfilter. static constexpr const char* kSparkBloomFilterExpectedNumItems = "spark.bloom_filter.expected_num_items"; @@ -383,6 +483,11 @@ class QueryConfig { static constexpr const char* kSparkLegacyStatisticalAggregate = "spark.legacy_statistical_aggregate"; + /// If true, ignore null fields when generating JSON string. + /// If false, null fields are included with a null value. + static constexpr const char* kSparkJsonIgnoreNullFields = + "spark.json_ignore_null_fields"; + /// The number of local parallel table writer operators per task. static constexpr const char* kTaskWriterCount = "task_writer_count"; @@ -396,6 +501,18 @@ class QueryConfig { static constexpr const char* kHashProbeFinishEarlyOnEmptyBuild = "hash_probe_finish_early_on_empty_build"; + /// Whether hash probe can generate any dynamic filter (including Bloom + /// filter) and push down to upstream operators. + static constexpr const char* kHashProbeDynamicFilterPushdownEnabled = + "hash_probe_dynamic_filter_pushdown_enabled"; + + /// The maximum byte size of Bloom filter that can be generated from hash + /// probe. When set to 0, no Bloom filter will be generated. To achieve + /// optimal performance, this should not be too larger than the CPU cache size + /// on the host. + static constexpr const char* kHashProbeBloomFilterPushdownMaxSize = + "hash_probe_bloom_filter_pushdown_max_size"; + /// The minimum number of table rows that can trigger the parallel hash join /// table build. static constexpr const char* kMinTableRowsForParallelJoinBuild = @@ -434,6 +551,12 @@ class QueryConfig { static constexpr const char* kDriverCpuTimeSliceLimitMs = "driver_cpu_time_slice_limit_ms"; + /// Window operator can be configured to sub-divide window partitions on each + /// thread of execution into groups of partitions for sequential processing. + /// This setting specifies how many sub-partitions to create for each thread. + static constexpr const char* kWindowNumSubPartitions = + "window_num_sub_partitions"; + /// Maximum number of bytes to use for the normalized key in prefix-sort. Use /// 0 to disable prefix-sort. static constexpr const char* kPrefixSortNormalizedKeyMaxBytes = @@ -454,9 +577,9 @@ class QueryConfig { /// Base dir of a query to store tracing data. static constexpr const char* kQueryTraceDir = "query_trace_dir"; - /// A comma-separated list of plan node ids whose input data will be traced. + /// The plan node id whose input data will be traced. /// Empty string if only want to trace the query metadata. - static constexpr const char* kQueryTraceNodeIds = "query_trace_node_ids"; + static constexpr const char* kQueryTraceNodeId = "query_trace_node_id"; /// The max trace bytes limit. Tracing is disabled if zero. static constexpr const char* kQueryTraceMaxBytes = "query_trace_max_bytes"; @@ -466,6 +589,10 @@ class QueryConfig { static constexpr const char* kQueryTraceTaskRegExp = "query_trace_task_reg_exp"; + /// If true, we only collect the input trace for a given operator but without + /// the actual execution. + static constexpr const char* kQueryTraceDryRun = "query_trace_dry_run"; + /// Config used to create operator trace directory. This config is provided to /// underlying file system and the config is free form. The form should be /// defined by the underlying file system. @@ -506,15 +633,35 @@ class QueryConfig { static constexpr const char* kDebugMemoryPoolNameRegex = "debug_memory_pool_name_regex"; + /// Warning threshold in bytes for debug memory pools. When set to a + /// non-zero value, a warning will be logged once per memory pool when + /// allocations cause the pool to exceed this threshold. This is useful for + /// identifying memory usage patterns during debugging. Requires allocation + /// tracking to be enabled via `debug_memory_pool_name_regex` for the pool. A + /// value of 0 means no warning threshold is enforced. + static constexpr const char* kDebugMemoryPoolWarnThresholdBytes = + "debug_memory_pool_warn_threshold_bytes"; + /// Some lambda functions over arrays and maps are evaluated in batches of the /// underlying elements that comprise the arrays/maps. This is done to make - /// the batch size managable as array vectors can have thousands of elements + /// the batch size manageable as array vectors can have thousands of elements /// each and hit scaling limits as implementations typically expect /// BaseVectors to a couple of thousand entries. This lets up tune those batch /// sizes. static constexpr const char* kDebugLambdaFunctionEvaluationBatchSize = "debug_lambda_function_evaluation_batch_size"; + /// The UDF `bing_tile_children` generates the children of a Bing tile based + /// on a specified target zoom level. The number of children produced is + /// determined by the difference between the target zoom level and the zoom + /// level of the input tile. This configuration limits the number of children + /// by capping the maximum zoom level difference, with a default value set + /// to 5. This cap is necessary to prevent excessively large array outputs, + /// which can exceed the size limits of the elements vector in the Velox array + /// vector. + static constexpr const char* kDebugBingTileChildrenMaxZoomShift = + "debug_bing_tile_children_max_zoom_shift"; + /// Temporary flag to control whether selective Nimble reader should be used /// in this query or not. Will be removed after the selective Nimble reader /// is fully rolled out. @@ -580,6 +727,12 @@ class QueryConfig { static constexpr const char* kIndexLookupJoinMaxPrefetchBatches = "index_lookup_join_max_prefetch_batches"; + /// If this is true, then the index join operator might split output for each + /// input batch based on the output batch size control. Otherwise, it tries to + /// produce a single output for each input batch. + static constexpr const char* kIndexLookupJoinSplitOutput = + "index_lookup_join_split_output"; + // Max wait time for exchange request in seconds. static constexpr const char* kRequestDataSizesMaxWaitSec = "request_data_sizes_max_wait_sec"; @@ -594,15 +747,86 @@ class QueryConfig { static constexpr const char* kStreamingAggregationEagerFlush = "streaming_aggregation_eager_flush"; + // If true, skip request data size if there is only single source. + // This is used to optimize the Presto-on-Spark use case where each + // exchange client has only one shuffle partition source. + static constexpr const char* kSkipRequestDataSizeWithSingleSourceEnabled = + "skip_request_data_size_with_single_source_enabled"; + + /// If true, exchange clients defer data fetching until next() is called. + /// This enables waiter tasks using cached hash tables to skip I/O entirely + /// when the table is already cached. If false (default), exchange clients + /// start fetching data immediately when remote tasks are added. + static constexpr const char* kExchangeLazyFetchingEnabled = + "exchange_lazy_fetching_enabled"; + /// If this is true, then it allows you to get the struct field names /// as json element names when casting a row to json. static constexpr const char* kFieldNamesInJsonCastEnabled = "field_names_in_json_cast_enabled"; + /// If this is true, then operators that evaluate expressions will track + /// stats for expressions that are not special forms and return them as + /// part of their operator stats. Tracking these stats can be expensive + /// (especially if operator stats are retrieved frequently) and this allows + /// the user to explicitly enable it. + static constexpr const char* kOperatorTrackExpressionStats = + "operator_track_expression_stats"; + + /// If this is true, enable the operator input/output batch size stats + /// collection in driver execution. This can be expensive for data types with + /// a large number of columns (e.g., ROW types) as it calls estimateFlatSize() + /// which recursively calculates sizes for all child vectors. + static constexpr const char* kEnableOperatorBatchSizeStats = + "enable_operator_batch_size_stats"; + + /// If this is true, then the unnest operator might split output for each + /// input batch based on the output batch size control. Otherwise, it produces + /// a single output for each input batch. + static constexpr const char* kUnnestSplitOutput = "unnest_split_output"; + + /// Priority of the query in the memory pool reclaimer. Lower value means + /// higher priority. This is used in global arbitration victim selection. + static constexpr const char* kQueryMemoryReclaimerPriority = + "query_memory_reclaimer_priority"; + + /// The max number of input splits to listen to by SplitListener per table + /// scan node per worker. It's up to the SplitListener implementation to + /// respect this config. + static constexpr const char* kMaxNumSplitsListenedTo = + "max_num_splits_listened_to"; + + /// Source of the query. Used by Presto to identify the file system username. + static constexpr const char* kSource = "source"; + + /// Client tags of the query. Used by Presto to identify the file system + /// username. + static constexpr const char* kClientTags = "client_tags"; + + /// Enable (reader) row size tracker as a fallback to file level row size + /// estimates. + static constexpr const char* kRowSizeTrackingMode = "row_size_tracking_mode"; + + /// Maximum number of distinct values to keep when merging vector hashers in + /// join HashBuild. + static constexpr const char* kJoinBuildVectorHasherMaxNumDistinct = + "join_build_vector_hasher_max_num_distinct"; + + enum class RowSizeTrackingMode { + DISABLED = 0, + EXCLUDE_DELTA_SPLITS = 1, + ENABLED_FOR_ALL = 2, + }; + bool selectiveNimbleReaderEnabled() const { return get(kSelectiveNimbleReaderEnabled, false); } + RowSizeTrackingMode rowSizeTrackingMode() const { + return get( + kRowSizeTrackingMode, RowSizeTrackingMode::ENABLED_FOR_ALL); + } + bool debugDisableExpressionsWithPeeling() const { return get(kDebugDisableExpressionWithPeeling, false); } @@ -623,6 +847,12 @@ class QueryConfig { return get(kDebugMemoryPoolNameRegex, ""); } + uint64_t debugMemoryPoolWarnThresholdBytes() const { + return config::toCapacity( + get(kDebugMemoryPoolWarnThresholdBytes, "0B"), + config::CapacityUnit::BYTE); + } + std::optional debugAggregationApproxPercentileFixedRandomSeed() const { return get(kDebugAggregationApproxPercentileFixedRandomSeed); @@ -632,6 +862,10 @@ class QueryConfig { return get(kDebugLambdaFunctionEvaluationBatchSize, 10'000); } + uint8_t debugBingTileChildrenMaxZoomShift() const { + return get(kDebugBingTileChildrenMaxZoomShift, 7); + } + uint64_t queryMaxMemoryPerNode() const { return config::toCapacity( get(kQueryMaxMemoryPerNode, "0B"), @@ -656,6 +890,14 @@ class QueryConfig { return get(kAbandonPartialAggregationMinPct, 80); } + uint64_t aggregationCompactionBytesThreshold() const { + return get(kAggregationCompactionBytesThreshold, 0); + } + + double aggregationCompactionUnusedMemoryRatio() const { + return get(kAggregationCompactionUnusedMemoryRatio, 0.25); + } + int32_t abandonPartialTopNRowNumberMinRows() const { return get(kAbandonPartialTopNRowNumberMinRows, 100'000); } @@ -664,6 +906,14 @@ class QueryConfig { return get(kAbandonPartialTopNRowNumberMinPct, 80); } + int32_t abandonHashBuildDedupMinRows() const { + return get(kAbandonDedupHashMapMinRows, 100'000); + } + + int32_t abandonHashBuildDedupMinPct() const { + return get(kAbandonDedupHashMapMinPct, 0); + } + int32_t maxElementsSizeInRepeatAndSequence() const { return get(kMaxElementsSizeInRepeatAndSequence, 10'000); } @@ -720,6 +970,10 @@ class QueryConfig { return get(kLocalExchangePartitionBufferPreserveEncoding, false); } + uint32_t localMergeSourceQueueSize() const { + return get(kLocalMergeSourceQueueSize, 2); + } + uint64_t maxExchangeBufferSize() const { static constexpr uint64_t kDefault = 32UL << 20; return get(kMaxExchangeBufferSize, kDefault); @@ -799,10 +1053,20 @@ class QueryConfig { return get(kSessionTimezone, ""); } + /// Returns the session start time in milliseconds since Unix epoch. + /// If not set, returns 0 (or epoch). + int64_t sessionStartTimeMs() const { + return get(kSessionStartTime, 0); + } + bool exprEvalSimplified() const { return get(kExprEvalSimplified, false); } + bool parallelOutputJoinBuildRowsEnabled() const { + return get(kParallelOutputJoinBuildRowsEnabled, false); + } + bool spillEnabled() const { return get(kSpillEnabled, false); } @@ -827,6 +1091,10 @@ class QueryConfig { return get(kWindowSpillEnabled, true); } + uint32_t windowSpillMinReadBatchRows() const { + return get(kWindowSpillMinReadBatchRows, 1'000); + } + bool writerSpillEnabled() const { return get(kWriterSpillEnabled, true); } @@ -839,6 +1107,17 @@ class QueryConfig { return get(kTopNRowNumberSpillEnabled, true); } + bool localMergeSpillEnabled() const { + return get(kLocalMergeSpillEnabled, false); + } + + uint32_t localMergeMaxNumMergeSources() const { + const auto maxNumMergeSources = get( + kLocalMergeMaxNumMergeSources, std::numeric_limits::max()); + VELOX_CHECK_GT(maxNumMergeSources, 0); + return maxNumMergeSources; + } + int32_t maxSpillLevel() const { return get(kMaxSpillLevel, 1); } @@ -868,6 +1147,11 @@ class QueryConfig { return get(kSpillCompressionKind, "none"); } + uint32_t spillNumMaxMergeFiles() const { + constexpr uint32_t kDefaultMergeFiles = 0; + return get(kSpillNumMaxMergeFiles, kDefaultMergeFiles); + } + bool spillPrefixSortEnabled() const { return get(kSpillPrefixSortEnabled, false); } @@ -905,9 +1189,9 @@ class QueryConfig { return get(kQueryTraceDir, ""); } - std::string queryTraceNodeIds() const { - // The default query trace nodes, empty by default. - return get(kQueryTraceNodeIds, ""); + std::string queryTraceNodeId() const { + // The default query trace node ID, empty by default. + return get(kQueryTraceNodeId, ""); } uint64_t queryTraceMaxBytes() const { @@ -919,6 +1203,10 @@ class QueryConfig { return get(kQueryTraceTaskRegExp, ""); } + bool queryTraceDryRun() const { + return get(kQueryTraceDryRun, false); + } + std::string opTraceDirectoryCreateConfig() const { return get(kOpTraceDirectoryCreateConfig, ""); } @@ -927,6 +1215,10 @@ class QueryConfig { return get(kPrestoArrayAggIgnoreNulls, false); } + bool sparkAnsiEnabled() const { + return get(kSparkAnsiEnabled, false); + } + int64_t sparkBloomFilterExpectedNumItems() const { constexpr int64_t kDefault = 1'000'000L; return get(kSparkBloomFilterExpectedNumItems, kDefault); @@ -966,10 +1258,22 @@ class QueryConfig { return get(kSparkLegacyStatisticalAggregate, false); } + bool sparkJsonIgnoreNullFields() const { + return get(kSparkJsonIgnoreNullFields, true); + } + bool exprTrackCpuUsage() const { return get(kExprTrackCpuUsage, false); } + std::string exprTrackCpuUsageForFunctions() const { + return get(kExprTrackCpuUsageForFunctions, ""); + } + + bool exprDedupNonDeterministic() const { + return get(kExprDedupNonDeterministic, true); + } + bool operatorTrackCpuUsage() const { return get(kOperatorTrackCpuUsage, true); } @@ -987,6 +1291,14 @@ class QueryConfig { return get(kHashProbeFinishEarlyOnEmptyBuild, false); } + bool hashProbeDynamicFilterPushdownEnabled() const { + return get(kHashProbeDynamicFilterPushdownEnabled, true); + } + + uint64_t hashProbeBloomFilterPushdownMaxSize() const { + return get(kHashProbeBloomFilterPushdownMaxSize, 0); + } + uint32_t minTableRowsForParallelJoinBuild() const { return get(kMinTableRowsForParallelJoinBuild, 1'000); } @@ -1023,6 +1335,10 @@ class QueryConfig { return get(kDriverCpuTimeSliceLimitMs, 0); } + uint32_t windowNumSubPartitions() const { + return get(kWindowNumSubPartitions, 1); + } + uint32_t prefixSortNormalizedKeyMaxBytes() const { return get(kPrefixSortNormalizedKeyMaxBytes, 128); } @@ -1065,6 +1381,10 @@ class QueryConfig { return get(kIndexLookupJoinMaxPrefetchBatches, 0); } + bool indexLookupJoinSplitOutput() const { + return get(kIndexLookupJoinSplitOutput, true); + } + std::string shuffleCompressionKind() const { return get(kShuffleCompressionKind, "none"); } @@ -1086,17 +1406,63 @@ class QueryConfig { return get(kStreamingAggregationMinOutputBatchRows, 0); } + bool singleSourceExchangeOptimizationEnabled() const { + return get(kSkipRequestDataSizeWithSingleSourceEnabled, false); + } + + bool exchangeLazyFetchingEnabled() const { + return get(kExchangeLazyFetchingEnabled, false); + } + bool isFieldNamesInJsonCastEnabled() const { return get(kFieldNamesInJsonCastEnabled, false); } + bool operatorTrackExpressionStats() const { + return get(kOperatorTrackExpressionStats, false); + } + + bool enableOperatorBatchSizeStats() const { + return get(kEnableOperatorBatchSizeStats, true); + } + + bool unnestSplitOutput() const { + return get(kUnnestSplitOutput, true); + } + + int32_t queryMemoryReclaimerPriority() const { + return get( + kQueryMemoryReclaimerPriority, std::numeric_limits::max()); + } + + int32_t maxNumSplitsListenedTo() const { + return get(kMaxNumSplitsListenedTo, 0); + } + + std::string source() const { + return get(kSource, ""); + } + + std::string clientTags() const { + return get(kClientTags, ""); + } + + uint32_t joinBuildVectorHasherMaxNumDistinct() const { + return get(kJoinBuildVectorHasherMaxNumDistinct, 1'000'000); + } + template T get(const std::string& key, const T& defaultValue) const { return config_->get(key, defaultValue); } + template std::optional get(const std::string& key) const { - return std::optional(config_->get(key)); + return config_->get(key); + } + + const std::shared_ptr& config() const { + return config_; } /// Test-only method to override the current query config properties. @@ -1109,6 +1475,6 @@ class QueryConfig { private: void validateConfig(); - std::unique_ptr config_; + std::shared_ptr config_; }; } // namespace facebook::velox::core diff --git a/velox/core/QueryCtx.cpp b/velox/core/QueryCtx.cpp index b65309d7d9b8..0c1c961a9e92 100644 --- a/velox/core/QueryCtx.cpp +++ b/velox/core/QueryCtx.cpp @@ -30,16 +30,34 @@ std::shared_ptr QueryCtx::create( cache::AsyncDataCache* cache, std::shared_ptr pool, folly::Executor* spillExecutor, - const std::string& queryId) { + std::string queryId, + std::shared_ptr tokenProvider) { + return QueryCtx::Builder() + .executor(executor) + .queryConfig(std::move(queryConfig)) + .connectorConfigs(std::move(connectorConfigs)) + .asyncDataCache(cache) + .pool(std::move(pool)) + .spillExecutor(spillExecutor) + .queryId(std::move(queryId)) + .tokenProvider(std::move(tokenProvider)) + .build(); +} + +std::shared_ptr QueryCtx::Builder::build() { std::shared_ptr queryCtx(new QueryCtx( - executor, - std::move(queryConfig), - std::move(connectorConfigs), - cache, - std::move(pool), - spillExecutor, - queryId)); + executor_, + std::move(queryConfig_), + std::move(connectorConfigs_), + cache_, + std::move(pool_), + spillExecutor_, + std::move(queryId_), + std::move(tokenProvider_))); queryCtx->maybeSetReclaimer(); + for (auto& cb : releaseCallbacks_) { + queryCtx->addReleaseCallback(std::move(cb)); + } return queryCtx; } @@ -51,22 +69,37 @@ QueryCtx::QueryCtx( cache::AsyncDataCache* cache, std::shared_ptr pool, folly::Executor* spillExecutor, - const std::string& queryId) + const std::string& queryId, + std::shared_ptr tokenProvider) : queryId_(queryId), executor_(executor), spillExecutor_(spillExecutor), cache_(cache), connectorSessionProperties_(connectorSessionProperties), pool_(std::move(pool)), - queryConfig_{std::move(queryConfig)} { + queryConfig_{std::move(queryConfig)}, + fsTokenProvider_(std::move(tokenProvider)) { initPool(queryId); } +QueryCtx::~QueryCtx() { + for (auto& cb : releaseCallbacks_) { + try { + cb(); + } catch (const std::exception& e) { + LOG(ERROR) << "Release callback threw exception: " << e.what(); + } catch (...) { + LOG(ERROR) << "Release callback threw unknown exception"; + } + } + VELOX_CHECK(!underArbitration_); +} + /*static*/ std::string QueryCtx::generatePoolName(const std::string& queryId) { // We attach a monotonically increasing sequence number to ensure the pool // name is unique. static std::atomic seqNum{0}; - return fmt::format("query.{}.{}", queryId.c_str(), seqNum++); + return fmt::format("query.{}.{}", queryId, seqNum++); } void QueryCtx::maybeSetReclaimer() { @@ -82,26 +115,30 @@ void QueryCtx::updateSpilledBytesAndCheckLimit(uint64_t bytes) { const auto numSpilledBytes = numSpilledBytes_.fetch_add(bytes) + bytes; if (queryConfig_.maxSpillBytes() > 0 && numSpilledBytes > queryConfig_.maxSpillBytes()) { - VELOX_SPILL_LIMIT_EXCEEDED(fmt::format( - "Query exceeded per-query local spill limit of {}", - succinctBytes(queryConfig_.maxSpillBytes()))); + VELOX_SPILL_LIMIT_EXCEEDED( + fmt::format( + "Query exceeded per-query local spill limit of {}", + succinctBytes(queryConfig_.maxSpillBytes()))); } } void QueryCtx::updateTracedBytesAndCheckLimit(uint64_t bytes) { if (numTracedBytes_.fetch_add(bytes) + bytes >= queryConfig_.queryTraceMaxBytes()) { - VELOX_TRACE_LIMIT_EXCEEDED(fmt::format( - "Query exceeded per-query local trace limit of {}", - succinctBytes(queryConfig_.queryTraceMaxBytes()))); + VELOX_TRACE_LIMIT_EXCEEDED( + fmt::format( + "Query exceeded per-query local trace limit of {}", + succinctBytes(queryConfig_.queryTraceMaxBytes()))); } } std::unique_ptr QueryCtx::MemoryReclaimer::create( QueryCtx* queryCtx, memory::MemoryPool* pool) { - return std::unique_ptr( - new QueryCtx::MemoryReclaimer(queryCtx->shared_from_this(), pool)); + return std::unique_ptr(new QueryCtx::MemoryReclaimer( + queryCtx->shared_from_this(), + pool, + queryCtx->queryConfig().queryMemoryReclaimerPriority())); } uint64_t QueryCtx::MemoryReclaimer::reclaim( diff --git a/velox/core/QueryCtx.h b/velox/core/QueryCtx.h index b9ebee5d706a..7303a3e0591f 100644 --- a/velox/core/QueryCtx.h +++ b/velox/core/QueryCtx.h @@ -18,31 +18,84 @@ #include #include +#include +#include #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/memory/Memory.h" #include "velox/core/QueryConfig.h" #include "velox/vector/DecodedVector.h" #include "velox/vector/VectorPool.h" -namespace facebook::velox { -class Config; -}; - namespace facebook::velox::core { +/// Query execution context that manages resources and configuration for a +/// query. +/// +/// QueryCtx encapsulates query-level state and resources including: +/// +/// - Memory pool management and memory arbitration. +/// - Query and connector-specific configuration. +/// - Executor for parallel task execution. +/// - Async data cache for IO operations. +/// - Spill executor for disk-based operations. +/// - Query tracing and metrics tracking. +/// +/// Usage Contexts: +/// +/// - Multi-threaded execution: Used with Task::start() where an executor must +/// be provided and its lifetime must outlive all tasks using this context. +/// - Single-threaded execution: Used with ExecCtx or Task::next() for +/// expression evaluation where no executor is required. +/// +/// Construction: +/// +/// To construct a QueryCtx, prefer to use the builder pattern: +/// +/// @code +/// auto queryCtx = QueryCtx::Builder() +/// .executor(myExecutor) +/// .queryConfig(configMap) +/// .queryId("query-123") +/// .pool(myMemoryPool) +/// .build(); +/// @endcode +/// +/// Memory Management: +/// +/// - Automatically creates a root memory pool if not provided +/// - Supports memory arbitration and reclamation under memory pressure +/// - Tracks spilled bytes with configurable limits +/// - Thread-safe memory pool operations +/// +/// Thread-safety: QueryCtx is thread-safe for concurrent access across +/// multiple tasks and operators within a query execution. class QueryCtx : public std::enable_shared_from_this { public: - ~QueryCtx() { - VELOX_CHECK(!underArbitration_); - } + using ReleaseCallback = std::function; - /// QueryCtx is used in different places. When used with `Task::start()`, it's - /// required that the caller supplies the executor and ensure its lifetime - /// outlives the tasks that use it. In contrast, when used in expression - /// evaluation through `ExecCtx` or 'Task::next()' for single thread execution - /// mode, executor is not needed. Hence, we don't require executor to always - /// be passed in here, but instead, ensure that executor exists when actually - /// being used. + ~QueryCtx(); + + /// Creates a new QueryCtx instance with the specified configuration. + /// + /// This factory method constructs a QueryCtx with all necessary resources + /// and automatically sets up memory reclamation if not already configured. + /// + /// @param executor Optional executor for parallel task execution. Required + /// when used with Task::start(), but not needed for + /// expression evaluation or single-threaded execution. + /// @param queryConfig Query-level configuration settings. + /// @param connectorConfigs Connector-specific configuration mappings. + /// @param cache Async data cache for IO operations (defaults to global + /// instance). + /// @param pool Memory pool for query execution (auto-created if nullptr). + /// @param spillExecutor Optional executor for spilling operations. + /// @param queryId Unique identifier for this query. + /// @param tokenProvider Optional filesystem token provider for + /// authentication. + /// @return Shared pointer to the newly created QueryCtx. + /// + /// Note: The caller must ensure the executor's lifetime outlives all tasks + /// using this QueryCtx when executor is provided. static std::shared_ptr create( folly::Executor* executor = nullptr, QueryConfig&& queryConfig = QueryConfig{{}}, @@ -51,8 +104,103 @@ class QueryCtx : public std::enable_shared_from_this { cache::AsyncDataCache* cache = cache::AsyncDataCache::getInstance(), std::shared_ptr pool = nullptr, folly::Executor* spillExecutor = nullptr, - const std::string& queryId = ""); + std::string queryId = "", + std::shared_ptr tokenProvider = {}); + + /// Builder pattern for constructing QueryCtx instances. + /// + /// Provides a fluent interface for creating QueryCtx with optional + /// parameters. This is the recommended approach for improved readability, + /// especially when only setting a subset of configuration options. + /// + /// Example: + /// @code + /// auto ctx = QueryCtx::Builder() + /// .queryId("my-query") + /// .executor(myExecutor) + /// .queryConfig(QueryConfig{mySettings}) + /// .build(); + /// @endcode + class Builder { + public: + Builder& executor(folly::Executor* executor) { + executor_ = executor; + return *this; + } + + Builder& queryConfig(QueryConfig queryConfig) { + queryConfig_ = std::move(queryConfig); + return *this; + } + + Builder& connectorConfigs( + std::unordered_map> + connectorConfigs) { + connectorConfigs_ = std::move(connectorConfigs); + return *this; + } + + Builder& asyncDataCache(cache::AsyncDataCache* cache) { + cache_ = cache; + return *this; + } + + Builder& pool(std::shared_ptr pool) { + pool_ = std::move(pool); + return *this; + } + + Builder& spillExecutor(folly::Executor* spillExecutor) { + spillExecutor_ = spillExecutor; + return *this; + } + + Builder& queryId(std::string queryId) { + queryId_ = std::move(queryId); + return *this; + } + + Builder& tokenProvider( + std::shared_ptr tokenProvider) { + tokenProvider_ = std::move(tokenProvider); + return *this; + } + + /// Adds a callback to be invoked when the QueryCtx is destroyed. + /// Multiple callbacks can be added by calling this method multiple times. + Builder& releaseCallback(ReleaseCallback callback) { + releaseCallbacks_.push_back(std::move(callback)); + return *this; + } + + /// Constructs and returns a QueryCtx with the configured parameters. + /// + /// @return Shared pointer to the newly created QueryCtx instance + std::shared_ptr build(); + + private: + folly::Executor* executor_{nullptr}; + QueryConfig queryConfig_{QueryConfig{{}}}; + std::unordered_map> + connectorConfigs_; + cache::AsyncDataCache* cache_{cache::AsyncDataCache::getInstance()}; + std::shared_ptr pool_; + folly::Executor* spillExecutor_{nullptr}; + std::string queryId_; + std::shared_ptr tokenProvider_; + std::deque releaseCallbacks_; + }; + /// Generates a unique memory pool name for a query. + /// + /// Creates a pool name by combining the provided query ID with a + /// monotonically increasing sequence number to ensure uniqueness across + /// multiple pool creations, even for the same query ID. + /// + /// @param queryId The query identifier to incorporate into the pool name + /// @return A unique pool name in the format "query.{queryId}.{seqNum}" + /// + /// Thread-safe: Uses atomic operations for sequence number generation. static std::string generatePoolName(const std::string& queryId); memory::MemoryPool* pool() const { @@ -89,6 +237,26 @@ class QueryCtx : public std::enable_shared_from_this { return connectorSessionProperties_; } + std::shared_ptr fsTokenProvider() const { + return fsTokenProvider_; + } + + /// Registers a callback to be invoked when this QueryCtx is destroyed. + /// This allows external resources tied to the query's lifetime to be cleaned + /// up before the QueryCtx and its members are destructed. For example, + /// resources that have allocations in the query's memory pool. + /// + /// Example: HashTableCache uses this to remove cached hash tables when a + /// query completes. The cache entry holds a child memory pool of the query + /// pool, so it must be released before the query pool is destroyed. + /// + /// Note: Callbacks are invoked in registration order. Exceptions thrown by + /// callbacks are caught and logged; they do not prevent subsequent callbacks + /// from running. + void addReleaseCallback(ReleaseCallback callback) { + releaseCallbacks_.push_back(std::move(callback)); + } + /// Overrides the previous configuration. Note that this function is NOT /// thread-safe and should probably only be used in tests. void testingOverrideConfigUnsafe( @@ -153,7 +321,8 @@ class QueryCtx : public std::enable_shared_from_this { cache::AsyncDataCache* cache = cache::AsyncDataCache::getInstance(), std::shared_ptr pool = nullptr, folly::Executor* spillExecutor = nullptr, - const std::string& queryId = ""); + const std::string& queryId = "", + std::shared_ptr tokenProvider = {}); class MemoryReclaimer : public memory::MemoryReclaimer { public: @@ -208,6 +377,7 @@ class QueryCtx : public std::enable_shared_from_this { // Invoked to start memory arbitration on this query. void startArbitration(); + // Invoked to stop memory arbitration on this query. void finishArbitration(); @@ -224,9 +394,13 @@ class QueryCtx : public std::enable_shared_from_this { std::atomic numTracedBytes_{0}; mutable std::mutex mutex_; + // Indicates if this query is under memory arbitration or not. std::atomic_bool underArbitration_{false}; std::vector arbitrationPromises_; + std::shared_ptr fsTokenProvider_; + // Callbacks invoked before destruction to clean up external resources. + std::deque releaseCallbacks_; }; // Represents the state of one thread of query execution. diff --git a/velox/core/SimpleFunctionMetadata.h b/velox/core/SimpleFunctionMetadata.h index 0bf202827038..633553069a48 100644 --- a/velox/core/SimpleFunctionMetadata.h +++ b/velox/core/SimpleFunctionMetadata.h @@ -15,7 +15,6 @@ */ #pragma once -#include #include #include @@ -28,7 +27,6 @@ #include "velox/expression/SignatureBinder.h" #include "velox/type/SimpleFunctionApi.h" #include "velox/type/Type.h" -#include "velox/type/Variant.h" namespace facebook::velox::core { @@ -122,22 +120,24 @@ struct TypeAnalysisResults { size_t concreteCount = 0; // Set a priority based on the collected information. Lower priorities are - // picked first during function resolution. Each signature get a rank out - // of 4, those ranks form a Lattice ordering. + // picked first during function resolution. Each signature receives a rank + // from 1 to 4. Those ranks provice a lattice ordering. + // // rank 1: generic free and variadic free. - // e.g: int, int, int -> int. + // e.g: int, int, int -> int. // rank 2: has variadic but generic free. - // e.g: Variadic -> int. + // e.g: Variadic -> int. // rank 3: has generic but no variadic of generic. - // e.g: Any, Any, -> int. + // e.g: Any, Any, -> int. // rank 4: has variadic of generic. - // e.g: Variadic -> int. - + // e.g: Variadic -> int. + // // If two functions have the same rank, then concreteCount is used to // to resolve the ordering. - // e.g: consider the two functions: - // 1. int, Any, Variadic -> has rank 3. concreteCount =2 - // 2. int, Any, Any -> has rank 3. concreteCount =1 + // + // E.g. consider the two functions: + // 1. int, Any, Variadic -> rank 3; concreteCount 2 + // 2. int, Any, Any -> rank 3; concreteCount 1 // in this case (1) is picked. // e.g: (Any, int) will be picked before (Any, Any) // e.g: Variadic> is picked before Variadic. @@ -243,13 +243,14 @@ struct TypeAnalysis> { } else { auto typeVariableName = fmt::format("__user_T{}", T::getId()); results.out << typeVariableName; - results.addVariable(exec::SignatureVariable( - typeVariableName, - std::nullopt, - exec::ParameterType::kTypeParameter, - false, - orderable, - comparable)); + results.addVariable( + exec::SignatureVariable( + typeVariableName, + std::nullopt, + exec::ParameterType::kTypeParameter, + false, + orderable, + comparable)); } results.stats.hasGeneric = true; results.physicalType = UNKNOWN(); @@ -264,10 +265,12 @@ struct TypeAnalysis> { const auto p = P::name(); const auto s = S::name(); results.out << fmt::format("decimal({},{})", p, s); - results.addVariable(exec::SignatureVariable( - p, std::nullopt, exec::ParameterType::kIntegerParameter)); - results.addVariable(exec::SignatureVariable( - s, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + p, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + s, std::nullopt, exec::ParameterType::kIntegerParameter)); results.physicalType = BIGINT(); } }; @@ -280,14 +283,44 @@ struct TypeAnalysis> { const auto p = P::name(); const auto s = S::name(); results.out << fmt::format("decimal({},{})", p, s); - results.addVariable(exec::SignatureVariable( - p, std::nullopt, exec::ParameterType::kIntegerParameter)); - results.addVariable(exec::SignatureVariable( - s, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + p, std::nullopt, exec::ParameterType::kIntegerParameter)); + results.addVariable( + exec::SignatureVariable( + s, std::nullopt, exec::ParameterType::kIntegerParameter)); results.physicalType = HUGEINT(); } }; +template +struct TypeAnalysis> { + void run(TypeAnalysisResults& results) { + results.stats.concreteCount++; + + const auto e = E::name(); + results.out << fmt::format("bigint_enum({})", e); + results.addVariable( + exec::SignatureVariable( + e, std::nullopt, exec::ParameterType::kEnumParameter)); + results.physicalType = BIGINT(); + } +}; + +template +struct TypeAnalysis> { + void run(TypeAnalysisResults& results) { + results.stats.concreteCount++; + + const auto e = E::name(); + results.out << fmt::format("varchar_enum({})", e); + results.addVariable( + exec::SignatureVariable( + e, std::nullopt, exec::ParameterType::kEnumParameter)); + results.physicalType = VARCHAR(); + } +}; + template struct TypeAnalysis> { void run(TypeAnalysisResults& results) { @@ -377,6 +410,24 @@ struct TypeAnalysis> { } }; +template +struct TypeAnalysis> { + void run(TypeAnalysisResults& results) { + // Need to call the TypeAnalysis on T, not T::type for BigintEnum type (on + // BigintEnumT, not Bigint). + TypeAnalysis>().run(results); + } +}; + +template +struct TypeAnalysis> { + void run(TypeAnalysisResults& results) { + // Need to call the TypeAnalysis on T, not T::type for VarcharEnum type (on + // VarcharEnumT, not Varchar). + TypeAnalysis>().run(results); + } +}; + class ISimpleFunctionMetadata { public: virtual ~ISimpleFunctionMetadata() = default; @@ -598,11 +649,12 @@ class SimpleFunctionMetadata : public ISimpleFunctionMetadata { ...); for (const auto& constraint : constraints) { - VELOX_CHECK( - !constraint.constraint().empty(), - "Constraint must be set for variable {}", - constraint.name()); - + if (constraint.isIntegerParameter()) { + VELOX_CHECK( + !constraint.constraint().empty(), + "Constraint must be set for variable {}", + constraint.name()); + } results.variablesInformation.erase(constraint.name()); results.variablesInformation.emplace(constraint.name(), constraint); } @@ -829,8 +881,28 @@ class UDFHolder { (udf_has_callAscii_return_void && udf_has_call_return_bool)), "The return type for callAscii() must match the return type for call()."); - // initialize(): - static constexpr bool udf_has_initialize = util::has_method< + // Detects if initialize() is a template method using SFINAE. + // Template methods can match any signature via template parameter deduction, + // causing false positives in trait detection. We probe with a dummy type + // that's not in our expected signature to identify templates. + struct DummyProbeType {}; + + template + struct has_template_initialize : std::false_type {}; + + template + struct has_template_initialize< + U, + util::detail::void_t().initialize( + std::declval&>(), + std::declval(), + std::declval()))>> : std::true_type {}; + + static constexpr bool is_initialize_template = + has_template_initialize::value; + + // Check for initialize() without MemoryPool parameter. + static constexpr bool udf_has_initialize_without_pool = util::has_method< Fun, initialize_method_resolver, void, @@ -838,6 +910,25 @@ class UDFHolder { const core::QueryConfig&, const exec_arg_type*...>::value; + // Check for initialize() with MemoryPool parameter. + // Excludes template methods to prevent them from incorrectly matching + // via template parameter substitution (e.g., T=MemoryPool). + static constexpr bool udf_has_initialize_with_pool = + !is_initialize_template && + util::has_method< + Fun, + initialize_method_resolver, + void, + const std::vector&, + const core::QueryConfig&, + memory::MemoryPool*, + const exec_arg_type*...>::value; + + // Combined trait for backward compatibility: true if ANY initialize exists + // This preserves the original meaning of udf_has_initialize + static constexpr bool udf_has_initialize = + udf_has_initialize_with_pool || udf_has_initialize_without_pool; + // TODO Remove static constexpr bool udf_has_legacy_initialize = util::has_method< Fun, @@ -921,11 +1012,29 @@ class UDFHolder { } } + FOLLY_ALWAYS_INLINE void initialize( + const std::vector& inputTypes, + const core::QueryConfig& config, + memory::MemoryPool* memoryPool, + const typename exec_resolver::in_type*... constantArgs) { + // Prefer non-MemoryPool signature first to handle template methods + // correctly. Template initialize() methods can match any signature via + // template parameter deduction, so we avoid passing MemoryPool to them. + if constexpr (udf_has_initialize_without_pool) { + return instance_.initialize(inputTypes, config, constantArgs...); + } else if constexpr (udf_has_initialize_with_pool) { + return instance_.initialize( + inputTypes, config, memoryPool, constantArgs...); + } + } + + // Overload for backward compatibility with callers that don't pass + // MemoryPool (e.g., Koski batch functions). FOLLY_ALWAYS_INLINE void initialize( const std::vector& inputTypes, const core::QueryConfig& config, const typename exec_resolver::in_type*... constantArgs) { - if constexpr (udf_has_initialize) { + if constexpr (udf_has_initialize_without_pool) { return instance_.initialize(inputTypes, config, constantArgs...); } } diff --git a/velox/core/TableWriteTraits.cpp b/velox/core/TableWriteTraits.cpp new file mode 100644 index 000000000000..f2520987a90c --- /dev/null +++ b/velox/core/TableWriteTraits.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/core/TableWriteTraits.h" +#include "velox/vector/ComplexVector.h" +#include "velox/vector/ConstantVector.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::core { + +// static +RowVectorPtr TableWriteTraits::createAggregationStatsOutput( + RowTypePtr outputType, + RowVectorPtr aggregationOutput, + StringView tableCommitContext, + velox::memory::MemoryPool* pool) { + // TODO: record aggregation stats output time. + if (aggregationOutput == nullptr) { + return nullptr; + } + VELOX_CHECK_GT(aggregationOutput->childrenSize(), 0); + const vector_size_t numOutputRows = aggregationOutput->childAt(0)->size(); + std::vector columns; + for (int channel = 0; channel < outputType->size(); channel++) { + if (channel < TableWriteTraits::kContextChannel) { + // 1. Set null rows column. + // 2. Set null fragments column. + columns.push_back( + BaseVector::createNullConstant( + outputType->childAt(channel), numOutputRows, pool)); + continue; + } + if (channel == TableWriteTraits::kContextChannel) { + // 3. Set commitcontext column. + columns.push_back( + std::make_shared>( + pool, + numOutputRows, + false /*isNull*/, + VARBINARY(), + // Note that we move tableCommitContext here, so ensure this + // branch is only executed once in the loop. + std::move(tableCommitContext))); + continue; + } + // 4. Set statistics columns. + columns.push_back( + aggregationOutput->childAt(channel - TableWriteTraits::kStatsChannel)); + } + return std::make_shared( + pool, outputType, nullptr, numOutputRows, columns); +} + +std::string TableWriteTraits::rowCountColumnName() { + static const std::string kRowCountName = "rows"; + return kRowCountName; +} + +std::string TableWriteTraits::fragmentColumnName() { + static const std::string kFragmentName = "fragments"; + return kFragmentName; +} + +std::string TableWriteTraits::contextColumnName() { + static const std::string kContextName = "commitcontext"; + return kContextName; +} + +const TypePtr& TableWriteTraits::rowCountColumnType() { + static const TypePtr kRowCountType = BIGINT(); + return kRowCountType; +} + +const TypePtr& TableWriteTraits::fragmentColumnType() { + static const TypePtr kFragmentType = VARBINARY(); + return kFragmentType; +} + +const TypePtr& TableWriteTraits::contextColumnType() { + static const TypePtr kContextType = VARBINARY(); + return kContextType; +} + +// static. +RowTypePtr TableWriteTraits::outputType( + const std::optional& columnStatsSpec) { + static const auto kOutputTypeWithoutStats = + ROW({rowCountColumnName(), fragmentColumnName(), contextColumnName()}, + {rowCountColumnType(), fragmentColumnType(), contextColumnType()}); + if (!columnStatsSpec.has_value()) { + return kOutputTypeWithoutStats; + } + return kOutputTypeWithoutStats->unionWith(columnStatsSpec->outputType()); +} + +folly::dynamic TableWriteTraits::getTableCommitContext( + const RowVectorPtr& input) { + VELOX_CHECK_GT(input->size(), 0); + auto* contextVector = + input->childAt(kContextChannel)->as>(); + return folly::parseJson( + std::string_view(contextVector->valueAt(input->size() - 1))); +} + +int64_t TableWriteTraits::getRowCount(const RowVectorPtr& output) { + VELOX_CHECK_GT(output->size(), 0); + auto rowCountVector = + output->childAt(kRowCountChannel)->asFlatVector(); + VELOX_CHECK_NOT_NULL(rowCountVector); + int64_t rowCount{0}; + for (int i = 0; i < output->size(); ++i) { + if (!rowCountVector->isNullAt(i)) { + rowCount += rowCountVector->valueAt(i); + } + } + return rowCount; +} + +} // namespace facebook::velox::core diff --git a/velox/core/TableWriteTraits.h b/velox/core/TableWriteTraits.h new file mode 100644 index 000000000000..3147294528d6 --- /dev/null +++ b/velox/core/TableWriteTraits.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "velox/core/PlanNode.h" +#include "velox/type/Type.h" + +namespace facebook::velox::core { + +/// Defines table writer output related config properties that are shared +/// between TableWriter and TableWriteMerger. +/// +/// TODO: the table write output processing is Prestissimo specific. Consider +/// move these part logic to Prestissimo and pass to Velox through a customized +/// output processing callback. +class TableWriteTraits { + public: + /// Defines the column names/types in table write output. + static std::string rowCountColumnName(); + static std::string fragmentColumnName(); + static std::string contextColumnName(); + + static const TypePtr& rowCountColumnType(); + static const TypePtr& fragmentColumnType(); + static const TypePtr& contextColumnType(); + + /// Defines the column channels in table write output. + /// Both the statistics and the row_count + fragments are transferred over the + /// same communication link between the TableWriter and TableFinish. Thus the + /// multiplexing is needed. + /// + /// The transferred page layout looks like: + /// [row_count_channel], [fragment_channel], [context_channel], + /// [statistic_channel_1] ... [statistic_channel_N]] + /// + /// [row_count_channel] - contains number of rows processed by a TableWriter + /// [fragment_channel] - contains data provided by the DataSink#finish + /// [statistic_channel_1] ...[statistic_channel_N] - + /// contain aggregated statistics computed by the statistics aggregation + /// within the TableWriter + /// + /// For convenience, we never set both: [row_count_channel] + + /// [fragment_channel] and the [statistic_channel_1] ... + /// [statistic_channel_N]. + /// + /// If this is a row that holds statistics - the [row_count_channel] + + /// [fragment_channel] will be NULL. + /// + /// If this is a row that holds the row count + /// or the fragment - all the statistics channels will be set to NULL. + static constexpr int32_t kRowCountChannel = 0; + static constexpr int32_t kFragmentChannel = 1; + static constexpr int32_t kContextChannel = 2; + static constexpr int32_t kStatsChannel = 3; + + /// Defines the names of metadata in commit context in table writer output. + static constexpr std::string_view kLifeSpanContextKey = "lifespan"; + static constexpr std::string_view kTaskIdContextKey = "taskId"; + static constexpr std::string_view kCommitStrategyContextKey = + "pageSinkCommitStrategy"; + static constexpr std::string_view klastPageContextKey = "lastPage"; + + static RowTypePtr outputType( + const std::optional& columnStatsSpec); + + /// Returns the parsed commit context from table writer 'output'. + static folly::dynamic getTableCommitContext(const RowVectorPtr& output); + + /// Returns the sum of row counts from table writer 'output'. + static int64_t getRowCount(const RowVectorPtr& output); + + /// Creates the statistics output. + /// Statistics page layout (aggregate by partition): + /// row fragments context [partition] stats1 stats2 ... + /// null null X [X] X X + /// null null X [X] X X + static RowVectorPtr createAggregationStatsOutput( + RowTypePtr outputType, + RowVectorPtr aggregationOutput, + StringView tableCommitContext, + velox::memory::MemoryPool* pool); +}; + +} // namespace facebook::velox::core diff --git a/velox/core/tests/CMakeLists.txt b/velox/core/tests/CMakeLists.txt index 6092bb4b98f6..7b45ae059137 100644 --- a/velox/core/tests/CMakeLists.txt +++ b/velox/core/tests/CMakeLists.txt @@ -16,11 +16,14 @@ add_executable( velox_core_test ConstantTypedExprTest.cpp PlanFragmentTest.cpp + PlanNodeBuilderTest.cpp PlanNodeTest.cpp QueryConfigTest.cpp + QueryCtxTest.cpp StringTest.cpp TypeAnalysisTest.cpp - TypedExprSerdeTest.cpp) + TypedExprSerdeTest.cpp +) add_test(velox_core_test velox_core_test) @@ -34,4 +37,14 @@ target_link_libraries( velox_type velox_vector_test_lib GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) + +add_executable(velox_core_plan_consistency_checker_test PlanConsistencyCheckerTest.cpp) + +add_test(velox_core_plan_consistency_checker_test velox_core_plan_consistency_checker_test) + +target_link_libraries( + velox_core_plan_consistency_checker_test + PRIVATE velox_core GTest::gtest GTest::gtest_main +) diff --git a/velox/core/tests/ConstantTypedExprTest.cpp b/velox/core/tests/ConstantTypedExprTest.cpp index 3d067f9ede87..2107d373e8ae 100644 --- a/velox/core/tests/ConstantTypedExprTest.cpp +++ b/velox/core/tests/ConstantTypedExprTest.cpp @@ -15,15 +15,179 @@ */ #include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/memory/Memory.h" #include "velox/core/Expressions.h" #include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/JsonType.h" #include "velox/functions/prestosql/types/TDigestType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/Variant.h" +#include "velox/vector/BaseVector.h" +#include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::core::test { -TEST(ConstantTypedExprTest, null) { +namespace { +struct TestOpaqueStruct { + int value; + std::string name; + + TestOpaqueStruct(int v, std::string n) : value(v), name(std::move(n)) {} + + bool operator==(const TestOpaqueStruct& other) const { + return value == other.value && name == other.name; + } +}; + +} // namespace + +class ConstantTypedExprTest : public ::testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + pool_ = memory::memoryManager()->addLeafPool(); + + // Register serialization/deserialization functions needed for the tests + Type::registerSerDe(); + ITypedExpr::registerSerDe(); + + // Register OPAQUE type serialization for TestOpaqueStruct + static folly::once_flag once; + folly::call_once(once, []() { + OpaqueType::registerSerialization( + "TestOpaqueStruct", + [](const std::shared_ptr& obj) -> std::string { + return folly::json::serialize( + folly::dynamic::object("value", obj->value)("name", obj->name), + folly::json::serialization_opts{}); + }, + [](const std::string& json) -> std::shared_ptr { + folly::dynamic obj = folly::parseJson(json); + return std::make_shared( + obj["value"].asInt(), obj["name"].asString()); + }); + }); + } + + // Helper functions + std::shared_ptr createVariantExpr( + const TypePtr& type, + const Variant& value) { + return std::make_shared(type, value); + } + + std::shared_ptr createNullVariantExpr( + const TypePtr& type) { + return std::make_shared( + type, variant::null(type->kind())); + } + + std::shared_ptr createVectorExpr(const VectorPtr& vector) { + return std::make_shared(vector); + } + + template + VectorPtr createConstantVector(const TypePtr& type, const T& value) { + return BaseVector::createConstant(type, variant(value), 1, pool_.get()); + } + + VectorPtr createNullConstantVector(const TypePtr& type) { + return BaseVector::createNullConstant(type, 1, pool_.get()); + } + + // Test Data + struct TestValues { + variant nullValue; + std::vector nonNullValues; + + TestValues(TypeKind kind) : nullValue(variant::null(kind)) {} + }; + + TestValues getTestValues(TypeKind kind) { + TestValues values(kind); + + switch (kind) { + case TypeKind::BOOLEAN: + values.nonNullValues = {variant(true), variant(false)}; + break; + case TypeKind::TINYINT: + values.nonNullValues = { + variant(int8_t(0)), variant(int8_t(127)), variant(int8_t(-128))}; + break; + case TypeKind::SMALLINT: + values.nonNullValues = { + variant(int16_t(0)), + variant(int16_t(32767)), + variant(int16_t(-32768))}; + break; + case TypeKind::INTEGER: + values.nonNullValues = { + variant(int32_t(0)), + variant(int32_t(2147483647)), + variant(int32_t(-2147483648))}; + break; + case TypeKind::BIGINT: + values.nonNullValues = { + variant(int64_t(0)), + variant(int64_t(9223372036854775807LL)), + variant(int64_t(-9223372036854775808ULL))}; + break; + case TypeKind::REAL: + values.nonNullValues = {variant(0.0f), variant(3.14f), variant(-1.5f)}; + break; + case TypeKind::DOUBLE: + values.nonNullValues = { + variant(0.0), variant(3.14159), variant(-2.71828)}; + break; + case TypeKind::VARCHAR: + values.nonNullValues = { + variant(""), variant("hello"), variant("test string")}; + break; + case TypeKind::VARBINARY: + values.nonNullValues = { + variant::binary(""), + variant::binary("binary data"), + variant::binary("\x00\x01\x02")}; + break; + case TypeKind::TIMESTAMP: + values.nonNullValues = { + variant(Timestamp(0, 0)), + variant(Timestamp(1234567890, 123456789))}; + break; + case TypeKind::HUGEINT: + values.nonNullValues = { + variant(int128_t(0)), + variant(int128_t(123)), + variant(int128_t(-456))}; + break; + default: + // For complex types, we'll handle them within individual tests. + break; + } + return values; + } + + std::shared_ptr pool_; + const std::vector scalarTypes_ = { + TypeKind::BOOLEAN, + TypeKind::TINYINT, + TypeKind::SMALLINT, + TypeKind::INTEGER, + TypeKind::BIGINT, + TypeKind::REAL, + TypeKind::DOUBLE, + TypeKind::VARCHAR, + TypeKind::VARBINARY, + TypeKind::TIMESTAMP, + TypeKind::HUGEINT}; +}; + +TEST_F(ConstantTypedExprTest, null) { auto makeNull = [](const TypePtr& type) { return std::make_shared( type, variant::null(type->kind())); @@ -67,4 +231,386 @@ TEST(ConstantTypedExprTest, null) { *makeNull(ROW({"x", "y"}, {INTEGER(), REAL()}))); } +TEST_F(ConstantTypedExprTest, hashScalarTypes) { + // Tests the consistency of the hash value returned by the ConstantTypedExpr + // between its construction using variant and Velox vectors. + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + EXPECT_EQ(nullVariantExpr->hash(), nullVectorExpr->hash()) + << "Hash mismatch for null " << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + EXPECT_EQ(variantExpr->hash(), vectorExpr->hash()) + << "Hash mismatch for non-null " << TypeKindName::toName(kind) + << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, hashComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + EXPECT_EQ(nullArrayVariantExpr->hash(), nullArrayVectorExpr->hash()) + << "Hash mismatch for null ARRAY variant vs vector"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + EXPECT_EQ(arrayVariantExpr->hash(), arrayVectorExpr->hash()) + << "Hash mismatch for non-null ARRAY variant vs vector"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + auto nullMapVectorExpr = createVectorExpr(createNullConstantVector(mapType)); + EXPECT_EQ(nullMapVariantExpr->hash(), nullMapVectorExpr->hash()) + << "Hash mismatch for null MAP variant vs vector"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + auto mapVector = + makeMapVector({{{"key1", 1}, {"key2", 2}}}); + auto mapVectorExpr = createVectorExpr(mapVector); + EXPECT_EQ(mapVariantExpr->hash(), mapVectorExpr->hash()) + << "Hash mismatch for non-null MAP variant vs vector"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + auto nullRowVectorExpr = createVectorExpr(createNullConstantVector(rowType)); + EXPECT_EQ(nullRowVariantExpr->hash(), nullRowVectorExpr->hash()) + << "Hash mismatch for null ROW variant vs vector"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + auto rowVector = makeRowVector( + {makeFlatVector({42}), makeFlatVector({"hello"})}); + auto rowVectorExpr = createVectorExpr(rowVector); + EXPECT_EQ(rowVariantExpr->hash(), rowVectorExpr->hash()) + << "Hash mismatch for non-null ROW variant vs vector"; + + // OPAQUE + auto testObj = std::make_shared(42, "test_data"); + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + EXPECT_EQ(nullOpaqueVariantExpr->hash(), nullOpaqueVectorExpr->hash()) + << "Hash mismatch for null OPAQUE"; + + // non-null values + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + auto opaqueVectorExpr = createVectorExpr( + BaseVector::createConstant(opaqueType, opaqueVariant, 1, pool_.get())); + EXPECT_EQ(opaqueVariantExpr->hash(), opaqueVectorExpr->hash()) + << "Hash mismatch for non-null OPAQUE"; +} + +TEST_F(ConstantTypedExprTest, serdeScalarTypes) { + // Test serialize/deserialize APIs for scalar types to ensure backward + // compatibility. + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto serialized = nullVariantExpr->serialize(); + auto deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null variant " + << TypeKindName::toName(kind); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + serialized = nullVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null vector " + << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + serialized = variantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*variantExpr == *deserialized) + << "Serialize/deserialize mismatch for variant " + << TypeKindName::toName(kind); + + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + serialized = vectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*vectorExpr == *deserialized) + << "Serialize/deserialize mismatch for vector " + << TypeKindName::toName(kind) << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, serdeComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto serialized = nullArrayVariantExpr->serialize(); + auto deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullArrayVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null ARRAY variant"; + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + serialized = nullArrayVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullArrayVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null ARRAY vector"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + serialized = arrayVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*arrayVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for ARRAY variant with data"; + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + serialized = arrayVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*arrayVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for ARRAY vector with data"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + serialized = nullMapVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullMapVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null MAP variant"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + serialized = mapVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*mapVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for MAP variant with data"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + serialized = nullRowVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullRowVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null ROW variant"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + serialized = rowVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*rowVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for ROW variant with data"; + + // OPAQUE + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + serialized = nullOpaqueVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + EXPECT_TRUE(*nullOpaqueVariantExpr == *deserialized) + << "Serialize/deserialize mismatch for null OPAQUE variant"; + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + serialized = nullOpaqueVectorExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, pool_.get()); + EXPECT_TRUE(*nullOpaqueVectorExpr == *deserialized) + << "Serialize/deserialize mismatch for null OPAQUE vector"; + + // non-null values + auto testObj = std::make_shared(42, "test_data"); + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + serialized = opaqueVariantExpr->serialize(); + deserialized = ConstantTypedExpr::create(serialized, nullptr); + auto actualObj = static_pointer_cast(deserialized) + ->value() + .value() + .obj; + EXPECT_EQ(*testObj, *static_pointer_cast(actualObj)); +} + +TEST_F(ConstantTypedExprTest, toStringScalarTypes) { + for (auto kind : scalarTypes_) { + auto type = createScalarType(kind); + auto testValues = getTestValues(kind); + + // null values + auto nullVariantExpr = createNullVariantExpr(type); + auto nullVectorExpr = createVectorExpr(createNullConstantVector(type)); + EXPECT_EQ(nullVariantExpr->toString(), nullVectorExpr->toString()) + << "toString mismatch for null " << TypeKindName::toName(kind); + + // non-null values + for (const auto& value : testValues.nonNullValues) { + auto variantExpr = std::make_shared(type, value); + auto vectorExpr = createVectorExpr( + BaseVector::createConstant(type, value, 1, pool_.get())); + EXPECT_EQ(variantExpr->toString(), vectorExpr->toString()) + << "toString mismatch for " << TypeKindName::toName(kind) + << " with value " << value.toJson(type); + } + } +} + +TEST_F(ConstantTypedExprTest, toStringComplexTypes) { + // ARRAY + auto arrayType = ARRAY(INTEGER()); + + // null values + auto nullArrayVariantExpr = createNullVariantExpr(arrayType); + auto nullArrayVectorExpr = + createVectorExpr(createNullConstantVector(arrayType)); + EXPECT_EQ(nullArrayVariantExpr->toString(), nullArrayVectorExpr->toString()) + << "toString mismatch for null ARRAY"; + + // non-null values + auto arrayVariant = Variant::array({1, 2, 3}); + auto arrayVariantExpr = + std::make_shared(arrayType, arrayVariant); + auto arrayVector = makeArrayVector({{1, 2, 3}}); + auto arrayVectorExpr = createVectorExpr(arrayVector); + EXPECT_EQ(arrayVariantExpr->toString(), arrayVectorExpr->toString()) + << "toString mismatch for ARRAY variant vs vector"; + + // MAP + auto mapType = MAP(VARCHAR(), INTEGER()); + + // null values + auto nullMapVariantExpr = createNullVariantExpr(mapType); + auto nullMapVectorExpr = createVectorExpr(createNullConstantVector(mapType)); + EXPECT_EQ(nullMapVariantExpr->toString(), nullMapVectorExpr->toString()) + << "toString mismatch for null MAP"; + + // non-null values + std::map mapData = {{"key1", 1}, {"key2", 2}}; + auto mapVariant = Variant::map(mapData); + auto mapVariantExpr = + std::make_shared(mapType, mapVariant); + auto mapVector = + makeMapVector({{{"key1", 1}, {"key2", 2}}}); + auto mapVectorExpr = createVectorExpr(mapVector); + EXPECT_EQ(mapVariantExpr->toString(), mapVectorExpr->toString()) + << "toString mismatch for MAP variant vs vector"; + + // ROW + auto rowType = ROW({{"a", INTEGER()}, {"b", VARCHAR()}}); + + // null values + auto nullRowVariantExpr = createNullVariantExpr(rowType); + auto nullRowVectorExpr = createVectorExpr(createNullConstantVector(rowType)); + EXPECT_EQ(nullRowVariantExpr->toString(), nullRowVectorExpr->toString()) + << "toString mismatch for null ROW"; + + // non-null values + auto rowVariant = Variant::row({42, "hello"}); + auto rowVariantExpr = + std::make_shared(rowType, rowVariant); + auto rowVector = makeRowVector( + {makeFlatVector({42}), makeFlatVector({"hello"})}); + auto rowVectorExpr = createVectorExpr(rowVector); + EXPECT_EQ(rowVariantExpr->toString(), rowVectorExpr->toString()) + << "toString mismatch for ROW variant vs vector"; + + // OPAQUE + auto opaqueType = OPAQUE(); + + // null values + auto nullOpaqueVariantExpr = createNullVariantExpr(opaqueType); + auto nullOpaqueVectorExpr = + createVectorExpr(createNullConstantVector(opaqueType)); + EXPECT_EQ(nullOpaqueVariantExpr->toString(), nullOpaqueVectorExpr->toString()) + << "toString mismatch for null OPAQUE"; + + // non-null values + auto testObj = std::make_shared(42, "test_data"); + auto opaqueVariant = Variant::opaque(testObj); + auto opaqueVariantExpr = + std::make_shared(opaqueType, opaqueVariant); + auto opaqueVectorExpr = createVectorExpr( + BaseVector::createConstant(opaqueType, opaqueVariant, 1, pool_.get())); + EXPECT_EQ(opaqueVariantExpr->toString(), opaqueVectorExpr->toString()) + << "toString mismatch for OPAQUE variant vs vector"; +} + +TEST_F(ConstantTypedExprTest, variantTypeCheck) { + auto testVariantExpr = [&](const Variant& value, + const TypePtr& type, + const TypePtr& expectedType) { + VELOX_ASSERT_THROW( + createVariantExpr(type, value), + fmt::format( + "Expression type {} does not match variant type {}", + type->toString(), + expectedType->toString())); + if (type->isPrimitiveType()) { + VELOX_ASSERT_THROW( + createVariantExpr(type, Variant::null(expectedType->kind())), + fmt::format( + "Expression type {} does not match variant type {}", + type->toString(), + expectedType->toString())); + } else { + ASSERT_NO_THROW( + createVariantExpr(type, Variant::null(expectedType->kind()))); + } + }; + + testVariantExpr("abc", INTEGER(), VARCHAR()); + testVariantExpr(variant(123LL), INTEGER(), BIGINT()); + testVariantExpr(2.0, BIGINT(), DOUBLE()); + testVariantExpr( + variant::array({1, 2, 3}), ARRAY(VARCHAR()), ARRAY(INTEGER())); + testVariantExpr( + variant::map({{2.0, "xyz"}}), + MAP(INTEGER(), VARCHAR()), + MAP(DOUBLE(), VARCHAR())); +} + } // namespace facebook::velox::core::test diff --git a/velox/core/tests/PlanConsistencyCheckerTest.cpp b/velox/core/tests/PlanConsistencyCheckerTest.cpp new file mode 100644 index 000000000000..7faad88f4e37 --- /dev/null +++ b/velox/core/tests/PlanConsistencyCheckerTest.cpp @@ -0,0 +1,406 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/core/PlanConsistencyChecker.h" +#include "velox/parse/PlanNodeIdGenerator.h" + +namespace facebook::velox::core { + +namespace { +class PlanConsistencyCheckerTest : public testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + idGenerator_.reset(); + } + + std::string nextId() { + return idGenerator_.next(); + } + + PlanNodeIdGenerator idGenerator_; +}; + +TypedExprPtr Lit(Variant value) { + auto type = value.inferType(); + return std::make_shared(std::move(type), std::move(value)); +} + +FieldAccessTypedExprPtr Col(TypePtr type, std::string name) { + return std::make_shared( + std::move(type), std::move(name)); +} + +TEST_F(PlanConsistencyCheckerTest, filter) { + auto valuesNode = + std::make_shared(nextId(), std::vector{}); + + auto projectNode = std::make_shared( + nextId(), + std::vector{"a", "b", "c"}, + std::vector{Lit(true), Lit(1), Lit(0.1)}, + valuesNode); + + auto filterNode = + std::make_shared(nextId(), Col(BOOLEAN(), "a"), projectNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(filterNode)); + + // Wrong type. + filterNode = + std::make_shared(nextId(), Col(BOOLEAN(), "b"), projectNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(filterNode), + "Wrong type of input column: b, BOOLEAN vs. INTEGER"); + + // Wrong name. + filterNode = + std::make_shared(nextId(), Col(BOOLEAN(), "x"), projectNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(filterNode), "Field not found: x"); + + // Non-existent column referenced in a lambda expression. + filterNode = std::make_shared( + nextId(), + std::make_shared( + BOOLEAN(), + "any_match", + Lit(Variant::array({1, 2, 3})), + std::make_shared( + ROW("x", INTEGER()), + std::make_shared( + BOOLEAN(), + "lt", + Col(INTEGER(), "x"), + Col(INTEGER(), "blah")))), + projectNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(filterNode), "Field not found: blah"); +} + +TEST_F(PlanConsistencyCheckerTest, project) { + auto valuesNode = + std::make_shared(nextId(), std::vector{}); + + auto projectNode = std::make_shared( + nextId(), + std::vector{"a", "b", "c"}, + std::vector{Lit(true), Lit(1), Lit(0.1)}, + valuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(projectNode)); + + // Duplicate output name. + projectNode = std::make_shared( + nextId(), + std::vector{"a", "a", "c"}, + std::vector{Lit(true), Lit(1), Lit(0.1)}, + valuesNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(projectNode), "Duplicate output column: a"); + + // Wrong column name. + projectNode = std::make_shared( + nextId(), + std::vector{"a", "a", "c"}, + std::vector{Lit(true), Col(REAL(), "x"), Lit(0.1)}, + valuesNode); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(projectNode), "Field not found: x"); +} + +TEST_F(PlanConsistencyCheckerTest, aggregation) { + auto valuesNode = + std::make_shared(nextId(), std::vector{}); + + auto projectNode = std::make_shared( + nextId(), + std::vector{"a", "b", "c"}, + std::vector{Lit(true), Lit(1), Lit(0.1)}, + valuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(projectNode)); + + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "x")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: x"); + } + + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{Col(INTEGER(), "y")}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: y"); + } + + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "cnt"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + .mask = Col(BOOLEAN(), "z"), + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), "Field not found: z"); + } + + { + auto aggregationNode = std::make_shared( + nextId(), + AggregationNode::Step::kPartial, + std::vector{}, + std::vector{}, + std::vector{"sum", "sum"}, + std::vector{ + { + .call = std::make_shared( + BIGINT(), "sum", Col(INTEGER(), "b")), + .rawInputTypes = {BIGINT()}, + }, + { + .call = std::make_shared(BIGINT(), "count"), + .rawInputTypes = {}, + }, + }, + /*ignoreNullKeys*/ false, + projectNode); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(aggregationNode), + "Duplicate output column: sum"); + } +} + +TEST_F(PlanConsistencyCheckerTest, hashJoin) { + auto leftValuesNode = + std::make_shared(nextId(), std::vector{}); + + auto leftProjectNode = std::make_shared( + nextId(), + std::vector{"a", "b"}, + std::vector{Lit(1), Lit(2)}, + leftValuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(leftValuesNode)); + + auto rightValuesNode = + std::make_shared(nextId(), std::vector{}); + + auto rightProjectNode = std::make_shared( + nextId(), + std::vector{"c", "d"}, + std::vector{Lit(1), Lit(2)}, + leftValuesNode); + ASSERT_NO_THROW(PlanConsistencyChecker::check(rightProjectNode)); + + // Invalid reference in the filter. + { + auto joinNode = std::make_shared( + nextId(), + JoinType::kLeft, + /*nullAware=*/false, + std::vector{Col(INTEGER(), "a")}, + std::vector{Col(INTEGER(), "c")}, + std::make_shared( + BOOLEAN(), "lt", Col(INTEGER(), "b"), Col(INTEGER(), "blah")), + leftProjectNode, + rightProjectNode, + ROW({})); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(joinNode), + "Field not found: blah. Available fields are: a, b, c, d."); + } + + // Duplicate join condition. + { + auto joinNode = std::make_shared( + nextId(), + JoinType::kLeft, + /*nullAware=*/false, + std::vector{ + Col(INTEGER(), "a"), Col(INTEGER(), "a")}, + std::vector{ + Col(INTEGER(), "c"), Col(INTEGER(), "c")}, + /*filter=*/nullptr, + leftProjectNode, + rightProjectNode, + ROW({})); + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(joinNode), + "Duplicate join condition: \"a\" = \"c\""); + } +} + +namespace { +class TestTableHandle : public connector::ConnectorTableHandle { + public: + explicit TestTableHandle(std::string connectorId, std::string name) + : connector::ConnectorTableHandle(std::move(connectorId)), + name_{std::move(name)} {} + + const std::string& name() const override { + return name_; + } + + private: + const std::string name_; +}; + +class TestColumnHandle : public connector::ColumnHandle { + public: + explicit TestColumnHandle(std::string name) : name_{std::move(name)} {} + + const std::string& name() const override { + return name_; + } + + private: + const std::string name_; +}; +} // namespace + +TEST_F(PlanConsistencyCheckerTest, tableScan) { + // Empty output column name. + { + auto scanNode = std::make_shared( + nextId(), + ROW({"", "b"}, INTEGER()), + std::make_shared("test", "t"), + connector::ColumnHandleMap{}); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), + "Output column name cannot be empty"); + } + + // Duplicate output column name. + { + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "a"}, INTEGER()), + std::make_shared("test", "t"), + connector::ColumnHandleMap{}); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), "Duplicate output column: a"); + } + + // Missing assignments. + { + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "c"}, INTEGER()), + std::make_shared("test", "t"), + connector::ColumnHandleMap{}); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), + "Column assignments must match output type"); + } + + { + connector::ColumnHandleMap assignments{ + {"a", std::make_shared("x")}, + {"b", std::make_shared("y")}, + {"blah", std::make_shared("z")}, + }; + + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "c"}, INTEGER()), + std::make_shared("test", "t"), + assignments); + + VELOX_ASSERT_THROW( + PlanConsistencyChecker::check(scanNode), + "Column assignment is missing for c"); + } + + // No issues. + { + connector::ColumnHandleMap assignments{ + {"a", std::make_shared("x")}, + {"b", std::make_shared("y")}, + {"c", std::make_shared("z")}, + }; + + auto scanNode = std::make_shared( + nextId(), + ROW({"a", "b", "c"}, INTEGER()), + std::make_shared("test", "t"), + assignments); + + ASSERT_NO_THROW(PlanConsistencyChecker::check(scanNode)); + } +} + +} // namespace +} // namespace facebook::velox::core diff --git a/velox/core/tests/PlanFragmentTest.cpp b/velox/core/tests/PlanFragmentTest.cpp index e5ddab00fbf7..64bc4e587dfd 100644 --- a/velox/core/tests/PlanFragmentTest.cpp +++ b/velox/core/tests/PlanFragmentTest.cpp @@ -129,10 +129,9 @@ TEST_F(PlanFragmentTest, aggregationCanSpill) { const std::vector emptyPreGroupingKeys; std::vector aggregateNames{"sum"}; std::vector emptyAggregateNames{}; - std::vector aggregateInputs{ - std::make_shared(BIGINT())}; + auto aggregateInput = std::make_shared(BIGINT()); const std::vector aggregates{ - {std::make_shared(BIGINT(), aggregateInputs, "sum"), + {std::make_shared(BIGINT(), "sum", aggregateInput), {}, nullptr, {}, @@ -149,8 +148,9 @@ TEST_F(PlanFragmentTest, aggregationCanSpill) { std::string debugString() const { return fmt::format( - "aggregationStep:{} isSpillEnabled:{} isAggregationSpillEnabled:{} isDistinct:{} hasPreAggregation:{} expectedCanSpill:{}", - AggregationNode::stepName(aggregationStep), + "aggregationStep:{} isSpillEnabled:{} isAggregationSpillEnabled:{} " + "isDistinct:{} hasPreAggregation:{} expectedCanSpill:{}", + AggregationNode::toName(aggregationStep), isSpillEnabled, isAggregationSpillEnabled, isDistinct, @@ -188,7 +188,8 @@ TEST_F(PlanFragmentTest, aggregationCanSpill) { testData.hasPreAggregation ? preGroupingKeys : emptyPreGroupingKeys, testData.isDistinct ? emptyAggregateNames : aggregateNames, testData.isDistinct ? emptyAggregates : aggregates, - false, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, valueNode_); auto queryCtx = getSpillQueryCtx( testData.isSpillEnabled, @@ -354,7 +355,7 @@ TEST_F(PlanFragmentTest, supportsBarrier) { core::JoinType::kAnti) .planNode(); const PlanFragment planFragment{plan}; - ASSERT_FALSE(planFragment.supportsBarrier()); + ASSERT_TRUE(planFragment.firstNodeNotSupportingBarrier() != nullptr); } // Plan fragment with plan node supporting barrier. { @@ -374,6 +375,6 @@ TEST_F(PlanFragmentTest, supportsBarrier) { core::JoinType::kInner) .planNode(); const PlanFragment planFragment{plan}; - ASSERT_TRUE(planFragment.supportsBarrier()); + ASSERT_TRUE(planFragment.firstNodeNotSupportingBarrier() == nullptr); } } diff --git a/velox/core/tests/PlanNodeBuilderTest.cpp b/velox/core/tests/PlanNodeBuilderTest.cpp index 59b588df7f15..00d9d034075e 100644 --- a/velox/core/tests/PlanNodeBuilderTest.cpp +++ b/velox/core/tests/PlanNodeBuilderTest.cpp @@ -17,6 +17,9 @@ #include "velox/common/memory/Memory.h" #include "velox/core/PlanNode.h" +#include "velox/duckdb/conversion/DuckParser.h" +#include "velox/exec/tests/utils/AggregationResolver.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/vector/tests/utils/VectorTestBase.h" using namespace ::facebook::velox; @@ -31,11 +34,82 @@ class PlanNodeBuilderTest : public testing::Test, public test::VectorTestBase { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + aggregate::prestosql::registerAllAggregateFunctions(); + } + + core::ColumnStatsSpec createStatsSpec( + const RowTypePtr& type, + const std::vector& groupingKeys, + core::AggregationNode::Step step, + const std::vector& aggregates, + const std::vector>& rawInputArgs = {}) { + std::vector aggs; + aggs.reserve(aggregates.size()); + std::vector names; + names.reserve(aggregates.size()); + + duckdb::ParseOptions options; + options.parseIntegerAsBigint = true; + exec::test::AggregateTypeResolver resolver(step); + + for (auto i = 0; i < aggregates.size(); ++i) { + const auto& aggregate = aggregates[i]; + const auto untypedExpr = duckdb::parseAggregateExpr(aggregate, options); + + if (!rawInputArgs.empty()) { + resolver.setRawInputTypes(rawInputArgs[i]); + } + + core::AggregationNode::Aggregate agg; + agg.call = std::dynamic_pointer_cast( + core::Expressions::inferTypes(untypedExpr.expr, type, pool())); + + if (step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle) { + VELOX_CHECK(rawInputArgs.empty()); + for (const auto& input : agg.call->inputs()) { + agg.rawInputTypes.push_back(input->type()); + } + } else { + agg.rawInputTypes = rawInputArgs[i]; + } + + VELOX_CHECK_NULL(untypedExpr.maskExpr); + VELOX_CHECK(!untypedExpr.distinct); + VELOX_CHECK(untypedExpr.orderBy.empty()); + + aggs.emplace_back(agg); + + if (untypedExpr.expr->alias().has_value()) { + names.push_back(untypedExpr.expr->alias().value()); + } else { + names.push_back(fmt::format("a{}", i)); + } + } + VELOX_CHECK_EQ(aggs.size(), names.size()); + + std::vector groupingKeyExprs; + groupingKeyExprs.reserve(groupingKeys.size()); + for (const auto& groupingKey : groupingKeys) { + auto untypedGroupingKeyExpr = duckdb::parseExpr(groupingKey, options); + auto groupingKeyExpr = + std::dynamic_pointer_cast( + core::Expressions::inferTypes( + untypedGroupingKeyExpr, type, pool())); + VELOX_CHECK_NOT_NULL( + groupingKeyExpr, + "Grouping key must use a column name, not an expression: {}", + groupingKey); + groupingKeyExprs.emplace_back(std::move(groupingKeyExpr)); + } + + return core::ColumnStatsSpec( + std::move(groupingKeyExprs), step, std::move(names), std::move(aggs)); } // A default source node, these are frequently needed when construting // PlanNodes, so providing one here. - const std::shared_ptr source = + const std::shared_ptr source_ = ValuesNode::Builder() .id("values_node_id") .values({makeRowVector({makeFlatVector({1, 2, 3})})}) @@ -49,13 +123,17 @@ class TestConnectorTableHandleForLookupJoin explicit TestConnectorTableHandleForLookupJoin(std::string connectorId) : connector::ConnectorTableHandle(std::move(connectorId)) {} + const std::string& name() const override { + VELOX_NYI(); + } + bool supportsIndexLookup() const override { return true; } }; } // namespace -TEST_F(PlanNodeBuilderTest, ValuesNode) { +TEST_F(PlanNodeBuilderTest, valuesNode) { const PlanNodeId id = "values_node_id"; const std::vector values{ makeRowVector( @@ -87,7 +165,7 @@ TEST_F(PlanNodeBuilderTest, ValuesNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, ArrowStreamNode) { +TEST_F(PlanNodeBuilderTest, arrowStreamNode) { const PlanNodeId id = "arrow_stream_node_id"; const RowTypePtr outputType = ROW({"c0", "c1"}, {INTEGER(), VARCHAR()}); auto arrowStream = std::make_shared(); @@ -109,7 +187,7 @@ TEST_F(PlanNodeBuilderTest, ArrowStreamNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, TraceScanNode) { +TEST_F(PlanNodeBuilderTest, traceScanNode) { const PlanNodeId id = "trace_scan_node_id"; const std::string traceDir = "/tmp/trace"; const uint32_t pipelineId = 7; @@ -137,7 +215,7 @@ TEST_F(PlanNodeBuilderTest, TraceScanNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, FilterNode) { +TEST_F(PlanNodeBuilderTest, filterNode) { const PlanNodeId id = "filter_node_id"; const auto filter = std::make_shared(BOOLEAN(), "col0"); @@ -145,18 +223,18 @@ TEST_F(PlanNodeBuilderTest, FilterNode) { const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->filter(), filter); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = - FilterNode::Builder().id(id).filter(filter).source(source).build(); + FilterNode::Builder().id(id).filter(filter).source(source_).build(); verify(node); const auto node2 = FilterNode::Builder(*node).build(); verify(node2); } -TEST_F(PlanNodeBuilderTest, ProjectNode) { +TEST_F(PlanNodeBuilderTest, projectNode) { const PlanNodeId id = "project_node_id"; const std::vector names{"out_col"}; const std::vector projections{ @@ -166,14 +244,14 @@ TEST_F(PlanNodeBuilderTest, ProjectNode) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->names(), names); EXPECT_EQ(node->projections(), projections); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = ProjectNode::Builder() .id(id) .names(names) .projections(projections) - .source(source) + .source(source_) .build(); verify(node); @@ -181,16 +259,30 @@ TEST_F(PlanNodeBuilderTest, ProjectNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, TableScanNode) { +class DummyTableHandle : public connector::ConnectorTableHandle { + public: + DummyTableHandle(const std::string& connectorId) + : connector::ConnectorTableHandle(connectorId) {} + + const std::string& name() const override { + VELOX_NYI(); + } +}; + +class DummyColumnHandle : public connector::ColumnHandle { + public: + const std::string& name() const override { + VELOX_NYI(); + } +}; + +TEST_F(PlanNodeBuilderTest, tableScanNode) { const PlanNodeId id = "table_scan_node_id"; const RowTypePtr outputType = ROW({"c0", "c1"}, {INTEGER(), VARCHAR()}); - const auto tableHandle = - std::make_shared("connector_id"); - const std:: - unordered_map> - assignments{ - {"c0", std::make_shared()}, - {"c1", std::make_shared()}}; + const auto tableHandle = std::make_shared("connector_id"); + const connector::ColumnHandleMap assignments{ + {"c0", std::make_shared()}, + {"c1", std::make_shared()}}; const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); @@ -211,7 +303,7 @@ TEST_F(PlanNodeBuilderTest, TableScanNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, AggregationNode) { +TEST_F(PlanNodeBuilderTest, aggregationNode) { const PlanNodeId id = "aggregation_node_id"; const auto step = AggregationNode::Step::kSingle; const auto rowType = ROW({"c0"}, {INTEGER()}); @@ -223,8 +315,7 @@ TEST_F(PlanNodeBuilderTest, AggregationNode) { const std::vector aggregateNames{"a0"}; const std::vector aggregates{ AggregationNode::Aggregate{ - .call = std::make_shared( - INTEGER(), std::vector{}, "sum"), + .call = std::make_shared(INTEGER(), "sum"), .rawInputTypes = {INTEGER()}}}; const std::vector globalGroupingSets{0}; const std::optional groupId{ @@ -242,7 +333,7 @@ TEST_F(PlanNodeBuilderTest, AggregationNode) { EXPECT_EQ(node->globalGroupingSets(), globalGroupingSets); EXPECT_EQ(node->groupId(), groupId); EXPECT_EQ(node->ignoreNullKeys(), ignoreNullKeys); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = AggregationNode::Builder() @@ -255,7 +346,7 @@ TEST_F(PlanNodeBuilderTest, AggregationNode) { .globalGroupingSets(globalGroupingSets) .groupId(groupId) .ignoreNullKeys(ignoreNullKeys) - .source(source) + .source(source_) .build(); verify(node); @@ -263,7 +354,7 @@ TEST_F(PlanNodeBuilderTest, AggregationNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, TableWriteNode) { +TEST_F(PlanNodeBuilderTest, tableWriteNode) { const PlanNodeId id = "table_write_node_id"; const RowTypePtr columns = ROW({"c0"}, {INTEGER()}); const std::vector columnNames{"c0"}; @@ -271,16 +362,11 @@ TEST_F(PlanNodeBuilderTest, TableWriteNode) { const bool hasPartitioningScheme = true; const auto commitStrategy = connector::CommitStrategy::kNoCommit; - const auto aggregationNode = AggregationNode::Builder() - .id("aggregation_node_id") - .step(AggregationNode::Step::kPartial) - .groupingKeys({}) - .preGroupedKeys({}) - .aggregateNames({}) - .aggregates({}) - .ignoreNullKeys(true) - .source(source) - .build(); + const auto statsSpec = createStatsSpec( + columns, + std::vector{}, + AggregationNode::Step::kPartial, + std::vector{"sum(c0)"}); const auto insertTableHandle = std::make_shared("connector_id", nullptr); @@ -289,24 +375,24 @@ TEST_F(PlanNodeBuilderTest, TableWriteNode) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->columns(), columns); EXPECT_EQ(node->columnNames(), columnNames); - EXPECT_EQ(node->aggregationNode(), aggregationNode); EXPECT_EQ(node->insertTableHandle(), insertTableHandle); + EXPECT_TRUE(node->hasColumnStatsSpec()); EXPECT_EQ(node->hasPartitioningScheme(), hasPartitioningScheme); EXPECT_EQ(node->outputType(), outputType); EXPECT_EQ(node->commitStrategy(), commitStrategy); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = TableWriteNode::Builder() .id(id) .columns(columns) .columnNames(columnNames) - .aggregationNode(aggregationNode) + .columnStatsSpec(statsSpec) .insertTableHandle(insertTableHandle) .hasPartitioningScheme(hasPartitioningScheme) .outputType(outputType) .commitStrategy(commitStrategy) - .source(source) + .source(source_) .build(); verify(node); @@ -314,34 +400,30 @@ TEST_F(PlanNodeBuilderTest, TableWriteNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, TableWriteMergeNode) { +TEST_F(PlanNodeBuilderTest, tableWriteMergeNode) { const PlanNodeId id = "table_write_merge_node_id"; const RowTypePtr outputType = ROW({"c0"}, {BIGINT()}); - const auto aggregationNode = AggregationNode::Builder() - .id("aggregation_node_id") - .step(AggregationNode::Step::kPartial) - .groupingKeys({}) - .preGroupedKeys({}) - .aggregateNames({}) - .aggregates({}) - .ignoreNullKeys(true) - .source(source) - .build(); + const auto statsSpec = createStatsSpec( + outputType, + std::vector{}, + AggregationNode::Step::kIntermediate, + std::vector{"sum(c0)"}, + {{BIGINT()}}); const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->outputType(), outputType); - EXPECT_EQ(node->aggregationNode(), aggregationNode); - EXPECT_EQ(node->sources()[0], source); + EXPECT_TRUE(node->hasColumnStatsSpec()); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = TableWriteMergeNode::Builder() .id(id) .outputType(outputType) - .aggregationNode(aggregationNode) - .source(source) + .columnStatsSpec(statsSpec) + .source(source_) .build(); verify(node); @@ -349,7 +431,7 @@ TEST_F(PlanNodeBuilderTest, TableWriteMergeNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, ExpandNode) { +TEST_F(PlanNodeBuilderTest, expandNode) { const PlanNodeId id = "expand_node_id"; std::vector> projections{ {std::make_shared(INTEGER(), "col0")}}; @@ -359,14 +441,14 @@ TEST_F(PlanNodeBuilderTest, ExpandNode) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->projections(), projections); EXPECT_EQ(node->names(), names); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = ExpandNode::Builder() .id(id) .projections(projections) .names(names) - .source(source) + .source(source_) .build(); verify(node); @@ -374,7 +456,7 @@ TEST_F(PlanNodeBuilderTest, ExpandNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, GroupIdNode) { +TEST_F(PlanNodeBuilderTest, groupIdNode) { const PlanNodeId id = "group_id_node_id"; const std::vector> groupingSets{{"a"}, {"b"}}; const std::vector groupingKeyInfos{ @@ -394,7 +476,7 @@ TEST_F(PlanNodeBuilderTest, GroupIdNode) { groupingKeyInfos[0].serialize()); EXPECT_EQ(node->aggregationInputs(), aggregationInputs); EXPECT_EQ(node->groupIdName(), groupIdName); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = GroupIdNode::Builder() @@ -403,7 +485,7 @@ TEST_F(PlanNodeBuilderTest, GroupIdNode) { .groupingKeyInfos(groupingKeyInfos) .aggregationInputs(aggregationInputs) .groupIdName(groupIdName) - .source(source) + .source(source_) .build(); verify(node); @@ -411,7 +493,7 @@ TEST_F(PlanNodeBuilderTest, GroupIdNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, ExchangeNode) { +TEST_F(PlanNodeBuilderTest, exchangeNode) { const PlanNodeId id = "exchange_node_id"; const RowTypePtr type = ROW({"c0"}, {BIGINT()}); const auto serdeKind = VectorSerde::Kind::kPresto; @@ -433,7 +515,7 @@ TEST_F(PlanNodeBuilderTest, ExchangeNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, MergeExchangeNode) { +TEST_F(PlanNodeBuilderTest, mergeExchangeNode) { const PlanNodeId id = "merge_exchange_node_id"; const RowTypePtr type = ROW({"c0"}, {BIGINT()}); const auto serdeKind = VectorSerde::Kind::kPresto; @@ -463,7 +545,7 @@ TEST_F(PlanNodeBuilderTest, MergeExchangeNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, LocalMergeNode) { +TEST_F(PlanNodeBuilderTest, localMergeNode) { const PlanNodeId id = "local_merge_node_id"; std::vector sortingKeys = { std::make_shared(BIGINT(), "c0")}; @@ -473,14 +555,14 @@ TEST_F(PlanNodeBuilderTest, LocalMergeNode) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->sortingKeys(), sortingKeys); EXPECT_EQ(node->sortingOrders(), sortingOrders); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = LocalMergeNode::Builder() .id(id) .sortingKeys(sortingKeys) .sortingOrders(sortingOrders) - .sources({source}) + .sources({source_}) .build(); verify(node); @@ -488,7 +570,7 @@ TEST_F(PlanNodeBuilderTest, LocalMergeNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, LocalPartitionNode) { +TEST_F(PlanNodeBuilderTest, localPartitionNode) { const PlanNodeId id = "local_partition_node_id"; const auto type = LocalPartitionNode::Type::kGather; const bool scaleWriter = true; @@ -500,7 +582,7 @@ TEST_F(PlanNodeBuilderTest, LocalPartitionNode) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->type(), type); EXPECT_EQ(node->scaleWriter(), scaleWriter); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); EXPECT_EQ( node->partitionFunctionSpec().serialize(), partitionFunctionSpec->serialize()); @@ -511,7 +593,7 @@ TEST_F(PlanNodeBuilderTest, LocalPartitionNode) { .type(type) .scaleWriter(scaleWriter) .partitionFunctionSpec(partitionFunctionSpec) - .sources({source}) + .sources({source_}) .build(); verify(node); @@ -519,7 +601,7 @@ TEST_F(PlanNodeBuilderTest, LocalPartitionNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, PartitionedOutputNode) { +TEST_F(PlanNodeBuilderTest, partitionedOutputNode) { const PlanNodeId id = "partitioned_output_node_id"; const auto kind = PartitionedOutputNode::Kind::kPartitioned; std::vector keys{ @@ -541,7 +623,7 @@ TEST_F(PlanNodeBuilderTest, PartitionedOutputNode) { EXPECT_EQ(node->outputType(), outputType); EXPECT_EQ(node->serdeKind(), serdeKind); EXPECT_EQ(node->partitionFunctionSpecPtr(), partitionFunctionSpec); - EXPECT_EQ(node->sources(), std::vector{source}); + EXPECT_EQ(node->sources(), std::vector{source_}); }; const auto node = PartitionedOutputNode::Builder() @@ -553,7 +635,7 @@ TEST_F(PlanNodeBuilderTest, PartitionedOutputNode) { .partitionFunctionSpec(partitionFunctionSpec) .outputType(outputType) .serdeKind(serdeKind) - .source(source) + .source(source_) .build(); verify(node); @@ -561,7 +643,7 @@ TEST_F(PlanNodeBuilderTest, PartitionedOutputNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, HashJoinNode) { +TEST_F(PlanNodeBuilderTest, hashJoinNode) { const PlanNodeId id = "hash_join_node_id"; const auto joinType = JoinType::kLeftSemiProject; const bool nullAware = true; @@ -614,7 +696,7 @@ TEST_F(PlanNodeBuilderTest, HashJoinNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, MergeJoinNode) { +TEST_F(PlanNodeBuilderTest, mergeJoinNode) { const PlanNodeId id = "merge_join_node_id"; const auto joinType = JoinType::kInner; const std::vector leftKeys{ @@ -664,7 +746,7 @@ TEST_F(PlanNodeBuilderTest, MergeJoinNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, IndexLookupJoinNode) { +TEST_F(PlanNodeBuilderTest, indexLookupJoinNode) { const PlanNodeId id = "index_lookup_join_node_id"; const auto joinType = JoinType::kInner; const std::vector leftKeys{ @@ -674,8 +756,8 @@ TEST_F(PlanNodeBuilderTest, IndexLookupJoinNode) { const std::vector joinConditions{ std::make_shared( std::make_shared(BIGINT(), "c0"), - std::make_shared(BIGINT(), variant(1)), - std::make_shared(BIGINT(), variant(2)))}; + std::make_shared(BIGINT(), Variant(1LL)), + std::make_shared(BIGINT(), Variant(2LL)))}; const auto left = ValuesNode::Builder() .id("values_node_id_1") @@ -686,9 +768,10 @@ TEST_F(PlanNodeBuilderTest, IndexLookupJoinNode) { TableScanNode::Builder() .id("values_node_id_2") .outputType(ROW({"c1"}, {VARCHAR()})) - .tableHandle(std::make_shared( - "connector_id")) - .assignments({{"c1", std::make_shared()}}) + .tableHandle( + std::make_shared( + "connector_id")) + .assignments({{"c1", std::make_shared()}}) .build(); const auto outputType = ROW({"c0"}, {BIGINT()}); @@ -723,11 +806,11 @@ TEST_F(PlanNodeBuilderTest, IndexLookupJoinNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, NestedLoopJoinNode) { +TEST_F(PlanNodeBuilderTest, nestedLoopJoinNode) { const PlanNodeId id = "nested_loop_join_node_id"; const auto joinType = JoinType::kLeft; const auto joinCondition = - std::make_shared(BOOLEAN(), variant(true)); + std::make_shared(BOOLEAN(), Variant(true)); const auto left = ValuesNode::Builder() .id("values_node_id_1") @@ -766,7 +849,7 @@ TEST_F(PlanNodeBuilderTest, NestedLoopJoinNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, OrderByNode) { +TEST_F(PlanNodeBuilderTest, orderByNode) { const PlanNodeId id = "order_by_node_id"; const std::vector sortingKeys{ std::make_shared(BIGINT(), "c0")}; @@ -779,7 +862,7 @@ TEST_F(PlanNodeBuilderTest, OrderByNode) { EXPECT_EQ(node->sortingOrders(), sortingOrders); EXPECT_EQ(node->isPartial(), isPartial); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = OrderByNode::Builder() @@ -787,7 +870,7 @@ TEST_F(PlanNodeBuilderTest, OrderByNode) { .sortingKeys(sortingKeys) .sortingOrders(sortingOrders) .isPartial(isPartial) - .source(source) + .source(source_) .build(); verify(node); @@ -795,7 +878,61 @@ TEST_F(PlanNodeBuilderTest, OrderByNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, TopNNode) { +TEST_F(PlanNodeBuilderTest, spatialJoinNode) { + const PlanNodeId id = "spatial_join_node_id"; + const auto joinType = JoinType::kInner; + const auto joinCondition = + std::make_shared(BOOLEAN(), Variant(true)); + const auto left = ValuesNode::Builder() + .id("values_node_id_1") + .values({makeRowVector( + {"c0", "g0"}, + {makeFlatVector(std::vector{1}), + makeFlatVector( + std::vector{"POINT(0 0)"})})}) + .build(); + const auto right = ValuesNode::Builder() + .id("values_node_id_2") + .values({makeRowVector( + {"c1", "g1"}, + {makeFlatVector(std::vector{2}), + makeFlatVector( + std::vector{"POINT(0 0)"})})}) + .build(); + const auto outputType = ROW({"c0"}, {BIGINT()}); + const auto probeGeom = + std::make_shared(VARCHAR(), "g0"); + const auto buildGeom = + std::make_shared(VARCHAR(), "g1"); + + const auto verify = [&](const std::shared_ptr& node) { + EXPECT_EQ(node->id(), id); + EXPECT_EQ(node->joinType(), joinType); + EXPECT_EQ(node->joinCondition(), joinCondition); + EXPECT_EQ(node->probeGeometry(), probeGeom); + EXPECT_EQ(node->buildGeometry(), buildGeom); + EXPECT_EQ(node->sources()[0], left); + EXPECT_EQ(node->sources()[1], right); + EXPECT_EQ(node->outputType(), outputType); + }; + + const auto node = SpatialJoinNode::Builder() + .id(id) + .joinType(joinType) + .joinCondition(joinCondition) + .left(left) + .right(right) + .probeGeometry(probeGeom) + .buildGeometry(buildGeom) + .outputType(outputType) + .build(); + verify(node); + + const auto node2 = SpatialJoinNode::Builder(*node).build(); + verify(node2); +} + +TEST_F(PlanNodeBuilderTest, topNNode) { const PlanNodeId id = "topn_node_id"; const std::vector sortingKeys{ std::make_shared(BIGINT(), "c0")}; @@ -810,7 +947,7 @@ TEST_F(PlanNodeBuilderTest, TopNNode) { EXPECT_EQ(node->count(), count); EXPECT_EQ(node->isPartial(), isPartial); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = TopNNode::Builder() @@ -819,7 +956,7 @@ TEST_F(PlanNodeBuilderTest, TopNNode) { .sortingOrders(sortingOrders) .count(count) .isPartial(isPartial) - .source(source) + .source(source_) .build(); verify(node); @@ -827,7 +964,7 @@ TEST_F(PlanNodeBuilderTest, TopNNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, LimitNode) { +TEST_F(PlanNodeBuilderTest, limitNode) { const PlanNodeId id = "limit_node_id"; const int64_t offset = 1; const int64_t count = 5; @@ -839,7 +976,7 @@ TEST_F(PlanNodeBuilderTest, LimitNode) { EXPECT_EQ(node->count(), count); EXPECT_EQ(node->isPartial(), isPartial); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = LimitNode::Builder() @@ -847,7 +984,7 @@ TEST_F(PlanNodeBuilderTest, LimitNode) { .offset(offset) .count(count) .isPartial(isPartial) - .source(source) + .source(source_) .build(); verify(node); @@ -855,7 +992,7 @@ TEST_F(PlanNodeBuilderTest, LimitNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, UnnestNode) { +TEST_F(PlanNodeBuilderTest, unnestNode) { const PlanNodeId id = "unnest_node_id"; std::vector replicateVariables{ std::make_shared(BIGINT(), "a")}; @@ -869,8 +1006,8 @@ TEST_F(PlanNodeBuilderTest, UnnestNode) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->replicateVariables(), replicateVariables); EXPECT_EQ(node->unnestVariables(), unnestVariables); - EXPECT_TRUE(node->withOrdinality()); - EXPECT_EQ(node->sources()[0], source); + EXPECT_TRUE(node->hasOrdinality()); + EXPECT_EQ(node->sources()[0], source_); for (int i = 0; i < node->outputType()->size(); ++i) { if (i < replicateVariables.size()) { @@ -892,7 +1029,7 @@ TEST_F(PlanNodeBuilderTest, UnnestNode) { .unnestVariables(unnestVariables) .unnestNames(unnestNames) .ordinalityName(ordinalityName) - .source(source) + .source(source_) .build(); verify(node); @@ -900,25 +1037,25 @@ TEST_F(PlanNodeBuilderTest, UnnestNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, EnforceSingleRowNode) { +TEST_F(PlanNodeBuilderTest, enforceSingleRowNode) { const PlanNodeId id = "enforce_single_row_id"; const auto verify = [&](const std::shared_ptr& node) { EXPECT_EQ(node->id(), id); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = - EnforceSingleRowNode::Builder().id(id).source(source).build(); + EnforceSingleRowNode::Builder().id(id).source(source_).build(); verify(node); const auto node2 = EnforceSingleRowNode::Builder(*node).build(); verify(node2); } -TEST_F(PlanNodeBuilderTest, AssignUniqueIdNode) { +TEST_F(PlanNodeBuilderTest, assignUniqueIdNode) { const PlanNodeId id = "assign_unique_id_id"; const std::string idName = "unique_id"; const int32_t taskUniqueId = 42; @@ -929,14 +1066,14 @@ TEST_F(PlanNodeBuilderTest, AssignUniqueIdNode) { EXPECT_EQ(node->outputType()->names().back(), idName); EXPECT_EQ(node->taskUniqueId(), taskUniqueId); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = AssignUniqueIdNode::Builder() .id(id) .idName(idName) .taskUniqueId(taskUniqueId) - .source(source) + .source(source_) .build(); verify(node); @@ -944,7 +1081,7 @@ TEST_F(PlanNodeBuilderTest, AssignUniqueIdNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, WindowNode) { +TEST_F(PlanNodeBuilderTest, windowNode) { const PlanNodeId id = "window_node_id"; const std::vector partitionKeys{ std::make_shared(BIGINT(), "a")}; @@ -955,8 +1092,7 @@ TEST_F(PlanNodeBuilderTest, WindowNode) { const bool inputsSorted = true; // Create a dummy window function. - const auto functionCall = std::make_shared( - BIGINT(), std::vector{}, "rank"); + const auto functionCall = std::make_shared(BIGINT(), "rank"); WindowNode::Frame frame{ WindowNode::WindowType::kRows, WindowNode::BoundType::kUnboundedPreceding, @@ -975,13 +1111,13 @@ TEST_F(PlanNodeBuilderTest, WindowNode) { EXPECT_EQ( node->outputType()->size(), - source->outputType()->size() + windowColumnNames.size()); - for (size_t i = source->outputType()->size(); + source_->outputType()->size() + windowColumnNames.size()); + for (size_t i = source_->outputType()->size(); i < node->outputType()->size(); ++i) { EXPECT_EQ( node->outputType()->nameOf(i), - windowColumnNames[i - source->outputType()->size()]); + windowColumnNames[i - source_->outputType()->size()]); } EXPECT_EQ(node->windowFunctions().size(), 1); @@ -989,7 +1125,7 @@ TEST_F(PlanNodeBuilderTest, WindowNode) { node->windowFunctions()[0].serialize(), windowFunctions[0].serialize()); EXPECT_EQ(node->inputsSorted(), inputsSorted); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = WindowNode::Builder() @@ -1000,7 +1136,7 @@ TEST_F(PlanNodeBuilderTest, WindowNode) { .windowColumnNames(windowColumnNames) .windowFunctions(windowFunctions) .inputsSorted(inputsSorted) - .source(source) + .source(source_) .build(); verify(node); @@ -1008,7 +1144,7 @@ TEST_F(PlanNodeBuilderTest, WindowNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, RowNumberNode) { +TEST_F(PlanNodeBuilderTest, rowNumberNode) { const PlanNodeId id = "row_number_node_id"; const std::vector partitionKeys{ std::make_shared(BIGINT(), "c0")}; @@ -1021,7 +1157,7 @@ TEST_F(PlanNodeBuilderTest, RowNumberNode) { EXPECT_EQ(node->limit(), limit); EXPECT_TRUE(node->generateRowNumber()); EXPECT_EQ(node->outputType()->names().back(), rowNumberColumnName); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = RowNumberNode::Builder() @@ -1029,7 +1165,7 @@ TEST_F(PlanNodeBuilderTest, RowNumberNode) { .partitionKeys(partitionKeys) .rowNumberColumnName(rowNumberColumnName) .limit(limit) - .source(source) + .source(source_) .build(); verify(node); @@ -1037,7 +1173,7 @@ TEST_F(PlanNodeBuilderTest, RowNumberNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, MarkDistinctNode) { +TEST_F(PlanNodeBuilderTest, markDistinctNode) { const PlanNodeId id = "mark_distinct_node_id"; const std::string markerName = "is_distinct"; const std::vector distinctKeys{ @@ -1048,14 +1184,14 @@ TEST_F(PlanNodeBuilderTest, MarkDistinctNode) { EXPECT_EQ(node->markerName(), markerName); EXPECT_EQ(node->distinctKeys(), distinctKeys); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = MarkDistinctNode::Builder() .id(id) .markerName(markerName) .distinctKeys(distinctKeys) - .source(source) + .source(source_) .build(); verify(node); @@ -1063,7 +1199,7 @@ TEST_F(PlanNodeBuilderTest, MarkDistinctNode) { verify(node2); } -TEST_F(PlanNodeBuilderTest, TopNRowNumberNode) { +TEST_F(PlanNodeBuilderTest, topNRowNumberNode) { const PlanNodeId id = "topn_row_number_node_id"; const std::vector partitionKeys{ std::make_shared(BIGINT(), "c0")}; @@ -1083,7 +1219,7 @@ TEST_F(PlanNodeBuilderTest, TopNRowNumberNode) { EXPECT_TRUE(node->generateRowNumber()); EXPECT_EQ(node->outputType()->names().back(), rowNumberColumnName); EXPECT_EQ(node->sources().size(), 1); - EXPECT_EQ(node->sources()[0], source); + EXPECT_EQ(node->sources()[0], source_); }; const auto node = TopNRowNumberNode::Builder() @@ -1093,7 +1229,7 @@ TEST_F(PlanNodeBuilderTest, TopNRowNumberNode) { .sortingOrders(sortingOrders) .rowNumberColumnName(rowNumberColumnName) .limit(limit) - .source(source) + .source(source_) .build(); verify(node); diff --git a/velox/core/tests/PlanNodeTest.cpp b/velox/core/tests/PlanNodeTest.cpp index b1026300c0dd..fa39d0927d90 100644 --- a/velox/core/tests/PlanNodeTest.cpp +++ b/velox/core/tests/PlanNodeTest.cpp @@ -46,8 +46,7 @@ TEST_F(PlanNodeTest, findFirstNode) { auto rowType = ROW({"name1"}, {BIGINT()}); std::shared_ptr tableHandle; - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; std::shared_ptr tableScan3 = std::make_shared("3", rowType, tableHandle, assignments); @@ -86,6 +85,55 @@ TEST_F(PlanNodeTest, findFirstNode) { })); } +TEST_F(PlanNodeTest, findNodeById) { + auto values = std::make_shared("1", std::vector{}); + auto project = std::make_shared( + "2", + std::vector{"a", "b"}, + std::vector{ + std::make_shared(DOUBLE(), "rand"), + std::make_shared(DOUBLE(), "rand"), + }, + values); + + auto filter = std::make_shared( + "3", + std::make_shared( + BOOLEAN(), + "gt", + std::make_shared(DOUBLE(), "a"), + std::make_shared(DOUBLE(), 0.5)), + project); + + auto limit = std::make_shared("4", 0, 10, false, filter); + + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "1"), values.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "2"), project.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "3"), filter.get()); + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "4"), limit.get()); + + ASSERT_EQ(PlanNode::findNodeById(limit.get(), "5"), nullptr); + ASSERT_EQ(PlanNode::findNodeById(project.get(), "4"), nullptr); +} + +TEST_F(PlanNodeTest, is) { + auto values = std::make_shared("1", std::vector{}); + auto project = std::make_shared( + "2", + std::vector{"a", "b"}, + std::vector{ + std::make_shared(DOUBLE(), "rand"), + std::make_shared(DOUBLE(), "rand"), + }, + values); + + ASSERT_TRUE(values->is()); + ASSERT_FALSE(values->is()); + + ASSERT_FALSE(project->is()); + ASSERT_TRUE(project->is()); +} + TEST_F(PlanNodeTest, sortOrder) { struct { SortOrder order1; @@ -133,6 +181,7 @@ TEST_F(PlanNodeTest, duplicateSortKeys) { "orderBy", sortingKeys, sortingOrders, false, nullptr), "Duplicate sorting keys are not allowed: c0"); } + class TestIndexTableHandle : public connector::ConnectorTableHandle { public: TestIndexTableHandle() @@ -164,7 +213,7 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { } }; -TEST_F(PlanNodeTest, isIndexLookupJoin) { +TEST_F(PlanNodeTest, indexLookupJoin) { const auto rowType = ROW({"name"}, {BIGINT()}); const auto valueNode = std::make_shared("orderBy", rowData_); ASSERT_FALSE(isIndexLookupJoin(valueNode.get())); @@ -174,34 +223,326 @@ TEST_F(PlanNodeTest, isIndexLookupJoin) { const RowTypePtr outputType = ROW({"c0", "c1"}, {BIGINT(), BIGINT()}); auto indexTableHandle = std::make_shared(); const auto probeNode = std::make_shared( - "tableScan-probe", - probeType, - nullptr, - std::unordered_map< - std::string, - std::shared_ptr>{}); + "tableScan-probe", probeType, nullptr, connector::ColumnHandleMap{}); ASSERT_FALSE(isIndexLookupJoin(probeNode.get())); const auto buildNode = std::make_shared( "tableScan-build", buildType, indexTableHandle, - std::unordered_map< - std::string, - std::shared_ptr>{}); + connector::ColumnHandleMap{}); ASSERT_FALSE(isIndexLookupJoin(buildNode.get())); const std::vector leftKeys{ std::make_shared(BIGINT(), "c0")}; const std::vector rightKeys{ std::make_shared(BIGINT(), "c1")}; - const auto indexJoinNode = std::make_shared( - "indexJoinNode", - core::JoinType::kInner, - leftKeys, - rightKeys, - std::vector{}, - probeNode, - buildNode, - outputType); - ASSERT_TRUE(isIndexLookupJoin(indexJoinNode.get())); + { + const auto indexJoinNodeWithInnerJoin = + std::make_shared( + "indexJoinNode", + core::JoinType::kInner, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/nullptr, + /*hasMarker=*/false, + probeNode, + buildNode, + outputType); + ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithInnerJoin.get())); + ASSERT_FALSE(indexJoinNodeWithInnerJoin->hasMarker()); + ASSERT_EQ(indexJoinNodeWithInnerJoin->filter(), nullptr); + ASSERT_EQ( + indexJoinNodeWithInnerJoin->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNode][INNER c0=c1] -> c0:BIGINT, c1:BIGINT\n"); + } + { + const RowTypePtr outputTypeWithMatchColumn = + ROW({"c0", "c1", "c2"}, {BIGINT(), BIGINT(), BOOLEAN()}); + const auto indexJoinNodeWithLeftJoin = + std::make_shared( + "indexJoinNode", + core::JoinType::kLeft, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/nullptr, + /*hasMarker=*/true, + probeNode, + buildNode, + outputTypeWithMatchColumn); + ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithLeftJoin.get())); + ASSERT_TRUE(indexJoinNodeWithLeftJoin->hasMarker()); + ASSERT_EQ(indexJoinNodeWithLeftJoin->filter(), nullptr); + ASSERT_EQ( + indexJoinNodeWithLeftJoin->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNode][LEFT c0=c1] -> c0:BIGINT, c1:BIGINT, c2:BOOLEAN\n"); + } + { + // Test IndexLookupJoinNode with filter + const auto filterExpr = std::make_shared( + BOOLEAN(), "filter_column"); + const auto indexJoinNodeWithFilter = std::make_shared( + "indexJoinNodeWithFilter", + core::JoinType::kInner, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/filterExpr, + /*hasMarker=*/false, + probeNode, + buildNode, + outputType); + ASSERT_TRUE(isIndexLookupJoin(indexJoinNodeWithFilter.get())); + ASSERT_FALSE(indexJoinNodeWithFilter->hasMarker()); + ASSERT_EQ(indexJoinNodeWithFilter->filter(), filterExpr); + ASSERT_EQ( + indexJoinNodeWithFilter->toString(/*detailed=*/true), + "-- IndexLookupJoin[indexJoinNodeWithFilter][INNER c0=c1, filter: \"filter_column\"] -> c0:BIGINT, c1:BIGINT\n"); + } + // Error case. + { + VELOX_ASSERT_THROW( + std::make_shared( + "indexJoinNode", + core::JoinType::kInner, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/nullptr, + /*hasMarker=*/true, + probeNode, + buildNode, + outputType), + "Index join match column can only present for LEFT but not INNER"); + } + { + VELOX_ASSERT_THROW( + std::make_shared( + "indexJoinNode", + core::JoinType::kLeft, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/nullptr, + /*hasMarker=*/true, + probeNode, + buildNode, + outputType), + "The last output column must be boolean type if match column is present"); + } + { + const RowTypePtr outputTypeWithDuplicateMatchColumn = + ROW({"c0", "c1", "c0"}, {BIGINT(), BIGINT(), BOOLEAN()}); + VELOX_ASSERT_THROW( + std::make_shared( + "indexJoinNode", + core::JoinType::kLeft, + leftKeys, + rightKeys, + std::vector{}, + /*filter=*/nullptr, + /*hasMarker=*/true, + probeNode, + buildNode, + outputTypeWithDuplicateMatchColumn), + ""); + } +} + +TEST_F(PlanNodeTest, partitionedOutputNode) { + const PlanNodeId id{"partitionedOutputNode"}; + const PartitionedOutputNode::Kind kind = + PartitionedOutputNode::Kind::kPartitioned; + const std::vector keys = { + std::make_shared(BIGINT(), "c0")}; + const PartitionFunctionSpecPtr partitionFunctionSpec = + std::make_shared(); + const VectorSerde::Kind serdeKind = VectorSerde::Kind::kPresto; + PlanNodePtr source = std::make_shared("source", rowData_); + + { + // Creating a PartitionedOutputNode with a single partition, empty keys, and + // a null partition function should succeed. + PartitionedOutputNode node( + id, + kind, + {}, + 1, // numPartitions + true, // replicateNullsAndAny + nullptr, // partitionFunctionSpec + rowType_, + serdeKind, + source); + // Attempting to dereference the nullptr should fail. + ASSERT_EQ(node.partitionFunctionSpecPtr(), nullptr); + VELOX_ASSERT_THROW(node.partitionFunctionSpec(), ""); + } + + // Creating a PartitionedOutputNode that is not partitioned and has empty keys + // and a partition function (even kinds other than partitioned still use a + // partition function) should succeed. + { + PartitionedOutputNode node( + id, + PartitionedOutputNode::Kind::kArbitrary, + {}, + 10, // numPartitions + true, // replicateNullsAndAny + partitionFunctionSpec, + rowType_, + serdeKind, + source); + // We should be able to dereference the partition function spec. + ASSERT_EQ(node.partitionFunctionSpecPtr(), partitionFunctionSpec); + ASSERT_EQ( + node.partitionFunctionSpec().toString(), + partitionFunctionSpec->toString()); + } + + // Creating a PartitionedOutputNode with numPartitions = 0 should throw. + VELOX_ASSERT_THROW( + PartitionedOutputNode( + id, + kind, + keys, + 0, // numPartitions + true, // replicateNullsAndAny + partitionFunctionSpec, + rowType_, + serdeKind, + source), + ""); + + // Creating a PartitionedOutputNode with numPartitions = 1 and non-empty + // keys should throw. + VELOX_ASSERT_THROW( + PartitionedOutputNode( + id, + kind, + keys, + 1, // numPartitions + true, // replicateNullsAndAny + partitionFunctionSpec, + rowType_, + serdeKind, + source), + "Non-empty partitioning keys require more than one partition"); + + // Creating a PartitionedOutputNode with numPartitions > 1 and no partition + // function should throw. + VELOX_ASSERT_THROW( + PartitionedOutputNode( + id, + kind, + keys, + 5, // numPartitions + true, // replicateNullsAndAny + nullptr, // partitionFunctionSpec + rowType_, + serdeKind, + source), + "Partition function spec must be specified when the number of destinations is more than 1."); + + // Creating a PartitionedOutputNode that is not partitioned with non-empty + // keys should throw. + VELOX_ASSERT_THROW( + PartitionedOutputNode( + id, + PartitionedOutputNode::Kind::kArbitrary, + keys, + 5, // numPartitions + true, // replicateNullsAndAny + partitionFunctionSpec, + rowType_, + serdeKind, + source), + "partitioning doesn't allow for partitioning keys"); +} + +TEST_F(PlanNodeTest, aggregationNodeNoGroupsSpanBatches) { + auto values = std::make_shared("values", rowData_); + + const std::vector groupingKeys{ + std::make_shared(BIGINT(), "c0")}; + const std::vector preGroupedKeys{ + std::make_shared(BIGINT(), "c0")}; + const std::vector aggregateNames{"sum"}; + const std::vector aggregates{ + {.call = std::make_shared(BIGINT(), "sum"), + .rawInputTypes = {BIGINT()}}}; + + // noGroupsSpanBatches=true with preGroupedKeys (streaming aggregation) should + // succeed and the accessor should return true. + { + auto aggNode = std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + preGroupedKeys, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/true, + values); + ASSERT_TRUE(aggNode->noGroupsSpanBatches()); + ASSERT_TRUE(aggNode->isPreGrouped()); + ASSERT_EQ( + aggNode->toString(true), + "-- Aggregation[agg][SINGLE STREAMING [c0] sum := sum() noGroupsSpanBatches] -> c0:BIGINT, sum:BIGINT\n"); + } + + // noGroupsSpanBatches=false with preGroupedKeys should succeed and the + // accessor should return false. + { + auto aggNode = std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + preGroupedKeys, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + values); + ASSERT_FALSE(aggNode->noGroupsSpanBatches()); + ASSERT_TRUE(aggNode->isPreGrouped()); + ASSERT_EQ( + aggNode->toString(true), + "-- Aggregation[agg][SINGLE STREAMING [c0] sum := sum()] -> c0:BIGINT, sum:BIGINT\n"); + } + + // noGroupsSpanBatches=true without preGroupedKeys (non-streaming aggregation) + // should fail. + VELOX_ASSERT_THROW( + std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + /*preGroupedKeys=*/std::vector{}, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/true, + values), + "noGroupsSpanBatches can only be set for streaming aggregation (pre-grouped)"); + + // noGroupsSpanBatches=false without preGroupedKeys should succeed. + { + auto aggNode = std::make_shared( + "agg", + AggregationNode::Step::kSingle, + groupingKeys, + /*preGroupedKeys=*/std::vector{}, + aggregateNames, + aggregates, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, + values); + ASSERT_FALSE(aggNode->noGroupsSpanBatches()); + ASSERT_FALSE(aggNode->isPreGrouped()); + ASSERT_EQ( + aggNode->toString(true), + "-- Aggregation[agg][SINGLE [c0] sum := sum()] -> c0:BIGINT, sum:BIGINT\n"); + } } } // namespace diff --git a/velox/core/tests/QueryConfigTest.cpp b/velox/core/tests/QueryConfigTest.cpp index 599c1ce65338..8dec2d31aea6 100644 --- a/velox/core/tests/QueryConfigTest.cpp +++ b/velox/core/tests/QueryConfigTest.cpp @@ -33,6 +33,7 @@ TEST_F(QueryConfigTest, emptyConfig) { const QueryConfig& config = queryCtx->queryConfig(); ASSERT_FALSE(config.isLegacyCast()); + EXPECT_EQ(config.maxNumSplitsListenedTo(), 0); } TEST_F(QueryConfigTest, setConfig) { @@ -204,4 +205,79 @@ TEST_F(QueryConfigTest, expressionEvaluationRelatedConfigs) { testConfig(createConfig(false, false, false, true)); } +TEST_F(QueryConfigTest, sessionStartTime) { + // Test with no session start time set + { + auto queryCtx = QueryCtx::create(nullptr, QueryConfig{{}}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), 0); + } + + // Test with session start time set + { + int64_t startTimeMs = 1674123456789; // Some timestamp in milliseconds + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(startTimeMs)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), startTimeMs); + } + + // Test with negative session start time (should be valid) + { + int64_t negativeStartTime = -1000; + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(negativeStartTime)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), negativeStartTime); + } + + // Test with maximum int64_t value + { + int64_t maxTime = std::numeric_limits::max(); + std::unordered_map configData( + {{QueryConfig::kSessionStartTime, std::to_string(maxTime)}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + + EXPECT_EQ(config.sessionStartTimeMs(), maxTime); + } +} + +TEST_F(QueryConfigTest, singleSourceExchangeOptimizationConfig) { + // Test default value (should be false) + { + auto queryCtx = QueryCtx::create(nullptr, QueryConfig{{}}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_FALSE(config.singleSourceExchangeOptimizationEnabled()); + } + + // Test with optimization enabled + { + std::unordered_map configData( + {{QueryConfig::kSkipRequestDataSizeWithSingleSourceEnabled, "true"}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_TRUE(config.singleSourceExchangeOptimizationEnabled()); + } + + // Test with optimization explicitly disabled + { + std::unordered_map configData( + {{QueryConfig::kSkipRequestDataSizeWithSingleSourceEnabled, "false"}}); + auto queryCtx = + QueryCtx::create(nullptr, QueryConfig{std::move(configData)}); + const QueryConfig& config = queryCtx->queryConfig(); + EXPECT_FALSE(config.singleSourceExchangeOptimizationEnabled()); + } +} + } // namespace facebook::velox::core::test diff --git a/velox/core/tests/QueryCtxTest.cpp b/velox/core/tests/QueryCtxTest.cpp new file mode 100644 index 000000000000..44dbc5927c47 --- /dev/null +++ b/velox/core/tests/QueryCtxTest.cpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/core/QueryCtx.h" + +namespace facebook::velox::core::test { + +class QueryCtxTest : public testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } +}; + +TEST_F(QueryCtxTest, withSysRootPool) { + auto queryCtx = QueryCtx::create( + nullptr, + QueryConfig{{}}, + std::unordered_map>{}, + nullptr, + memory::deprecatedRootPool().shared_from_this()); + auto* queryPool = queryCtx->pool(); + ASSERT_EQ(&memory::deprecatedRootPool(), queryPool); + ASSERT_NE(queryPool->reclaimer(), nullptr); + try { + VELOX_FAIL("Trigger Error"); + } catch (const velox::VeloxRuntimeError&) { + VELOX_ASSERT_THROW( + queryPool->reclaimer()->abort(queryPool, std::current_exception()), + "SysMemoryReclaimer::abort is not supported"); + } + ASSERT_EQ(queryPool->reclaimer()->priority(), 0); + memory::MemoryReclaimer::Stats stats; + ASSERT_EQ(queryPool->reclaimer()->reclaim(queryPool, 1'000, 1'000, stats), 0); + uint64_t reclaimableBytes{0}; + ASSERT_FALSE( + queryPool->reclaimer()->reclaimableBytes(*queryPool, reclaimableBytes)); +} + +TEST_F(QueryCtxTest, releaseCallbacks) { + int callbackCount = 0; + std::string capturedQueryId; + + { + auto queryCtx = QueryCtx::create( + nullptr, + QueryConfig{{}}, + std::unordered_map>{}, + nullptr, + nullptr, + nullptr, + "test_query_id"); + + // Add multiple callbacks. + queryCtx->addReleaseCallback([&callbackCount]() { ++callbackCount; }); + + queryCtx->addReleaseCallback( + [&callbackCount, &capturedQueryId, id = queryCtx->queryId()]() { + ++callbackCount; + capturedQueryId = id; + }); + + // Callbacks should not be invoked yet. + ASSERT_EQ(callbackCount, 0); + } + + // After QueryCtx destruction, all callbacks should have been invoked. + ASSERT_EQ(callbackCount, 2); + ASSERT_EQ(capturedQueryId, "test_query_id"); +} + +TEST_F(QueryCtxTest, releaseCallbackException) { + int callbackCount = 0; + + { + auto queryCtx = QueryCtx::create( + nullptr, + QueryConfig{{}}, + std::unordered_map>{}, + nullptr, + nullptr, + nullptr, + "test_query_id"); + + // First callback succeeds. + queryCtx->addReleaseCallback([&callbackCount]() { ++callbackCount; }); + + // Second callback throws an exception. + queryCtx->addReleaseCallback( + []() { throw std::runtime_error("Test exception"); }); + + // Third callback should still execute despite the previous exception. + queryCtx->addReleaseCallback([&callbackCount]() { ++callbackCount; }); + } + + // All callbacks should have been attempted, with exception caught and logged. + // First and third callbacks should have incremented the counter. + ASSERT_EQ(callbackCount, 2); +} + +TEST_F(QueryCtxTest, builderReleaseCallbacks) { + int callbackCount = 0; + std::string capturedQueryId; + + { + // Use builder to add release callbacks during construction. + auto queryCtx = + QueryCtx::Builder() + .queryId("builder_test_query_id") + .releaseCallback([&callbackCount]() { ++callbackCount; }) + .releaseCallback([&callbackCount, &capturedQueryId]() { + ++callbackCount; + capturedQueryId = "builder_test_query_id"; + }) + .build(); + + // Callbacks should not be invoked yet. + ASSERT_EQ(callbackCount, 0); + } + + // After QueryCtx destruction, all callbacks should have been invoked. + ASSERT_EQ(callbackCount, 2); + ASSERT_EQ(capturedQueryId, "builder_test_query_id"); +} +} // namespace facebook::velox::core::test diff --git a/velox/core/tests/TypedExprSerdeTest.cpp b/velox/core/tests/TypedExprSerdeTest.cpp index b49e36222140..62c6ae93006e 100644 --- a/velox/core/tests/TypedExprSerdeTest.cpp +++ b/velox/core/tests/TypedExprSerdeTest.cpp @@ -87,33 +87,25 @@ TEST_F(TypedExprSerDeTest, call) { // a + b auto expression = std::make_shared( BIGINT(), - std::vector{ - std::make_shared(BIGINT(), "a"), - std::make_shared(BIGINT(), "b"), - }, - "plus"); + "plus", + std::make_shared(BIGINT(), "a"), + std::make_shared(BIGINT(), "b")); testSerde(expression); // f(g(h(a, b), c)) expression = std::make_shared( VARCHAR(), - std::vector{ + "f", + std::make_shared( + DOUBLE(), + "g", std::make_shared( - DOUBLE(), - std::vector{ - std::make_shared( - BIGINT(), - std::vector{ - std::make_shared(BIGINT(), "a"), - std::make_shared(BIGINT(), "b"), - }, - "h"), - std::make_shared(BIGINT(), "c"), - }, - "g"), - }, - "f"); + BIGINT(), + "h", + std::make_shared(BIGINT(), "a"), + std::make_shared(BIGINT(), "b")), + std::make_shared(BIGINT(), "c"))); testSerde(expression); } @@ -144,13 +136,11 @@ TEST_F(TypedExprSerDeTest, lambda) { ROW({"x"}, {BIGINT()}), std::make_shared( BOOLEAN(), - std::vector{ - std::make_shared(BIGINT(), "x"), - std::make_shared(makeArrayVector({ - {1, 2, 3, 4, 5}, - })), - }, - "in")); + "in", + std::make_shared(BIGINT(), "x"), + std::make_shared(makeArrayVector({ + {1, 2, 3, 4, 5}, + })))); testSerde(expression); } diff --git a/velox/docs/conf.py b/velox/docs/conf.py index d9dd62e7450b..a87507e85dae 100644 --- a/velox/docs/conf.py +++ b/velox/docs/conf.py @@ -50,6 +50,8 @@ "issue", "pr", "spark", + "iceberg", + "delta", "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.mathjax", diff --git a/velox/docs/configs.rst b/velox/docs/configs.rst index aee9aa327981..ba71a407a65c 100644 --- a/velox/docs/configs.rst +++ b/velox/docs/configs.rst @@ -43,6 +43,16 @@ Generic Configuration - integer - 80 - Abandons partial TopNRowNumber if number of output rows equals or exceeds this percentage of the number of input rows. + * - abandon_dedup_hashmap_min_rows + - integer + - 100,000 + - Number of input rows to receive before starting to check whether to abandon building a HashTable without + duplicates in HashBuild for left semi/anti join. + * - abandon_dedup_hashmap_min_pct + - integer + - 0 + - Abandons building a HashTable without duplicates in HashBuild for left semi/anti join if the percentage of + distinct keys in the HashTable exceeds this threshold. Zero means 'disable this optimization'. * - session_timezone - string - @@ -60,6 +70,11 @@ Generic Configuration - true - Whether to track CPU usage for stages of individual operators. Can be expensive when processing small batches, e.g. < 10K rows. + * - operator_batch_size_stats_enabled + - bool + - true + - If true, the driver will collect the operator's input/output batch size through vector flat size estimation, otherwise not. + We might turn this off in use cases which have very wide column width and batch size estimation has non-trivial cpu cost. * - hash_adaptivity_enabled - bool - true @@ -68,6 +83,11 @@ Generic Configuration - bool - true - If true, the conjunction expression can reorder inputs based on the time taken to calculate them. + * - parallel_join_build_rows_enabled + - bool + - false + - If true, the hash probe drivers can output build-side rows in parallel for full and right joins (only when spilling is not + - enabled by hash probe). If false, only the last prober is allowed to output build-side rows. * - max_local_exchange_buffer_size - integer - 32MB @@ -101,6 +121,16 @@ Generic Configuration client. Enforced approximately, not strictly. A larger size can increase network throughput for larger clusters and thus decrease query processing time at the expense of reducing the amount of memory available for other usage. + * - skip_request_data_size_with_single_source_enabled + - bool + - false + - If true, skip request data size if there is only single source. + This is used to optimize the Presto-on-Spark use case where each exchange client + has only one shuffle partition source. + * - local_merge_source_queue_size + - integer + - 2 + - Maximum number of vectors buffered in each local merge source before blocking to wait for consumers. * - max_page_partitioning_buffer_size - integer - 32MB @@ -117,6 +147,17 @@ Generic Configuration - integer - 1000 - The minimum number of table rows that can trigger the parallel hash join table build. + * - hash_probe_dynamic_filter_pushdown_enabled + - bool + - true + - Whether hash probe can generate any dynamic filter (including Bloom filter) and push down to upstream operators. + * - hash_probe_bloom_filter_pushdown_max_size + - integer + - 0 + - The maximum byte size of Bloom filter that can be generated from hash + probe. When set to 0, no Bloom filter will be generated. To achieve + optimal performance, this should not be too larger than the CPU cache + size on the host. * - debug.validate_output_from_operators - bool - false @@ -139,6 +180,12 @@ Generic Configuration - 0 - If it is not zero, specifies the time limit that a driver can continuously run on a thread before yield. If it is zero, then it no limit. + * - window_num_sub_partitions + - integer + - 1 + - Window operator can be configured to sub-divide window partitions on each thread of execution into groups of + sub partitions for sequential processing. This setting specifies how many sub-partitions to create for each + thread. Use 1 to disable sub partitioning. * - prefixsort_normalized_key_max_bytes - integer - 128 @@ -167,8 +214,30 @@ Generic Configuration - 0 - Specifies the max number of input batches to prefetch to do index lookup ahead. If it is zero, then process one input batch at a time. + * - index_lookup_join_split_output + - bool + - true + - If this is true, then the index join operator might split output for each input batch based + on the output batch size control. Otherwise, it tries to produce a single output for each input + batch. + * - unnest_split_output_batch + - bool + - true + - If this is true, then the unnest operator might split output for each input batch based on the + output batch size control. Otherwise, it produces a single output for each input batch. + * - max_num_splits_listened_to + - integer + - 0 + - Specifies The max number of input splits to listen to by SplitListener per table scan node per + worker. It's up to the SplitListener implementation to respect this config. + * - operator_track_expression_stats + - bool + - false + - If this is true, then operators that evaluate expressions will track stats for expressions that + are not special forms and return them as part of their operator stats. Tracking these stats can + be expensive (especially if operator stats are retrieved frequently) and this allows the user to + explicitly enable it. -.. _expression-evaluation-conf: Expression Evaluation Configuration ----------------------------------- @@ -189,6 +258,13 @@ Expression Evaluation Configuration - false - Whether to track CPU usage for individual expressions (supported by call and cast expressions). Can be expensive when processing small batches, e.g. < 10K rows. + * - expression.track_cpu_usage_for_functions + - string + - "" + - Comma-separated list of function names to selectively track CPU usage for. Only applicable when + ``expression.track_cpu_usage`` is set to false. Function names are case-insensitive and will be normalized + to lowercase. This allows fine-grained control over CPU tracking overhead when only specific functions need to + be monitored. * - legacy_cast - bool - false @@ -226,6 +302,11 @@ Expression Evaluation Configuration - integer - 10000 - Some lambda functions over arrays and maps are evaluated in batches of the underlying elements that comprise the arrays/maps. This is done to make the batch size managable as array vectors can have thousands of elements each and hit scaling limits as implementations typically expect BaseVectors to a couple of thousand entries. This lets up tune those batch sizes. Setting this to zero is setting unlimited batch size. + * - debug_bing_tile_children_max_zoom_shift + - integer + - 5 + - The UDF `bing_tile_children` generates the children of a Bing tile based on a specified target zoom level. The number of children produced is determined by the difference between the target zoom level and the zoom level of the input tile. This configuration limits the number of children by capping the maximum zoom level difference, with a default value set to 5. This cap is necessary to prevent excessively large array outputs, which can exceed the size limits of the elements vector in the Velox array vector. + Memory Management ----------------- @@ -254,6 +335,11 @@ Memory Management memory limit for partial aggregation is automatically doubled up to `max_extended_partial_aggregation_memory`. This adaptation is disabled by default, since the value of `max_extended_partial_aggregation_memory` equals the value of `max_partial_aggregation_memory`. Specify higher value for `max_extended_partial_aggregation_memory` to enable. + * - query_memory_reclaimer_priority + - integer + - 2147483647 + - Priority of the query in the memory pool reclaimer. Lower value means higher priority. This is used in + global arbitration victim selection. Spilling -------- @@ -277,6 +363,10 @@ Spilling - boolean - true - When `spill_enabled` is true, determines whether HashBuild and HashProbe operators can spill to disk under memory pressure. + * - local_merge_spill_enabled + - boolean + - false + - When `spill_enabled` is true, determines whether LocalMerge operators can spill to disk to cap memory usage. * - mixed_grouped_mode_hash_join_spill_enabled - boolean - false @@ -289,6 +379,12 @@ Spilling - boolean - true - When `spill_enabled` is true, determines whether Window operator can spill to disk under memory pressure. + * - window_spill_min_read_batch_rows + - integer + - 1000 + - When processing spilled window data, read batches of whole partitions having at least that many rows. Set to 1 to + read one whole partition at a time. Each driver processing the Window operator will process that much data at + once. * - row_number_spill_enabled - boolean - true @@ -382,6 +478,12 @@ Spilling - Specifies the compression algorithm type to compress the spilled data before write to disk to trade CPU for IO efficiency. The supported compression codecs are: zlib, snappy, lzo, zstd, lz4 and gzip. none means no compression. + * - spill_num_max_merge_files + - integer + - 0 + - The max number of files to merge at a time when merging sorted files into a single ordered stream. 0 means unlimited. + This is used to reduce memory pressure by capping the number of open files when merging spilled sorted files to + avoid using too much memory and causing OOM. Note that this is only applicable for ordered spill. * - spill_prefixsort_enabled - bool - false @@ -390,7 +492,7 @@ Spilling * - spiller_start_partition_bit - integer - 29 - - The start partition bit which is used with `spiller_partition_bits` together to calculate the spilling partition number. + - The start partition bit which is used with `spiller_num_partition_bits` together to calculate the spilling partition number. * - spiller_num_partition_bits - integer - 3 @@ -419,6 +521,21 @@ Aggregation - integer - 80 - Abandons partial aggregation if number of groups equals or exceeds this percentage of the number of input rows. + * - aggregation_compaction_bytes_threshold + - integer + - 0 + - Memory threshold in bytes for triggering string compaction during global + aggregation. When total string storage exceeds this limit with high unused + memory ratio, compaction is triggered to reclaim dead strings. Disabled by + default (0). Currently only applies to approx_most_frequent aggregate with + StringView type during global aggregation. + * - aggregation_compaction_unused_memory_ratio + - double + - 0.25 + - Ratio of unused (evicted) bytes to total bytes that triggers compaction. + The value is in the range of [0, 1). Currently only applies to approx_most_frequent + aggregate with StringView type during global aggregation. May be extended + to other aggregation types on-demand. * - streaming_aggregation_min_output_batch_rows - integer - 0 @@ -502,6 +619,31 @@ Table Writer - Minimum amount of data processed by all the logical table partitions to trigger skewed partition rebalancing by scale writer exchange. +Connector Config +---------------- +Connector config is initialized on velox runtime startup and is shared among queries as the default config across all connectors. +Each query can override the config by setting corresponding query session properties such as in Prestissimo. + +.. list-table:: + :widths: 20 20 10 10 70 + :header-rows: 1 + + * - user + - + - string + - "" + - The user of the query. Used for storage logging. + * - source + - + - string + - "" + - The source of the query. Used for storage access and logging. + * - schema + - + - string + - "" + - The schema of the query. Used for storage logging. + Hive Connector -------------- Hive Connector config is initialized on velox runtime startup and is shared among queries as the default config. @@ -521,6 +663,11 @@ Each query can override the config by setting corresponding query session proper - integer - 100 - Maximum number of (bucketed) partitions per a single table writer instance. + * - hive.max-bucket-count + - hive.max_bucket_count + - integer + - 100000 + - Maximum number of buckets that a table writer is allowed to write to. * - insert-existing-partitions-behavior - insert_existing_partitions_behavior - string @@ -629,6 +776,11 @@ Each query can override the config by setting corresponding query session proper - bool - true - Reads timestamp partition value as local time if true. Otherwise, reads as UTC. + * - hive.preserve-flat-maps-in-memory + - hive.preserve_flat_maps_in_memory + - bool + - false + - Whether to preserve flat maps in memory as FlatMapVectors instead of converting them to MapVectors. This is only applied during data reading inside the DWRF and Nimble readers, not during downstream processing like expression evaluation etc. ``ORC File Format Configuration`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -720,6 +872,11 @@ Each query can override the config by setting corresponding query session proper - integer - 1024 - Batch size used when writing into Parquet through Arrow bridge. + * - hive.parquet.writer.created-by + - + - string + - parquet-cpp-velox version 0.0.0 + - Created-by value used when writing to Parquet. ``Amazon S3 Configuration`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -852,6 +1009,12 @@ These semantics are similar to the `Apache Hadoop-Aws module .dfs.core.windows.net - string - - SharedKey + - - Specifies the authentication mechanism to use for Azure storage accounts. **Allowed values:** "SharedKey", "OAuth", "SAS". "SharedKey": Uses the storage account name and key for authentication. "OAuth": Utilizes OAuth tokens for secure authentication. "SAS": Employs Shared Access Signatures for granular access control. + To create Azure clients with the configured authentication type, the caller must + register the corresponding Azure client provider from the configuration by calling + `registerAzureClientProvider`. * - fs.azure.account.key..dfs.core.windows.net - string - @@ -899,6 +1065,12 @@ These semantics are similar to the `Apache Hadoop-Aws module /oauth2/token`. + * - fs.azure.sas.token.renew.period.for.streams + - string + - 120 + - Specifies the period in seconds to re-use SAS tokens until the expiry is within this number of seconds. + This configuration is used together with `registerSasTokenProvider` for dynamic SAS token renewal. + When a SAS token is close to expiry, it will be renewed by getting a new token from the provider. CLP Connector ----------------------------- @@ -973,6 +1145,14 @@ Spark-specific Configuration - Type - Default Value - Description + * - spark.ansi_enabled + - bool + - false + - If true, Spark function's behavior is ANSI-compliant, e.g. throws runtime exception instead + of returning null on invalid inputs. It affects only functions explicitly marked as "ANSI compliant". + Note: This feature is still under development to achieve full ANSI compliance. Users can + refer to the Spark function documentation to verify the current support status of a specific + function. * - spark.legacy_size_of_null - bool - true @@ -1006,6 +1186,10 @@ Spark-specific Configuration - false - If true, Spark statistical aggregation functions including skewness, kurtosis, stddev, stddev_samp, variance, var_samp, covar_samp and corr will return NaN instead of NULL when dividing by zero during expression evaluation. + * - spark.json_ignore_null_fields + - bool + - true + - If true, ignore null fields when generating JSON string. If false, null fields are included with a null value. Tracing -------- @@ -1025,10 +1209,10 @@ Tracing - string - - The root directory to store the tracing data and metadata for a query. - * - query_trace_node_ids + * - query_trace_node_id - string - - - A comma-separated list of plan node ids whose input data will be trace. If it is empty, then we only trace the + - The plan node id whose input data will be trace. If it is empty, then we only trace the query metadata which includes the query plan and configs etc. * - query_trace_task_reg_exp - string @@ -1038,3 +1222,60 @@ Tracing - integer - 0 - The max trace bytes limit. Tracing is disabled if zero. + * - query_trace_dry_run + - boolean + - false + - If true, we only collect the input trace for a given operator but without the actual + execution. This is used for crash debugging. + +Cudf-specific Configuration (Experimental) +------------------------------------------ +These configurations are available when `compiled with cuDF `_. +Note: These configurations are experimental and subject to change. + +.. list-table:: + :widths: 30 10 10 70 + :header-rows: 1 + + * - Property Name + - Type + - Default Value + - Description + * - cudf.enabled + - bool + - true + - If true, enable cuDF. By default, it is enabled if compiled with cuDF. + * - cudf.memory_resource + - string + - async + - The memory resource to use for cuDF. Possible values are (cuda, pool, async, arena, managed, managed_pool, prefetch_managed, prefetch_managed_pool). + The prefetch options enable automatic prefetching for better GPU memory performance: prefetch_managed uses CUDA unified memory with prefetching, + prefetch_managed_pool uses a pooled version of CUDA unified memory with prefetching. + * - cudf.memory_percent + - integer + - 50 + - The initial percent of GPU memory to allocate for pool or arena memory resources. + * - cudf.function_name_prefix + - string + - "" + - The prefix to use for the function names in cuDF. + * - cudf.ast_expression_enabled + - bool + - true + - If true, enable using cuDF AST-based expression evaluation when supported. + * - cudf.ast_expression_priority + - integer + - 100 + - Priority of cuDF AST expressions. Higher value wins when multiple cuDF execution options are available for the same Velox expression. Standalone cuDF functions have priority 50. If enabled, with a default priority of 100, AST will be chosen as replacement for cudf execution. + * - cudf.allow_cpu_fallback + - bool + - true + - If true, allow falling back to Velox CPU execution when an operation is not supported in cuDF execution. If false, an error will be thrown if an operation is not supported in cuDF execution. + * - cudf.debug_enabled + - bool + - false + - If true, enable debug printing. + * - cudf.log_fallback + - bool + - true + - If true, log a reason for falling back to Velox CPU execution, when an operation is not supported in cuDF execution. diff --git a/velox/docs/develop/aggregations.rst b/velox/docs/develop/aggregations.rst index 934252f7c14a..86c7f31293ff 100644 --- a/velox/docs/develop/aggregations.rst +++ b/velox/docs/develop/aggregations.rst @@ -112,6 +112,27 @@ encounters a row with a different values in pre-grouped keys. This helps reduce the total amount of memory used and allows to unblock downstream operators faster. +noGroupsSpanBatches Optimization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +AggregationNode supports an optional ``noGroupsSpanBatches`` flag that can be set +to true for streaming aggregations. When enabled, this flag indicates that no +sort group spans across input batches - each input batch contains complete data +for its groups, and no group will appear in any subsequent input batch. + +This optimization allows the StreamingAggregation operator to immediately produce +aggregation results for all groups in each input batch without waiting to see if +more data for those groups will arrive in subsequent batches. This can +significantly improve output latency and reduce memory usage since the operator +doesn't need to hold onto partial aggregation state across batches. + +The ``noGroupsSpanBatches`` flag can only be set when the aggregation is +pre-grouped (streaming). Setting it on a non-streaming aggregation will result +in an error. + +This optimization is typically set automatically by query optimizers when they can +guarantee that the input data meets the required properties. + Push-Down into Table Scan ------------------------- diff --git a/velox/docs/develop/connectors.rst b/velox/docs/develop/connectors.rst index 22886134e3b2..b4c4275cdb85 100644 --- a/velox/docs/develop/connectors.rst +++ b/velox/docs/develop/connectors.rst @@ -29,13 +29,13 @@ Connector Interface * - Connector Factory - Enables creating instances of a particular connector. -Velox provides Hive, CLP, and TPC-H Connectors out of the box. -Let's examine the implementation details of both the Hive and CLP Connectors as examples. +Velox provides Hive and TPC-H Connectors out of the box. +Let's see how the above connector interfaces are implemented in the Hive Connector in detail below. Hive Connector -------------- The Hive Connector is used to read and write data files (Parquet, DWRF) residing on -an external storage (S3, HDFS, GCS, Linux FS). +an external storage. Supported external "Storage Adapters" are listed below. HiveConnectorSplit ~~~~~~~~~~~~~~~~~~ @@ -80,7 +80,7 @@ sources and each require a different configuration. Storage Adapters ~~~~~~~~~~~~~~~~ Hive Connector allows reading and writing files from a variety of distributed storage systems. -The supported storage API are S3, HDFS, GCS, Linux FS. +The supported storage API are S3 (Amazon S3 Compatible Storage), HDFS (Hadoop Distributed File System), GCS (Google Cloud Storage), ABFS (Azure Blob File Storage), Linux File System. If file is not found when reading, `openFileForRead` API throws `VeloxRuntimeError` with `error_code::kFileNotFound`. This behavior is necessary to support the `ignore_missing_files` configuration property. @@ -89,7 +89,7 @@ S3 is supported using the `AWS SDK for C++ ` S3 supported schemes are `s3://` (Amazon S3, Minio), `s3a://` (Hadoop 3.x), `s3n://` (Deprecated in Hadoop 3.x), `oss://` (Alibaba cloud storage), and `cos://`, `cosn://` (Tencent cloud storage). -HDFS is supported using the +HDFS is supported using the `Apache Hadoop libhdfs.so `_ and `Apache Hawk libhdfs3 `_ library. HDFS supported schemes are `hdfs://`. @@ -97,8 +97,8 @@ GCS is supported using the `Google Cloud Platform C++ Client Libraries `_. GCS supported schemes are `gs://`. -ABS (Azure Blob Storage) is supported using the -`Azure SDK for C++ `_ library. ABS supported schemes are `abfs(s)://`. +ABFS is supported using the +`Azure SDK for C++ `_ library. ABFS supported schemes are `abfs(s)://`. S3 Storage adapter using a proxy ******************************** @@ -122,57 +122,7 @@ This is the behavior when the proxy settings are enabled: 5. Use . or \*. to indicate domain suffix matching, e.g. `.foobar.com` will match `test.foobar.com` or `foo.foobar.com`. -CLP Connector -------------- -The CLP Connector is used to read CLP splits stored on a local file system or S3. It implements similar interfaces as -the Hive Connector except for the ``DataSink`` interface. Here we only describe the ``DataSource`` interface and the -``ConnectorSplit`` interface implementation since ``Connector`` and ``ConnectorFactory`` are similar to the Hive -Connector. We also describe ``ClpS3AuthProviderBase``, an interface that allows users to customize S3 authentication. - -ClpConnectorSplit -~~~~~~~~~~~~~~~~~ -``ClpConnectorSplit`` describes a data chunk using ``path``. This path may be the absolute file path to the split file -if it is stored on a local file system, or the complete (or partial) URL of the split if it is stored on S3. In the -latter case, when only a partial URL is provided, ``ClpS3AuthProviderBase`` provides a hook in ``ClpDataSource`` to -assist in constructing the full URL. Refer to :ref:`ClpS3AuthProviderBase` for details. -``ClpConnectorSplit`` also includes a ``type`` property that specifies whether the split is an archive or an IR -(Internal Representation) stream. - -BaseClpCursor -~~~~~~~~~~~~~ -``BaseClpCursor`` is responsible for preparing pushdown operations, loading splits, filtering data, and returning -results. Each split type (archive or IR stream) has its own corresponding subclass (``ClpArchiveCursor`` and -``ClpIrCursor``). +HDFS Storage adapter +******************** -ClpDataSource -~~~~~~~~~~~~~ -``ClpDataSource`` implements the ``addSplit`` API that consumes a ``ClpConnectorSplit`` and ``next`` API that -processes the split and returns a batch of rows. - -During initialization, it records the KQL query and split source (S3 or local). It then iterates through -each output column, accessing its handle to get its type and original name. For row types, it recursively -traverses the nested structure to process each field; for non-row types, it directly maps the Velox column -type to a CLP column type. - -When a split is added, a ``BaseClpCursor`` instance is created with the split path and input source (which -may be either a ``ClpArchiveCursor`` or a ``ClpIrCursor``). The query is parsed and simplified into an AST. -On ``next``, the cursor finds matching row indices and, if any exist, ``ClpDataSource`` recursively creates -a row vector composed of lazy vectors, which use CLP column readers to decode and load data as needed during -execution. - -.. _ClpS3AuthProviderBase: - -ClpS3AuthProviderBase -~~~~~~~~~~~~~~~~~~~~~ -``ClpS3AuthProviderBase`` defines an interface for obtaining S3 authentication information and constructing the -complete split URL from the current URL stored in ``ClpConnectorSplit``. It provides the following two functions: - -1. ``exportAuthEnvironmentVariables()`` – Parses user-defined configuration options and exports the three environment - variables required by CLP-s to the system. This ensures that, at runtime, CLP-s can execute S3-related operations - correctly. This function is invoked immediately after the configuration is parsed. -2. ``constructS3Url()`` – Builds the full S3 URL for a split. It takes the ``path`` property of ``ClpConnectorSplit`` - as its argument, allowing customization in URL construction. For example, the ``path`` could be ``"prefix/split"``, - which must be prefixed with ``"https://bucket.s3.region.amazonaws.com/"`` to form the complete URL. - -Additionally, this interface maintains a reference to ``config_``, enabling users to define custom configuration -options for passing any required information. +Velox currently supports HDFS by dynamically loading libhdfs.so from the environment's ${HADOOP_HOME}/native/lib directory. If you prefer to use libhdfs3 instead, you can create a symbolic link from libhdfs.so to libhdfs3.so within the same directory. diff --git a/velox/docs/develop/debugging/tracing.rst b/velox/docs/develop/debugging/tracing.rst index 6966e97ca0b2..aad6e836f147 100644 --- a/velox/docs/develop/debugging/tracing.rst +++ b/velox/docs/develop/debugging/tracing.rst @@ -101,7 +101,7 @@ to the following simplified, pretty-printed JSON string (with some content remov "name":"HashJoinNode" }, "connectorProperties":{...}, - "queryConfig":{"query_trace_node_ids":"2", ...} + "queryConfig":{"query_trace_node_id":"2", ...} } **OperatorTraceInputWriter** @@ -301,7 +301,7 @@ It would show something as the follows: ++++++Query configs++++++ query_trace_task_reg_exp: .* - query_trace_node_ids: 2 + query_trace_node_id: 2 query_trace_max_bytes: 107374182400 query_trace_dir: /tmp/velox_test_aJqeFd/basic/traceRoot/ query_trace_enabled: 1 diff --git a/velox/docs/develop/joins.rst b/velox/docs/develop/joins.rst index 0a1bf3d6a5c5..07b35462f085 100644 --- a/velox/docs/develop/joins.rst +++ b/velox/docs/develop/joins.rst @@ -198,6 +198,11 @@ the join is executed using broadcast or partitioned strategy has no effect on the join execution itself. The only difference is that broadcast execution allows for dynamic filter pushdown while partitioned execution does not. +HashJoinNode supports a ``useHashTableCache`` flag (used only by Presto-on-Spark) +that enables caching of the hash table built for broadcast joins. When enabled, +the first task to build the hash table stores it in a global cache, and subsequent +tasks from same query reuse the cached table instead of rebuilding it. + PartitionedOutput operator and OutputBufferManager support broadcasting the results of the plan evaluation. This functionality is enabled by setting boolean flag "broadcast" in the PartitionedOutputNode to true. diff --git a/velox/docs/develop/memory.rst b/velox/docs/develop/memory.rst index 06ff5fe667ed..df4bc5a96a12 100644 --- a/velox/docs/develop/memory.rst +++ b/velox/docs/develop/memory.rst @@ -122,14 +122,14 @@ Memory Manager :alt: Memory Manager The memory manager is created on server startup with the provided -*MemoryManagerOption*. It creates a memory allocator instance to manage the +*MemoryManager::Options*. It creates a memory allocator instance to manage the physical memory allocations for both query memory allocated through memory pool and cache memory allocated through the file cache. It ensures the total allocated memory is within the system memory limit (specified by -*MemoryManagerOptions::allocatorCapacity*). The memory manager also creates a +*MemoryManager::Options::allocatorCapacity*). The memory manager also creates a memory arbitrator instance to arbitrate the memory capacity among running queries. It ensures the total allocated query memory capacity is within the -query memory limit (specified by *MemoryManagerOptions::arbitratorCapacity*). The +query memory limit (specified by *MemoryManager::Options::arbitratorCapacity*). The memory arbitrator also prevents each individual query running out of its per-query memory limit (specified by *QueryConfig::query_max_memory_per_node*) by reclaiming overused memory through `disk spilling `_ (refer to `memory arbitrator @@ -171,8 +171,8 @@ allocator guarantees the actual allocated memory are within the system memory limit no matter if it is for system operation or for user query execution. In practice, we shall reserve some space from the memory allocator to compensate for such system memory usage. We can do that by configuring the query -memory limit (*MemoryManagerOptions::arbitratorCapacity*) to be smaller than the system memory -limit (*MemoryManagerOptions::allocatorCapacity*) (refer to `OOM prevention section <#server-oom-prevention>`_ +memory limit (*MemoryManager::Options::arbitratorCapacity*) to be smaller than the system memory +limit (*MemoryManager::Options::allocatorCapacity*) (refer to `OOM prevention section <#server-oom-prevention>`_ for detail). Memory System Setup @@ -186,7 +186,7 @@ Here is the code block from Prestissimo that initializes the Velox memory system void PrestoServer::initializeVeloxMemory() { auto* systemConfig = SystemConfig::instance(); const uint64_t memoryGb = systemConfig->systemMemoryGb(); - MemoryManagerOptions options; + MemoryManager::Options options; options.allocatorCapacity = memoryGb << 30; options.useMmapAllocator = systemConfig->useMmapAllocator(); if (!systemConfig->memoryArbitratorKind().empty()) { @@ -217,7 +217,7 @@ Here is the code block from Prestissimo that initializes the Velox memory system * L10: set the memory arbitrator capacity (query memory limit) from the Prestissimo system config * L13: creates the process-wide memory manager which creates memory - allocator and arbitrator inside based on MemoryManagerOptions initialized from previous steps + allocator and arbitrator inside based on MemoryManager::Options initialized from previous steps * L15-19: creates the file cache if it is enabled in Prestissimo system config @@ -334,7 +334,7 @@ root memory pool to reduce the cpu cost, the leaf memory pool does quantized memory reservation. It rounds up the actual reservation bytes to the next large quantized reservation value (MemoryPool::quantizedSize): -- round up to next MB if size < 16MB +- round up to next 1MB if size < 16MB - round up to next 4MB if size < 64MB - round up to next 8MB if size >= 64MB @@ -539,10 +539,10 @@ between queries by adjusting their memory pool’s capacities accordingly (see The *MemoryArbitrator* is defined to support different implementations for different query systems. As for now, we implement *SharedArbitrator* for both -Prestissimo and Prestissimo-on-Spark. `Gulten `_ implements its own memory +Prestissimo and Prestissimo-on-Spark. `Gluten `_ implements its own memory arbitrator to integrate with the `Spark memory system `_. *SharedArbitrator* ensures the total allocated memory capacity is within the query memory limit -(*MemoryManagerOptions::arbitratorCapacity*), and also ensures each individual +(*MemoryManager::Options::arbitratorCapacity*), and also ensures each individual query’s capacity is within the per-query memory limit (*MemoryPool::maxCapacity_*). When a query needs to grow its capacity, *SharedArbitrator* either reclaims the used memory from the query itself if it has exceeded its max memory capacity, @@ -700,7 +700,7 @@ control on RSS. We haven't yet confirmed whether *MmapAllocator* works better than *MallocAllocator*, but we are able to run a sizable Prestissimo workload using it. We will compare that workload using two allocators to determine which one is better in the future. Users can choose the allocator for their -application by setting *MemoryManagerOptions::useMmapAllocator* (see +application by setting *MemoryManager::Options::useMmapAllocator* (see `memory system setup section <#memory-system-setup>`_ for example). Non-Contiguous Allocation diff --git a/velox/docs/develop/operators.rst b/velox/docs/develop/operators.rst index 9af21bcdacd1..108213e1b9c6 100644 --- a/velox/docs/develop/operators.rst +++ b/velox/docs/develop/operators.rst @@ -547,6 +547,8 @@ and emitting results. - Join type: inner, left, right, full, left semi filter, left semi project, right semi filter, right semi project, anti. You can read about different join types in this `blog post `_. * - nullAware - Applies to anti and semi project joins only. Indicates whether the join semantic is IN (nullAware = true) or EXISTS (nullAware = false). + * - useHashTableCache + - Optional. Used only by Presto-on-Spark. When true, enables caching of the hash table built for broadcast joins so that subsequent tasks can reuse it. * - leftKeys - Columns from the left hand side input that are part of the equality condition. At least one must be specified. * - rightKeys @@ -652,8 +654,12 @@ The unnest operation expands arrays and maps into separate columns. Arrays are expanded into a single column, and maps are expanded into two columns (key, value). Can be used to expand multiple columns. In this case produces as many rows as the highest cardinality array or map (the other columns are padded -with nulls). Optionally can produce an ordinality column that specifies the row -number starting with 1. +with nulls). Optionally, it can include an ordinality column to indicate the row +number starting from 1, and an emptyUnnestValue column to indicate whether an +output row has empty unnest value or not. If the ordinality column is specified +along with the emptyUnnestValue column, the ordinality for the output row with +empty unnest values is set to zero. If the emptyUnnestValue column is not specified, +output rows with empty unnest values are not produced. .. list-table:: :widths: 10 30 @@ -670,6 +676,8 @@ number starting with 1. - Names to use for expanded columns. One name per array column. Two names per map column. * - ordinalityName - Optional name for the ordinality column. + * - emptyUnnestValueName + - Optional name for the emptyUnnestValue column. .. _TableWriteNode: @@ -941,7 +949,7 @@ results available before seeing all input. TopNRowNumberNode ~~~~~~~~~~~~~~~~~ -An optimized version of a WindowNode with a single row_number function and a +An optimized version of a WindowNode with a single row_number, rank or dense_rank function and a limit over sorted partitions. Partitions the input using specified partitioning keys and maintains up to @@ -949,11 +957,11 @@ a 'limit' number of top rows for each partition. After receiving all input, assigns row numbers within each partition starting from 1. This operator accumulates state: a hash table mapping partition keys to a list -of top 'limit' rows within that partition. Returning the row numbers as +of top 'limit' rows within that partition. Returning the row number or rank as a column in the output is optional. This operator supports spilling as well. This operator is logically equivalent to a WindowNode followed by -FilterNode(row_number <= limit), but it uses less memory and CPU. +FilterNode(rank/row_number <= limit), but it uses less memory and CPU. .. list-table:: :widths: 10 30 diff --git a/velox/docs/develop/scalar-functions.rst b/velox/docs/develop/scalar-functions.rst index 19688defb94a..663aa6bd0a79 100644 --- a/velox/docs/develop/scalar-functions.rst +++ b/velox/docs/develop/scalar-functions.rst @@ -722,7 +722,7 @@ cardinality function for maps: BaseVector::ensureWritable(rows, BIGINT(), context.pool(), result); BufferPtr resultValues = - result->as>()->mutableValues(rows.size()); + result->as>()->mutableValues(); auto rawResult = resultValues->asMutable(); auto mapVector = args[0]->as(); diff --git a/velox/docs/develop/testing.rst b/velox/docs/develop/testing.rst index 93f92b50a04e..e9dabed5229a 100644 --- a/velox/docs/develop/testing.rst +++ b/velox/docs/develop/testing.rst @@ -11,5 +11,6 @@ Testing Tools testing/join-fuzzer testing/memory-arbitration-fuzzer testing/row-number-fuzzer + testing/spatial-join-fuzzer testing/writer-fuzzer testing/spark-query-runner.rst diff --git a/velox/docs/develop/testing/memory-arbitration-fuzzer.rst b/velox/docs/develop/testing/memory-arbitration-fuzzer.rst index 2895138a6faf..005b80d4ddb1 100644 --- a/velox/docs/develop/testing/memory-arbitration-fuzzer.rst +++ b/velox/docs/develop/testing/memory-arbitration-fuzzer.rst @@ -9,8 +9,8 @@ It works as follows: 1. Data Generation: It starts by generating a random set of input data, also known as a vector. This data can have a variety of encodings and data layouts to ensure thorough testing. -2. Plan Generation: Generate multiple plans with different query shapes. Currently, it supports HashJoin and - HashAggregation plans. +2. Plan Generation: Generate multiple plans with different query shapes. Currently, it supports HashJoin, + HashAggregation, RowNumber, TopNRowNumber, and OrderBy plans. 3. Query Execution: Create multiple threads, each thread randomly picks a plan with spill enabled or not, and repeatedly running this process until ${iteration_duration_sec} seconds. The query thread expects query to succeed or fail with query OOM or abort errors, otherwise it throws. @@ -19,11 +19,11 @@ It works as follows: How to run ---------- -Use velox_memory_arbitration_fuzzer_test binary to run this fuzzer: +Use velox_memory_arbitration_fuzzer binary to run this fuzzer: :: - velox/exec/tests/velox_memory_arbitration_fuzzer_test --seed 123 --duration_sec 60 + velox/exec/tests/velox_memory_arbitration_fuzzer --seed 123 --duration_sec 60 By default, the fuzzer will go through 10 iterations. Use --steps or --duration-sec flag to run fuzzer for longer. Use --seed to diff --git a/velox/docs/develop/testing/row-number-fuzzer.rst b/velox/docs/develop/testing/row-number-fuzzer.rst index f90b33f76210..f90edce12ed5 100644 --- a/velox/docs/develop/testing/row-number-fuzzer.rst +++ b/velox/docs/develop/testing/row-number-fuzzer.rst @@ -1,14 +1,16 @@ -================ -RowNumber Fuzzer -================ +================================== +RowNumber and TopNRowNumber Fuzzer +================================== -The RowNumberFuzzer is a testing tool that automatically generate equivalent query plans and then executes these plans -to validate the consistency of the results. It works as follows: +The RowNumberFuzzer and TopNRowNumberFuzzer are testing tools that automatically generate equivalent query plans that +use the RowNumber and TopNRowNumber Velox plan nodes, and then execute these plans to validate the consistency of +the results. They works as follows: -1. Data Generation: It starts by generating a random set of input data, also known as a vector. This data can +1. Data Generation: Generate a random set of input data, also known as a vector. This data can have a variety of encodings and data layouts to ensure thorough testing. -2. Plan Generation: Generate two equivalent query plans, one is row-number over ValuesNode as the base plan. - and the other is over TableScanNode as the alter plan. +2. Plan Generation: Generate two equivalent query plans: one is RowNumber over ValuesNode as + the base plan and the other is over TableScanNode as the alternative plan. The TopNRowNumberFuzzer generates similar + plans with TopNRowNumber node instead. 3. Query Execution: Executes those equivalent query plans using the generated data and asserts that the results are consistent across different plans. i. Execute the base plan, compare the result with the reference (DuckDB or Presto) and use it as the expected result. @@ -19,12 +21,16 @@ to validate the consistency of the results. It works as follows: How to run ---------- -Use velox_row_number_fuzzer binary to run rowNumber fuzzer: - +Use velox_row_number_fuzzer to run RowNumberFuzzer :: velox/exec/fuzzer/velox_row_number_fuzzer --seed 123 --duration_sec 60 +Similarly, use velox_topn_row_number_fuzzer to run TopNRowNumberFuzzer +:: + + velox/exec/fuzzer/velox_topn_row_number_fuzzer --seed 123 --duration_sec 60 + By default, the fuzzer will go through 10 iterations. Use --steps or --duration-sec flag to run fuzzer for longer. Use --seed to reproduce fuzzer failures. diff --git a/velox/docs/develop/testing/spatial-join-fuzzer.rst b/velox/docs/develop/testing/spatial-join-fuzzer.rst new file mode 100644 index 000000000000..519ab3590932 --- /dev/null +++ b/velox/docs/develop/testing/spatial-join-fuzzer.rst @@ -0,0 +1,118 @@ +==================== +Spatial Join Fuzzer +==================== + +Overview +======== + +The Spatial Join Fuzzer tests the correctness of the SpatialJoin operator by generating random geometry data and spatial join plans. It verifies that SpatialJoin produces the same results as NestedLoopJoin for equivalent queries. + + +Supported Features +================== + +Join Types +---------- + +The fuzzer tests the two join types supported by SpatialJoin (as defined in ``SpatialJoinNode::isSupported()``): + +* **INNER** - Only matching rows from both sides +* **LEFT** - All rows from left side, matched rows from right side + +Spatial Predicates +------------------ + +The fuzzer tests these spatial predicates: + +* ``ST_Intersects(geometry1, geometry2)`` - Tests if geometries intersect +* ``ST_Contains(geometry1, geometry2)`` - Tests if one geometry contains another +* ``ST_Within(geometry1, geometry2)`` - Tests if one geometry is within another +* ``ST_Distance(geometry1, geometry2) < threshold`` - Tests distance with threshold + +Geometry Types +-------------- + +The fuzzer generates Well-Known Text (WKT) strings for three geometry types: + +* **POINT** - Single coordinate point (e.g., ``POINT (10.5 20.3)``) +* **POLYGON** - Closed shape with vertices +* **LINESTRING** - Line segment between two points + +Distribution Patterns +--------------------- + +Geometries are generated using three distribution patterns: + +* **Uniform** - Geometries uniformly distributed in space (0-1000 range) +* **Clustered** - Geometries grouped in 5 specific regions to test overlap scenarios +* **Sparse** - Geometries widely spread (0-2000 range) with low overlap probability + +Implementation Details +====================== + + +Geometry Generation +------------------- + +Geometries are generated using ``AbstractInputGenerator`` subclasses: + +* ``PointInputGenerator`` - Generates POINT WKT strings +* ``PolygonInputGenerator`` - Generates POLYGON WKT strings +* ``LineStringInputGenerator`` - Generates LINESTRING WKT strings + +Each generator implements the ``generate(vector_size_t index)`` method to produce geometry strings based on the distribution pattern. + +**Uniform Distribution**:: + + x = random(0, 1000) + y = random(0, 1000) + POINT (x y) + +**Clustered Distribution**:: + + cluster = row % 5 // 5 clusters + centerX = cluster * 200 + 100 + centerY = cluster * 200 + 100 + x = centerX + random(-50, 50) + y = centerY + random(-50, 50) + POINT (x y) + +**Sparse Distribution**:: + + x = random(0, 2000) // Larger Range + y = random(0, 2000) + POINT (x y) + +Data Matching Strategy +---------------------- + +To ensure some matches occur during joins: + +* Build side copies ~30% of geometries from probe side +* 10% chance of empty build side to test edge cases + +Verification +------------ + +The fuzzer compares results from two equivalent plans: + +1. **SpatialJoin plan** - Using the specialized SpatialJoin operator +2. **NestedLoopJoin plan** - Using NestedLoopJoin with the same spatial predicate as a filter + +Results must match exactly, validating that SpatialJoin implements spatial predicates correctly. + +Key Differences from JoinFuzzer +================================ + +Join Conditions +--------------- + +Unlike regular joins with simple equality predicates:: + + // Regular join + probe.id = build.id + + // Spatial join + ST_Intersects(probe_geom, build_geom) + +Spatial joins use **function call expressions** as join conditions rather than simple column references. diff --git a/velox/docs/develop/types.rst b/velox/docs/develop/types.rst index 0253c9266c3a..9a420bf7a607 100644 --- a/velox/docs/develop/types.rst +++ b/velox/docs/develop/types.rst @@ -115,6 +115,7 @@ DATE INTEGER DECIMAL BIGINT if precision <= 18, HUGEINT if precision >= 19 INTERVAL DAY TO SECOND BIGINT INTERVAL YEAR TO MONTH INTEGER +TIME BIGINT ====================== ====================================================== DECIMAL type carries additional `precision`, @@ -130,6 +131,9 @@ upto 38 precision, with a range of :math:`[-10^{38} + 1, +10^{38} - 1]`. All the three values, precision, scale, unscaled value are required to represent a decimal value. +TIME type represents time in milliseconds from midnight UTC. Thus min/max value can range from UTC-14:00 at 00:00:00 to UTC+14:00 at 23:59:59.999 modulo 24 hours. +TIME type is backed by BIGINT physical type. + Custom Types ~~~~~~~~~~~~ Most custom types can be represented as logical types and can be built by extending @@ -173,7 +177,14 @@ TIMESTAMP WITH TIME ZONE BIGINT UUID HUGEINT IPADDRESS HUGEINT IPPREFIX ROW(HUGEINT,TINYINT) +BINGTILE BIGINT GEOMETRY VARBINARY +SPHERICALGEOGRAPHY VARBINARY +TDIGEST VARBINARY +QDIGEST VARBINARY +BIGINT_ENUM BIGINT +VARCHAR_ENUM VARCHAR +TIME WITH TIME ZONE BIGINT ======================== ===================== TIMESTAMP WITH TIME ZONE represents a time point in milliseconds precision @@ -209,6 +220,53 @@ As a result the IPPREFIX object stores *FFFF:FFFF::* and the length 32 for both IPPREFIX 'FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF/32' -- IPPREFIX 'FFFF:FFFF:0000:0000:0000:0000:0000:0000/32' IPPREFIX 'FFFF:FFFF:4455:6677:8899:AABB:CCDD:EEFF/32' -- IPPREFIX 'FFFF:FFFF:0000:0000:0000:0000:0000:0000/32' +TDIGEST(DOUBLE) is a data sketch for estimating rank-based metrics. +T-digests may be merged without losing precision, and for storage and retrieval +they may be cast to/from VARBINARY. The T-digest accepts a parameter of type +DOUBLE which represents the set of numbers to be ingested by the T-digest. + +QDIGEST(BIGINT), QDIGEST(REAL), QDIGEST(DOUBLE) are data sketches for +estimating rank-based metrics. A quantile digest captures the approximate distribution of +data for a given input set, and can be queried to retrieve approximate quantile values from the +distribution. They may be merged without losing precision, and for storage and retrieval they may +be cast to/from VARBINARY. The parameter type (BIGINT, REAL, or DOUBLE) represents +the set of numbers that may be ingested by the quantile digest. + +BIGINT_ENUM(LongEnumParameter) type represents an enumerated value where the physical type is BIGINT. +It takes one LongEnumParameter as parameter, which consists of a string name and a mapping of +string keys to BIGINT values. +There is a static cache which stores instances of different BIGINT_ENUM types. This is to treat each +different enum type as a singleton. The LongEnumParameter is used as the key to retrieve the cached instance, +and a new instance is only created if it has not been created with the given LongEnumParameter. +Casting is permitted from any integer type to an enum type. Casting is only permitted from an enum type +to a BIGINT type. Casting between different enum types is not permitted. +Comparison operations are only allowed between values of the same enum type. + +VARCHAR_ENUM(VarcharEnumParameter) type represents an enumerated value where the physical type is VARCHAR. +It takes one VarcharEnumParameter as parameter, which consists of a string name and a mapping of +string keys to VARCHAR values. +Similar to BIGINT_ENUM, there is a static cache which stores instances of different VARCHAR_ENUM types, with the +VarcharEnumParameter as the key. +Casting is only permitted to and from VARCHAR type, and is case-sensitive. Casting between different enum types is not permitted. +Comparison operations are only allowed between values of the same enum type. + +TIME WITH TIME ZONE represents time from midnight in milliseconds precision at a particular timezone. +Its physical type is BIGINT. The high 52 bits of bigint store signed integer for milliseconds in UTC. +The lower 12 bits store the time zone offsets minutes. This allows the time to be converted at any point of +time without ambiguity of daylight savings time. Time zone offsets range from -14:00 hours to +14:00 hours. + +BINGTILE represents a `Bing tile `_. +It is a quadtree in the Web Mercator projection, where each tile is 256x256 pixels. Its physical type is BIGINT. + +GEOMETRY represents a geometry as defined in `Simple Feature Access `_. +Subtypes include Point, MultiPoint, LineString, MultiLineString, Polygon, MultiPolygon, and GeometryCollection. They +are often stored as `Well-Known Text `_ or +`Well-Known Binary `_. + +SPHERICALGEOGRAPHY represents a geometry on a spherical model of the Earth. It is internally represented the same +way as GEOMETRY, but only certain functions are supported. Moreover, these functions will return values in meters +as opposed to the units of the coordinate space. + Spark Types ~~~~~~~~~~~~ The `data types `_ in Spark have some semantic differences compared to those in diff --git a/velox/docs/ext/delta.py b/velox/docs/ext/delta.py new file mode 100644 index 000000000000..40a76426a2e0 --- /dev/null +++ b/velox/docs/ext/delta.py @@ -0,0 +1,773 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generating delta function link :delta:func:``""" + +from __future__ import annotations + +from function import ( + function_sig_re, + pairindextypes, + parse_arglist, + pseudo_parse_arglist, + parse_annotation, + ObjectEntry, + ModuleEntry, +) +from typing import Any, Iterable, Iterator, Tuple, cast + +from docutils import nodes +from docutils.nodes import Element, Node +from docutils.parsers.rst import directives +from sphinx import addnodes +from sphinx.addnodes import desc_signature, pending_xref +from sphinx.application import Sphinx +from sphinx.builders import Builder +from sphinx.directives import ObjectDescription +from sphinx.domains import Domain, Index, IndexEntry, ObjType +from sphinx.environment import BuildEnvironment +from sphinx.locale import _, __ +from sphinx.roles import XRefRole +from sphinx.util import logging +from sphinx.util.docfields import Field +from sphinx.util.nodes import ( + find_pending_xref_condition, + make_id, + make_refnode, +) +from sphinx.util.typing import OptionSpec + +logger = logging.getLogger(__name__) + +function_module = "delta" + + +class DeltaObject(ObjectDescription[Tuple[str, str]]): + """ + Description of a general Delta object. + + :cvar allow_nesting: Class is an object that allows for nested namespaces + :vartype allow_nesting: bool + """ + + option_spec: OptionSpec = { + "noindex": directives.flag, + "noindexentry": directives.flag, + "nocontentsentry": directives.flag, + "module": directives.unchanged, + "canonical": directives.unchanged, + "annotation": directives.unchanged, + } + + doc_field_types = [ + Field( + "returnvalue", + label=_("Returns"), + has_arg=False, + names=("returns", "return"), + ), + ] + + allow_nesting = False + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + """May return a prefix to put before the object name in the + signature. + """ + return [] + + def needs_arglist(self) -> bool: + """May return true if an empty argument list is to be generated even if + the document contains none. + """ + return False + + def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str]: + """Transform a Delta signature into RST nodes. + Return (fully qualified name of the thing, classname if any). + If inside a class, the current class name is handled intelligently: + * it is stripped from the displayed name if present + * it is added to the full name (return value) if not present + """ + m = function_sig_re.match(sig) + if m is None: + raise ValueError + prefix, name, arglist, retann = m.groups() + + # determine module and class name (if applicable), as well as full name + modname = self.options.get("module", self.env.ref_context.get("delta:module")) + classname = self.env.ref_context.get("delta:class") + if classname: + add_module = False + if prefix and (prefix == classname or prefix.startswith(classname + ".")): + fullname = prefix + name + # class name is given again in the signature + prefix = prefix[len(classname) :].lstrip(".") + elif prefix: + # class name is given in the signature, but different + # (shouldn't happen) + fullname = classname + "." + prefix + name + else: + # class name is not given in the signature + fullname = classname + "." + name + else: + add_module = True + if prefix: + classname = prefix.rstrip(".") + fullname = prefix + name + else: + classname = "" + fullname = name + + signode["module"] = modname + signode["class"] = classname + signode["fullname"] = fullname + + sig_prefix = self.get_signature_prefix(sig) + if sig_prefix: + if type(sig_prefix) is str: + raise TypeError( + "Python directive method get_signature_prefix()" + " must return a list of nodes." + f" Return value was '{sig_prefix}'." + ) + else: + signode += addnodes.desc_annotation(str(sig_prefix), "", *sig_prefix) + + if prefix: + signode += addnodes.desc_addname(prefix, prefix) + elif modname and add_module and self.env.config.add_module_names: + nodetext = modname + "." + signode += addnodes.desc_addname(nodetext, nodetext) + + signode += addnodes.desc_name(name, name) + if arglist: + try: + signode += parse_arglist(function_module, arglist, self.env) + except SyntaxError: + # fallback to parse arglist original parser. + # it supports to represent optional arguments (ex. "func(foo [, bar])") + pseudo_parse_arglist(signode, arglist) + except NotImplementedError as exc: + logger.warning( + "could not parse arglist (%r): %s", arglist, exc, location=signode + ) + pseudo_parse_arglist(signode, arglist) + else: + if self.needs_arglist(): + # for callables, add an empty parameter list + signode += addnodes.desc_parameterlist() + + if retann: + children = parse_annotation(function_module, retann, self.env) + signode += addnodes.desc_returns(retann, "", *children) + + anno = self.options.get("annotation") + if anno: + signode += addnodes.desc_annotation( + " " + anno, "", addnodes.desc_sig_space(), nodes.Text(anno) + ) + + return fullname, prefix + + def _object_hierarchy_parts(self, sig_node: desc_signature) -> tuple[str, ...]: + if "fullname" not in sig_node: + return () + modname = sig_node.get("module") + fullname = sig_node["fullname"] + + if modname: + return (modname, *fullname.split(".")) + else: + return tuple(fullname.split(".")) + + def get_index_text(self, modname: str, name: tuple[str, str]) -> str: + """Return the text for the index entry of the object.""" + raise NotImplementedError("must be implemented in subclasses") + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + modname = self.options.get("module", self.env.ref_context.get("delta:module")) + fullname = (modname + "." if modname else "") + name_cls[0] + node_id = make_id(self.env, self.state.document, "", fullname) + signode["ids"].append(node_id) + self.state.document.note_explicit_target(signode) + + domain = cast(DeltaDomain, self.env.get_domain("delta")) + domain.note_object(fullname, self.objtype, node_id, location=signode) + + canonical_name = self.options.get("canonical") + if canonical_name: + domain.note_object( + canonical_name, self.objtype, node_id, aliased=True, location=signode + ) + + if "noindexentry" not in self.options: + indextext = self.get_index_text(modname, name_cls) + if indextext: + self.indexnode["entries"].append( + ("single", indextext, node_id, "", None) + ) + + def before_content(self) -> None: + """Handle object nesting before content + + For constructs that aren't nestable, the stack is bypassed, and instead + only the most recent object is tracked. This object prefix name will be + removed with :delta:meth:`after_content`. + """ + prefix = None + if self.names: + # fullname and name_prefix come from the `handle_signature` method. + # fullname represents the full object name that is constructed using + # object nesting and explicit prefixes. `name_prefix` is the + # explicit prefix given in a signature + (fullname, name_prefix) = self.names[-1] + if self.allow_nesting: + prefix = fullname + elif name_prefix: + prefix = name_prefix.strip(".") + if prefix: + self.env.ref_context["delta:class"] = prefix + if self.allow_nesting: + classes = self.env.ref_context.setdefault("delta:classes", []) + classes.append(prefix) + if "module" in self.options: + modules = self.env.ref_context.setdefault("delta:modules", []) + modules.append(self.env.ref_context.get("delta:module")) + self.env.ref_context["delta:module"] = self.options["module"] + + def after_content(self) -> None: + """Handle object de-nesting after content + + If this class is a nestable object, removing the last nested class prefix + ends further nesting in the object. + + If this class is not a nestable object, the list of classes should not + be altered as we didn't affect the nesting levels in + :delta:meth:`before_content`. + """ + classes = self.env.ref_context.setdefault("delta:classes", []) + if self.allow_nesting: + try: + classes.pop() + except IndexError: + pass + self.env.ref_context["delta:class"] = classes[-1] if len(classes) > 0 else None + if "module" in self.options: + modules = self.env.ref_context.setdefault("delta:modules", []) + if modules: + self.env.ref_context["delta:module"] = modules.pop() + else: + self.env.ref_context.pop("delta:module") + + def _toc_entry_name(self, sig_node: desc_signature) -> str: + if not sig_node.get("_toc_parts"): + return "" + + config = self.env.app.config + objtype = sig_node.parent.get("objtype") + if config.add_function_parentheses and objtype in {"function", "method"}: + parens = "()" + else: + parens = "" + *parents, name = sig_node["_toc_parts"] + if config.toc_object_entries_show_parents == "domain": + return sig_node.get("fullname", name) + parens + if config.toc_object_entries_show_parents == "hide": + return name + parens + if config.toc_object_entries_show_parents == "all": + return ".".join(parents + [name + parens]) + return "" + + +class DeltaFunction(DeltaObject): + """Description of a function.""" + + option_spec: OptionSpec = DeltaObject.option_spec.copy() + option_spec.update( + { + "async": directives.flag, + } + ) + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + if "async" in self.options: + return [addnodes.desc_sig_keyword("", "async"), addnodes.desc_sig_space()] + else: + return [] + + def needs_arglist(self) -> bool: + return True + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + super().add_target_and_index(name_cls, sig, signode) + if "noindexentry" not in self.options: + modname = self.options.get( + "module", self.env.ref_context.get("delta:module") + ) + node_id = signode["ids"][0] + + name, cls = name_cls + if modname: + text = _("%s() (in module %s)") % (name, modname) + self.indexnode["entries"].append(("single", text, node_id, "", None)) + else: + text = f"{pairindextypes['builtin']}; {name}()" + self.indexnode["entries"].append(("pair", text, node_id, "", None)) + + def get_index_text(self, modname: str, name_cls: tuple[str, str]) -> str | None: + # add index in own add_target_and_index() instead. + return None + + +class DeltaXRefRole(XRefRole): + def process_link( + self, + env: BuildEnvironment, + refnode: Element, + has_explicit_title: bool, + title: str, + target: str, + ) -> tuple[str, str]: + refnode["delta:module"] = env.ref_context.get("delta:module") + refnode["delta:class"] = env.ref_context.get("delta:class") + if not has_explicit_title: + title = title.lstrip(".") # only has a meaning for the target + target = target.lstrip("~") # only has a meaning for the title + # if the first character is a tilde, don't display the module/class + # parts of the contents + if title[0:1] == "~": + title = title[1:] + dot = title.rfind(".") + if dot != -1: + title = title[dot + 1 :] + # if the first character is a dot, search more specific namespaces first + # else search builtins first + if target[0:1] == ".": + target = target[1:] + refnode["refspecific"] = True + return title, target + + +class DeltaModuleIndex(Index): + """ + Index subclass to provide the Delta module index. + """ + + name = "modindex" + localname = _("Delta Module Index") + shortname = _("modules") + + def generate( + self, docnames: Iterable[str] | None = None + ) -> tuple[list[tuple[str, list[IndexEntry]]], bool]: + content: dict[str, list[IndexEntry]] = {} + # list of prefixes to ignore + ignores: list[str] = self.domain.env.config["modindex_common_prefix"] + ignores = sorted(ignores, key=len, reverse=True) + # list of all modules, sorted by module name + modules = sorted( + self.domain.data["modules"].items(), key=lambda x: x[0].lower() + ) + # sort out collapsible modules + prev_modname = "" + num_toplevels = 0 + for modname, (docname, node_id, synopsis, platforms, deprecated) in modules: + if docnames and docname not in docnames: + continue + + for ignore in ignores: + if modname.startswith(ignore): + modname = modname[len(ignore) :] + stripped = ignore + break + else: + stripped = "" + + # we stripped the whole module name? + if not modname: + modname, stripped = stripped, "" + + entries = content.setdefault(modname[0].lower(), []) + + package = modname.split(".")[0] + if package != modname: + # it's a submodule + if prev_modname == package: + # first submodule - make parent a group head + if entries: + last = entries[-1] + entries[-1] = IndexEntry( + last[0], 1, last[2], last[3], last[4], last[5], last[6] + ) + elif not prev_modname.startswith(package): + # submodule without parent in list, add dummy entry + entries.append( + IndexEntry(stripped + package, 1, "", "", "", "", "") + ) + subtype = 2 + else: + num_toplevels += 1 + subtype = 0 + + qualifier = _("Deprecated") if deprecated else "" + entries.append( + IndexEntry( + stripped + modname, + subtype, + docname, + node_id, + platforms, + qualifier, + synopsis, + ) + ) + prev_modname = modname + + # apply heuristics when to collapse modindex at page load: + # only collapse if number of toplevel modules is larger than + # number of submodules + collapse = len(modules) - num_toplevels < num_toplevels + + # sort by first letter + sorted_content = sorted(content.items()) + + return sorted_content, collapse + + +class DeltaDomain(Domain): + """Delta domain.""" + + name = "delta" + label = "Delta" + object_types: dict[str, ObjType] = { + "function": ObjType(_("function"), "func", "obj"), + } + + directives = { + "function": DeltaFunction, + } + roles = { + "func": DeltaXRefRole(fix_parens=True), + } + initial_data: dict[str, dict[str, tuple[Any]]] = { + "objects": {}, # fullname -> docname, objtype + "modules": {}, # modname -> docname, synopsis, platform, deprecated + } + indices = [ + DeltaModuleIndex, + ] + + @property + def objects(self) -> dict[str, ObjectEntry]: + return self.data.setdefault("objects", {}) # fullname -> ObjectEntry + + def note_object( + self, + name: str, + objtype: str, + node_id: str, + aliased: bool = False, + location: Any = None, + ) -> None: + """Note a delta object for cross reference. + + .. versionadded:: 2.1 + """ + if name in self.objects: + other = self.objects[name] + if other.aliased and aliased is False: + # The original definition found. Override it! + pass + elif other.aliased is False and aliased: + # The original definition is already registered. + return + else: + # duplicated + logger.warning( + __( + "duplicate object description of %s, " + "other instance in %s, use :noindex: for one of them" + ), + name, + other.docname, + location=location, + ) + self.objects[name] = ObjectEntry(self.env.docname, node_id, objtype, aliased) + + @property + def modules(self) -> dict[str, ModuleEntry]: + return self.data.setdefault("modules", {}) # modname -> ModuleEntry + + def note_module( + self, name: str, node_id: str, synopsis: str, platform: str, deprecated: bool + ) -> None: + """Note a delta module for cross reference. + + .. versionadded:: 2.1 + """ + self.modules[name] = ModuleEntry( + self.env.docname, node_id, synopsis, platform, deprecated + ) + + def clear_doc(self, docname: str) -> None: + for fullname, obj in list(self.objects.items()): + if obj.docname == docname: + del self.objects[fullname] + for modname, mod in list(self.modules.items()): + if mod.docname == docname: + del self.modules[modname] + + def merge_domaindata(self, docnames: list[str], otherdata: dict[str, Any]) -> None: + # XXX check duplicates? + for fullname, obj in otherdata["objects"].items(): + if obj.docname in docnames: + self.objects[fullname] = obj + for modname, mod in otherdata["modules"].items(): + if mod.docname in docnames: + self.modules[modname] = mod + + def find_obj( + self, + env: BuildEnvironment, + modname: str, + classname: str, + name: str, + type: str | None, + searchmode: int = 0, + ) -> list[tuple[str, ObjectEntry]]: + """Find a Delta object for "name", perhaps using the given module + and/or classname. Returns a list of (name, object entry) tuples. + """ + # skip parens + if name[-2:] == "()": + name = name[:-2] + + if not name: + return [] + + matches: list[tuple[str, ObjectEntry]] = [] + + newname = None + if searchmode == 1: + if type is None: + objtypes = list(self.object_types) + else: + objtypes = self.objtypes_for_role(type) + if objtypes is not None: + if modname and classname: + fullname = modname + "." + classname + "." + name + if ( + fullname in self.objects + and self.objects[fullname].objtype in objtypes + ): + newname = fullname + if not newname: + if ( + modname + and modname + "." + name in self.objects + and self.objects[modname + "." + name].objtype in objtypes + ): + newname = modname + "." + name + elif ( + name in self.objects and self.objects[name].objtype in objtypes + ): + newname = name + else: + # "fuzzy" searching mode + searchname = "." + name + matches = [ + (oname, self.objects[oname]) + for oname in self.objects + if oname.endswith(searchname) + and self.objects[oname].objtype in objtypes + ] + else: + # NOTE: searching for exact match, object type is not considered + if name in self.objects: + newname = name + elif type == "mod": + # only exact matches allowed for modules + return [] + elif classname and classname + "." + name in self.objects: + newname = classname + "." + name + elif modname and modname + "." + name in self.objects: + newname = modname + "." + name + elif ( + modname + and classname + and modname + "." + classname + "." + name in self.objects + ): + newname = modname + "." + classname + "." + name + if newname is not None: + matches.append((newname, self.objects[newname])) + return matches + + def resolve_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + type: str, + target: str, + node: pending_xref, + contnode: Element, + ) -> Element | None: + modname = node.get("delta:module") + clsname = node.get("delta:class") + searchmode = 1 if node.hasattr("refspecific") else 0 + matches = self.find_obj(env, modname, clsname, target, type, searchmode) + + if not matches and type == "attr": + # fallback to meth (for property; Sphinx-2.4.x) + # this ensures that `:attr:` role continues to refer to the old property entry + # that defined by ``method`` directive in old reST files. + matches = self.find_obj(env, modname, clsname, target, "meth", searchmode) + if not matches and type == "meth": + # fallback to attr (for property) + # this ensures that `:meth:` in the old reST files can refer to the property + # entry that defined by ``property`` directive. + # + # Note: _prop is a secret role only for internal look-up. + matches = self.find_obj(env, modname, clsname, target, "_prop", searchmode) + + if not matches: + return None + elif len(matches) > 1: + canonicals = [m for m in matches if not m[1].aliased] + if len(canonicals) == 1: + matches = canonicals + else: + logger.warning( + __("more than one target found for cross-reference %r: %s"), + target, + ", ".join(match[0] for match in matches), + type="ref", + subtype="python", + location=node, + ) + name, obj = matches[0] + + if obj[2] == "module": + return self._make_module_refnode(builder, fromdocname, name, contnode) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + return make_refnode(builder, fromdocname, obj[0], obj[1], children, name) + + def resolve_any_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + target: str, + node: pending_xref, + contnode: Element, + ) -> list[tuple[str, Element]]: + modname = node.get("delta:module") + clsname = node.get("delta:class") + results: list[tuple[str, Element]] = [] + + # always search in "refspecific" mode with the :any: role + matches = self.find_obj(env, modname, clsname, target, None, 1) + multiple_matches = len(matches) > 1 + + for name, obj in matches: + if multiple_matches and obj.aliased: + # Skip duplicated matches + continue + + if obj[2] == "module": + results.append( + ( + "delta:mod", + self._make_module_refnode(builder, fromdocname, name, contnode), + ) + ) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + results.append( + ( + "delta:" + self.role_for_objtype(obj[2]), + make_refnode( + builder, fromdocname, obj[0], obj[1], children, name + ), + ) + ) + return results + + def _make_module_refnode( + self, builder: Builder, fromdocname: str, name: str, contnode: Node + ) -> Element: + # get additional info for modules + module = self.modules[name] + title = name + if module.synopsis: + title += ": " + module.synopsis + if module.deprecated: + title += _(" (deprecated)") + if module.platform: + title += " (" + module.platform + ")" + return make_refnode( + builder, fromdocname, module.docname, module.node_id, contnode, title + ) + + def get_objects(self) -> Iterator[tuple[str, str, str, str, str, int]]: + for modname, mod in self.modules.items(): + yield (modname, modname, "module", mod.docname, mod.node_id, 0) + for refname, obj in self.objects.items(): + if obj.objtype != "module": # modules are already handled + if obj.aliased: + # aliased names are not full-text searchable. + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, -1) + else: + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, 1) + + def get_full_qualified_name(self, node: Element) -> str | None: + modname = node.get("delta:module") + clsname = node.get("delta:class") + target = node.get("reftarget") + if target is None: + return None + else: + return ".".join(filter(None, [modname, clsname, target])) + + +def setup(app: Sphinx) -> dict[str, Any]: + app.setup_extension("sphinx.directives") + app.add_domain(DeltaDomain) + + return { + "version": "builtin", + "env_version": 3, + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/velox/docs/ext/function.py b/velox/docs/ext/function.py new file mode 100644 index 000000000000..c89f888f0f28 --- /dev/null +++ b/velox/docs/ext/function.py @@ -0,0 +1,374 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The Function common functions.""" + +from __future__ import annotations + +import ast +import re +from inspect import Parameter +from typing import NamedTuple + +from docutils import nodes +from docutils.nodes import Element, Node +from sphinx import addnodes +from sphinx.addnodes import desc_signature, pending_xref + +from sphinx.environment import BuildEnvironment +from sphinx.locale import _ +from sphinx.util.inspect import signature_from_str + +# REs for Function signatures +function_sig_re = re.compile( + r"""^ ([\w.]*\.)? # class name(s) + (\w+) \s* # thing name + (?: \(\s*(.*)\s*\) # optional: arguments + (?:\s* -> \s* (.*))? # return annotation + )? $ # and nothing more + """, + re.VERBOSE, +) + +pairindextypes = { + "module": _("module"), + "keyword": _("keyword"), + "operator": _("operator"), + "object": _("object"), + "exception": _("exception"), + "statement": _("statement"), + "builtin": _("built-in function"), +} + + +class ObjectEntry(NamedTuple): + docname: str + node_id: str + objtype: str + aliased: bool + + +class ModuleEntry(NamedTuple): + docname: str + node_id: str + synopsis: str + platform: str + deprecated: bool + + +def parse_reftarget( + reftarget: str, suppress_prefix: bool = False +) -> tuple[str, str, str, bool]: + """Parse a type string and return (reftype, reftarget, title, refspecific flag)""" + refspecific = False + if reftarget.startswith("."): + reftarget = reftarget[1:] + title = reftarget + refspecific = True + elif reftarget.startswith("~"): + reftarget = reftarget[1:] + title = reftarget.split(".")[-1] + elif suppress_prefix: + title = reftarget.split(".")[-1] + elif reftarget.startswith("typing."): + title = reftarget[7:] + else: + title = reftarget + + if reftarget == "None" or reftarget.startswith("typing."): + # typing module provides non-class types. Obj reference is good to refer them. + reftype = "obj" + else: + reftype = "class" + + return reftype, reftarget, title, refspecific + + +def type_to_xref( + function_module: str, + target: str, + env: BuildEnvironment | None = None, + suppress_prefix: bool = False, +) -> addnodes.pending_xref: + """Convert a type string to a cross reference node.""" + if env: + kwargs = { + function_module + ":module": env.ref_context.get( + function_module + ":module" + ), + function_module + ":class": env.ref_context.get(function_module + ":class"), + } + else: + kwargs = {} + + reftype, target, title, refspecific = parse_reftarget(target, suppress_prefix) + contnodes = [nodes.Text(title)] + + return pending_xref( + "", + *contnodes, + refdomain=function_module, + reftype=reftype, + reftarget=target, + refspecific=refspecific, + **kwargs, + ) + + +def parse_annotation( + function_module: str, annotation: str, env: BuildEnvironment | None +) -> list[Node]: + """Parse type annotation.""" + + def unparse(node: ast.AST) -> list[Node]: + if isinstance(node, ast.Attribute): + return [nodes.Text(f"{unparse(node.value)[0]}.{node.attr}")] + elif isinstance(node, ast.BinOp): + result: list[Node] = unparse(node.left) + result.extend(unparse(node.op)) + result.extend(unparse(node.right)) + return result + elif isinstance(node, ast.BitOr): + return [ + addnodes.desc_sig_space(), + addnodes.desc_sig_punctuation("", "|"), + addnodes.desc_sig_space(), + ] + elif isinstance(node, ast.Constant): + if node.value is Ellipsis: + return [addnodes.desc_sig_punctuation("", "...")] + elif isinstance(node.value, bool): + return [addnodes.desc_sig_keyword("", repr(node.value))] + elif isinstance(node.value, int): + return [addnodes.desc_sig_literal_number("", repr(node.value))] + elif isinstance(node.value, str): + return [addnodes.desc_sig_literal_string("", repr(node.value))] + else: + # handles None, which is further handled by type_to_xref later + # and fallback for other types that should be converted + return [nodes.Text(repr(node.value))] + elif isinstance(node, ast.Expr): + return unparse(node.value) + elif isinstance(node, ast.Index): + return unparse(node.value) + elif isinstance(node, ast.Invert): + return [addnodes.desc_sig_punctuation("", "~")] + elif isinstance(node, ast.List): + result = [addnodes.desc_sig_punctuation("", "[")] + if node.elts: + # check if there are elements in node.elts to only pop the + # last element of result if the for-loop was run at least + # once + for elem in node.elts: + result.extend(unparse(elem)) + result.append(addnodes.desc_sig_punctuation("", ",")) + result.append(addnodes.desc_sig_space()) + result.pop() + result.pop() + result.append(addnodes.desc_sig_punctuation("", "]")) + return result + elif isinstance(node, ast.Module): + return sum((unparse(e) for e in node.body), []) + elif isinstance(node, ast.Name): + return [nodes.Text(node.id)] + elif isinstance(node, ast.Subscript): + if getattr(node.value, "id", "") in {"Optional", "Union"}: + return _unparse_pep_604_annotation(node) + result = unparse(node.value) + result.append(addnodes.desc_sig_punctuation("", "]")) + result.extend(unparse(node.slice)) + result.append(addnodes.desc_sig_punctuation("", "]")) + + # Wrap the Text nodes inside brackets by literal node if the subscript is a Literal + if result[0] in ("Literal", "typing.Literal"): + for i, subnode in enumerate(result[1:], start=1): + if isinstance(subnode, nodes.Text): + result[i] = nodes.literal("", "", subnode) + return result + elif isinstance(node, ast.UnaryOp): + return unparse(node.op) + unparse(node.operand) + elif isinstance(node, ast.Tuple): + if node.elts: + result = [] + for elem in node.elts: + result.extend(unparse(elem)) + result.append(addnodes.desc_sig_punctuation("", ",")) + result.append(addnodes.desc_sig_space()) + result.pop() + result.pop() + else: + result = [ + addnodes.desc_sig_punctuation("", "("), + addnodes.desc_sig_punctuation("", ")"), + ] + + return result + else: + raise SyntaxError # unsupported syntax + + def _unparse_pep_604_annotation(node: ast.Subscript) -> list[Node]: + subscript = node.slice + if isinstance(subscript, ast.Index): + subscript = subscript.value # type: ignore[assignment] + + flattened: list[Node] = [] + if isinstance(subscript, ast.Tuple): + flattened.extend(unparse(subscript.elts[0])) + for elt in subscript.elts[1:]: + flattened.extend(unparse(ast.BitOr())) + flattened.extend(unparse(elt)) + else: + # e.g. a Union[] inside an Optional[] + flattened.extend(unparse(subscript)) + + if getattr(node.value, "id", "") == "Optional": + flattened.extend(unparse(ast.BitOr())) + flattened.append(nodes.Text("None")) + + return flattened + + try: + tree = ast.parse(annotation) + result: list[Node] = [] + for node in unparse(tree): + if isinstance(node, nodes.literal): + result.append(node[0]) + elif isinstance(node, nodes.Text) and node.strip(): + if ( + result + and isinstance(result[-1], addnodes.desc_sig_punctuation) + and result[-1].astext() == "~" + ): + result.pop() + result.append( + type_to_xref( + function_module, str(node), env, suppress_prefix=True + ) + ) + else: + result.append(type_to_xref(function_module, str(node), env)) + else: + result.append(node) + return result + except SyntaxError: + return [type_to_xref(function_module, annotation, env)] + + +def parse_arglist( + function_module: str, arglist: str, env: BuildEnvironment | None = None +) -> addnodes.desc_parameterlist: + """Parse a list of arguments using AST parser""" + params = addnodes.desc_parameterlist(arglist) + sig = signature_from_str("(%s)" % arglist) + last_kind = None + for param in sig.parameters.values(): + if param.kind != param.POSITIONAL_ONLY and last_kind == param.POSITIONAL_ONLY: + # PEP-570: Separator for Positional Only Parameter: / + params += addnodes.desc_parameter( + "", "", addnodes.desc_sig_operator("", "/") + ) + if param.kind == param.KEYWORD_ONLY and last_kind in ( + param.POSITIONAL_OR_KEYWORD, + param.POSITIONAL_ONLY, + None, + ): + # PEP-3102: Separator for Keyword Only Parameter: * + params += addnodes.desc_parameter( + "", "", addnodes.desc_sig_operator("", "*") + ) + + node = addnodes.desc_parameter() + if param.kind == param.VAR_POSITIONAL: + node += addnodes.desc_sig_operator("", "*") + node += addnodes.desc_sig_name("", param.name) + elif param.kind == param.VAR_KEYWORD: + node += addnodes.desc_sig_operator("", "**") + node += addnodes.desc_sig_name("", param.name) + else: + node += addnodes.desc_sig_name("", param.name) + + if param.annotation is not param.empty: + children = parse_annotation(function_module, param.annotation, env) + node += addnodes.desc_sig_punctuation("", ":") + node += addnodes.desc_sig_space() + node += addnodes.desc_sig_name("", "", *children) # type: ignore + if param.default is not param.empty: + if param.annotation is not param.empty: + node += addnodes.desc_sig_space() + node += addnodes.desc_sig_operator("", "=") + node += addnodes.desc_sig_space() + else: + node += addnodes.desc_sig_operator("", "=") + node += nodes.inline( + "", param.default, classes=["default_value"], support_smartquotes=False + ) + + params += node + last_kind = param.kind + + if last_kind == Parameter.POSITIONAL_ONLY: + # PEP-570: Separator for Positional Only Parameter: / + params += addnodes.desc_parameter("", "", addnodes.desc_sig_operator("", "/")) + + return params + + +def pseudo_parse_arglist(signode: desc_signature, arglist: str) -> None: + """ "Parse" a list of arguments separated by commas. + + Arguments can have "optional" annotations given by enclosing them in + brackets. Currently, this will split at any comma, even if it's inside a + string literal (e.g. default argument value). + """ + paramlist = addnodes.desc_parameterlist() + stack: list[Element] = [paramlist] + try: + for argument in arglist.split(","): + argument = argument.strip() + ends_open = ends_close = 0 + while argument.startswith("["): + stack.append(addnodes.desc_optional()) + stack[-2] += stack[-1] + argument = argument[1:].strip() + while argument.startswith("]"): + stack.pop() + argument = argument[1:].strip() + while argument.endswith("]") and not argument.endswith("[]"): + ends_close += 1 + argument = argument[:-1].strip() + while argument.endswith("["): + ends_open += 1 + argument = argument[:-1].strip() + if argument: + stack[-1] += addnodes.desc_parameter( + "", "", addnodes.desc_sig_name(argument, argument) + ) + while ends_open: + stack.append(addnodes.desc_optional()) + stack[-2] += stack[-1] + ends_open -= 1 + while ends_close: + stack.pop() + ends_close -= 1 + if len(stack) != 1: + raise IndexError + except IndexError: + # if there are too few or too many elements on the stack, just give up + # and treat the whole argument list as one argument, discarding the + # already partially populated paramlist node + paramlist = addnodes.desc_parameterlist() + paramlist += addnodes.desc_parameter(arglist, arglist) + signode += paramlist + else: + signode += paramlist diff --git a/velox/docs/ext/iceberg.py b/velox/docs/ext/iceberg.py new file mode 100644 index 000000000000..1b2b746a30e1 --- /dev/null +++ b/velox/docs/ext/iceberg.py @@ -0,0 +1,775 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generating iceberg function link :iceberg:func:``""" + +from __future__ import annotations + +from function import ( + function_sig_re, + pairindextypes, + parse_arglist, + pseudo_parse_arglist, + parse_annotation, + ObjectEntry, + ModuleEntry, +) +from typing import Any, Iterable, Iterator, Tuple, cast + +from docutils import nodes +from docutils.nodes import Element, Node +from docutils.parsers.rst import directives +from sphinx import addnodes +from sphinx.addnodes import desc_signature, pending_xref +from sphinx.application import Sphinx +from sphinx.builders import Builder +from sphinx.directives import ObjectDescription +from sphinx.domains import Domain, Index, IndexEntry, ObjType +from sphinx.environment import BuildEnvironment +from sphinx.locale import _, __ +from sphinx.roles import XRefRole +from sphinx.util import logging +from sphinx.util.docfields import Field +from sphinx.util.nodes import ( + find_pending_xref_condition, + make_id, + make_refnode, +) +from sphinx.util.typing import OptionSpec + +logger = logging.getLogger(__name__) + +function_module = "iceberg" + + +class IcebergObject(ObjectDescription[Tuple[str, str]]): + """ + Description of a general Iceberg object. + + :cvar allow_nesting: Class is an object that allows for nested namespaces + :vartype allow_nesting: bool + """ + + option_spec: OptionSpec = { + "noindex": directives.flag, + "noindexentry": directives.flag, + "nocontentsentry": directives.flag, + "module": directives.unchanged, + "canonical": directives.unchanged, + "annotation": directives.unchanged, + } + + doc_field_types = [ + Field( + "returnvalue", + label=_("Returns"), + has_arg=False, + names=("returns", "return"), + ), + ] + + allow_nesting = False + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + """May return a prefix to put before the object name in the + signature. + """ + return [] + + def needs_arglist(self) -> bool: + """May return true if an empty argument list is to be generated even if + the document contains none. + """ + return False + + def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str]: + """Transform a Iceberg signature into RST nodes. + Return (fully qualified name of the thing, classname if any). + If inside a class, the current class name is handled intelligently: + * it is stripped from the displayed name if present + * it is added to the full name (return value) if not present + """ + m = function_sig_re.match(sig) + if m is None: + raise ValueError + prefix, name, arglist, retann = m.groups() + + # determine module and class name (if applicable), as well as full name + modname = self.options.get("module", self.env.ref_context.get("iceberg:module")) + classname = self.env.ref_context.get("iceberg:class") + if classname: + add_module = False + if prefix and (prefix == classname or prefix.startswith(classname + ".")): + fullname = prefix + name + # class name is given again in the signature + prefix = prefix[len(classname) :].lstrip(".") + elif prefix: + # class name is given in the signature, but different + # (shouldn't happen) + fullname = classname + "." + prefix + name + else: + # class name is not given in the signature + fullname = classname + "." + name + else: + add_module = True + if prefix: + classname = prefix.rstrip(".") + fullname = prefix + name + else: + classname = "" + fullname = name + + signode["module"] = modname + signode["class"] = classname + signode["fullname"] = fullname + + sig_prefix = self.get_signature_prefix(sig) + if sig_prefix: + if type(sig_prefix) is str: + raise TypeError( + "Python directive method get_signature_prefix()" + " must return a list of nodes." + f" Return value was '{sig_prefix}'." + ) + else: + signode += addnodes.desc_annotation(str(sig_prefix), "", *sig_prefix) + + if prefix: + signode += addnodes.desc_addname(prefix, prefix) + elif modname and add_module and self.env.config.add_module_names: + nodetext = modname + "." + signode += addnodes.desc_addname(nodetext, nodetext) + + signode += addnodes.desc_name(name, name) + if arglist: + try: + signode += parse_arglist(function_module, arglist, self.env) + except SyntaxError: + # fallback to parse arglist original parser. + # it supports to represent optional arguments (ex. "func(foo [, bar])") + pseudo_parse_arglist(signode, arglist) + except NotImplementedError as exc: + logger.warning( + "could not parse arglist (%r): %s", arglist, exc, location=signode + ) + pseudo_parse_arglist(signode, arglist) + else: + if self.needs_arglist(): + # for callables, add an empty parameter list + signode += addnodes.desc_parameterlist() + + if retann: + children = parse_annotation(function_module, retann, self.env) + signode += addnodes.desc_returns(retann, "", *children) + + anno = self.options.get("annotation") + if anno: + signode += addnodes.desc_annotation( + " " + anno, "", addnodes.desc_sig_space(), nodes.Text(anno) + ) + + return fullname, prefix + + def _object_hierarchy_parts(self, sig_node: desc_signature) -> tuple[str, ...]: + if "fullname" not in sig_node: + return () + modname = sig_node.get("module") + fullname = sig_node["fullname"] + + if modname: + return (modname, *fullname.split(".")) + else: + return tuple(fullname.split(".")) + + def get_index_text(self, modname: str, name: tuple[str, str]) -> str: + """Return the text for the index entry of the object.""" + raise NotImplementedError("must be implemented in subclasses") + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + modname = self.options.get("module", self.env.ref_context.get("iceberg:module")) + fullname = (modname + "." if modname else "") + name_cls[0] + node_id = make_id(self.env, self.state.document, "", fullname) + signode["ids"].append(node_id) + self.state.document.note_explicit_target(signode) + + domain = cast(IcebergDomain, self.env.get_domain("iceberg")) + domain.note_object(fullname, self.objtype, node_id, location=signode) + + canonical_name = self.options.get("canonical") + if canonical_name: + domain.note_object( + canonical_name, self.objtype, node_id, aliased=True, location=signode + ) + + if "noindexentry" not in self.options: + indextext = self.get_index_text(modname, name_cls) + if indextext: + self.indexnode["entries"].append( + ("single", indextext, node_id, "", None) + ) + + def before_content(self) -> None: + """Handle object nesting before content + + For constructs that aren't nestable, the stack is bypassed, and instead + only the most recent object is tracked. This object prefix name will be + removed with :iceberg:meth:`after_content`. + """ + prefix = None + if self.names: + # fullname and name_prefix come from the `handle_signature` method. + # fullname represents the full object name that is constructed using + # object nesting and explicit prefixes. `name_prefix` is the + # explicit prefix given in a signature + (fullname, name_prefix) = self.names[-1] + if self.allow_nesting: + prefix = fullname + elif name_prefix: + prefix = name_prefix.strip(".") + if prefix: + self.env.ref_context["iceberg:class"] = prefix + if self.allow_nesting: + classes = self.env.ref_context.setdefault("iceberg:classes", []) + classes.append(prefix) + if "module" in self.options: + modules = self.env.ref_context.setdefault("iceberg:modules", []) + modules.append(self.env.ref_context.get("iceberg:module")) + self.env.ref_context["iceberg:module"] = self.options["module"] + + def after_content(self) -> None: + """Handle object de-nesting after content + + If this class is a nestable object, removing the last nested class prefix + ends further nesting in the object. + + If this class is not a nestable object, the list of classes should not + be altered as we didn't affect the nesting levels in + :iceberg:meth:`before_content`. + """ + classes = self.env.ref_context.setdefault("iceberg:classes", []) + if self.allow_nesting: + try: + classes.pop() + except IndexError: + pass + self.env.ref_context["iceberg:class"] = ( + classes[-1] if len(classes) > 0 else None + ) + if "module" in self.options: + modules = self.env.ref_context.setdefault("iceberg:modules", []) + if modules: + self.env.ref_context["iceberg:module"] = modules.pop() + else: + self.env.ref_context.pop("iceberg:module") + + def _toc_entry_name(self, sig_node: desc_signature) -> str: + if not sig_node.get("_toc_parts"): + return "" + + config = self.env.app.config + objtype = sig_node.parent.get("objtype") + if config.add_function_parentheses and objtype in {"function", "method"}: + parens = "()" + else: + parens = "" + *parents, name = sig_node["_toc_parts"] + if config.toc_object_entries_show_parents == "domain": + return sig_node.get("fullname", name) + parens + if config.toc_object_entries_show_parents == "hide": + return name + parens + if config.toc_object_entries_show_parents == "all": + return ".".join(parents + [name + parens]) + return "" + + +class IcebergFunction(IcebergObject): + """Description of a function.""" + + option_spec: OptionSpec = IcebergObject.option_spec.copy() + option_spec.update( + { + "async": directives.flag, + } + ) + + def get_signature_prefix(self, sig: str) -> list[nodes.Node]: + if "async" in self.options: + return [addnodes.desc_sig_keyword("", "async"), addnodes.desc_sig_space()] + else: + return [] + + def needs_arglist(self) -> bool: + return True + + def add_target_and_index( + self, name_cls: tuple[str, str], sig: str, signode: desc_signature + ) -> None: + super().add_target_and_index(name_cls, sig, signode) + if "noindexentry" not in self.options: + modname = self.options.get( + "module", self.env.ref_context.get("iceberg:module") + ) + node_id = signode["ids"][0] + + name, cls = name_cls + if modname: + text = _("%s() (in module %s)") % (name, modname) + self.indexnode["entries"].append(("single", text, node_id, "", None)) + else: + text = f"{pairindextypes['builtin']}; {name}()" + self.indexnode["entries"].append(("pair", text, node_id, "", None)) + + def get_index_text(self, modname: str, name_cls: tuple[str, str]) -> str | None: + # add index in own add_target_and_index() instead. + return None + + +class IcebergXRefRole(XRefRole): + def process_link( + self, + env: BuildEnvironment, + refnode: Element, + has_explicit_title: bool, + title: str, + target: str, + ) -> tuple[str, str]: + refnode["iceberg:module"] = env.ref_context.get("iceberg:module") + refnode["iceberg:class"] = env.ref_context.get("iceberg:class") + if not has_explicit_title: + title = title.lstrip(".") # only has a meaning for the target + target = target.lstrip("~") # only has a meaning for the title + # if the first character is a tilde, don't display the module/class + # parts of the contents + if title[0:1] == "~": + title = title[1:] + dot = title.rfind(".") + if dot != -1: + title = title[dot + 1 :] + # if the first character is a dot, search more specific namespaces first + # else search builtins first + if target[0:1] == ".": + target = target[1:] + refnode["refspecific"] = True + return title, target + + +class IcebergModuleIndex(Index): + """ + Index subclass to provide the Iceberg module index. + """ + + name = "modindex" + localname = _("Iceberg Module Index") + shortname = _("modules") + + def generate( + self, docnames: Iterable[str] | None = None + ) -> tuple[list[tuple[str, list[IndexEntry]]], bool]: + content: dict[str, list[IndexEntry]] = {} + # list of prefixes to ignore + ignores: list[str] = self.domain.env.config["modindex_common_prefix"] + ignores = sorted(ignores, key=len, reverse=True) + # list of all modules, sorted by module name + modules = sorted( + self.domain.data["modules"].items(), key=lambda x: x[0].lower() + ) + # sort out collapsible modules + prev_modname = "" + num_toplevels = 0 + for modname, (docname, node_id, synopsis, platforms, deprecated) in modules: + if docnames and docname not in docnames: + continue + + for ignore in ignores: + if modname.startswith(ignore): + modname = modname[len(ignore) :] + stripped = ignore + break + else: + stripped = "" + + # we stripped the whole module name? + if not modname: + modname, stripped = stripped, "" + + entries = content.setdefault(modname[0].lower(), []) + + package = modname.split(".")[0] + if package != modname: + # it's a submodule + if prev_modname == package: + # first submodule - make parent a group head + if entries: + last = entries[-1] + entries[-1] = IndexEntry( + last[0], 1, last[2], last[3], last[4], last[5], last[6] + ) + elif not prev_modname.startswith(package): + # submodule without parent in list, add dummy entry + entries.append( + IndexEntry(stripped + package, 1, "", "", "", "", "") + ) + subtype = 2 + else: + num_toplevels += 1 + subtype = 0 + + qualifier = _("Deprecated") if deprecated else "" + entries.append( + IndexEntry( + stripped + modname, + subtype, + docname, + node_id, + platforms, + qualifier, + synopsis, + ) + ) + prev_modname = modname + + # apply heuristics when to collapse modindex at page load: + # only collapse if number of toplevel modules is larger than + # number of submodules + collapse = len(modules) - num_toplevels < num_toplevels + + # sort by first letter + sorted_content = sorted(content.items()) + + return sorted_content, collapse + + +class IcebergDomain(Domain): + """Iceberg domain.""" + + name = "iceberg" + label = "Iceberg" + object_types: dict[str, ObjType] = { + "function": ObjType(_("function"), "func", "obj"), + } + + directives = { + "function": IcebergFunction, + } + roles = { + "func": IcebergXRefRole(fix_parens=True), + } + initial_data: dict[str, dict[str, tuple[Any]]] = { + "objects": {}, # fullname -> docname, objtype + "modules": {}, # modname -> docname, synopsis, platform, deprecated + } + indices = [ + IcebergModuleIndex, + ] + + @property + def objects(self) -> dict[str, ObjectEntry]: + return self.data.setdefault("objects", {}) # fullname -> ObjectEntry + + def note_object( + self, + name: str, + objtype: str, + node_id: str, + aliased: bool = False, + location: Any = None, + ) -> None: + """Note a iceberg object for cross reference. + + .. versionadded:: 2.1 + """ + if name in self.objects: + other = self.objects[name] + if other.aliased and aliased is False: + # The original definition found. Override it! + pass + elif other.aliased is False and aliased: + # The original definition is already registered. + return + else: + # duplicated + logger.warning( + __( + "duplicate object description of %s, " + "other instance in %s, use :noindex: for one of them" + ), + name, + other.docname, + location=location, + ) + self.objects[name] = ObjectEntry(self.env.docname, node_id, objtype, aliased) + + @property + def modules(self) -> dict[str, ModuleEntry]: + return self.data.setdefault("modules", {}) # modname -> ModuleEntry + + def note_module( + self, name: str, node_id: str, synopsis: str, platform: str, deprecated: bool + ) -> None: + """Note a iceberg module for cross reference. + + .. versionadded:: 2.1 + """ + self.modules[name] = ModuleEntry( + self.env.docname, node_id, synopsis, platform, deprecated + ) + + def clear_doc(self, docname: str) -> None: + for fullname, obj in list(self.objects.items()): + if obj.docname == docname: + del self.objects[fullname] + for modname, mod in list(self.modules.items()): + if mod.docname == docname: + del self.modules[modname] + + def merge_domaindata(self, docnames: list[str], otherdata: dict[str, Any]) -> None: + # XXX check duplicates? + for fullname, obj in otherdata["objects"].items(): + if obj.docname in docnames: + self.objects[fullname] = obj + for modname, mod in otherdata["modules"].items(): + if mod.docname in docnames: + self.modules[modname] = mod + + def find_obj( + self, + env: BuildEnvironment, + modname: str, + classname: str, + name: str, + type: str | None, + searchmode: int = 0, + ) -> list[tuple[str, ObjectEntry]]: + """Find a Iceberg object for "name", perhaps using the given module + and/or classname. Returns a list of (name, object entry) tuples. + """ + # skip parens + if name[-2:] == "()": + name = name[:-2] + + if not name: + return [] + + matches: list[tuple[str, ObjectEntry]] = [] + + newname = None + if searchmode == 1: + if type is None: + objtypes = list(self.object_types) + else: + objtypes = self.objtypes_for_role(type) + if objtypes is not None: + if modname and classname: + fullname = modname + "." + classname + "." + name + if ( + fullname in self.objects + and self.objects[fullname].objtype in objtypes + ): + newname = fullname + if not newname: + if ( + modname + and modname + "." + name in self.objects + and self.objects[modname + "." + name].objtype in objtypes + ): + newname = modname + "." + name + elif ( + name in self.objects and self.objects[name].objtype in objtypes + ): + newname = name + else: + # "fuzzy" searching mode + searchname = "." + name + matches = [ + (oname, self.objects[oname]) + for oname in self.objects + if oname.endswith(searchname) + and self.objects[oname].objtype in objtypes + ] + else: + # NOTE: searching for exact match, object type is not considered + if name in self.objects: + newname = name + elif type == "mod": + # only exact matches allowed for modules + return [] + elif classname and classname + "." + name in self.objects: + newname = classname + "." + name + elif modname and modname + "." + name in self.objects: + newname = modname + "." + name + elif ( + modname + and classname + and modname + "." + classname + "." + name in self.objects + ): + newname = modname + "." + classname + "." + name + if newname is not None: + matches.append((newname, self.objects[newname])) + return matches + + def resolve_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + type: str, + target: str, + node: pending_xref, + contnode: Element, + ) -> Element | None: + modname = node.get("iceberg:module") + clsname = node.get("iceberg:class") + searchmode = 1 if node.hasattr("refspecific") else 0 + matches = self.find_obj(env, modname, clsname, target, type, searchmode) + + if not matches and type == "attr": + # fallback to meth (for property; Sphinx-2.4.x) + # this ensures that `:attr:` role continues to refer to the old property entry + # that defined by ``method`` directive in old reST files. + matches = self.find_obj(env, modname, clsname, target, "meth", searchmode) + if not matches and type == "meth": + # fallback to attr (for property) + # this ensures that `:meth:` in the old reST files can refer to the property + # entry that defined by ``property`` directive. + # + # Note: _prop is a secret role only for internal look-up. + matches = self.find_obj(env, modname, clsname, target, "_prop", searchmode) + + if not matches: + return None + elif len(matches) > 1: + canonicals = [m for m in matches if not m[1].aliased] + if len(canonicals) == 1: + matches = canonicals + else: + logger.warning( + __("more than one target found for cross-reference %r: %s"), + target, + ", ".join(match[0] for match in matches), + type="ref", + subtype="python", + location=node, + ) + name, obj = matches[0] + + if obj[2] == "module": + return self._make_module_refnode(builder, fromdocname, name, contnode) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + return make_refnode(builder, fromdocname, obj[0], obj[1], children, name) + + def resolve_any_xref( + self, + env: BuildEnvironment, + fromdocname: str, + builder: Builder, + target: str, + node: pending_xref, + contnode: Element, + ) -> list[tuple[str, Element]]: + modname = node.get("iceberg:module") + clsname = node.get("iceberg:class") + results: list[tuple[str, Element]] = [] + + # always search in "refspecific" mode with the :any: role + matches = self.find_obj(env, modname, clsname, target, None, 1) + multiple_matches = len(matches) > 1 + + for name, obj in matches: + if multiple_matches and obj.aliased: + # Skip duplicated matches + continue + + if obj[2] == "module": + results.append( + ( + "iceberg:mod", + self._make_module_refnode(builder, fromdocname, name, contnode), + ) + ) + else: + # determine the content of the reference by conditions + content = find_pending_xref_condition(node, "resolved") + if content: + children = content.children + else: + # if not found, use contnode + children = [contnode] + + results.append( + ( + "iceberg:" + self.role_for_objtype(obj[2]), + make_refnode( + builder, fromdocname, obj[0], obj[1], children, name + ), + ) + ) + return results + + def _make_module_refnode( + self, builder: Builder, fromdocname: str, name: str, contnode: Node + ) -> Element: + # get additional info for modules + module = self.modules[name] + title = name + if module.synopsis: + title += ": " + module.synopsis + if module.deprecated: + title += _(" (deprecated)") + if module.platform: + title += " (" + module.platform + ")" + return make_refnode( + builder, fromdocname, module.docname, module.node_id, contnode, title + ) + + def get_objects(self) -> Iterator[tuple[str, str, str, str, str, int]]: + for modname, mod in self.modules.items(): + yield (modname, modname, "module", mod.docname, mod.node_id, 0) + for refname, obj in self.objects.items(): + if obj.objtype != "module": # modules are already handled + if obj.aliased: + # aliased names are not full-text searchable. + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, -1) + else: + yield (refname, refname, obj.objtype, obj.docname, obj.node_id, 1) + + def get_full_qualified_name(self, node: Element) -> str | None: + modname = node.get("iceberg:module") + clsname = node.get("iceberg:class") + target = node.get("reftarget") + if target is None: + return None + else: + return ".".join(filter(None, [modname, clsname, target])) + + +def setup(app: Sphinx) -> dict[str, Any]: + app.setup_extension("sphinx.directives") + app.add_domain(IcebergDomain) + + return { + "version": "builtin", + "env_version": 3, + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/velox/docs/ext/spark.py b/velox/docs/ext/spark.py index 134a882db99f..8781aee89418 100644 --- a/velox/docs/ext/spark.py +++ b/velox/docs/ext/spark.py @@ -16,10 +16,16 @@ from __future__ import annotations -import ast -import re -from inspect import Parameter -from typing import Any, Iterable, Iterator, NamedTuple, Tuple, cast +from function import ( + function_sig_re, + pairindextypes, + parse_arglist, + pseudo_parse_arglist, + parse_annotation, + ObjectEntry, + ModuleEntry, +) +from typing import Any, Iterable, Iterator, Tuple, cast from docutils import nodes from docutils.nodes import Element, Node @@ -35,7 +41,6 @@ from sphinx.roles import XRefRole from sphinx.util import logging from sphinx.util.docfields import Field -from sphinx.util.inspect import signature_from_str from sphinx.util.nodes import ( find_pending_xref_condition, make_id, @@ -45,338 +50,7 @@ logger = logging.getLogger(__name__) - -# REs for Spark signatures -spark_sig_re = re.compile( - r"""^ ([\w.]*\.)? # class name(s) - (\w+) \s* # thing name - (?: \(\s*(.*)\s*\) # optional: arguments - (?:\s* -> \s* (.*))? # return annotation - )? $ # and nothing more - """, - re.VERBOSE, -) - -pairindextypes = { - "module": _("module"), - "keyword": _("keyword"), - "operator": _("operator"), - "object": _("object"), - "exception": _("exception"), - "statement": _("statement"), - "builtin": _("built-in function"), -} - - -class ObjectEntry(NamedTuple): - docname: str - node_id: str - objtype: str - aliased: bool - - -class ModuleEntry(NamedTuple): - docname: str - node_id: str - synopsis: str - platform: str - deprecated: bool - - -def parse_reftarget( - reftarget: str, suppress_prefix: bool = False -) -> tuple[str, str, str, bool]: - """Parse a type string and return (reftype, reftarget, title, refspecific flag)""" - refspecific = False - if reftarget.startswith("."): - reftarget = reftarget[1:] - title = reftarget - refspecific = True - elif reftarget.startswith("~"): - reftarget = reftarget[1:] - title = reftarget.split(".")[-1] - elif suppress_prefix: - title = reftarget.split(".")[-1] - elif reftarget.startswith("typing."): - title = reftarget[7:] - else: - title = reftarget - - if reftarget == "None" or reftarget.startswith("typing."): - # typing module provides non-class types. Obj reference is good to refer them. - reftype = "obj" - else: - reftype = "class" - - return reftype, reftarget, title, refspecific - - -def type_to_xref( - target: str, env: BuildEnvironment | None = None, suppress_prefix: bool = False -) -> addnodes.pending_xref: - """Convert a type string to a cross reference node.""" - if env: - kwargs = { - "spark:module": env.ref_context.get("spark:module"), - "spark:class": env.ref_context.get("spark:class"), - } - else: - kwargs = {} - - reftype, target, title, refspecific = parse_reftarget(target, suppress_prefix) - contnodes = [nodes.Text(title)] - - return pending_xref( - "", - *contnodes, - refdomain="spark", - reftype=reftype, - reftarget=target, - refspecific=refspecific, - **kwargs, - ) - - -def _parse_annotation(annotation: str, env: BuildEnvironment | None) -> list[Node]: - """Parse type annotation.""" - - def unparse(node: ast.AST) -> list[Node]: - if isinstance(node, ast.Attribute): - return [nodes.Text(f"{unparse(node.value)[0]}.{node.attr}")] - elif isinstance(node, ast.BinOp): - result: list[Node] = unparse(node.left) - result.extend(unparse(node.op)) - result.extend(unparse(node.right)) - return result - elif isinstance(node, ast.BitOr): - return [ - addnodes.desc_sig_space(), - addnodes.desc_sig_punctuation("", "|"), - addnodes.desc_sig_space(), - ] - elif isinstance(node, ast.Constant): - if node.value is Ellipsis: - return [addnodes.desc_sig_punctuation("", "...")] - elif isinstance(node.value, bool): - return [addnodes.desc_sig_keyword("", repr(node.value))] - elif isinstance(node.value, int): - return [addnodes.desc_sig_literal_number("", repr(node.value))] - elif isinstance(node.value, str): - return [addnodes.desc_sig_literal_string("", repr(node.value))] - else: - # handles None, which is further handled by type_to_xref later - # and fallback for other types that should be converted - return [nodes.Text(repr(node.value))] - elif isinstance(node, ast.Expr): - return unparse(node.value) - elif isinstance(node, ast.Index): - return unparse(node.value) - elif isinstance(node, ast.Invert): - return [addnodes.desc_sig_punctuation("", "~")] - elif isinstance(node, ast.List): - result = [addnodes.desc_sig_punctuation("", "[")] - if node.elts: - # check if there are elements in node.elts to only pop the - # last element of result if the for-loop was run at least - # once - for elem in node.elts: - result.extend(unparse(elem)) - result.append(addnodes.desc_sig_punctuation("", ",")) - result.append(addnodes.desc_sig_space()) - result.pop() - result.pop() - result.append(addnodes.desc_sig_punctuation("", "]")) - return result - elif isinstance(node, ast.Module): - return sum((unparse(e) for e in node.body), []) - elif isinstance(node, ast.Name): - return [nodes.Text(node.id)] - elif isinstance(node, ast.Subscript): - if getattr(node.value, "id", "") in {"Optional", "Union"}: - return _unparse_pep_604_annotation(node) - result = unparse(node.value) - result.append(addnodes.desc_sig_punctuation("", "]")) - result.extend(unparse(node.slice)) - result.append(addnodes.desc_sig_punctuation("", "]")) - - # Wrap the Text nodes inside brackets by literal node if the subscript is a Literal - if result[0] in ("Literal", "typing.Literal"): - for i, subnode in enumerate(result[1:], start=1): - if isinstance(subnode, nodes.Text): - result[i] = nodes.literal("", "", subnode) - return result - elif isinstance(node, ast.UnaryOp): - return unparse(node.op) + unparse(node.operand) - elif isinstance(node, ast.Tuple): - if node.elts: - result = [] - for elem in node.elts: - result.extend(unparse(elem)) - result.append(addnodes.desc_sig_punctuation("", ",")) - result.append(addnodes.desc_sig_space()) - result.pop() - result.pop() - else: - result = [ - addnodes.desc_sig_punctuation("", "("), - addnodes.desc_sig_punctuation("", ")"), - ] - - return result - else: - raise SyntaxError # unsupported syntax - - def _unparse_pep_604_annotation(node: ast.Subscript) -> list[Node]: - subscript = node.slice - if isinstance(subscript, ast.Index): - subscript = subscript.value # type: ignore[assignment] - - flattened: list[Node] = [] - if isinstance(subscript, ast.Tuple): - flattened.extend(unparse(subscript.elts[0])) - for elt in subscript.elts[1:]: - flattened.extend(unparse(ast.BitOr())) - flattened.extend(unparse(elt)) - else: - # e.g. a Union[] inside an Optional[] - flattened.extend(unparse(subscript)) - - if getattr(node.value, "id", "") == "Optional": - flattened.extend(unparse(ast.BitOr())) - flattened.append(nodes.Text("None")) - - return flattened - - try: - tree = ast.parse(annotation) - result: list[Node] = [] - for node in unparse(tree): - if isinstance(node, nodes.literal): - result.append(node[0]) - elif isinstance(node, nodes.Text) and node.strip(): - if ( - result - and isinstance(result[-1], addnodes.desc_sig_punctuation) - and result[-1].astext() == "~" - ): - result.pop() - result.append(type_to_xref(str(node), env, suppress_prefix=True)) - else: - result.append(type_to_xref(str(node), env)) - else: - result.append(node) - return result - except SyntaxError: - return [type_to_xref(annotation, env)] - - -def _parse_arglist( - arglist: str, env: BuildEnvironment | None = None -) -> addnodes.desc_parameterlist: - """Parse a list of arguments using AST parser""" - params = addnodes.desc_parameterlist(arglist) - sig = signature_from_str("(%s)" % arglist) - last_kind = None - for param in sig.parameters.values(): - if param.kind != param.POSITIONAL_ONLY and last_kind == param.POSITIONAL_ONLY: - # PEP-570: Separator for Positional Only Parameter: / - params += addnodes.desc_parameter( - "", "", addnodes.desc_sig_operator("", "/") - ) - if param.kind == param.KEYWORD_ONLY and last_kind in ( - param.POSITIONAL_OR_KEYWORD, - param.POSITIONAL_ONLY, - None, - ): - # PEP-3102: Separator for Keyword Only Parameter: * - params += addnodes.desc_parameter( - "", "", addnodes.desc_sig_operator("", "*") - ) - - node = addnodes.desc_parameter() - if param.kind == param.VAR_POSITIONAL: - node += addnodes.desc_sig_operator("", "*") - node += addnodes.desc_sig_name("", param.name) - elif param.kind == param.VAR_KEYWORD: - node += addnodes.desc_sig_operator("", "**") - node += addnodes.desc_sig_name("", param.name) - else: - node += addnodes.desc_sig_name("", param.name) - - if param.annotation is not param.empty: - children = _parse_annotation(param.annotation, env) - node += addnodes.desc_sig_punctuation("", ":") - node += addnodes.desc_sig_space() - node += addnodes.desc_sig_name("", "", *children) # type: ignore - if param.default is not param.empty: - if param.annotation is not param.empty: - node += addnodes.desc_sig_space() - node += addnodes.desc_sig_operator("", "=") - node += addnodes.desc_sig_space() - else: - node += addnodes.desc_sig_operator("", "=") - node += nodes.inline( - "", param.default, classes=["default_value"], support_smartquotes=False - ) - - params += node - last_kind = param.kind - - if last_kind == Parameter.POSITIONAL_ONLY: - # PEP-570: Separator for Positional Only Parameter: / - params += addnodes.desc_parameter("", "", addnodes.desc_sig_operator("", "/")) - - return params - - -def _pseudo_parse_arglist(signode: desc_signature, arglist: str) -> None: - """ "Parse" a list of arguments separated by commas. - - Arguments can have "optional" annotations given by enclosing them in - brackets. Currently, this will split at any comma, even if it's inside a - string literal (e.g. default argument value). - """ - paramlist = addnodes.desc_parameterlist() - stack: list[Element] = [paramlist] - try: - for argument in arglist.split(","): - argument = argument.strip() - ends_open = ends_close = 0 - while argument.startswith("["): - stack.append(addnodes.desc_optional()) - stack[-2] += stack[-1] - argument = argument[1:].strip() - while argument.startswith("]"): - stack.pop() - argument = argument[1:].strip() - while argument.endswith("]") and not argument.endswith("[]"): - ends_close += 1 - argument = argument[:-1].strip() - while argument.endswith("["): - ends_open += 1 - argument = argument[:-1].strip() - if argument: - stack[-1] += addnodes.desc_parameter( - "", "", addnodes.desc_sig_name(argument, argument) - ) - while ends_open: - stack.append(addnodes.desc_optional()) - stack[-2] += stack[-1] - ends_open -= 1 - while ends_close: - stack.pop() - ends_close -= 1 - if len(stack) != 1: - raise IndexError - except IndexError: - # if there are too few or too many elements on the stack, just give up - # and treat the whole argument list as one argument, discarding the - # already partially populated paramlist node - paramlist = addnodes.desc_parameterlist() - paramlist += addnodes.desc_parameter(arglist, arglist) - signode += paramlist - else: - signode += paramlist +function_module = "spark" class SparkObject(ObjectDescription[Tuple[str, str]]): @@ -426,7 +100,7 @@ def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str] * it is stripped from the displayed name if present * it is added to the full name (return value) if not present """ - m = spark_sig_re.match(sig) + m = function_sig_re.match(sig) if m is None: raise ValueError prefix, name, arglist, retann = m.groups() @@ -480,23 +154,23 @@ def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str] signode += addnodes.desc_name(name, name) if arglist: try: - signode += _parse_arglist(arglist, self.env) + signode += parse_arglist(function_module, arglist, self.env) except SyntaxError: # fallback to parse arglist original parser. # it supports to represent optional arguments (ex. "func(foo [, bar])") - _pseudo_parse_arglist(signode, arglist) + pseudo_parse_arglist(signode, arglist) except NotImplementedError as exc: logger.warning( "could not parse arglist (%r): %s", arglist, exc, location=signode ) - _pseudo_parse_arglist(signode, arglist) + pseudo_parse_arglist(signode, arglist) else: if self.needs_arglist(): # for callables, add an empty parameter list signode += addnodes.desc_parameterlist() if retann: - children = _parse_annotation(retann, self.env) + children = parse_annotation(function_module, retann, self.env) signode += addnodes.desc_returns(retann, "", *children) anno = self.options.get("annotation") diff --git a/velox/docs/functions.rst b/velox/docs/functions.rst index 9c9baa9e849f..38dc1b24c9af 100644 --- a/velox/docs/functions.rst +++ b/velox/docs/functions.rst @@ -21,7 +21,12 @@ Presto Functions functions/presto/aggregate functions/presto/window functions/presto/hyperloglog + functions/presto/tdigest + functions/presto/qdigest + functions/presto/geospatial + functions/presto/ipaddress functions/presto/uuid + functions/presto/enum functions/presto/misc Here is a list of all scalar and aggregate Presto functions available in Velox. @@ -62,99 +67,143 @@ for :doc:`all ` and :doc:`most used :widths: auto :class: rows - ======================================== ======================================== ======================================== == ======================================== == ======================================== - Scalar Functions Aggregate Functions Window Functions - ============================================================================================================================ == ======================================== == ======================================== - :func:`$internal$split_to_map` :func:`format_datetime` :func:`plus` :func:`any_value` :func:`cume_dist` - :func:`abs` :func:`from_base` :func:`poisson_cdf` :func:`approx_distinct` :func:`dense_rank` - :func:`acos` :func:`from_base64` :func:`pow` :func:`approx_most_frequent` :func:`first_value` - :func:`all_keys_match` :func:`from_base64url` :func:`power` :func:`approx_percentile` :func:`lag` - :func:`all_match` :func:`from_big_endian_32` :func:`quarter` :func:`approx_set` :func:`last_value` - :func:`any_keys_match` :func:`from_big_endian_64` :func:`radians` :func:`arbitrary` :func:`lead` - :func:`any_match` :func:`from_hex` :func:`rand` :func:`array_agg` :func:`nth_value` - :func:`any_values_match` :func:`from_ieee754_32` :func:`random` :func:`avg` :func:`ntile` - :func:`array_average` :func:`from_ieee754_64` :func:`reduce` :func:`bitwise_and_agg` :func:`percent_rank` - :func:`array_constructor` :func:`from_iso8601_date` :func:`regexp_extract` :func:`bitwise_or_agg` :func:`rank` - :func:`array_cum_sum` :func:`from_iso8601_timestamp` :func:`regexp_extract_all` :func:`bitwise_xor_agg` :func:`row_number` - :func:`array_distinct` :func:`from_unixtime` :func:`regexp_like` :func:`bool_and` - :func:`array_duplicates` :func:`from_utf8` :func:`regexp_replace` :func:`bool_or` - :func:`array_except` :func:`gamma_cdf` :func:`regexp_split` :func:`checksum` - :func:`array_frequency` :func:`greatest` :func:`remove_nulls` :func:`classification_fall_out` - :func:`array_has_duplicates` :func:`gt` :func:`repeat` :func:`classification_miss_rate` - :func:`array_intersect` :func:`gte` :func:`replace` :func:`classification_precision` - :func:`array_join` :func:`hamming_distance` :func:`replace_first` :func:`classification_recall` - :func:`array_max` :func:`hmac_md5` :func:`reverse` :func:`classification_thresholds` - :func:`array_min` :func:`hmac_sha1` :func:`round` :func:`corr` - :func:`array_normalize` :func:`hmac_sha256` :func:`rpad` :func:`count` - :func:`array_position` :func:`hmac_sha512` :func:`rtrim` :func:`count_if` - :func:`array_remove` :func:`hour` :func:`second` :func:`covar_pop` - :func:`array_sort` in :func:`secure_rand` :func:`covar_samp` - :func:`array_sort_desc` :func:`infinity` :func:`secure_random` :func:`entropy` - :func:`array_sum` :func:`inverse_beta_cdf` :func:`sequence` :func:`every` - :func:`array_sum_propagate_element_null` :func:`inverse_cauchy_cdf` :func:`sha1` :func:`geometric_mean` - :func:`array_union` :func:`inverse_laplace_cdf` :func:`sha256` :func:`histogram` - :func:`arrays_overlap` :func:`inverse_normal_cdf` :func:`sha512` :func:`kurtosis` - :func:`asin` :func:`inverse_weibull_cdf` :func:`shuffle` :func:`map_agg` - :func:`at_timezone` :func:`ip_prefix` :func:`sign` :func:`map_union` - :func:`atan` :func:`is_finite` :func:`sin` :func:`map_union_sum` - :func:`atan2` :func:`is_infinite` :func:`slice` :func:`max` - :func:`beta_cdf` :func:`is_json_scalar` :func:`split` :func:`max_by` - :func:`between` :func:`is_nan` :func:`split_part` :func:`max_data_size_for_stats` - :func:`binomial_cdf` :func:`is_null` :func:`split_to_map` :func:`merge` - :func:`bit_count` :func:`json_array_contains` :func:`spooky_hash_v2_32` :func:`min` - :func:`bitwise_and` :func:`json_array_get` :func:`spooky_hash_v2_64` :func:`min_by` - :func:`bitwise_arithmetic_shift_right` :func:`json_array_length` :func:`sqrt` :func:`multimap_agg` - :func:`bitwise_left_shift` :func:`json_extract` :func:`starts_with` :func:`reduce_agg` - :func:`bitwise_logical_shift_right` :func:`json_extract_scalar` :func:`strpos` :func:`regr_avgx` - :func:`bitwise_not` :func:`json_format` :func:`strrpos` :func:`regr_avgy` - :func:`bitwise_or` :func:`json_parse` :func:`subscript` :func:`regr_count` - :func:`bitwise_right_shift` :func:`json_size` :func:`substr` :func:`regr_intercept` - :func:`bitwise_right_shift_arithmetic` :func:`laplace_cdf` :func:`tan` :func:`regr_r2` - :func:`bitwise_shift_left` :func:`last_day_of_month` :func:`tanh` :func:`regr_slope` - :func:`bitwise_xor` :func:`least` :func:`timezone_hour` :func:`regr_sxx` - :func:`cardinality` :func:`length` :func:`timezone_minute` :func:`regr_sxy` - :func:`cauchy_cdf` :func:`levenshtein_distance` :func:`to_base` :func:`regr_syy` - :func:`cbrt` :func:`like` :func:`to_base64` :func:`set_agg` - :func:`ceil` :func:`ln` :func:`to_base64url` :func:`set_union` - :func:`ceiling` :func:`log10` :func:`to_big_endian_32` :func:`skewness` - :func:`chi_squared_cdf` :func:`log2` :func:`to_big_endian_64` :func:`stddev` - :func:`chr` :func:`lower` :func:`to_hex` :func:`stddev_pop` - :func:`clamp` :func:`lpad` :func:`to_ieee754_32` :func:`stddev_samp` - :func:`codepoint` :func:`lt` :func:`to_ieee754_64` :func:`sum` - :func:`combinations` :func:`lte` :func:`to_iso8601` :func:`sum_data_size_for_stats` - :func:`concat` :func:`ltrim` :func:`to_milliseconds` :func:`var_pop` - :func:`contains` :func:`map` :func:`to_unixtime` :func:`var_samp` - :func:`cos` :func:`map_concat` :func:`to_utf8` :func:`variance` - :func:`cosh` :func:`map_entries` :func:`trail` - :func:`cosine_similarity` :func:`map_filter` :func:`transform` - :func:`crc32` :func:`map_from_entries` :func:`transform_keys` - :func:`current_date` :func:`map_key_exists` :func:`transform_values` - :func:`date` :func:`map_keys` :func:`trim` - :func:`date_add` :func:`map_normalize` :func:`trim_array` - :func:`date_diff` :func:`map_remove_null_values` :func:`truncate` - :func:`date_format` :func:`map_subset` :func:`typeof` - :func:`date_parse` :func:`map_top_n` :func:`upper` - :func:`date_trunc` :func:`map_top_n_keys` :func:`url_decode` - :func:`day` :func:`map_values` :func:`url_encode` - :func:`day_of_month` :func:`map_zip_with` :func:`url_extract_fragment` - :func:`day_of_week` :func:`md5` :func:`url_extract_host` - :func:`day_of_year` :func:`millisecond` :func:`url_extract_parameter` - :func:`degrees` :func:`minus` :func:`url_extract_path` - :func:`distinct_from` :func:`minute` :func:`url_extract_port` - :func:`divide` :func:`mod` :func:`url_extract_protocol` - :func:`dow` :func:`month` :func:`url_extract_query` - :func:`doy` :func:`multimap_from_entries` :func:`uuid` - :func:`e` :func:`multiply` :func:`week` - :func:`element_at` :func:`nan` :func:`week_of_year` - :func:`empty_approx_set` :func:`negate` :func:`weibull_cdf` - :func:`ends_with` :func:`neq` :func:`width_bucket` - :func:`eq` :func:`ngrams` :func:`wilson_interval_lower` - :func:`exp` :func:`no_keys_match` :func:`wilson_interval_upper` - :func:`f_cdf` :func:`no_values_match` :func:`word_stem` - :func:`fail` :func:`none_match` :func:`xxhash64` - :func:`filter` :func:`normal_cdf` :func:`year` - :func:`find_first` :func:`normalize` :func:`year_of_week` - :func:`find_first_index` not :func:`yow` - :func:`flatten` :func:`parse_datetime` :func:`zip` - :func:`floor` :func:`pi` :func:`zip_with` - ======================================== ======================================== ======================================== == ======================================== == ======================================== + ================================================= ================================================= ================================================= == ================================================= == ================================================= + Scalar Functions Aggregate Functions Window Functions + ======================================================================================================================================================= == ================================================= == ================================================= + :func:`$internal$json_string_to_array_cast` :func:`geometry_to_dissolved_bing_tiles` :func:`secure_rand` :func:`any_value` :func:`cume_dist` + :func:`$internal$json_string_to_map_cast` :func:`geometry_union` :func:`secure_random` :func:`approx_distinct` :func:`dense_rank` + :func:`$internal$json_string_to_row_cast` :func:`great_circle_distance` :func:`sequence` :func:`approx_most_frequent` :func:`first_value` + :func:`$internal$split_to_map` :func:`greatest` :func:`sha1` :func:`approx_percentile` :func:`lag` + :func:`abs` :func:`gt` :func:`sha256` :func:`approx_set` :func:`last_value` + :func:`acos` :func:`gte` :func:`sha512` :func:`arbitrary` :func:`lead` + :func:`all_keys_match` :func:`hamming_distance` :func:`shuffle` :func:`array_agg` :func:`nth_value` + :func:`all_match` :func:`hmac_md5` :func:`sign` :func:`avg` :func:`ntile` + :func:`any_keys_match` :func:`hmac_sha1` :func:`simplify_geometry` :func:`bitwise_and_agg` :func:`percent_rank` + :func:`any_match` :func:`hmac_sha256` :func:`sin` :func:`bitwise_or_agg` :func:`rank` + :func:`any_values_match` :func:`hmac_sha512` :func:`slice` :func:`bitwise_xor_agg` :func:`row_number` + :func:`array_average` :func:`hour` :func:`split` :func:`bool_and` + :func:`array_constructor` in :func:`split_part` :func:`bool_or` + :func:`array_cum_sum` :func:`infinity` :func:`split_to_map` :func:`checksum` + :func:`array_distinct` :func:`inverse_beta_cdf` :func:`split_to_multimap` :func:`classification_fall_out` + :func:`array_duplicates` :func:`inverse_binomial_cdf` :func:`spooky_hash_v2_32` :func:`classification_miss_rate` + :func:`array_except` :func:`inverse_cauchy_cdf` :func:`spooky_hash_v2_64` :func:`classification_precision` + :func:`array_frequency` :func:`inverse_chi_squared_cdf` :func:`sqrt` :func:`classification_recall` + :func:`array_has_duplicates` :func:`inverse_f_cdf` :func:`st_area` :func:`classification_thresholds` + :func:`array_intersect` :func:`inverse_gamma_cdf` :func:`st_asbinary` :func:`corr` + :func:`array_join` :func:`inverse_laplace_cdf` :func:`st_astext` :func:`count` + :func:`array_max` :func:`inverse_normal_cdf` :func:`st_boundary` :func:`count_if` + :func:`array_max_by` :func:`inverse_poisson_cdf` :func:`st_buffer` :func:`covar_pop` + :func:`array_min` :func:`inverse_t_cdf` :func:`st_centroid` :func:`covar_samp` + :func:`array_min_by` :func:`inverse_weibull_cdf` :func:`st_contains` :func:`entropy` + :func:`array_normalize` :func:`ip_prefix` :func:`st_convexhull` :func:`every` + :func:`array_position` :func:`ip_prefix_collapse` :func:`st_coorddim` :func:`geometric_mean` + :func:`array_remove` :func:`ip_prefix_subnets` :func:`st_crosses` :func:`histogram` + :func:`array_sort` :func:`ip_subnet_max` :func:`st_difference` :func:`kurtosis` + :func:`array_sort_desc` :func:`ip_subnet_min` :func:`st_dimension` :func:`map_agg` + :func:`array_subset` :func:`ip_subnet_range` :func:`st_disjoint` :func:`map_union` + :func:`array_sum` :func:`is_finite` :func:`st_distance` :func:`map_union_sum` + :func:`array_sum_propagate_element_null` :func:`is_infinite` :func:`st_endpoint` :func:`max` + :func:`array_top_n` :func:`is_json_scalar` :func:`st_envelope` :func:`max_by` + :func:`array_union` :func:`is_nan` :func:`st_envelopeaspts` :func:`max_data_size_for_stats` + :func:`arrays_overlap` :func:`is_null` :func:`st_equals` :func:`merge` + :func:`asin` :func:`is_private_ip` :func:`st_exteriorring` :func:`min` + :func:`at_timezone` :func:`is_subnet_of` :func:`st_geometries` :func:`min_by` + :func:`atan` :func:`json_array_contains` :func:`st_geometryfromtext` :func:`multimap_agg` + :func:`atan2` :func:`json_array_get` :func:`st_geometryn` :func:`noisy_approx_distinct_sfm` + :func:`beta_cdf` :func:`json_array_length` :func:`st_geometrytype` :func:`noisy_approx_set_sfm` + :func:`between` :func:`json_extract` :func:`st_geomfrombinary` :func:`noisy_approx_set_sfm_from_index_and_zeros` + :func:`bing_tile` :func:`json_extract_scalar` :func:`st_interiorringn` :func:`noisy_avg_gaussian` + :func:`bing_tile_at` :func:`json_format` :func:`st_interiorrings` :func:`noisy_count_gaussian` + :func:`bing_tile_children` :func:`json_parse` :func:`st_intersection` :func:`noisy_count_if_gaussian` + :func:`bing_tile_coordinates` :func:`json_size` :func:`st_intersects` :func:`noisy_sum_gaussian` + :func:`bing_tile_parent` :func:`laplace_cdf` :func:`st_isclosed` :func:`numeric_histogram` + :func:`bing_tile_polygon` :func:`last_day_of_month` :func:`st_isempty` :func:`qdigest_agg` + :func:`bing_tile_quadkey` :func:`least` :func:`st_isring` :func:`reduce_agg` + :func:`bing_tile_zoom_level` :func:`length` :func:`st_issimple` :func:`regr_avgx` + :func:`bing_tiles_around` :func:`levenshtein_distance` :func:`st_isvalid` :func:`regr_avgy` + :func:`binomial_cdf` :func:`like` :func:`st_length` :func:`regr_count` + :func:`bit_count` :func:`line_interpolate_point` :func:`st_linefromtext` :func:`regr_intercept` + :func:`bit_length` :func:`line_locate_point` :func:`st_linestring` :func:`regr_r2` + :func:`bitwise_and` :func:`ln` :func:`st_multipoint` :func:`regr_slope` + :func:`bitwise_arithmetic_shift_right` :func:`localtime` :func:`st_numgeometries` :func:`regr_sxx` + :func:`bitwise_left_shift` :func:`log10` :func:`st_numinteriorring` :func:`regr_sxy` + :func:`bitwise_logical_shift_right` :func:`log2` :func:`st_numpoints` :func:`regr_syy` + :func:`bitwise_not` :func:`longest_common_prefix` :func:`st_overlaps` :func:`set_agg` + :func:`bitwise_or` :func:`lower` :func:`st_point` :func:`set_union` + :func:`bitwise_right_shift` :func:`lpad` :func:`st_pointn` :func:`skewness` + :func:`bitwise_right_shift_arithmetic` :func:`lt` :func:`st_points` :func:`stddev` + :func:`bitwise_shift_left` :func:`lte` :func:`st_polygon` :func:`stddev_pop` + :func:`bitwise_xor` :func:`ltrim` :func:`st_relate` :func:`stddev_samp` + :func:`cardinality` :func:`map` :func:`st_startpoint` :func:`sum` + :func:`cauchy_cdf` :func:`map_concat` :func:`st_symdifference` :func:`sum_data_size_for_stats` + :func:`cbrt` :func:`map_entries` :func:`st_touches` :func:`tdigest_agg` + :func:`ceil` :func:`map_filter` :func:`st_union` :func:`var_pop` + :func:`ceiling` :func:`map_from_entries` :func:`st_within` :func:`var_samp` + :func:`chi_squared_cdf` :func:`map_intersect` :func:`st_x` :func:`variance` + :func:`chr` :func:`map_key_exists` :func:`st_xmax` + :func:`clamp` :func:`map_keys` :func:`st_xmin` + :func:`codepoint` :func:`map_keys_by_top_n_values` :func:`st_y` + :func:`combinations` :func:`map_normalize` :func:`st_ymax` + :func:`combine_hash_internal` :func:`map_remove_null_values` :func:`st_ymin` + :func:`concat` :func:`map_subset` :func:`starts_with` + :func:`construct_tdigest` :func:`map_top_n` :func:`strpos` + :func:`contains` :func:`map_top_n_keys` :func:`strrpos` + :func:`cos` :func:`map_top_n_values` :func:`subscript` + :func:`cosh` :func:`map_values` :func:`substr` + :func:`cosine_similarity` :func:`map_zip_with` :func:`substring` + :func:`crc32` :func:`md5` :func:`t_cdf` + :func:`current_date` :func:`merge_hll` :func:`tan` + :func:`date` :func:`merge_sfm` :func:`tanh` + :func:`date_add` :func:`merge_tdigest` :func:`timezone_hour` + :func:`date_diff` :func:`millisecond` :func:`timezone_minute` + :func:`date_format` :func:`minus` :func:`to_base` + :func:`date_parse` :func:`minute` :func:`to_base64` + :func:`date_trunc` :func:`mod` :func:`to_base64url` + :func:`day` :func:`month` :func:`to_big_endian_32` + :func:`day_of_month` :func:`multimap_from_entries` :func:`to_big_endian_64` + :func:`day_of_week` :func:`multiply` :func:`to_hex` + :func:`day_of_year` :func:`murmur3_x64_128` :func:`to_ieee754_32` + :func:`degrees` :func:`nan` :func:`to_ieee754_64` + :func:`destructure_tdigest` :func:`negate` :func:`to_iso8601` + :func:`distinct_from` :func:`neq` :func:`to_milliseconds` + :func:`divide` :func:`ngrams` :func:`to_unixtime` + :func:`dot_product` :func:`no_keys_match` :func:`to_utf8` + :func:`dow` :func:`no_values_match` :func:`trail` + :func:`doy` :func:`noisy_empty_approx_set_sfm` :func:`transform` + :func:`e` :func:`none_match` :func:`transform_keys` + :func:`element_at` :func:`normal_cdf` :func:`transform_values` + :func:`empty_approx_set` :func:`normalize` :func:`trim` + :func:`ends_with` not :func:`trim_array` + :func:`enum_key` :func:`parse_datetime` :func:`trimmed_mean` + :func:`eq` :func:`parse_duration` :func:`truncate` + :func:`exp` :func:`parse_presto_data_size` :func:`typeof` + :func:`expand_envelope` :func:`pi` :func:`upper` + :func:`f_cdf` :func:`plus` :func:`url_decode` + :func:`fail` :func:`poisson_cdf` :func:`url_encode` + :func:`filter` :func:`pow` :func:`url_extract_fragment` + :func:`find_first` :func:`power` :func:`url_extract_host` + :func:`find_first_index` :func:`quantile_at_value` :func:`url_extract_parameter` + :func:`flatten` :func:`quantiles_at_values` :func:`url_extract_path` + :func:`flatten_geometry_collections` :func:`quarter` :func:`url_extract_port` + :func:`floor` :func:`radians` :func:`url_extract_protocol` + :func:`format_datetime` :func:`rand` :func:`url_extract_query` + :func:`from_base` :func:`random` :func:`uuid` + :func:`from_base32` :func:`reduce` :func:`value_at_quantile` + :func:`from_base64` :func:`regexp_extract` :func:`values_at_quantiles` + :func:`from_base64url` :func:`regexp_extract_all` :func:`week` + :func:`from_big_endian_32` :func:`regexp_like` :func:`week_of_year` + :func:`from_big_endian_64` :func:`regexp_replace` :func:`weibull_cdf` + :func:`from_hex` :func:`regexp_split` :func:`width_bucket` + :func:`from_ieee754_32` :func:`remap_keys` :func:`wilson_interval_lower` + :func:`from_ieee754_64` :func:`remove_nulls` :func:`wilson_interval_upper` + :func:`from_iso8601_date` :func:`repeat` :func:`word_stem` + :func:`from_iso8601_timestamp` :func:`replace` :func:`xxhash64` + :func:`from_unixtime` :func:`replace_first` :func:`xxhash64_internal` + :func:`from_utf8` :func:`reverse` :func:`year` + :func:`gamma_cdf` :func:`round` :func:`year_of_week` + :func:`geometry_as_geojson` :func:`rpad` :func:`yow` + :func:`geometry_from_geojson` :func:`rtrim` :func:`zip` + :func:`geometry_invalid_reason` :func:`scale_qdigest` :func:`zip_with` + :func:`geometry_nearest_points` :func:`scale_tdigest` + :func:`geometry_to_bing_tiles` :func:`second` + ================================================= ================================================= ================================================= == ================================================= == ================================================= diff --git a/velox/docs/functions/delta/functions.rst b/velox/docs/functions/delta/functions.rst new file mode 100644 index 000000000000..e1b3de5673fb --- /dev/null +++ b/velox/docs/functions/delta/functions.rst @@ -0,0 +1,13 @@ +******************** +Delta Lake Functions +******************** + +Here is a list of all scalar Delta Lake functions available in Velox. +Function names link to function description. + +These functions are used in deletion vector read. +Refer to `Delta Lake documentation `_ and `Delta Lake deletion vector blog `_ for details. + +.. delta:function:: bitmap_array_contains(bitmap_array: varbinary, input: bigint) -> bool + + Not implemented. diff --git a/velox/docs/functions/iceberg/functions.rst b/velox/docs/functions/iceberg/functions.rst new file mode 100644 index 000000000000..a878f8c40ea8 --- /dev/null +++ b/velox/docs/functions/iceberg/functions.rst @@ -0,0 +1,79 @@ +***************** +Iceberg Functions +***************** + +Here is a list of all scalar Iceberg functions available in Velox. +Function names link to function description. + +These functions are used in partition transform. +Refer to `Iceberg documenation `_ for details. + +.. iceberg:function:: bucket(numBuckets, input) -> integer + + Returns an integer between 0 and numBuckets - 1, indicating the assigned bucket. + Bucket partitioning is based on a 32-bit hash of the input, specifically using the x86 + variant of the Murmur3 hash function with a seed of 0. + + The function can be expressed in pseudo-code as below. :: + + def bucket(numBuckets, input)= (murmur3_x86_32_hash(input) & Integer.MAX_VALUE) % numBuckets + + The ``numBuckets`` is of type INTEGER and must be greater than 0. Otherwise, an exception is thrown. + Supported types for ``input`` are INTEGER, BIGINT, DECIMAL, DATE, TIMESTAMP, VARCHAR, VARBINARY. :: + SELECT bucket(128, 'abcd'); -- 4 + SELECT bucket(100, 34L); -- 79 + +.. iceberg:function:: days(input) -> date + + Returns the date. :: + + SELECT days(DATE '2017-12-01'); -- 2017-12-01 + SELECT days(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 2017-12-01 + SELECT days(DATE '1969-12-31'); -- 1969-12-31 + +.. iceberg:function:: hours(input) -> integer + + Returns the number of hours since epoch (1970-01-01 00:00:00). Returns 0 for '1970-01-01 00:00:00' timestamps. + Returns negative value for timestamps before '1970-01-01 00:00:00'. :: + + SELECT hours(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 420034 + SELECT hours(TIMESTAMP '1969-12-31 23:59:58.999999'); -- -1 + +.. iceberg:function:: months(input) -> integer + + Returns the number of months since epoch (1970-01-01). Returns 0 for '1970-01-01' date and timestamps. + Returns negative value for dates and timestamps before '1970-01-01'. :: + + SELECT months(DATE '2017-12-01'); -- 575 + SELECT months(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 575 + SELECT months(DATE '1960-01-01'); -- -120 + +.. iceberg:function:: truncate(width, input) -> same type as input + + Returns the truncated value of the input based on the specified width. + For numeric values, truncate to the nearest lower multiple of ``width``, the truncate function is: input - (((input % width) + width) % width). + The ``width`` is used to truncate decimal values is applied using unscaled value to avoid additional (and potentially conflicting) parameters. + For string values, it truncates a valid UTF-8 string with no more than ``width`` code points. + In contrast to strings, binary values do not have an assumed encoding and are truncated to ``width`` bytes. + + Argument ``width`` must be a positive integer. + Supported types for ``input`` are: SHORTINT, TYNYINT, SMALLINT, INTEGER, BIGINT, DECIMAL, VARCHAR, VARBINARY. :: + + SELECT truncate(10, 11); -- 10 + SELECT truncate(10, -11); -- -20 + SELECT truncate(7, 22); -- 21 + SELECT truncate(0, 11); -- error: Reason: (0 vs. 0) Invalid truncate width\nExpression: width <= 0 + SELECT truncate(-3, 11); -- error: Reason: (-3 vs. 0) Invalid truncate width\nExpression: width <= 0 + SELECT truncate(4, 'iceberg'); -- 'iceb' + SELECT truncate(1, '测试'); -- 测 + SELECT truncate(6, '测试'); -- 测试 + SELECT truncate(6, cast('测试' as binary)); -- 测试_ + +.. iceberg:function:: years(input) -> integer + + Returns the number of years since epoch (1970-01-01). Returns 0 for '1970-01-01' date and timestamps. + Returns negative value for dates and timestamps before '1970-01-01'. :: + + SELECT years(DATE '2017-12-01'); -- 47 + SELECT years(TIMESTAMP '2017-12-01 10:12:55.038194'); -- 47 + SELECT years(DATE '1960-01-01'); -- -10 diff --git a/velox/docs/functions/presto/aggregate.rst b/velox/docs/functions/presto/aggregate.rst index 89ce72189634..8212675c328d 100644 --- a/velox/docs/functions/presto/aggregate.rst +++ b/velox/docs/functions/presto/aggregate.rst @@ -31,11 +31,20 @@ General Aggregate Functions inputs if :doc:`presto.array_agg.ignore_nulls <../../configs>` is set to false. -.. function:: avg(x) -> double|real +.. function:: avg(x) -> double|real|decimal Returns the average (arithmetic mean) of all non-null input values. When x is of type REAL, the result type is REAL. - For all other input types, the result type is DOUBLE. + When x is an integer or a DOUBLE, the result is DOUBLE. + When x is of type DECIMAL(p, s), the result type is DECIMAL(p, s). + Note: For the overflow cases, Velox returns a result when Presto throws "Decimal overflow". :: + SELECT AVG(col) + FROM ( VALUES + (CAST(9999999999999999999999999999999.9999999 AS DECIMAL(38,7))), + (CAST(9999999999999999999999999999999.9999999 AS DECIMAL(38,7))) + ) AS t(col); + -- Velox: 9999999999999999999999999999999.9999999 + -- Presto: Decimal overflow .. function:: bool_and(boolean) -> boolean @@ -709,20 +718,223 @@ Statistical Aggregate Functions Noisy Aggregate Functions ------------------------- -.. function:: noisy_count_if_gaussian(col, noise_scale) -> bigint +Overview +~~~~~~~~ + +Noisy aggregate functions provide random, noisy approximations of common +aggregations like ``sum()``, ``count()``, and ``approx_distinct()`` as well as sketches like +``approx_set()``. By injecting random noise into results, noisy aggregation functions make it +more difficult to determine or confirm the exact data that was aggregated. + +While many of these functions resemble `differential privacy `_ +mechanisms, neither the values returned by these functions nor the query results that incorporate +these functions are differentially private in general. See Limitations_ below for more details. +Users who wish to support a strong privacy guarantee should discuss with a suitable technical +expert first. + +Counts, Sums, and Averages +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. function:: noisy_count_if_gaussian(col, noise_scale[, random_seed]) -> bigint Counts the ``TRUE`` values in ``col`` and then adds a normally distributed random double value with 0 mean and standard deviation of ``noise_scale`` to the true count. The noisy count is post-processed to be non-negative and rounded to bigint. + If provided, ``random_seed`` is used to seed the random number generator. + Otherwise, noise is drawn from a secure random. (*Note: ``random_seed`` is a constant and shared across all groups in a query. + It is kept in each accmulator to ensure the ``random_seed`` is accessible in the final aggregation step.*) + :: - SELECT noisy_count_if_gaussian(orderkey > 10000, 20.0) FROM tpch.tiny.lineitem; -- 50180 (1 row) - SELECT noisy_count_if_gaussian(orderkey > 10000, 20.0) FROM tpch.tiny.lineitem WHERE false; -- NULL (1 row) + SELECT noisy_count_if_gaussian(orderkey > 10000, 20.0) FROM lineitem; -- 50180 (1 row) + SELECT noisy_count_if_gaussian(orderkey > 10000, 20.0) FROM lineitem WHERE false; -- NULL (1 row) .. note:: - Unlike :func:`!count_if`, this function returns ``NULL`` when the (true) count is 0. + Unlike :func:`count_if`, this function returns ``NULL`` when the (true) count is 0. + +.. function:: noisy_count_gaussian(col, noise_scale[, random_seed]) -> bigint + + Counts the non-null values in ``col`` and then adds a normally distributed random double + value with 0 mean and standard deviation of ``noise_scale`` to the true count. + The noisy count is post-processed to be non-negative and rounded to bigint. + + If provided, ``random_seed`` is used to seed the random number generator. + Otherwise, noise is drawn from a secure random. + + :: + + SELECT noisy_count_gaussian(orderkey, 20.0) FROM lineitem; -- 60181 (1 row) + SELECT noisy_count_gaussian(orderkey, 20.0) FROM lineitem WHERE false; -- NULL (1 row) + + .. note:: + + Unlike :func:`!count`, this function returns ``NULL`` when the (true) count is 0. + + Distinct counting can be performed using :func:`noisy_count_gaussian` ``(DISTINCT col, ...)``, or with ``noisy_approx_distinct_sfm()``. + Generally speaking, :func:`noisy_count_gaussian` returns more accurate results but at a larger computational cost. + +.. function:: noisy_sum_gaussian(col, noise_scale[, random_seed]) -> double + + Calculates the sum over the input values in ``col`` and then adds a normally distributed + random double value with 0 mean and standard deviation of ``noise_scale``. + + If provided, ``random_seed`` is used to seed the random number generator. + Otherwise, noise is drawn from a secure random. + +.. function:: noisy_sum_gaussian(col, noise_scale, lower, upper[, random_seed]) -> double + + Calculates the sum over the input values in ``col`` and then adds a normally distributed + random double value with 0 mean and standard deviation of ``noise_scale``. + Each value is clipped to the range of [``lower``, ``upper``] before adding to the sum. + + If provided, ``random_seed`` is used to seed the random number generator. + Otherwise, noise is drawn from a secure random. + +.. function:: noisy_avg_gaussian(col, noise_scale[, random_seed]) -> double + + Calculates the average (arithmetic mean) of all the input values in col and then adds a + normally distributed random double value with 0 mean and standard deviation of noise_scale. + + If provided, ``random_seed`` is used to seed the random number generator. + Otherwise, noise is drawn from a secure random. + +.. function:: noisy_avg_gaussian(col, noise_scale, lower, upper[, random_seed]) -> double + + Calculates the average (arithmetic mean) of all the input values in ``col`` and then adds a + normally distributed random double value with 0 mean and standard deviation of ``noise_scale``. + Each value is clipped to the range of [``lower``, ``upper``] before averaging. + + If provided, ``random_seed`` is used to seed the random number generator. + Otherwise, noise is drawn from a secure random. + +.. function:: noisy_approx_set_sfm(col, epsilon[, buckets[, precision]]) -> SfmSketch + + Returns an SFM sketch of the input values in ``col``. This is analogous to the ``approx_set()`` function, + which returns a (deterministic) HyperLogLog sketch. + + * ``col`` currently supports types: "bigint", "double", "string", "varbinary". + * ``epsilon`` (double) is a positive number that controls the level of noise in the sketch, as described in [Hehir2023]_. + Smaller values of epsilon correspond to noisier sketches. + * ``buckets`` (int) defaults to 4096. + * ``precision`` (int) defaults to 24. + + .. note:: + + Unlike ``approx_set()``, this function returns ``NULL`` when ``col`` is empty. + If this behavior is undesirable, use ``coalesce()`` with :func:`noisy_empty_approx_set_sfm`. + +.. function:: noisy_approx_distinct_sfm(col, epsilon[, buckets[, precision]]) -> bigint + + Equivalent to ``cardinality(noisy_approx_set_sfm(col, epsilon, buckets, precision))``, + this returns the approximate cardinality (distinct count) of the column col. + This is analogous to the (deterministic) :func:`approx_distinct` function. + + .. note:: + + Unlike :func:`approx_distinct`, this function returns ``NULL`` when ``col`` is empty. + +.. function:: noisy_empty_approx_set_sfm(epsilon[, buckets[, precision]]) -> SfmSketch + + Returns an SFM sketch with no items in it. This is analogous to the ``empty_approx_set()`` function, + which returns an empty (deterministic) ``HyperLogLog`` sketch. + + * ``epsilon`` (double) is a positive number that controls the level of noise in the sketch, as described in [Hehir2023]_. Smaller values of epsilon correspond to noisier sketches. + * ``buckets`` (int) defaults to 4096. + * ``precision`` (int) defaults to 24. + +.. function:: noisy_approx_set_sfm_from_index_and_zeros(col_index, col_zeros, epsilon, buckets[, precision]) -> SfmSketch + + Returns an SFM sketch of the input values in ``col_index`` and ``col_zeros``. + + This is similar to :func:`noisy_approx_set_sfm` except that function calculates a ``xxhash64()`` of ``col``, + and calculates the SFM PCSA bucket index and number of trailing zeros as described in + [FlajoletMartin1985]_. In this function, the caller must explicitly calculate the hash bucket index + and zeros themselves and pass them as arguments ``col_index`` and ``col_zeros``. + + - ``col_index`` (bigint) must be in the range ``0..buckets-1``. + - ``col_zeros`` (bigint) must be in the range ``0..64``. If it exceeds ``precision``, it + is cropped to ``precision-1``. + - ``epsilon`` (double) is a positive number that controls the level of noise in + the sketch, as described in [Hehir2023]_. Smaller values of epsilon correspond + to noisier sketches. + - ``buckets`` (int) is the number of buckets in the SFM PCSA sketch as described in [Hehir2023]_. + - ``precision`` (int) defaults to 24. + + .. note:: + + Like :func:`noisy_approx_set_sfm`, this function returns ``NULL`` when ``col_index`` + or ``col_zeros`` is ``NULL``. + If this behavior is undesirable, use :func:`!coalesce` with :func:`noisy_empty_approx_set_sfm`. + +.. function:: cardinality(SfmSketch) -> bigint + + Returns the estimated cardinality (distinct count) of an ``SfmSketch`` object. + +.. function:: merge(SfmSketch) -> SfmSketch + + An aggregator function that returns a merged ``SfmSketch`` of the set union of + individual ``SfmSketch`` objects, similar to ``merge(HyperLogLog)``. + + :: + + SELECT year, cardinality(merge(sketch)) AS annual_distinct_count + FROM monthly_sketches + GROUP BY 1 + +.. function:: merge_sfm(ARRAY[SfmSketch, ...]) -> SfmSketch + + A scalar function that returns a merged ``SfmSketch`` of the set union of an array of ``SfmSketch`` objects, similar to ``merge_hll()``. + + :: + + SELECT cardinality(merge_sfm(ARRAY[ + noisy_approx_set_sfm(col_1, 5.0), + noisy_approx_set_sfm(col_2, 5.0), + noisy_approx_set_sfm(col_3, 5.0) + ])) AS distinct_count_over_3_cols + FROM my_table + +Limitations +~~~~~~~~~~~ + +While these functions resemble differential privacy mechanisms, the values returned by these +functions are not differentially private in general. There are several important limitations +to keep in mind if using these functions for privacy-preserving purposes, including: + +* All noisy aggregate functions return ``NULL`` when aggregating empty sets. This means a ``NULL`` + return value noiselessly indicates the absence of data. + +* ``GROUP BY`` clauses used in combination with noisy aggregation functions reveal non-noisy + information: the presence or absence of a group noiselessly indicates the presence or + absence of data. See, e.g., [Wilkins2024]_. + +* Functions relying on floating-point noise may be susceptible to inference attacks such as + those identified in [Mironov2012]_ and [Casacuberta2022]_. + +References +~~~~~~~~~~ + +.. [Casacuberta2022] Casacuberta, S., Shoemate, M., Vadhan, S., & Wagaman, C. (2022). + `Widespread Underestimation of Sensitivity in Differentially Private Libraries and How to Fix It `_. + In Proceedings of the 2022 ACM SIGSAC Conference on Computer and Communications Security (pp. 471-484). + +.. [Hehir2023] Hehir, J., Ting, D., & Cormode, G. (2023). + `Sketch-Flip-Merge: Mergeable Sketches for Private Distinct Counting `_. + In Proceedings of the 40th International Conference on Machine Learning (Vol. 202). + +.. [Mironov2012] Mironov, I. (2012). + `On significance of the least significant bits for differential privacy `_. + In Proceedings of the 2012 ACM Conference on Computer and Communications Security (pp. 650-661). + +.. [Wilkins2024] Wilkins, A., Kifer, D., Zhang, D., & Karrer, B. (2024). + `Exact Privacy Analysis of the Gaussian Sparse Histogram Mechanism `_. + Journal of Privacy and Confidentiality, 14 (1). + +.. [FlajoletMartin1985] Flajolet, P, Martin, G. N. (1985). + `Probabilistic Counting Algorithms for Data Base Applications `_. + In Journal of Computer and System Sciences, 31:182-209, 1985 Miscellaneous ------------- diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index ffc691c44458..bcb355ed6f56 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -34,8 +34,11 @@ Array Functions Returns the average of all non-null elements of the array. If there are no non-null elements, returns null. .. function:: array_cum_sum(array(T)) -> array(T) + Returns the array whose elements are the cumulative sum of the input array, i.e. result[i] = input[1] + input[2] + - … + input[i]. If there there is null elements in the array, the cumulative sum at and after the element is null. :: + … + input[i]. If there there is null elements in the array, the cumulative sum at and after the element is null. + The following types are supported: int8_t, int16_t, int32_t, int64_t, int128_t, float, double, ShortDecimal, + and LongDecimal. :: SELECT array_cum_sum(ARRAY [1, 2, 3]) -- array[1, 3, 6] SELECT array_cum_sum(ARRAY [1, 2, null, 3]) -- array[1, 3, null, null] @@ -129,6 +132,13 @@ Array Functions SELECT array_min(ARRAY[{-1, -2, -3, nan()]); -- -1 SELECT array_min(ARRAY[{infinity(), nan()]); -- Infinity +.. function:: array_max_by(array(T), function(T, U)) -> T() + + Applies the provided function to each element, and returns the element that gives the maximum value. + ``U`` can be any orderable type. :: + + SELECT array_max_by(ARRAY ['a', 'bbb', 'cc'], x -> LENGTH(x)) -- 'bbb' + .. function:: array_normalize(array(E), E) -> array(E) Normalizes array ``x`` by dividing each element by the p-norm of the array. It is equivalent to ``TRANSFORM(array, v -> v / REDUCE(array, 0, (a, v) -> a + POW(ABS(v), p), a -> POW(a, 1 / p))``, but the reduce part is only executed once. Returns null if the array is null or there are null array elements. If ``p`` is 0, then the input array is returned. Only REAL and DOUBLE types are supported. @@ -207,7 +217,7 @@ Array Functions SELECT array_sort(ARRAY [ARRAY [1, 2], ARRAY [1, null]]); -- failed: Ordering nulls is not supported .. function:: array_sort_desc(array(T), function(T,U)) -> array(T) - :noindex: + :noindex: Returns the array sorted by values computed using specified lambda in descending order. U must be an orderable type. Null elements will be placed at the end of @@ -215,7 +225,20 @@ Array Functions nested nulls. Throws if deciding the order of elements would require comparing nested null values. :: - SELECT array_sort_desc(ARRAY ['cat', 'leopard', 'mouse'], x -> length(x)); -- ['leopard', 'mouse', 'cat'] + SELECT array_sort_desc(ARRAY ['cat', 'leopard', 'mouse'], x -> length(x)); -- ['leopard', 'mouse', 'cat'] + +.. function:: array_subset(array(T), array(int)) -> array(T) + + Returns an array containing elements from the input array at the specified 1-based indices. + Indices must be positive integers. Invalid indices (out of bounds, zero, or negative) are ignored. + Null elements at valid indices are preserved in the output. Duplicate indices result in duplicate elements in the output. + The output maintains the order of the indices array. :: + + SELECT array_subset(ARRAY[1, 2, 3, 4, 5], ARRAY[1, 3, 5]); -- [1, 3, 5] + SELECT array_subset(ARRAY['a', 'b', 'c'], ARRAY[3, 1, 2]); -- ['c', 'a', 'b'] + SELECT array_subset(ARRAY[1, NULL, 3], ARRAY[2]); -- [NULL] + SELECT array_subset(ARRAY[1, 2, 3], ARRAY[1, 1, 2]); -- [1, 1, 2] + SELECT array_subset(ARRAY[1, 2, 3], ARRAY[5, 0, -1]); -- [] .. function:: array_sum(array(T)) -> bigint/double diff --git a/velox/docs/functions/presto/binary.rst b/velox/docs/functions/presto/binary.rst index 1e9ee252b135..315ba71be2c3 100644 --- a/velox/docs/functions/presto/binary.rst +++ b/velox/docs/functions/presto/binary.rst @@ -150,3 +150,7 @@ Binary Functions .. function:: xxhash64(binary) -> varbinary Computes the xxhash64 hash of ``binary``. + +.. function:: xxhash64(binary, bigint) -> varbinary + + Computes the xxhash64 hash of ``binary`` with ``bigint`` seed. diff --git a/velox/docs/functions/presto/conversion.rst b/velox/docs/functions/presto/conversion.rst index b04dd316918d..d33e6c492d4c 100644 --- a/velox/docs/functions/presto/conversion.rst +++ b/velox/docs/functions/presto/conversion.rst @@ -30,7 +30,7 @@ are supported if the conversion of their element types are supported. In additio supported conversions to/from JSON are listed in :doc:`json`. .. list-table:: - :widths: 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 + :widths: 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 :header-rows: 1 * - @@ -50,6 +50,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - decimal - ipaddress - ipprefix + - tdigest + - qdigest * - tinyint - Y - Y @@ -67,6 +69,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - smallint - Y - Y @@ -84,6 +88,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - integer - Y - Y @@ -101,6 +107,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - bigint - Y - Y @@ -118,6 +126,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - boolean - Y - Y @@ -135,6 +145,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - real - Y - Y @@ -152,6 +164,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - double - Y - Y @@ -169,6 +183,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - varchar - Y - Y @@ -186,6 +202,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - Y - Y + - + - * - varbinary - - @@ -202,7 +220,9 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - Y - - + - Y + - Y + - Y * - timestamp - - @@ -220,6 +240,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - timestamp with time zone - - @@ -237,6 +259,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - date - - @@ -254,6 +278,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - interval day to second - - @@ -271,6 +297,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - + - + - * - decimal - Y - Y @@ -288,6 +316,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - Y - - + - + - * - ipaddress - - @@ -305,6 +335,8 @@ supported conversions to/from JSON are listed in :doc:`json`. - - - Y + - + - * - ipprefix - - @@ -322,6 +354,47 @@ supported conversions to/from JSON are listed in :doc:`json`. - - Y - Y + - + - + * - tdigest + - + - + - + - + - + - + - + - + - Y + - + - + - + - + - + - + - + - + - + * - qdigest + - + - + - + - + - + - + - + - + - Y + - + - + - + - + - + - + - + - + - + Cast to Integral Types ---------------------- @@ -570,7 +643,7 @@ Invalid example SELECT cast(decimal '300.001' as tinyint); -- Out of range Cast to VARCHAR --------------- +--------------- Casting from scalar types to string is allowed. @@ -777,6 +850,26 @@ IPV4 mapped IPV6: SELECT cast('::ffff:ffff:ffff' as ipaddress); -- 0x00000000000000000000ffffffffffff +From TDIGEST(DOUBLE) +^^^^^^^^^^^^^^^^^^^^ + +Returns the T-digest as a varbinary string containing the serialized representation of the T-digest data structure. +This allows T-digests to be stored and retrieved for later use. + +:: + + SELECT cast(tdigest_agg(cast(1.0 as double)) as varbinary); -- AQAAAAAAAADwPwAAAAAAAPA/AAAAAAAA8D8AAAAAAABZQAAAAAAAAPA/AQAAAAAAAAAAAPA/AAAAAAAA8D8= + +From QDIGEST(BIGINT), QDIGEST(REAL), QDIGEST(DOUBLE) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Returns the quantile digest as a varbinary string containing the serialized representation of the quantile digest data structure. +This allows quantile digests to be stored and retrieved for later use. + +:: + + SELECT cast(qdigest_agg(cast(1.0 as double)) as varbinary); -- AHsUrkfheoQ/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAPA/AAAAAAAA8D8BAAAAAAAAAAAAAPA/AAAAAAAA8L8= + Cast to TIMESTAMP ----------------- @@ -1203,6 +1296,34 @@ Examples: SELECT cast(ipprefix '1.2.3.4/24' as ipaddress) -- ipaddress '1.2.3.0' SELECT cast(ipprefix '2001:db8::ff00:42:8329/64' as ipaddress) -- ipaddress '2001:db8::' +Cast to TDIGEST(DOUBLE) +----------------------- + +From VARBINARY +^^^^^^^^^^^^^^ + +Returns a T-digest reconstructed from the varbinary string containing the serialized representation. +This allows previously stored T-digests to be restored for use. + +:: + + SELECT cast(stored_tdigest_binary as tdigest(double)); + +Cast to QDIGEST(BIGINT), QDIGEST(REAL), QDIGEST(DOUBLE) +------------------------------------------------------- + +From VARBINARY +^^^^^^^^^^^^^^ + +Returns a quantile digest reconstructed from the varbinary string containing the serialized representation. +This allows previously stored quantile digests to be restored for use. + +:: + + SELECT cast(stored_qdigest_binary as qdigest(bigint)); + SELECT cast(stored_qdigest_binary as qdigest(real)); + SELECT cast(stored_qdigest_binary as qdigest(double)); + Cast to IPPREFIX ---------------- diff --git a/velox/docs/functions/presto/coverage.rst b/velox/docs/functions/presto/coverage.rst index ab1c906785bc..c6698306a2fa 100644 --- a/velox/docs/functions/presto/coverage.rst +++ b/velox/docs/functions/presto/coverage.rst @@ -13,97 +13,127 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage td:nth-child(8) {background-color: lightblue;} table.coverage tr:nth-child(1) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(1) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(1) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(1) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(1) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(1) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(1) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(2) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(2) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(2) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(2) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(2) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(2) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(2) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(3) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(3) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(3) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(3) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(3) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(3) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(3) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(4) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(4) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(4) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(4) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(4) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(4) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(4) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(5) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(5) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(5) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(5) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(5) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(5) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(5) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(6) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(6) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(6) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(6) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(6) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(6) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(6) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(7) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(7) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(8) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(8) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(9) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(9) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(9) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(10) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(11) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(9) {background-color: #6BA81E;} table.coverage tr:nth-child(12) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(12) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(12) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(12) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(12) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(12) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(13) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(13) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(13) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(13) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(13) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(13) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(14) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(14) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(14) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(14) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(14) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(14) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(15) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(15) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(15) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(15) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(15) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(15) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(16) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(16) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(17) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(17) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(17) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(17) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(17) td:nth-child(7) {background-color: #6BA81E;} - table.coverage tr:nth-child(18) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(18) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(18) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(18) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(18) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(18) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(19) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(19) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(20) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(20) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(20) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(20) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(20) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(21) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(21) td:nth-child(2) {background-color: #6BA81E;} @@ -119,40 +149,48 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(22) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(23) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(23) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(23) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(23) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(23) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(23) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(24) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(24) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(24) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(24) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(24) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(4) {background-color: #6BA81E;} - table.coverage tr:nth-child(25) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(25) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(26) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(26) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(26) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(26) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(27) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(27) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(27) td:nth-child(7) {background-color: #6BA81E;} - table.coverage tr:nth-child(28) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(28) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(28) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(28) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(28) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(29) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(29) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(29) td:nth-child(4) {background-color: #6BA81E;} - table.coverage tr:nth-child(29) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(29) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(30) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(30) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(30) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(30) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(30) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(31) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(31) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(31) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(31) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(31) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(32) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(32) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(32) td:nth-child(3) {background-color: #6BA81E;} @@ -162,6 +200,7 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(33) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(33) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(33) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(34) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(34) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(34) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(34) td:nth-child(5) {background-color: #6BA81E;} @@ -169,166 +208,262 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with table.coverage tr:nth-child(35) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(35) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(35) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(36) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(36) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(36) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(36) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(36) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(37) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(37) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(37) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(38) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(38) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(39) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(39) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(39) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(39) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(39) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(39) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(40) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(40) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(40) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(40) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(40) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(40) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(41) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(41) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(42) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(42) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(42) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(42) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(42) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(42) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(43) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(43) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(43) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(43) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(43) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(43) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(44) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(44) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(44) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(44) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(44) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(44) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(45) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(45) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(45) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(46) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(46) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(46) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(46) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(46) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(46) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(47) td:nth-child(1) {background-color: #6BA81E;} - table.coverage tr:nth-child(47) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(47) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(47) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(47) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(47) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(48) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(48) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(48) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(48) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(48) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(48) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(49) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(49) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(49) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(49) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(49) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(49) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(50) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(50) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(50) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(50) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(50) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(50) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(51) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(51) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(51) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(51) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(51) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(51) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(52) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(52) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(52) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(52) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(53) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(53) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(53) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(53) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(53) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(53) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(54) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(54) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(54) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(54) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(54) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(1) {background-color: #6BA81E;} - table.coverage tr:nth-child(55) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(55) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(55) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(56) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(56) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(56) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(56) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(56) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(56) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(57) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(57) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(57) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(57) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(58) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(58) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(58) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(59) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(59) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(59) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(2) {background-color: #6BA81E;} - table.coverage tr:nth-child(60) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(60) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(60) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(2) {background-color: #6BA81E;} - table.coverage tr:nth-child(61) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(61) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(61) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(62) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(62) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(62) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(62) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(62) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(62) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(63) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(63) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(63) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(63) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(63) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(63) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(64) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(64) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(64) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(64) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(64) td:nth-child(5) {background-color: #6BA81E;} - table.coverage tr:nth-child(65) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(64) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(65) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(65) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(65) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(66) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(66) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(66) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(66) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(66) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(66) td:nth-child(7) {background-color: #6BA81E;} - table.coverage tr:nth-child(67) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(67) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(67) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(67) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(67) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(68) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(68) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(68) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(68) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(68) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(69) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(69) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(69) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(69) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(69) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(69) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(70) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(70) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(70) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(70) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(70) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(71) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(71) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(71) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(71) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(71) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(71) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(72) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(72) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(72) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(73) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(73) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(73) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(73) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(73) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(73) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(2) {background-color: #6BA81E;} - table.coverage tr:nth-child(74) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(74) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(74) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(75) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(75) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(75) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(75) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(76) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(76) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(76) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(76) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(76) td:nth-child(7) {background-color: #6BA81E;} + table.coverage tr:nth-child(77) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(77) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(77) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(77) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(77) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(77) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(78) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(78) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(78) td:nth-child(3) {background-color: #6BA81E;} - table.coverage tr:nth-child(78) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(78) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(78) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(79) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(79) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(79) td:nth-child(3) {background-color: #6BA81E;} - table.coverage tr:nth-child(79) td:nth-child(5) {background-color: #6BA81E;} + table.coverage tr:nth-child(79) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(80) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(1) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(2) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(81) td:nth-child(4) {background-color: #6BA81E;} .. table:: @@ -338,83 +473,85 @@ Here is a list of all scalar, aggregate, and window functions from Presto, with ======================================== ======================================== ======================================== ======================================== ======================================== == ======================================== == ======================================== Scalar Functions Aggregate Functions Window Functions ================================================================================================================================================================================================================ == ======================================== == ======================================== - :func:`abs` :func:`date_diff` ip_subnet_range :func:`random` st_numgeometries :func:`approx_distinct` :func:`cume_dist` - :func:`acos` :func:`date_format` :func:`is_finite` :func:`reduce` st_numinteriorring :func:`approx_most_frequent` :func:`dense_rank` - :func:`all_match` :func:`date_parse` :func:`is_infinite` :func:`regexp_extract` st_numpoints :func:`approx_percentile` :func:`first_value` - :func:`any_keys_match` :func:`date_trunc` :func:`is_json_scalar` :func:`regexp_extract_all` st_overlaps :func:`approx_set` :func:`lag` - :func:`any_match` :func:`day` :func:`is_nan` :func:`regexp_like` st_point :func:`arbitrary` :func:`last_value` - :func:`any_values_match` :func:`day_of_month` is_private_ip :func:`regexp_replace` st_pointn :func:`array_agg` :func:`lead` - :func:`array_average` :func:`day_of_week` is_subnet_of :func:`regexp_split` st_points :func:`avg` :func:`nth_value` - :func:`array_cum_sum` :func:`day_of_year` jaccard_index regress st_polygon :func:`bitwise_and_agg` :func:`ntile` - :func:`array_distinct` :func:`degrees` :func:`json_array_contains` reidentification_potential st_relate :func:`bitwise_or_agg` :func:`percent_rank` - :func:`array_duplicates` :func:`dow` :func:`json_array_get` :func:`remove_nulls` st_startpoint :func:`bool_and` :func:`rank` - :func:`array_except` :func:`doy` :func:`json_array_length` render st_symdifference :func:`bool_or` :func:`row_number` - :func:`array_frequency` :func:`e` :func:`json_extract` :func:`repeat` st_touches :func:`checksum` - :func:`array_has_duplicates` :func:`element_at` :func:`json_extract_scalar` :func:`replace` st_union :func:`classification_fall_out` - :func:`array_intersect` :func:`empty_approx_set` :func:`json_format` :func:`replace_first` st_within :func:`classification_miss_rate` - :func:`array_join` :func:`ends_with` :func:`json_parse` :func:`reverse` st_x :func:`classification_precision` - array_least_frequent enum_key :func:`json_size` rgb st_xmax :func:`classification_recall` - :func:`array_max` :func:`exp` key_sampling_percent :func:`round` st_xmin :func:`classification_thresholds` - array_max_by expand_envelope :func:`laplace_cdf` :func:`rpad` st_y convex_hull_agg - :func:`array_min` :func:`f_cdf` :func:`last_day_of_month` :func:`rtrim` st_ymax :func:`corr` - array_min_by features :func:`least` scale_qdigest st_ymin :func:`count` - :func:`array_normalize` :func:`filter` :func:`length` :func:`second` :func:`starts_with` :func:`count_if` - :func:`array_position` :func:`filter` :func:`levenshtein_distance` :func:`secure_rand` :func:`strpos` :func:`covar_pop` - :func:`array_remove` :func:`find_first` line_interpolate_point :func:`secure_random` :func:`strrpos` :func:`covar_samp` - :func:`array_sort` :func:`find_first_index` line_locate_point :func:`sequence` :func:`substr` differential_entropy - :func:`array_sort_desc` :func:`flatten` :func:`ln` :func:`sha1` :func:`tan` :func:`entropy` - array_split_into_chunks flatten_geometry_collections localtime :func:`sha256` :func:`tanh` evaluate_classifier_predictions - :func:`array_sum` :func:`floor` localtimestamp :func:`sha512` tdigest_agg :func:`every` - array_top_n fnv1_32 :func:`log10` :func:`shuffle` :func:`timezone_hour` :func:`geometric_mean` - :func:`array_union` fnv1_64 :func:`log2` :func:`sign` :func:`timezone_minute` geometry_union_agg - :func:`arrays_overlap` fnv1a_32 :func:`lower` simplify_geometry :func:`to_base` :func:`histogram` - :func:`asin` fnv1a_64 :func:`lpad` :func:`sin` to_base32 khyperloglog_agg - :func:`atan` :func:`format_datetime` :func:`ltrim` sketch_kll_quantile :func:`to_base64` :func:`kurtosis` - :func:`atan2` :func:`from_base` :func:`map` sketch_kll_rank :func:`to_base64url` learn_classifier - bar from_base32 :func:`map_concat` :func:`slice` :func:`to_big_endian_32` learn_libsvm_classifier - :func:`beta_cdf` :func:`from_base64` :func:`map_entries` spatial_partitions :func:`to_big_endian_64` learn_libsvm_regressor - bing_tile :func:`from_base64url` :func:`map_filter` :func:`split` to_geometry learn_regressor - bing_tile_at :func:`from_big_endian_32` :func:`map_from_entries` :func:`split_part` :func:`to_hex` make_set_digest - bing_tile_children :func:`from_big_endian_64` :func:`map_keys` :func:`split_to_map` :func:`to_ieee754_32` :func:`map_agg` - bing_tile_coordinates :func:`from_hex` map_keys_by_top_n_values split_to_multimap :func:`to_ieee754_64` :func:`map_union` - bing_tile_parent :func:`from_ieee754_32` :func:`map_normalize` :func:`spooky_hash_v2_32` :func:`to_iso8601` :func:`map_union_sum` - bing_tile_polygon :func:`from_ieee754_64` :func:`map_remove_null_values` :func:`spooky_hash_v2_64` :func:`to_milliseconds` :func:`max` - bing_tile_quadkey :func:`from_iso8601_date` :func:`map_subset` :func:`sqrt` to_spherical_geography :func:`max_by` - bing_tile_zoom_level :func:`from_iso8601_timestamp` :func:`map_top_n` st_area :func:`to_unixtime` :func:`merge` - bing_tiles_around :func:`from_unixtime` :func:`map_top_n_keys` st_asbinary :func:`to_utf8` merge_set_digest - :func:`binomial_cdf` :func:`from_utf8` map_top_n_keys_by_value st_astext :func:`trail` :func:`min` - :func:`bit_count` :func:`gamma_cdf` map_top_n_values st_boundary :func:`transform` :func:`min_by` - :func:`bitwise_and` geometry_as_geojson :func:`map_values` st_buffer :func:`transform_keys` :func:`multimap_agg` - :func:`bitwise_arithmetic_shift_right` geometry_from_geojson :func:`map_zip_with` st_centroid :func:`transform_values` noisy_avg_gaussian - :func:`bitwise_left_shift` geometry_invalid_reason :func:`md5` st_contains :func:`trim` noisy_count_gaussian - :func:`bitwise_logical_shift_right` geometry_nearest_points merge_hll st_convexhull :func:`trim_array` noisy_count_if_gaussian - :func:`bitwise_not` geometry_to_bing_tiles merge_khll st_coorddim :func:`truncate` noisy_sum_gaussian - :func:`bitwise_or` geometry_to_dissolved_bing_tiles :func:`millisecond` st_crosses :func:`typeof` numeric_histogram - :func:`bitwise_right_shift` geometry_union :func:`minute` st_difference uniqueness_distribution qdigest_agg - :func:`bitwise_right_shift_arithmetic` great_circle_distance :func:`mod` st_dimension :func:`upper` :func:`reduce_agg` - :func:`bitwise_shift_left` :func:`greatest` :func:`month` st_disjoint :func:`url_decode` :func:`regr_avgx` - :func:`bitwise_xor` :func:`hamming_distance` :func:`multimap_from_entries` st_distance :func:`url_encode` :func:`regr_avgy` - :func:`cardinality` hash_counts murmur3_x64_128 st_endpoint :func:`url_extract_fragment` :func:`regr_count` - :func:`cauchy_cdf` :func:`hmac_md5` myanmar_font_encoding st_envelope :func:`url_extract_host` :func:`regr_intercept` - :func:`cbrt` :func:`hmac_sha1` myanmar_normalize_unicode st_envelopeaspts :func:`url_extract_parameter` :func:`regr_r2` - :func:`ceil` :func:`hmac_sha256` :func:`nan` st_equals :func:`url_extract_path` :func:`regr_slope` - :func:`ceiling` :func:`hmac_sha512` :func:`ngrams` st_exteriorring :func:`url_extract_port` :func:`regr_sxx` - :func:`chi_squared_cdf` :func:`hour` :func:`no_keys_match` st_geometries :func:`url_extract_protocol` :func:`regr_sxy` - :func:`chr` :func:`infinity` :func:`no_values_match` st_geometryfromtext :func:`url_extract_query` :func:`regr_syy` - classify intersection_cardinality :func:`none_match` st_geometryn :func:`uuid` reservoir_sample - :func:`codepoint` :func:`inverse_beta_cdf` :func:`normal_cdf` st_geometrytype value_at_quantile :func:`set_agg` - color inverse_binomial_cdf :func:`normalize` st_geomfrombinary values_at_quantiles :func:`set_union` - :func:`combinations` :func:`inverse_cauchy_cdf` now st_interiorringn :func:`week` sketch_kll - :func:`concat` inverse_chi_squared_cdf :func:`parse_datetime` st_interiorrings :func:`week_of_year` sketch_kll_with_k - :func:`contains` inverse_f_cdf parse_duration st_intersection :func:`weibull_cdf` :func:`skewness` - :func:`cos` inverse_gamma_cdf parse_presto_data_size st_intersects :func:`width_bucket` spatial_partitioning - :func:`cosh` :func:`inverse_laplace_cdf` :func:`pi` st_isclosed :func:`wilson_interval_lower` :func:`stddev` - :func:`cosine_similarity` :func:`inverse_normal_cdf` pinot_binary_decimal_to_double st_isempty :func:`wilson_interval_upper` :func:`stddev_pop` - :func:`crc32` inverse_poisson_cdf :func:`poisson_cdf` st_isring :func:`word_stem` :func:`stddev_samp` - :func:`current_date` :func:`inverse_weibull_cdf` :func:`pow` st_issimple :func:`xxhash64` :func:`sum` - current_time :func:`ip_prefix` :func:`power` st_isvalid :func:`year` tdigest_agg - current_timestamp ip_prefix_collapse quantile_at_value st_length :func:`year_of_week` :func:`var_pop` - current_timezone ip_prefix_subnets :func:`quarter` st_linefromtext :func:`yow` :func:`var_samp` - :func:`date` ip_subnet_max :func:`radians` st_linestring :func:`zip` :func:`variance` - :func:`date_add` ip_subnet_min :func:`rand` st_multipoint :func:`zip_with` + :func:`abs` :func:`date_format` :func:`ip_subnet_range` :func:`random` :func:`st_numpoints` :func:`approx_distinct` :func:`cume_dist` + :func:`acos` :func:`date_parse` :func:`is_finite` :func:`reduce` :func:`st_overlaps` :func:`approx_most_frequent` :func:`dense_rank` + :func:`all_match` :func:`date_trunc` :func:`is_infinite` :func:`regexp_extract` :func:`st_point` :func:`approx_percentile` :func:`first_value` + :func:`any_keys_match` :func:`day` :func:`is_json_scalar` :func:`regexp_extract_all` :func:`st_pointn` :func:`approx_set` :func:`lag` + :func:`any_match` :func:`day_of_month` :func:`is_nan` :func:`regexp_like` :func:`st_points` :func:`arbitrary` :func:`last_value` + :func:`any_values_match` :func:`day_of_week` :func:`is_private_ip` :func:`regexp_replace` :func:`st_polygon` :func:`array_agg` :func:`lead` + :func:`array_average` :func:`day_of_year` :func:`is_subnet_of` :func:`regexp_split` :func:`st_relate` :func:`avg` :func:`nth_value` + :func:`array_cum_sum` :func:`degrees` :func:`jaccard_index` regress :func:`st_startpoint` :func:`bitwise_and_agg` :func:`ntile` + :func:`array_distinct` :func:`dot_product` :func:`json_array_contains` :func:`reidentification_potential` :func:`st_symdifference` :func:`bitwise_or_agg` :func:`percent_rank` + :func:`array_duplicates` :func:`dow` :func:`json_array_get` :func:`remove_nulls` :func:`st_touches` :func:`bool_and` :func:`rank` + :func:`array_except` :func:`doy` :func:`json_array_length` render :func:`st_union` :func:`bool_or` :func:`row_number` + :func:`array_frequency` :func:`e` :func:`json_extract` :func:`repeat` :func:`st_within` :func:`checksum` + :func:`array_has_duplicates` :func:`element_at` :func:`json_extract_scalar` :func:`replace` :func:`st_x` :func:`classification_fall_out` + :func:`array_intersect` :func:`empty_approx_set` :func:`json_format` :func:`replace_first` :func:`st_xmax` :func:`classification_miss_rate` + :func:`array_join` :func:`ends_with` :func:`json_parse` :func:`reverse` :func:`st_xmin` :func:`classification_precision` + array_least_frequent :func:`enum_key` :func:`json_size` rgb :func:`st_y` :func:`classification_recall` + :func:`array_max` :func:`exp` key_sampling_percent :func:`round` :func:`st_ymax` :func:`classification_thresholds` + :func:`array_max_by` :func:`expand_envelope` l2_squared :func:`rpad` :func:`st_ymin` :func:`convex_hull_agg` + :func:`array_min` :func:`f_cdf` :func:`laplace_cdf` :func:`rtrim` :func:`starts_with` :func:`corr` + :func:`array_min_by` features :func:`last_day_of_month` :func:`scale_qdigest` :func:`strpos` :func:`count` + :func:`array_normalize` :func:`filter` :func:`least` :func:`second` :func:`strrpos` :func:`count_if` + :func:`array_position` :func:`filter` :func:`length` :func:`secure_rand` :func:`substr` :func:`covar_pop` + :func:`array_remove` :func:`find_first` :func:`levenshtein_distance` :func:`secure_random` :func:`tan` :func:`covar_samp` + :func:`array_sort` :func:`find_first_index` :func:`line_interpolate_point` :func:`sequence` :func:`tanh` differential_entropy + :func:`array_sort_desc` :func:`flatten` :func:`line_locate_point` :func:`sha1` tdigest_agg :func:`entropy` + array_split_into_chunks :func:`flatten_geometry_collections` :func:`ln` :func:`sha256` :func:`timezone_hour` evaluate_classifier_predictions + :func:`array_sum` :func:`floor` :func:`localtime` :func:`sha512` :func:`timezone_minute` :func:`every` + :func:`array_top_n` fnv1_32 localtimestamp :func:`shuffle` :func:`to_base` :func:`geometric_mean` + :func:`array_union` fnv1_64 :func:`log10` :func:`sign` to_base32 :func:`geometry_union_agg` + :func:`arrays_overlap` fnv1a_32 :func:`log2` :func:`simplify_geometry` :func:`to_base64` :func:`histogram` + :func:`asin` fnv1a_64 :func:`longest_common_prefix` :func:`sin` :func:`to_base64url` :func:`khyperloglog_agg` + :func:`atan` :func:`format_datetime` :func:`lower` sketch_kll_quantile :func:`to_big_endian_32` :func:`kurtosis` + :func:`atan2` :func:`from_base` :func:`lpad` sketch_kll_rank :func:`to_big_endian_64` learn_classifier + bar :func:`from_base32` :func:`ltrim` :func:`slice` :func:`to_geometry` learn_libsvm_classifier + :func:`beta_cdf` :func:`from_base64` :func:`map` spatial_partitions :func:`to_hex` learn_libsvm_regressor + :func:`bing_tile` :func:`from_base64url` :func:`map_concat` :func:`split` :func:`to_ieee754_32` learn_regressor + :func:`bing_tile_at` :func:`from_big_endian_32` :func:`map_entries` :func:`split_part` :func:`to_ieee754_64` :func:`make_set_digest` + :func:`bing_tile_children` :func:`from_big_endian_64` :func:`map_filter` :func:`split_to_map` :func:`to_iso8601` :func:`map_agg` + :func:`bing_tile_coordinates` :func:`from_hex` :func:`map_from_entries` :func:`split_to_multimap` :func:`to_milliseconds` :func:`map_union` + :func:`bing_tile_parent` :func:`from_ieee754_32` :func:`map_keys` :func:`spooky_hash_v2_32` :func:`to_spherical_geography` :func:`map_union_sum` + :func:`bing_tile_polygon` :func:`from_ieee754_64` :func:`map_keys_by_top_n_values` :func:`spooky_hash_v2_64` :func:`to_unixtime` :func:`max` + :func:`bing_tile_quadkey` :func:`from_iso8601_date` :func:`map_normalize` :func:`sqrt` :func:`to_utf8` :func:`max_by` + :func:`bing_tile_zoom_level` :func:`from_iso8601_timestamp` :func:`map_remove_null_values` :func:`st_area` :func:`trail` :func:`merge` + :func:`bing_tiles_around` :func:`from_unixtime` :func:`map_subset` :func:`st_asbinary` :func:`transform` :func:`merge_set_digest` + :func:`binomial_cdf` :func:`from_utf8` :func:`map_top_n` :func:`st_astext` :func:`transform_keys` :func:`min` + :func:`bit_count` :func:`gamma_cdf` :func:`map_top_n_keys` :func:`st_boundary` :func:`transform_values` :func:`min_by` + :func:`bit_length` :func:`geometry_as_geojson` map_top_n_keys_by_value :func:`st_buffer` :func:`trim` :func:`multimap_agg` + :func:`bitwise_and` :func:`geometry_from_geojson` :func:`map_top_n_values` :func:`st_centroid` :func:`trim_array` :func:`noisy_avg_gaussian` + :func:`bitwise_arithmetic_shift_right` :func:`geometry_invalid_reason` :func:`map_values` :func:`st_contains` :func:`truncate` :func:`noisy_count_gaussian` + :func:`bitwise_left_shift` :func:`geometry_nearest_points` :func:`map_zip_with` :func:`st_convexhull` :func:`typeof` :func:`noisy_count_if_gaussian` + :func:`bitwise_logical_shift_right` :func:`geometry_to_bing_tiles` :func:`md5` :func:`st_coorddim` :func:`uniqueness_distribution` :func:`noisy_sum_gaussian` + :func:`bitwise_not` :func:`geometry_to_dissolved_bing_tiles` :func:`merge_hll` :func:`st_crosses` :func:`upper` :func:`numeric_histogram` + :func:`bitwise_or` :func:`geometry_union` :func:`merge_khll` :func:`st_difference` :func:`url_decode` :func:`qdigest_agg` + :func:`bitwise_right_shift` google_polyline_decode :func:`millisecond` :func:`st_dimension` :func:`url_encode` :func:`reduce_agg` + :func:`bitwise_right_shift_arithmetic` google_polyline_encode :func:`minute` :func:`st_disjoint` :func:`url_extract_fragment` :func:`regr_avgx` + :func:`bitwise_shift_left` :func:`great_circle_distance` :func:`mod` :func:`st_distance` :func:`url_extract_host` :func:`regr_avgy` + :func:`bitwise_xor` :func:`greatest` :func:`month` :func:`st_endpoint` :func:`url_extract_parameter` :func:`regr_count` + :func:`cardinality` :func:`hamming_distance` :func:`multimap_from_entries` :func:`st_envelope` :func:`url_extract_path` :func:`regr_intercept` + :func:`cauchy_cdf` :func:`hash_counts` :func:`murmur3_x64_128` :func:`st_envelopeaspts` :func:`url_extract_port` :func:`regr_r2` + :func:`cbrt` :func:`hmac_md5` myanmar_font_encoding :func:`st_equals` :func:`url_extract_protocol` :func:`regr_slope` + :func:`ceil` :func:`hmac_sha1` myanmar_normalize_unicode :func:`st_exteriorring` :func:`url_extract_query` :func:`regr_sxx` + :func:`ceiling` :func:`hmac_sha256` :func:`nan` :func:`st_geometries` :func:`uuid` :func:`regr_sxy` + :func:`chi_squared_cdf` :func:`hmac_sha512` :func:`ngrams` :func:`st_geometryfromtext` :func:`value_at_quantile` :func:`regr_syy` + :func:`chr` :func:`hour` :func:`no_keys_match` :func:`st_geometryn` :func:`values_at_quantiles` :func:`reservoir_sample` + classify :func:`infinity` :func:`no_values_match` :func:`st_geometrytype` :func:`week` :func:`set_agg` + :func:`codepoint` :func:`intersection_cardinality` :func:`none_match` :func:`st_geomfrombinary` :func:`week_of_year` :func:`set_union` + color :func:`inverse_beta_cdf` :func:`normal_cdf` :func:`st_interiorringn` :func:`weibull_cdf` sketch_kll + :func:`combinations` :func:`inverse_binomial_cdf` :func:`normalize` :func:`st_interiorrings` :func:`width_bucket` sketch_kll_with_k + :func:`concat` :func:`inverse_cauchy_cdf` :func:`now` :func:`st_intersection` :func:`wilson_interval_lower` :func:`skewness` + :func:`contains` :func:`inverse_chi_squared_cdf` :func:`parse_datetime` :func:`st_intersects` :func:`wilson_interval_upper` spatial_partitioning + :func:`cos` :func:`inverse_f_cdf` :func:`parse_duration` :func:`st_isclosed` :func:`word_stem` :func:`stddev` + :func:`cosh` :func:`inverse_gamma_cdf` :func:`parse_presto_data_size` :func:`st_isempty` :func:`xxhash64` :func:`stddev_pop` + :func:`cosine_similarity` :func:`inverse_laplace_cdf` :func:`pi` :func:`st_isring` :func:`year` :func:`stddev_samp` + :func:`crc32` :func:`inverse_normal_cdf` pinot_binary_decimal_to_double :func:`st_issimple` :func:`year_of_week` :func:`sum` + :func:`current_date` :func:`inverse_poisson_cdf` :func:`poisson_cdf` :func:`st_isvalid` :func:`yow` :func:`tdigest_agg` + current_time :func:`inverse_weibull_cdf` :func:`pow` :func:`st_length` :func:`zip` :func:`var_pop` + :func:`current_timestamp` :func:`ip_prefix` :func:`power` :func:`st_linefromtext` :func:`zip_with` :func:`var_samp` + :func:`current_timezone` :func:`ip_prefix_collapse` :func:`quantile_at_value` :func:`st_linestring` :func:`variance` + :func:`date` :func:`ip_prefix_subnets` :func:`quarter` :func:`st_multipoint` + :func:`date_add` :func:`ip_subnet_max` :func:`radians` :func:`st_numgeometries` + :func:`date_diff` :func:`ip_subnet_min` :func:`rand` :func:`st_numinteriorring` ======================================== ======================================== ======================================== ======================================== ======================================== == ======================================== == ======================================== diff --git a/velox/docs/functions/presto/datetime.rst b/velox/docs/functions/presto/datetime.rst index 98e42dac552b..cb19b2f75da8 100644 --- a/velox/docs/functions/presto/datetime.rst +++ b/velox/docs/functions/presto/datetime.rst @@ -144,7 +144,10 @@ Date and Time Functions .. function:: from_unixtime(unixtime) -> timestamp - Returns the UNIX timestamp ``unixtime`` as a timestamp. + Returns the UNIX timestamp ``unixtime`` as a timestamp. If the + :doc:`adjust_timestamp_to_session_timezone <../../configs>` property is set + to true, then the timestamp is adjusted to the time zone specified in + :doc:`session_timezone <../../configs>`. .. function:: from_unixtime(unixtime, string) -> timestamp with time zone :noindex: @@ -177,6 +180,27 @@ Date and Time Functions Returns ``timestamp`` as a UNIX timestamp. +.. function:: current_timezone() -> varchar + + Returns the current session time zone as a varchar. + + Example:: + + SELECT current_timezone; -- Asia/Kolkata + +.. function:: current_timestamp() -> timestamp with time zone +.. function:: now() -> timestamp with time zone + + Returns the current timestamp with session time zone applied. + The timestamp is captured once at the start of query execution and remains + constant throughout the query. This matches the standard SQL behavior for + ``CURRENT_TIMESTAMP`` and ``NOW()``. + + Example:: + + SELECT current_timestamp; -- 2025-07-17 14:53:12.123 Asia/Kolkata + SELECT now(); -- 2025-07-17 14:53:12.123 Asia/Kolkata + Truncation Function ------------------- diff --git a/velox/docs/functions/presto/enum.rst b/velox/docs/functions/presto/enum.rst new file mode 100644 index 000000000000..8e1a11d82b87 --- /dev/null +++ b/velox/docs/functions/presto/enum.rst @@ -0,0 +1,11 @@ +============== +Enum Functions +============== + +.. function:: enum_key(x) -> varchar + + Returns the string key of the enum value, where ``x`` is either a BigintEnum or VarcharEnum value. :: + + SELECT enum_key(alpha.A); -- "A" + + where alpha is a BigintEnum type with name "alpha" and values {"A": 1, "B": 2} diff --git a/velox/docs/functions/presto/geospatial.rst b/velox/docs/functions/presto/geospatial.rst index 63a80da47c33..4ae9f461a333 100644 --- a/velox/docs/functions/presto/geospatial.rst +++ b/velox/docs/functions/presto/geospatial.rst @@ -43,9 +43,6 @@ the coordinate order is (x, y). The details of both WKT and WKB can be found `here `_. -.. _OpenGIS Specifications: https://www.ogc.org/standards/ogcapi-features/ -.. _SQL/MM Part 3: Spatial: https://www.iso.org/standard/31369.html - Geometry Constructors --------------------- @@ -72,6 +69,44 @@ Geometry Constructors Returns the Point geometry at the given coordinates. This will raise an error if ``x`` or ``y`` is ``NaN`` or ``infinity``. +.. function:: ST_Polygon(wkt: varchar) -> polygon: Geometry + + Returns a geometry type polygon object from WKT representation. + +.. function:: ST_LineFromText(wkt: varchar) -> linestring: Geometry + + Returns a geometry type linestring object from WKT representation. + An error is returned if the input WKT represents a valid non-LineString + geometry. Null input returns null output. + +.. function:: ST_LineString(points: array(Geometry)) -> linestring: Geometry + + Returns a LineString formed from an array of points. If there are fewer + than two non-empty points in the input array, an empty LineString will + be returned. Throws an exception if any element in the array is null or + empty or same as the previous one. The returned geometry may not be simple, + e.g. may self-intersect or may contain duplicate vertexes depending on the + input. + +.. function:: ST_MultiPoint(points: array(Geometry)) -> multipoint: Geometry + + Returns a MultiPoint geometry object formed from the specified points. + Return null if input array is empty. Throws an exception if any element + in the array is null or empty. The returned geometry may not be simple + and may contain duplicate points if input array has duplicates. + +.. function:: to_spherical_geography(input: Geometry) -> output: SphericalGeography + + Converts a ``Geometry`` object to a SphericalGeography object on the sphere + of the Earth’s radius. For each point of the input geometry, it verifies that + point.x is within [-180.0, 180.0] and point.y is within [-90.0, 90.0], + and uses them as (longitude, latitude) degrees to construct the shape + of the ``SphericalGeography`` result. + +.. function:: to_geometry(input: SphericalGeography) -> output: Geometry + + Converts a SphericalGeography object to a Geometry object. + Spatial Predicates ------------------ @@ -113,7 +148,7 @@ function you are using. Returns ``true`` if the given geometries share space, are of the same dimension, but are not completely contained by each other. -.. function:: ST_Relat(geometry1: Geometry, geometry2: Geometry, relation: varchar) -> boolean +.. function:: ST_Relate(geometry1: Geometry, geometry2: Geometry, relation: varchar) -> boolean Returns true if first geometry is spatially related to second geometry as described by the relation. The relation is a string like ``'"1*T***T**'``: @@ -134,6 +169,11 @@ function you are using. Spatial Operations ------------------ +.. function:: ST_Boundary(geometry: Geometry) -> boundary: Geometry + + Returns the closure of the combinatorial boundary of ``geometry``. + Empty geometry inputs result in empty output. + .. function:: ST_Difference(geometry1: Geometry, geometry2: Geometry) -> difference: Geometry Returns the geometry that represents the portion of ``geometry1`` that is @@ -155,8 +195,119 @@ Spatial Operations Returns the geometry that represents the all points in either ``geometry1`` or ``geometry2``. +.. function:: ST_Envelope(geometry: Geometry) -> envelope: Geometry + + Returns the bounding rectangular polygon of a ``geometry``. Empty input will + result in empty output. + +.. function:: ST_ExteriorRing(geometry: Geometry) -> output: Geometry + + Returns a LineString representing the exterior ring of the input polygon. + Empty or null inputs result in null output. Non-polygon types will return + an error. + +.. function:: expand_envelope(geometry: Geometry, distance: double) -> output: Geometry + + Returns the bounding rectangular polygon of a geometry, expanded by a distance. + Empty geometries will return an empty polygon. Negative or NaN distances will + return an error. Positive infinity distances may lead to undefined results. + +.. function:: geometry_union(geometries: array(Geometry)) -> union: Geometry + + Returns a geometry that represents the point set union of the input geometries. + Performance of this function, in conjunction with array_agg() to first + aggregate the input geometries, may be better than geometry_union_agg(), + at the expense of higher memory utilization. Null elements in the input + array are ignored. Empty array input returns null. + +.. function:: geometry_union_agg(geometry: Geometry) -> union: Geometry + + Returns a geometry that represents the point set union of the aggregated + input geometries. Null geometries are ignored. Empty input returns null. + +.. function:: convex_hull_agg(geometry: Geometry) -> union: Geometry + + Returns a geometry that represents the convex hull of the points in the + aggregated input geometries. Null geometries are ignored. Empty input + returns null. + Accessors --------- +.. function:: ST_IsValid(geometry: Geometry) -> valid: bool + + Returns if ``geometry`` is valid, according to `SQL/MM Part 3: Spatial`_. + Examples of non-valid geometries include Polygons with self-intersecting shells. + +.. function:: ST_IsSimple(geometry: Geometry) -> simple: bool + + Returns if ``geometry`` is simple, according to `SQL/MM Part 3: Spatial`_. + Examples of non-simple geometries include LineStrings with self-intersections, + Polygons with empty rings for holes, and more. + +.. function:: ST_IsClosed(geometry: Geometry) -> closed: bool + + Returns true if the LineString’s start and end points are coincident. Will + return an error if the input geometry is not a LineString or MultiLineString. + +.. function:: ST_IsRing(geometry: Geometry) -> ring: bool + + Returns true if and only if the line is closed and simple. Will return an error + if input geometry is not a LineString. + +.. function:: ST_IsEmpty(geometry: Geometry) -> empty: bool + + Returns true if and only if this Geometry is an empty GeometryCollection, Polygon, + Point etc. + +.. function:: ST_Length(geometry: Geometry) -> length: double + + Returns the length of a LineString or MultiLineString using Euclidean measurement + on a two dimensional plane (based on spatial ref) in projected units. Will + return an error if the input geometry is not a LineString or MultiLineString. + +.. function:: ST_Length(sphericalgeography: SphericalGeography) -> length: double + + Returns the length of a ``LineString`` or ``MultiLineString`` on a spherical model of the + Earth. This is equivalent to the sum of great-circle distances between adjacent points + on the ``LineString``. + +.. function:: ST_PointN(linestring: Geometry, index: integer) -> point: geometry + + Returns the vertex of a LineString at a given index (indices start at 1). + If the given index is less than 1 or greater than the total number of elements + in the collection, returns NULL. + +.. function:: ST_Points(geometry: Geometry) -> points: array(geometry) + + Returns an array of points in a geometry. Empty or null inputs + return null. + +.. function:: ST_NumPoints(geometry: Geometry) -> points: bigint + + Returns the number of points in a geometry. This is an extension + to the SQL/MM ``ST_NumPoints`` function which only applies to + point and linestring. + +.. function:: geometry_nearest_points(geometry1: Geometry, geometry2: Geometry) -> points: array(geometry) + + Returns the points on each geometry nearest the other. If either geometry + is empty, return null. Otherwise, return an array of two Points that have + the minimum distance of any two points on the geometries. The first Point + will be from the first Geometry argument, the second from the second Geometry + argument. If there are multiple pairs with the minimum distance, one pair + is chosen arbitrarily. + +.. function:: ST_EnvelopeAsPts(geometry: Geometry) -> points: array(geometry) + + Returns an array of two points: the lower left and upper right corners + of the bounding rectangular polygon of a geometry. Empty or null inputs + return null. + +.. function:: geometry_invalid_reason(geometry: Geometry) -> reason: varchar + + If ``geometry`` is not valid or not simple, return a description of the + reason. If the geometry is valid and simple (or ``NULL``), return ``NULL``. + This function is relatively expensive. .. function:: ST_Area(geometry: Geometry) -> area: double @@ -165,6 +316,40 @@ Accessors returns the sum of the areas of the individual geometries. Empty geometries return 0. +.. function:: ST_Area(sphericalgeography: SphericalGeography) -> area: double + + Returns the area of a polygon or multi-polygon in square meters using a spherical model for Earth. + +.. function:: ST_Centroid(geometry: Geometry) -> geometry: Geometry + + Returns the point value that is the mathematical centroid of ``geometry``. + Empty geometry inputs result in empty output. + +.. function:: ST_Centroid(SphericalGeography) -> Point + + Returns the point value that is the mathematical centroid of a spherical geometry. + Empty geometry inputs result in null output. + + It supports Points and MultiPoints as input and returns the three-dimensional + centroid projected onto the surface of the (spherical) Earth. + For example, MULTIPOINT (0 -45, 0 45, 30 0, -30 0) returns Point(0, 0). + Note: In the case that the three-dimensional centroid is at (0, 0, 0) + (e.g. MULTIPOINT (0 0, -180 0)), the spherical centroid is undefined and an + arbitrary point will be returned. + +.. function:: ST_Distance(geometry1: Geometry, geometry2: Geometry) -> distance: double + + Returns the 2-dimensional cartesian minimum distance (based on spatial ref) + between two geometries in projected units. Empty geometries result in null output. + +.. function:: ST_Distance(sphericalgeography1: SphericalGeography, sphericalgeography2: SphericalGeography) -> distance: double + + Returns the great-circle distance in meters between two SphericalGeography points. + +.. function:: ST_GeometryType(geometry: Geometry) -> type: varchar + + Returns the type of the geometry. + .. function:: ST_X(geometry: Geometry) -> x: double Returns the ``x`` coordinate of the geometry if it is a Point. Returns @@ -177,6 +362,162 @@ Accessors ``null`` if the geometry is empty. Raises an error if the geometry is not a Point and not empty. +.. function:: ST_XMin(geometry: Geometry) -> x: double + + Returns the minimum ``x`` coordinate of the geometries bounding box. + Returns ``null`` if the geometry is empty. + +.. function:: ST_YMin(geometry: Geometry) -> y: double + + Returns the minimum ``y`` coordinate of the geometries bounding box. + Returns ``null`` if the geometry is empty. + +.. function:: ST_XMax(geometry: Geometry) -> x: double + + Returns the maximum ``x`` coordinate of the geometries bounding box. + Returns ``null`` if the geometry is empty. + +.. function:: ST_YMax(geometry: Geometry) -> y: double + + Returns the maximum ``y`` coordinate of the geometries bounding box. + Returns ``null`` if the geometry is empty. + +.. function:: ST_StartPoint(geometry: Geometry) -> point: Geometry + + Returns the first point of a LineString geometry as a Point. + This is a shortcut for ``ST_PointN(geometry, 1)``. Empty + input will return ``null``. + +.. function:: ST_EndPoint(geometry: Geometry) -> point: Geometry + + Returns the last point of a LineString geometry as a Point. + This is a shortcut for ``ST_PointN(geometry, ST_NumPoints(geometry))``. + Empty input will return ``null``. + +.. function:: ST_GeometryN(geometry: Geometry, index: integer) -> geometry: Geometry + + Returns the ``geometry`` element at a given index (indices start at 1). + If the ``geometry`` is a collection of geometries (e.g., GeometryCollection or + Multi*), returns the ``geometry`` at a given index. If the given index is less + than 1 or greater than the total number of elements in the collection, returns + NULL. Use ``:func:ST_NumGeometries`` to find out the total number of elements. + Singular geometries (e.g., Point, LineString, Polygon), are treated as + collections of one element. Empty geometries are treated as empty collections. + +.. function:: ST_InteriorRingN(geometry: Geometry, index: integer) -> geometry: Geometry + + Returns the interior ring element at the specified index (indices start at 1). + If the given index is less than 1 or greater than the total number of interior + rings in the input ``geometry``, returns NULL. Throws an error if the input geometry + is not a polygon. Use ``:func:ST_NumInteriorRing`` to find out the total number of + elements. + +.. function:: ST_NumGeometries(geometry: Geometry) -> output: integer + + Returns the number of geometries in the collection. If the geometry is a + collection of geometries (e.g., GeometryCollection or Multi*), + returns the number of geometries, for single geometries returns 1, + for empty geometries returns 0. Note that empty geometries inside of a + GeometryCollection will count as a geometry if and only if there is at + least 1 non-empty geometry in the collection. e.g. + ``ST_NumGeometries(ST_GeometryFromText('GEOMETRYCOLLECTION(POINT EMPTY)'))`` + will evaluate to 0, but + ``ST_NumGeometries(ST_GeometryFromText('GEOMETRYCOLLECTION(POINT EMPTY, POINT (1 2))'))`` + will evaluate to 1. + +.. function:: ST_InteriorRings(geometry: Geometry) -> output: array(geometry) + + Returns an array of all interior rings found in the input geometry, + or an empty array if the polygon has no interior rings. Returns + null if the input geometry is empty. + Throws an error if the input geometry is not a polygon. + +.. function:: ST_Geometries(geometry: Geometry) -> output: array(geometry) + + Returns an array of geometries in the specified collection. Returns + a one-element array if the input geometry is not a multi-geometry. + Returns null if input geometry is empty. For example, a MultiLineString + will create an array of LineStrings. A GeometryCollection will + produce an un-flattened array of its constituents: + GEOMETRYCOLLECTION (MULTIPOINT(0 0, 1 1), + GEOMETRYCOLLECTION (MULTILINESTRING((2 2, 3 3))) ) would produce + array[MULTIPOINT(0 0, 1 1), GEOMETRYCOLLECTION( MULTILINESTRING((2 2, 3 3)) )] + +.. function:: flatten_geometry_collections(geometry: Geometry) -> output: array(geometry) + + Recursively flattens any GeometryCollections in Geometry, returning an array + of constituent non-GeometryCollection geometries. The order of the array + is arbitrary and should not be relied upon. null input results in null output. + Examples: + + POINT (0 0) -> [POINT (0 0)], MULTIPOINT (0 0, 1 1) -> [MULTIPOINT (0 0, 1 1)], + GEOMETRYCOLLECTION (POINT (0 0), GEOMETRYCOLLECTION (POINT (1 1))) -> + [POINT (0 0), POINT (1 1)], GEOMETRYCOLLECTION EMPTY -> []. + +.. function:: ST_NumInteriorRing(geometry: Geometry) -> output: bigint + + Returns the cardinality of the collection of interior rings of a polygon. + +.. function:: ST_ConvexHull(geometry: Geometry) -> output: Geometry + + Returns the minimum convex geometry that encloses all input geometries. + +.. function:: ST_CoordDim(geometry: Geometry) -> output: tinyint + + Return the coordinate dimension of the geometry. + +.. function:: ST_Dimension(geometry: Geometry) -> output: tinyint + + Returns the inherent dimension of this geometry object, which + must be less than or equal to the coordinate dimension. + +.. function:: ST_ExteriorRing(geometry: Geometry) -> output: Geometry + + Returns a line string representing the exterior ring of the input polygon. + +.. function:: ST_Buffer(geometry: Geometry, distance: double) -> output: Geometry + + Returns the geometry that represents all points whose distance from the + specified ``geometry`` is less than or equal to the specified ``distance``. + If the points of the ``geometry`` are extremely close together + (delta < 1e-8), this might return an empty geometry. Empty inputs return + null. + +.. function:: simplify_geometry(geometry: Geometry, tolerance: double) -> output: Geometry + + Returns a "simplified" version of the input geometry using the + Douglas-Peucker algorithm. Will avoid creating geometries (polygons in + particular) that are invalid. Tolerance must be a non-negative finite value. + Using tolerance of 0 will return the original geometry. Empty geometries + will also be returned as-is. + +.. function:: line_locate_point(linestring: Geometry, point: Geometry) -> output: double + + Returns a float between 0 and 1 representing the location of the closest + point on the LineString to the given Point, as a fraction of total 2d line length. + + Returns null if a LineString or a Point is empty or null. + +.. function:: line_interpolate_point(linestring: Geometry, fraction: double) -> output: geometry + + Returns the Point on the LineString at a fractional distance given by + the double argument. Throws an exception if the distance is not between 0 and 1. + + Returns an empty Point if the LineString is empty. + Returns null if either the LineString or double is null. + +.. function:: geometry_as_geojson(geometry: Geometry) -> output: varchar + + Returns the GeoJSON encoded defined by the input geometry. If the + geometry is atomic (non-multi) empty, this function would return null. + Null input returns null output. + +.. function:: geometry_from_geojson(geometry: varchar) -> output: geometry + + Returns the geometry type object from the GeoJSON representation. + The geometry cannot be empty if it is an atomic (non-multi) geometry type. + Null input returns null output. + Bing Tile Functions ------------------- @@ -243,3 +584,20 @@ for more details. .. function:: bing_tile_quadkey() -> quadKey: varchar Returns the quadkey representing the provided bing tile. + +.. function:: geometry_to_bing_tiles(geometry: Geometry, zoom_level: tinyint) -> tiles: array(BingTile) + + Returns the minimum set of Bing tiles that fully covers a given geometry at a + given zoom level. Empty inputs return an empty array, and null inputs return + null. + +.. function:: geometry_to_dissolved_bing_tiles(geometry: Geometry, max_zoom_level: tinyint) -> tile: array(BingTile) + + Returns the minimum set of Bing tiles that fully covers a given geometry at a + given zoom level, recursively dissolving full sets of children into parents. + This results in a smaller array of tiles of different zoom levels. + For example, if the non-dissolved covering is [“00”, “01”, “02”, “03”, “10”], + the dissolved covering would be [“0”, “10”]. Zoom levels from 0 to 23 are supported. + +.. _OpenGIS Specifications: https://www.ogc.org/standards/ogcapi-features/ +.. _SQL/MM Part 3: Spatial: https://www.iso.org/standard/31369.html diff --git a/velox/docs/functions/presto/hyperloglog.rst b/velox/docs/functions/presto/hyperloglog.rst index ecd8e6d384ab..fd10739f5a53 100644 --- a/velox/docs/functions/presto/hyperloglog.rst +++ b/velox/docs/functions/presto/hyperloglog.rst @@ -70,3 +70,10 @@ Functions Returns the ``HyperLogLog`` of the aggregate union of the individual ``hll`` HyperLogLog structures. + +.. function:: merge_hll(array(HyperLogLog)) -> HyperLogLog + + Returns the ``HyperLogLog`` of the union of an array of ``HyperLogLog`` structures. + + * Returns ``NULL`` if the input array is ``NULL``, empty, or contains only ``NULL`` elements + * Ignores ``NULL`` elements and merges only valid ``HyperLogLog`` structures when the array contains a mix of ``NULL`` and non-null elements diff --git a/velox/docs/functions/presto/map.rst b/velox/docs/functions/presto/map.rst index c81d7eba948b..51c74d2bdc84 100644 --- a/velox/docs/functions/presto/map.rst +++ b/velox/docs/functions/presto/map.rst @@ -46,6 +46,18 @@ Map Functions See also :func:`map_agg` for creating a map as an aggregation. +.. function:: map_append(map(K,V), array(K), array(V)) -> map(K,V) + + Returns a map with new key-value pairs appended to the input map. The new keys are provided in the first array parameter and corresponding values in the second array parameter. + Keys and values arrays must have the same length. New keys must not already exist in the input map. Duplicate keys in the new keys array are not allowed. + Null keys are ignored. Null values are preserved in the output map. For REAL and DOUBLE, NaNs (Not-a-Number) are considered equal. :: + + SELECT map_append(MAP(ARRAY[1, 2], ARRAY[10, 20]), ARRAY[3, 4], ARRAY[30, 40]); -- {1 -> 10, 2 -> 20, 3 -> 30, 4 -> 40} + SELECT map_append(MAP(ARRAY['a', 'b'], ARRAY[1, 2]), ARRAY['c'], ARRAY[3]); -- {'a' -> 1, 'b' -> 2, 'c' -> 3} + SELECT map_append(MAP(ARRAY[1], ARRAY[10]), ARRAY[2, null, 3], ARRAY[20, 30, 40]); -- {1 -> 10, 2 -> 20, 3 -> 40} + SELECT map_append(MAP(ARRAY[1], ARRAY[10]), ARRAY[2, 3], ARRAY[null, 30]); -- {1 -> 10, 2 -> null, 3 -> 30} + SELECT map_append(MAP(ARRAY[1], ARRAY[10]), ARRAY[], ARRAY[]); -- {1 -> 10} + .. function:: map_concat(map1(K,V), map2(K,V), ..., mapN(K,V)) -> map(K,V) Returns the union of all the given maps. If a key is found in multiple given maps, @@ -94,6 +106,16 @@ Map Functions SELECT map_remove_null_values(MAP(ARRAY[1, 2, 3], ARRAY[3, 4, NULL])); -- {1=3, 2=4} SELECT map_remove_null_values(NULL); -- NULL +.. function:: remap_keys(map(K,V), array(K), array(K)) -> map(K,V) + + Returns a map with keys remapped according to the oldKeys and newKeys arrays. + Unmapped keys remain unchanged. Values are preserved. Null keys are ignored. :: + + SELECT remap_keys(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[1, 3], ARRAY[100, 300]); -- {100 -> 10, 2 -> 20, 300 -> 30} + SELECT remap_keys(MAP(ARRAY['a', 'b', 'c'], ARRAY[1, 2, 3]), ARRAY['a', 'c'], ARRAY['alpha', 'charlie']); -- {alpha -> 1, b -> 2, charlie -> 3} + SELECT remap_keys(MAP(ARRAY[1, 2, 3], ARRAY[10, null, 30]), ARRAY[1, 2], ARRAY[100, 200]); -- {100 -> 10, 200 -> null, 3 -> 30} + SELECT remap_keys(MAP(ARRAY[1, 2], ARRAY[10, 20]), ARRAY[], ARRAY[]); -- {1 -> 10, 2 -> 20} + .. function:: map_subset(map(K,V), array(k)) -> map(K,V) Constructs a map from those entries of ``map`` for which the key is in the array given @@ -102,8 +124,40 @@ Map Functions SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[10]); -- {} SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1]); -- {1->'a'} SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1,3]); -- {1->'a'} - SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[]); -- {} SELECT map_subset(MAP(ARRAY[], ARRAY[]), ARRAY[1,2]); -- {} + SELECT map_subset(MAP(ARRAY[], ARRAY[]), ARRAY[]); -- {} + +.. function:: map_intersect(map(K,V), array(K)) -> map(K,V) + + Returns a map containing only the entries from the input map whose keys are present in the given array. + This function is equivalent to map_subset. Null keys in the array are ignored. + For keys containing REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: + + SELECT map_intersect(MAP(ARRAY[1,2,3], ARRAY['a','b','c']), ARRAY[1,3]); -- {1->'a', 3->'c'} + SELECT map_intersect(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[10]); -- {} + SELECT map_intersect(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[]); -- {} + SELECT map_intersect(MAP(ARRAY[], ARRAY[]), ARRAY[1,2]); -- {} + +.. function:: map_except(map(K,V), array(k)) -> map(K,V) + + Constructs a map from those entries of ``map`` for which the key is not in the array given. + For keys containing REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: + + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[10]); -- {1->'a', 2->'b'} + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1]); -- {2->'b'} + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1,3]); -- {2->'b'} + SELECT map_except(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[]); -- {1->'a', 2->'b'} + SELECT map_except(MAP(ARRAY[], ARRAY[]), ARRAY[1,2]); -- {} + +.. function:: map_keys_overlap(map(K,V), array(K)) -> boolean + + Returns true if any key in the map matches any element in the given array, false otherwise. + Returns false if either the map or array is empty. Null keys in the array are ignored. :: + + SELECT map_keys_overlap(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[1, 5]); -- true + SELECT map_keys_overlap(MAP(ARRAY[1, 2, 3], ARRAY[10, 20, 30]), ARRAY[4, 5]); -- false + SELECT map_keys_overlap(MAP(ARRAY['a', 'b'], ARRAY[1, 2]), ARRAY['a']); -- true + SELECT map_keys_overlap(MAP(ARRAY[], ARRAY[]), ARRAY[1]); -- false .. function:: map_top_n(map(K,V), n) -> map(K, V) diff --git a/velox/docs/functions/presto/math.rst b/velox/docs/functions/presto/math.rst index 7de371d10249..91e3cc7e6a20 100644 --- a/velox/docs/functions/presto/math.rst +++ b/velox/docs/functions/presto/math.rst @@ -51,6 +51,47 @@ Mathematical Functions SELECT cosine_similarity(ARRAY[], ARRAY[]); -- NaN +.. function:: cosine_similarity(array(real), array(real)) -> real + + Returns the `cosine similarity `_ between the vectors represented as array(real). + If any input array is empty, the function returns NaN. If the input arrays have different sizes, the function throws VeloxUserError. + +.. function:: l2_squared(array(real), array(real)) -> real + + Returns the squared `Euclidean distance `_ between the vectors represented as array(real). + If any input array is empty, the function returns NaN. If the input arrays have different sizes, the function throws VeloxUserError. + + SELECT l2_squared(ARRAY[1], ARRAY[2]); -- 1.0 + + SELECT l2_squared(ARRAY[1.0, 2.0], ARRAY[NULL, 3.0]); -- NULL + + SELECT l2_squared(ARRAY[], ARRAY[2, 3]); -- Throws VeloxUserError + + SELECT l2_squared(ARRAY[], ARRAY[]); -- NaN + +.. function:: l2_squared(array(double), array(double)) -> double + + Returns the squared `Euclidean distance `_ between the vectors represented as array(double). + If any input array is empty, the function returns NaN. If the input arrays have different sizes, the function throws VeloxUserError. + +.. function:: dot_product(array(real), array(real)) -> real + + Returns the `Dot Product `_ between the vectors represented as array(real). + If any input array is empty, the function returns NaN. If the input arrays have different sizes, the function throws VeloxUserError. + + SELECT dot_product(ARRAY[1], ARRAY[2]); -- 2.0 + + SELECT dot_product(ARRAY[1.0, 2.0], ARRAY[NULL, 3.0]); -- NULL + + SELECT dot_product(ARRAY[], ARRAY[2, 3]); -- Throws VeloxUserError + + SELECT dot_product(ARRAY[], ARRAY[]); -- NaN + +.. function:: dot_product(array(double), array(double)) -> double + + Returns the `Dot Product `_ between the vectors represented as array(double). + If any input array is empty, the function returns NaN. If the input arrays have different sizes, the function throws VeloxUserError. + .. function:: degrees(x) -> double Converts angle x in radians to degrees. @@ -342,6 +383,10 @@ Probability Functions: cdf Compute the Poisson cdf with given lambda (mean) parameter: P(N <= value; lambda). The lambda parameter must be a positive real number (of type DOUBLE) and value must be a non-negative integer. +.. function:: t_cdf(df, value) -> double + + Compute the Student's t cdf with given degrees of freedom: P(N < value; df). + The degrees of freedom must be a positive real number and value must be a real value. .. function:: weibull_cdf(a, b, value) -> double @@ -375,6 +420,51 @@ Probability Functions: inverse_cdf The mean must be a real value and the scale must be a positive real value (both of type DOUBLE). The probability ``p`` must lie on the interval [0, 1]. +.. function:: inverse_f_cdf(df1, df2, p) -> double + + Compute the inverse of the Fisher F cdf with a given ``df1`` (numerator degrees of freedom) and ``df2`` (denominator degrees of freedom) parameters + for the cumulative probability (p): P(N < n). The numerator and denominator df parameters must be positive real numbers. + The probability ``p`` must lie on the interval [0, 1]. + +.. function:: inverse_normal_cdf(mean, sd, p) -> double + + Compute the inverse of the Normal cdf with given mean and standard + deviation (sd) for the cumulative probability (p): P(N < n). The mean must be + a real value and the standard deviation must be a real and positive value (both of type DOUBLE). + The probability p must lie on the interval (0, 1). + +.. function:: inverse_gamma_cdf(shape, scale, p) -> double + + Compute the inverse of the Gamma cdf with given shape and scale parameters for the cumulative + probability (p): P(N < n). The shape and scale parameters must be positive real values. + The probability p must lie on the interval [0, 1]. + +.. function:: inverse_binomial_cdf(numberOfTrials, successProbability, p) -> int + + Compute the inverse of the Binomial cdf with given numberOfTrials and successProbability (of a single trial) the + cumulative probability (p): P(N <= n). + The successProbability and p must be real values in [0, 1] and the numberOfTrials must be + a positive integer. + +.. function:: inverse_poisson_cdf(lambda, p) -> integer + + Compute the inverse of the Poisson cdf with given lambda (mean) parameter for the cumulative + probability (p). It returns the value of n so that: P(N <= n; lambda) = p. + The lambda parameter must be a positive real number (of type DOUBLE). + The probability p must lie on the interval [0, 1). + +.. function:: inverse_chi_squared_cdf(df, p) -> double + + Compute the inverse of the Chi-square cdf with given df (degrees of freedom) parameter for the cumulative + probability (p): P(N < n). The df parameter must be positive real values. + The probability p must lie on the interval [0, 1]. + +.. function:: inverse_t_cdf(df, p) -> double + + Compute the inverse of the Student's t cdf with given degrees of freedom for the cumulative + probability (p): P(N < n). The degrees of freedom must be a positive real value. + The probability p must lie on the interval [0, 1]. + ==================================== Statistical Functions ==================================== diff --git a/velox/docs/functions/presto/most_used_coverage.rst b/velox/docs/functions/presto/most_used_coverage.rst index 9a9e42a40a92..50e52edd5493 100644 --- a/velox/docs/functions/presto/most_used_coverage.rst +++ b/velox/docs/functions/presto/most_used_coverage.rst @@ -54,6 +54,7 @@ Here is a list of most used scalar and aggregate Presto functions with functions table.coverage tr:nth-child(7) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(7) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(7) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(8) td:nth-child(2) {background-color: #6BA81E;} @@ -71,10 +72,13 @@ Here is a list of most used scalar and aggregate Presto functions with functions table.coverage tr:nth-child(10) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(3) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(10) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(10) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(11) td:nth-child(4) {background-color: #6BA81E;} + table.coverage tr:nth-child(11) td:nth-child(5) {background-color: #6BA81E;} table.coverage tr:nth-child(11) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(12) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(12) td:nth-child(2) {background-color: #6BA81E;} @@ -100,6 +104,7 @@ Here is a list of most used scalar and aggregate Presto functions with functions table.coverage tr:nth-child(16) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(2) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(3) {background-color: #6BA81E;} + table.coverage tr:nth-child(16) td:nth-child(4) {background-color: #6BA81E;} table.coverage tr:nth-child(16) td:nth-child(7) {background-color: #6BA81E;} table.coverage tr:nth-child(17) td:nth-child(1) {background-color: #6BA81E;} table.coverage tr:nth-child(17) td:nth-child(2) {background-color: #6BA81E;} @@ -126,16 +131,16 @@ Here is a list of most used scalar and aggregate Presto functions with functions :func:`upper` :func:`transform_values` :func:`map_filter` :func:`map_zip_with` :func:`xxhash64` :func:`array_agg` :func:`split` :func:`map_entries` :func:`regexp_extract` :func:`year` :func:`to_hex` :func:`arbitrary` :func:`random` :func:`concat` :func:`map_values` :func:`slice` :func:`transform_keys` :func:`min` - :func:`floor` :func:`cardinality` :func:`map_keys` :func:`month` bing_tile_quadkey :func:`max_by` + :func:`floor` :func:`cardinality` :func:`map_keys` :func:`month` :func:`bing_tile_quadkey` :func:`max_by` :func:`contains` :func:`sequence` :func:`reduce` :func:`any_match` :func:`to_utf8` :func:`approx_distinct` :func:`map_concat` :func:`substr` :func:`greatest` :func:`bitwise_and` :func:`crc32` :func:`count_if` - :func:`length` :func:`date` :func:`date_trunc` :func:`date_parse` st_y :func:`approx_percentile` - :func:`from_unixtime` :func:`is_nan` :func:`date_diff` bing_tile_at st_x :func:`avg` + :func:`length` :func:`date` :func:`date_trunc` :func:`date_parse` :func:`st_y` :func:`approx_percentile` + :func:`from_unixtime` :func:`is_nan` :func:`date_diff` :func:`bing_tile_at` :func:`st_x` :func:`avg` :func:`transform` :func:`rand` :func:`array_max` :func:`array_union` now :func:`map_agg` :func:`to_unixtime` :func:`filter` :func:`from_iso8601_date` :func:`reverse` :func:`truncate` :func:`min_by` :func:`regexp_like` :func:`sqrt` :func:`json_extract` :func:`array_intersect` :func:`stddev` :func:`array_join` :func:`least` :func:`mod` :func:`repeat` :func:`set_agg` - :func:`replace` :func:`json_parse` :func:`array_distinct` st_geometryfromtext :func:`histogram` + :func:`replace` :func:`json_parse` :func:`array_distinct` :func:`st_geometryfromtext` :func:`histogram` :func:`regexp_replace` :func:`map_from_entries` :func:`pow` :func:`split_part` :func:`set_union` :func:`parse_datetime` :func:`date_add` :func:`power` :func:`log10` :func:`merge` =========================== =========================== =========================== =========================== =========================== == =========================== == =========================== diff --git a/velox/docs/functions/presto/qdigest.rst b/velox/docs/functions/presto/qdigest.rst new file mode 100644 index 000000000000..9b7b00dabb05 --- /dev/null +++ b/velox/docs/functions/presto/qdigest.rst @@ -0,0 +1,65 @@ +========================= +Quantile Digest Functions +========================= + +Quantile digest and `T-digest `_ are two +older algorithms for estimating rank-based metrics. While T-digest generally has `better +performance `_ and better accuracy at the tails, +quantile digest supports more numeric types (``bigint``, ``double``, ``real``) and +provides a maximum rank error guarantee, which ensures relative uniformity of precision +along the quantiles. Quantile digests are also formally proven to support lossless merges, +while T-digest is not (though it does empirically demonstrate lossless merges). + +Velox uses the modern KLL sketch algorithm for the ``approx_percentile`` function, which +provides stronger accuracy guarantees than both quantile digest and T-digest. +The quantile digest functions documented here exist primarily to support +pre-existing workloads that have data stored using the `quantile +digest `_ format for backward compatibility. + +Data Structures +--------------- + +A quantile digest is a data sketch which stores approximate percentile +information. The Velox type for this data structure is called ``qdigest``, +and it takes a parameter which must be one of ``bigint``, ``double`` or +``real`` which represent the set of numbers that may be ingested by the +``qdigest``. They may be merged without losing precision, and for storage +and retrieval they may be cast to/from ``VARBINARY``. + +In the function signatures below, ``T`` represents the parameterized type of the qdigest, +which can be ``bigint``, ``double``, or ``real``. + +Functions +--------- + +.. function:: merge(qdigest) -> qdigest + + Merges all input ``qdigest``\ s into a single ``qdigest``. + +.. function:: qdigest_agg(x: T) -> qdigest + + Returns the ``qdigest`` which summarizes the approximate distribution of all input values of ``x``. + +.. function:: qdigest_agg(x: T, w: bigint) -> qdigest + :noindex: + + Returns the ``qdigest`` which summarizes the approximate distribution of all input values of ``x`` using + the per-item weight ``w``. + +.. function:: qdigest_agg(x: T, w: bigint, accuracy: double) -> qdigest + :noindex: + + Returns the ``qdigest`` which summarizes the approximate distribution of all input values of ``x`` using + the per-item weight ``w`` and maximum error of ``accuracy``. ``accuracy`` + must be a value greater than zero and less than one, and it must be constant + for all input rows. + +.. function:: value_at_quantile(digest: qdigest, quantile: double) -> T + + Returns the approximate percentile values from the quantile digest ``digest`` given the ``quantile``. + The ``quantile`` must be between zero and one (inclusive). + +.. function:: values_at_quantiles(digest: qdigest, quantiles: array) -> array + + Returns the approximate percentile values as an array from the quantile digest ``digest`` at each of the specified quantiles given in the ``quantiles`` array. + All quantile values must be between zero and one (inclusive). diff --git a/velox/docs/functions/presto/regexp.rst b/velox/docs/functions/presto/regexp.rst index e9b8914c8d60..dee06e7d0157 100644 --- a/velox/docs/functions/presto/regexp.rst +++ b/velox/docs/functions/presto/regexp.rst @@ -114,3 +114,8 @@ limited to 20 different expressions per instance and thread of execution. array. Trailing empty strings are preserved:: SELECT regexp_split('1a 2b 14m', '\s*[a-z]+\s*'); -- [1, 2, 14, ] + + Note: When the regular expression ``pattern`` matches an empty string, the result array contains an + empty string, a single-character string for each character in the input string, and another empty string:: + + SELECT regexp_split('1a 2b 14m', ''); -- [, 1, 'a', ' ', 2, 'b', ' ', 1, 4, 'm', ] diff --git a/velox/docs/functions/presto/string.rst b/velox/docs/functions/presto/string.rst index 7806bd646ee0..d9ffbe0c8b2d 100644 --- a/velox/docs/functions/presto/string.rst +++ b/velox/docs/functions/presto/string.rst @@ -60,6 +60,10 @@ String Functions i.e. the number of positions at which the corresponding characters are different. Note that the two strings must have the same length. +.. function:: jarowinkler_similarity(string1, string2) -> double + + Returns the Jaro-Winkler similarity of ``string1`` and ``string2``. + .. function:: length(string) -> bigint Returns the length of ``string`` in characters. diff --git a/velox/docs/functions/presto/tdigest.rst b/velox/docs/functions/presto/tdigest.rst new file mode 100644 index 000000000000..b43be2f3de6c --- /dev/null +++ b/velox/docs/functions/presto/tdigest.rst @@ -0,0 +1,116 @@ +================== +T-Digest Functions +================== + +T-digest and `quantile digest `_ are two +older algorithms for estimating rank-based metrics. T-digest generally has `better +performance `_ than quantile digest and better accuracy +at the tails (often dramatically better), but may have worse accuracy at the median +depending on the compression factor used. In comparison, quantile digest supports more +numeric types and provides a maximum rank error guarantee, which ensures relative uniformity +of precision along the quantiles. Quantile digests are also formally proven to support +lossless merges, while T-digest is not (though it does empirically demonstrate lossless merges). + +T-digest was developed by Ted Dunning and is more restrictive in its type support, +accepting only ``double`` type parameters. This contrasts with quantile digest, which +supports a broader range of numeric types including ``bigint``, ``double``, and ``real``, +making quantile digest more versatile for different data types. + +Velox uses the modern KLL sketch algorithm for the ``approx_percentile`` function, which +provides stronger accuracy guarantees than both T-digest and quantile digest. +The T-digest functions documented here exist primarily to support +pre-existing workloads that have data stored using the `T-digest +`_ format for backward compatibility. + +Data Structures +--------------- + +A T-digest is a data sketch which stores approximate percentile information. +The Velox type for this data structure is called ``tdigest``, +and it accepts a parameter of type ``double`` which represents the set of +numbers to be ingested by the ``tdigest``. + +T-digests may be merged without losing precision, and for storage and retrieval +they may be cast to/from ``VARBINARY``. + +Functions +--------- + +.. function:: construct_tdigest(means: array, counts: array, compression: double, min: double, max: double, sum: double, count: bigint) -> tdigest + + Constructs a T-digest from the given parameters: + + * ``means`` - array of centroid means + * ``counts`` - array of centroid counts (weights) + * ``compression`` - compression factor + * ``min`` - minimum value + * ``max`` - maximum value + * ``sum`` - sum of all values + * ``count`` - total count of values + +.. function:: destructure_tdigest(digest: tdigest) -> row(means array, counts array, compression double, min double, max double, sum double, count bigint) + + Destructures a T-digest into its component parts, returning a row containing: + + * ``means`` - array of centroid means + * ``counts`` - array of centroid counts + * ``compression`` - compression factor + * ``min`` - minimum value + * ``max`` - maximum value + * ``sum`` - sum of all values + * ``count`` - total count of values + +.. function:: merge(tdigest) -> tdigest + + Merges all input ``tdigest``\ s into a single ``tdigest``. + +.. function:: merge_tdigest(digests: array>) -> tdigest + + Merges an array of T-digests into a single T-digest. + +.. function:: quantile_at_value(digest: tdigest, value: double) -> double + + Returns the approximate quantile (percentile) of the given ``value`` based on the T-digest ``digest``. + The result will be between zero and one (inclusive). + +.. function:: quantiles_at_values(digest: tdigest, values: array) -> array + + Returns the approximate quantiles (percentiles) as an array for each of the given ``values`` based on the T-digest ``digest``. + All results will be between zero and one (inclusive). + +.. function:: scale_tdigest(digest: tdigest, scale: double) -> tdigest + + Scales the T-digest ``digest`` by the given ``scale`` factor. + This multiplies all the centroid values in the T-digest by the scale factor. + +.. function:: tdigest_agg(x: double) -> tdigest + + Returns the ``tdigest`` which summarizes the approximate distribution of all input values of ``x``. + The default compression factor is ``100``. + +.. function:: tdigest_agg(x: double, w: double) -> tdigest + :noindex: + + Returns the ``tdigest`` which summarizes the approximate distribution of all input values of ``x`` using per-item weight ``w``. + The default compression factor is ``100``. + +.. function:: tdigest_agg(x: double, w: double, compression: double) -> tdigest + :noindex: + + Returns the ``tdigest`` which summarizes the approximate distribution of all input values of ``x`` using per-item weight ``w`` and the specified compression factor. + ``compression`` must be a positive constant for all input rows. The default is ``100``, maximum is ``1000``, and values lower than ``10`` are rounded to ``10``. Higher compression means more accuracy at the cost of more memory. + +.. function:: trimmed_mean(digest: tdigest, low_quantile: double, high_quantile: double) -> double + + Returns the mean of values between ``low_quantile`` and ``high_quantile`` (inclusive) from the T-digest ``digest``. + Both quantile values must be between zero and one (inclusive), and ``low_quantile`` must be less than or equal to ``high_quantile``. + +.. function:: value_at_quantile(digest: tdigest, quantile: double) -> double + + Returns the approximate percentile value from the T-digest ``digest`` at the given ``quantile``. + The ``quantile`` must be between zero and one (inclusive). + +.. function:: values_at_quantiles(digest: tdigest, quantiles: array) -> array + + Returns the approximate percentile values as an array from the T-digest ``digest`` at each of the specified quantiles given in the ``quantiles`` array. + All quantile values must be between zero and one (inclusive). diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index a9ac2bd001c5..9c5e90c6dd43 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -10,11 +10,12 @@ General Aggregate Functions .. spark:function:: avg(x) -> double|decimal Returns the average (arithmetic mean) of all non-null input values. - When x is of type DECIMAL, the result type is DECIMAL, - and the intermediate results are varbinarys or (sum, count) pairs represented as row(decimal, bigint). + When ``x`` is of type DECIMAL(p, s), the result type is DECIMAL(p + 4, s + 4), + and the intermediate results are (sum, count) pairs represented as ROW(DECIMAL(p + 10, s), BIGINT). + The current implementation for DECIMAL matches Spark avg's default behavior with spark.sql.decimalOperations.allowPrecisionLoss=true. For all other input types, the result type is DOUBLE, - and the intermediate results are (sum, count) pairs represented as row(double, bigint). - When all inputs are nulls, the intermediate result is row(0, 0), + and the intermediate results are (sum, count) pairs represented as ROW(DOUBLE, BIGINT). + When all inputs are nulls, the intermediate result is ROW(0, 0), and the final result is null. .. spark:function:: bit_xor(x) -> bigint diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index 88689787c3b4..42039e6c636d 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -1,6 +1,6 @@ -============================= +=============== Array Functions -============================= +=============== .. spark:function:: aggregate(array(E), start, merge, finish) -> array(E) @@ -171,13 +171,35 @@ Array Functions .. spark:function:: array_sort(array(E)) -> array(E) Returns an array which has the sorted order of the input array(E). The elements of array(E) must - be orderable. Null elements will be placed at the end of the returned array. :: + be orderable. NULL and NaN elements will be placed at the end of the returned array, with NaN elements appearing before NULL elements for floating-point types. :: SELECT array_sort(array(1, 2, 3)); -- [1, 2, 3] SELECT array_sort(array(3, 2, 1)); -- [1, 2, 3] - SELECT array_sort(array(2, 1, NULL); -- [1, 2, NULL] + SELECT array_sort(array(2, 1, NULL)); -- [1, 2, NULL] SELECT array_sort(array(NULL, 1, NULL)); -- [1, NULL, NULL] SELECT array_sort(array(NULL, 2, 1)); -- [1, 2, NULL] + SELECT array_sort(array(4.0, NULL, float('nan'), 3.0)); -- [3.0, 4.0, NaN, NULL] + SELECT array_sort(array(array(), array(1, 3, NULL), array(NULL, 6), NULL, array(2, 1))); -- [[], [NULL, 6], [1, 3, NULL], [2, 1], NULL] + +.. spark:function:: array_sort(array(E), function(E,U)) -> array(E) + :noindex: + + Returns the array sorted by values computed using specified lambda in ascending order. ``U`` must be an orderable type. + NULL and NaN elements returned by the lambda function will be placed at the end of the returned array, with NaN elements appearing before NULL elements. + This function is not supported in Spark and is only used inside Velox for rewriting :spark:func:`array_sort(array(E), function(E,E,U)) -> array(E)` as :spark:func:`array_sort(array(E), function(E,U)) -> array(E)`. :: + +.. spark:function:: array_sort(array(E), function(E,E,U)) -> array(E) + :noindex: + + Returns the array sorted by values computed using specified lambda in ascending + order. ``U`` must be an orderable type. + The function attempts to analyze the lambda function and rewrite it into a simpler call that + specifies the sort-by expression (like :spark:func:`array_sort(array(E), function(E,U)) -> array(E)`). For example, ``(left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))`` will be rewritten to ``x -> length(x)``. If rewrite is not possible, a user error will be thrown. + If the rewritten function returns NULL, the corresponding element will be placed at the end the returned array. Please note that due to this rewrite optimization, the NULL handling logics between Spark and Velox differ. In Spark, the position of NULL element is determined by the comparison of NULL with other elements. :: + + SELECT array_sort(array('cat', 'leopard', 'mouse'), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ['cat', 'mouse', 'leopard'] + select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ["abc", "abcd", "abcd123", NULL] + select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) = length(right), 0, -1))); -- ["abc", "abcd", "abcd123", NULL] different with Spark: ["abc", NULL, "abcd", "abcd123"] .. spark:function:: array_union(array(E) x, array(E) y) -> array(E) diff --git a/velox/docs/functions/spark/binary.rst b/velox/docs/functions/spark/binary.rst index ab594134d870..4b49e6e36f8f 100644 --- a/velox/docs/functions/spark/binary.rst +++ b/velox/docs/functions/spark/binary.rst @@ -16,16 +16,6 @@ Binary Functions Computes the hash of one or more input values using specified seed. For multiple arguments, their types can be different. -.. spark:function:: xxhash64(x, ...) -> bigint - - Computes the xxhash64 of one or more input values using seed value of 42. - For multiple arguments, their types can be different. - -.. spark:function:: xxhash64_with_seed(seed, x, ...) -> bigint - - Computes the xxhash64 of one or more input values using specified seed. For - multiple arguments, their types can be different. - .. spark:function:: md5(x) -> varbinary Computes the md5 of x. @@ -50,3 +40,13 @@ Binary Functions have a value of 224, 256, 384, 512, or 0 (which is equivalent to 256). If asking for an unsupported bitLength, the return value is NULL. Note: x can only be varbinary type. + +.. spark:function:: xxhash64(x, ...) -> bigint + + Computes the xxhash64 of one or more input values using seed value of 42. + For multiple arguments, their types can be different. + +.. spark:function:: xxhash64_with_seed(seed, x, ...) -> bigint + + Computes the xxhash64 of one or more input values using specified seed. For + multiple arguments, their types can be different. diff --git a/velox/docs/functions/spark/comparison.rst b/velox/docs/functions/spark/comparison.rst index 7dc953930299..11a89738dd85 100644 --- a/velox/docs/functions/spark/comparison.rst +++ b/velox/docs/functions/spark/comparison.rst @@ -1,6 +1,6 @@ -===================================== +==================== Comparison Functions -===================================== +==================== .. spark:function:: between(x, min, max) -> boolean @@ -86,33 +86,3 @@ Comparison Functions Returns true if x is not equal to y. Supports all scalar types. The types of x and y must be the same. Corresponds to Spark's operator ``!=``. - -.. spark:function:: decimal_lessthan(x, y) -> boolean - - Returns true if x is less than y. Supports decimal types with different precisions and scales. - Corresponds to Spark's operator ``<``. - -.. spark:function:: decimal_lessthanorequal(x, y) -> boolean - - Returns true if x is less than y or x is equal to y. Supports decimal types with different precisions and scales. - Corresponds to Spark's operator ``<=``. - -.. spark:function:: decimal_equalto(x, y) -> boolean - - Returns true if x is equal to y. Supports decimal types with different precisions and scales. - Corresponds to Spark's operator ``==``. - -.. spark:function:: decimal_notequalto(x, y) -> boolean - - Returns true if x is not equal to y. Supports decimal types with different precisions and scales. - Corresponds to Spark's operator ``!=``. - -.. spark:function:: decimal_greaterthan(x, y) -> boolean - - Returns true if x is greater than y. Supports decimal types with different precisions and scales. - Corresponds to Spark's operator ``>``. - -.. spark:function:: decimal_greaterthanorequal(x, y) -> boolean - - Returns true if x is greater than y or x is equal to y. Supports decimal types with different precisions and scales. - Corresponds to Spark's operator ``>=``. diff --git a/velox/docs/functions/spark/conversion.rst b/velox/docs/functions/spark/conversion.rst index 3661dd023c07..6ea90482f8cd 100644 --- a/velox/docs/functions/spark/conversion.rst +++ b/velox/docs/functions/spark/conversion.rst @@ -2,6 +2,38 @@ Conversion Functions ==================== +.. spark:function:: cast(value AS type) -> type + + Explicitly cast a ``value`` to a specified ``type``. + Follows the behavior when Spark ANSI mode is disabled, and does not support + the behavior when ANSI is turned on: + + * If the ``value`` exceeds the range of the ``type``, no error is raised. + Instead, the ``value`` is "wrapped" around. + + * If the ``value`` has an invalid format or contains characters incompatible + with the target ``type``, the cast function returns NULL. :: + + SELECT cast(128 as tinyint); -- -128 + SELECT cast('2012-Oct-23' as date); -- NULL + +.. spark:function:: try_cast(value AS type) -> type + + Returns the ``value`` cast to ``type`` if possible, or NULL if not possible. + Its behavior is independent of the ANSI mode setting, and it acts identically + to cast with ANSI mode enabled but returns NULL rather than throwing errors + for failure to cast. + ``try_cast`` differs from ``cast`` function with ANSI mode disabled in following case: + + * If the ``value`` cannot fit within the domain of ``type``, the result is NULL. :: + + SELECT try_cast(128 as tinyint); -- NULL + SELECT try_cast(cast(550000.0 as DECIMAL(8, 1)) as smallint); -- NULL + SELECT try_cast(1e12 as int); -- NULL + +Cast from UNKNOWN Type +---------------------- + Casting from UNKNOWN type to all other scalar types is supported, e.g., cast(NULL as int). Cast to Integral Types @@ -74,13 +106,13 @@ Invalid examples :: - SELECT cast('1234567' as tinyint); -- Out of range - SELECT cast('1a' as tinyint); -- Invalid argument - SELECT cast('' as tinyint); -- Invalid argument - SELECT cast('1,234,567' as bigint); -- Invalid argument - SELECT cast('1'234'567' as bigint); -- Invalid argument - SELECT cast('nan' as bigint); -- Invalid argument - SELECT cast('infinity' as bigint); -- Invalid argument + SELECT cast('1234567' as tinyint); -- NULL // Reason: Out of range + SELECT cast('1a' as tinyint); -- NULL // Invalid argument + SELECT cast('' as tinyint); -- NULL // Invalid argument + SELECT cast('1,234,567' as bigint); -- NULL // Invalid argument + SELECT cast('1'234'567' as bigint); -- NULL // Invalid argument + SELECT cast('nan' as bigint); -- NULL // Invalid argument + SELECT cast('infinity' as bigint); -- NULL // Invalid argument From decimal ^^^^^^^^^^^^ @@ -100,7 +132,7 @@ Valid examples SELECT cast(cast(2147483648.90 as DECIMAL(12, 2)) as bigint); -- 2147483648 From timestamp -^^^^^^^^^^^^^ +^^^^^^^^^^^^^^ Casting timestamp as integral types returns the number of seconds by converting timestamp as microseconds, dividing by the number of microseconds in a second, and then rounding down to the nearest second since the epoch (1970-01-01 00:00:00 UTC). @@ -145,13 +177,13 @@ Invalid examples :: - SELECT cast('1.7E308' as boolean); -- Invalid argument - SELECT cast('nan' as boolean); -- Invalid argument - SELECT cast('infinity' as boolean); -- Invalid argument - SELECT cast('12' as boolean); -- Invalid argument - SELECT cast('-1' as boolean); -- Invalid argument - SELECT cast('tr' as boolean); -- Invalid argument - SELECT cast('tru' as boolean); -- Invalid argument + SELECT cast('1.7E308' as boolean); -- NULL // Invalid argument + SELECT cast('nan' as boolean); -- NULL // Invalid argument + SELECT cast('infinity' as boolean); -- NULL // Invalid argument + SELECT cast('12' as boolean); -- NULL // Invalid argument + SELECT cast('-1' as boolean); -- NULL // Invalid argument + SELECT cast('tr' as boolean); -- NULL // Invalid argument + SELECT cast('tru' as boolean); -- NULL // Invalid argument Cast to String -------------- @@ -215,9 +247,9 @@ Invalid examples :: - SELECT cast('2012-Oct-23' as date); -- Invalid argument - SELECT cast('2012/10/23' as date); -- Invalid argument - SELECT cast('2012.10.23' as date); -- Invalid argument + SELECT cast('2012-Oct-23' as date); -- NULL // Invalid argument + SELECT cast('2012/10/23' as date); -- NULL // Invalid argument + SELECT cast('2012.10.23' as date); -- NULL // Invalid argument Cast to Decimal --------------- diff --git a/velox/docs/functions/spark/coverage.rst b/velox/docs/functions/spark/coverage.rst index f721415c82ab..f8fe4e00fd20 100644 --- a/velox/docs/functions/spark/coverage.rst +++ b/velox/docs/functions/spark/coverage.rst @@ -1,3 +1,9 @@ +================= +Function Coverage +================= + +Here is a list of all scalar, aggregate, and window functions from Spark, with functions that are available in Velox highlighted. + .. raw:: html @@ -81,73 +208,73 @@ ========================================= ========================================= ========================================= ========================================= ========================================= == ========================================= == ========================================= Scalar Functions Aggregate Functions Window Functions ===================================================================================================================================================================================================================== == ========================================= == ========================================= - :spark:func:`abs` count_if inline nvl sqrt any cume_dist - :spark:func:`acos count_min_sketch inline_outer nvl2 stack approx_count_distinct dense_rank + :spark:func:`abs` count_if inline nvl :spark:func:`sqrt` any cume_dist + :spark:func:`acos` count_min_sketch inline_outer nvl2 stack approx_count_distinct :spark:func:`dense_rank` :spark:func:`acosh` covar_pop input_file_block_length octet_length std approx_percentile first_value - add_months covar_samp input_file_block_start or stddev array_agg lag - :spark:func:`aggregate` crc32 input_file_name overlay stddev_pop avg last_value + :spark:func:`add_months` covar_samp input_file_block_start or stddev array_agg lag + :spark:func:`aggregate` :spark:func:`crc32` input_file_name :spark:func:`overlay` stddev_pop :spark:func:`avg` last_value and cume_dist :spark:func:`instr` parse_url stddev_samp bit_and lead - any current_catalog int percent_rank str_to_map bit_or :spark:func:`nth_value` - approx_count_distinct current_database isnan percentile string :spark:func:`bit_xor` ntile + any current_catalog int percent_rank :spark:func:`str_to_map` bit_or :spark:func:`nth_value` + approx_count_distinct current_database :spark:func:`isnan` percentile string :spark:func:`bit_xor` :spark:func:`ntile` approx_percentile current_date :spark:func:`isnotnull` percentile_approx struct bool_and percent_rank - :spark:func:`array` current_timestamp :spark:func:`isnull` pi substr bool_or rank - :spark:func:`array_contains` current_timezone java_method :spark:func:`pmod` :spark:func:`substring` collect_list row_number - array_distinct current_user json_array_length posexplode substring_index collect_set - array_except date json_object_keys posexplode_outer sum corr - :spark:func:`array_intersect` date_add json_tuple position tan count - array_join date_format kurtosis positive tanh count_if - array_max date_from_unix_date lag pow timestamp count_min_sketch - array_min date_part last :spark:func:`power` timestamp_micros covar_pop - array_position date_sub last_day printf timestamp_millis covar_samp - array_remove date_trunc last_value quarter timestamp_seconds every - array_repeat datediff lcase radians tinyint :spark:func:`first` - :spark:func:`array_sort` day lead raise_error to_csv first_value - array_union dayofmonth :spark:func:`least` :spark:func:`rand` to_date grouping - arrays_overlap dayofweek :spark:func:`left` randn to_json grouping_id - arrays_zip dayofyear :spark:func:`length` random to_timestamp histogram_numeric - :spark:func:`ascii` decimal levenshtein range :spark:func:`to_unix_timestamp` kurtosis - asin decode like rank to_utc_timestamp :spark:func:`last` - :spark:func:`asinh` degrees ln reflect :spark:func:`transform` last_value - assert_true dense_rank locate regexp transform_keys max - atan div log :spark:func:`regexp_extract` transform_values max_by - atan2 double log10 regexp_extract_all translate mean - :spark:func:`atanh` e :spark:func:`log1p` regexp_like :spark:func:`trim` min - avg :spark:func:`element_at` log2 regexp_replace trunc min_by - base64 elt :spark:func:`lower` repeat try_add percentile - :spark:func:`between` encode lpad :spark:func:`replace` try_divide percentile_approx - bigint every :spark:func:`ltrim` reverse typeof regr_avgx - :spark:func:`bin` exists make_date right ucase regr_avgy - binary :spark:func:`exp` make_dt_interval rint unbase64 regr_count - bit_and explode make_interval :spark:func:`rlike` unhex regr_r2 - bit_count explode_outer make_timestamp :spark:func:`round` unix_date skewness - bit_get expm1 make_ym_interval row_number unix_micros some - bit_length extract :spark:func:`map` rpad unix_millis std - bit_or factorial map_concat :spark:func:`rtrim` unix_seconds stddev - bit_xor :spark:func:`filter` map_entries schema_of_csv :spark:func:`unix_timestamp` stddev_pop - bool_and find_in_set :spark:func:`map_filter` schema_of_json :spark:func:`upper` stddev_samp - bool_or first :spark:func:`map_from_arrays` second uuid sum + :spark:func:`array` current_timestamp :spark:func:`isnull` pi substr bool_or :spark:func:`rank` + :spark:func:`array_contains` current_timezone java_method :spark:func:`pmod` :spark:func:`substring` :spark:func:`collect_list` :spark:func:`row_number` + :spark:func:`array_distinct` current_user :spark:func:`json_array_length` posexplode :spark:func:`substring_index` :spark:func:`collect_set` + :spark:func:`array_except` date :spark:func:`json_object_keys` posexplode_outer sum :spark:func:`corr` + :spark:func:`array_intersect` :spark:func:`date_add` json_tuple position tan count + :spark:func:`array_join` :spark:func:`date_format` kurtosis positive tanh count_if + :spark:func:`array_max` :spark:func:`date_from_unix_date` lag pow timestamp count_min_sketch + :spark:func:`array_min` date_part last :spark:func:`power` :spark:func:`timestamp_micros` covar_pop + :spark:func:`array_position` :spark:func:`date_sub` :spark:func:`last_day` printf :spark:func:`timestamp_millis` :spark:func:`covar_samp` + :spark:func:`array_remove` :spark:func:`date_trunc` last_value :spark:func:`quarter` timestamp_seconds every + :spark:func:`array_repeat` :spark:func:`datediff` lcase radians tinyint :spark:func:`first` + :spark:func:`array_sort` :spark:func:`day` lead :spark:func:`raise_error` to_csv first_value + :spark:func:`array_union` :spark:func:`dayofmonth` :spark:func:`least` :spark:func:`rand` to_date grouping + arrays_overlap :spark:func:`dayofweek` :spark:func:`left` randn to_json grouping_id + :spark:func:`arrays_zip` :spark:func:`dayofyear` :spark:func:`length` :spark:func:`random` to_timestamp histogram_numeric + :spark:func:`ascii` decimal :spark:func:`levenshtein` range :spark:func:`to_unix_timestamp` :spark:func:`kurtosis` + :spark:func:`asin` decode :spark:func:`like` rank :spark:func:`to_utc_timestamp` :spark:func:`last` + :spark:func:`asinh` :spark:func:`degrees` ln reflect :spark:func:`transform` last_value + assert_true dense_rank :spark:func:`locate` regexp transform_keys :spark:func:`max` + :spark:func:`atan` div :spark:func:`log` :spark:func:`regexp_extract` transform_values :spark:func:`max_by` + :spark:func:`atan2` double :spark:func:`log10` :spark:func:`regexp_extract_all` :spark:func:`translate` mean + :spark:func:`atanh` e :spark:func:`log1p` regexp_like :spark:func:`trim` :spark:func:`min` + avg :spark:func:`element_at` :spark:func:`log2` :spark:func:`regexp_replace` :spark:func:`trunc` :spark:func:`min_by` + base64 elt :spark:func:`lower` :spark:func:`repeat` try_add percentile + :spark:func:`between` encode :spark:func:`lpad` :spark:func:`replace` try_divide percentile_approx + bigint every :spark:func:`ltrim` :spark:func:`reverse` typeof regr_avgx + :spark:func:`bin` :spark:func:`exists` :spark:func:`make_date` right ucase regr_avgy + binary :spark:func:`exp` make_dt_interval :spark:func:`rint` :spark:func:`unbase64` regr_count + bit_and explode make_interval :spark:func:`rlike` :spark:func:`unhex` regr_r2 + :spark:func:`bit_count` explode_outer :spark:func:`make_timestamp` :spark:func:`round` :spark:func:`unix_date` :spark:func:`skewness` + :spark:func:`bit_get` :spark:func:`expm1` :spark:func:`make_ym_interval` row_number :spark:func:`unix_micros` some + :spark:func:`bit_length` extract :spark:func:`map` :spark:func:`rpad` :spark:func:`unix_millis` std + bit_or :spark:func:`factorial` :spark:func:`map_concat` :spark:func:`rtrim` :spark:func:`unix_seconds` :spark:func:`stddev` + bit_xor :spark:func:`filter` :spark:func:`map_entries` schema_of_csv :spark:func:`unix_timestamp` stddev_pop + bool_and :spark:func:`find_in_set` :spark:func:`map_filter` schema_of_json :spark:func:`upper` :spark:func:`stddev_samp` + bool_or first :spark:func:`map_from_arrays` :spark:func:`second` :spark:func:`uuid` :spark:func:`sum` boolean first_value map_from_entries sentences var_pop try_avg - bround flatten map_keys sequence var_samp try_sum - btrim float map_values session_window variance var_pop - cardinality :spark:func:`floor` map_zip_with sha version var_samp - case forall max :spark:func:`sha1` weekday variance + bround :spark:func:`flatten` :spark:func:`map_keys` sequence var_samp try_sum + btrim float :spark:func:`map_values` session_window variance var_pop + cardinality :spark:func:`floor` :spark:func:`map_zip_with` sha version :spark:func:`var_samp` + case :spark:func:`forall` max :spark:func:`sha1` :spark:func:`weekday` :spark:func:`variance` cast format_number max_by :spark:func:`sha2` weekofyear - cbrt format_string :spark:func:`md5` :spark:func:`shiftleft` when - :spark:func:`ceil` from_csv mean :spark:func:`shiftright` width_bucket + :spark:func:`cbrt` format_string :spark:func:`md5` :spark:func:`shiftleft` when + :spark:func:`ceil` from_csv mean :spark:func:`shiftright` :spark:func:`width_bucket` ceiling from_json min shiftrightunsigned window - char from_unixtime min_by shuffle xpath - char_length from_utc_timestamp minute sign xpath_boolean + char :spark:func:`from_unixtime` min_by :spark:func:`shuffle` xpath + char_length :spark:func:`from_utc_timestamp` :spark:func:`minute` :spark:func:`sign` xpath_boolean character_length :spark:func:`get_json_object` mod signum xpath_double - :spark:func:`chr` getbit monotonically_increasing_id sin xpath_float - coalesce :spark:func:`greatest` month :spark:func:`sinh` xpath_int + :spark:func:`chr` getbit :spark:func:`monotonically_increasing_id` sin xpath_float + coalesce :spark:func:`greatest` :spark:func:`month` :spark:func:`sinh` xpath_int collect_list grouping months_between :spark:func:`size` xpath_long collect_set grouping_id named_struct skewness xpath_number - :spark:func:`concat` :spark:func:`hash` nanvl slice xpath_short - concat_ws hex negative smallint xpath_string - conv hour next_day some :spark:func:`xxhash64` + :spark:func:`concat` :spark:func:`hash` nanvl :spark:func:`slice` xpath_short + concat_ws :spark:func:`hex` negative smallint xpath_string + :spark:func:`conv` :spark:func:`hour` :spark:func:`next_day` some :spark:func:`xxhash64` corr :spark:func:`hypot` :spark:func:`not` :spark:func:`sort_array` :spark:func:`year` - cos if now soundex zip_with - cosh ifnull nth_value space - cot :spark:func:`in` ntile spark_partition_id + :spark:func:`cos` if now :spark:func:`soundex` :spark:func:`zip_with` + :spark:func:`cosh` ifnull nth_value space + :spark:func:`cot` :spark:func:`in` ntile :spark:func:`spark_partition_id` count initcap nullif :spark:func:`split` ========================================= ========================================= ========================================= ========================================= ========================================= == ========================================= == ========================================= diff --git a/velox/docs/functions/spark/datetime.rst b/velox/docs/functions/spark/datetime.rst index e94004cd56b1..97e6f2b25166 100644 --- a/velox/docs/functions/spark/datetime.rst +++ b/velox/docs/functions/spark/datetime.rst @@ -1,6 +1,6 @@ -===================================== +======================= Date and Time Functions -===================================== +======================= Convenience Extraction Functions -------------------------------- @@ -183,35 +183,6 @@ These functions support TIMESTAMP and DATE input types. ``day`` need to be from 1 to 31, and matches the number of days in each month. days of ``year-month-day - 1970-01-01`` need to be in the range of INTEGER type. -.. spark:function:: make_ym_interval([years[, months]]) -> interval year to month - - Make year-month interval from ``years`` and ``months`` fields. - Returns the actual year-month with month in the range of [0, 11]. - Both ``years`` and ``months`` can be zero, positive or negative. - Throws an error when inputs lead to int overflow, - e.g., make_ym_interval(178956970, 8). :: - - SELECT make_ym_interval(1, 2); -- 1-2 - SELECT make_ym_interval(1, 0); -- 1-0 - SELECT make_ym_interval(-1, 1); -- -0-11 - SELECT make_ym_interval(1, 100); -- 9-4 - SELECT make_ym_interval(1, 12); -- 2-0 - SELECT make_ym_interval(1, -12); -- 0-0 - SELECT make_ym_interval(2); -- 2-0 - SELECT make_ym_interval(); -- 0-0 - -.. spark:function:: minute(timestamp) -> integer - - Returns the minutes of ``timestamp``.:: - - SELECT minute('2009-07-30 12:58:59'); -- 58 - -.. spark:function:: quarter(date) -> integer - - Returns the quarter of ``date``. The value ranges from ``1`` to ``4``. :: - - SELECT quarter('2009-07-30'); -- 3 - .. spark:function:: make_timestamp(year, month, day, hour, minute, second[, timezone]) -> timestamp Create timestamp from ``year``, ``month``, ``day``, ``hour``, ``minute`` and ``second`` fields. @@ -244,12 +215,49 @@ These functions support TIMESTAMP and DATE input types. SELECT make_timestamp(2014, 12, 28, 6, 30, 60.000001); -- NULL SELECT make_timestamp(2014, 13, 28, 6, 30, 45.887); -- NULL +.. spark:function:: make_ym_interval([years[, months]]) -> interval year to month + + Make year-month interval from ``years`` and ``months`` fields. + Returns the actual year-month with month in the range of [0, 11]. + Both ``years`` and ``months`` can be zero, positive or negative. + Throws an error when inputs lead to int overflow, + e.g., make_ym_interval(178956970, 8). :: + + SELECT make_ym_interval(1, 2); -- 1-2 + SELECT make_ym_interval(1, 0); -- 1-0 + SELECT make_ym_interval(-1, 1); -- -0-11 + SELECT make_ym_interval(1, 100); -- 9-4 + SELECT make_ym_interval(1, 12); -- 2-0 + SELECT make_ym_interval(1, -12); -- 0-0 + SELECT make_ym_interval(2); -- 2-0 + SELECT make_ym_interval(); -- 0-0 + +.. spark:function:: minute(timestamp) -> integer + + Returns the minutes of ``timestamp``.:: + + SELECT minute('2009-07-30 12:58:59'); -- 58 + .. spark:function:: month(date) -> integer Returns the month of ``date``. :: SELECT month('2009-07-30'); -- 7 +.. spark:function:: months_between(timestamp1, timestamp2, roundOff) -> double + + Returns number of months between times ``timestamp1`` and ``timestamp2``. + If ``timestamp1`` is later than ``timestamp2``, the result is positive. + If ``timestamp1`` and ``timestamp2`` are on the same day of month, or both are the + last day of month, time of day will be ignored. Otherwise, the difference is calculated + based on 31 days per month, and rounded to 8 digits unless ``roundOff`` is false. :: + + SELECT months_between('1997-02-28 10:30:00', '1996-10-30', true); -- 3.94959677 + SELECT months_between('1997-02-28 10:30:00', '1996-10-30', false); -- 3.9495967741935485 + SELECT months_between('1997-02-28 10:30:00', '1996-03-31 11:00:00', true); -- 11.0 + SELECT months_between('1997-02-28 10:30:00', '1996-03-28 11:00:00', true); -- 11.0 + SELECT months_between('1997-02-21 10:30:00', '1996-03-21 11:00:00', true); -- 11.0 + .. spark:function:: next_day(startDate, dayOfWeek) -> date Returns the first date which is later than ``startDate`` and named as ``dayOfWeek``. @@ -265,12 +273,47 @@ These functions support TIMESTAMP and DATE input types. SELECT next_day('2015-07-23', "tu"); -- '2015-07-28' SELECT next_day('2015-07-23', "we"); -- '2015-07-29' +.. spark:function:: quarter(date) -> integer + + Returns the quarter of ``date``. The value ranges from ``1`` to ``4``. :: + + SELECT quarter('2009-07-30'); -- 3 + .. spark:function:: second(timestamp) -> integer Returns the seconds of ``timestamp``. :: SELECT second('2009-07-30 12:58:59'); -- 59 +.. spark:function:: timestampadd(unit, value, timestamp) -> timestamp + + Adds an int or bigint interval ``value`` of type ``unit`` to ``timestamp``. + Subtraction can be performed by using a negative ``value``. + Throws exception if ``unit`` is invalid. + ``unit`` is case insensitive and must be one of the following: + ``YEAR``, ``QUARTER``, ``MONTH``, ``WEEK``, ``DAY``, ``DAYOFYEAR``, ``HOUR``, ``MINUTE``, ``SECOND``, + ``MILLISECOND``, ``MICROSECOND``. :: + + SELECT timestampadd(YEAR, 1, '2030-02-28 10:00:00.500'); -- 2031-02-28 10:00:00.500 + SELECT timestampadd(DAY, 1, '2020-02-29 10:00:00.500'); -- 2020-03-01 10:00:00.500 + SELECT timestampadd(DAYOFYEAR, 1, '2020-02-29 10:00:00.500'); -- 2020-03-01 10:00:00.500 + SELECT timestampadd(SECOND, 10, '2019-03-01 10:00:00.500'); -- 2019-03-01 10:00:10.500 + SELECT timestampadd(MICROSECOND, 500, '2019-02-28 10:01:00.500999'); -- 2019-02-28 10:01:00.501499 + +.. spark:function:: timestampdiff(unit, timestamp1, timestamp2) -> bigint + + Returns ``timestamp2`` - ``timestamp1`` expressed in terms of ``unit``, the fraction + part is truncated. + Throws exception if ``unit`` is invalid. + ``unit`` is case insensitive and must be one of the following: + ``YEAR``, ``QUARTER``, ``MONTH``, ``WEEK``, ``DAY``, ``HOUR``, ``MINUTE``, ``SECOND``, + ``MILLISECOND``, ``MICROSECOND``. :: + + SELECT timestampdiff(YEAR, '2020-02-29 10:00:00.500', '2030-02-28 10:00:00.500'); -- 9 + SELECT timestampdiff(DAY, '2019-01-30 10:00:00.500', '2020-02-29 10:00:00.500'); -- 395 + SELECT timestampdiff(SECOND, '2019-02-28 10:00:00.500', '2019-03-01 10:00:00.500'); -- 86400 + SELECT timestampdiff(MICROSECOND, '2019-02-28 10:00:00.000000', '2019-02-28 10:01:00.500999'); -- 60500999 + .. spark:function:: timestamp_micros(x) -> timestamp Returns timestamp from the number of microseconds since UTC epoch. @@ -285,6 +328,23 @@ These functions support TIMESTAMP and DATE input types. SELECT timestamp_millis(1230219000123); -- '2008-12-25 15:30:00.123' +.. spark:function:: timestamp_seconds(x) -> timestamp + + Returns timestamp from the number of seconds (can be fractional) since UTC epoch. + Supported types are: TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, and DOUBLE. + For integral types (TINYINT, SMALLINT, INTEGER, BIGINT), the function directly + converts the number of seconds to a timestamp. For floating-point types + (FLOAT, DOUBLE), the function scales the input to microseconds, truncates + towards zero, and saturates the result to the minimum and maximum values allowed + in Spark. Returns NULL when ``x`` is NaN or Infinity. :: + + SELECT timestamp_seconds(1230219000); -- '2008-12-25 15:30:00' + SELECT timestamp_seconds(1230219000.123); -- '2008-12-25 15:30:00.123' + SELECT timestamp_seconds(double(1.1234567)); -- '1970-01-01 00:00:01.123456' + SELECT timestamp_seconds(double('inf')); -- NULL + SELECT timestamp_seconds(float(3.4028235E+38)); -- '+294247-01-10 04:00:54.775807' + SELECT timestamp_seconds(float('nan')); -- NULL + .. spark:function:: to_unix_timestamp(date) -> bigint :noindex: @@ -310,6 +370,26 @@ These functions support TIMESTAMP and DATE input types. SELECT to_utc_timestamp('2015-07-24 00:00:00', 'America/Los_Angeles'); -- '2015-07-24 07:00:00' +.. spark:function:: trunc(date, fmt) -> date + + Returns ``date`` truncated to the unit specified by the format model ``fmt``. + Returns NULL if ``fmt`` is invalid. + + ``fmt`` is case insensitive and must be one of the following: + * "YEAR", "YYYY", "YY" - truncate to the first date of the year that the ``date`` falls in + * "QUARTER" - truncate to the first date of the quarter that the ``date`` falls in + * "MONTH", "MM", "MON" - truncate to the first date of the month that the ``date`` falls in + * "WEEK" - truncate to the Monday of the week that the ``date`` falls in + + :: + + SELECT trunc('2019-08-04', 'week'); -- 2019-07-29 + SELECT trunc('2019-08-04', 'quarter'); -- 2019-07-01 + SELECT trunc('2009-02-12', 'MM'); -- 2009-02-01 + SELECT trunc('2015-10-27', 'YEAR'); -- 2015-01-01 + SELECT trunc('2015-10-27', ''); -- NULL + SELECT trunc('2015-10-27', 'day'); -- NULL + .. spark:function:: unix_date(date) -> integer Returns the number of days since 1970-01-01. :: diff --git a/velox/docs/functions/spark/decimal.rst b/velox/docs/functions/spark/decimal.rst index 3d1c1c18fc23..9573e3aecb7e 100644 --- a/velox/docs/functions/spark/decimal.rst +++ b/velox/docs/functions/spark/decimal.rst @@ -38,7 +38,7 @@ The HiveQL behavior: https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf Additionally, the computation of decimal division adapts to the allow-precision-loss flag, -while the decimal addition, subtraction, and multiplication do not. +while the decimal addition, subtraction, multiplication and integer division do not. Addition and Subtraction ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -74,6 +74,15 @@ When allow-precision-loss is false: p = wholeDigits + fractionalDigits s = fractionalDigits +Integer Division +~~~~~~~~~~~~~~~~ + +:: + + precision = p1 - s1 + s2 + p = precision == 0 ? 1 : min(38, precision) + s = 0 + Decimal Precision and Scale Adjustment -------------------------------------- @@ -112,6 +121,28 @@ Decimal division uses a different formula: Returns NULL when the actual result cannot be represented with the calculated decimal type. +Arithmetic Functions +-------------------- + +.. spark:function:: div(x: decimal(p1, s1), y: decimal(p2, s2)) -> bigint + + Performs integer division and returns the bigint result of dividing ``x`` by ``y``, truncating toward zero. + Truncation occurs if the result is within the result precision but exceeds the BIGINT range. + Division by zero or overflow results in NULL. Does not respect the ``allow-precision-loss`` configuration. + Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to false. :: + + SELECT CAST(1 as DECIMAL(17, 3)) div CAST(2 as DECIMAL(17, 3)); -- 0 + SELECT CAST(21 as DECIMAL(20, 3)) div CAST(20 as DECIMAL(20, 2)); -- 1 + SELECT CAST(1 as DECIMAL(20, 3)) div CAST(0 as DECIMAL(20, 3)); -- NULL + SELECT CAST(99999999999999999999999999999999999 as DECIMAL(38, 1)) div CAST(0.001 as DECIMAL(7, 4)); -- 687399551400672280 // Result is truncated to int64_t. + +.. spark:function:: checked_div(x: decimal(p1, s1), y: decimal(p2, s2)) -> bigint + + Performs integer division and returns the bigint result of dividing ``x`` by ``y``, truncating toward zero. + Truncation occurs if the result is within the result precision but exceeds the BIGINT range. + Division by zero or overflow results in an error. + Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to true. + Decimal Functions ----------------- .. spark:function:: ceil(x: decimal(p, s)) -> r: decimal(pr, 0) @@ -122,6 +153,36 @@ Decimal Functions SELECT ceil(cast(1.23 as DECIMAL(3, 2))); -- 2 // Output type: decimal(2,0) +.. spark:function:: decimal_equalto(x, y) -> boolean + + Returns true if x is equal to y. Supports decimal types with different precisions and scales. + Corresponds to Spark's operator ``==``. + +.. spark:function:: decimal_greaterthan(x, y) -> boolean + + Returns true if x is greater than y. Supports decimal types with different precisions and scales. + Corresponds to Spark's operator ``>``. + +.. spark:function:: decimal_greaterthanorequal(x, y) -> boolean + + Returns true if x is greater than y or x is equal to y. Supports decimal types with different precisions and scales. + Corresponds to Spark's operator ``>=``. + +.. spark:function:: decimal_lessthan(x, y) -> boolean + + Returns true if x is less than y. Supports decimal types with different precisions and scales. + Corresponds to Spark's operator ``<``. + +.. spark:function:: decimal_lessthanorequal(x, y) -> boolean + + Returns true if x is less than y or x is equal to y. Supports decimal types with different precisions and scales. + Corresponds to Spark's operator ``<=``. + +.. spark:function:: decimal_notequalto(x, y) -> boolean + + Returns true if x is not equal to y. Supports decimal types with different precisions and scales. + Corresponds to Spark's operator ``!=``. + .. spark:function:: floor(x: decimal(p, s)) -> r: decimal(pr, 0) Returns ``x`` rounded down to the type ``decimal(min(38, p - s + min(1, s)), 0)``. @@ -150,12 +211,6 @@ Decimal Functions Decimal Special Forms --------------------- -.. spark:function:: make_decimal(x[, nullOnOverflow]) -> decimal - - Create ``decimal`` of requsted precision and scale from an unscaled bigint value ``x``. - By default, the value of ``nullOnOverflow`` is true, and null will be returned when ``x`` is too large for the result precision. - Otherwise, exception will be thrown when ``x`` overflows. - .. spark:function:: decimal_round(decimal[, scale]) -> [decimal] Returns ``decimal`` rounded to a new scale using HALF_UP rounding mode. In HALF_UP rounding, the digit 5 is rounded up. @@ -195,3 +250,9 @@ Decimal Special Forms SELECT round(cast (85.681 as DECIMAL(5, 3)), 1); -- decimal 85.7 SELECT round(cast (85.681 as DECIMAL(5, 3)), 999); -- decimal 85.681 SELECT round(cast (0.1234567890123456789 as DECIMAL(19, 19)), 14); -- decimal 0.12345678901235 + +.. spark:function:: make_decimal(x[, nullOnOverflow]) -> decimal + + Create ``decimal`` of requsted precision and scale from an unscaled bigint value ``x``. + By default, the value of ``nullOnOverflow`` is true, and null will be returned when ``x`` is too large for the result precision. + Otherwise, exception will be thrown when ``x`` overflows. diff --git a/velox/docs/functions/spark/json.rst b/velox/docs/functions/spark/json.rst index 5ffe3af0fca5..c7f37a7ffa51 100644 --- a/velox/docs/functions/spark/json.rst +++ b/velox/docs/functions/spark/json.rst @@ -24,9 +24,9 @@ JSON Functions Casts ``jsonString`` to an ARRAY, MAP, or ROW type, with the output type determined by the expression. Returns NULL, if the input string is unparsable. Supported element types include BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, - REAL, DOUBLE, DATE, VARCHAR, ARRAY, MAP and ROW. When casting to ARRAY or MAP, - the element type of the array or the value type of the map must be one of - these supported types, and for maps, the key type must be VARCHAR. Casting + REAL, DOUBLE, DECIMAL, DATE, VARCHAR, ARRAY, MAP and ROW. When casting to ARRAY + or MAP, the element type of the array or the value type of the map must be one + of these supported types, and for maps, the key type must be VARCHAR. Casting to ROW supports only JSON objects. Note that since the result type can be inferred from the expression, in Velox we do not need to provide the ``schema`` parameter as required by Spark's from_json @@ -35,6 +35,7 @@ JSON Functions SELECT from_json('{"a": true}', 'a BOOLEAN'); -- {'a'=true} SELECT from_json('{"a": 1}', 'a INT'); -- {'a'=1} SELECT from_json('{"a": 1.0}', 'a DOUBLE'); -- {'a'=1.0} + SELECT from_json('{"a": 5.321E2}', 'a DECIMAL(7, 2)'); -- {'a'=532.10} SELECT from_json('{"a":"2021-7-1T"}', 'a DATE'); -- {'a'="2021-07-01"} SELECT from_json('{"a":"1"}', 'a DATE'); -- {'a'="1970-01-02"} SELECT from_json('["name", "age", "id"]', 'ARRAY'); -- ['name', 'age', 'id'] @@ -58,10 +59,7 @@ JSON Functions .. spark:function:: get_json_object(jsonString, path) -> varchar Returns a json object, represented by VARCHAR, from ``jsonString`` by searching ``path``. - Valid ``path`` should start with '$' and then contain "[index]", "['field']" or ".field" - to define a JSON path. Here are some examples: "$.a" "$.a.b", "$[0]['a'].b". Returns - ``jsonString`` if ``path`` is "$". Returns NULL if ``jsonString`` or ``path`` is malformed. - Returns NULL if ``path`` does not exist. :: + Returns NULL if ``jsonString`` or ``path`` is malformed or ``path`` does not exist. :: SELECT get_json_object('{"a":"b"}', '$.a'); -- 'b' SELECT get_json_object('{"a":{"b":"c"}}', '$.a'); -- '{"b":"c"}' @@ -69,6 +67,12 @@ JSON Functions SELECT get_json_object('{"a"-3}'', '$.a'); -- NULL (malformed JSON string) SELECT get_json_object('{"a":3}'', '.a'); -- NULL (malformed JSON path) + Valid ``path`` syntax: + * Must start with '$'. + * Using "[index]", "['field']" or ".field" to navigate to the desired JSON object. + * Whitespace is allowed **after the dot** and **before the field name**, e.g., "$. field". + * Trailing whitespace after '$' is allowed, e.g., "$ ". + .. spark:function:: json_array_length(jsonString) -> integer Returns the number of elements in the outermost JSON array from ``jsonString``. @@ -90,3 +94,21 @@ JSON Functions SELECT json_object_keys(1); -- NULL SELECT json_object_keys('"hello"'); -- NULL SELECT json_object_keys("invalid json"); -- NULL + +.. spark:function:: to_json(jsonObject) -> jsonString + + Converts a Json object (ROW, ARRAY or MAP) into a JSON string. :: + + SELECT to_json(named_struct('c0', 1, 'c1', 'a')); -- {"c0":1,"c1":"a"} + SELECT to_json(ARRAY(1, 2, 3)); -- [1,2,3] + SELECT to_json(MAP('x', 1, 'y', 2)); -- {"x":1,"y":2} + + The current implementation has following limitations. + + * Does not support user provided options. :: + + to_json(MAP(1, 'a'), map('option', 'value')) + + * MAP key type cannot be/contain MAP. :: + + to_json(MAP(MAP('a', 1), 10)) diff --git a/velox/docs/functions/spark/map.rst b/velox/docs/functions/spark/map.rst index 0efb77da7ce6..85532a6a12a8 100644 --- a/velox/docs/functions/spark/map.rst +++ b/velox/docs/functions/spark/map.rst @@ -1,6 +1,6 @@ -=========================== +============= Map Functions -=========================== +============= .. spark:function:: element_at(map(K,V), key) -> V @@ -8,9 +8,14 @@ Map Functions .. spark:function:: map(K, V, K, V, ...) -> map(K,V) - Returns a map created using the given key/value pairs. Keys are not allowed to be null. :: + Returns a map created using the given key/value pairs. If there is duplicate key, by default that + key's value comes from last value for that key in the arguments. + If configuration `throw_exception_on_duplicate_map_keys` is set true, + throws exception for duplicate keys. Keys are not allowed to be null. :: SELECT map(1, 2, 3, 4); -- {1 -> 2, 3 -> 4} + SELECT map(1, 2, 3, 4, 1, 5); -- {1 -> 5, 3 -> 4} (LAST_WIN behavior) + SELECT map(1, 2, 3, 4, 1, 5); -- "Duplicate map key (1) was found" (EXCEPTION behavior) SELECT map(array(1, 2), array(3, 4)); -- {[1, 2] -> [3, 4]} @@ -70,7 +75,7 @@ Map Functions (k, v1, v2) -> k || CAST(v1/v2 AS VARCHAR)); .. spark:function:: size(map(K,V), legacySizeOfNull) -> integer - :noindex: + :noindex: Returns the size of the input map. Returns null for null input if ``legacySizeOfNull`` is set to false. Otherwise, returns -1 for null input. :: diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index dcc088d756fa..4bdc0d58fd65 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -1,11 +1,18 @@ -==================================== +====================== Mathematical Functions -==================================== +====================== -.. spark:function:: abs(x) -> [same as x] +.. spark:function:: abs(x) -> [same as x] (ANSI compliant) - Returns the absolute value of ``x``. + Returns the absolute value of ``x``. When ``x`` is negative minimum + value of integral type returns the same value as ``x`` following + the behavior when Spark ANSI mode is disabled and throws exception + when Spark ANSI mode is enabled. :: + SELECT abs(-42); -- 42 + SELECT abs(3.14); -- 3.14 + SELECT abs(-128); -- 128 (with ANSI mode disabled) + SELECT abs(-128); -- Overflow exception (with ANSI mode enabled for TINYINT) .. spark:function:: acos(x) -> double Returns the inverse cosine (a.k.a. arc cosine) of ``x``. @@ -62,6 +69,10 @@ Mathematical Functions Returns the string representation of the long value ``x`` represented in binary. +.. spark:function:: cbrt(x) -> double + + Returns the cube root of ``x``. + .. spark:function:: ceil(x) -> [same as x] Returns ``x`` rounded up to the nearest integer. @@ -72,7 +83,14 @@ Mathematical Functions Returns the result of adding x to y. The types of x and y must be the same. For integral types, overflow results in an error. Corresponds to Spark's operator ``+`` with ``failOnError`` as true. -.. function:: checked_divide(x, y) -> [same as x] +.. function:: checked_div(x, y) -> bigint + + Returns the result of integer division of ``x`` by ``y``, truncating toward zero. + Supported types are integral types, ``x`` and ``y`` must have the same type. + Division by zero or overflow results in an error. This function operates in ANSI mode (error on invalid input). + Corresponds to Spark's operator ``div`` with ``spark.sql.ansi.enabled`` set to true. + +.. spark:function:: checked_divide(x, y) -> [same as x] Returns the results of dividing x by y. The types of x and y must be the same. Division by zero results in an error. Corresponds to Spark's operator ``/`` with ``failOnError`` as true. @@ -107,6 +125,16 @@ Mathematical Functions Converts angle x in radians to degrees. +.. spark:function:: div(x, y) -> bigint + + Returns the results of dividing x by y. Performs the integer division truncates toward zero. + Supported types are integral types, x and y must have the same type. + Division by zero or overflow results in null. :: + + SELECT 3 div 2; -- 1 + SELECT 1L div 2L; -- 0 + SELECT 3 div 0; -- NULL + .. spark:function:: divide(x, y) -> double Returns the results of dividing x by y. Performs floating point division. diff --git a/velox/docs/functions/spark/misc.rst b/velox/docs/functions/spark/misc.rst index de826e35b140..311586faa6f9 100644 --- a/velox/docs/functions/spark/misc.rst +++ b/velox/docs/functions/spark/misc.rst @@ -1,6 +1,6 @@ -==================================== +======================= Miscellaneous Functions -==================================== +======================= .. spark:function:: at_least_n_non_nulls(n, value1, value2, ..., valueN) -> bool @@ -15,6 +15,16 @@ Miscellaneous Functions SELECT at_least_n_non_nulls(2, 0, 1.0, NULL); -- true SELECT at_least_n_non_nulls(2, 0, array(NULL), NULL); -- true +.. spark:function:: get_array_struct_fields(array, ordinal) -> array(T) + + Extracts the ``ordinal``-th fields of all array elements, and returns them as a new array. + The first input must be of array(strcut) type and nested complex type is allowed. + The ``ordinal`` is 0-based, and if ``ordinal`` is negative or no less than + the children size of strcut, exception is thrown. :: + + SELECT items.col1 FROM VALUES (array(struct(100,'foo'), struct(200,'bar'))) AS t(items); -- array(100, 200) + SELECT items.col2 FROM VALUES (array(struct(100,'foo'), struct(200,'bar'))) AS t(items); -- array('foo', 'bar') + .. spark:function:: get_struct_field(struct, ordinal) -> T Returns the value of nested subfield at position ``ordinal`` in the input ``struct``. diff --git a/velox/docs/functions/spark/regexp.rst b/velox/docs/functions/spark/regexp.rst index 0552e2a4678f..ab994e738f2d 100644 --- a/velox/docs/functions/spark/regexp.rst +++ b/velox/docs/functions/spark/regexp.rst @@ -32,7 +32,7 @@ See https://github.com/google/re2/wiki/Syntax for more information. '%hello', '%__hello_', '%hello%', where 'hello', 'velox' contains only regular characters and '_' wildcards are evaluated without using regular expressions. Only those patterns that require the compilation of - regular expressions are counted towards the limit. + regular expressions are counted towards the limit. :: SELECT like('abc', '%b%'); -- true SELECT like('a_c', '%#_%', '#'); -- true @@ -48,7 +48,7 @@ See https://github.com/google/re2/wiki/Syntax for more information. SELECT regexp_extract('1a 2b 14m', '\d+'); -- 1 .. spark:function:: regexp_extract(string, pattern, group) -> varchar - :noindex: + :noindex: Finds the first occurrence of the regular expression ``pattern`` in ``string`` and returns the capturing group number ``group``. diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index e788f1d2ddb2..6e8c5f96de86 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -1,6 +1,6 @@ -==================================== +================ String Functions -==================================== +================ .. note:: @@ -13,12 +13,37 @@ String Functions Returns unicode code point of the first character of ``string``. Returns 0 if ``string`` is empty. +.. spark:function:: base64(expr) -> varchar + + Converts ``expr`` to a base 64 string using RFC2045 Base64 transfer encoding for MIME. :: + + SELECT base64('Spark SQL'); -- 'U3BhcmsgU1FM' + .. spark:function:: bit_length(string/binary) -> integer Returns the bit length for the specified string column. :: SELECT bit_length('123'); -- 24 +.. spark:function:: char_type_write_side_check(string, limit) -> varchar + + Ensures that input ``string`` fits within the specified length ``limit`` in characters by padding or trimming spaces as needed. + If the length of ``string`` is less than ``limit``, it is padded with trailing spaces (ASCII 32) to reach ``limit``. + If the length of ``string`` is greater than ``limit``, trailing spaces are trimmed to fit within ``limit``. + Throws exception when ``string`` still exceeds ``limit`` after trimming trailing spaces or when ``limit`` is not greater than 0. + Note: This function is not directly callable in Spark SQL, but internally used for length check when writing char type columns. :: + + -- Function call examples (this function is not directly callable in Spark SQL). + char_type_write_side_check("abc", 3) -- "abc" + char_type_write_side_check("ab", 3) -- "ab " + char_type_write_side_check("a", 3) -- "a " + char_type_write_side_check("abc ", 3) -- "abc" + char_type_write_side_check("abcd", 3) -- VeloxUserError: "Exceeds allowed length limitation: '3'" + char_type_write_side_check("世界", 2) -- "世界" + char_type_write_side_check("世", 3) -- "世 " + char_type_write_side_check("世界人", 2) -- VeloxUserError: "Exceeds allowed length limitation: '2'" + char_type_write_side_check("abc", 0) -- VeloxUserError: "The length limit must be greater than 0." + .. spark:function:: chr(n) -> varchar Returns the Unicode code point ``n`` as a single character string. @@ -108,6 +133,20 @@ String Functions SELECT find_in_set(NULL, ',123'); -- NULL SELECT find_in_set("abc", NULL); -- NULL +.. spark:function:: initcap(string) -> varchar + + The ``initcap`` function converts the first character of each word to uppercase + and all other characters in the word to lowercase. It supports UTF-8 multibyte + characters, up to four bytes per character. + + A *word* is defined as a sequence of characters separated by whitespace. :: + + SELECT initcap('spark sql'); -- Spark Sql + SELECT initcap('spARK sQL'); -- Spark Sql + SELECT initcap('123abc DEF!ghi'); -- 123abc Def!ghi + SELECT initcap('élan vital für alle'); -- Élan Vital Für Alle + SELECT initcap('hello-world test_case'); -- Hello-world Test_case + .. spark:function:: instr(string, substring) -> integer Returns the starting position of the first instance of ``substring`` in @@ -184,7 +223,7 @@ String Functions SELECT ltrim(' data '); -- "data " .. spark:function:: ltrim(trimCharacters, string) -> varchar - :noindex: + :noindex: Removes specified leading characters from ``string``. The specified character is any character contained in ``trimCharacters``. @@ -247,6 +286,24 @@ String Functions SELECT overlay('Spark SQL', 'tructured', 2, 4); -- "Structured SQL" SELECT overlay('Spark SQL', '_', -6, 3); -- "_Sql" +.. spark:function:: read_side_padding(string, limit) -> varchar + + Right-pads the given string with spaces to the specified length ``limit``. + If the string's length is already greater than or equal to ``limit``, it is returned as-is. + Throws an exception if ``limit`` is not greater than 0. + Note: This function is not directly callable in Spark SQL, but is used internally for reading CHAR type columns. :: + + -- Function call examples (this function is not directly callable in Spark SQL). + read_side_padding("a", 3) -- "a " + read_side_padding("abc", 3) -- "abc" + read_side_padding("abcd", 3) -- "abcd" + read_side_padding("世", 3) -- "世 " + read_side_padding("世界", 2) -- "世界" + read_side_padding("Привет", 8) -- "Привет " + read_side_padding("Γειά", 5) -- "Γειά " + read_side_padding("Приветик", 6) -- "Приветик" + read_side_padding("a", 0) -- VeloxUserError: "The length limit must be greater than 0." + .. spark:function:: repeat(input, n) -> varchar Returns the string which repeats ``input`` ``n`` times. @@ -294,7 +351,7 @@ String Functions SELECT rtrim(' data '); -- " data" .. spark:function:: rtrim(trimCharacters, string) -> varchar - :noindex: + :noindex: Removes specified trailing characters from ``string``. The specified character is any character contained in ``trimCharacters``. @@ -360,7 +417,7 @@ String Functions the meaning is to refer to the first character.Type of 'start' must be an INTEGER. .. spark:function:: substring(string, start, length) -> varchar - :noindex: + :noindex: Returns a substring from ``string`` of length ``length`` from the starting position ``start``. Positions start with ``1``. A negative starting @@ -425,7 +482,7 @@ String Functions SELECT trim(' data '); -- "data" .. spark:function:: trim(trimCharacters, string) -> varchar - :noindex: + :noindex: Removes specified leading and trailing characters from ``string``. The specified character is any character contained in ``trimCharacters``. @@ -433,13 +490,17 @@ String Functions SELECT trim('sprk', 'spark'); -- "a" +.. spark:function:: unbase64(expr) -> varbinary + + Returns a decoded base64 string as binary. :: + + SELECT cast(unbase64('U3BhcmsgU1FM') AS STRING); -- 'Spark SQL' + .. spark:function:: upper(string) -> string Returns string with all characters changed to uppercase. :: SELECT upper('SparkSql'); -- SPARKSQL -<<<<<<< HEAD -======= .. spark:function:: varchar_type_write_side_check(string, limit) -> varchar @@ -456,4 +517,3 @@ String Functions varchar_type_write_side_check("中文中国", 3) -- VeloxUserError: "Exceeds allowed length limitation: '3'" varchar_type_write_side_check(" ", 0) -- VeloxUserError: "The length limit must be greater than 0." varchar_type_write_side_check("", 3) -- "" ->>>>>>> 7c73c1106 (misc: Use pre-commit for quality checks (#13361)) diff --git a/velox/docs/functions/spark/url.rst b/velox/docs/functions/spark/url.rst index a6eaf70edba9..26544fb4a184 100644 --- a/velox/docs/functions/spark/url.rst +++ b/velox/docs/functions/spark/url.rst @@ -48,6 +48,11 @@ digits after the percent character "%". All the url extract functions will retur Encoding Functions ------------------ +.. spark:function:: url_decode(value) -> varchar + + Unescapes the URL encoded ``value``. + This function is the inverse of :spark:func:`url_encode`. + .. spark:function:: url_encode(value) -> varchar Escapes ``value`` by encoding it so that it can be safely included in @@ -59,8 +64,3 @@ Encoding Functions * All other characters are converted to UTF-8 and the bytes are encoded as the string ``%XX`` where ``XX`` is the uppercase hexadecimal value of the UTF-8 byte. - -.. spark:function:: url_decode(value) -> varchar - - Unescapes the URL encoded ``value``. - This function is the inverse of :spark:func:`url_encode`. diff --git a/velox/docs/functions/spark/window.rst b/velox/docs/functions/spark/window.rst index 2a4d7921c95a..2420e9f07f71 100644 --- a/velox/docs/functions/spark/window.rst +++ b/velox/docs/functions/spark/window.rst @@ -3,7 +3,7 @@ Window functions ================ Spark window functions can be used to compute SQL window functions. -More details about window functions can be found at :doc:`/develop/window` +More details about window functions can be found at :doc:`/develop/window`. Value functions --------------- @@ -18,14 +18,6 @@ in the window, null is returned. It is an error for the offset to be zero or neg Rank functions --------------- -.. spark:function:: row_number() -> integer - -Returns a unique, sequential number to each row, starting with one, according to the ordering of rows within the window partition. - -.. spark:function:: rank() -> integer - -Returns the rank of a value in a group of values. The rank is one plus the number of rows preceding the row that are not peer with the row. Thus, the values in the ordering will produce gaps in the sequence. The ranking is performed for each window partition. - .. spark:function:: dense_rank() -> integer Returns the rank of a value in a group of values. This is similar to rank(), except that tie values do not produce gaps in the sequence. @@ -35,3 +27,11 @@ Returns the rank of a value in a group of values. This is similar to rank(), exc Divides the rows for each window partition into n buckets ranging from 1 to at most ``n``. Bucket values will differ by at most 1. If the number of rows in the partition does not divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket. For example, with 6 rows and 4 buckets, the bucket values would be as follows: ``1 1 2 2 3 4`` + +.. spark:function:: rank() -> integer + +Returns the rank of a value in a group of values. The rank is one plus the number of rows preceding the row that are not peer with the row. Thus, the values in the ordering will produce gaps in the sequence. The ranking is performed for each window partition. + +.. spark:function:: row_number() -> integer + +Returns a unique, sequential number to each row, starting with one, according to the ordering of rows within the window partition. diff --git a/velox/docs/index.rst b/velox/docs/index.rst index 37b9470c0d24..f1d65f2065e0 100644 --- a/velox/docs/index.rst +++ b/velox/docs/index.rst @@ -9,6 +9,8 @@ Velox Documentation monthly-updates functions spark_functions + functions/iceberg/functions + functions/delta/functions configs monitoring bindings/python/index diff --git a/velox/docs/mailmap_base64 b/velox/docs/mailmap_base64 index 40240370d949..5f2e47c368f7 100644 --- a/velox/docs/mailmap_base64 +++ b/velox/docs/mailmap_base64 @@ -1 +1,3 @@ -Q2hyaXN0aWFuIFplbnRncmFmIDxJQk0+IDxraXRnb2N6QGdtYWlsLmNvbT4KSmFjb2IgV3VqY2lhay1KZW5zIDxWb2x0cm9uIERhdGE+IDxqYWNvYkB3dWpjaWFrLmRlPgpQcmFtb2QgU2F0eWEgPElCTT4gPHByYW1vZEBhaGFuYS5pbz4KUHJhbW9kIFNhdHlhIDxJQk0+IDxwcmFtb2Quc2F0eWFAaWJtLmNvbT4K +Q2hyaXN0aWFuIFplbnRncmFmIDxJQk0+IDxraXRnb2N6QGdtYWlsLmNvbT4KSmFjb2IgV3VqY2lh +ay1KZW5zIDwudHh0PiA8amFjb2JAd3VqY2lhay5kZT4KUHJhbW9kIFNhdHlhIDxJQk0+IDxwcmFt +b2RAYWhhbmEuaW8+ClByYW1vZCBTYXR5YSA8SUJNPiA8cHJhbW9kLnNhdHlhQGlibS5jb20+Cg== diff --git a/velox/docs/monitoring/metrics.rst b/velox/docs/monitoring/metrics.rst index 63961e46e920..25d42af25cc2 100644 --- a/velox/docs/monitoring/metrics.rst +++ b/velox/docs/monitoring/metrics.rst @@ -9,14 +9,17 @@ instance, the collected data can help automatically generate alerts at an outage. Velox provides a framework to collect the metrics which consists of three steps: -**Define**: define the name and type for the metric through DEFINE_METRIC and -DEFINE_HISTOGRAM_METRIC macros. DEFINE_HISTOGRAM_METRIC is used for histogram -metric type and DEFINE_METRIC is used for the other types (see metric type -definition below). BaseStatsReporter provides methods for metric definition. -Register metrics during startup using registerVeloxMetrics() API. - -**Record**: record the metric data point using RECORD_METRIC_VALUE and -RECORD_HISTOGRAM_METRIC_VALUE macros when the corresponding event happens. +**Define**: define the name and type for the metric through DEFINE_METRIC, +DEFINE_HISTOGRAM_METRIC, DEFINE_QUANTILE_STAT, and DEFINE_DYNAMIC_QUANTILE_STAT +macros. DEFINE_HISTOGRAM_METRIC is used for histogram metric type, +DEFINE_QUANTILE_STAT for quantile metrics, DEFINE_DYNAMIC_QUANTILE_STAT for +dynamic quantile metrics, and DEFINE_METRIC is used for the other types (see +metric type definition below). BaseStatsReporter provides methods for metric +definition. Register metrics during startup using registerVeloxMetrics() API. + +**Record**: record the metric data point using RECORD_METRIC_VALUE, +RECORD_HISTOGRAM_METRIC_VALUE, RECORD_QUANTILE_STAT_VALUE, and +RECORD_DYNAMIC_QUANTILE_STAT_VALUE macros when the corresponding event happens. BaseStatsReporter provides methods for metric recording. **Export**: aggregates the collected data points based on the defined metrics, @@ -26,7 +29,7 @@ implementation of BaseStatsReporter is required to integrate with a specific monitoring service. The metric aggregation granularity and export interval are also configured based on the actual used monitoring service. -Velox supports five metric types: +Velox supports seven metric types: **Count**: tracks the count of events, such as the number of query failures. @@ -49,6 +52,23 @@ in max bucket. It also allows to specify the value percentiles to report for monitoring. This allows BaseStatsReporter and the backend monitoring service to optimize the aggregated data storage. +**Quantile**: tracks quantile statistics (percentiles) of event data point values +over configurable sliding time windows, such as P50, P95, and P99 latencies over +the last 60 seconds. Unlike histograms which use fixed buckets, quantile metrics +dynamically calculate percentiles from the actual data distribution. +DEFINE_QUANTILE_STAT specifies the stat types to export (e.g., AVG, COUNT, SUM), +the percentiles to track (as values between 0.0 and 1.0), and the sliding window +periods in seconds. This provides more accurate percentile calculations compared +to histogram approximations, especially for metrics with varying distributions. + +**Dynamic Quantile**: extends quantile metrics to support dynamic key patterns +with runtime substitution using format placeholders (e.g., "latency.{}.{}" where +placeholders are replaced with actual values like database names or endpoint names). +DEFINE_DYNAMIC_QUANTILE_STAT registers a pattern template, and +RECORD_DYNAMIC_QUANTILE_STAT_VALUE substitutes the placeholders to create specific +metric instances. This enables efficient tracking of quantile metrics across +multiple dimensions without pre-registering every possible combination. + Task Execution -------------- .. list-table:: @@ -73,13 +93,16 @@ Task Execution 30 buckets. It is configured to report the latency at P50, P90, P99, and P100 percentiles. * - task_batch_process_time_ms - - Average + - Avg - Tracks the averaged task batch processing time. This only applies for sequential task execution mode. * - task_barrier_process_time_ms - Histogram - Tracks task barrier execution time in range of [0, 30s] with 30 buckets and each bucket with time window of 1s. We report P50, P90, P99, and P100. + * - task_splits_count + - Count + - The total number of splits received by all tasks. Memory Management ----------------- @@ -140,9 +163,6 @@ Memory Management * - task_memory_reclaim_wait_timeout_count - Count - The number of times that the task memory reclaim wait timeouts. - * - task_splits_count - - Count - - The total number of splits received by all tasks. * - memory_non_reclaimable_count - Count - The number of times that the memory reclaim fails because the operator is executing a @@ -206,11 +226,11 @@ Memory Management arbitration operation in range of [0, 600s] with 20 buckets. It is configured to report the latency at P50, P90, P99 and P100 percentiles. * - arbitrator_free_capacity_bytes - - Average + - Avg - The average of total free memory capacity which is managed by the memory arbitrator. * - arbitrator_free_reserved_capacity_bytes - - Average + - Avg - The average of free memory capacity reserved to ensure each query has the minimal required capacity to run. * - memory_pool_initial_capacity_bytes @@ -243,15 +263,14 @@ Memory Management the bytes that are either currently being allocated or were in the past allocated, not yet been returned back to the operating system, in the form of 'Allocation' or 'ContiguousAllocation'. - * - memory_allocator_alloc_bytes + * - memory_allocator_allocated_bytes - Avg - Number of bytes currently allocated (used) from MemoryAllocator in the form of 'Allocation' or 'ContiguousAllocation'. - * - mmap_allocator_external_mapped_bytes + * - memory_allocator_external_mapped_bytes - Avg - - Number of bytes currently mapped in MmapAllocator, in the form of + - Number of bytes currently mapped in MemoryAllocator, in the form of 'ContiguousAllocation'. - NOTE: This applies only to MmapAllocator * - mmap_allocator_delegated_alloc_bytes - Avg - Number of bytes currently allocated from MmapAllocator directly from raw @@ -273,9 +292,12 @@ Cache - Avg - Max possible age of AsyncDataCache and SsdCache entries since the raw file was opened to load the cache. - * - memory_cache_num_entries + * - memory_cache_num_large_entries + - Avg + - Total number of large cache entries. + * - memory_cache_num_tiny_entries - Avg - - Total number of cache entries. + - Total number of tiny cache entries. * - memory_cache_num_empty_entries - Avg - Total number of cache entries that do not cache anything. @@ -397,9 +419,15 @@ Cache * - ssd_cache_write_ssd_errors - Sum - Total number of error while writing to SSD cache files. + * - ssd_cache_write_no_space_errors + - Sum + - Total number of errors due to SSD no space for writes. * - ssd_cache_write_ssd_dropped - Sum - Total number of writes dropped due to no cache space. + * - ssd_cache_write_exceed_entry_limit + - Sum + - Total number of writes dropped due to entry limit exceeded. * - ssd_cache_write_checkpoint_errors - Sum - Total number of errors while writing SSD checkpoint file. @@ -600,6 +628,9 @@ Index Join - The distribution of index lookup result bytes in range of [0, 128MB] with 128 buckets. It is configured to report the capacity at P50, P90, P99, and P100 percentiles. + * - index_lookup_error_result_count + - Count + - The number of results with error. Table Scan ---------- @@ -612,12 +643,11 @@ Table Scan - Type - Description * - table_scan_batch_process_time_ms - - Histogram - - The time distribution of table scan batch processing time in range of [0, - 16s] with 512 buckets and reports P50, P90, P99, and P100. + - Avg + - Tracks the averaged table scan batch processing time in milliseconds. * - table_scan_batch_bytes - - Histogram - - The size distribution of table scan output batch in range of [0, 512MB] + - Avg + - Tracks the averaged table scan output batch size in bytes. with 512 buckets and reports P50, P90, P99, and P100 S3 FileSystem diff --git a/velox/docs/monitoring/stats.rst b/velox/docs/monitoring/stats.rst index 818902b58c13..e2ff1db8db69 100644 --- a/velox/docs/monitoring/stats.rst +++ b/velox/docs/monitoring/stats.rst @@ -189,6 +189,31 @@ These stats are reported only by IndexLookupJoin operator - bytes - The byte size of the result data in velox vectors that are decoded from the raw data received from the remote storage lookup. + * - clientNumLazyDecodedResultBatches + - + - The number of lazy decoded result batches returned from the storage client. + +Merge +----- +These stats are reported only by Merge operator + +.. list-table:: + :widths: 50 25 50 + :header-rows: 1 + + * - Stats + - Unit + - Description + * - streamingSourceReadWallNanos + - nanos + - The number of a spillable operators that don't support spill because of + spill limitation. For instance, a window operator do not support spill + if there is no partitioning. + * - spilledSourceReadWallNanos + - nanos + - The running wall time of the merge operator reading from the spilled source to + produce the final output. This only applies when spilling is enabled for local + merge. Spilling -------- @@ -305,3 +330,24 @@ These stats are reported by IterativeVectorSerializer. * - compressionSkippedBytes - - The number of bytes that skip in-efficient compression. + +Connector +--------- +These stats are reported only by connector data or index sources. + +.. list-table:: + :widths: 50 25 50 + :header-rows: 1 + + * - Stats + - Unit + - Description + * - totalRemainingFilterWallNanos + - nanos + - The total walltime in nanoseconds that the data or index connector do the remaining filtering. + * - numIndexFilterConversions + - + - The number of index columns that were converted from ScanSpec filters to + index bounds for index-based filtering (e.g., cluster index pruning in + Nimble). A value greater than zero indicates filters were successfully + converted to leverage file index structures for row pruning. diff --git a/velox/docs/monthly-updates.rst b/velox/docs/monthly-updates.rst index 34830dbb3ac3..fa1abeafaeed 100644 --- a/velox/docs/monthly-updates.rst +++ b/velox/docs/monthly-updates.rst @@ -5,14 +5,10 @@ Monthly Updates .. toctree:: :maxdepth: 1 - monthly-updates/august-2024 - monthly-updates/july-2024 - monthly-updates/june-2024 - monthly-updates/may-2024 - monthly-updates/april-2024 - monthly-updates/march-2024 - monthly-updates/february-2024 - monthly-updates/january-2024 + monthly-updates/july-2025 + monthly-updates/june-2025 + monthly-updates/may-2025 + monthly-updates/2024/index monthly-updates/2023/index monthly-updates/2022/index monthly-updates/2021/index diff --git a/velox/docs/monthly-updates/april-2024.rst b/velox/docs/monthly-updates/2024/april-2024.rst similarity index 100% rename from velox/docs/monthly-updates/april-2024.rst rename to velox/docs/monthly-updates/2024/april-2024.rst diff --git a/velox/docs/monthly-updates/august-2024.rst b/velox/docs/monthly-updates/2024/august-2024.rst similarity index 100% rename from velox/docs/monthly-updates/august-2024.rst rename to velox/docs/monthly-updates/2024/august-2024.rst diff --git a/velox/docs/monthly-updates/february-2024.rst b/velox/docs/monthly-updates/2024/february-2024.rst similarity index 100% rename from velox/docs/monthly-updates/february-2024.rst rename to velox/docs/monthly-updates/2024/february-2024.rst diff --git a/velox/docs/monthly-updates/2024/index.rst b/velox/docs/monthly-updates/2024/index.rst new file mode 100644 index 000000000000..b098f4799151 --- /dev/null +++ b/velox/docs/monthly-updates/2024/index.rst @@ -0,0 +1,15 @@ +*************** +2024 +*************** + +.. toctree:: + :maxdepth: 1 + + august-2024 + july-2024 + june-2024 + may-2024 + april-2024 + march-2024 + february-2024 + january-2024 diff --git a/velox/docs/monthly-updates/january-2024.rst b/velox/docs/monthly-updates/2024/january-2024.rst similarity index 100% rename from velox/docs/monthly-updates/january-2024.rst rename to velox/docs/monthly-updates/2024/january-2024.rst diff --git a/velox/docs/monthly-updates/july-2024.rst b/velox/docs/monthly-updates/2024/july-2024.rst similarity index 100% rename from velox/docs/monthly-updates/july-2024.rst rename to velox/docs/monthly-updates/2024/july-2024.rst diff --git a/velox/docs/monthly-updates/june-2024.rst b/velox/docs/monthly-updates/2024/june-2024.rst similarity index 100% rename from velox/docs/monthly-updates/june-2024.rst rename to velox/docs/monthly-updates/2024/june-2024.rst diff --git a/velox/docs/monthly-updates/march-2024.rst b/velox/docs/monthly-updates/2024/march-2024.rst similarity index 100% rename from velox/docs/monthly-updates/march-2024.rst rename to velox/docs/monthly-updates/2024/march-2024.rst diff --git a/velox/docs/monthly-updates/may-2024.rst b/velox/docs/monthly-updates/2024/may-2024.rst similarity index 100% rename from velox/docs/monthly-updates/may-2024.rst rename to velox/docs/monthly-updates/2024/may-2024.rst diff --git a/velox/docs/monthly-updates/july-2025.rst b/velox/docs/monthly-updates/july-2025.rst new file mode 100644 index 000000000000..0146c286aa5c --- /dev/null +++ b/velox/docs/monthly-updates/july-2025.rst @@ -0,0 +1,83 @@ +**************** +July 2025 Update +**************** + +This update was generated with the assistance of AI. While we strive for accuracy, please note +that AI-generated content may not always be error-free. We encourage you to verify any information +that you deem important. + +Core Library +============ + +* Switch to C++20 standard. :pr:`10866` +* Add ParallelProject node and operator from Verax. :pr:`14220` +* Add barriered execution support to AssignUniqueId. :pr:`14224` +* Add support for left join semantics to Unnest. :pr:`14095` +* Add support for converting OPAQUE vectors to variant. :pr:`14235` +* Add basic support for coercions to help with function type resolution. :pr:`14113` +* Optimize streaming aggregation by removing max output batch size limit, reducing peak memory 4x. :pr:`14238` +* Fix deadlock when dropping non-existent child memory pools. :pr:`14202` +* Fix Variant::hash and BaseVector::hashValueAt for arrays and maps. :pr:`14019` +* Fix ConstantTypedExpr equals/hash/toString methods for various data types. :pr:`14055` +* Fix AssignUniqueId needsInput logic. :pr:`14127` + +Presto Functions +================ + +* Add :func:`dot_product` function for embedding similarity calculations. +* Add seeded version of :func:`xxhash64` function. +* Add SFM sketch functions: :func:`merge`, :func:`noisy_approx_set_sfm_from_index_and_zeros`, :func:`noisy_approx_distinct_sfm` and :func:`noisy_approx_set_sfm` aggregate functions. +* Add :func:`quantile_at_value` and :func:`scale_qdigest` functions. +* Add :func:`geometry_nearest_points`, :func:`ST_NumPoints`, :func:`ST_EnvelopeAsPts`, :func:`ST_Points` functions. +* Add :func:`ST_Buffer`, :func:`ST_CoordDim`, :func:`ST_Envelope`, :func:`ST_ExteriorRing` functions. +* Add :func:`ST_ConvexHull`, :func:`ST_Dimension`, :func:`ST_NumInteriorRing`, :func:`ST_NumGeometries` functions. +* Add :func:`ST_GeometryN`, :func:`ST_InteriorRingN`, :func:`ST_StartPoint`, :func:`ST_EndPoint` functions. +* Add :func:`ST_PointN`, :func:`ST_Length`, :func:`ST_IsClosed`, :func:`ST_Empty`, :func:`ST_IsRing` functions. +* Add :func:`ST_Polygon` function. +* Fix Geometry serialization/deserialization errors for GeometryCollections with empty geometries. :pr:`14243` +* Optimize :func:`flatten` as a VectorFunction to enable zero copy. :pr:`14215` + +Spark Functions +=============== + +* Add :spark:func:`base64` and :spark:func:`initcap` functions. +* Add support for decimal type in :spark:func:`from_json` function. +* Add :spark:func:`abs` function to handle ANSI mode differences from Presto. +* Fix :spark:func:`corr` aggregate function to return NaN instead of NULL when variance is zero. :pr:`13956` +* Fix :spark:func:`covar_samp` aggregate function to return NaN instead of Inf when c2 is infinite. :pr:`13990` +* Fix :spark:func:`get_json_object` function to normalize JSON paths properly. :pr:`13854` + +Connectors +========== + +* Add metadata support and filter pushdown to TpchConnector. :pr:`14099` +* Add HDFS filesystem operations: remove, rmdir, rename, mkdir. :pr:`13948` +* Add S3 filesystem operations: exists and list. :pr:`13893` +* Add TokenProvider support to ConnectorQueryCtx for authentication. :pr:`13919` +* Add support for timestamp as Hive partition ID. :pr:`13494` +* Add text format write support for complex types: ROW, MAP, and ARRAY. :pr:`14064` +* Add escape character support for text parsing. :pr:`14130` +* Add backward compatibility support for TIMESTAMP in TextReader. :pr:`14063` +* Fix HiveDataSink to materialize input before writes to prevent lazy vector errors. :pr:`14085` + +Performance and Correctness +=========================== + +* Make BingTile and other custom types non-orderable to match Presto behavior. :pr:`14100` +* Fix Cast to JSON output size estimation for invalid Unicode input. :pr:`14062` +* Fix constant input size validation in streaming aggregation. :pr:`13933` +* Fix hash aggregation row container cleanup crash on abort. :pr:`13979` + +Credits +======= + +Amit Dutta, Bikramjeet Vig, Bowen Wu, Chengcheng Jin, Christian Zentgraf, +Elodie Li, Eric Jia, Heidi Han, Henry Edwin Dikeman, Hongze Zhang, Jacob +Khaliqi, Jacob Wujciak-Jens, James Gill, Jialiang Tan, Jimmy Lu, Joe Abraham, +Ke Jia, Ke Wang, Kevin Wilfong, Konstantinos Karatsenidis, Krishna Pai, Libin +Bai, Manikanta Loya, Masha Basmanova, Natasha Sehgal, Oliver Xu, Orri Erling, +Patrick Sullivan, Pedro Eugenio Rocha Pedreira, Peter Enescu, Pramod Satya, +Raaghav Ravishankar, Rajeev Dharmendra Singh, Rui Mo, Sutou Kouhei, Tony Liu, +Vincent Crabtree, Xiao Du, Xiaoxuan Meng, Yi Cheng Lee, Yuxuan Chen, Zhen Li, +Zhiying Liang, aditi-pandit, lingbin, nimesh.k, wecharyu, wraymo, zhli1142015, +zml1206 diff --git a/velox/docs/monthly-updates/june-2025.rst b/velox/docs/monthly-updates/june-2025.rst new file mode 100644 index 000000000000..caa55f9202a4 --- /dev/null +++ b/velox/docs/monthly-updates/june-2025.rst @@ -0,0 +1,74 @@ +**************** +June 2025 Update +**************** + +This update was generated with the assistance of AI. While we strive for accuracy, please note +that AI-generated content may not always be error-free. We encourage you to verify any information +that is important to you. + +Core Library +============ + +* Add null key support for index join. :pr:`13891` +* Add Async SpillMerger in LocalMerge. :pr:`13634` +* Share filter among drivers for improved efficiency. :pr:`13784` +* Make global arbitration consider query priority. :pr:`13827` +* Fix anti and semi join result mismatch with filter and multiple matches. :pr:`13123` +* Fix the server crash caused by remote exchange error. :pr:`13905` +* Fix resource release for memoizing constant folding expression. :pr:`13755` +* Fix Expr::isDeterministic for lambda functions. :pr:`13647` +* Enable constant folding for lambda functions. :pr:`13642` + +Presto Functions +================ + +* Add :func:`noisy_avg_gaussian`, :func:`noisy_sum_gaussian`, :func:`noisy_count_gaussian`, :func:`noisy_count_if_gaussian` functions. +* Add :func:`tdigest_agg`, :func:`value_at_quatile`, :func:`values_at_quantiles`, :func:`quantile_at_value` functions. +* Add :func:`inverse_chi_squared_cdf`, :func:`inverse_f_cdf` functions. +* Add :func:`l2_squared` function. +* Add :func:`ST_Distance`, :func:`ST_GeometryType`, :func:`St_Centroid`, :func:`St_Boundary`, :func:`ST_XMin`, :func:`ST_XMax`, :func:`ST_YMin`, :func:`ST_YMax` functions. +* Add serialization and deserialization of Geometry type to/from WKT and WKB. :pr:`12771` +* Fix :func:`values_at_quantiles` function to throw error on null input. :pr:`13810` +* Optimize :func:`array_remove` to avoid unnecessary string copies. + +Spark Functions +=============== + +* Add :spark:func:`unbase64`, :spark:func:`trunc`, :spark:func:`varchar_type_write_side_check`, :spark:func:`cbrt` functions. +* Fix try_cast and cast function error case handling. :pr:`12993` +* Fix :spark:func:`lower` function on unicode character. :pr:`13158` +* Fix handling of extreme floating-point values in :spark:func:`from_json`. :pr:`13378` + +Connectors +========== + +* Add support for exist and list functions in HdfsFileSystem. :pr:`13813` +* Add GCS filesystem operations: rmdir, mkdir, and rename. :pr:`13533`, :pr:`13532`, :pr:`13490` +* Flatten complex-type vectors when writing to Parquet. :pr:`13338` +* Fix multi range filter in timestamp reader for Parquet. :pr:`12926` +* Fix NPE when reading complex type data from Parquet v2. :pr:`13512` + +Performance and Correctness +=========================== + +* Optimize getStringView performance. :pr:`13870` +* Optimize streaming aggregation performance. :pr:`13812` +* Reduce HashTable load factor from 0.875 to 0.7 for better performance. :pr:`13694` +* Improve fillNewMemory tight loop performance. :pr:`13883` +* Increase readBatchSize when the last batch is empty in TableScan. :pr:`13626` + +Credits +======= + +aditi-pandit, Anders Dellien, Andrii Rosa, Artem Selishchev, Bikramjeet Vig, +Bowen Wu, Chandrashekhar Kumar Singh, Chengcheng Jin, Christian Zentgraf, +David Reveman, Deepak Majeti, Devavret Makkar, Dharan Aditya, duanmeng, +Eric Jia, Huameng (Michael) Jiang, iiFeung, Jacob Khaliqi, Jacob Wujciak-Jens, +Jialiang Tan, Jim Meyering, Jimmy Lu, Jin Chengcheng, Joe Giardino, joey.ljy, +Ke Jia, Ke Wang, Kent Yao, Kevin Wilfong, Konstantinos Karatsenidis, +Kostas Xirogiannopoulos, Krishna Pai, lingbin, Luis Garcés-Erice, Mario Ruiz, +Masha Basmanova, Natasha Sehgal, Oliver Xu, Orri Erling, Patrick Sullivan, +Paul Meng, Peter Enescu, Ping Liu, Richard Barnes, Rui Mo, Shakyan Kushwaha, +Wei He, wraymo, xhs7700, Xiao Du, xiaodou, Xiaoxuan Meng, Xin Zhang, Yabin Ma, +Yi Cheng Lee, yingsu00, Yoav Helfman, yumwang@ebay.com, Zhichen Xu, +zhli1142015, zml1206, Zoltan Arnold Nagy diff --git a/velox/docs/monthly-updates/may-2025.rst b/velox/docs/monthly-updates/may-2025.rst new file mode 100644 index 000000000000..d049377a9a52 --- /dev/null +++ b/velox/docs/monthly-updates/may-2025.rst @@ -0,0 +1,83 @@ +*************** +May 2025 Update +*************** + +This update was generated with the assistance of AI. While we strive for accuracy, please note +that AI-generated content may not always be error-free. We encourage you to verify any information +that is important to you. + +Core Library +============ + +* Add task barrier. :pr:`13087` +* Add task barrier support for streaming aggregation, unnest and index join. :pr:`13273`, :pr:`13293`, :pr:`13244` +* Add support for memory pool priority. :pr:`13386` +* Support buffering in local exchange operator. :pr:`13234` +* Add PageSpill for OutputBuffer spill. :pr:`13305` +* Add lazy start with spill for LocalMerge. :pr:`13337` +* Add support for leftSemiProject join in nested loop join. :pr:`12172` +* Fix crash when aggregate push down applied on updated column with sparse row set. :pr:`13503` +* Fix the timing order of CpuWallTimer. :pr:`13313` +* Fix overflow in NegatedBigintValuesUsingHashTable::testInt64Range. :pr:`13523` + +Presto Functions +================ + +* Add :func:`simplify_geometry`, :func:`geometry_invalid_reason`, :func:`ST_IsValid`, :func:`ST_IsSimple`, :func:`ST_Point`, :func:`ST_X`, :func:`ST_Y`, :func:`ST_area` functions. +* Add :func:`bing_tile_at`, :func:`bing_tiles_around` functions. +* Add :func:`qdigest_agg`, :func:`cosine_similarity` functions. +* Add :func:`noisy_count_if_gaussian` function. +* Add :func:`quantile_at_value`, :func:`trimmed_mean` functions. +* Add geometry functions for WKT/WKB conversion. +* Add :func:`xxhash64_internal` with extended type support. +* Fix undefined behavior in qdigest when total weight exceeds int64_t max. :pr:`13336` +* Fix handing of empty arrays in :func:`array_min`, :func:`array_max_by` functions. :pr:`13272` +* Fix overflow check in timestamp addition. :pr:`13444` +* Fix overflow in :func:`from_unixtime`. :pr:`13262` + +Spark Functions +=============== + +* Add :spark:func:`sqrt`, :spark:func:`luhn_check` functions. +* Add CAST(bool as timestamp) support. +* Add support for legacy behavior in covariance functions. :pr:`12994` +* Fix casting complex types to only cast recursively if child type changes. :pr:`13245` +* Fix unescape json elements in :spark:func:`array_join`. :pr:`13222` +* Fix the duplicate map key handling for :spark:func:`map` function. :pr:`13183` + +Connectors +========== + +* Support bucket write with non-partitioned table. :pr:`13283` +* Support delta update on bucket column. :pr:`13404` +* Support Null Column Projection in Batch Reader Adapter. :pr:`13430` +* Add ColumnReaderOptions for better reader configuration. :pr:`12840` +* Add support for float-to-double schema evolution. :pr:`13317` +* Fix selective flatmap column reader read offset when all rows filtered out. :pr:`13350` +* Fix access after buffer boundary causing crash in selective reader. :pr:`13344` +* Fix crash if table column type does not match file column type. :pr:`12350` +* Fix incorrect filter result during schema evolution when range is outside of old type. :pr:`13459` + +Performance and Correctness +=========================== + +* Optimize selective ARRAY and MAP reader. :pr:`13240` +* Avoid decompressing data when estimating row size. :pr:`13365` +* Clear hash table as soon as probe finish. :pr:`13254` +* Allow empty file in table scan. :pr:`13241` + +Credits +======= + +aditi-pandit, ajeyabsf, Ali LeClerc, alileclerc, Amit Dutta, Anders Dellien, +Andrii Rosa, arnavb, Artem Selishchev, Bikramjeet Vig, Chandrashekhar Kumar Singh, +Chengcheng Jin, Christian Zentgraf, Deepak Majeti, Devavret Makkar, duanmeng, +Eric Jia, Haiping Xue, Heidi Han, Jacob Khaliqi, Jacob Wujciak-Jens, James Gill, +Jialiang Tan, Jimmy Lu, Ke Jia, Ke Wang, Kent Yao, Kevin Wilfong, Kien Nguyen, +Kk Pulla, Kostas Xirogiannopoulos, Krishna Pai, lingbin, Lukas Krenz, +MacVincent Agha-Oko, Mario Ruiz, Mingyu Zhang, Natasha Sehgal, Nathan Phan, +NEUpanning, Oliver Xu, Patrick Sullivan, Pedro Eugenio Rocha Pedreira, +Peter Enescu, Pradeep Vaka, Qian Sun, Rui Mo, Serge Druzkin, Shakyan Kushwaha, +Soumya Duriseti, Surbhi Vijayvargeeya, Tanay Bhartia, Wei He, Xiao Du, +Xiaoxuan Meng, Yabin Ma, Yenda Li, Yi Cheng Lee, Zhenyuan Zhao, Zhiguo Wu, +Zhiying Liang diff --git a/velox/docs/spark_functions.rst b/velox/docs/spark_functions.rst index 24c825ac1ef5..e6944a702a9e 100644 --- a/velox/docs/spark_functions.rst +++ b/velox/docs/spark_functions.rst @@ -2,7 +2,11 @@ Spark Functions *********************** -The semantics of Spark functions match Spark 3.5 with ANSI OFF. +The semantics of Spark functions align with +`Spark 3.5 `_. +In the function descriptions, a function is marked as *ANSI compliant* +if it adheres to ANSI standard, subject to the :doc:`spark.ansi_enabled ` +configuration. Otherwise, it simply follows Spark's semantics in ANSI OFF mode. .. toctree:: :maxdepth: 1 @@ -61,35 +65,80 @@ for :doc:`all ` functions. :widths: auto :class: rows - ================================ ================================ ================================ == ================================ == ================================ - Scalar Functions Aggregate Functions Window Functions - ==================================================================================================== == ================================ == ================================ - :spark:func:`abs` :spark:func:`floor` :spark:func:`power` :spark:func:`bit_xor` :spark:func:`nth_value` - :spark:func:`acos` :spark:func:`get_json_object` :spark:func:`rand` :spark:func:`first` - :spark:func:`acosh` :spark:func:`greaterthan` :spark:func:`regexp_extract` :spark:func:`first_ignore_null` - :spark:func:`add` :spark:func:`greaterthanorequal` :spark:func:`remainder` :spark:func:`last` - :spark:func:`aggregate` :spark:func:`greatest` :spark:func:`replace` :spark:func:`last_ignore_null` - :spark:func:`array` :spark:func:`hash` :spark:func:`rlike` - :spark:func:`array_contains` :spark:func:`hypot` :spark:func:`round` - :spark:func:`array_intersect` :spark:func:`in` :spark:func:`rtrim` - :spark:func:`array_sort` :spark:func:`instr` :spark:func:`sec` - :spark:func:`ascii` :spark:func:`isnotnull` :spark:func:`sha1` - :spark:func:`asinh` :spark:func:`isnull` :spark:func:`sha2` - :spark:func:`atanh` :spark:func:`least` :spark:func:`shiftleft` - :spark:func:`between` :spark:func:`left` :spark:func:`shiftright` - :spark:func:`bin` :spark:func:`length` :spark:func:`sinh` - :spark:func:`bitwise_and` :spark:func:`lessthan` :spark:func:`size` - :spark:func:`bitwise_or` :spark:func:`lessthanorequal` :spark:func:`sort_array` - :spark:func:`ceil` :spark:func:`log1p` :spark:func:`split` - :spark:func:`chr` :spark:func:`lower` :spark:func:`startswith` - :spark:func:`concat` :spark:func:`ltrim` :spark:func:`substring` - :spark:func:`contains` :spark:func:`map` :spark:func:`subtract` - :spark:func:`csc` :spark:func:`map_filter` :spark:func:`to_unix_timestamp` - :spark:func:`divide` :spark:func:`map_from_arrays` :spark:func:`transform` - :spark:func:`element_at` :spark:func:`md5` :spark:func:`trim` - :spark:func:`endswith` :spark:func:`might_contain` :spark:func:`unaryminus` - :spark:func:`equalnullsafe` :spark:func:`multiply` :spark:func:`unix_timestamp` - :spark:func:`equalto` :spark:func:`not` :spark:func:`upper` - :spark:func:`exp` :spark:func:`notequalto` :spark:func:`xxhash64` - :spark:func:`filter` :spark:func:`pmod` :spark:func:`year` - ================================ ================================ ================================ == ================================ == ================================ + =========================================== =========================================== =========================================== == =========================================== == =========================================== + Scalar Functions Aggregate Functions Window Functions + ===================================================================================================================================== == =========================================== == =========================================== + :spark:func:`abs` :spark:func:`divide_deny_precision_loss` :spark:func:`not` :spark:func:`avg` :spark:func:`dense_rank` + :spark:func:`acos` :spark:func:`doy` :spark:func:`overlay` :spark:func:`bit_xor` :spark:func:`nth_value` + :spark:func:`acosh` :spark:func:`element_at` :spark:func:`pmod` :spark:func:`bloom_filter_agg` :spark:func:`ntile` + :spark:func:`add` :spark:func:`empty2null` :spark:func:`power` :spark:func:`collect_list` :spark:func:`rank` + :spark:func:`add_deny_precision_loss` :spark:func:`endswith` :spark:func:`quarter` :spark:func:`collect_set` :spark:func:`row_number` + :spark:func:`add_months` :spark:func:`equalnullsafe` :spark:func:`raise_error` :spark:func:`corr` + :spark:func:`aggregate` :spark:func:`equalto` :spark:func:`rand` :spark:func:`covar_samp` + :spark:func:`array` :spark:func:`exists` :spark:func:`random` :spark:func:`first` + :spark:func:`array_append` :spark:func:`exp` :spark:func:`regexp_extract` :spark:func:`first_ignore_null` + :spark:func:`array_compact` :spark:func:`expm1` :spark:func:`regexp_extract_all` :spark:func:`kurtosis` + :spark:func:`array_contains` :spark:func:`factorial` :spark:func:`regexp_replace` :spark:func:`last` + :spark:func:`array_distinct` :spark:func:`filter` :spark:func:`remainder` :spark:func:`last_ignore_null` + :spark:func:`array_except` :spark:func:`find_in_set` :spark:func:`repeat` :spark:func:`max` + :spark:func:`array_insert` :spark:func:`flatten` :spark:func:`replace` :spark:func:`max_by` + :spark:func:`array_intersect` :spark:func:`floor` :spark:func:`reverse` :spark:func:`min` + :spark:func:`array_join` :spark:func:`forall` :spark:func:`rint` :spark:func:`min_by` + :spark:func:`array_max` :spark:func:`from_unixtime` :spark:func:`rlike` :spark:func:`mode` + :spark:func:`array_min` :spark:func:`from_utc_timestamp` :spark:func:`round` :spark:func:`regr_replacement` + :spark:func:`array_position` :spark:func:`get` :spark:func:`rpad` :spark:func:`skewness` + :spark:func:`array_prepend` :spark:func:`get_json_object` :spark:func:`rtrim` :spark:func:`stddev` + :spark:func:`array_remove` :spark:func:`get_timestamp` :spark:func:`sec` :spark:func:`stddev_samp` + :spark:func:`array_repeat` :spark:func:`greaterthan` :spark:func:`second` :spark:func:`sum` + :spark:func:`array_sort` :spark:func:`greaterthanorequal` :spark:func:`sha1` :spark:func:`var_samp` + :spark:func:`array_union` :spark:func:`greatest` :spark:func:`sha2` :spark:func:`variance` + :spark:func:`arrays_zip` :spark:func:`hash` :spark:func:`shiftleft` + :spark:func:`ascii` :spark:func:`hash_with_seed` :spark:func:`shiftright` + :spark:func:`asin` :spark:func:`hex` :spark:func:`shuffle` + :spark:func:`asinh` :spark:func:`hour` :spark:func:`sign` + :spark:func:`atan` :spark:func:`hypot` :spark:func:`sinh` + :spark:func:`atan2` :spark:func:`in` :spark:func:`size` + :spark:func:`atanh` :spark:func:`instr` :spark:func:`slice` + :spark:func:`between` :spark:func:`isnan` :spark:func:`sort_array` + :spark:func:`bin` :spark:func:`isnotnull` :spark:func:`soundex` + :spark:func:`bit_count` :spark:func:`isnull` :spark:func:`spark_partition_id` + :spark:func:`bit_get` :spark:func:`json_array_length` :spark:func:`split` + :spark:func:`bit_length` :spark:func:`json_object_keys` :spark:func:`sqrt` + :spark:func:`bitwise_and` :spark:func:`last_day` :spark:func:`startswith` + :spark:func:`bitwise_not` :spark:func:`least` :spark:func:`str_to_map` + :spark:func:`bitwise_or` :spark:func:`left` :spark:func:`substring` + :spark:func:`bitwise_xor` :spark:func:`length` :spark:func:`substring_index` + :spark:func:`cbrt` :spark:func:`lessthan` :spark:func:`subtract` + :spark:func:`ceil` :spark:func:`lessthanorequal` :spark:func:`subtract_deny_precision_loss` + :spark:func:`checked_add` :spark:func:`levenshtein` :spark:func:`timestamp_micros` + :spark:func:`checked_divide` :spark:func:`like` :spark:func:`timestamp_millis` + :spark:func:`checked_multiply` :spark:func:`locate` :spark:func:`to_unix_timestamp` + :spark:func:`checked_subtract` :spark:func:`log` :spark:func:`to_utc_timestamp` + :spark:func:`chr` :spark:func:`log10` :spark:func:`transform` + :spark:func:`concat` :spark:func:`log1p` :spark:func:`translate` + :spark:func:`contains` :spark:func:`log2` :spark:func:`trim` + :spark:func:`conv` :spark:func:`lower` :spark:func:`trunc` + :spark:func:`cos` :spark:func:`lpad` :spark:func:`unaryminus` + :spark:func:`cosh` :spark:func:`ltrim` :spark:func:`unbase64` + :spark:func:`cot` :spark:func:`luhn_check` :spark:func:`unhex` + :spark:func:`crc32` :spark:func:`make_date` :spark:func:`unix_date` + :spark:func:`csc` :spark:func:`make_timestamp` :spark:func:`unix_micros` + :spark:func:`date_add` :spark:func:`make_ym_interval` :spark:func:`unix_millis` + :spark:func:`date_format` :spark:func:`map` :spark:func:`unix_seconds` + :spark:func:`date_from_unix_date` :spark:func:`map_concat` :spark:func:`unix_timestamp` + :spark:func:`date_sub` :spark:func:`map_entries` :spark:func:`unscaled_value` + :spark:func:`date_trunc` :spark:func:`map_filter` :spark:func:`upper` + :spark:func:`datediff` :spark:func:`map_from_arrays` :spark:func:`url_decode` + :spark:func:`day` :spark:func:`map_keys` :spark:func:`url_encode` + :spark:func:`dayofmonth` :spark:func:`map_values` :spark:func:`uuid` + :spark:func:`dayofweek` :spark:func:`map_zip_with` :spark:func:`varchar_type_write_side_check` + :spark:func:`dayofyear` :spark:func:`mask` :spark:func:`week_of_year` + :spark:func:`decimal_equalto` :spark:func:`md5` :spark:func:`weekday` + :spark:func:`decimal_greaterthan` :spark:func:`might_contain` :spark:func:`width_bucket` + :spark:func:`decimal_greaterthanorequal` :spark:func:`minute` :spark:func:`xxhash64` + :spark:func:`decimal_lessthan` :spark:func:`monotonically_increasing_id` :spark:func:`xxhash64_with_seed` + :spark:func:`decimal_lessthanorequal` :spark:func:`month` :spark:func:`year` + :spark:func:`decimal_notequalto` :spark:func:`multiply` :spark:func:`year_of_week` + :spark:func:`degrees` :spark:func:`multiply_deny_precision_loss` :spark:func:`zip_with` + :spark:func:`divide` :spark:func:`next_day` + =========================================== =========================================== =========================================== == =========================================== == =========================================== diff --git a/velox/duckdb/conversion/CMakeLists.txt b/velox/duckdb/conversion/CMakeLists.txt index 8361cb68e20f..010f13f8f358 100644 --- a/velox/duckdb/conversion/CMakeLists.txt +++ b/velox/duckdb/conversion/CMakeLists.txt @@ -13,13 +13,16 @@ # limitations under the License. velox_add_library(velox_duckdb_conversion DuckConversion.cpp) -velox_link_libraries(velox_duckdb_conversion velox_core velox_vector - duckdb_static) +velox_link_libraries(velox_duckdb_conversion velox_core velox_vector duckdb_static) velox_add_library(velox_duckdb_parser DuckParser.cpp) -velox_link_libraries(velox_duckdb_parser velox_duckdb_conversion - velox_parse_expression duckdb_static) +velox_link_libraries( + velox_duckdb_parser + velox_duckdb_conversion + velox_parse_expression + duckdb_static +) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/duckdb/conversion/DuckConversion.cpp b/velox/duckdb/conversion/DuckConversion.cpp index 912e0746243d..e1b23eda4e70 100644 --- a/velox/duckdb/conversion/DuckConversion.cpp +++ b/velox/duckdb/conversion/DuckConversion.cpp @@ -32,21 +32,21 @@ using ::duckdb::dtime_t; using ::duckdb::string_t; using ::duckdb::timestamp_t; -variant decimalVariant(const Value& val) { +Variant decimalVariant(const Value& val) { VELOX_DCHECK(val.type().id() == LogicalTypeId::DECIMAL); switch (val.type().InternalType()) { case ::duckdb::PhysicalType::INT128: { auto unscaledValue = val.GetValueUnsafe<::duckdb::hugeint_t>(); - return variant(HugeInt::build(unscaledValue.upper, unscaledValue.lower)); + return Variant(HugeInt::build(unscaledValue.upper, unscaledValue.lower)); } case ::duckdb::PhysicalType::INT16: { - return variant(static_cast(val.GetValueUnsafe())); + return Variant(static_cast(val.GetValueUnsafe())); } case ::duckdb::PhysicalType::INT32: { - return variant(static_cast(val.GetValueUnsafe())); + return Variant(static_cast(val.GetValueUnsafe())); } case ::duckdb::PhysicalType::INT64: { - return variant(val.GetValueUnsafe()); + return Variant(val.GetValueUnsafe()); } default: VELOX_UNSUPPORTED(); @@ -192,6 +192,9 @@ TypePtr toVeloxType(LogicalType type, bool fileColumnNamesReadAsLowerCase) { if (auto customType = getCustomType(name, {})) { return customType; } + if (name == "OPAQUE") { + return OPAQUE(); + } [[fallthrough]]; } default: @@ -201,37 +204,37 @@ TypePtr toVeloxType(LogicalType type, bool fileColumnNamesReadAsLowerCase) { } } -variant duckValueToVariant(const Value& val) { +Variant duckValueToVariant(const Value& val) { switch (val.type().id()) { case LogicalTypeId::SQLNULL: - return variant(TypeKind::UNKNOWN); + return Variant(TypeKind::UNKNOWN); case LogicalTypeId::BOOLEAN: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::TINYINT: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::SMALLINT: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::INTEGER: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::BIGINT: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::FLOAT: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::DOUBLE: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::TIMESTAMP: - return variant(duckdbTimestampToVelox(val.GetValue())); + return Variant(duckdbTimestampToVelox(val.GetValue())); case LogicalTypeId::DECIMAL: return decimalVariant(val); case LogicalTypeId::VARCHAR: - return variant(val.GetValue()); + return Variant(val.GetValue()); case LogicalTypeId::BLOB: - return variant::binary(val.GetValue()); + return Variant::binary(val.GetValue()); case LogicalTypeId::DATE: - return variant(val.GetValue<::duckdb::date_t>().days); + return Variant(val.GetValue<::duckdb::date_t>().days); default: throw std::runtime_error( - "unsupported type for duckdb value -> velox variant conversion: " + + "unsupported type for duckdb value -> velox variant conversion: " + val.type().ToString()); } } diff --git a/velox/duckdb/conversion/DuckConversion.h b/velox/duckdb/conversion/DuckConversion.h index 9888a0cb66a8..c4ffa51ef776 100644 --- a/velox/duckdb/conversion/DuckConversion.h +++ b/velox/duckdb/conversion/DuckConversion.h @@ -20,7 +20,7 @@ #include // @manual namespace facebook::velox { -class variant; +class Variant; } namespace facebook::velox::duckdb { @@ -57,16 +57,16 @@ static Timestamp duckdbTimestampToVelox( } // Converts a duckDB Value (class that holds an arbitrary data type) into -// Velox variant. -variant duckValueToVariant(const ::duckdb::Value& val); +// Velox Variant. +Variant duckValueToVariant(const ::duckdb::Value& val); -// Converts duckDB decimal Value into appropriate decimal variant. +// Converts duckDB decimal Value into appropriate decimal Variant. // The duckdb::Value::GetValue() call for decimal type returns a double value. // To avoid this, this method uses the duckdb::Value::GetUnsafeValue() // method. // @param val duckdb decimal value. -// @return decimal variant. -variant decimalVariant(const ::duckdb::Value& val); +// @return decimal Variant. +Variant decimalVariant(const ::duckdb::Value& val); // value conversion routines template diff --git a/velox/duckdb/conversion/DuckParser.cpp b/velox/duckdb/conversion/DuckParser.cpp index b5427b7f33fa..dcbbfcb48962 100644 --- a/velox/duckdb/conversion/DuckParser.cpp +++ b/velox/duckdb/conversion/DuckParser.cpp @@ -15,7 +15,6 @@ */ #include "velox/duckdb/conversion/DuckParser.h" #include "velox/common/base/Exceptions.h" -#include "velox/core/PlanNode.h" #include "velox/duckdb/conversion/DuckConversion.h" #include "velox/parse/Expressions.h" #include "velox/type/Variant.h" @@ -59,9 +58,7 @@ using ::duckdb::WindowExpression; namespace { -std::shared_ptr parseExpr( - ParsedExpression& expr, - const ParseOptions& options); +core::ExprPtr parseExpr(ParsedExpression& expr, const ParseOptions& options); std::string normalizeFuncName(std::string input) { static std::map kLookup{ @@ -133,7 +130,7 @@ std::optional getAlias(const ParsedExpression& expr) { std::shared_ptr callExpr( std::string name, - std::vector> params, + std::vector params, std::optional alias, const ParseOptions& options) { return std::make_shared( @@ -144,10 +141,10 @@ std::shared_ptr callExpr( std::shared_ptr callExpr( std::string name, - const std::shared_ptr& param, + const core::ExprPtr& param, std::optional alias, const ParseOptions& options) { - std::vector> params = {param}; + std::vector params = {param}; return std::make_shared( toFullFunctionName(name, options.functionPrefix), std::move(params), @@ -155,7 +152,7 @@ std::shared_ptr callExpr( } // Parse a constant (1, 99.8, "string", etc). -std::shared_ptr parseConstantExpr( +core::ExprPtr parseConstantExpr( ParsedExpression& expr, const ParseOptions& options) { auto& constantExpr = dynamic_cast(expr); @@ -179,7 +176,7 @@ std::shared_ptr parseConstantExpr( } // Parse a column reference (col1, "col2", tbl.col, etc). -std::shared_ptr parseColumnRefExpr( +core::ExprPtr parseColumnRefExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& colRefExpr = dynamic_cast(expr); @@ -190,9 +187,8 @@ std::shared_ptr parseColumnRefExpr( return std::make_shared( colRefExpr.GetColumnName(), getAlias(expr), - std::vector>{ - std::make_shared( - colRefExpr.GetTableName(), std::nullopt)}); + std::vector{std::make_shared( + colRefExpr.GetTableName(), std::nullopt)}); } namespace { @@ -213,7 +209,7 @@ std::optional extractInteger(const core::ConstantExpr& constInput) { std::shared_ptr tryParseInterval( const std::string& functionName, - const std::shared_ptr& input, + const core::ExprPtr& input, std::optional alias) { std::optional value; @@ -221,8 +217,8 @@ std::shared_ptr tryParseInterval( value = extractInteger(*constInput); } else if ( auto castInput = dynamic_cast(input.get())) { - if (auto constInput = dynamic_cast( - castInput->getInput().get())) { + if (auto constInput = + dynamic_cast(castInput->input().get())) { value = extractInteger(*constInput); } } @@ -256,20 +252,20 @@ std::shared_ptr tryParseInterval( } return std::make_shared( INTERVAL_YEAR_MONTH(), - variant((int32_t)(value.value() * multiplier)), + Variant((int32_t)(value.value() * multiplier)), alias); } return std::make_shared( - INTERVAL_DAY_TIME(), variant(value.value() * multiplier), alias); + INTERVAL_DAY_TIME(), Variant(value.value() * multiplier), alias); } // Parse a function call (avg(a), func(1, b), etc). // Arithmetic operators also follow this path (a + b, a * b, etc). -std::shared_ptr parseFunctionExpr( +core::ExprPtr parseFunctionExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& functionExpr = dynamic_cast(expr); - std::vector> params; + std::vector params; params.reserve(functionExpr.children.size()); for (const auto& c : functionExpr.children) { @@ -295,11 +291,11 @@ std::shared_ptr parseFunctionExpr( } // Parse a comparison (a > b, a = b, etc). -std::shared_ptr parseComparisonExpr( +core::ExprPtr parseComparisonExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& compExpr = dynamic_cast(expr); - std::vector> params{ + std::vector params{ parseExpr(*compExpr.left, options), parseExpr(*compExpr.right, options)}; return callExpr( normalizeFuncName(ExpressionTypeToOperator(expr.GetExpressionType())), @@ -309,7 +305,7 @@ std::shared_ptr parseComparisonExpr( } // Parse x between lower and upper -std::shared_ptr parseBetweenExpr( +core::ExprPtr parseBetweenExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& betweenExpr = dynamic_cast(expr); @@ -323,7 +319,7 @@ std::shared_ptr parseBetweenExpr( } // Parse a conjunction (AND or OR). -std::shared_ptr parseConjunctionExpr( +core::ExprPtr parseConjunctionExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& conjExpr = dynamic_cast(expr); @@ -331,19 +327,20 @@ std::shared_ptr parseConjunctionExpr( StringUtil::Lower(ExpressionTypeToOperator(expr.GetExpressionType())); if (conjExpr.children.size() < 2) { - throw std::invalid_argument(folly::sformat( - "Malformed conjunction expression " - "(expected at least 2 input columns, got {}).", - conjExpr.children.size())); + throw std::invalid_argument( + folly::sformat( + "Malformed conjunction expression " + "(expected at least 2 input columns, got {}).", + conjExpr.children.size())); } // DuckDB's parser returns conjunction involving multiple input in a flat // expression, in the form `AND(a, b, d, e)`, but internally we expect // conjunctions to have exactly 2 input. This code converts that input into // `AND(AND(AND(a, b), d), e)` (so it's executed in the same order). - std::shared_ptr current; + core::ExprPtr current; for (size_t i = 1; i < conjExpr.children.size(); ++i) { - std::vector> params; + std::vector params; params.reserve(2); if (current == nullptr) { @@ -368,7 +365,7 @@ static bool areAllChildrenConstant(const OperatorExpression& operExpr) { } // Parse an "operator", like NOT. -std::shared_ptr parseOperatorExpr( +core::ExprPtr parseOperatorExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& operExpr = dynamic_cast(expr); @@ -376,7 +373,7 @@ std::shared_ptr parseOperatorExpr( // Code for array literal parsing (e.g. "ARRAY[1, 2, 3]") if (expr.GetExpressionType() == ExpressionType::ARRAY_CONSTRUCTOR) { if (areAllChildrenConstant(operExpr)) { - std::vector arrayElements; + std::vector arrayElements; arrayElements.reserve(operExpr.children.size()); TypePtr valueType = UNKNOWN(); @@ -397,9 +394,9 @@ std::shared_ptr parseOperatorExpr( } } return std::make_shared( - ARRAY(valueType), variant::array(arrayElements), getAlias(expr)); + ARRAY(valueType), Variant::array(arrayElements), getAlias(expr)); } else { - std::vector> params; + std::vector params; params.reserve(operExpr.children.size()); for (const auto& child : operExpr.children) { @@ -415,8 +412,16 @@ std::shared_ptr parseOperatorExpr( expr.GetExpressionType() == ExpressionType::COMPARE_NOT_IN) { auto numValues = operExpr.children.size() - 1; - std::vector values; - values.reserve(numValues); + std::vector values; + if (options.parseInListAsArray) { + values.reserve(numValues); + } + + std::vector params; + if (!options.parseInListAsArray) { + params.reserve(numValues + 1); + } + params.emplace_back(parseExpr(*operExpr.children[0], options)); TypePtr valueType = UNKNOWN(); for (auto i = 0; i < numValues; i++) { @@ -428,21 +433,29 @@ std::shared_ptr parseOperatorExpr( dynamic_cast(castExpr->child.get()); auto value = constExpr->value.DefaultCastAs( castExpr->cast_type, !castExpr->try_cast); - values.emplace_back(duckValueToVariant(value)); - valueType = toVeloxType(castExpr->cast_type); + if (options.parseInListAsArray) { + values.emplace_back(duckValueToVariant(value)); + valueType = toVeloxType(castExpr->cast_type); + } else { + params.emplace_back(parseExpr(*castExpr->child, options)); + } continue; } } if (auto constantExpr = dynamic_cast(valueExpr)) { - auto& value = constantExpr->value; - if (options.parseDecimalAsDouble && - value.type().id() == duckdb::LogicalTypeId::DECIMAL) { - value = Value::DOUBLE(value.GetValue()); - } - values.emplace_back(duckValueToVariant(value)); - if (!value.IsNull()) { - valueType = toVeloxType(value.type()); + if (options.parseInListAsArray) { + auto& value = constantExpr->value; + if (options.parseDecimalAsDouble && + value.type().id() == duckdb::LogicalTypeId::DECIMAL) { + value = Value::DOUBLE(value.GetValue()); + } + values.emplace_back(duckValueToVariant(value)); + if (!value.IsNull()) { + valueType = toVeloxType(value.type()); + } + } else { + params.emplace_back(parseExpr(*constantExpr, options)); } continue; } @@ -450,10 +463,11 @@ std::shared_ptr parseOperatorExpr( VELOX_UNSUPPORTED("IN list values need to be constant"); } - std::vector> params; - params.emplace_back(parseExpr(*operExpr.children[0], options)); - params.emplace_back(std::make_shared( - ARRAY(valueType), variant::array(values), std::nullopt)); + if (options.parseInListAsArray) { + params.emplace_back( + std::make_shared( + ARRAY(valueType), Variant::array(values), std::nullopt)); + } auto inExpr = callExpr("in", std::move(params), getAlias(expr), options); // Translate COMPARE_NOT_IN into NOT(IN()). return (expr.GetExpressionType() == ExpressionType::COMPARE_IN) @@ -461,7 +475,7 @@ std::shared_ptr parseOperatorExpr( : callExpr("not", inExpr, std::nullopt, options); } - std::vector> params; + std::vector params; params.reserve(operExpr.children.size()); for (const auto& child : operExpr.children) { @@ -472,7 +486,7 @@ std::shared_ptr parseOperatorExpr( // (a).b.c, (a.b).c if (expr.GetExpressionType() == ExpressionType::STRUCT_EXTRACT) { VELOX_CHECK_EQ(params.size(), 2); - std::vector> input = {params[0]}; + std::vector input = {params[0]}; if (auto constantExpr = std::dynamic_pointer_cast(params[1])) { @@ -501,7 +515,7 @@ std::shared_ptr parseOperatorExpr( } namespace { -bool isNullConstant(const std::shared_ptr& expr) { +bool isNullConstant(const core::ExprPtr& expr) { if (auto constExpr = std::dynamic_pointer_cast(expr)) { return constExpr->value().isNull(); @@ -512,7 +526,7 @@ bool isNullConstant(const std::shared_ptr& expr) { } // namespace // Parse an IF()/CASE expression. -std::shared_ptr parseCaseExpr( +core::ExprPtr parseCaseExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& caseExpr = dynamic_cast(expr); @@ -521,7 +535,7 @@ std::shared_ptr parseCaseExpr( if (checks.size() == 1) { const auto& check = checks.front(); - std::vector> params{ + std::vector params{ parseExpr(*check.when_expr, options), parseExpr(*check.then_expr, options), parseExpr(*caseExpr.else_expr, options), @@ -529,7 +543,7 @@ std::shared_ptr parseCaseExpr( return callExpr("if", std::move(params), getAlias(expr), options); } - std::vector> inputs; + std::vector inputs; inputs.reserve(checks.size() * 2 + 1); for (auto& check : checks) { inputs.emplace_back(parseExpr(*check.when_expr, options)); @@ -545,12 +559,11 @@ std::shared_ptr parseCaseExpr( } // Parse an CAST expression. -std::shared_ptr parseCastExpr( +core::ExprPtr parseCastExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& castExpr = dynamic_cast(expr); - std::vector> params{ - parseExpr(*castExpr.child, options)}; + std::vector params{parseExpr(*castExpr.child, options)}; // We may need to expand toVeloxType in the future to support // Map and Array and Struct properly. auto targetType = toVeloxType(castExpr.cast_type); @@ -561,7 +574,7 @@ std::shared_ptr parseCastExpr( dynamic_cast(params[0].get())) { if (constant->value().isNull()) { return std::make_shared( - targetType, variant::null(targetType->kind()), getAlias(expr)); + targetType, Variant::null(targetType->kind()), getAlias(expr)); } // DuckDB parses BOOLEAN literal as cast expression. Try to restore it back @@ -573,25 +586,25 @@ std::shared_ptr parseCastExpr( if (s == "t") { return std::make_shared( BOOLEAN(), - variant::create(true), + Variant::create(true), getAlias(expr)); } if (s == "f") { return std::make_shared( BOOLEAN(), - variant::create(false), + Variant::create(false), getAlias(expr)); } } } - const bool nullOnFailure = castExpr.try_cast; + const bool isTryCast = castExpr.try_cast; return std::make_shared( - targetType, params[0], nullOnFailure, getAlias(expr)); + targetType, params[0], isTryCast, getAlias(expr)); } -std::shared_ptr parseLambdaExpr( +core::ExprPtr parseLambdaExpr( ParsedExpression& expr, const ParseOptions& options) { const auto& lambdaExpr = dynamic_cast<::duckdb::LambdaExpression&>(expr); @@ -604,18 +617,17 @@ std::shared_ptr parseLambdaExpr( std::vector names; if (auto fieldExpr = std::dynamic_pointer_cast(capture)) { - names.push_back(fieldExpr->getFieldName()); + names.push_back(fieldExpr->name()); } else if ( auto callExpr = std::dynamic_pointer_cast(capture)) { VELOX_CHECK_EQ( - toFullFunctionName("row", options.functionPrefix), - callExpr->getFunctionName()); - for (auto& input : callExpr->getInputs()) { + toFullFunctionName("row", options.functionPrefix), callExpr->name()); + for (auto& input : callExpr->inputs()) { auto fieldExpr = std::dynamic_pointer_cast(input); VELOX_CHECK_NOT_NULL(fieldExpr); - names.push_back(fieldExpr->getFieldName()); + names.push_back(fieldExpr->name()); } } else { VELOX_FAIL( @@ -627,9 +639,7 @@ std::shared_ptr parseLambdaExpr( std::move(names), std::move(body)); } -std::shared_ptr parseExpr( - ParsedExpression& expr, - const ParseOptions& options) { +core::ExprPtr parseExpr(ParsedExpression& expr, const ParseOptions& options) { switch (expr.GetExpressionClass()) { case ExpressionClass::CONSTANT: return parseConstantExpr(expr, options); @@ -689,19 +699,19 @@ std::unique_ptr<::duckdb::ParsedExpression> parseSingleExpression( } } // namespace -std::shared_ptr parseExpr( +core::ExprPtr parseExpr( const std::string& exprString, const ParseOptions& options) { auto parsed = parseSingleExpression(exprString); return parseExpr(*parsed, options); } -std::vector> parseMultipleExpressions( +std::vector parseMultipleExpressions( const std::string& exprString, const ParseOptions& options) { auto parsedExpressions = parseExpression(exprString); VELOX_CHECK_GT(parsedExpressions.size(), 0); - std::vector> exprs; + std::vector exprs; exprs.reserve(parsedExpressions.size()); for (const auto& parsedExpr : parsedExpressions) { exprs.push_back(parseExpr(*parsedExpr, options)); @@ -745,8 +755,7 @@ bool isNullsFirst( } } // namespace -std::pair, core::SortOrder> parseOrderByExpr( - const std::string& exprString) { +OrderByClause parseOrderByExpr(const std::string& exprString) { ParserOptions options; ParseOptions parseOptions; options.preserve_identifier_case = false; @@ -763,8 +772,9 @@ std::pair, core::SortOrder> parseOrderByExpr( const bool nullsFirst = isNullsFirst(orderByNode.null_order, exprString); return { - parseExpr(*orderByNode.expression, parseOptions), - core::SortOrder(ascending, nullsFirst)}; + .expr = parseExpr(*orderByNode.expression, parseOptions), + .ascending = ascending, + .nullsFirst = nullsFirst}; } AggregateExpr parseAggregateExpr( @@ -783,8 +793,10 @@ AggregateExpr parseAggregateExpr( const bool ascending = isAscending(orderByNode.type, exprString); const bool nullsFirst = isNullsFirst(orderByNode.null_order, exprString); aggregateExpr.orderBy.emplace_back( - parseExpr(*orderByNode.expression, options), - core::SortOrder(ascending, nullsFirst)); + OrderByClause{ + parseExpr(*orderByNode.expression, options), + ascending, + nullsFirst}); } } @@ -856,11 +868,13 @@ IExprWindowFunction parseWindowExpr( const bool ascending = isAscending(orderByNode.type, windowString); const bool nullsFirst = isNullsFirst(orderByNode.null_order, windowString); windowIExpr.orderBy.emplace_back( - parseExpr(*orderByNode.expression, options), - core::SortOrder(ascending, nullsFirst)); + OrderByClause{ + parseExpr(*orderByNode.expression, options), + ascending, + nullsFirst}); } - std::vector> params; + std::vector params; params.reserve(windowExpr.children.size()); for (const auto& c : windowExpr.children) { params.emplace_back(parseExpr(*c, options)); @@ -894,4 +908,12 @@ IExprWindowFunction parseWindowExpr( return windowIExpr; } +std::string OrderByClause::toString() const { + return fmt::format( + "{} {} NULLS {}", + expr->toString(), + (ascending ? "ASC" : "DESC"), + (nullsFirst ? "FIRST" : "LAST")); +} + } // namespace facebook::velox::duckdb diff --git a/velox/duckdb/conversion/DuckParser.h b/velox/duckdb/conversion/DuckParser.h index 05d5d5a91bae..413e9280561a 100644 --- a/velox/duckdb/conversion/DuckParser.h +++ b/velox/duckdb/conversion/DuckParser.h @@ -15,21 +15,27 @@ */ #pragma once -#include #include -#include - -namespace facebook::velox::core { -class IExpr; -class SortOrder; -} // namespace facebook::velox::core +#include "velox/parse/IExpr.h" namespace facebook::velox::duckdb { + +struct OrderByClause { + core::ExprPtr expr; + bool ascending; + bool nullsFirst; + + std::string toString() const; +}; + /// Hold parsing options. struct ParseOptions { // Retain legacy behavior by default. bool parseDecimalAsDouble = true; bool parseIntegerAsBigint = true; + // Whether to parse the values in an IN list as separate arguments or as a + // single array argument. + bool parseInListAsArray = true; /// SQL functions could be registered with different prefixes by the user. /// This parameter is the registered prefix of presto or spark functions, @@ -44,20 +50,19 @@ struct ParseOptions { // are lower-cased, what prevents you to use functions and column names // containing upper case letters (e.g: "concatRow" will be parsed as // "concatrow"). -std::shared_ptr parseExpr( +core::ExprPtr parseExpr( const std::string& exprString, const ParseOptions& options); -std::vector> parseMultipleExpressions( +std::vector parseMultipleExpressions( const std::string& exprString, const ParseOptions& options); struct AggregateExpr { - std::shared_ptr expr; - std::vector, core::SortOrder>> - orderBy; + core::ExprPtr expr; + std::vector orderBy; bool distinct{false}; - std::shared_ptr maskExpr{nullptr}; + core::ExprPtr maskExpr{nullptr}; }; /// Parses aggregate function call expression with optional ORDER by clause. @@ -69,11 +74,9 @@ AggregateExpr parseAggregateExpr( const std::string& exprString, const ParseOptions& options); -// Parses an ORDER BY clause using DuckDB's internal postgresql-based parser, -// converting it to a pair of an IExpr tree and a core::SortOrder. Uses ASC -// NULLS LAST as the default sort order. -std::pair, core::SortOrder> parseOrderByExpr( - const std::string& exprString); +// Parses an ORDER BY clause using DuckDB's internal postgresql-based parser. +// Uses ASC NULLS LAST as the default sort order. +OrderByClause parseOrderByExpr(const std::string& exprString); // Parses a WINDOW function SQL string using DuckDB's internal postgresql-based // parser. Window Functions are executed by Velox Window PlanNodes and not the @@ -93,19 +96,18 @@ enum class BoundType { struct IExprWindowFrame { WindowType type; BoundType startType; - std::shared_ptr startValue; + core::ExprPtr startValue; BoundType endType; - std::shared_ptr endValue; + core::ExprPtr endValue; }; struct IExprWindowFunction { - std::shared_ptr functionCall; + core::ExprPtr functionCall; IExprWindowFrame frame; bool ignoreNulls; - std::vector> partitionBy; - std::vector, core::SortOrder>> - orderBy; + std::vector partitionBy; + std::vector orderBy; }; IExprWindowFunction parseWindowExpr( diff --git a/velox/duckdb/conversion/tests/CMakeLists.txt b/velox/duckdb/conversion/tests/CMakeLists.txt index 3a4607dd8a89..afe10aa20fa5 100644 --- a/velox/duckdb/conversion/tests/CMakeLists.txt +++ b/velox/duckdb/conversion/tests/CMakeLists.txt @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_duckdb_conversion_test DuckConversionTest.cpp - DuckParserTest.cpp) +add_executable(velox_duckdb_conversion_test DuckConversionTest.cpp DuckParserTest.cpp) add_test(velox_duckdb_conversion_test velox_duckdb_conversion_test) target_link_libraries( velox_duckdb_conversion_test velox_duckdb_parser + velox_common_base velox_parse_expression velox_functions_prestosql velox_functions_lib velox_functions_test_lib GTest::gtest GTest::gtest_main - gflags::gflags) + gflags::gflags +) diff --git a/velox/duckdb/conversion/tests/DuckParserTest.cpp b/velox/duckdb/conversion/tests/DuckParserTest.cpp index 85fb67cbb219..f4b4044e016a 100644 --- a/velox/duckdb/conversion/tests/DuckParserTest.cpp +++ b/velox/duckdb/conversion/tests/DuckParserTest.cpp @@ -43,8 +43,8 @@ TEST(DuckParserTest, constants) { EXPECT_EQ("-303.1234", parseExpr("-303.1234")->toString()); // Strings. - EXPECT_EQ("\"\"", parseExpr("''")->toString()); - EXPECT_EQ("\"hello world\"", parseExpr("'hello world'")->toString()); + EXPECT_EQ("", parseExpr("''")->toString()); + EXPECT_EQ("hello world", parseExpr("'hello world'")->toString()); // Nulls EXPECT_EQ("null", parseExpr("NULL")->toString()); @@ -57,16 +57,15 @@ TEST(DuckParserTest, constants) { TEST(DuckParserTest, arrays) { // Literal arrays with different types. - EXPECT_EQ("[1,2,-33]", parseExpr("ARRAY[1, 2, -33]")->toString()); + EXPECT_EQ("{1, 2, -33}", parseExpr("ARRAY[1, 2, -33]")->toString()); EXPECT_EQ( - "[1.99,-8.3,0.878]", parseExpr("ARRAY[1.99, -8.3, 0.878]")->toString()); + "{1.99, -8.3, 0.878}", parseExpr("ARRAY[1.99, -8.3, 0.878]")->toString()); EXPECT_EQ( - "[\"asd\",\"qwe\",\"ewqq\"]", - parseExpr("ARRAY['asd', 'qwe', 'ewqq']")->toString()); - EXPECT_EQ("[null,1]", parseExpr("ARRAY[NULL, 1]")->toString()); + "{asd, qwe, ewqq}", parseExpr("ARRAY['asd', 'qwe', 'ewqq']")->toString()); + EXPECT_EQ("{null, 1}", parseExpr("ARRAY[NULL, 1]")->toString()); // Empty array. - EXPECT_EQ("[]", parseExpr("ARRAY[]")->toString()); + EXPECT_EQ("", parseExpr("ARRAY[]")->toString()); // Expressions with variables and without. EXPECT_EQ( @@ -92,8 +91,7 @@ TEST(DuckParserTest, variables) { TEST(DuckParserTest, functions) { EXPECT_EQ("avg(\"col1\")", parseExpr("avg(col1)")->toString()); - EXPECT_EQ( - "func(1,3.4,\"str\")", parseExpr("func(1, 3.4, 'str')")->toString()); + EXPECT_EQ("func(1,3.4,str)", parseExpr("func(1, 3.4, 'str')")->toString()); // Nested calls. EXPECT_EQ( @@ -101,10 +99,7 @@ TEST(DuckParserTest, functions) { } namespace { -std::string toString( - const std::vector< - std::pair, core::SortOrder>>& - orderBy) { +std::string toString(const std::vector& orderBy) { std::stringstream out; if (!orderBy.empty()) { out << "ORDER BY "; @@ -112,8 +107,7 @@ std::string toString( if (i > 0) { out << ", "; } - out << orderBy[i].first->toString() << " " - << orderBy[i].second.toString(); + out << orderBy[i].toString(); } } @@ -176,11 +170,10 @@ TEST(DuckParserTest, subscript) { "subscript(\"col\",plus(10,99))", parseExpr("col[10 + 99]")->toString()); EXPECT_EQ("subscript(\"m1\",34)", parseExpr("m1[34]")->toString()); EXPECT_EQ( - "subscript(\"m2\",\"key str\")", parseExpr("m2['key str']")->toString()); + "subscript(\"m2\",key str)", parseExpr("m2['key str']")->toString()); EXPECT_EQ( - "subscript(func(),\"key str\")", - parseExpr("func()['key str']")->toString()); + "subscript(func(),key str)", parseExpr("func()['key str']")->toString()); } TEST(DuckParserTest, coalesce) { @@ -191,52 +184,105 @@ TEST(DuckParserTest, coalesce) { } TEST(DuckParserTest, in) { - EXPECT_EQ("in(\"col1\",[1,2,3])", parseExpr("col1 in (1, 2, 3)")->toString()); EXPECT_EQ( - "in(\"col1\",[1,2,null,3])", + "in(\"col1\",{1, 2, 3})", parseExpr("col1 in (1, 2, 3)")->toString()); + EXPECT_EQ( + "in(\"col1\",{1, 2, null, 3})", parseExpr("col1 in (1, 2, null, 3)")->toString()); EXPECT_EQ( - "in(\"col1\",[\"a\",\"b\",\"c\"])", + "in(\"col1\",{a, b, c})", parseExpr("col1 in ('a', 'b', 'c')")->toString()); EXPECT_EQ( - "in(\"col1\",[\"a\",null,\"b\",\"c\"])", + "in(\"col1\",{a, null, b, c})", parseExpr("col1 in ('a', null, 'b', 'c')")->toString()); } -TEST(DuckParserTest, notin) { +TEST(DuckParserTest, inListAsArray) { + ParseOptions parseOptions{.parseInListAsArray = false, .functionPrefix = ""}; + EXPECT_EQ( + "in(\"col1\",1,2,3)", + parseExpr("col1 in (1, 2, 3)", parseOptions)->toString()); + EXPECT_EQ( + "in(\"col1\",1,2,null,3)", + parseExpr("col1 in (1, 2, null, 3)", parseOptions)->toString()); + EXPECT_EQ( + "in(\"col1\",a,b,c)", + parseExpr("col1 in ('a', 'b', 'c')", parseOptions)->toString()); + EXPECT_EQ( + "in(\"col1\",a,null,b,c)", + parseExpr("col1 in ('a', null, 'b', 'c')", parseOptions)->toString()); +} + +TEST(DuckParserTest, notIn) { EXPECT_EQ( - "not(in(\"col1\",[1,2,3]))", + "not(in(\"col1\",{1, 2, 3}))", parseExpr("col1 not in (1, 2, 3)")->toString()); EXPECT_EQ( - "not(in(\"col1\",[1,2,3]))", + "not(in(\"col1\",{1, 2, 3}))", parseExpr("not(col1 in (1, 2, 3))")->toString()); EXPECT_EQ( - "not(in(\"col1\",[1,2,null,3]))", + "not(in(\"col1\",{1, 2, null, 3}))", parseExpr("col1 not in (1, 2, null, 3)")->toString()); EXPECT_EQ( - "not(in(\"col1\",[1,2,null,3]))", + "not(in(\"col1\",{1, 2, null, 3}))", parseExpr("not(col1 in (1, 2, null, 3))")->toString()); EXPECT_EQ( - "not(in(\"col1\",[\"a\",\"b\",\"c\"]))", + "not(in(\"col1\",{a, b, c}))", parseExpr("col1 not in ('a', 'b', 'c')")->toString()); EXPECT_EQ( - "not(in(\"col1\",[\"a\",\"b\",\"c\"]))", + "not(in(\"col1\",{a, b, c}))", parseExpr("not(col1 in ('a', 'b', 'c'))")->toString()); EXPECT_EQ( - "not(in(\"col1\",[\"a\",null,\"b\",\"c\"]))", + "not(in(\"col1\",{a, null, b, c}))", parseExpr("col1 not in ('a', null, 'b', 'c')")->toString()); EXPECT_EQ( - "not(in(\"col1\",[\"a\",null,\"b\",\"c\"]))", + "not(in(\"col1\",{a, null, b, c}))", parseExpr("not(col1 in ('a', null, 'b', 'c'))")->toString()); } +TEST(DuckParserTest, notInListAsArray) { + ParseOptions parseOptions{.parseInListAsArray = false, .functionPrefix = ""}; + EXPECT_EQ( + "not(in(\"col1\",1,2,3))", + parseExpr("col1 not in (1, 2, 3)", parseOptions)->toString()); + + EXPECT_EQ( + "not(in(\"col1\",1,2,3))", + parseExpr("not(col1 in (1, 2, 3))", parseOptions)->toString()); + + EXPECT_EQ( + "not(in(\"col1\",1,2,null,3))", + parseExpr("col1 not in (1, 2, null, 3)", parseOptions)->toString()); + + EXPECT_EQ( + "not(in(\"col1\",1,2,null,3))", + parseExpr("not(col1 in (1, 2, null, 3))", parseOptions)->toString()); + + EXPECT_EQ( + "not(in(\"col1\",a,b,c))", + parseExpr("col1 not in ('a', 'b', 'c')", parseOptions)->toString()); + + EXPECT_EQ( + "not(in(\"col1\",a,b,c))", + parseExpr("not(col1 in ('a', 'b', 'c'))", parseOptions)->toString()); + + EXPECT_EQ( + "not(in(\"col1\",a,null,b,c))", + parseExpr("col1 not in ('a', null, 'b', 'c')", parseOptions)->toString()); + + EXPECT_EQ( + "not(in(\"col1\",a,null,b,c))", + parseExpr("not(col1 in ('a', null, 'b', 'c'))", parseOptions) + ->toString()); +} + TEST(DuckParserTest, expressions) { // Comparisons. EXPECT_EQ("eq(1,0)", parseExpr("1 = 0")->toString()); @@ -338,36 +384,34 @@ TEST(DuckParserTest, intervalYearMonth) { } TEST(DuckParserTest, cast) { + EXPECT_EQ("cast(1 as BIGINT)", parseExpr("cast('1' as bigint)")->toString()); EXPECT_EQ( - "cast(\"1\", BIGINT)", parseExpr("cast('1' as bigint)")->toString()); + "cast(0.99 as INTEGER)", parseExpr("cast(0.99 as integer)")->toString()); EXPECT_EQ( - "cast(0.99, INTEGER)", parseExpr("cast(0.99 as integer)")->toString()); + "cast(0.99 as SMALLINT)", + parseExpr("cast(0.99 as smallint)")->toString()); EXPECT_EQ( - "cast(0.99, SMALLINT)", parseExpr("cast(0.99 as smallint)")->toString()); + "cast(0.99 as TINYINT)", parseExpr("cast(0.99 as tinyint)")->toString()); + EXPECT_EQ("cast(1 as DOUBLE)", parseExpr("cast(1 as double)")->toString()); EXPECT_EQ( - "cast(0.99, TINYINT)", parseExpr("cast(0.99 as tinyint)")->toString()); - EXPECT_EQ("cast(1, DOUBLE)", parseExpr("cast(1 as double)")->toString()); + "cast(\"col1\" as REAL)", parseExpr("cast(col1 as real)")->toString()); EXPECT_EQ( - "cast(\"col1\", REAL)", parseExpr("cast(col1 as real)")->toString()); + "cast(\"col1\" as REAL)", parseExpr("cast(col1 as float)")->toString()); EXPECT_EQ( - "cast(\"col1\", REAL)", parseExpr("cast(col1 as float)")->toString()); + "cast(0.99 as VARCHAR)", parseExpr("cast(0.99 as string)")->toString()); EXPECT_EQ( - "cast(0.99, VARCHAR)", parseExpr("cast(0.99 as string)")->toString()); + "cast(0.99 as VARCHAR)", parseExpr("cast(0.99 as varchar)")->toString()); + EXPECT_EQ("abc", parseExpr("cast('abc' as varbinary)")->toString()); EXPECT_EQ( - "cast(0.99, VARCHAR)", parseExpr("cast(0.99 as varchar)")->toString()); - // Cast varchar to varbinary produces a varbinary value which is serialized - // using base64 encoding. - EXPECT_EQ("\"YWJj\"", parseExpr("cast('abc' as varbinary)")->toString()); - EXPECT_EQ( - "cast(\"str_col\", TIMESTAMP)", + "cast(\"str_col\" as TIMESTAMP)", parseExpr("cast(str_col as timestamp)")->toString()); EXPECT_EQ( - "cast(\"str_col\", DATE)", + "cast(\"str_col\" as DATE)", parseExpr("cast(str_col as date)")->toString()); EXPECT_EQ( - "cast(\"str_col\", INTERVAL DAY TO SECOND)", + "cast(\"str_col\" as INTERVAL DAY TO SECOND)", parseExpr("cast(str_col as interval day to second)")->toString()); // Unsupported casts for now. @@ -375,39 +419,39 @@ TEST(DuckParserTest, cast) { // Complex types. EXPECT_EQ( - "cast(\"c0\", ARRAY)", parseExpr("c0::bigint[]")->toString()); + "cast(\"c0\" as ARRAY)", parseExpr("c0::bigint[]")->toString()); EXPECT_EQ( - "cast(\"c0\", ARRAY)", + "cast(\"c0\" as ARRAY)", parseExpr("cast(c0 as bigint[])")->toString()); EXPECT_EQ( - "cast(\"c0\", MAP)", + "cast(\"c0\" as MAP)", parseExpr("c0::map(varchar, bigint)")->toString()); EXPECT_EQ( - "cast(\"c0\", MAP)", + "cast(\"c0\" as MAP)", parseExpr("cast(c0 as map(varchar, bigint))")->toString()); EXPECT_EQ( - "cast(\"c0\", ROW)", + "cast(\"c0\" as ROW)", parseExpr("c0::struct(a bigint, b real, c varchar)")->toString()); EXPECT_EQ( - "cast(\"c0\", ROW)", + "cast(\"c0\" as ROW)", parseExpr("cast(c0 as struct(a bigint, b real, c varchar))")->toString()); } TEST(DuckParserTest, castToJson) { registerJsonType(); - EXPECT_EQ("cast(\"c0\", JSON)", parseExpr("cast(c0 as json)")->toString()); - EXPECT_EQ("cast(\"c0\", JSON)", parseExpr("cast(c0 as JSON)")->toString()); + EXPECT_EQ("cast(\"c0\" as JSON)", parseExpr("cast(c0 as json)")->toString()); + EXPECT_EQ("cast(\"c0\" as JSON)", parseExpr("cast(c0 as JSON)")->toString()); } TEST(DuckParserTest, castToTimestampWithTimeZone) { registerTimestampWithTimeZoneType(); EXPECT_EQ( - "cast(\"c0\", TIMESTAMP WITH TIME ZONE)", + "cast(\"c0\" as TIMESTAMP WITH TIME ZONE)", parseExpr("cast(c0 as timestamp with time zone)")->toString()); EXPECT_EQ( - "cast(\"c0\", TIMESTAMP WITH TIME ZONE)", + "cast(\"c0\" as TIMESTAMP WITH TIME ZONE)", parseExpr("cast(c0 as TIMESTAMP WITH TIME ZONE)")->toString()); } @@ -438,7 +482,7 @@ TEST(DuckParserTest, switchCase) { parseExpr("case when a > 0 then 1 when a < 0 then -1end")->toString()); EXPECT_EQ( - "switch(eq(\"a\",1),\"x\",eq(\"a\",5),\"y\",\"z\")", + "switch(eq(\"a\",1),x,eq(\"a\",5),y,z)", parseExpr("case a when 1 then 'x' when 5 then 'y' else 'z' end") ->toString()); } @@ -479,24 +523,23 @@ TEST(DuckParserTest, alias) { "gt(\"a\",\"b\") AS result", parseExpr("a > b AS result")->toString()); EXPECT_EQ("2 AS multiplier", parseExpr("2 AS multiplier")->toString()); EXPECT_EQ( - "cast(\"a\", DOUBLE) AS a_double", + "cast(\"a\" as DOUBLE) AS a_double", parseExpr("cast(a AS DOUBLE) AS a_double")->toString()); EXPECT_EQ("\"a\" AS b", parseExpr("a AS b")->toString()); } TEST(DuckParserTest, like) { - EXPECT_EQ("like(\"name\",\"%b%\")", parseExpr("name LIKE '%b%'")->toString()); + EXPECT_EQ("like(\"name\",%b%)", parseExpr("name LIKE '%b%'")->toString()); EXPECT_EQ( - "like(\"name\",\"%#_%\",\"#\")", + "like(\"name\",%#_%,#)", parseExpr("name LIKE '%#_%' ESCAPE '#'")->toString()); } TEST(DuckParserTest, notLike) { EXPECT_EQ( - "not(like(\"name\",\"%b%\"))", - parseExpr("name NOT LIKE '%b%'")->toString()); + "not(like(\"name\",%b%))", parseExpr("name NOT LIKE '%b%'")->toString()); EXPECT_EQ( - "not(like(\"name\",\"%#_%\",\"#\"))", + "not(like(\"name\",%#_%,#))", parseExpr("name NOT LIKE '%#_%' ESCAPE '#'")->toString()); } @@ -509,9 +552,7 @@ TEST(DuckParserTest, count) { TEST(DuckParserTest, orderBy) { auto parse = [](const auto& expr) { - auto orderBy = parseOrderByExpr(expr); - return fmt::format( - "{} {}", orderBy.first->toString(), orderBy.second.toString()); + return parseOrderByExpr(expr).toString(); }; EXPECT_EQ("\"c1\" ASC NULLS LAST", parse("c1")); @@ -612,8 +653,9 @@ TEST(DuckParserTest, window) { EXPECT_EQ( "row_number() OVER (PARTITION BY \"a\" ORDER BY \"b\" ASC NULLS LAST " "ROWS BETWEEN plus(\"a\",10) PRECEDING AND 10 FOLLOWING)", - parseWindow("row_number() over (partition by a order by b " - "rows between a + 10 preceding and 10 following)")); + parseWindow( + "row_number() over (partition by a order by b " + "rows between a + 10 preceding and 10 following)")); EXPECT_EQ( "row_number() OVER (PARTITION BY \"a\" ORDER BY \"b\" DESC NULLS FIRST " "ROWS BETWEEN plus(\"a\",10) PRECEDING AND 10 FOLLOWING)", @@ -645,8 +687,8 @@ TEST(DuckParserTest, windowWithIntegerConstant) { std::dynamic_pointer_cast(windowExpr.functionCall); ASSERT_TRUE(func != nullptr) << windowExpr.functionCall->toString() << " is not a call expr"; - EXPECT_EQ(func->getInputs().size(), 2); - auto param = func->getInputs()[1]; + EXPECT_EQ(func->inputs().size(), 2); + auto param = func->inputs()[1]; auto constant = std::dynamic_pointer_cast(param); ASSERT_TRUE(constant != nullptr) << param->toString() << " is not a constant"; EXPECT_EQ(*constant->type(), *INTEGER()); @@ -686,13 +728,13 @@ TEST(DuckParserTest, parseWithPrefix) { ParseOptions options; options.functionPrefix = "prefix."; EXPECT_EQ( - "prefix.in(\"col1\",[1,2,3])", + "prefix.in(\"col1\",{1, 2, 3})", parseExpr("col1 in (1, 2, 3)", options)->toString()); EXPECT_EQ( - "prefix.like(\"name\",\"%b%\")", + "prefix.like(\"name\",%b%)", parseExpr("name LIKE '%b%'", options)->toString()); EXPECT_EQ( - "prefix.not(prefix.like(\"name\",\"%b%\"))", + "prefix.not(prefix.like(\"name\",%b%))", parseExpr("name NOT LIKE '%b%'", options)->toString()); // Arithmetic operators. @@ -748,7 +790,7 @@ TEST(DuckParserTest, parseWithPrefix) { EXPECT_EQ( "coalesce(null,0)", parseExpr("coalesce(NULL, 0)", options)->toString()); EXPECT_EQ( - "cast(\"1\", BIGINT)", + "cast(1 as BIGINT)", parseExpr("cast('1' as bigint)", options)->toString()); EXPECT_EQ( "try(prefix.plus(\"c0\",\"c1\"))", diff --git a/velox/dwio/CMakeLists.txt b/velox/dwio/CMakeLists.txt index d4a879e8f680..13a22460e902 100644 --- a/velox/dwio/CMakeLists.txt +++ b/velox/dwio/CMakeLists.txt @@ -28,7 +28,8 @@ velox_link_libraries( velox_type_fbhive velox_vector Folly::folly - fmt::fmt) + fmt::fmt +) add_subdirectory(common) add_subdirectory(catalog) diff --git a/velox/dwio/catalog/fbhive/CMakeLists.txt b/velox/dwio/catalog/fbhive/CMakeLists.txt index cb6a0d30b50f..17a778f41c3c 100644 --- a/velox/dwio/catalog/fbhive/CMakeLists.txt +++ b/velox/dwio/catalog/fbhive/CMakeLists.txt @@ -13,8 +13,7 @@ # limitations under the License. velox_add_library(velox_dwio_catalog_fbhive FileUtils.cpp) -velox_link_libraries(velox_dwio_catalog_fbhive velox_exception fmt::fmt - Folly::folly) +velox_link_libraries(velox_dwio_catalog_fbhive velox_exception fmt::fmt Folly::folly) if(${VELOX_BUILD_TESTING}) add_subdirectory(test) diff --git a/velox/dwio/catalog/fbhive/FileUtils.cpp b/velox/dwio/catalog/fbhive/FileUtils.cpp index 36d1b29d98e5..621ee20bbbf4 100644 --- a/velox/dwio/catalog/fbhive/FileUtils.cpp +++ b/velox/dwio/catalog/fbhive/FileUtils.cpp @@ -112,9 +112,10 @@ std::vector> extractPartitionKeyValues( std::vector tokens; folly::split('=', partitionPart, tokens); if (tokens.size() == 2) { - parsedParts.emplace_back(std::make_pair( - FileUtils::unescapePathName(tokens[0]), - FileUtils::unescapePathName(tokens[1]))); + parsedParts.emplace_back( + std::make_pair( + FileUtils::unescapePathName(tokens[0]), + FileUtils::unescapePathName(tokens[1]))); } }); } @@ -157,46 +158,29 @@ std::string FileUtils::unescapePathName(const std::string& data) { std::string FileUtils::makePartName( const std::vector>& entries, - bool partitionPathAsLowerCase) { - size_t size = 0; - size_t escapeCount = 0; - std::for_each(entries.begin(), entries.end(), [&](auto& pair) { - auto keySize = pair.first.size(); - VELOX_CHECK_GT(keySize, 0); - size += keySize; - escapeCount += countEscape(pair.first); - - auto valSize = pair.second.size(); - if (valSize == 0) { - size += kDefaultPartitionValue.size(); - } else { - size += valSize; - escapeCount += countEscape(pair.second); + bool partitionPathAsLowerCase, + bool useDefaultPartitionValue, + const EncodeFunction& encodeFunc) { + VELOX_CHECK(!entries.empty()); + std::ostringstream out; + + for (const auto& [key, value] : entries) { + VELOX_CHECK(!key.empty()); + if (out.tellp() > 0) { + out << '/'; } - }); - std::string ret; - ret.reserve(size + escapeCount * HEX_WIDTH + entries.size() - 1); - - std::for_each(entries.begin(), entries.end(), [&](auto& pair) { - if (ret.size() > 0) { - ret += "/"; - } - if (partitionPathAsLowerCase) { - ret += escapePathName(toLower(pair.first)); - } else { - ret += escapePathName(pair.first); - } + std::string keyToEncode = partitionPathAsLowerCase ? toLower(key) : key; + out << encodeFunc(keyToEncode) << '='; - ret += "="; - if (pair.second.size() == 0) { - ret += kDefaultPartitionValue; + if (value.empty() && useDefaultPartitionValue) { + out << kDefaultPartitionValue; } else { - ret += escapePathName(pair.second); + out << encodeFunc(value); } - }); + } - return ret; + return out.str(); } std::vector> FileUtils::parsePartKeyValues( diff --git a/velox/dwio/catalog/fbhive/FileUtils.h b/velox/dwio/catalog/fbhive/FileUtils.h index a8ca8bf07efd..519c274fc6f1 100644 --- a/velox/dwio/catalog/fbhive/FileUtils.h +++ b/velox/dwio/catalog/fbhive/FileUtils.h @@ -29,6 +29,10 @@ namespace fbhive { class FileUtils { public: + /// Function type for encoding partition key/value strings. + /// Takes a string to encode and returns the encoded string. + using EncodeFunction = std::function; + /// Converts the path name to be hive metastore compliant, will do /// url-encoding when needed. static std::string escapePathName(const std::string& data); @@ -39,9 +43,19 @@ class FileUtils { /// Creates the partition directory path from the list of partition key/value /// pairs, will do url-encoding when needed. + /// @param entries Vector of (key, value) pairs for partition columns. Cannot + /// be empty. + /// @param partitionPathAsLowerCase Whether to convert keys to lowercase + /// @param useDefaultPartitionValue If true, empty values are replaced with + /// kDefaultPartitionValue. If false, empty values are encoded as-is. + /// Defaults to true for Hive compatibility. + /// @param encodeFunc Function to use for encoding keys and values. + /// Defaults to escapePathName. static std::string makePartName( const std::vector>& entries, - bool partitionPathAsLowerCase); + bool partitionPathAsLowerCase, + bool useDefaultPartitionValue = true, + const EncodeFunction& encodeFunc = escapePathName); /// Converts the hive-metastore-compliant path name back to the corresponding /// partition key/value pairs. diff --git a/velox/dwio/catalog/fbhive/test/CMakeLists.txt b/velox/dwio/catalog/fbhive/test/CMakeLists.txt index d7fb52887727..fede218f0acb 100644 --- a/velox/dwio/catalog/fbhive/test/CMakeLists.txt +++ b/velox/dwio/catalog/fbhive/test/CMakeLists.txt @@ -20,4 +20,5 @@ target_link_libraries( velox_exception GTest::gtest GTest::gtest_main - GTest::gmock) + GTest::gmock +) diff --git a/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp b/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp index 042c5ba93080..579c7533ecd0 100644 --- a/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp +++ b/velox/dwio/catalog/fbhive/test/FileUtilsTests.cpp @@ -19,10 +19,12 @@ #include "velox/common/base/Exceptions.h" #include "velox/dwio/catalog/fbhive/FileUtils.h" +namespace facebook::velox::dwio::catalog::fbhive { +namespace { + using namespace ::testing; -using namespace facebook::velox::dwio::catalog::fbhive; -TEST(FileUtilsTests, MakePartName) { +TEST(FileUtilsTests, makePartName) { std::vector> pairs{ {"ds", "2016-01-01"}, {"FOO", ""}, {"a\nb:c", "a#b=c"}}; ASSERT_EQ( @@ -31,9 +33,22 @@ TEST(FileUtilsTests, MakePartName) { ASSERT_EQ( FileUtils::makePartName(pairs, false), "ds=2016-01-01/FOO=__HIVE_DEFAULT_PARTITION__/a%0Ab%3Ac=a%23b%3Dc"); + ASSERT_THROW(FileUtils::makePartName({}, false), VeloxException); +} + +TEST(FileUtilsTests, makePartNameWithoutDefaultPartitionValue) { + std::vector> pairs{ + {"ds", "2016-01-01"}, {"FOO", ""}, {"a\nb:c", "a#b=c"}}; + // Test with useDefaultPartitionValue = false. + ASSERT_EQ( + FileUtils::makePartName(pairs, true, false), + "ds=2016-01-01/foo=/a%0Ab%3Ac=a%23b%3Dc"); + ASSERT_EQ( + FileUtils::makePartName(pairs, false, false), + "ds=2016-01-01/FOO=/a%0Ab%3Ac=a%23b%3Dc"); } -TEST(FileUtilsTests, ParsePartKeyValues) { +TEST(FileUtilsTests, parsePartKeyValues) { EXPECT_THROW( FileUtils::parsePartKeyValues("ds"), facebook::velox::VeloxRuntimeError); EXPECT_THROW( @@ -60,7 +75,7 @@ TEST(FileUtilsTests, ParsePartKeyValues) { std::make_pair("a\nb:c", "a#b=c/"))); } -TEST(FileUtilsTests, ExtractPartitionName) { +TEST(FileUtilsTests, extractPartitionName) { struct TestCase { public: TestCase(const std::string& filePath, const std::string& partitionName) @@ -88,3 +103,6 @@ TEST(FileUtilsTests, ExtractPartitionName) { FileUtils::extractPartitionName(testCase.filePath)); } } + +} // namespace +} // namespace facebook::velox::dwio::catalog::fbhive diff --git a/velox/dwio/common/Adaptor.h b/velox/dwio/common/Adaptor.h index 2fef0a96655c..ae89074808fd 100644 --- a/velox/dwio/common/Adaptor.h +++ b/velox/dwio/common/Adaptor.h @@ -27,7 +27,7 @@ #define DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") #define DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") #else -#error("Unknown compiler") +#error ("Unknown compiler") #endif #define PRAGMA(TXT) _Pragma(#TXT) diff --git a/velox/dwio/common/Arena.h b/velox/dwio/common/Arena.h new file mode 100644 index 000000000000..9690ce727d35 --- /dev/null +++ b/velox/dwio/common/Arena.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace facebook::velox::dwio::common { + +/// Wrapper over protobuf's arena allocation. The API changes from +/// CreateMessage() to Create() in newer protobuf versions. +template +T* ArenaCreate(google::protobuf::Arena* arena, Args&&... args) { +#if GOOGLE_PROTOBUF_VERSION >= 5030000 + return google::protobuf::Arena::Create(arena, std::forward(args)...); +#else + return google::protobuf::Arena::CreateMessage( + arena, std::forward(args)...); +#endif +} + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/BitPackDecoder.h b/velox/dwio/common/BitPackDecoder.h index 2aa785e2d29f..5f31267026d8 100644 --- a/velox/dwio/common/BitPackDecoder.h +++ b/velox/dwio/common/BitPackDecoder.h @@ -109,7 +109,7 @@ static inline uint32_t unpackNaive( uint64_t inputBufferLen, uint64_t numValues, uint8_t bitWidth, - T* FOLLY_NONNULL& result); + T * FOLLY_NONNULL & result); /// Unpack numValues number of input values from inputBuffer. The results /// will be written to result. numValues must be a multiple of 8. The @@ -122,7 +122,7 @@ inline void unpack( uint64_t inputBufferLen, uint64_t numValues, uint8_t bitWidth, - T* FOLLY_NONNULL& result) { + T * FOLLY_NONNULL & result) { unpackNaive(inputBits, inputBufferLen, numValues, bitWidth, result); } @@ -159,7 +159,7 @@ static inline uint32_t unpackNaive( uint64_t inputBufferLen, uint64_t numValues, uint8_t bitWidth, - T* FOLLY_NONNULL& result) { + T * FOLLY_NONNULL & result) { VELOX_CHECK(bitWidth >= 1 && bitWidth <= sizeof(T) * 8); VELOX_CHECK(inputBufferLen * 8 >= bitWidth * numValues); diff --git a/velox/dwio/common/BufferedInput.cpp b/velox/dwio/common/BufferedInput.cpp index 8791e418f49e..6062bcfd22ad 100644 --- a/velox/dwio/common/BufferedInput.cpp +++ b/velox/dwio/common/BufferedInput.cpp @@ -114,8 +114,9 @@ std::unique_ptr BufferedInput::enqueue( // help faster lookup using enqueuedToBufferOffset_ later. [region, this, i = regions_.size() - 1]() { auto result = readInternal(region.offset, region.length, i); - VELOX_CHECK( - std::get<1>(result) != MAX_UINT64, + VELOX_CHECK_NE( + std::get<1>(result), + MAX_UINT64, "Fail to read region offset={} length={}", region.offset, region.length); diff --git a/velox/dwio/common/BufferedInput.h b/velox/dwio/common/BufferedInput.h index 1f877b3fa8d0..e2089c42a799 100644 --- a/velox/dwio/common/BufferedInput.h +++ b/velox/dwio/common/BufferedInput.h @@ -37,13 +37,15 @@ class BufferedInput { IoStatistics* stats = nullptr, filesystems::File::IoStats* fsStats = nullptr, uint64_t maxMergeDistance = kMaxMergeDistance, - std::optional wsVRLoad = std::nullopt) + std::optional wsVRLoad = std::nullopt, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::make_shared( std::move(readFile), metricsLog, stats, - fsStats), + fsStats, + std::move(fileReadOps)), pool, maxMergeDistance, wsVRLoad) {} diff --git a/velox/dwio/common/CMakeLists.txt b/velox/dwio/common/CMakeLists.txt index 55fadf6c2f21..2023728686bc 100644 --- a/velox/dwio/common/CMakeLists.txt +++ b/velox/dwio/common/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(compression) add_subdirectory(encryption) add_subdirectory(exception) +add_subdirectory(wrap) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) @@ -45,6 +46,7 @@ velox_add_library( MetadataFilter.cpp Options.cpp OutputStream.cpp + ParallelUnitLoader.cpp ParallelFor.cpp Range.cpp Reader.cpp @@ -53,6 +55,7 @@ velox_add_library( SeekableInputStream.cpp SelectiveByteRleColumnReader.cpp SelectiveColumnReader.cpp + SelectiveFlatMapColumnReader.cpp SelectiveRepeatedColumnReader.cpp SelectiveStructColumnReader.cpp SortingWriter.cpp @@ -61,7 +64,8 @@ velox_add_library( TypeUtils.cpp TypeWithId.cpp Writer.cpp - WriterFactory.cpp) + WriterFactory.cpp +) velox_include_directories(velox_dwio_common PRIVATE ${Protobuf_INCLUDE_DIRS}) @@ -78,7 +82,7 @@ velox_link_libraries( velox_expression velox_memory velox_type_tz - Boost::regex Folly::folly glog::glog - protobuf::libprotobuf) + protobuf::libprotobuf +) diff --git a/velox/dwio/common/CacheInputStream.cpp b/velox/dwio/common/CacheInputStream.cpp index dedda0c72f6a..cbf1f6ecf609 100644 --- a/velox/dwio/common/CacheInputStream.cpp +++ b/velox/dwio/common/CacheInputStream.cpp @@ -142,8 +142,8 @@ bool CacheInputStream::SkipInt64(int64_t count) { return false; } -google::protobuf::int64 CacheInputStream::ByteCount() const { - return static_cast(position_); +int64_t CacheInputStream::ByteCount() const { + return static_cast(position_); } void CacheInputStream::seekToPosition(PositionProvider& seekPosition) { @@ -394,7 +394,7 @@ velox::common::Region CacheInputStream::nextQuantizedLoadRegion( nextRegion.offset += (prevLoadedPosition / loadQuantum_) * loadQuantum_; // Set length to be the lesser of 'loadQuantum_' and distance to end of // 'region_' - nextRegion.length = std::min( + nextRegion.length = std::min( loadQuantum_, region_.length - (nextRegion.offset - region_.offset)); return nextRegion; } diff --git a/velox/dwio/common/CacheInputStream.h b/velox/dwio/common/CacheInputStream.h index 195bdbdd1351..1e1dab8d47a9 100644 --- a/velox/dwio/common/CacheInputStream.h +++ b/velox/dwio/common/CacheInputStream.h @@ -51,7 +51,7 @@ class CacheInputStream : public SeekableInputStream { bool Next(const void** data, int* size) override; void BackUp(int count) override; bool SkipInt64(int64_t count) override; - google::protobuf::int64 ByteCount() const override; + int64_t ByteCount() const override; void seekToPosition(PositionProvider& position) override; std::string getName() const override; size_t positionSize() const override; diff --git a/velox/dwio/common/CachedBufferedInput.cpp b/velox/dwio/common/CachedBufferedInput.cpp index 55ee71e49468..a6255a87c234 100644 --- a/velox/dwio/common/CachedBufferedInput.cpp +++ b/velox/dwio/common/CachedBufferedInput.cpp @@ -48,20 +48,20 @@ std::unique_ptr CachedBufferedInput::enqueue( } VELOX_CHECK_LE(region.offset + region.length, fileSize_); requests_.emplace_back( - RawFileCacheKey{fileNum_, region.offset}, region.length, id); + RawFileCacheKey{fileNum_.id(), region.offset}, region.length, id); if (tracker_ != nullptr) { - tracker_->recordReference(id, region.length, fileNum_, groupId_); + tracker_->recordReference(id, region.length, fileNum_.id(), groupId_.id()); } auto stream = std::make_unique( this, ioStats_.get(), region, input_, - fileNum_, + fileNum_.id(), options_.noCacheRetention(), tracker_, id, - groupId_, + groupId_.id(), options_.loadQuantum()); requests_.back().stream = stream.get(); return stream; @@ -127,10 +127,11 @@ std::vector makeRequestParts( std::vector parts; for (uint64_t offset = 0; offset < request.size; offset += loadQuantum) { const int32_t size = std::min(loadQuantum, request.size - offset); - extraRequests.push_back(std::make_unique( - RawFileCacheKey{request.key.fileNum, request.key.offset + offset}, - size, - request.trackingId)); + extraRequests.push_back( + std::make_unique( + RawFileCacheKey{request.key.fileNum, request.key.offset + offset}, + size, + request.trackingId)); parts.push_back(extraRequests.back().get()); parts.back()->coalesces = prefetch; if (prefetchOne) { @@ -166,7 +167,7 @@ void CachedBufferedInput::load(const LogType /*unused*/) { cache::SsdFile* ssdFile{nullptr}; auto* ssdCache = cache_->ssdCache(); if (ssdCache != nullptr) { - ssdFile = &ssdCache->file(fileNum_); + ssdFile = &ssdCache->file(fileNum_.id()); } // Extra requests made for pre-loadable regions that are larger than @@ -467,14 +468,14 @@ void CachedBufferedInput::readRegion( std::shared_ptr load; if (!requests[0]->ssdPin.empty()) { load = std::make_shared( - *cache_, ioStats_, fsStats_, groupId_, requests); + *cache_, ioStats_, fsStats_, groupId_.id(), requests); } else { load = std::make_shared( *cache_, input_, ioStats_, fsStats_, - groupId_, + groupId_.id(), requests, options_.maxCoalesceDistance()); } @@ -553,7 +554,7 @@ std::unique_ptr CachedBufferedInput::read( ioStats_.get(), Region{offset, length}, input_, - fileNum_, + fileNum_.id(), options_.noCacheRetention(), nullptr, TrackingId(), diff --git a/velox/dwio/common/CachedBufferedInput.h b/velox/dwio/common/CachedBufferedInput.h index 6782c2dacf54..4fb2475f5c85 100644 --- a/velox/dwio/common/CachedBufferedInput.h +++ b/velox/dwio/common/CachedBufferedInput.h @@ -27,8 +27,6 @@ #include "velox/dwio/common/CacheInputStream.h" #include "velox/dwio/common/InputStream.h" -DECLARE_int32(cache_load_quantum); - namespace facebook::velox::dwio::common { struct CacheRequest { @@ -57,24 +55,28 @@ class CachedBufferedInput : public BufferedInput { CachedBufferedInput( std::shared_ptr readFile, const MetricsLogPtr& metricsLog, - uint64_t fileNum, + StringIdLease fileNum, cache::AsyncDataCache* cache, std::shared_ptr tracker, - uint64_t groupId, + StringIdLease groupId, std::shared_ptr ioStats, std::shared_ptr fsStats, folly::Executor* executor, - const io::ReaderOptions& readerOptions) + const io::ReaderOptions& readerOptions, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::move(readFile), readerOptions.memoryPool(), metricsLog, ioStats.get(), - fsStats.get()), + fsStats.get(), + kMaxMergeDistance, + std::nullopt, + std::move(fileReadOps)), cache_(cache), - fileNum_(fileNum), + fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), - groupId_(groupId), + groupId_(std::move(groupId)), ioStats_(std::move(ioStats)), fsStats_(std::move(fsStats)), executor_(executor), @@ -85,19 +87,19 @@ class CachedBufferedInput : public BufferedInput { CachedBufferedInput( std::shared_ptr input, - uint64_t fileNum, + StringIdLease fileNum, cache::AsyncDataCache* cache, std::shared_ptr tracker, - uint64_t groupId, + StringIdLease groupId, std::shared_ptr ioStats, std::shared_ptr fsStats, folly::Executor* executor, const io::ReaderOptions& readerOptions) : BufferedInput(std::move(input), readerOptions.memoryPool()), cache_(cache), - fileNum_(fileNum), + fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), - groupId_(groupId), + groupId_(std::move(groupId)), ioStats_(std::move(ioStats)), fsStats_(std::move(fsStats)), executor_(executor), @@ -140,7 +142,7 @@ class CachedBufferedInput : public BufferedInput { void setNumStripes(int32_t numStripes) override { auto stats = tracker_->fileGroupStats(); if (stats) { - stats->recordFile(fileNum_, groupId_, numStripes); + stats->recordFile(fileNum_.id(), groupId_.id(), numStripes); } } @@ -210,9 +212,9 @@ class CachedBufferedInput : public BufferedInput { } cache::AsyncDataCache* const cache_; - const uint64_t fileNum_; + const StringIdLease fileNum_; const std::shared_ptr tracker_; - const uint64_t groupId_; + const StringIdLease groupId_; const std::shared_ptr ioStats_; const std::shared_ptr fsStats_; folly::Executor* const executor_; diff --git a/velox/dwio/common/ChainedBuffer.h b/velox/dwio/common/ChainedBuffer.h index 7b514d6b460b..38ba8d7c269c 100644 --- a/velox/dwio/common/ChainedBuffer.h +++ b/velox/dwio/common/ChainedBuffer.h @@ -20,10 +20,7 @@ #include "velox/common/base/GTestMacros.h" #include "velox/dwio/common/DataBuffer.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { namespace { @@ -228,7 +225,4 @@ class ChainedBuffer { VELOX_FRIEND_TEST(ChainedBufferTests, testClearAll); }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ColumnLoader.cpp b/velox/dwio/common/ColumnLoader.cpp index e4a000f295f5..7d62008440b9 100644 --- a/velox/dwio/common/ColumnLoader.cpp +++ b/velox/dwio/common/ColumnLoader.cpp @@ -87,61 +87,6 @@ void scatter(RowSet rows, vector_size_t resultSize, VectorPtr* result) { *result = BaseVector::wrapInDictionary(nullptr, indices, resultSize, *result); } -template -void addToHookImpl( - const DecodedVector& decoded, - const RowSet& rows, - ValueHook& hook) { - if (decoded.isIdentityMapping()) { - auto* values = decoded.data(); - hook.addValues(rows.data(), values, rows.size()); - return; - } - for (auto i : rows) { - if (!decoded.isNullAt(i)) { - hook.addValueTyped(i, decoded.valueAt(i)); - } else if (hook.acceptsNulls()) { - hook.addNull(i); - } - } -} - -void addToHook( - const DecodedVector& decoded, - const RowSet& rows, - ValueHook& hook) { - switch (decoded.base()->typeKind()) { - case TypeKind::BOOLEAN: - addToHookImpl(decoded, rows, hook); - break; - case TypeKind::TINYINT: - addToHookImpl(decoded, rows, hook); - break; - case TypeKind::SMALLINT: - addToHookImpl(decoded, rows, hook); - break; - case TypeKind::INTEGER: - addToHookImpl(decoded, rows, hook); - break; - case TypeKind::BIGINT: - addToHookImpl(decoded, rows, hook); - break; - case TypeKind::REAL: - addToHookImpl(decoded, rows, hook); - break; - case TypeKind::DOUBLE: - addToHookImpl(decoded, rows, hook); - break; - case TypeKind::VARCHAR: - case TypeKind::VARBINARY: - addToHookImpl(decoded, rows, hook); - break; - default: - VELOX_FAIL( - "Unsupported type kind for hook: {}", decoded.base()->typeKind()); - } -} - } // namespace void ColumnLoader::loadInternal( @@ -156,7 +101,7 @@ void ColumnLoader::loadInternal( ->debugString(); }, structReader_}); - raw_vector selectedRows; + raw_vector selectedRows(fieldReader_->memoryPool()); auto effectiveRows = read(structReader_, fieldReader_, version_, rows, selectedRows, hook); if (!hook) { @@ -189,17 +134,14 @@ void DeltaUpdateColumnLoader::loadInternal( // method return. VELOX_CHECK(!scanSpec->hasFilter()); scanSpec->setValueHook(nullptr); - raw_vector selectedRows; + raw_vector selectedRows(fieldReader_->memoryPool()); RowSet effectiveRows; effectiveRows = read(structReader_, fieldReader_, version_, rows, selectedRows, nullptr); fieldReader_->getValues(effectiveRows, result); scanSpec->deltaUpdate()->update(effectiveRows, *result); - if (hook) { - DecodedVector decoded(**result); - addToHook(decoded, effectiveRows, *hook); - } else if ( - rows.back() + 1 < resultSize || + VELOX_CHECK_NULL(hook); + if (rows.back() + 1 < resultSize || rows.size() != structReader_->outputRows().size()) { scatter(rows, resultSize, result); } diff --git a/velox/dwio/common/ColumnLoader.h b/velox/dwio/common/ColumnLoader.h index 2a2109f35cf0..71d86c7e0aaf 100644 --- a/velox/dwio/common/ColumnLoader.h +++ b/velox/dwio/common/ColumnLoader.h @@ -17,6 +17,7 @@ #pragma once #include "velox/dwio/common/SelectiveStructColumnReader.h" +#include "velox/vector/LazyVector.h" namespace facebook::velox::dwio::common { @@ -30,7 +31,13 @@ class ColumnLoader : public VectorLoader { fieldReader_(fieldReader), version_(version) {} - private: + virtual ~ColumnLoader() = default; + + bool supportsHook() const override { + return true; + } + + protected: void loadInternal( RowSet rows, ValueHook* hook, diff --git a/velox/dwio/common/ColumnSelector.h b/velox/dwio/common/ColumnSelector.h index 62408e55a213..9bef521d99a1 100644 --- a/velox/dwio/common/ColumnSelector.h +++ b/velox/dwio/common/ColumnSelector.h @@ -388,8 +388,10 @@ class ColumnSelector { // expect a runtime_error rather than fault. // Do-Not change the message as expected by client in failure case if (!notFound.empty()) { - throw std::runtime_error(folly::to( - "Columns not found in hive table: ", folly::join(", ", notFound))); + throw std::runtime_error( + folly::to( + "Columns not found in hive table: ", + folly::join(", ", notFound))); } } diff --git a/velox/dwio/common/ColumnVisitors.h b/velox/dwio/common/ColumnVisitors.h index 85798ee5bdf3..6e5f0db8187f 100644 --- a/velox/dwio/common/ColumnVisitors.h +++ b/velox/dwio/common/ColumnVisitors.h @@ -21,7 +21,6 @@ #include "velox/dwio/common/DecoderUtil.h" #include "velox/dwio/common/SelectiveColumnReader.h" #include "velox/dwio/common/TypeUtil.h" - namespace facebook::velox::dwio::common { // structs for extractValues in ColumnVisitor. @@ -164,7 +163,7 @@ class ColumnVisitor { static constexpr bool kFilterOnly = std::is_same_v; ColumnVisitor( - TFilter& filter, + const TFilter& filter, SelectiveColumnReader* reader, const RowSet& rows, ExtractValues values) @@ -409,7 +408,7 @@ class ColumnVisitor { inline void addNull(); inline void addOutputRow(vector_size_t row); - TFilter& filter() { + const TFilter& filter() { return filter_; } @@ -491,8 +490,8 @@ class ColumnVisitor { } protected: - TFilter& filter_; - SelectiveColumnReader* reader_; + const TFilter& filter_; + SelectiveColumnReader* const reader_; const bool allowNulls_; const vector_size_t* rows_; vector_size_t numRows_; @@ -724,7 +723,15 @@ inline xsimd::batch cvtU32toI64( xsimd::batch values) { return _mm256_cvtepu32_epi64(values); } -#elif XSIMD_WITH_SSE2 || XSIMD_WITH_NEON +#elif (XSIMD_WITH_SVE && SVE_BITS == 256) +inline xsimd::batch cvtU32toI64(simd::Batch128 values) { + int64_t element_1 = static_cast(values.data[0]); + int64_t element_2 = static_cast(values.data[1]); + int64_t element_3 = static_cast(values.data[2]); + int64_t element_4 = static_cast(values.data[3]); + return xsimd::batch(element_1, element_2, element_3, element_4); +} +#elif XSIMD_WITH_SSE2 || XSIMD_WITH_NEON || (XSIMD_WITH_SVE && SVE_BITS == 128) inline xsimd::batch cvtU32toI64(simd::Batch64 values) { int64_t lo = static_cast(values.data[0]); int64_t hi = static_cast(values.data[1]); @@ -741,7 +748,7 @@ class DictionaryColumnVisitor public: DictionaryColumnVisitor( - TFilter& filter, + const TFilter& filter, SelectiveColumnReader* reader, const RowSet& rows, ExtractValues values) @@ -913,10 +920,20 @@ class DictionaryColumnVisitor dictMask, reinterpret_cast(filterCache() - 3), indices); - auto unknowns = simd::toBitMask(xsimd::batch_bool( - simd::reinterpretBatch((cache & (kUnknown << 24)) << 1))); +#ifdef SVE_BITS + auto unknowns = simd::toBitMask( + simd::reinterpretBatch((cache & (kUnknown << 24)) << 1) != + xsimd::batch(0)); + auto passed = simd::toBitMask( + (simd::reinterpretBatch(cache) & + xsimd::batch(1)) != xsimd::batch(0)); +#else + auto unknowns = simd::toBitMask( + xsimd::batch_bool(simd::reinterpretBatch( + (cache & (kUnknown << 24)) << 1))); auto passed = simd::toBitMask( xsimd::batch_bool(simd::reinterpretBatch(cache))); +#endif if (UNLIKELY(unknowns)) { uint16_t bits = unknowns; // Ranges only over inputs that are in dictionary, the not in dictionary @@ -1159,7 +1176,7 @@ ColumnVisitor:: } auto result = DictionaryColumnVisitor( filter_, reader_, RowSet(rows_ + rowIndex_, numRows_), values_); - result.numValuesBias_ = numValuesBias_; + result.setNumValuesBias(numValuesBias_); return result; } @@ -1192,7 +1209,7 @@ class StringDictionaryColumnVisitor public: StringDictionaryColumnVisitor( - TFilter& filter, + const TFilter& filter, SelectiveColumnReader* reader, RowSet rows, ExtractValues values) @@ -1305,10 +1322,20 @@ class StringDictionaryColumnVisitor } else { cache = simd::gather(base, indices); } - auto unknowns = simd::toBitMask(xsimd::batch_bool( - simd::reinterpretBatch((cache & (kUnknown << 24)) << 1))); +#ifdef SVE_BITS + auto unknowns = simd::toBitMask( + simd::reinterpretBatch((cache & (kUnknown << 24)) << 1) != + xsimd::batch(0)); + auto passed = simd::toBitMask( + (simd::reinterpretBatch(cache) & + xsimd::batch(1)) != xsimd::batch(0)); +#else + auto unknowns = simd::toBitMask( + xsimd::batch_bool(simd::reinterpretBatch( + (cache & (kUnknown << 24)) << 1))); auto passed = simd::toBitMask( xsimd::batch_bool(simd::reinterpretBatch(cache))); +#endif if (UNLIKELY(unknowns)) { uint16_t bits = unknowns; while (bits) { @@ -1394,7 +1421,7 @@ class StringDictionaryColumnVisitor } } - folly::StringPiece valueInDictionary(int64_t index) { + StringView valueInDictionary(int64_t index) { auto stripeDictSize = DictSuper::state_.dictionary.numValues; if (index < stripeDictSize) { return reinterpret_cast( @@ -1412,7 +1439,7 @@ class DirectRleColumnVisitor public: DirectRleColumnVisitor( - TFilter& filter, + const TFilter& filter, SelectiveColumnReader* reader, RowSet rows, ExtractValues values) @@ -1547,12 +1574,15 @@ class StringColumnReadWithVisitorHelper { private: template void readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, ExtractValues extractValues, F readWithVisitor) { readWithVisitor( - ColumnVisitor( - *static_cast(filter), &reader_, rows_, extractValues)); + ColumnVisitor( + *static_cast(filter), + &reader_, + rows_, + extractValues)); } template diff --git a/velox/dwio/common/DataBuffer.h b/velox/dwio/common/DataBuffer.h index ba84cb45a219..1f2964c1a61e 100644 --- a/velox/dwio/common/DataBuffer.h +++ b/velox/dwio/common/DataBuffer.h @@ -23,10 +23,7 @@ #include "velox/common/memory/Memory.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { template >> class DataBuffer { @@ -34,8 +31,9 @@ class DataBuffer { explicit DataBuffer(velox::memory::MemoryPool& pool, uint64_t size = 0) : pool_(&pool), // Initial allocation uses calloc, to avoid memset. - buf_(reinterpret_cast( - pool_->allocateZeroFilled(1, sizeInBytes(size)))), + buf_( + reinterpret_cast( + pool_->allocateZeroFilled(1, sizeInBytes(size)))), size_(size), capacity_(size) { VELOX_CHECK(buf_ != nullptr || size_ == 0); @@ -233,7 +231,5 @@ class DataBuffer { // Maximum capacity of items of type T. uint64_t capacity_; }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/DataBufferHolder.cpp b/velox/dwio/common/DataBufferHolder.cpp index f6ff2757a352..0f5b246f062b 100644 --- a/velox/dwio/common/DataBufferHolder.cpp +++ b/velox/dwio/common/DataBufferHolder.cpp @@ -21,7 +21,7 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::dwio::common { -void DataBufferHolder::take(const std::vector& buffers) { +void DataBufferHolder::take(const std::vector& buffers) { // compute size uint64_t totalSize = 0; for (auto& buf : buffers) { @@ -38,7 +38,7 @@ void DataBufferHolder::take(const std::vector& buffers) { auto* data = buf.data(); for (auto& buffer : buffers) { const auto size = buffer.size(); - ::memcpy(data, buffer.begin(), size); + ::memcpy(data, buffer.cbegin(), size); data += size; } // If possibly, write content of the data to output immediately. Otherwise, diff --git a/velox/dwio/common/DataBufferHolder.h b/velox/dwio/common/DataBufferHolder.h index 1c291f13328b..91f4fe56007b 100644 --- a/velox/dwio/common/DataBufferHolder.h +++ b/velox/dwio/common/DataBufferHolder.h @@ -47,14 +47,14 @@ class DataBufferHolder { /// Takes content of the incoming data buffer. It is the caller's /// responsibility to resize the buffer (if required). - void take(const std::vector& buffers); + void take(const std::vector& buffers); - void take(folly::StringPiece buffer) { - take(std::vector{buffer}); + void take(std::string_view buffer) { + take(std::vector{buffer}); } void take(const dwio::common::DataBuffer& buffer) { - take(folly::StringPiece{buffer.data(), buffer.size()}); + take(std::string_view{buffer.data(), buffer.size()}); } std::vector>& getBuffers() { diff --git a/velox/dwio/common/DecoderUtil.h b/velox/dwio/common/DecoderUtil.h index a79f3bce6365..9527915e0998 100644 --- a/velox/dwio/common/DecoderUtil.h +++ b/velox/dwio/common/DecoderUtil.h @@ -162,7 +162,7 @@ void fixedWidthScan( dwio::common::SeekableInputStream& input, const char*& bufferStart, const char*& bufferEnd, - TFilter& filter, + const TFilter& filter, THook& hook) { constexpr int32_t kWidth = xsimd::batch::size; constexpr bool is16 = sizeof(T) == 2; @@ -473,7 +473,7 @@ void processFixedWidthRun( T* values, int32_t* filterHits, int32_t& numValues, - TFilter& filter, + const TFilter& filter, THook& hook) { constexpr int32_t kWidth = xsimd::batch::size; constexpr bool hasFilter = diff --git a/velox/dwio/common/DirectBufferedInput.cpp b/velox/dwio/common/DirectBufferedInput.cpp index 5b35f8b942d2..a7ebe5300bd0 100644 --- a/velox/dwio/common/DirectBufferedInput.cpp +++ b/velox/dwio/common/DirectBufferedInput.cpp @@ -51,17 +51,17 @@ std::unique_ptr DirectBufferedInput::enqueue( VELOX_CHECK_LE(region.offset + region.length, fileSize_); requests_.emplace_back(region, id); if (tracker_) { - tracker_->recordReference(id, region.length, fileNum_, groupId_); + tracker_->recordReference(id, region.length, fileNum_.id(), groupId_.id()); } auto stream = std::make_unique( this, ioStats_.get(), region, input_, - fileNum_, + fileNum_.id(), tracker_, id, - groupId_, + groupId_.id(), options_.loadQuantum()); requests_.back().stream = stream.get(); return stream; @@ -189,7 +189,7 @@ void DirectBufferedInput::readRegion( input_, ioStats_, fsStats_, - groupId_, + groupId_.id(), requests, pool_, options_.loadQuantum()); @@ -235,7 +235,7 @@ std::shared_ptr DirectBufferedInput::coalescedLoad( return streamToCoalescedLoad_.withWLock( [&](auto& loads) -> std::shared_ptr { auto it = loads.find(stream); - if (it == loads.end()) { + if (it == loads.cend()) { return nullptr; } auto load = std::move(it->second); @@ -254,7 +254,7 @@ std::unique_ptr DirectBufferedInput::read( ioStats_.get(), Region{offset, length}, input_, - fileNum_, + fileNum_.id(), nullptr, TrackingId(), 0, @@ -286,10 +286,11 @@ std::vector DirectCoalescedLoad::loadData(bool prefetch) { for (auto& request : requests_) { const auto& region = request.region; if (region.offset > lastEnd) { - buffers.push_back(folly::Range( - nullptr, - reinterpret_cast( - static_cast(region.offset - lastEnd)))); + buffers.push_back( + folly::Range( + nullptr, + reinterpret_cast( + static_cast(region.offset - lastEnd)))); overread += buffers.back().size(); } @@ -342,7 +343,7 @@ int32_t DirectCoalescedLoad::getData( requests_.begin(), requests_.end(), offset, [](auto& x, auto offset) { return x.region.offset < offset; }); - if (it == requests_.end() || it->region.offset != offset) { + if (it == requests_.cend() || it->region.offset != offset) { return 0; } data = std::move(it->data); diff --git a/velox/dwio/common/DirectBufferedInput.h b/velox/dwio/common/DirectBufferedInput.h index 4485f2c4cd81..ae1ff124ef73 100644 --- a/velox/dwio/common/DirectBufferedInput.h +++ b/velox/dwio/common/DirectBufferedInput.h @@ -71,9 +71,10 @@ class DirectCoalescedLoad : public cache::CoalescedLoad { pool_(pool) { VELOX_DCHECK_NOT_NULL(pool_); VELOX_DCHECK( - std::is_sorted(requests.begin(), requests.end(), [](auto* x, auto* y) { - return x->region.offset < y->region.offset; - })); + std::is_sorted( + requests.cbegin(), requests.cend(), [](auto* x, auto* y) { + return x->region.offset < y->region.offset; + })); requests_.reserve(requests.size()); for (auto i = 0; i < requests.size(); ++i) { requests_.push_back(std::move(*requests[i])); @@ -117,22 +118,26 @@ class DirectBufferedInput : public BufferedInput { DirectBufferedInput( std::shared_ptr readFile, const MetricsLogPtr& metricsLog, - uint64_t fileNum, + StringIdLease fileNum, std::shared_ptr tracker, - uint64_t groupId, + StringIdLease groupId, std::shared_ptr ioStats, std::shared_ptr fsStats, folly::Executor* executor, - const io::ReaderOptions& readerOptions) + const io::ReaderOptions& readerOptions, + folly::F14FastMap fileReadOps = {}) : BufferedInput( std::move(readFile), readerOptions.memoryPool(), metricsLog, ioStats.get(), - fsStats.get()), - fileNum_(fileNum), + fsStats.get(), + kMaxMergeDistance, + std::nullopt, + std::move(fileReadOps)), + fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), - groupId_(groupId), + groupId_(std::move(groupId)), ioStats_(std::move(ioStats)), fsStats_(std::move(fsStats)), executor_(executor), @@ -166,7 +171,7 @@ class DirectBufferedInput : public BufferedInput { void setNumStripes(int32_t numStripes) override { auto* stats = tracker_->fileGroupStats(); if (stats) { - stats->recordFile(fileNum_, groupId_, numStripes); + stats->recordFile(fileNum_.id(), groupId_.id(), numStripes); } } @@ -208,17 +213,17 @@ class DirectBufferedInput : public BufferedInput { /// Constructor used by clone(). DirectBufferedInput( std::shared_ptr input, - uint64_t fileNum, + StringIdLease fileNum, std::shared_ptr tracker, - uint64_t groupId, + StringIdLease groupId, std::shared_ptr ioStats, std::shared_ptr fsStats, folly::Executor* executor, const io::ReaderOptions& readerOptions) : BufferedInput(std::move(input), readerOptions.memoryPool()), - fileNum_(fileNum), + fileNum_(std::move(fileNum)), tracker_(std::move(tracker)), - groupId_(groupId), + groupId_(std::move(groupId)), ioStats_(std::move(ioStats)), fsStats_(std::move(fsStats)), executor_(executor), @@ -261,9 +266,9 @@ class DirectBufferedInput : public BufferedInput { } }; - const uint64_t fileNum_; + const StringIdLease fileNum_; const std::shared_ptr tracker_; - const uint64_t groupId_; + const StringIdLease groupId_; const std::shared_ptr ioStats_; const std::shared_ptr fsStats_; folly::Executor* const executor_; diff --git a/velox/dwio/common/DirectInputStream.cpp b/velox/dwio/common/DirectInputStream.cpp index 68173b40f259..8840768854c0 100644 --- a/velox/dwio/common/DirectInputStream.cpp +++ b/velox/dwio/common/DirectInputStream.cpp @@ -91,8 +91,8 @@ bool DirectInputStream::SkipInt64(int64_t count) { return false; } -google::protobuf::int64 DirectInputStream::ByteCount() const { - return static_cast(offsetInRegion_); +int64_t DirectInputStream::ByteCount() const { + return static_cast(offsetInRegion_); } void DirectInputStream::seekToPosition(PositionProvider& seekPosition) { diff --git a/velox/dwio/common/DirectInputStream.h b/velox/dwio/common/DirectInputStream.h index 3d75b4459568..45fb465f37ab 100644 --- a/velox/dwio/common/DirectInputStream.h +++ b/velox/dwio/common/DirectInputStream.h @@ -44,7 +44,7 @@ class DirectInputStream : public SeekableInputStream { bool Next(const void** data, int* size) override; void BackUp(int count) override; bool SkipInt64(int64_t count) override; - google::protobuf::int64 ByteCount() const override; + int64_t ByteCount() const override; void seekToPosition(PositionProvider& position) override; std::string getName() const override; diff --git a/velox/dwio/common/ErrorTolerance.h b/velox/dwio/common/ErrorTolerance.h index db186084cea1..4f2aafd6111f 100644 --- a/velox/dwio/common/ErrorTolerance.h +++ b/velox/dwio/common/ErrorTolerance.h @@ -18,10 +18,7 @@ #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { /** * Error tolerance level for readers @@ -63,7 +60,4 @@ struct ErrorTolerance { } }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/FileSink.cpp b/velox/dwio/common/FileSink.cpp index 3dcc44186182..6654dfb64070 100644 --- a/velox/dwio/common/FileSink.cpp +++ b/velox/dwio/common/FileSink.cpp @@ -26,8 +26,8 @@ namespace facebook::velox::dwio::common { namespace { -constexpr std::string_view kFileScheme("file:"); -constexpr std::string_view kFileSep("/"); +constexpr std::string_view kFileScheme{"file:"}; +constexpr char kFileSep{'/'}; std::vector& factories() { static std::vector factories; @@ -37,10 +37,10 @@ std::vector& factories() { std::unique_ptr localFileSink( const std::string& filePath, const FileSink::Options& options) { - if (filePath.find(kFileScheme) == 0) { + if (filePath.starts_with(kFileScheme)) { return std::make_unique(filePath.substr(5), options); } - if (filePath.find(kFileSep) == 0) { + if (filePath.starts_with(kFileSep)) { return std::make_unique(filePath, options); } return nullptr; diff --git a/velox/dwio/common/FileSink.h b/velox/dwio/common/FileSink.h index 6fe559790279..914d06055f1a 100644 --- a/velox/dwio/common/FileSink.h +++ b/velox/dwio/common/FileSink.h @@ -46,6 +46,7 @@ class FileSink : public Closeable { memory::MemoryPool* pool{nullptr}; MetricsLogPtr metricLogger{MetricsLog::voidLog()}; IoStatistics* stats{nullptr}; + filesystems::File::IoStats* fileSystemStats{nullptr}; }; FileSink(std::string name, const Options& options) @@ -54,6 +55,7 @@ class FileSink : public Closeable { pool_(options.pool), metricLogger_{options.metricLogger}, stats_{options.stats}, + fileSystemStats_{options.fileSystemStats}, size_{0} {} ~FileSink() override { @@ -117,6 +119,7 @@ class FileSink : public Closeable { memory::MemoryPool* const pool_; const MetricsLogPtr metricLogger_; IoStatistics* const stats_; + filesystems::File::IoStats* const fileSystemStats_; uint64_t size_; }; diff --git a/velox/dwio/common/FlatMapHelper.cpp b/velox/dwio/common/FlatMapHelper.cpp index 5d48887af475..e392bf60a65f 100644 --- a/velox/dwio/common/FlatMapHelper.cpp +++ b/velox/dwio/common/FlatMapHelper.cpp @@ -21,11 +21,20 @@ namespace facebook::velox::dwio::common::flatmap { namespace detail { -void reset(VectorPtr& vector, vector_size_t size, bool hasNulls) { +void reset( + VectorPtr& vector, + VectorEncoding::Simple desiredEncoding, + vector_size_t size, + bool hasNulls) { if (!vector) { return; } + if (vector->encoding() != desiredEncoding) { + vector.reset(); + return; + } + if (vector.use_count() > 1) { vector.reset(); return; @@ -39,6 +48,9 @@ void reset(VectorPtr& vector, vector_size_t size, bool hasNulls) { } } vector->resize(size); + // Reside BaseVector::length_ as it will be updated in the subsequent copy() + // calls. + vector->BaseVector::resize(0); } void initializeStringVector( @@ -162,7 +174,7 @@ void initializeVectorImpl( } } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::ARRAY, size, hasNulls); VectorPtr origElementsVector; if (vector) { auto& arrayVector = dynamic_cast(*vector); @@ -226,7 +238,7 @@ void initializeMapVector( size = sizeOverride.value(); } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::MAP, size, hasNulls); VectorPtr origKeysVector; VectorPtr origValuesVector; if (vector) { @@ -298,7 +310,7 @@ void initializeVectorImpl( } } - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::ROW, size, hasNulls); std::vector origChildren; if (vector) { auto& rowVector = dynamic_cast(*vector); @@ -359,8 +371,9 @@ vector_size_t copyNulls( vector_size_t nulls = 0; // it's assumed that initVector is called before calling this method to // properly allocate/clear nulls buffer. So we only need to check against - // target vector here. - target.resize(targetIndex + count, false); + // target vector here. We only call BaseVector::resize here to make sure + // BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + count, false); if (target.mayHaveNulls()) { auto tgtNulls = const_cast(target.rawNulls()); if (source.isConstantEncoding()) { @@ -485,7 +498,10 @@ vector_size_t copyOffsets( vector_size_t sourceIndex, vector_size_t count, vector_size_t& childOffset) { - target.resize(targetIndex + count); + // Its expected that initVector is called before calling this method so the + // offsets and sizes buffers are properly allocated. We only call + // BaseVector::resize here to make sure BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + count); auto tgtOffsets = const_cast(target.rawOffsets()); auto tgtSizes = const_cast(target.rawSizes()); auto srcSizes = source.rawSizes(); @@ -701,8 +717,9 @@ bool copyNull( vector_size_t sourceIndex) { // it's assumed that initVector is called before calling this method to // properly allocate/clear nulls buffer. So we only need to check against - // target vector here. - target.resize(targetIndex + 1, false); + // target vector here. We only call BaseVector::resize here to make sure + // BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + 1, false); if (target.mayHaveNulls()) { bool srcIsNull = (source.isConstantEncoding() || @@ -792,7 +809,10 @@ vector_size_t copyOffset( const T& source, vector_size_t sourceIndex, vector_size_t& childOffset) { - target.resize(targetIndex + 1); + // Its expected that initVector is called before calling this method so the + // offsets and sizes buffers are properly allocated. We only call + // BaseVector::resize here to make sure BaseVector::size() is only updated. + target.BaseVector::resize(targetIndex + 1); auto tgtSizes = const_cast(target.rawSizes()); childOffset = nextChildOffset(target, targetIndex); const_cast(target.rawOffsets())[targetIndex] = childOffset; diff --git a/velox/dwio/common/FlatMapHelper.h b/velox/dwio/common/FlatMapHelper.h index ac97b5c3b134..404e7891a089 100644 --- a/velox/dwio/common/FlatMapHelper.h +++ b/velox/dwio/common/FlatMapHelper.h @@ -24,7 +24,11 @@ namespace facebook::velox::dwio::common::flatmap { namespace detail { // Reset vector with the desired size/hasNulls properties -void reset(VectorPtr& vector, vector_size_t size, bool hasNulls); +void reset( + VectorPtr& vector, + VectorEncoding::Simple desiredEncoding, + vector_size_t size, + bool hasNulls); // Reset vector smart pointer if any of the buffers is not single referenced. template @@ -39,6 +43,14 @@ void resetIfNotWritable(VectorPtr& vector, const T&... buffer) { } // namespace detail +// Output type of flat map column reader, indicates the in memory representation +// the flat map should be read into. +enum class FlatMapOutput : uint8_t { + kMap = 0, // MapVector + kStruct = 1, // RowVector + kFlatMap = 2, // FlatMapVector +}; + // Struct for keeping track flatmap key stream metrics. // Used by keySelectionCallback_ in FlatMapColumnReader struct FlatMapKeySelectionStats { @@ -55,7 +67,7 @@ void initializeFlatVector( vector_size_t size, bool hasNulls, std::vector&& stringBuffers = {}) { - detail::reset(vector, size, hasNulls); + detail::reset(vector, VectorEncoding::Simple::FLAT, size, hasNulls); if (vector) { auto& flatVector = dynamic_cast&>(*vector); detail::resetIfNotWritable(vector, flatVector.nulls(), flatVector.values()); @@ -122,6 +134,7 @@ class KeyValue { const T& get() const { return value_; } + std::size_t hash() const { return h_; } @@ -223,14 +236,14 @@ KeyPredicate prepareKeyPredicate(std::string_view expression) { // You cannot mix allow key and reject key. VELOX_CHECK( modes.empty() || - std::all_of(modes.begin(), modes.end(), [&modes](const auto& v) { + std::all_of(modes.cbegin(), modes.cend(), [&modes](const auto& v) { return v == modes.front(); })); auto mode = modes.empty() ? KeyProjectionMode::ALLOW : modes.front(); return KeyPredicate( - mode, typename KeyPredicate::Lookup(keys.begin(), keys.end())); + mode, typename KeyPredicate::Lookup(keys.cbegin(), keys.cend())); } } // namespace facebook::velox::dwio::common::flatmap diff --git a/velox/dwio/common/FormatData.h b/velox/dwio/common/FormatData.h index 4e8a5548fb06..d045f36cb4a2 100644 --- a/velox/dwio/common/FormatData.h +++ b/velox/dwio/common/FormatData.h @@ -131,8 +131,8 @@ class FormatData { /// Base class for format-specific reader initialization arguments. class FormatParams { public: - explicit FormatParams(memory::MemoryPool& pool, ColumnReaderStatistics& stats) - : pool_(pool), stats_(stats) {} + FormatParams(memory::MemoryPool& pool, ColumnReaderStatistics& stats) + : pool_(&pool), stats_(&stats) {} virtual ~FormatParams() = default; @@ -143,16 +143,16 @@ class FormatParams { const velox::common::ScanSpec& scanSpec) = 0; memory::MemoryPool& pool() { - return pool_; + return *pool_; } ColumnReaderStatistics& runtimeStatistics() { - return stats_; + return *stats_; } private: - memory::MemoryPool& pool_; - ColumnReaderStatistics& stats_; + memory::MemoryPool* const pool_; + ColumnReaderStatistics* const stats_; }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/InputStream.cpp b/velox/dwio/common/InputStream.cpp index cc20a5fc55e5..c25a7b0be772 100644 --- a/velox/dwio/common/InputStream.cpp +++ b/velox/dwio/common/InputStream.cpp @@ -64,8 +64,10 @@ ReadFileInputStream::ReadFileInputStream( std::shared_ptr readFile, const MetricsLogPtr& metricsLog, IoStatistics* stats, - filesystems::File::IoStats* fsStats) + filesystems::File::IoStats* fsStats, + folly::F14FastMap fileReadOps) : InputStream(readFile->getName(), metricsLog, stats, fsStats), + fileStorageContext_(fsStats, std::move(fileReadOps)), readFile_(std::move(readFile)) {} void ReadFileInputStream::read( @@ -79,7 +81,7 @@ void ReadFileInputStream::read( std::string_view readData; { MicrosecondTimer timer(&readTimeUs); - readData = readFile_->pread(offset, length, buf, fsStats_); + readData = readFile_->pread(offset, length, buf, fileStorageContext_); } if (stats_) { stats_->incRawBytesRead(length); @@ -102,7 +104,7 @@ void ReadFileInputStream::read( LogType logType) { const int64_t bufferSize = totalBufferSize(buffers); logRead(offset, bufferSize, logType); - const auto size = readFile_->preadv(offset, buffers, fsStats_); + const auto size = readFile_->preadv(offset, buffers, fileStorageContext_); VELOX_CHECK_EQ( size, bufferSize, @@ -119,7 +121,7 @@ folly::SemiFuture ReadFileInputStream::readAsync( LogType logType) { const int64_t bufferSize = totalBufferSize(buffers); logRead(offset, bufferSize, logType); - return readFile_->preadvAsync(offset, buffers, fsStats_); + return readFile_->preadvAsync(offset, buffers, fileStorageContext_); } bool ReadFileInputStream::hasReadAsync() const { @@ -138,7 +140,7 @@ void ReadFileInputStream::vread( [&](size_t acc, const auto& r) { return acc + r.length; }); logRead(regions[0].offset, length, purpose); auto readStartMicros = getCurrentTimeMicro(); - readFile_->preadv(regions, iobufs, fsStats_); + readFile_->preadv(regions, iobufs, fileStorageContext_); if (stats_) { stats_->incRawBytesRead(length); stats_->incTotalScanTime((getCurrentTimeMicro() - readStartMicros) * 1000); diff --git a/velox/dwio/common/InputStream.h b/velox/dwio/common/InputStream.h index 30210cb31a96..34dc948550c0 100644 --- a/velox/dwio/common/InputStream.h +++ b/velox/dwio/common/InputStream.h @@ -16,10 +16,6 @@ #pragma once -#include -#include -#include -#include #include #include #include @@ -30,6 +26,7 @@ #include #include +#include #include "velox/common/file/File.h" #include "velox/common/file/Region.h" #include "velox/common/io/IoStatistics.h" @@ -147,7 +144,8 @@ class ReadFileInputStream final : public InputStream { std::shared_ptr, const MetricsLogPtr& metricsLog = MetricsLog::voidLog(), IoStatistics* stats = nullptr, - filesystems::File::IoStats* fsStats = nullptr); + filesystems::File::IoStats* fsStats = nullptr, + folly::F14FastMap fileReadOps = {}); ~ReadFileInputStream() override = default; @@ -183,6 +181,7 @@ class ReadFileInputStream final : public InputStream { } private: + FileStorageContext fileStorageContext_; std::shared_ptr readFile_; }; diff --git a/velox/dwio/common/MeasureTime.h b/velox/dwio/common/MeasureTime.h index c4eaf1281f37..1b7de305a032 100644 --- a/velox/dwio/common/MeasureTime.h +++ b/velox/dwio/common/MeasureTime.h @@ -20,10 +20,7 @@ #include #include -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class MeasureTime { public: @@ -61,7 +58,4 @@ inline std::optional measureTimeIfCallback( return std::nullopt; } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/MetadataFilter.cpp b/velox/dwio/common/MetadataFilter.cpp index 374a2d861082..88e3a5481e28 100644 --- a/velox/dwio/common/MetadataFilter.cpp +++ b/velox/dwio/common/MetadataFilter.cpp @@ -18,6 +18,7 @@ #include #include "velox/dwio/common/ScanSpec.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" namespace facebook::velox::common { @@ -67,86 +68,118 @@ class MetadataFilter::LeafNode : public Node { std::unique_ptr filter_; }; -struct MetadataFilter::AndNode : Node { +struct MetadataFilter::ConditionNode : Node { static std::unique_ptr create( - std::unique_ptr lhs, - std::unique_ptr rhs) { - if (!lhs) { - return rhs; - } - if (!rhs) { - return lhs; + bool conjuction, + std::vector> args); + + static std::unique_ptr fromExpression( + const std::vector& inputs, + core::ExpressionEvaluator* evaluator, + bool conjunction, + bool negated) { + conjunction = negated ? !conjunction : conjunction; + std::vector> args; + args.reserve(inputs.size()); + for (const auto& input : inputs) { + auto node = Node::fromExpression(*input, evaluator, negated); + if (node) { + args.push_back(std::move(node)); + } else if (!conjunction) { + return nullptr; + } } - return std::make_unique(std::move(lhs), std::move(rhs)); + return create(conjunction, std::move(args)); } - AndNode(std::unique_ptr lhs, std::unique_ptr rhs) - : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} + explicit ConditionNode(std::vector> args) + : args_{std::move(args)} {} - void addToScanSpec(ScanSpec& scanSpec) const override { - lhs_->addToScanSpec(scanSpec); - rhs_->addToScanSpec(scanSpec); - } - - uint64_t* eval(LeafResults& leafResults, int size) const override { - auto* l = lhs_->eval(leafResults, size); - auto* r = rhs_->eval(leafResults, size); - if (!l) { - return r; - } - if (!r) { - return l; + void addToScanSpec(ScanSpec& scanSpec) const final { + for (const auto& arg : args_) { + arg->addToScanSpec(scanSpec); } - bits::orBits(l, r, 0, size); - return l; } - std::string toString() const override { - return "and(" + lhs_->toString() + "," + rhs_->toString() + ")"; + protected: + std::string ToStringImpl(std::string_view prefix) const { + std::string result{prefix}; + for (size_t i = 0; i < args_.size(); ++i) { + if (i != 0) { + result += ","; + } + result += args_[i]->toString(); + } + result += ")"; + return result; } - private: - std::unique_ptr lhs_; - std::unique_ptr rhs_; + std::vector> args_; }; -struct MetadataFilter::OrNode : Node { - static std::unique_ptr create( - std::unique_ptr lhs, - std::unique_ptr rhs) { - if (!lhs || !rhs) { - return nullptr; +struct MetadataFilter::AndNode final : ConditionNode { + using ConditionNode::ConditionNode; + + uint64_t* eval(LeafResults& leafResults, int size) const final { + uint64_t* result = nullptr; + for (const auto& arg : args_) { + auto* a = arg->eval(leafResults, size); + if (!a) { + continue; + } + if (!result) { + result = a; + } else { + bits::orBits(result, a, 0, size); + } } - return std::make_unique(std::move(lhs), std::move(rhs)); + return result; } - OrNode(std::unique_ptr lhs, std::unique_ptr rhs) - : lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} - - void addToScanSpec(ScanSpec& scanSpec) const override { - lhs_->addToScanSpec(scanSpec); - rhs_->addToScanSpec(scanSpec); + std::string toString() const final { + return ToStringImpl("and("); } +}; - uint64_t* eval(LeafResults& leafResults, int size) const override { - auto* l = lhs_->eval(leafResults, size); - auto* r = rhs_->eval(leafResults, size); - if (!l || !r) { - return nullptr; +struct MetadataFilter::OrNode final : ConditionNode { + using ConditionNode::ConditionNode; + + uint64_t* eval(LeafResults& leafResults, int size) const final { + uint64_t* result = nullptr; + for (const auto& arg : args_) { + auto* a = arg->eval(leafResults, size); + if (!a) { + return nullptr; + } + if (!result) { + result = a; + } else { + bits::andBits(result, a, 0, size); + } } - bits::andBits(l, r, 0, size); - return l; + return result; } - std::string toString() const override { - return "or(" + lhs_->toString() + "," + rhs_->toString() + ")"; + std::string toString() const final { + return ToStringImpl("or("); } - - private: - std::unique_ptr lhs_; - std::unique_ptr rhs_; }; +std::unique_ptr MetadataFilter::ConditionNode::create( + bool conjunction, + std::vector> args) { + if (args.empty()) { + return nullptr; + } + if (args.size() == 1) { + return std::move(args[0]); + } + if (conjunction) { + return std::make_unique(std::move(args)); + } + return std::make_unique(std::move(args)); +} + namespace { const core::CallTypedExpr* asCall(const core::ITypedExpr* expr) { @@ -163,29 +196,26 @@ std::unique_ptr MetadataFilter::Node::fromExpression( if (!call) { return nullptr; } - if (call->name() == "and") { - auto lhs = fromExpression(*call->inputs()[0], evaluator, negated); - auto rhs = fromExpression(*call->inputs()[1], evaluator, negated); - return negated ? OrNode::create(std::move(lhs), std::move(rhs)) - : AndNode::create(std::move(lhs), std::move(rhs)); + if (call->name() == expression::kAnd) { + return ConditionNode::fromExpression( + call->inputs(), evaluator, true, negated); } - if (call->name() == "or") { - auto lhs = fromExpression(*call->inputs()[0], evaluator, negated); - auto rhs = fromExpression(*call->inputs()[1], evaluator, negated); - return negated ? AndNode::create(std::move(lhs), std::move(rhs)) - : OrNode::create(std::move(lhs), std::move(rhs)); + if (call->name() == expression::kOr) { + return ConditionNode::fromExpression( + call->inputs(), evaluator, false, negated); } if (call->name() == "not") { return fromExpression(*call->inputs()[0], evaluator, !negated); } try { - Subfield subfield; - auto filter = + auto subfieldAndFilter = exec::ExprToSubfieldFilterParser::getInstance() - ->leafCallToSubfieldFilter(*call, subfield, evaluator, negated); - if (!filter) { + ->leafCallToSubfieldFilter(*call, evaluator, negated); + if (!subfieldAndFilter.has_value()) { return nullptr; } + + auto& [subfield, filter] = subfieldAndFilter.value(); VELOX_CHECK( subfield.valid(), "Invalid subfield from expression: {}", diff --git a/velox/dwio/common/MetadataFilter.h b/velox/dwio/common/MetadataFilter.h index 62b604b14407..d626bbdd9675 100644 --- a/velox/dwio/common/MetadataFilter.h +++ b/velox/dwio/common/MetadataFilter.h @@ -50,6 +50,7 @@ class MetadataFilter { private: struct Node; + struct ConditionNode; struct AndNode; struct OrNode; diff --git a/velox/dwio/common/MetricsLog.h b/velox/dwio/common/MetricsLog.h index 555e92f24d63..db54c36821e6 100644 --- a/velox/dwio/common/MetricsLog.h +++ b/velox/dwio/common/MetricsLog.h @@ -19,15 +19,12 @@ #include "velox/dwio/common/FilterNode.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class MetricsLog { public: static constexpr std::string_view LIB_VERSION_STRING{"1.1"}; - static constexpr folly::StringPiece WRITE_OPERATION{"WRITE"}; + static constexpr std::string_view WRITE_OPERATION{"WRITE"}; enum class MetricsType { HEADER, @@ -128,8 +125,8 @@ class MetricsLog { virtual void logFileClose(const FileCloseMetrics& /* metrics */) const {} static std::shared_ptr voidLog() { - static std::shared_ptr log{new MetricsLog("")}; - return log; + static const MetricsLog kInstance{{}}; + return {std::shared_ptr{}, &kInstance}; } protected: @@ -178,7 +175,4 @@ void registerMetricsLogFactory(std::shared_ptr factory); DwioMetricsLogFactory& getMetricsLogFactory(); -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/OnDemandUnitLoader.cpp b/velox/dwio/common/OnDemandUnitLoader.cpp index d4ef4f0a5ef2..6a5616a31e53 100644 --- a/velox/dwio/common/OnDemandUnitLoader.cpp +++ b/velox/dwio/common/OnDemandUnitLoader.cpp @@ -15,12 +15,12 @@ */ #include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/common/time/Timer.h" #include #include "velox/common/base/Exceptions.h" #include "velox/dwio/common/MeasureTime.h" -#include "velox/dwio/common/UnitLoaderTools.h" using facebook::velox::dwio::common::measureTimeIfCallback; @@ -42,6 +42,7 @@ class OnDemandUnitLoader : public UnitLoader { LoadUnit& getLoadedUnit(uint32_t unit) override { VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + processedUnits_.insert(unit); if (loadedUnit_.has_value()) { if (loadedUnit_.value() == unit) { return *loadUnits_[unit]; @@ -51,11 +52,14 @@ class OnDemandUnitLoader : public UnitLoader { loadedUnit_.reset(); } + uint64_t unitLoadNanos{0}; { + NanosecondTimer timer{&unitLoadNanos}; auto measure = measureTimeIfCallback(blockedOnIoCallback_); loadUnits_[unit]->load(); } loadedUnit_ = unit; + unitLoadNanos_ += unitLoadNanos; return *loadUnits_[unit]; } @@ -73,11 +77,28 @@ class OnDemandUnitLoader : public UnitLoader { rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); } + UnitLoaderStats stats() override { + UnitLoaderStats stats; + stats.addCounter("processedUnits", RuntimeCounter(processedUnits_.size())); + stats.addCounter( + "unitLoadNanos", + RuntimeCounter( + unitLoadNanos_ > std::numeric_limits::max() + ? std::numeric_limits::max() + : unitLoadNanos_, + RuntimeCounter::Unit::kNanos)); + return stats; + } + private: const std::vector> loadUnits_; const std::function blockedOnIoCallback_; std::optional loadedUnit_; + + // Stats + std::unordered_set processedUnits_; + uint64_t unitLoadNanos_{0}; }; } // namespace diff --git a/velox/dwio/common/Options.h b/velox/dwio/common/Options.h index 0526cfbd9dd7..ccfd6e78eb62 100644 --- a/velox/dwio/common/Options.h +++ b/velox/dwio/common/Options.h @@ -17,7 +17,8 @@ #pragma once #include -#include +#include +#include #include #include "velox/common/base/RandomUtil.h" @@ -51,6 +52,7 @@ enum class FileFormat { NIMBLE = 8, ORC = 9, SST = 10, // rocksdb sst format + FLUX = 11, }; FileFormat toFileFormat(std::string_view s); @@ -72,7 +74,34 @@ enum class SerDeSeparator { class SerDeOptions { public: + /// The following members control how data is separated in TEXT format files: + /// + /// - 'separators': An array of separator characters used to delimit columns + /// and nested data. + /// - 'separators[0]' defines the delimiter that separates top-level + /// columns. + /// - 'separators[1 to depth_-1]' defines the delimiters that separate + /// nested data within a ComplexType column. + /// - 'newLine': The character used to separate rows in the file. + /// + /// Suppose we have a schema: ROW(MAP(VARCHAR(), ARRAY(BIGINT())), BOOLEAN()) + /// With the following configuration: + /// - separators = [',', '@', ':', '#] + /// - newLine = '\n' + /// - nullString = "NULL" + /// + /// With the following data to be written: + /// - row1: {key1:[10, 20, 30], key2:[40, 50, 60]}, true + /// - row2: {key3:[100, 2, 30], key4:[80, 40, 45]}, true + /// + /// A sample text file with the 2 rows of data above would look like this: + /// key1:10#20#30@key2:40#50#60,true\n + /// key3:100#2#30@key4:80#40#45,true\n + std::array separators; + uint8_t newLine; + + /// Null values are represented by 'nullString' std::string nullString; bool lastColumnTakesRest; uint8_t escapeChar; @@ -88,8 +117,10 @@ class SerDeOptions { uint8_t collectionDelim = '\2', uint8_t mapKeyDelim = '\3', uint8_t escape = '\\', - bool isEscapedFlag = false) + bool isEscapedFlag = false, + uint8_t newLine = '\n') : separators{{fieldDelim, collectionDelim, mapKeyDelim, 4, 5, 6, 7, 8}}, + newLine(newLine), nullString("\\N"), lastColumnTakesRest(false), escapeChar(escape), @@ -253,6 +284,22 @@ class RowReaderOptions { scanSpec_ = std::move(scanSpec); } + folly::Executor* ioExecutor() const { + return ioExecutor_; + } + + void setIOExecutor(folly::Executor* const ioExecutor) { + ioExecutor_ = ioExecutor; + } + + const size_t parallelUnitLoadCount() const { + return parallelUnitLoadCount_; + } + + void setParallelUnitLoadCount(size_t parallelUnitLoadCount) { + parallelUnitLoadCount_ = parallelUnitLoadCount; + } + const std::shared_ptr& metadataFilter() const { return metadataFilter_; } @@ -267,8 +314,8 @@ class RowReaderOptions { flatmapNodeIdsAsStruct) { VELOX_CHECK( std::all_of( - flatmapNodeIdsAsStruct.begin(), - flatmapNodeIdsAsStruct.end(), + flatmapNodeIdsAsStruct.cbegin(), + flatmapNodeIdsAsStruct.cend(), [](const auto& kv) { return !kv.second.empty(); }), "To use struct encoding for flatmap, keys to project must be specified"); flatmapNodeIdAsStruct_ = std::move(flatmapNodeIdsAsStruct); @@ -279,6 +326,10 @@ class RowReaderOptions { return flatmapNodeIdAsStruct_; } + void setPreserveFlatMapsInMemory(bool preserveFlatMapsInMemory) { + preserveFlatMapsInMemory_ = preserveFlatMapsInMemory; + } + void setDecodingExecutor(std::shared_ptr executor) { decodingExecutor_ = executor; } @@ -348,6 +399,10 @@ class RowReaderOptions { return skipRows_; } + bool preserveFlatMapsInMemory() const { + return preserveFlatMapsInMemory_; + } + void setUnitLoaderFactory( std::shared_ptr unitLoaderFactory) { unitLoaderFactory_ = std::move(unitLoaderFactory); @@ -391,19 +446,43 @@ class RowReaderOptions { serdeParameters_ = std::move(serdeParameters); } + bool trackRowSize() const { + return trackRowSize_; + } + + void setTrackRowSize(bool trackRowSize) { + trackRowSize_ = trackRowSize; + } + + bool passStringBuffersFromDecoder() const { + return passStringBuffersFromDecoder_; + } + + void setPassStringBuffersFromDecoder(bool passStringBuffersFromDecoder) { + passStringBuffersFromDecoder_ = passStringBuffersFromDecoder; + } + private: uint64_t dataStart_; uint64_t dataLength_; bool preloadStripe_; bool projectSelectedType_; bool returnFlatVector_ = false; + size_t parallelUnitLoadCount_ = 0; ErrorTolerance errorTolerance_; std::shared_ptr selector_; RowTypePtr requestedType_; std::shared_ptr scanSpec_{nullptr}; std::shared_ptr metadataFilter_; + // Node id for map column to a list of keys to be projected as a struct. std::unordered_map> flatmapNodeIdAsStruct_; + + // Whether to generate FlatMapVectors when reading flat maps from the file. By + // default, converts flat maps in the file to MapVectors. + bool preserveFlatMapsInMemory_ = false; + // Optional io executor to enable parallel unit loader. + folly::Executor* ioExecutor_{nullptr}; // Optional executors to enable internal reader parallelism. // 'decodingExecutor' allow parallelising the vector decoding process. // 'ioExecutor' enables parallelism when performing file system read @@ -411,6 +490,7 @@ class RowReaderOptions { std::shared_ptr decodingExecutor_; size_t decodingParallelismFactor_{0}; std::optional rowNumberColumnInfo_{std::nullopt}; + // Parameters that are provided as the physical storage properties. std::unordered_map storageParameters_{}; // Parameters that are provided as the serialization/deserialization @@ -441,6 +521,8 @@ class RowReaderOptions { TimestampPrecision timestampPrecision_ = TimestampPrecision::kMilliseconds; std::shared_ptr formatSpecificOptions_; + bool trackRowSize_{false}; + bool passStringBuffersFromDecoder_{true}; }; /// Options for creating a Reader. @@ -463,6 +545,13 @@ class ReaderOptions : public io::ReaderOptions { return *this; } + /// Sets the property bag. + ReaderOptions& setProperties( + std::unordered_map properties) { + properties_ = std::move(properties); + return *this; + } + /// Sets the current table schema of the file (a Type tree). This could be /// different from the actual schema in file if schema evolution happened. /// For "dwrf" format, a default schema is derived from the file. For "rc" @@ -536,6 +625,11 @@ class ReaderOptions : public io::ReaderOptions { return fileFormat_; } + /// Gets the property bag. + const std::unordered_map& properties() const { + return properties_; + } + /// Gets the file schema. const std::shared_ptr& fileSchema() const { return fileSchema_; @@ -626,6 +720,7 @@ class ReaderOptions : public io::ReaderOptions { FileFormat fileFormat_; RowTypePtr fileSchema_; SerDeOptions serDeOptions_; + std::unordered_map properties_{}; std::shared_ptr decrypterFactory_; uint64_t footerEstimatedSize_{kDefaultFooterEstimatedSize}; uint64_t filePreloadThreshold_{kDefaultFilePreloadThreshold}; diff --git a/velox/dwio/common/OutputStream.h b/velox/dwio/common/OutputStream.h index 46e90410ab18..d106cd4c9f3b 100644 --- a/velox/dwio/common/OutputStream.h +++ b/velox/dwio/common/OutputStream.h @@ -47,8 +47,8 @@ class BufferedOutputStream : public google::protobuf::io::ZeroCopyOutputStream { void BackUp(int32_t count) override; - google::protobuf::int64 ByteCount() const override { - return static_cast(size()); + int64_t ByteCount() const override { + return static_cast(size()); } bool WriteAliasedRaw(const void* /* unused */, int32_t /* unused */) @@ -116,7 +116,7 @@ class BufferedOutputStream : public google::protobuf::io::ZeroCopyOutputStream { void** buffer, int32_t* size, uint64_t headerSize, - const std::vector& bufferToFlush) { + const std::vector& bufferToFlush) { bufferHolder_.take(bufferToFlush); *buffer = buffer_.data() + headerSize; *size = static_cast(buffer_.size() - headerSize); diff --git a/velox/dwio/common/ParallelUnitLoader.cpp b/velox/dwio/common/ParallelUnitLoader.cpp new file mode 100644 index 000000000000..2ef6718e8fcf --- /dev/null +++ b/velox/dwio/common/ParallelUnitLoader.cpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/ParallelUnitLoader.h" +#include +#include "velox/common/base/AsyncSource.h" +#include "velox/common/base/Exceptions.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" + +namespace facebook::velox::dwio::common { + +class ParallelUnitLoader : public UnitLoader { + public: + /// Enables concurrent loading of multiple units (stripes, row groups, etc.) + /// using asynchronous I/O to improve throughput and reduce read latency. + /// + /// **Loading Strategy:** + /// - Initialization: Preloads up to `maxConcurrentLoads` units concurrently + /// - Access pattern: On each getLoadedUnit() call, ensures the requested unit + /// is loaded and triggers loading of subsequent units within the window + /// - Memory management: Unloads all previous units to control memory usage + /// + /// **Performance Characteristics:** + /// - Best suited for sequential access patterns + /// - Memory usage: O(maxConcurrentLoads * average_unit_size) + /// - I/O parallelism: Up to `maxConcurrentLoads` concurrent load operations + /// + /// **Parameters:** + /// @param units All units to be loaded + /// @param ioExecutor Thread pool for asynchronous unit loading operations + /// @param maxConcurrentLoads Maximum units to load concurrently (sliding + /// window size) + /// + /// **Example with maxConcurrentLoads=3:** + /// ``` + /// Units: [0,1,2,3,4,5,6,7,8,9] + /// Init: Load [0,1,2] concurrently + /// Get(0): Wait for unit 0, trigger load of units [0,1,2], unload none + /// Get(1): Wait for unit 1, trigger load of units [1,2,3], unload [0] + /// Get(2): Wait for unit 2, trigger load of units [2,3,4], unload [0,1] + /// ``` + ParallelUnitLoader( + std::vector> units, + folly::Executor* ioExecutor, + uint16_t maxConcurrentLoads) + : loadUnits_( + std::make_move_iterator(units.begin()), + std::make_move_iterator(units.end())), + ioExecutor_(ioExecutor), + maxConcurrentLoads_(maxConcurrentLoads) { + VELOX_CHECK_NOT_NULL(ioExecutor, "ParallelUnitLoader ioExecutor is null"); + VELOX_CHECK_GT( + maxConcurrentLoads_, + 0, + "ParallelUnitLoader maxConcurrentLoads should be larger than 0"); + asyncSources_.resize(loadUnits_.size()); + unitsLoaded_.resize(loadUnits_.size()); + } + + /// Destructor ensures all pending load operations are properly cancelled + /// and waited for to prevent resource leaks and dangling references. + ~ParallelUnitLoader() override { + for (auto& source : asyncSources_) { + if (source) { + source->cancel(); + } + } + } + + LoadUnit& getLoadedUnit(uint32_t unit) override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + + processedUnits_.insert(unit); + // Ensure sliding window of units [unit, unit + maxConcurrentLoads_) is + // loading + for (size_t i = unit; + i < loadUnits_.size() && i < unit + maxConcurrentLoads_; + ++i) { + if (!unitsLoaded_[i]) { + load(i); + } + } + + uint64_t unitLoadNanos{0}; + try { + NanosecondTimer timer{&unitLoadNanos}; + asyncSources_[unit]->move(); + } catch (const std::exception& e) { + VELOX_FAIL("Failed to load unit {}: {}", unit, e.what()); + } + waitForUnitReadyNanos_ += unitLoadNanos; + + // Unload the previous units + unloadUntil(unit); + + return *loadUnits_[unit]; + } + + void onRead(uint32_t unit, uint64_t rowOffsetInUnit, uint64_t /* rowCount */) + override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + VELOX_CHECK_LT( + rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); + } + + void onSeek(uint32_t unit, uint64_t rowOffsetInUnit) override { + VELOX_CHECK_LT(unit, loadUnits_.size(), "Unit out of range"); + VELOX_CHECK_LE( + rowOffsetInUnit, loadUnits_[unit]->getNumRows(), "Row out of range"); + } + + UnitLoaderStats stats() override { + UnitLoaderStats stats; + stats.addCounter("processedUnits", RuntimeCounter(processedUnits_.size())); + stats.addCounter( + "waitForUnitReadyNanos", + RuntimeCounter( + waitForUnitReadyNanos_ > std::numeric_limits::max() + ? std::numeric_limits::max() + : waitForUnitReadyNanos_, + RuntimeCounter::Unit::kNanos)); + return stats; + } + + private: + /// Submits the unit's load() to the I/O thread pool + void load(uint32_t unitIndex) { + VELOX_CHECK_LT(unitIndex, loadUnits_.size(), "Unit index out of bounds"); + VELOX_CHECK_NOT_NULL(ioExecutor_, "ParallelUnitLoader ioExecutor is null"); + VELOX_DCHECK(!loadUnits_.empty(), "loadUnits_ should not be empty"); + + // Capture shared_ptr by value to prevent use-after-free if + // ParallelUnitLoader is destroyed while async operation is running + auto unit = loadUnits_[unitIndex]; + auto asyncSource = std::make_shared>([unit] { + unit->load(); + return std::make_unique(); + }); + asyncSources_[unitIndex] = asyncSource; + ioExecutor_->add([asyncSource] { + velox::common::testutil::TestValue::adjust( + "facebook::velox::dwio::common::ParallelUnitLoader::load", + asyncSource.get()); + asyncSource->prepare(); + }); + unitsLoaded_[unitIndex] = true; + } + + /// Unloads all the units before 'unitIndex' + void unloadUntil(uint32_t unitIndex) { + for (size_t i = 0; i < unitIndex; ++i) { + if (unitsLoaded_[i]) { + loadUnits_[i]->unload(); + unitsLoaded_[i] = false; + } + } + } + + std::vector unitsLoaded_; + std::vector> loadUnits_; + std::vector>> asyncSources_; + folly::Executor* ioExecutor_; + size_t maxConcurrentLoads_; + + // Stats + std::unordered_set processedUnits_; + uint64_t waitForUnitReadyNanos_{0}; +}; + +std::unique_ptr ParallelUnitLoaderFactory::create( + std::vector> loadUnits, + uint64_t rowsToSkip) { + const auto totalRows = std::accumulate( + loadUnits.cbegin(), loadUnits.cend(), 0UL, [](uint64_t sum, auto& unit) { + return sum + unit->getNumRows(); + }); + VELOX_CHECK_LE( + rowsToSkip, + totalRows, + "Can only skip up to the past-the-end row of the file."); + return std::make_unique( + std::move(loadUnits), ioExecutor_, maxConcurrentLoads_); +} + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ParallelUnitLoader.h b/velox/dwio/common/ParallelUnitLoader.h new file mode 100644 index 000000000000..0ba89028326c --- /dev/null +++ b/velox/dwio/common/ParallelUnitLoader.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include "velox/dwio/common/UnitLoader.h" + +namespace facebook::velox::dwio::common { +class ParallelUnitLoaderFactory : public UnitLoaderFactory { + public: + ParallelUnitLoaderFactory( + folly::Executor* ioExecutor, + size_t maxConcurrentLoads) + : ioExecutor_(ioExecutor), maxConcurrentLoads_(maxConcurrentLoads) {} + + std::unique_ptr create( + std::vector> loadUnits, + uint64_t rowsToSkip) override; + + private: + folly::Executor* ioExecutor_; + size_t maxConcurrentLoads_; +}; + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/PositionProvider.h b/velox/dwio/common/PositionProvider.h index 7be3bc7a1602..c99655a87e7c 100644 --- a/velox/dwio/common/PositionProvider.h +++ b/velox/dwio/common/PositionProvider.h @@ -23,7 +23,7 @@ namespace facebook::velox::dwio::common { class PositionProvider { public: explicit PositionProvider(const std::vector& positions) - : position_{positions.begin()}, end_{positions.end()} {} + : position_{positions.cbegin()}, end_{positions.cend()} {} uint64_t next(); diff --git a/velox/dwio/common/RandGen.h b/velox/dwio/common/RandGen.h index b83743bbcca1..a1fc53e8dc7a 100644 --- a/velox/dwio/common/RandGen.h +++ b/velox/dwio/common/RandGen.h @@ -19,10 +19,7 @@ #include #include -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class RandGen { public: @@ -68,7 +65,4 @@ class RandGen { std::uniform_int_distribution dist_; }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/Range.h b/velox/dwio/common/Range.h index 8f12abff182f..d63a6506e47a 100644 --- a/velox/dwio/common/Range.h +++ b/velox/dwio/common/Range.h @@ -63,14 +63,8 @@ class Ranges { } } - bool operator==(const Iterator& other) const { - return std::tie(cur_, end_, val_) == - std::tie(other.cur_, other.end_, other.val_); - } - - bool operator!=(const Iterator& other) const { - return !operator==(other); - } + // TODO: Maybe only cur_ or cur_ with end_ should be compared? + bool operator==(const Iterator& other) const = default; Iterator& operator++() { VELOX_DCHECK(cur_ != end_); diff --git a/velox/dwio/common/Reader.cpp b/velox/dwio/common/Reader.cpp index 55066e979629..559dab5ae80e 100644 --- a/velox/dwio/common/Reader.cpp +++ b/velox/dwio/common/Reader.cpp @@ -61,7 +61,7 @@ VectorPtr RowReader::projectColumns( childType = inputRowType.childAt(childIdx); child = inputRow->childAt(childIdx); if (child) { - childSpec->applyFilter(*child, passed.data()); + childSpec->applyFilter(*child, inputRow->size(), passed.data()); } } if (!childSpec->projectOut()) { diff --git a/velox/dwio/common/Reader.h b/velox/dwio/common/Reader.h index 9dddfaeaca08..3a2fb37808f8 100644 --- a/velox/dwio/common/Reader.h +++ b/velox/dwio/common/Reader.h @@ -44,6 +44,13 @@ class RowReader { public: static constexpr int64_t kAtEnd = -1; + /// Runtime stat names. + /// Tracks the number of index columns that were converted from ScanSpec + /// filters to index bounds for index-based filtering (e.g., cluster index + /// pruning in Nimble). + static inline const std::string kNumIndexFilterConversions = + "numIndexFilterConversions"; + virtual ~RowReader() = default; /** diff --git a/velox/dwio/common/Retry.h b/velox/dwio/common/Retry.h index d4d52c120f28..9ea087d4361f 100644 --- a/velox/dwio/common/Retry.h +++ b/velox/dwio/common/Retry.h @@ -28,10 +28,7 @@ #include "velox/dwio/common/RandGen.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class retriable_error : public std::runtime_error { public: @@ -61,25 +58,24 @@ namespace retrypolicy { class IRetryPolicy { public: virtual ~IRetryPolicy() = default; - virtual folly::Optional nextWaitTime() = 0; + virtual std::optional nextWaitTime() = 0; virtual void start() {} }; class KAttempts : public IRetryPolicy { public: explicit KAttempts(std::vector durations) - : index_(0), durations_(std::move(durations)) {} + : durations_{std::move(durations)} {} - folly::Optional nextWaitTime() override { + std::optional nextWaitTime() override { if (index_ < durations_.size()) { - return folly::Optional(durations_[index_++]); - } else { - return folly::Optional(); + return durations_[index_++]; } + return std::nullopt; } private: - size_t index_; + size_t index_{0}; const std::vector durations_; }; @@ -105,16 +101,16 @@ class ExponentialBackoff : public IRetryPolicy { startTime_ = std::chrono::system_clock::now(); } - folly::Optional nextWaitTime() override { + std::optional nextWaitTime() override { if (retriesLeft_ == 0 || (maxTotal_.count() > 0 && total() >= maxTotal_)) { - return folly::Optional(); + return std::nullopt; } RetryDuration waitTime = nextWait_ + jitter(); nextWait_ = std::min(nextWait_ + nextWait_, maxWait_); --retriesLeft_; totalWait_ += waitTime; - return folly::Optional(waitTime); + return waitTime; } private: @@ -204,7 +200,7 @@ class RetryModule { return func(); } catch (const retriable_error& error) { LOG(INFO) << "RetryModule caught retriable exception. " << error.what(); - folly::Optional wait = policy->nextWaitTime(); + auto wait = policy->nextWaitTime(); if (wait.has_value()) { auto ms = wait.value().count(); LOG(INFO) << "RetryModule : Waiting for " << ms @@ -230,7 +226,4 @@ class RetryModule { } }; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/ScanSpec.cpp b/velox/dwio/common/ScanSpec.cpp index fc247ba4dd2a..6e3d7ae4b131 100644 --- a/velox/dwio/common/ScanSpec.cpp +++ b/velox/dwio/common/ScanSpec.cpp @@ -21,6 +21,21 @@ namespace facebook::velox::common { +// static +std::string_view ScanSpec::columnTypeString(ScanSpec::ColumnType columnType) { + switch (columnType) { + case ScanSpec::ColumnType::kRegular: + return "REGULAR"; + case ScanSpec::ColumnType::kRowIndex: + return "ROW_INDEX"; + case ScanSpec::ColumnType::kComposite: + return "COMPOSITE"; + default: + VELOX_UNREACHABLE( + "Unrecognized ColumnType: {}", static_cast(columnType)); + } +} + ScanSpec* ScanSpec::getOrCreateChild(const std::string& name) { if (auto it = this->childByFieldName_.find(name); it != this->childByFieldName_.end()) { @@ -37,7 +52,7 @@ ScanSpec* ScanSpec::getOrCreateChild(const Subfield& subfield) { const auto& path = subfield.path(); for (size_t depth = 0; depth < path.size(); ++depth) { const auto element = path[depth].get(); - VELOX_CHECK_EQ(element->kind(), kNestedField); + VELOX_CHECK_EQ(element->kind(), SubfieldKind::kNestedField); auto* nestedField = static_cast(element); container = container->getOrCreateChild(nestedField->name()); } @@ -89,8 +104,8 @@ uint64_t ScanSpec::newRead() { if (numReads_ == 0 || (!disableStatsBasedFilterReorder_ && !std::is_sorted( - children_.begin(), - children_.end(), + children_.cbegin(), + children_.cend(), [this]( const std::shared_ptr& left, const std::shared_ptr& right) { @@ -204,7 +219,7 @@ void ScanSpec::moveAdaptationFrom(ScanSpec& other) { namespace { bool testIntFilter( - common::Filter* filter, + const common::Filter* filter, dwio::common::IntegerColumnStatistics* intStats, bool mayHaveNull) { if (!intStats) { @@ -239,7 +254,7 @@ bool testIntFilter( } bool testDoubleFilter( - common::Filter* filter, + const common::Filter* filter, dwio::common::DoubleColumnStatistics* doubleStats, bool mayHaveNull) { if (!doubleStats) { @@ -274,7 +289,7 @@ bool testDoubleFilter( } bool testStringFilter( - common::Filter* filter, + const common::Filter* filter, dwio::common::StringColumnStatistics* stringStats, bool mayHaveNull) { if (!stringStats) { @@ -304,7 +319,7 @@ bool testStringFilter( } bool testBoolFilter( - common::Filter* filter, + const common::Filter* filter, dwio::common::BooleanColumnStatistics* boolStats) { const auto trueCount = boolStats->getTrueCount(); const auto falseCount = boolStats->getFalseCount(); @@ -325,7 +340,7 @@ bool testBoolFilter( } // namespace bool testFilter( - common::Filter* filter, + const common::Filter* filter, dwio::common::ColumnStatistics* stats, uint64_t totalRows, const TypePtr& type) { @@ -428,10 +443,6 @@ std::string ScanSpec::toString() const { return out.str(); } -void ScanSpec::addFilter(const Filter& filter) { - filter_ = filter_ ? filter_->mergeWith(&filter) : filter.clone(); -} - ScanSpec* ScanSpec::addField(const std::string& name, column_index_t channel) { auto child = getOrCreateChild(name); child->setProjectOut(true); @@ -510,7 +521,7 @@ namespace { template void filterSimpleVectorRows( const BaseVector& vector, - Filter& filter, + const Filter& filter, vector_size_t size, uint64_t* result) { VELOX_CHECK(size == 0 || result); @@ -529,9 +540,10 @@ void filterSimpleVectorRows( void filterRows( const BaseVector& vector, - Filter& filter, + const Filter& filter, vector_size_t size, uint64_t* result) { + VELOX_CHECK_LE(size, vector.size()); switch (vector.typeKind()) { case TypeKind::ARRAY: case TypeKind::MAP: @@ -562,20 +574,39 @@ void filterRows( } // namespace -void ScanSpec::applyFilter(const BaseVector& vector, uint64_t* result) const { +void ScanSpec::applyFilter( + const BaseVector& vector, + vector_size_t size, + uint64_t* result) const { if (filter_) { - filterRows(vector, *filter_, vector.size(), result); + filterRows(vector, *filter_, size, result); } if (!vector.type()->isRow()) { // Filter on MAP or ARRAY children are pruning, and won't affect correctness // of the result. return; } + auto& rowType = vector.type()->asRow(); - auto* rowVector = vector.asChecked(); - for (int i = 0; i < rowType.size(); ++i) { - if (auto* child = childByName(rowType.nameOf(i))) { - child->applyFilter(*rowVector->childAt(i), result); + if (vector.encoding() == VectorEncoding::Simple::ROW) { + auto rowVector = vector.asUnchecked(); + for (int i = 0; i < rowType.size(); ++i) { + if (auto* child = childByName(rowType.nameOf(i))) { + child->applyFilter(*rowVector->childAt(i), size, result); + } + } + } else { + DecodedVector decoded{vector}; + auto rowVector = decoded.base()->asUnchecked(); + + for (int i = 0; i < rowType.size(); ++i) { + if (auto* child = childByName(rowType.nameOf(i))) { + child->applyFilter( + *(decoded.wrap( + rowVector->childAt(i), *vector.pool(), vector.size())), + size, + result); + } } } } diff --git a/velox/dwio/common/ScanSpec.h b/velox/dwio/common/ScanSpec.h index a040fc5316bd..bee9ef5b89e1 100644 --- a/velox/dwio/common/ScanSpec.h +++ b/velox/dwio/common/ScanSpec.h @@ -28,17 +28,16 @@ #include -namespace facebook { -namespace velox { +namespace facebook::velox { namespace dwio::common { class ColumnStatistics; } namespace common { -// Describes the filtering and value extraction for a -// SelectiveColumnReader. This is owned by the TableScan Operator and -// is passed to SelectiveColumnReaders at construction. This is -// mutable by readers to reflect filter order and other adaptation. +/// Describes the filtering and value extraction for a +/// SelectiveColumnReader. This is owned by the TableScan Operator and +/// is passed to SelectiveColumnReaders at construction. This is +/// mutable by readers to reflect filter order and other adaptations. class ScanSpec { public: enum class ColumnType : int8_t { @@ -47,6 +46,9 @@ class ScanSpec { kComposite, // A struct with all children not read from file }; + /// Convert ColumnType to its string name representation. + static std::string_view columnTypeString(ColumnType columnType); + static constexpr column_index_t kNoChannel = ~0; static constexpr const char* kMapKeysFieldName = "keys"; static constexpr const char* kMapValuesFieldName = "values"; @@ -54,21 +56,19 @@ class ScanSpec { explicit ScanSpec(const std::string& name) : fieldName_(name) {} - // Filter to apply. If 'this' corresponds to a struct/list/map, this - // can only be isNull or isNotNull, other filtering is given by - // 'children'. - common::Filter* filter() const { + /// Filter to apply. If 'this' corresponds to a struct/list/map, this + /// can only be isNull or isNotNull, other filtering is given by + /// 'children'. + const common::Filter* filter() const { return filterDisabled_ ? nullptr : filter_.get(); } // Sets 'filter_'. May be used at initialization or when adding a // pushed down filter, e.g. top k cutoff. - void setFilter(std::unique_ptr filter) { + void setFilter(std::shared_ptr filter) { filter_ = std::move(filter); } - void addFilter(const Filter&); - void setMaxArrayElementsCount(vector_size_t count) { maxArrayElementsCount_ = count; } @@ -148,7 +148,9 @@ class ScanSpec { } void setSubscript(int64_t subscript) { - subscript_ = subscript; + if (subscript_ != subscript) { + subscript_ = subscript; + } } // True if the value is returned from scan. A runtime pushdown of a filter @@ -165,8 +167,8 @@ class ScanSpec { return projectOut_ || deltaUpdate_; } - // Position in the RowVector returned by the top level scan. Applies - // only to children of the root struct where projectOut_ is true. + /// Position in the RowVector returned by the top level scan. Applies + /// only to children of the root struct where projectOut_ is true. column_index_t channel() const { return channel_; } @@ -357,10 +359,17 @@ class ScanSpec { } } - /// Apply filter to the input `vector' and set the passed bits in `result'. + /// Apply filter to the first `size' rows of input `vector' and set the passed + /// bits in `result'. `size' is usually the size of top most RowVector, since + /// the child could be larger in some suboptimal/corrupted cases and we do not + /// want to crash the process for it. + /// /// This method is used by non-selective reader and delta update, so it /// ignores the filterDisabled_ state. - void applyFilter(const BaseVector& vector, uint64_t* result) const; + void applyFilter( + const BaseVector& vector, + vector_size_t size, + uint64_t* result) const; bool isFlatMapAsStruct() const { return isFlatMapAsStruct_; @@ -399,30 +408,30 @@ class ScanSpec { // Number of times read is called on the corresponding reader. This // is used for setup on first use and to produce a read sequence // number for LazyVectors. - uint64_t numReads_ = 0; + uint64_t numReads_{0}; // Ordinal position of 'this' in its containing spec. For a struct // member this is the position of the reader in the child // readers. If this describes an operation on an array element or a // map with numeric key, this is the subscript as defined for array // or map. - int64_t subscript_ = -1; + int64_t subscript_{-1}; // Column name if this is a struct mamber. String key if this // describes an operation on a map value. std::string fieldName_; // Ordinal position of the extracted value in the containing // RowVector. Set only when this describes a struct member. - column_index_t channel_ = kNoChannel; + column_index_t channel_{kNoChannel}; VectorPtr constantValue_; - bool projectOut_ = false; + bool projectOut_{false}; - ColumnType columnType_ = ColumnType::kRegular; + ColumnType columnType_{ColumnType::kRegular}; // True if a string dictionary or flat map in this field should be // returned as flat. - bool makeFlat_ = false; - std::unique_ptr filter_; + bool makeFlat_{false}; + std::shared_ptr filter_; bool filterDisabled_ = false; dwio::common::DeltaColumnUpdater* deltaUpdate_ = nullptr; @@ -437,6 +446,7 @@ class ScanSpec { SelectivityInfo selectivity_; std::vector> children_; + // Read-only copy of children, not subject to reordering. Used when // asynchronously constructing reader trees for read-ahead, while // 'children_' is reorderable by a running scan. @@ -495,11 +505,21 @@ void ScanSpec::visit(const Type& type, F&& f) { // Returns false if no value from a range defined by stats can pass the // filter. True, otherwise. bool testFilter( - common::Filter* filter, + const common::Filter* filter, dwio::common::ColumnStatistics* stats, uint64_t totalRows, const TypePtr& type); } // namespace common -} // namespace velox -} // namespace facebook +} // namespace facebook::velox + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::common::ScanSpec::ColumnType columnType, + format_context& ctx) const { + return formatter::format( + facebook::velox::common::ScanSpec::columnTypeString(columnType), ctx); + } +}; diff --git a/velox/dwio/common/SeekableInputStream.cpp b/velox/dwio/common/SeekableInputStream.cpp index 2f461551626b..db3c7f4a5028 100644 --- a/velox/dwio/common/SeekableInputStream.cpp +++ b/velox/dwio/common/SeekableInputStream.cpp @@ -163,8 +163,8 @@ bool SeekableArrayInputStream::SkipInt64(int64_t count) { return false; } -google::protobuf::int64 SeekableArrayInputStream::ByteCount() const { - return static_cast(position_); +int64_t SeekableArrayInputStream::ByteCount() const { + return static_cast(position_); } void SeekableArrayInputStream::seekToPosition(PositionProvider& position) { @@ -241,8 +241,8 @@ bool SeekableFileInputStream::SkipInt64(int64_t signedCount) { return position_ < length_; } -google::protobuf::int64 SeekableFileInputStream::ByteCount() const { - return static_cast(position_); +int64_t SeekableFileInputStream::ByteCount() const { + return static_cast(position_); } void SeekableFileInputStream::seekToPosition(PositionProvider& location) { diff --git a/velox/dwio/common/SeekableInputStream.h b/velox/dwio/common/SeekableInputStream.h index c53347a6204f..f4e33008523c 100644 --- a/velox/dwio/common/SeekableInputStream.h +++ b/velox/dwio/common/SeekableInputStream.h @@ -79,7 +79,7 @@ class SeekableArrayInputStream : public SeekableInputStream { virtual bool Next(const void** data, int32_t* size) override; virtual void BackUp(int32_t count) override; virtual bool SkipInt64(int64_t count) override; - virtual google::protobuf::int64 ByteCount() const override; + virtual int64_t ByteCount() const override; virtual void seekToPosition(PositionProvider& position) override; virtual std::string getName() const override; virtual size_t positionSize() const override; @@ -120,7 +120,7 @@ class SeekableFileInputStream : public SeekableInputStream { virtual bool Next(const void** data, int32_t* size) override; virtual void BackUp(int32_t count) override; virtual bool SkipInt64(int64_t count) override; - virtual google::protobuf::int64 ByteCount() const override; + virtual int64_t ByteCount() const override; virtual void seekToPosition(PositionProvider& position) override; virtual std::string getName() const override; virtual size_t positionSize() const override; diff --git a/velox/dwio/common/SelectiveByteRleColumnReader.h b/velox/dwio/common/SelectiveByteRleColumnReader.h index 3c437aba973b..c276e31fb22d 100644 --- a/velox/dwio/common/SelectiveByteRleColumnReader.h +++ b/velox/dwio/common/SelectiveByteRleColumnReader.h @@ -41,7 +41,7 @@ class SelectiveByteRleColumnReader : public SelectiveColumnReader { bool kEncodingHasNulls, typename ExtractValues> void processFilter( - velox::common::Filter* filter, + const velox::common::Filter* filter, ExtractValues extractValues, const RowSet& rows); @@ -54,7 +54,7 @@ class SelectiveByteRleColumnReader : public SelectiveColumnReader { bool isDense, typename ExtractValues> void readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues); @@ -69,13 +69,13 @@ template < bool isDense, typename ExtractValues> void SelectiveByteRleColumnReader::readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues) { reinterpret_cast(this)->readWithVisitor( rows, ColumnVisitor( - *reinterpret_cast(filter), this, rows, extractValues)); + *static_cast(filter), this, rows, extractValues)); } template < @@ -84,7 +84,7 @@ template < bool kEncodingHasNulls, typename ExtractValues> void SelectiveByteRleColumnReader::processFilter( - velox::common::Filter* filter, + const velox::common::Filter* filter, ExtractValues extractValues, const RowSet& rows) { using velox::common::FilterKind; @@ -162,7 +162,7 @@ void SelectiveByteRleColumnReader::readCommon( const uint64_t* incomingNulls) { prepareRead(offset, rows, incomingNulls); const bool isDense = rows.back() == rows.size() - 1; - velox::common::Filter* filter = + auto* filter = scanSpec_->filter() ? scanSpec_->filter() : &dwio::common::alwaysTrue(); if (scanSpec_->keepValues()) { if (scanSpec_->valueHook()) { diff --git a/velox/dwio/common/SelectiveColumnReader.cpp b/velox/dwio/common/SelectiveColumnReader.cpp index 17b2207d65af..df67694643f1 100644 --- a/velox/dwio/common/SelectiveColumnReader.cpp +++ b/velox/dwio/common/SelectiveColumnReader.cpp @@ -410,7 +410,7 @@ void SelectiveColumnReader::compactScalarValues( values_->setSize(bits::nbytes(numValues_)); } -char* SelectiveColumnReader::copyStringValue(folly::StringPiece value) { +char* SelectiveColumnReader::copyStringValue(std::string_view value) { uint64_t size = value.size(); if (stringBuffers_.empty() || rawStringUsed_ + size > rawStringSize_) { auto bytes = std::max(size, kStringBufferSize); @@ -431,7 +431,7 @@ char* SelectiveColumnReader::copyStringValue(folly::StringPiece value) { return rawStringBuffer_ + start; } -void SelectiveColumnReader::addStringValue(folly::StringPiece value) { +void SelectiveColumnReader::addStringValue(std::string_view value) { auto copy = copyStringValue(value); reinterpret_cast(rawValues_)[numValues_++] = StringView(copy, value.size()); @@ -449,8 +449,11 @@ void SelectiveColumnReader::setNulls(BufferPtr resultNulls) { void SelectiveColumnReader::resetFilterCaches() { if (scanState_.filterCache.empty() && scanSpec_->hasFilter()) { - scanState_.filterCache.resize(std::max( - 1, scanState_.dictionary.numValues + scanState_.dictionary2.numValues)); + scanState_.filterCache.resize( + std::max( + 1, + scanState_.dictionary.numValues + + scanState_.dictionary2.numValues)); scanState_.updateRawState(); } if (!scanState_.filterCache.empty()) { @@ -476,7 +479,7 @@ void SelectiveColumnReader::addSkippedParentNulls( int64_t from, int64_t to, int32_t numNulls) { - auto rowsPerRowGroup = formatData_->rowsPerRowGroup(); + const auto rowsPerRowGroup = formatData_->rowsPerRowGroup(); if (rowsPerRowGroup.has_value() && from / rowsPerRowGroup.value() > parentNullsRecordedTo_ / rowsPerRowGroup.value()) { @@ -484,7 +487,7 @@ void SelectiveColumnReader::addSkippedParentNulls( parentNullsRecordedTo_ = from; numParentNulls_ = 0; } - if (parentNullsRecordedTo_) { + if (parentNullsRecordedTo_ > 0) { VELOX_CHECK_EQ(parentNullsRecordedTo_, from); } numParentNulls_ += numNulls; diff --git a/velox/dwio/common/SelectiveColumnReader.h b/velox/dwio/common/SelectiveColumnReader.h index 8352c9b388d1..dbbbd55ec938 100644 --- a/velox/dwio/common/SelectiveColumnReader.h +++ b/velox/dwio/common/SelectiveColumnReader.h @@ -27,6 +27,8 @@ namespace facebook::velox::dwio::common { +using ScanSpec = velox::common::ScanSpec; + /// Generalized representation of a set of distinct values for dictionary /// encodings. struct DictionaryValues { @@ -161,12 +163,11 @@ class SelectiveColumnReader { // from a downstream operator. virtual void resetFilterCaches(); - // Seeks to offset and reads the rows in 'rows' and applies - // filters and value processing as given by 'scanSpec supplied at - // construction. 'offset' is relative to start of stripe. 'rows' are - // relative to 'offset', so that row 0 is the 'offset'th row from - // start of stripe. 'rows' is expected to stay constant - // between this and the next call to read. + // Seeks to offset and reads the rows in 'rows' and applies filters and value + // processing as given by 'scanSpec supplied at construction. 'offset' is + // relative to start of stripe. 'rows' are relative to 'offset', so that row 0 + // is the 'offset'th row from start of stripe. 'rows' is expected to stay + // constant between this and the next call to read. virtual void read(int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) = 0; @@ -335,7 +336,7 @@ class SelectiveColumnReader { template inline void addValue(T value) { static_assert( - std::is_pod_v, + std::is_standard_layout_v, "General case of addValue is only for primitive types"); VELOX_DCHECK_NOT_NULL(rawValues_); VELOX_DCHECK_LE((numValues_ + 1) * sizeof(T), values_->capacity()); @@ -483,7 +484,7 @@ class SelectiveColumnReader { return false; } - StringView copyStringValueIfNeed(folly::StringPiece value) { + StringView copyStringValueIfNeed(std::string_view value) { if (value.size() <= StringView::kInlineSize) { return StringView(value); } @@ -495,6 +496,10 @@ class SelectiveColumnReader { VELOX_UNREACHABLE("Only struct reader supports this method"); } + memory::MemoryPool* memoryPool() const { + return memoryPool_; + } + protected: template void prepareRead( @@ -578,11 +583,11 @@ class SelectiveColumnReader { // Checks consistency of nulls-related state. const uint64_t* shouldMoveNulls(const RowSet& rows); - void addStringValue(folly::StringPiece value); + void addStringValue(std::string_view value); // Copies 'value' to buffers owned by 'this' and returns the start of the // copy. - char* copyStringValue(folly::StringPiece value); + char* copyStringValue(std::string_view value); virtual bool hasDeletion() const { return false; @@ -729,7 +734,7 @@ class SelectiveColumnReader { }; template <> -inline void SelectiveColumnReader::addValue(const folly::StringPiece value) { +inline void SelectiveColumnReader::addValue(const std::string_view value) { const uint64_t size = value.size(); if (size <= StringView::kInlineSize) { reinterpret_cast(rawValues_)[numValues_++] = @@ -763,7 +768,7 @@ struct NoHook final : public ValueHook { void addValue(vector_size_t /*row*/, double /*value*/) final {} - void addValue(vector_size_t /*row*/, folly::StringPiece /*value*/) final {} + void addValue(vector_size_t /*row*/, std::string_view /*value*/) final {} }; } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveColumnReaderInternal.h b/velox/dwio/common/SelectiveColumnReaderInternal.h index f491870bfff9..c2ff5b0b0e84 100644 --- a/velox/dwio/common/SelectiveColumnReaderInternal.h +++ b/velox/dwio/common/SelectiveColumnReaderInternal.h @@ -21,8 +21,8 @@ #include "velox/dwio/common/DirectDecoder.h" #include "velox/dwio/common/SelectiveColumnReader.h" #include "velox/dwio/common/TypeUtils.h" -#include "velox/exec/AggregationHook.h" #include "velox/type/Timestamp.h" +#include "velox/vector/AggregationHook.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/DictionaryVector.h" #include "velox/vector/FlatVector.h" diff --git a/velox/dwio/common/SelectiveFlatMapColumnReader.cpp b/velox/dwio/common/SelectiveFlatMapColumnReader.cpp new file mode 100644 index 000000000000..05bf1a7a9915 --- /dev/null +++ b/velox/dwio/common/SelectiveFlatMapColumnReader.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/SelectiveFlatMapColumnReader.h" + +namespace facebook::velox::dwio::common { +namespace { + +// Check if `result` can be reused to prevent allocation of a new vector. +FlatMapVector* tryReuseResult( + const VectorPtr& result, + const VectorPtr& distinctKeys, + vector_size_t size) { + if (result.use_count() != 1) { + return nullptr; + } + + if (result->encoding() == VectorEncoding::Simple::FLAT_MAP) { + auto* flatMap = static_cast(result.get()); + if (flatMap->distinctKeys() != distinctKeys) { + flatMap->setDistinctKeys(distinctKeys); + } + flatMap->resize(size); + return flatMap; + } + return nullptr; +} +} // namespace + +void SelectiveFlatMapColumnReader::getValues( + const RowSet& rows, + VectorPtr* result) { + VELOX_CHECK(!scanSpec_->children().empty()); + VELOX_CHECK_NOT_NULL( + *result, "SelectiveFlatMapColumnReader expects a non-null result"); + VELOX_CHECK( + result->get()->type()->isMap(), + "Struct reader expects a result of type MAP."); + + if (rows.empty()) { + return; + } + + auto* resultFlatMap = prepareResult(*result, keysVector_, rows.size()); + setComplexNulls(rows, *result); + + for (const auto& childSpec : scanSpec_->children()) { + VELOX_TRACE_HISTORY_PUSH("getValues %s", childSpec->fieldName().c_str()); + if (!childSpec->keepValues()) { + continue; + } + + VELOX_CHECK( + childSpec->readFromFile(), + "Flatmap children must always be read from file."); + + if (childSpec->subscript() == kConstantChildSpecSubscript) { + continue; + } + + const auto channel = childSpec->channel(); + const auto index = childSpec->subscript(); + auto& childResult = resultFlatMap->mapValuesAt(channel); + + VELOX_CHECK( + !childSpec->deltaUpdate(), + "Delta update not supported in flat map yet"); + VELOX_CHECK( + !childSpec->isConstant(), + "Flat map values cannot be constant in scanSpec."); + VELOX_CHECK_EQ( + childSpec->columnType(), + velox::common::ScanSpec::ColumnType::kRegular, + "Flat map only supports regular column types in scan spec."); + + children_[index]->getValues(rows, &childResult); + + for (size_t i = 0; i < children_.size(); ++i) { + const auto& inMap = inMapBuffer(i); + if (inMap) { + resultFlatMap->inMapsAt(i, true) = inMap; + } + } + } +} + +FlatMapVector* SelectiveFlatMapColumnReader::prepareResult( + VectorPtr& result, + const VectorPtr& distinctKeys, + vector_size_t size) const { + if (auto reused = tryReuseResult(result, distinctKeys, size)) { + return reused; + } + + auto flatMap = std::make_shared( + result->pool(), + result->type(), + nullptr, + size, + distinctKeys, + std::vector(distinctKeys->size()), + std::vector{}); + result = flatMap; + return flatMap.get(); +} +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveFlatMapColumnReader.h b/velox/dwio/common/SelectiveFlatMapColumnReader.h new file mode 100644 index 000000000000..8a939ffa7bbe --- /dev/null +++ b/velox/dwio/common/SelectiveFlatMapColumnReader.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/dwio/common/SelectiveStructColumnReader.h" +#include "velox/vector/FlatMapVector.h" + +namespace facebook::velox::dwio::common { + +class SelectiveFlatMapColumnReader : public SelectiveStructColumnReaderBase { + protected: + SelectiveFlatMapColumnReader( + const TypePtr& requestedType, + const std::shared_ptr& fileType, + FormatParams& params, + velox::common::ScanSpec& scanSpec) + : SelectiveStructColumnReaderBase( + requestedType, + fileType, + params, + scanSpec, + false, + false) {} + + void getValues(const RowSet& rows, VectorPtr* result) override; + + void seekTo(int64_t offset, bool /*readsNullsOnly*/) override { + seekToPropagateNullsToChildren(offset); + } + + virtual const BufferPtr& inMapBuffer(column_index_t childIndex) const = 0; + + VectorPtr keysVector_; + + private: + FlatMapVector* prepareResult( + VectorPtr& result, + const VectorPtr& distinctKeys, + vector_size_t size) const; +}; + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveFloatingPointColumnReader.h b/velox/dwio/common/SelectiveFloatingPointColumnReader.h index 7de695fee86e..1f3445918b6d 100644 --- a/velox/dwio/common/SelectiveFloatingPointColumnReader.h +++ b/velox/dwio/common/SelectiveFloatingPointColumnReader.h @@ -35,9 +35,11 @@ class SelectiveFloatingPointColumnReader : public SelectiveColumnReader { params, scanSpec) {} - // Offers fast path only if data and result widths match. + // Offers a fast path only if data and result widths match. + static constexpr bool kHasBulkPath = std::is_same_v; + bool hasBulkPath() const override { - return std::is_same_v; + return kHasBulkPath; } template @@ -55,7 +57,7 @@ class SelectiveFloatingPointColumnReader : public SelectiveColumnReader { bool isDense, typename ExtractValues> void readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues values); @@ -65,7 +67,7 @@ class SelectiveFloatingPointColumnReader : public SelectiveColumnReader { bool kEncodingHasNulls, typename ExtractValues> void processFilter( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues); @@ -80,7 +82,7 @@ template < bool isDense, typename ExtractValues> void SelectiveFloatingPointColumnReader::readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues) { reinterpret_cast(this)->readWithVisitor( @@ -90,8 +92,8 @@ void SelectiveFloatingPointColumnReader::readHelper( TFilter, ExtractValues, isDense, - std::is_same_v>( - *reinterpret_cast(filter), this, rows, extractValues)); + Reader::kHasBulkPath>( + *static_cast(filter), this, rows, extractValues)); } template @@ -101,7 +103,7 @@ template < bool kEncodingHasNulls, typename ExtractValues> void SelectiveFloatingPointColumnReader::processFilter( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues) { if (filter == nullptr) { diff --git a/velox/dwio/common/SelectiveIntegerColumnReader.h b/velox/dwio/common/SelectiveIntegerColumnReader.h index c94b5a60d789..d9ba805167a8 100644 --- a/velox/dwio/common/SelectiveIntegerColumnReader.h +++ b/velox/dwio/common/SelectiveIntegerColumnReader.h @@ -47,7 +47,7 @@ class SelectiveIntegerColumnReader : public SelectiveColumnReader { bool kEncodingHasNulls, typename ExtractValues> void processFilter( - velox::common::Filter* filter, + const velox::common::Filter* filter, ExtractValues extractValues, const RowSet& rows); @@ -63,7 +63,7 @@ class SelectiveIntegerColumnReader : public SelectiveColumnReader { bool isDense, typename ExtractValues> void readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues); @@ -80,7 +80,7 @@ template < bool isDense, typename ExtractValues> void SelectiveIntegerColumnReader::readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues) { switch (valueSize_) { @@ -88,28 +88,28 @@ void SelectiveIntegerColumnReader::readHelper( reinterpret_cast(this)->Reader::readWithVisitor( rows, ColumnVisitor( - *reinterpret_cast(filter), this, rows, extractValues)); + *static_cast(filter), this, rows, extractValues)); break; case 4: reinterpret_cast(this)->Reader::readWithVisitor( rows, ColumnVisitor( - *reinterpret_cast(filter), this, rows, extractValues)); + *static_cast(filter), this, rows, extractValues)); break; case 8: reinterpret_cast(this)->Reader::readWithVisitor( rows, ColumnVisitor( - *reinterpret_cast(filter), this, rows, extractValues)); + *static_cast(filter), this, rows, extractValues)); break; case 16: reinterpret_cast(this)->Reader::readWithVisitor( rows, ColumnVisitor( - *reinterpret_cast(filter), this, rows, extractValues)); + *static_cast(filter), this, rows, extractValues)); break; default: @@ -123,7 +123,7 @@ template < bool kEncodingHasNulls, typename ExtractValues> void SelectiveIntegerColumnReader::processFilter( - velox::common::Filter* filter, + const velox::common::Filter* filter, ExtractValues extractValues, const RowSet& rows) { if (filter == nullptr) { @@ -200,6 +200,13 @@ void SelectiveIntegerColumnReader::processFilter( velox::common::NegatedBigintValuesUsingBitmask, isDense>(filter, rows, extractValues); break; + case velox::common::FilterKind::kBigintValuesUsingBloomFilter: + static_cast(this) + ->template readHelper< + Reader, + velox::common::BigintValuesUsingBloomFilter, + isDense>(filter, rows, extractValues); + break; default: static_cast(this) ->template readHelper( @@ -246,8 +253,7 @@ void SelectiveIntegerColumnReader::processValueHook( template void SelectiveIntegerColumnReader::readCommon(const RowSet& rows) { const bool isDense = rows.back() == rows.size() - 1; - velox::common::Filter* filter = - scanSpec_->filter() ? scanSpec_->filter() : &alwaysTrue(); + auto* filter = scanSpec_->filter() ? scanSpec_->filter() : &alwaysTrue(); if (scanSpec_->keepValues()) { if (scanSpec_->valueHook()) { if (isDense) { diff --git a/velox/dwio/common/SelectiveRepeatedColumnReader.cpp b/velox/dwio/common/SelectiveRepeatedColumnReader.cpp index ba6600002651..cee47f445c5f 100644 --- a/velox/dwio/common/SelectiveRepeatedColumnReader.cpp +++ b/velox/dwio/common/SelectiveRepeatedColumnReader.cpp @@ -254,6 +254,9 @@ void SelectiveListColumnReader::read( makeNestedRowSet(activeRows, rows.back()); if (child_ && !nestedRows_.empty()) { child_->read(child_->readOffset(), nestedRows_, nullptr); + nestedRowsAllSelected_ = nestedRowsAllSelected_ && + nestedRows_.size() == child_->outputRows().size(); + nestedRows_ = child_->outputRows(); } numValues_ = activeRows.size(); readOffset_ = offset + rows.back() + 1; @@ -274,15 +277,7 @@ void SelectiveListColumnReader::getValues( } } -SelectiveMapColumnReader::SelectiveMapColumnReader( - const TypePtr& requestedType, - const std::shared_ptr& fileType, - FormatParams& params, - velox::common::ScanSpec& scanSpec) - : SelectiveRepeatedColumnReader(requestedType, params, scanSpec, fileType) { -} - -uint64_t SelectiveMapColumnReader::skip(uint64_t numValues) { +uint64_t SelectiveMapColumnReaderBase::skip(uint64_t numValues) { numValues = formatData_->skipNulls(numValues); if (keyReader_ || elementReader_) { std::array buffer; @@ -312,7 +307,7 @@ uint64_t SelectiveMapColumnReader::skip(uint64_t numValues) { return numValues; } -void SelectiveMapColumnReader::read( +void SelectiveMapColumnReaderBase::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { @@ -338,12 +333,34 @@ void SelectiveMapColumnReader::read( nestedRows_ = keyReader_->outputRows(); if (!nestedRows_.empty()) { elementReader_->read(elementReader_->readOffset(), nestedRows_, nullptr); + nestedRowsAllSelected_ = nestedRowsAllSelected_ && + nestedRows_.size() == elementReader_->outputRows().size(); + nestedRows_ = elementReader_->outputRows(); } } numValues_ = activeRows.size(); readOffset_ = offset + rows.back() + 1; } +SelectiveMapColumnReader::SelectiveMapColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec) + : SelectiveMapColumnReaderBase(requestedType, params, scanSpec, fileType) { + VELOX_CHECK(!scanSpec_->isFlatMapAsStruct()); + // We should not need this anymore. Is there a safe way to find out if there + // is any prod usages that forget to set up the map children in scan spec? + // This should be only possible when user bypasses the connector interface and + // create file readers directly. + if (scanSpec_->children().empty()) { + scanSpec_->getOrCreateChild(ScanSpec::kMapKeysFieldName); + scanSpec_->getOrCreateChild(ScanSpec::kMapValuesFieldName); + } + scanSpec_->children()[0]->setProjectOut(true); + scanSpec_->children()[1]->setProjectOut(true); +} + void SelectiveMapColumnReader::getValues( const RowSet& rows, VectorPtr* result) { @@ -368,4 +385,110 @@ void SelectiveMapColumnReader::getValues( } } +SelectiveMapAsStructColumnReader::SelectiveMapAsStructColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec) + : SelectiveMapColumnReaderBase(requestedType, params, scanSpec, fileType) { + VELOX_CHECK(scanSpec_->isFlatMapAsStruct() && requestedType_->isMap()); + mapScanSpec_.addMapKeyFieldRecursively(*requestedType_->childAt(0)); + mapScanSpec_.addMapValueFieldRecursively(*requestedType_->childAt(1)); + column_index_t maxChannel = 0; + for (auto& childSpec : scanSpec_->children()) { + auto field = folly::tryTo(childSpec->fieldName()); + VELOX_CHECK( + field.hasValue(), + "Fail to parse field name: {}", + childSpec->fieldName()); + keyToIndex_[*field] = childSpec->channel(); + maxChannel = std::max(maxChannel, childSpec->channel()); + } + copyRanges_.resize(maxChannel + 1); +} + +void SelectiveMapAsStructColumnReader::getValues( + const RowSet& rows, + VectorPtr* result) { + VELOX_CHECK_NOT_NULL(*result); + VELOX_CHECK( + result->get()->type()->isRow(), + "Expect ROW, got {}", + result->get()->type()->toString()); + BaseVector::prepareForReuse(*result, rows.size()); + auto* resultRow = result->get()->asChecked(); + setComplexNulls(rows, *result); + for (auto& child : resultRow->children()) { + bits::fillBits(child->mutableRawNulls(), 0, rows.size(), bits::kNull); + } + numValues_ = rows.size(); + if (nestedRows_.empty()) { + return; + } + keyReader_->getValues(nestedRows_, &mapKeys_); + prepareStructResult(requestedType_->childAt(1), &mapValues_); + elementReader_->getValues(nestedRows_, &mapValues_); + decodedKeys_.decode(*mapKeys_); + for (auto& ranges : copyRanges_) { + ranges.clear(); + } + switch (mapKeys_->type()->kind()) { + case TypeKind::TINYINT: + makeCopyRanges(rows); + break; + case TypeKind::SMALLINT: + makeCopyRanges(rows); + break; + case TypeKind::INTEGER: + makeCopyRanges(rows); + break; + case TypeKind::BIGINT: + makeCopyRanges(rows); + break; + default: + VELOX_UNSUPPORTED( + "Unsupported key type: {}", mapKeys_->type()->toString()); + } + for (column_index_t i = 0; i < resultRow->childrenSize(); ++i) { + resultRow->childAt(i)->copyRanges(mapValues_.get(), copyRanges_[i]); + } +} + +template +void SelectiveMapAsStructColumnReader::makeCopyRanges(const RowSet& rows) { + auto* nulls = nullsInReadRange_ ? nullsInReadRange_->as() : nullptr; + for (vector_size_t i = 0, + currentOffset = 0, + currentRow = 0, + nestedRowIndex = 0; + i < rows.size(); + ++i) { + const auto row = rows[i]; + if (nulls && bits::isBitNull(nulls, row)) { + anyNulls_ = true; + continue; + } + currentOffset += sumLengths(allLengths_, nulls, currentRow, row); + currentRow = row + 1; + nestedRowIndex = + advanceNestedRows(nestedRows_, nestedRowIndex, currentOffset); + currentOffset += allLengths_[row]; + const auto newNestedRowIndex = + advanceNestedRows(nestedRows_, nestedRowIndex, currentOffset); + for (auto j = nestedRowIndex; j < newNestedRowIndex; ++j) { + VELOX_CHECK(!decodedKeys_.isNullAt(j)); + auto it = keyToIndex_.find(decodedKeys_.valueAt(j)); + if (it == keyToIndex_.end()) { + continue; + } + copyRanges_[it->second].push_back({ + .sourceIndex = j, + .targetIndex = i, + .count = 1, + }); + } + nestedRowIndex = newNestedRowIndex; + } +} + } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveRepeatedColumnReader.h b/velox/dwio/common/SelectiveRepeatedColumnReader.h index f1792ff9511e..de072b8d028b 100644 --- a/velox/dwio/common/SelectiveRepeatedColumnReader.h +++ b/velox/dwio/common/SelectiveRepeatedColumnReader.h @@ -128,13 +128,9 @@ class SelectiveListColumnReader : public SelectiveRepeatedColumnReader { std::unique_ptr child_; }; -class SelectiveMapColumnReader : public SelectiveRepeatedColumnReader { +class SelectiveMapColumnReaderBase : public SelectiveRepeatedColumnReader { public: - SelectiveMapColumnReader( - const TypePtr& requestedType, - const std::shared_ptr& fileType, - FormatParams& params, - velox::common::ScanSpec& scanSpec); + using SelectiveRepeatedColumnReader::SelectiveRepeatedColumnReader; void resetFilterCaches() override { keyReader_->resetFilterCaches(); @@ -146,11 +142,53 @@ class SelectiveMapColumnReader : public SelectiveRepeatedColumnReader { void read(int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) override; - void getValues(const RowSet& rows, VectorPtr* result) override; + void seekToRowGroup(int64_t index) override { + SelectiveRepeatedColumnReader::seekToRowGroup(index); + keyReader_->seekToRowGroup(index); + keyReader_->setReadOffsetRecursive(0); + elementReader_->seekToRowGroup(index); + elementReader_->setReadOffsetRecursive(0); + childTargetReadOffset_ = 0; + } protected: std::unique_ptr keyReader_; std::unique_ptr elementReader_; }; +class SelectiveMapColumnReader : public SelectiveMapColumnReaderBase { + public: + SelectiveMapColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec); + + void getValues(const RowSet& rows, VectorPtr* result) override; +}; + +class SelectiveMapAsStructColumnReader : public SelectiveMapColumnReaderBase { + public: + SelectiveMapAsStructColumnReader( + const TypePtr& requestedType, + const TypeWithIdPtr& fileType, + FormatParams& params, + ScanSpec& scanSpec); + + void getValues(const RowSet& rows, VectorPtr* result) override; + + protected: + ScanSpec mapScanSpec_{""}; + + private: + template + void makeCopyRanges(const RowSet& rows); + + folly::F14FastMap keyToIndex_; + std::vector> copyRanges_; + VectorPtr mapKeys_; + VectorPtr mapValues_; + DecodedVector decodedKeys_; +}; + } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/SelectiveStructColumnReader.cpp b/velox/dwio/common/SelectiveStructColumnReader.cpp index 4401e73977bb..e1a990fee89c 100644 --- a/velox/dwio/common/SelectiveStructColumnReader.cpp +++ b/velox/dwio/common/SelectiveStructColumnReader.cpp @@ -20,7 +20,6 @@ #include "velox/dwio/common/ColumnLoader.h" namespace facebook::velox::dwio::common { - namespace { bool testFilterOnConstant(const velox::common::ScanSpec& spec) { @@ -224,6 +223,74 @@ uint64_t SelectiveStructColumnReaderBase::skip(uint64_t numValues) { return numValues; } +void SelectiveStructColumnReaderBase::seekToPropagateNullsToChildren( + int64_t offset) { + if (offset == readOffset_) { + return; + } + if (readOffset_ < offset) { + if (numParentNulls_) { + VELOX_CHECK_LE( + parentNullsRecordedTo_, + offset, + "Must not seek to before parentNullsRecordedTo_"); + } + auto distance = offset - readOffset_ - numParentNulls_; + auto numNonNulls = formatData_->skipNulls(distance); + // We inform children how many nulls there were between original position + // and destination. The children will seek this many less. The + // nulls include the nulls found here as well as the enclosing + // level nulls reported to this by parents. + for (auto& child : children_) { + if (child) { + child->addSkippedParentNulls( + readOffset_, offset, numParentNulls_ + distance - numNonNulls); + } + } + numParentNulls_ = 0; + parentNullsRecordedTo_ = 0; + readOffset_ = offset; + } else { + VELOX_FAIL("Seeking backward on a ColumnReader"); + } +} + +void SelectiveStructColumnReaderBase::seekToRowGroupFixedRowsPerRowGroup( + const int64_t index, + const int32_t rowsPerRowGroup) { + dwio::common::SelectiveStructColumnReaderBase::seekToRowGroup(index); + if (isTopLevel_ && !formatData_->hasNulls()) { + readOffset_ = index * rowsPerRowGroup; + return; + } + + // There may be a nulls stream but no other streams for the struct. + formatData_->seekToRowGroup(index); + // Set the read offset recursively. Do this before seeking the children + // because list/map children will reset the offsets for their children. + setReadOffsetRecursive(index * rowsPerRowGroup); + for (auto& child : children_) { + child->seekToRowGroup(index); + } +} + +/// Advance field reader to the row group closest to specified offset by +/// calling seekToRowGroup. +void SelectiveStructColumnReaderBase::advanceFieldReaderFixedRowsPerRowGroup( + dwio::common::SelectiveColumnReader* reader, + const int64_t offset, + const int32_t rowsPerRowGroup) { + if (!reader->isTopLevel()) { + return; + } + const auto rowGroup = reader->readOffset() / rowsPerRowGroup; + const auto nextRowGroup = offset / rowsPerRowGroup; + if (nextRowGroup > rowGroup) { + reader->seekToRowGroup(nextRowGroup); + reader->setReadOffset(nextRowGroup * rowsPerRowGroup); + } +} + void SelectiveStructColumnReaderBase::fillOutputRowsFromMutation( vector_size_t size) { if (mutation_->deletedRows) { @@ -339,8 +406,7 @@ void SelectiveStructColumnReaderBase::read( activeRows = outputRows_; } - const uint64_t* structNulls = - nullsInReadRange_ ? nullsInReadRange_->as() : nullptr; + const uint64_t* structNulls = nulls(); // A struct reader may have a null/non-null filter if (scanSpec_->filter()) { const auto kind = scanSpec_->filter()->kind(); @@ -362,7 +428,6 @@ void SelectiveStructColumnReaderBase::read( VELOX_CHECK(!childSpecs.empty()); for (size_t i = 0; i < childSpecs.size(); ++i) { const auto& childSpec = childSpecs[i]; - VELOX_TRACE_HISTORY_PUSH("read %s", childSpec->fieldName().c_str()); if (childSpec->deltaUpdate()) { // Will make LazyVector. @@ -385,7 +450,7 @@ void SelectiveStructColumnReaderBase::read( const auto fieldIndex = childSpec->subscript(); auto* reader = children_.at(fieldIndex); if (reader->isTopLevel() && childSpec->projectOut() && - !childSpec->hasFilter()) { + !childSpec->hasFilter() && generateLazyChildren_) { // Will make a LazyVector. continue; } @@ -441,10 +506,7 @@ void SelectiveStructColumnReaderBase::recordParentNullsInChildren( const auto fieldIndex = childSpec->subscript(); auto* reader = children_.at(fieldIndex); - reader->addParentNulls( - offset, - nullsInReadRange_ ? nullsInReadRange_->as() : nullptr, - rows); + reader->addParentNulls(offset, nulls(), rows); } } @@ -467,6 +529,12 @@ bool SelectiveStructColumnReaderBase::isChildMissing( childSpec.channel() >= fileType_->size()); } +std::unique_ptr +SelectiveStructColumnReaderBase::makeColumnLoader(vector_size_t index) { + return std::make_unique( + this, children_[index], numReads_); +} + void SelectiveStructColumnReaderBase::getValues( const RowSet& rows, VectorPtr* result) { @@ -486,7 +554,6 @@ void SelectiveStructColumnReaderBase::getValues( setComplexNulls(rows, *result); for (const auto& childSpec : scanSpec_->children()) { - VELOX_TRACE_HISTORY_PUSH("getValues %s", childSpec->fieldName().c_str()); if (!childSpec->keepValues()) { continue; } @@ -543,7 +610,8 @@ void SelectiveStructColumnReaderBase::getValues( continue; } - if (childSpec->hasFilter() || !children_[index]->isTopLevel()) { + if (childSpec->hasFilter() || !children_[index]->isTopLevel() || + !generateLazyChildren_) { children_[index]->getValues(rows, &childResult); continue; } @@ -551,7 +619,7 @@ void SelectiveStructColumnReaderBase::getValues( // LazyVector result. setOutputRowsForLazy(rows); setLazyField( - std::make_unique(this, children_[index], numReads_), + makeColumnLoader(index), resultRow->type()->childAt(channel), rows.size(), memoryPool_, diff --git a/velox/dwio/common/SelectiveStructColumnReader.h b/velox/dwio/common/SelectiveStructColumnReader.h index 6aa21cae9c4f..5b8e8f720475 100644 --- a/velox/dwio/common/SelectiveStructColumnReader.h +++ b/velox/dwio/common/SelectiveStructColumnReader.h @@ -20,6 +20,8 @@ namespace facebook::velox::dwio::common { +class ColumnLoader; + template class SelectiveFlatMapColumnReaderHelper; @@ -108,18 +110,20 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { // The subscript of childSpecs will be set to this value if the column is // constant (either explicitly or because it's missing). - static constexpr int32_t kConstantChildSpecSubscript = -1; + static constexpr int32_t kConstantChildSpecSubscript{-1}; SelectiveStructColumnReaderBase( const TypePtr& requestedType, const std::shared_ptr& fileType, FormatParams& params, velox::common::ScanSpec& scanSpec, - bool isRoot = false) + bool isRoot = false, + bool generateLazyChildren = true) : SelectiveColumnReader(requestedType, fileType, params, scanSpec), debugString_( getExceptionContext().message(VeloxException::Type::kSystem)), isRoot_(isRoot), + generateLazyChildren_(generateLazyChildren), rows_(memoryPool_) {} bool hasDeletion() const final { @@ -136,17 +140,40 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { isChildMissing(childSpec); } - std::vector children_; - - private: - void fillOutputRowsFromMutation(vector_size_t size); - /// Records the number of nulls added by 'this' between the end position of /// each child reader and the end of the range of 'read(). This must be done /// also if a child is not read so that we know how much to skip when seeking /// forward within the row group. void recordParentNullsInChildren(int64_t offset, const RowSet& rows); + /// An implementation of seekTo that calls addSkippedParentNulls on each + /// child. Available as a helper function to formats that need it. + void seekToPropagateNullsToChildren(const int64_t offset); + + /// A helper function that implements seekToRowGroup for formats that support + /// a fixed number of rows per row group. + void seekToRowGroupFixedRowsPerRowGroup( + const int64_t index, + const int32_t rowsPerRowGroup); + + /// A helper function that implements advanceFieldReader for formats that + /// support a fixed number of rows per row group + void advanceFieldReaderFixedRowsPerRowGroup( + SelectiveColumnReader* reader, + const int64_t offset, + const int32_t rowsPerRowGroup); + + virtual std::unique_ptr makeColumnLoader( + vector_size_t index); + + // Sequence number of output batch. Checked against ColumnLoaders + // created by 'this' to verify they are still valid at load. + uint64_t numReads_ = 0; + std::vector children_; + + private: + void fillOutputRowsFromMutation(vector_size_t size); + void setOutputRowsForLazy(const RowSet& rows) { if (useOutputRows() && rows.size() != outputRows_.size()) { setOutputRows(rows); @@ -164,16 +191,15 @@ class SelectiveStructColumnReaderBase : public SelectiveColumnReader { // table. const bool isRoot_; + // Whether or not this should produce lazy vectors for children. + const bool generateLazyChildren_; + // Dense set of rows to read in next(). raw_vector rows_; - // Sequence number of output batch. Checked against ColumnLoaders - // created by 'this' to verify they are still valid at load. - uint64_t numReads_ = 0; - int64_t lazyVectorReadOffset_; - int64_t currentRowNumber_ = -1; + int64_t currentRowNumber_{-1}; const Mutation* mutation_ = nullptr; diff --git a/velox/dwio/common/SortingWriter.cpp b/velox/dwio/common/SortingWriter.cpp index 00bab4ee5360..d67efd6ec223 100644 --- a/velox/dwio/common/SortingWriter.cpp +++ b/velox/dwio/common/SortingWriter.cpp @@ -142,8 +142,9 @@ vector_size_t SortingWriter::outputBatchRows() { std::numeric_limits::max(); if (sortBuffer_->estimateOutputRowSize().has_value() && sortBuffer_->estimateOutputRowSize().value() != 0) { - const uint64_t maxOutputRows = - maxOutputBytesConfig_ / sortBuffer_->estimateOutputRowSize().value(); + const auto maxOutputRows = std::max( + static_cast(1), + maxOutputBytesConfig_ / sortBuffer_->estimateOutputRowSize().value()); if (UNLIKELY(maxOutputRows > std::numeric_limits::max())) { return maxOutputRowsConfig_; } diff --git a/velox/dwio/common/Statistics.h b/velox/dwio/common/Statistics.h index 1c6965d6d71c..95beba497df9 100644 --- a/velox/dwio/common/Statistics.h +++ b/velox/dwio/common/Statistics.h @@ -18,9 +18,11 @@ #include #include +#include "velox/dwio/common/UnitLoader.h" #include "velox/common/base/Exceptions.h" #include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/io/IoStatistics.h" #include "velox/dwio/common/exception/Exception.h" namespace facebook::velox::dwio::common { @@ -493,10 +495,11 @@ class MapColumnStatistics : public virtual ColumnStatistics { values.reserve(entryStatistics_.size()); for (const auto& entry : entryStatistics_) { auto& stats = *entry.second; - values.push_back(fmt::format( - "{{ Key: {}, Stats: {},}}", - entry.first.toString(), - stats.toString())); + values.push_back( + fmt::format( + "{{ Key: {}, Stats: {},}}", + entry.first.toString(), + stats.toString())); } std::string repr; folly::join(",", values, repr); @@ -536,6 +539,9 @@ struct ColumnReaderStatistics { // Number of rows returned by string dictionary reader that is flattened // instead of keeping dictionary encoding. int64_t flattenStringDictionaryValues{0}; + + // Total time spent in loading pages, in nanoseconds. + io::IoCounter pageLoadTimeNs; }; struct RuntimeStatistics { @@ -558,39 +564,53 @@ struct RuntimeStatistics { int64_t numStripes{0}; + UnitLoaderStats unitLoaderStats; ColumnReaderStatistics columnReaderStatistics; - std::unordered_map toMap() { - std::unordered_map result; + std::unordered_map toRuntimeMetricMap() { + std::unordered_map result; + for (const auto& [name, metric] : unitLoaderStats.stats()) { + result.emplace(name, RuntimeMetric(metric.sum, metric.unit)); + } if (skippedSplits > 0) { - result.emplace("skippedSplits", RuntimeCounter(skippedSplits)); + result.emplace("skippedSplits", RuntimeMetric(skippedSplits)); } if (processedSplits > 0) { - result.emplace("processedSplits", RuntimeCounter(processedSplits)); + result.emplace("processedSplits", RuntimeMetric(processedSplits)); } if (skippedSplitBytes > 0) { result.emplace( "skippedSplitBytes", - RuntimeCounter(skippedSplitBytes, RuntimeCounter::Unit::kBytes)); + RuntimeMetric(skippedSplitBytes, RuntimeCounter::Unit::kBytes)); } if (skippedStrides > 0) { - result.emplace("skippedStrides", RuntimeCounter(skippedStrides)); + result.emplace("skippedStrides", RuntimeMetric(skippedStrides)); } if (processedStrides > 0) { - result.emplace("processedStrides", RuntimeCounter(processedStrides)); + result.emplace("processedStrides", RuntimeMetric(processedStrides)); } if (footerBufferOverread > 0) { result.emplace( "footerBufferOverread", - RuntimeCounter(footerBufferOverread, RuntimeCounter::Unit::kBytes)); + RuntimeMetric(footerBufferOverread, RuntimeCounter::Unit::kBytes)); } if (numStripes > 0) { - result.emplace("numStripes", RuntimeCounter(numStripes)); + result.emplace("numStripes", RuntimeMetric(numStripes)); } if (columnReaderStatistics.flattenStringDictionaryValues > 0) { result.emplace( "flattenStringDictionaryValues", - RuntimeCounter(columnReaderStatistics.flattenStringDictionaryValues)); + RuntimeMetric(columnReaderStatistics.flattenStringDictionaryValues)); + } + if (columnReaderStatistics.pageLoadTimeNs.sum() > 0) { + result.emplace( + "pageLoadTimeNs", + RuntimeMetric( + columnReaderStatistics.pageLoadTimeNs.sum(), + columnReaderStatistics.pageLoadTimeNs.count(), + columnReaderStatistics.pageLoadTimeNs.min(), + columnReaderStatistics.pageLoadTimeNs.max(), + RuntimeCounter::Unit::kNanos)); } return result; } diff --git a/velox/dwio/common/StreamUtil.h b/velox/dwio/common/StreamUtil.h index 0a0a2b2b0e86..77b0d429ab76 100644 --- a/velox/dwio/common/StreamUtil.h +++ b/velox/dwio/common/StreamUtil.h @@ -144,7 +144,7 @@ inline void readContiguous( // Returns the number of elements in rows that are < limit. inline int32_t numBelow(folly::Range rows, int32_t limit) { - return std::lower_bound(rows.begin(), rows.end(), limit) - rows.begin(); + return std::lower_bound(rows.cbegin(), rows.cend(), limit) - rows.cbegin(); } template diff --git a/velox/dwio/common/TypeUtils.cpp b/velox/dwio/common/TypeUtils.cpp index 29e22046196f..b0220b9acdc2 100644 --- a/velox/dwio/common/TypeUtils.cpp +++ b/velox/dwio/common/TypeUtils.cpp @@ -131,11 +131,13 @@ void checkTypeCompatibility( const FShouldRead& shouldRead, const std::function& exceptionMessageCreator) { if (shouldRead(to) && !isCompatible(from.kind(), kind(to))) { - VELOX_SCHEMA_MISMATCH_ERROR(fmt::format( - "{}, From Kind: {}, To Kind: {}", - exceptionMessageCreator ? exceptionMessageCreator() : "Schema mismatch", - mapTypeKindToName(from.kind()), - mapTypeKindToName(kind(to)))); + VELOX_SCHEMA_MISMATCH_ERROR( + fmt::format( + "{}, From Kind: {}, To Kind: {}", + exceptionMessageCreator ? exceptionMessageCreator() + : "Schema mismatch", + TypeKindName::toName(from.kind()), + TypeKindName::toName(kind(to)))); } if (recurse) { diff --git a/velox/dwio/common/TypeWithId.h b/velox/dwio/common/TypeWithId.h index a147cfe5066f..80084dbfb535 100644 --- a/velox/dwio/common/TypeWithId.h +++ b/velox/dwio/common/TypeWithId.h @@ -97,4 +97,6 @@ class TypeWithId : public velox::Tree> { const std::vector> children_; }; +using TypeWithIdPtr = std::shared_ptr; + } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/UnitLoader.h b/velox/dwio/common/UnitLoader.h index d3125dacc4be..d1fc54ab2407 100644 --- a/velox/dwio/common/UnitLoader.h +++ b/velox/dwio/common/UnitLoader.h @@ -16,9 +16,13 @@ #pragma once +#include +#include #include #include #include +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" namespace facebook::velox::dwio::common { @@ -39,6 +43,44 @@ class LoadUnit { virtual uint64_t getIoSize() = 0; }; +class UnitLoaderStats { + public: + UnitLoaderStats() = default; + + void addCounter(const std::string& name, RuntimeCounter counter) { + auto locked = stats_.wlock(); + auto it = locked->find(name); + if (it == locked->end()) { + auto [ptr, inserted] = locked->emplace(name, RuntimeMetric(counter.unit)); + VELOX_CHECK(inserted); + ptr->second.addValue(counter.value); + } else { + VELOX_CHECK_EQ(it->second.unit, counter.unit); + it->second.addValue(counter.value); + } + } + + void merge(const UnitLoaderStats& other) { + auto otherStats = other.stats(); + auto locked = stats_.wlock(); + for (const auto& [name, metric] : otherStats) { + auto it = locked->find(name); + if (it == locked->end()) { + locked->emplace(name, metric); + } else { + it->second.merge(metric); + } + } + } + + folly::F14FastMap stats() const { + return stats_.copy(); + } + + private: + folly::Synchronized> stats_; +}; + class UnitLoader { public: virtual ~UnitLoader() = default; @@ -56,6 +98,10 @@ class UnitLoader { /// Reader reports seek calling this method. The call must be done **before** /// getLoadedUnit for the new unit. virtual void onSeek(uint32_t unit, uint64_t rowOffsetInUnit) = 0; + + virtual UnitLoaderStats stats() { + return UnitLoaderStats(); + }; }; class UnitLoaderFactory { diff --git a/velox/dwio/common/Writer.cpp b/velox/dwio/common/Writer.cpp index 87951cad7c59..10384341fc2b 100644 --- a/velox/dwio/common/Writer.cpp +++ b/velox/dwio/common/Writer.cpp @@ -80,6 +80,13 @@ bool Writer::isFinishing() const { } void Writer::checkRunning() const { + // Typically represents writer misuse. + VELOX_USER_CHECK_NE( + state_, + State::kClosed, + "Writer is not running: {}. Write operations are not allowed on a closed writer.", + Writer::stateString(state_)); + VELOX_CHECK_EQ( state_, State::kRunning, diff --git a/velox/dwio/common/compression/CMakeLists.txt b/velox/dwio/common/compression/CMakeLists.txt index e1f1d0ac60cf..ed39366b3a54 100644 --- a/velox/dwio/common/compression/CMakeLists.txt +++ b/velox/dwio/common/compression/CMakeLists.txt @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -velox_add_library(velox_dwio_common_compression Compression.cpp - PagedInputStream.cpp PagedOutputStream.cpp) +velox_add_library( + velox_dwio_common_compression + Compression.cpp + PagedInputStream.cpp + PagedOutputStream.cpp +) velox_link_libraries( velox_dwio_common_compression velox_dwio_common xsimd Folly::folly - Snappy::snappy) + Snappy::snappy +) diff --git a/velox/dwio/common/compression/Compression.cpp b/velox/dwio/common/compression/Compression.cpp index 1b66c17df041..376f2994b9f6 100644 --- a/velox/dwio/common/compression/Compression.cpp +++ b/velox/dwio/common/compression/Compression.cpp @@ -35,6 +35,8 @@ using memory::MemoryPool; namespace { +constexpr int kGzipCodec = 16; + class ZstdCompressor : public Compressor { public: explicit ZstdCompressor(int32_t level) : Compressor{level} {} @@ -57,7 +59,7 @@ ZstdCompressor::compress(const void* src, void* dest, uint64_t length) { class ZlibCompressor : public Compressor { public: - explicit ZlibCompressor(int32_t level); + explicit ZlibCompressor(int32_t level, int32_t windowBits, bool isGzip); ~ZlibCompressor() override; @@ -68,13 +70,17 @@ class ZlibCompressor : public Compressor { z_stream stream_; }; -ZlibCompressor::ZlibCompressor(int32_t level) +ZlibCompressor::ZlibCompressor(int32_t level, int32_t windowBits, bool isGzip) : Compressor{level}, isCompressCalled_{false} { stream_.zalloc = Z_NULL; stream_.zfree = Z_NULL; stream_.opaque = Z_NULL; + if (isGzip) { + windowBits = (windowBits < 0 ? -windowBits : windowBits) | kGzipCodec; + } DWIO_ENSURE_EQ( - deflateInit2(&stream_, level_, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY), + deflateInit2( + &stream_, level_, Z_DEFLATED, windowBits, 8, Z_DEFAULT_STRATEGY), Z_OK, "Error while calling deflateInit2() for zlib."); } @@ -150,9 +156,9 @@ ZlibDecompressor::ZlibDecompressor( zstream_.next_out = Z_NULL; zstream_.avail_out = folly::to(blockSize); int zlibWindowBits = windowBits; - constexpr int GZIP_DETECT_CODE = 32; if (isGzip) { - zlibWindowBits = zlibWindowBits | GZIP_DETECT_CODE; + zlibWindowBits = + (zlibWindowBits < 0 ? -zlibWindowBits : zlibWindowBits) | kGzipCodec; } const auto result = inflateInit2(&zstream_, zlibWindowBits); DWIO_ENSURE_EQ( @@ -392,9 +398,6 @@ uint64_t Lz4Decompressor::decompressInternal( return static_cast(result); } -// NOTE: We do not keep `ZSTD_DCtx' around on purpose, because if we keep it -// around, in flat map column reader we have hundreds of thousands of -// decompressors at same time and causing OOM. class ZstdDecompressor : public Decompressor { public: explicit ZstdDecompressor( @@ -418,7 +421,10 @@ uint64_t ZstdDecompressor::decompress( uint64_t srcLength, char* dest, uint64_t destLength) { - auto ret = ZSTD_decompress(dest, destLength, src, srcLength); + // Reuse 'ZSTD_DCtx' per-thread to avoid repeated allocations. + thread_local std::unique_ptr ctx{ + ZSTD_createDCtx(), ZSTD_freeDCtx}; + auto ret = ZSTD_decompressDCtx(ctx.get(), dest, destLength, src, srcLength); DWIO_ENSURE( !ZSTD_isError(ret), "ZSTD returned an error: ", @@ -549,6 +555,8 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { *size = static_cast(availSize); outputBufferPtr_ = inputBufferPtr_ + availSize; outputBufferLength_ = 0; + inputBufferPtr_ += availSize; + remainingLength_ -= availSize; } else { DWIO_ENSURE_EQ( state_, @@ -561,42 +569,49 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { getDecompressedLength(inputBufferPtr_, availSize).first); reset(); - zstream_.next_in = - reinterpret_cast(const_cast(inputBufferPtr_)); - zstream_.avail_in = folly::to(availSize); - outputBufferPtr_ = outputBuffer_->data(); - zstream_.next_out = - reinterpret_cast(const_cast(outputBufferPtr_)); - zstream_.avail_out = folly::to(blockSize_); int32_t result; + *size = 0; do { - result = inflate( - &zstream_, availSize == remainingLength_ ? Z_FINISH : Z_SYNC_FLUSH); - switch (result) { - case Z_OK: - remainingLength_ -= availSize; - inputBufferPtr_ += availSize; - readBuffer(true); - availSize = std::min( - static_cast(inputBufferPtrEnd_ - inputBufferPtr_), - remainingLength_); - zstream_.next_in = - reinterpret_cast(const_cast(inputBufferPtr_)); - zstream_.avail_in = static_cast(availSize); - break; - case Z_STREAM_END: - break; - default: - DWIO_RAISE( - "Error in ZlibDecompressionStream::Next in ", - getName(), - ". error: ", - result, - " Info: ", - ZlibDecompressor::streamDebugInfo_); + if (inputBufferPtr_ == inputBufferPtrEnd_) { + readBuffer(true); } + zstream_.next_in = + reinterpret_cast(const_cast(inputBufferPtr_)); + zstream_.avail_in = + static_cast(inputBufferPtrEnd_ - inputBufferPtr_); + + do { + // size_ of outputBuffer_ is not updated in inflate, so *size is used + // here to ensure enough capacity for the output data. + outputBuffer_->extend(*size); + outputBufferPtr_ = outputBuffer_->data(); + zstream_.next_out = reinterpret_cast( + const_cast(outputBufferPtr_ + *size)); + zstream_.avail_out = folly::to(blockSize_); + result = inflate(&zstream_, Z_SYNC_FLUSH); + // Result handling adapted from https://zlib.net/zlib_how.html + switch (result) { + case Z_NEED_DICT: + result = Z_DATA_ERROR; + [[fallthrough]]; + case Z_DATA_ERROR: + [[fallthrough]]; + case Z_MEM_ERROR: + [[fallthrough]]; + case Z_STREAM_ERROR: + DWIO_RAISE("Failed to inflate input data. error: ", result); + default: + *size += static_cast( + blockSize_ - static_cast(zstream_.avail_out)); + const size_t inputConsumed = + reinterpret_cast(zstream_.next_in) - + inputBufferPtr_; + remainingLength_ -= inputConsumed; + inputBufferPtr_ += inputConsumed; + } + } while (zstream_.avail_out == 0); } while (result != Z_STREAM_END); - *size = static_cast(blockSize_ - zstream_.avail_out); + if (data) { *data = outputBufferPtr_; } @@ -604,8 +619,6 @@ bool ZlibDecompressionStream::readOrSkip(const void** data, int32_t* size) { outputBufferPtr_ += *size; } - inputBufferPtr_ += availSize; - remainingLength_ -= availSize; bytesReturned_ += *size; return true; } @@ -623,7 +636,18 @@ std::unique_ptr createCompressor( "Initialized zlib compressor with compression level {}", options.format.zlib.compressionLevel); return std::make_unique( + options.format.zlib.compressionLevel, + options.format.zlib.windowBits, + false); + } + case CompressionKind::CompressionKind_GZIP: { + XLOG_FIRST_N(INFO, 1) << fmt::format( + "Initialized zlib compressor with compression level {}", options.format.zlib.compressionLevel); + return std::make_unique( + options.format.zlib.compressionLevel, + options.format.zlib.windowBits, + true); } case CompressionKind::CompressionKind_ZSTD: { XLOG_FIRST_N(INFO, 1) << fmt::format( diff --git a/velox/dwio/common/compression/PagedInputStream.cpp b/velox/dwio/common/compression/PagedInputStream.cpp index 08357298cf57..66aa8ffb156d 100644 --- a/velox/dwio/common/compression/PagedInputStream.cpp +++ b/velox/dwio/common/compression/PagedInputStream.cpp @@ -165,7 +165,7 @@ bool PagedInputStream::readOrSkip(const void** data, int32_t* size) { // perform decryption if (decrypter_) { decryptionBuffer_ = - decrypter_->decrypt(folly::StringPiece{input, remainingLength_}); + decrypter_->decrypt(std::string_view{input, remainingLength_}); input = reinterpret_cast(decryptionBuffer_->data()); remainingLength_ = decryptionBuffer_->length(); if (data) { diff --git a/velox/dwio/common/compression/PagedInputStream.h b/velox/dwio/common/compression/PagedInputStream.h index 15b1acd630a0..39d48cbb6869 100644 --- a/velox/dwio/common/compression/PagedInputStream.h +++ b/velox/dwio/common/compression/PagedInputStream.h @@ -56,7 +56,7 @@ class PagedInputStream : public dwio::common::SeekableInputStream { // NOTE: This always returns true. bool SkipInt64(int64_t count) override; - google::protobuf::int64 ByteCount() const override { + int64_t ByteCount() const override { return bytesReturned_ + pendingSkip_; } diff --git a/velox/dwio/common/compression/PagedOutputStream.cpp b/velox/dwio/common/compression/PagedOutputStream.cpp index 18d993bf74f0..2519893be4bf 100644 --- a/velox/dwio/common/compression/PagedOutputStream.cpp +++ b/velox/dwio/common/compression/PagedOutputStream.cpp @@ -18,7 +18,7 @@ namespace facebook::velox::dwio::common::compression { -std::vector PagedOutputStream::createPage() { +std::vector PagedOutputStream::createPage() { auto origSize = buffer_.size(); VELOX_CHECK_GT(origSize, pageHeaderSize_); origSize -= pageHeaderSize_; @@ -34,15 +34,15 @@ std::vector PagedOutputStream::createPage() { origSize); } - folly::StringPiece compressed; + std::string_view compressed; if (compressedSize >= origSize) { // write orig writeHeader(buffer_.data(), origSize, true); - compressed = folly::StringPiece(buffer_.data(), origSize + pageHeaderSize_); + compressed = std::string_view(buffer_.data(), origSize + pageHeaderSize_); } else { // write compressed writeHeader(compressionBuffer_->data(), compressedSize, false); - compressed = folly::StringPiece( + compressed = std::string_view( compressionBuffer_->data(), compressedSize + pageHeaderSize_); } @@ -50,13 +50,13 @@ std::vector PagedOutputStream::createPage() { return {compressed}; } - encryptionBuffer_ = encryptor_->encrypt(folly::StringPiece( - compressed.begin() + pageHeaderSize_, compressed.end())); + encryptionBuffer_ = encryptor_->encrypt( + std::string_view(compressed.begin() + pageHeaderSize_, compressed.end())); updateSize( const_cast(compressed.begin()), encryptionBuffer_->length()); return { - folly::StringPiece(compressed.begin(), pageHeaderSize_), - folly::StringPiece( + std::string_view(compressed.begin(), pageHeaderSize_), + std::string_view( reinterpret_cast(encryptionBuffer_->data()), encryptionBuffer_->length())}; } diff --git a/velox/dwio/common/compression/PagedOutputStream.h b/velox/dwio/common/compression/PagedOutputStream.h index 498bba3c781c..415d2961ecad 100644 --- a/velox/dwio/common/compression/PagedOutputStream.h +++ b/velox/dwio/common/compression/PagedOutputStream.h @@ -64,8 +64,8 @@ class PagedOutputStream : public BufferedOutputStream { int32_t strideIndex = -1) const override; private: - // create page using compressor and encryptor - std::vector createPage(); + // Create page using compressor and encryptor. + std::vector createPage(); void writeHeader(char* buffer, size_t compressedSize, bool original); diff --git a/velox/dwio/common/encryption/Encryption.cpp b/velox/dwio/common/encryption/Encryption.cpp index 2e7df9a47932..1db5b8c24dee 100644 --- a/velox/dwio/common/encryption/Encryption.cpp +++ b/velox/dwio/common/encryption/Encryption.cpp @@ -16,19 +16,11 @@ #include "velox/dwio/common/encryption/Encryption.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace encryption { +namespace facebook::velox::dwio::common::encryption { bool operator==(const EncryptionProperties& a, const EncryptionProperties& b) { return std::addressof(a) == std::addressof(b) || (typeid(a) == typeid(b) && a.equals(b)); } -} // namespace encryption -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::encryption diff --git a/velox/dwio/common/encryption/Encryption.h b/velox/dwio/common/encryption/Encryption.h index db3c2196d954..798f67d3c9fa 100644 --- a/velox/dwio/common/encryption/Encryption.h +++ b/velox/dwio/common/encryption/Encryption.h @@ -20,11 +20,7 @@ #include "folly/io/IOBuf.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace encryption { +namespace facebook::velox::dwio::common::encryption { enum class EncryptionProvider { Unknown = 0, CryptoService }; @@ -64,7 +60,7 @@ class Encrypter { virtual const std::string& getKey() const = 0; virtual std::unique_ptr encrypt( - folly::StringPiece input) const = 0; + std::string_view input) const = 0; virtual std::unique_ptr clone() const = 0; }; @@ -87,7 +83,7 @@ class Decrypter { virtual bool isKeyLoaded() const = 0; virtual std::unique_ptr decrypt( - folly::StringPiece input) const = 0; + std::string_view input) const = 0; virtual std::unique_ptr clone() const = 0; }; @@ -108,7 +104,7 @@ class DummyDecrypter : public Decrypter { } std::unique_ptr decrypt( - folly::StringPiece /* unused */) const override { + std::string_view /* unused */) const override { DWIO_RAISE("Failed to access encrypted data"); } @@ -127,8 +123,4 @@ class DummyDecrypterFactory : public DecrypterFactory { } }; -} // namespace encryption -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::encryption diff --git a/velox/dwio/common/encryption/TestProvider.h b/velox/dwio/common/encryption/TestProvider.h index 7ef41de8e574..b50f727a1e92 100644 --- a/velox/dwio/common/encryption/TestProvider.h +++ b/velox/dwio/common/encryption/TestProvider.h @@ -20,12 +20,7 @@ #include "velox/dwio/common/encryption/Encryption.h" #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace encryption { -namespace test { +namespace facebook::velox::dwio::common::encryption::test { class TestEncryption { public: @@ -37,18 +32,19 @@ class TestEncryption { return key_; } - std::unique_ptr encrypt(folly::StringPiece input) const { + std::unique_ptr encrypt(std::string_view input) const { ++count_; auto encoded = velox::encoding::Base64::encodeUrl(input); return folly::IOBuf::copyBuffer(key_ + encoded); } - std::unique_ptr decrypt(folly::StringPiece input) const { + std::unique_ptr decrypt(std::string_view input) const { ++count_; - std::string key{input.begin(), key_.size()}; + std::string key{input.cbegin(), key_.size()}; DWIO_ENSURE_EQ(key_, key); - auto decoded = velox::encoding::Base64::decodeUrl(folly::StringPiece{ - input.begin() + key_.size(), input.size() - key_.size()}); + auto decoded = velox::encoding::Base64::decodeUrl( + std::string_view{ + input.begin() + key_.size(), input.size() - key_.size()}); return folly::IOBuf::copyBuffer(decoded); } @@ -67,8 +63,7 @@ class TestEncrypter : public TestEncryption, public Encrypter { return TestEncryption::getKey(); } - std::unique_ptr encrypt( - folly::StringPiece input) const override { + std::unique_ptr encrypt(std::string_view input) const override { return TestEncryption::encrypt(input); } @@ -89,8 +84,7 @@ class TestDecrypter : public TestEncryption, public Decrypter { return !getKey().empty(); } - std::unique_ptr decrypt( - folly::StringPiece input) const override { + std::unique_ptr decrypt(std::string_view input) const override { return TestEncryption::decrypt(input); } @@ -144,9 +138,4 @@ class TestDecrypterFactory : public DecrypterFactory { } }; -} // namespace test -} // namespace encryption -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::encryption::test diff --git a/velox/dwio/common/exception/CMakeLists.txt b/velox/dwio/common/exception/CMakeLists.txt index 43d2d8a067b2..c30b1577c29f 100644 --- a/velox/dwio/common/exception/CMakeLists.txt +++ b/velox/dwio/common/exception/CMakeLists.txt @@ -14,5 +14,4 @@ velox_add_library(velox_dwio_common_exception Exception.cpp Exceptions.cpp) -velox_link_libraries(velox_dwio_common_exception velox_exception Folly::folly - glog::glog) +velox_link_libraries(velox_dwio_common_exception velox_exception Folly::folly glog::glog) diff --git a/velox/dwio/common/exception/Exception.cpp b/velox/dwio/common/exception/Exception.cpp index 97b0b78fd3cf..e15f503d3ee0 100644 --- a/velox/dwio/common/exception/Exception.cpp +++ b/velox/dwio/common/exception/Exception.cpp @@ -17,11 +17,7 @@ #include "velox/dwio/common/exception/Exception.h" #include "velox/common/base/Exceptions.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace exception { +namespace facebook::velox::dwio::common::exception { std::unique_ptr& exceptionLogger() { static std::unique_ptr logger(nullptr); @@ -40,8 +36,4 @@ ExceptionLogger* getExceptionLogger() { return exceptionLogger().get(); } -} // namespace exception -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common::exception diff --git a/velox/dwio/common/exception/Exception.h b/velox/dwio/common/exception/Exception.h index 814d25b728e6..140df90e1959 100644 --- a/velox/dwio/common/exception/Exception.h +++ b/velox/dwio/common/exception/Exception.h @@ -18,11 +18,8 @@ #include "velox/common/base/VeloxException.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { -namespace exception { +namespace facebook::velox::dwio { +namespace common::exception { class ExceptionLogger { public: @@ -52,7 +49,7 @@ class LoggedException : public velox::VeloxException { explicit LoggedException( const std::string& errorMessage, const std::string& errorSource = - ::facebook::velox::error_source::kErrorSourceRuntime, + ::facebook::velox::error_source::kErrorSourceExternal, const std::string& errorCode = ::facebook::velox::error_code::kUnknown, const bool isRetriable = false) : VeloxException( @@ -101,8 +98,7 @@ class LoggedException : public velox::VeloxException { } }; -} // namespace exception -} // namespace common +} // namespace common::exception #define DWIO_WARN_IF(e, ...) \ ({ \ @@ -181,7 +177,7 @@ containing information about the file, line, and function where it happened. #define DWIO_RAISE(...) \ DWIO_EXCEPTION_CUSTOM( \ facebook::velox::dwio::common::exception::LoggedException, \ - ::facebook::velox::error_source::kErrorSourceRuntime, \ + ::facebook::velox::error_source::kErrorSourceExternal, \ ::facebook::velox::error_code::kUnknown, \ ##__VA_ARGS__) @@ -189,7 +185,7 @@ containing information about the file, line, and function where it happened. DWIO_ENFORCE_CUSTOM( \ facebook::velox::dwio::common::exception::LoggedException, \ expr, \ - ::facebook::velox::error_source::kErrorSourceRuntime, \ + ::facebook::velox::error_source::kErrorSourceExternal, \ ::facebook::velox::error_code::kUnknown, \ ##__VA_ARGS__) @@ -256,6 +252,4 @@ containing information about the file, line, and function where it happened. "]: ", \ ##__VA_ARGS__); -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio diff --git a/velox/dwio/common/exception/Exceptions.cpp b/velox/dwio/common/exception/Exceptions.cpp index 04835dc5ac1a..6c1bc5818e3e 100644 --- a/velox/dwio/common/exception/Exceptions.cpp +++ b/velox/dwio/common/exception/Exceptions.cpp @@ -22,10 +22,7 @@ #include #include -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { void verify_range(uint64_t v, uint64_t rangeMask) { auto mv = (v & rangeMask); @@ -67,7 +64,4 @@ std::string format_error_string(std::string fmt...) { return s; } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/exception/Exceptions.h b/velox/dwio/common/exception/Exceptions.h index d6963468ea40..c241e5788f43 100644 --- a/velox/dwio/common/exception/Exceptions.h +++ b/velox/dwio/common/exception/Exceptions.h @@ -22,10 +22,7 @@ #include #include "velox/dwio/common/exception/Exception.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class NotImplementedYet : public std::logic_error { public: @@ -117,7 +114,4 @@ using logic_error = exception_error; using runtime_error = exception_error; using EOF_error = exception_error; -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/CMakeLists.txt b/velox/dwio/common/tests/CMakeLists.txt index 30efa1be5125..873814eec418 100644 --- a/velox/dwio/common/tests/CMakeLists.txt +++ b/velox/dwio/common/tests/CMakeLists.txt @@ -14,8 +14,7 @@ add_subdirectory(utils) add_library(velox_dwio_faulty_file_sink FaultyFileSink.cpp) -target_link_libraries( - velox_dwio_faulty_file_sink velox_file_test_utils velox_dwio_common) +target_link_libraries(velox_dwio_faulty_file_sink velox_file_test_utils velox_dwio_common) # There is an issue with the VTT symbol for the InlineExecutor from folly when # building on Linux with Clang15. It is not created and results in a SEGV when @@ -26,10 +25,7 @@ target_link_libraries( # optimization level to 0 results in proper creation/linkage and successful # execution of the test. Review if this is still necessary when upgrading Clang. if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - set_property( - SOURCE ParallelForTest.cpp - APPEND - PROPERTY COMPILE_OPTIONS -O0) + set_property(SOURCE ParallelForTest.cpp APPEND PROPERTY COMPILE_OPTIONS -O0) endif() add_executable( @@ -42,6 +38,7 @@ add_executable( DecoderUtilTest.cpp ExecutorBarrierTest.cpp OnDemandUnitLoaderTests.cpp + ParallelUnitLoaderTest.cpp LocalFileSinkTest.cpp MemorySinkTest.cpp LoggedExceptionTest.cpp @@ -51,20 +48,23 @@ add_executable( ReadFileInputStreamTests.cpp ReaderTest.cpp RetryTests.cpp + ScanSpecTest.cpp + SortingWriterTest.cpp TestBufferedInput.cpp ThrottlerTest.cpp TypeTests.cpp UnitLoaderToolsTests.cpp WriterTest.cpp - OptionsTests.cpp) + OptionsTests.cpp +) add_test(velox_dwio_common_test velox_dwio_common_test) target_link_libraries( velox_dwio_common_test velox_dwio_common_test_utils velox_temp_path velox_vector_test_lib - Boost::regex velox_link_libs + velox_hive_connector Folly::folly ${TEST_LINK_LIBS} GTest::gtest @@ -72,11 +72,11 @@ target_link_libraries( GTest::gmock glog::glog fmt::fmt - protobuf::libprotobuf) + protobuf::libprotobuf +) if(VELOX_ENABLE_BENCHMARKS) - add_executable(velox_dwio_common_data_buffer_benchmark - DataBufferBenchmark.cpp) + add_executable(velox_dwio_common_data_buffer_benchmark DataBufferBenchmark.cpp) target_link_libraries( velox_dwio_common_data_buffer_benchmark @@ -84,25 +84,27 @@ if(VELOX_ENABLE_BENCHMARKS) velox_memory velox_dwio_common_exception Folly::folly - Folly::follybenchmark) + Folly::follybenchmark + ) - add_executable(velox_dwio_common_int_decoder_benchmark - IntDecoderBenchmark.cpp) + add_executable(velox_dwio_common_int_decoder_benchmark IntDecoderBenchmark.cpp) target_link_libraries( velox_dwio_common_int_decoder_benchmark velox_dwio_common_exception velox_exception velox_dwio_dwrf_common Folly::folly - Folly::follybenchmark) + Folly::follybenchmark + ) if(VELOX_ENABLE_ARROW) add_subdirectory(Lemire/FastPFor) - add_executable(velox_dwio_common_bitpack_decoder_benchmark - BitPackDecoderBenchmark.cpp) + add_executable(velox_dwio_common_bitpack_decoder_benchmark BitPackDecoderBenchmark.cpp) - target_compile_options(velox_dwio_common_bitpack_decoder_benchmark - PRIVATE -Wno-deprecated-declarations) + target_compile_options( + velox_dwio_common_bitpack_decoder_benchmark + PRIVATE -Wno-deprecated-declarations + ) target_link_libraries( velox_dwio_common_bitpack_decoder_benchmark @@ -111,6 +113,7 @@ if(VELOX_ENABLE_BENCHMARKS) velox_fastpforlib duckdb_static Folly::folly - Folly::follybenchmark) + Folly::follybenchmark + ) endif() endif() diff --git a/velox/dwio/common/tests/ChainedBufferTests.cpp b/velox/dwio/common/tests/ChainedBufferTests.cpp index 0612f4341ff9..43820bc0c77c 100644 --- a/velox/dwio/common/tests/ChainedBufferTests.cpp +++ b/velox/dwio/common/tests/ChainedBufferTests.cpp @@ -23,10 +23,7 @@ using namespace ::testing; -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { class ChainedBufferTests : public Test { protected: @@ -64,8 +61,9 @@ TEST_F(ChainedBufferTests, testCreate) { TEST_F(ChainedBufferTests, testReserve) { for (const uint32_t initialCapacityBytes : {0, 16}) { - SCOPED_TRACE(fmt::format( - "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); + SCOPED_TRACE( + fmt::format( + "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); ChainedBuffer buf{*pool_, initialCapacityBytes, 1024}; ASSERT_EQ(buf.capacity(), initialCapacityBytes); ASSERT_EQ(buf.size(), 0); @@ -252,8 +250,9 @@ TEST_F(ChainedBufferTests, testTrailingZeros) { TEST_F(ChainedBufferTests, testClearAll) { for (const uint32_t initialCapacityBytes : {0, 128}) { - SCOPED_TRACE(fmt::format( - "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); + SCOPED_TRACE( + fmt::format( + "initialCapacityBytes ", succinctBytes(initialCapacityBytes))); ChainedBuffer buf{*pool_, initialCapacityBytes, 1024}; ASSERT_EQ(buf.capacity(), initialCapacityBytes); ASSERT_EQ(buf.size(), 0); @@ -321,7 +320,5 @@ TEST_F(ChainedBufferTests, testClearAll) { ASSERT_EQ(buf.pages_.size(), 9); } } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/ColumnSelectorTests.cpp b/velox/dwio/common/tests/ColumnSelectorTests.cpp index f0176b3824dd..b65411d0e6fb 100644 --- a/velox/dwio/common/tests/ColumnSelectorTests.cpp +++ b/velox/dwio/common/tests/ColumnSelectorTests.cpp @@ -394,17 +394,19 @@ TEST(ColumnSelectorTests, testFlatMapKeyFilterAllowed) { } TEST(ColumnSelectorTests, testPartitionKeysMark) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "memo:string" - "ds:string" - "key:string>")); - - const auto physicalSchema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "memo:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "memo:string" + "ds:string" + "key:string>")); + + const auto physicalSchema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "memo:string>")); // use schema and physical schema to initialize a column selector // without filtering @@ -456,14 +458,16 @@ TEST(ColumnSelectorTests, testPartitionKeysMark) { EXPECT_EQ(root->childAt(3)->getNode().expression, "gold"); // test apply to real data file disk schema - const auto schemaMore = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "memo:string" - "extra:array>")); - const auto schemaLess = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint>")); + const auto schemaMore = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "memo:string" + "extra:array>")); + const auto schemaLess = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint>")); auto csMore = ColumnSelector::apply(cs, schemaMore); LOG(INFO) << "CS filter size: " << cs->getProjection().size(); @@ -510,14 +514,15 @@ TEST(ColumnSelectorTests, testPartitionKeysMark) { } TEST(ColumnSelectorTests, testProjectionUnchangedWhenReadSetChanged) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string" - "extra:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string" + "extra:string>")); ColumnSelector cs(schema, std::vector{"id", "values"}); cs.setRead(cs.findColumn("notes")); @@ -559,13 +564,14 @@ TEST(ColumnSelectorTests, testProjectionUnchangedWhenReadSetChanged) { } TEST(ColumnSelectorTests, testProjectOrder) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string>")); // test filter with names with order of tags, memo and id { @@ -642,14 +648,15 @@ TEST(ColumnSelectorTests, testProjectOrder) { } TEST(ColumnSelectorTests, testNonexistingColFilters) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string" - "extra:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string" + "extra:string>")); EXPECT_THROW( ColumnSelector cs( @@ -659,15 +666,16 @@ TEST(ColumnSelectorTests, testNonexistingColFilters) { } TEST(TestColumnSelector, fileColumnNamesReadAsLowerCaseDuplicateColFilters) { - const auto schema = std::dynamic_pointer_cast( - HiveTypeParser().parse("struct<" - "id:bigint" - "id:bigint" - "values:array" - "tags:map" - "notes:struct" - "memo:string" - "extra:string>")); + const auto schema = + std::dynamic_pointer_cast(HiveTypeParser().parse( + "struct<" + "id:bigint" + "id:bigint" + "values:array" + "tags:map" + "notes:struct" + "memo:string" + "extra:string>")); EXPECT_THROW( ColumnSelector cs(schema, std::vector{"id"}, nullptr, true), diff --git a/velox/dwio/common/tests/DataBufferTests.cpp b/velox/dwio/common/tests/DataBufferTests.cpp index a6ddee1afab6..c5e7b816cbc6 100644 --- a/velox/dwio/common/tests/DataBufferTests.cpp +++ b/velox/dwio/common/tests/DataBufferTests.cpp @@ -21,10 +21,8 @@ #include "velox/common/memory/Memory.h" #include "velox/dwio/common/DataBuffer.h" -namespace facebook { -namespace velox { -namespace dwio { -namespace common { +namespace facebook::velox::dwio::common { + using namespace facebook::velox::memory; using namespace testing; using MemoryPool = facebook::velox::memory::MemoryPool; @@ -175,7 +173,5 @@ TEST_F(DataBufferTest, Move) { } ASSERT_EQ(0, pool_->usedBytes()); } -} // namespace common -} // namespace dwio -} // namespace velox -} // namespace facebook + +} // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt b/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt index 5f6ff7652e32..d74b5b7af666 100644 --- a/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt +++ b/velox/dwio/common/tests/Lemire/FastPFor/CMakeLists.txt @@ -13,5 +13,4 @@ # limitations under the License. add_library(velox_fastpforlib STATIC bitpacking.cpp) -target_include_directories( - velox_fastpforlib PUBLIC $) +target_include_directories(velox_fastpforlib PUBLIC $) diff --git a/velox/dwio/common/tests/LoggedExceptionTest.cpp b/velox/dwio/common/tests/LoggedExceptionTest.cpp index 6eea7e8b5c78..0f17e1f50a33 100644 --- a/velox/dwio/common/tests/LoggedExceptionTest.cpp +++ b/velox/dwio/common/tests/LoggedExceptionTest.cpp @@ -34,11 +34,12 @@ void testTraceCollectionSwitchControl(bool enabled) { try { throw LoggedException("Test error message"); } catch (VeloxException& e) { - SCOPED_TRACE(fmt::format( - "enabled: {}, user flag: {}, sys flag: {}", - enabled, - FLAGS_velox_exception_user_stacktrace_enabled, - FLAGS_velox_exception_system_stacktrace_enabled)); + SCOPED_TRACE( + fmt::format( + "enabled: {}, user flag: {}, sys flag: {}", + enabled, + FLAGS_velox_exception_user_stacktrace_enabled, + FLAGS_velox_exception_system_stacktrace_enabled)); ASSERT_TRUE(e.exceptionType() == VeloxException::Type::kSystem); ASSERT_EQ(enabled, e.stackTrace() != nullptr); } diff --git a/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp b/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp index 178ad21f9b2d..245c7d6186de 100644 --- a/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp +++ b/velox/dwio/common/tests/OnDemandUnitLoaderTests.cpp @@ -17,9 +17,8 @@ #include #include -#include "velox/common/base/tests/GTestUtils.h" #include "velox/dwio/common/OnDemandUnitLoader.h" -#include "velox/dwio/common/UnitLoaderTools.h" +#include "velox/dwio/common/tests/UnitLoaderBaseTest.h" #include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" using namespace ::testing; @@ -31,6 +30,38 @@ using facebook::velox::dwio::common::test::getUnitsLoadedWithFalse; using facebook::velox::dwio::common::test::LoadUnitMock; using facebook::velox::dwio::common::test::ReaderMock; +class OnDemandUnitLoaderCommonTests + : public UnitLoaderBaseTest { + protected: + OnDemandUnitLoaderFactory createFactory() override { + return OnDemandUnitLoaderFactory(nullptr); + } +}; + +TEST_F(OnDemandUnitLoaderCommonTests, NoUnitButSkip) { + testNoUnitButSkip(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, InitialSkip) { + testInitialSkip(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, CanRequestUnitMultipleTimes) { + testCanRequestUnitMultipleTimes(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, UnitOutOfRange) { + testUnitOutOfRange(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, SeekOutOfRange) { + testSeekOutOfRange(); +} + +TEST_F(OnDemandUnitLoaderCommonTests, SeekOutOfRangeReaderError) { + testSeekOutOfRangeReaderError(); +} + TEST(OnDemandUnitLoaderTests, LoadsCorrectlyWithReader) { size_t blockedOnIoCount = 0; OnDemandUnitLoaderFactory factory([&](auto) { ++blockedOnIoCount; }); @@ -127,96 +158,3 @@ TEST(OnDemandUnitLoaderTests, CanSeek) { EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, false, false})); EXPECT_EQ(blockedOnIoCount, 4); } - -TEST(OnDemandUnitLoaderTests, SeekOutOfRangeReaderError) { - size_t blockedOnIoCount = 0; - OnDemandUnitLoaderFactory factory([&](auto) { ++blockedOnIoCount; }); - ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; - EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, false})); - EXPECT_EQ(blockedOnIoCount, 0); - readerMock.seek(59); - - readerMock.seek(60); - - VELOX_ASSERT_THROW( - readerMock.seek(61), - "Can't seek to possition 61 in file. Must be up to 60."); -} - -TEST(OnDemandUnitLoaderTests, SeekOutOfRange) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - - unitLoader->onSeek(0, 10); - - VELOX_ASSERT_THROW(unitLoader->onSeek(0, 11), "Row out of range"); -} - -TEST(OnDemandUnitLoaderTests, UnitOutOfRange) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - unitLoader->getLoadedUnit(0); - - VELOX_ASSERT_THROW(unitLoader->getLoadedUnit(1), "Unit out of range"); -} - -TEST(OnDemandUnitLoaderTests, CanRequestUnitMultipleTimes) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - - auto unitLoader = factory.create(std::move(units), 0); - unitLoader->getLoadedUnit(0); - unitLoader->getLoadedUnit(0); - unitLoader->getLoadedUnit(0); -} - -TEST(OnDemandUnitLoaderTests, InitialSkip) { - auto getFactoryWithSkip = [](uint64_t skipToRow) { - auto factory = std::make_unique(nullptr); - std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); - std::vector> units; - units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); - units.push_back(std::make_unique(20, 0, unitsLoaded, 1)); - units.push_back(std::make_unique(30, 0, unitsLoaded, 2)); - factory->create(std::move(units), skipToRow); - }; - - EXPECT_NO_THROW(getFactoryWithSkip(0)); - EXPECT_NO_THROW(getFactoryWithSkip(1)); - EXPECT_NO_THROW(getFactoryWithSkip(9)); - EXPECT_NO_THROW(getFactoryWithSkip(10)); - EXPECT_NO_THROW(getFactoryWithSkip(11)); - EXPECT_NO_THROW(getFactoryWithSkip(29)); - EXPECT_NO_THROW(getFactoryWithSkip(30)); - EXPECT_NO_THROW(getFactoryWithSkip(31)); - EXPECT_NO_THROW(getFactoryWithSkip(59)); - EXPECT_NO_THROW(getFactoryWithSkip(60)); - VELOX_ASSERT_THROW( - getFactoryWithSkip(61), - "Can only skip up to the past-the-end row of the file."); - VELOX_ASSERT_THROW( - getFactoryWithSkip(100), - "Can only skip up to the past-the-end row of the file."); -} - -TEST(OnDemandUnitLoaderTests, NoUnitButSkip) { - OnDemandUnitLoaderFactory factory(nullptr); - std::vector> units; - - EXPECT_NO_THROW(factory.create(std::move(units), 0)); - - std::vector> units2; - VELOX_ASSERT_THROW( - factory.create(std::move(units2), 1), - "Can only skip up to the past-the-end row of the file."); -} diff --git a/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp b/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp new file mode 100644 index 000000000000..690acd9fc119 --- /dev/null +++ b/velox/dwio/common/tests/ParallelUnitLoaderTest.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/ParallelUnitLoader.h" +#include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/dwio/common/tests/UnitLoaderBaseTest.h" +#include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" + +#include +#include +#include + +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::dwio::common::test; + +class ParallelUnitLoaderTest + : public UnitLoaderBaseTest { + protected: + ParallelUnitLoaderFactory createFactory() override { + return ParallelUnitLoaderFactory(ioExecutor_.get(), 2); + } + + std::unique_ptr ioExecutor_ = + std::make_unique(10); +}; + +TEST_F(ParallelUnitLoaderTest, NoUnitButSkip) { + testNoUnitButSkip(); +} + +TEST_F(ParallelUnitLoaderTest, InitialSkip) { + testInitialSkip(); +} + +TEST_F(ParallelUnitLoaderTest, CanRequestUnitMultipleTimes) { + testCanRequestUnitMultipleTimes(); +} + +TEST_F(ParallelUnitLoaderTest, UnitOutOfRange) { + testUnitOutOfRange(); +} + +TEST_F(ParallelUnitLoaderTest, SeekOutOfRange) { + testSeekOutOfRange(); +} + +TEST_F(ParallelUnitLoaderTest, SeekOutOfRangeReaderError) { + testSeekOutOfRangeReaderError(); +} + +TEST_F(ParallelUnitLoaderTest, LoadsCorrectlyWithReader) { + auto factory = createFactory(); + ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; + + EXPECT_TRUE(readerMock.read(3)); // Unit: 0, rows: 0-2, load(0) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(3)); // Unit: 0, rows: 3-5 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(4)); // Unit: 0, rows: 6-9 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({true, true, false})); + + EXPECT_TRUE(readerMock.read(14)); // Unit: 1, rows: 0-13, unload(0), load(1) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, true, true})); + + // will only read 5 rows, no more rows in unit 1 + EXPECT_TRUE(readerMock.read(10)); // Unit: 1, rows: 14-19 + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, true, true})); + + EXPECT_TRUE(readerMock.read(30)); // Unit: 2, rows: 0-29, unload(1), load(2) + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, true})); + + EXPECT_FALSE(readerMock.read(30)); // No more data + EXPECT_EQ(readerMock.unitsLoaded(), std::vector({false, false, true})); +} + +// Performance comparison test +TEST_F(ParallelUnitLoaderTest, PerformanceComparison) { + std::vector rowsPerUnit = {100, 100, 100, 100, 100, 100, 100, 100}; + std::vector ioSizes = { + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + + // Measure ParallelUnitLoader performance + auto parallelStart = std::chrono::high_resolution_clock::now(); + { + auto factory = createFactory(); + ReaderMock reader(rowsPerUnit, ioSizes, factory, 0); + + for (size_t i = 0; i < rowsPerUnit.size(); ++i) { + uint64_t totalRowsRead = 0; + while (totalRowsRead < rowsPerUnit[i]) { + reader.read(25); + int nextRead = rowsPerUnit[i] - totalRowsRead; + totalRowsRead += std::min(25, nextRead); + } + } + } + auto parallelEnd = std::chrono::high_resolution_clock::now(); + + // Measure OnDemandUnitLoader performance + auto onDemandStart = std::chrono::high_resolution_clock::now(); + { + auto factory = std::make_shared(nullptr); + ReaderMock reader(rowsPerUnit, ioSizes, *factory, 0); + + for (size_t i = 0; i < rowsPerUnit.size(); ++i) { + uint64_t totalRowsRead = 0; + while (totalRowsRead < rowsPerUnit[i]) { + reader.read(25); + int nextRead = rowsPerUnit[i] - totalRowsRead; + totalRowsRead += std::min(25, nextRead); + } + } + } + auto onDemandEnd = std::chrono::high_resolution_clock::now(); + + auto parallelDuration = std::chrono::duration_cast( + parallelEnd - parallelStart); + auto onDemandDuration = std::chrono::duration_cast( + onDemandEnd - onDemandStart); + + // ParallelUnitLoader should be faster + EXPECT_GT(onDemandDuration.count(), parallelDuration.count()); +} diff --git a/velox/dwio/common/tests/ReaderTest.cpp b/velox/dwio/common/tests/ReaderTest.cpp index 59a870e3c8a8..780e178c32a3 100644 --- a/velox/dwio/common/tests/ReaderTest.cpp +++ b/velox/dwio/common/tests/ReaderTest.cpp @@ -148,24 +148,67 @@ TEST_F(ReaderTest, projectColumnsMutation) { makeFlatVector({0, 1, 3, 4, 5, 6, 7, 8, 9}), }); test::assertEqualVectors(expected, actual); - random::setSeed(42); - random::RandomSkipTracker randomSkip(0.5); - mutation.randomSkip = &randomSkip; - actual = RowReader::projectColumns(input, spec, &mutation); -#if FOLLY_HAVE_EXTRANDOM_SFMT19937 - expected = makeRowVector({ - makeFlatVector({0, 1, 3, 5, 6, 8}), - }); -#elif __APPLE__ - expected = makeRowVector({ - makeFlatVector({1, 5, 6, 7, 8, 9}), - }); -#else - expected = makeRowVector({ - makeFlatVector({3, 4, 7, 9}), - }); -#endif - test::assertEqualVectors(expected, actual); + + constexpr auto kNumRounds = 1U << 6; + + size_t numNonZero = 0; + size_t numNonMax = 0; + + // Test with random skip - use property-based testing instead of hardcoded + // outputs to avoid brittleness when folly::Random implementation changes. + std::mt19937 seeds; + for (size_t round = 0; round < kNumRounds; ++round) { + const auto seed = seeds(); + + random::setSeed(folly::to_narrow(seed)); + random::RandomSkipTracker randomSkip(0.5); + mutation.randomSkip = &randomSkip; + actual = RowReader::projectColumns(input, spec, &mutation); + + // Property 1: Result size should be less than input size (some rows + // skipped). With 0.5 sample rate and 9 eligible rows (excluding deleted row + // 2), we expect roughly 4-5 rows, but allow wider range for RNG variance. + EXPECT_GE(actual->size(), 0); + EXPECT_LE(actual->size(), kSize - 1); + + numNonZero += actual->size() > 0; + numNonMax += actual->size() < kSize - 1; + + // The result is a RowVector with one child column. Assume it. + auto res = actual->as()->childAt(0)->as>(); + std::vector vec; + vec.reserve(actual->size()); + for (vector_size_t i = 0; i < actual->size(); ++i) { + vec.push_back(res->valueAt(i)); + } + + // Property 2: All values in result must be from original input. + for (auto val : vec) { + // Each value must be in valid range + EXPECT_GE(val, 0); + EXPECT_LT(val, kSize); + // Deleted row should never appear + EXPECT_NE(val, 2); + } + + // Property 3: Values should be in ascending order (projectColumns preserves + // order). + EXPECT_TRUE(std::is_sorted(vec.begin(), vec.end())); + + // Property 4: No duplicate values (each input row appears at most once). + EXPECT_TRUE(std::adjacent_find(vec.begin(), vec.end()) == vec.end()); + + // Property 5: With a fixed seed, the result should be deterministic + // (same seed = same output, even if we don't know what that output is) + random::setSeed(folly::to_narrow(seed)); + random::RandomSkipTracker randomSkip2(0.5); + mutation.randomSkip = &randomSkip2; + auto actual2 = RowReader::projectColumns(input, spec, &mutation); + test::assertEqualVectors(actual, actual2); + } + + EXPECT_NE(0, numNonZero); + EXPECT_NE(0, numNonMax); } } // namespace diff --git a/velox/dwio/common/tests/ScanSpecTest.cpp b/velox/dwio/common/tests/ScanSpecTest.cpp new file mode 100644 index 000000000000..d61675878d27 --- /dev/null +++ b/velox/dwio/common/tests/ScanSpecTest.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/ScanSpec.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +#include + +namespace facebook::velox::common { +namespace { + +class ScanSpecTest : public testing::Test, public test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } +}; + +TEST_F(ScanSpecTest, applyFilter) { + auto rowVector = makeRowVector({ + makeFlatVector(64, folly::identity), + makeFlatVector(128, folly::identity), + }); + ASSERT_EQ(rowVector->size(), 64); + ScanSpec scanSpec(""); + scanSpec.addAllChildFields(*rowVector->type()); + scanSpec.childByName("c1")->setFilter(createBigintValues({63, 64}, false)); + uint64_t result = -1ll; + scanSpec.applyFilter(*rowVector, rowVector->size(), &result); + ASSERT_EQ(result, 1ull << 63); + result = -1ll; + scanSpec.childByName("c1")->applyFilter( + *rowVector->childAt("c1"), rowVector->size(), &result); + ASSERT_EQ(result, 1ull << 63); + rowVector = makeRowVector({ + makeFlatVector(128, folly::identity), + makeFlatVector(64, folly::identity), + }); + ASSERT_THROW( + scanSpec.applyFilter(*rowVector, rowVector->size(), &result), + VeloxRuntimeError); +} + +class TypedScanSpecTest : public testing::TestWithParam, + public test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + VectorPtr makeConstNullVector(TypePtr type, vector_size_t size) { + return BaseVector::createNullConstant(type, size, pool()); + } + + void addIsNullFilterRecursive(ScanSpec& scanSpec) { + scanSpec.setFilter(std::make_shared()); + for (auto& child : scanSpec.children()) { + addIsNullFilterRecursive(*child); + } + } + + void addIsNotNullFilterRecursive(ScanSpec& scanSpec) { + scanSpec.setFilter(std::make_shared()); + for (auto& child : scanSpec.children()) { + addIsNullFilterRecursive(*child); + } + } + + void addIsNullFilterToLeaf(ScanSpec& scanSpec) { + if (scanSpec.children().empty()) { + scanSpec.setFilter(std::make_shared()); + } else { + for (auto& child : scanSpec.children()) { + addIsNullFilterToLeaf(*child); + } + } + } + + void addIsNotNullFilterToLeaf(ScanSpec& scanSpec) { + if (scanSpec.children().empty()) { + scanSpec.setFilter(std::make_shared()); + } else { + for (auto& child : scanSpec.children()) { + addIsNotNullFilterToLeaf(*child); + } + } + } +}; + +// Due to how subfield filters of maps and arrays are pruning +// and can't affect the row selectivity, the current test skips +// cases when maps and arrays are the lone child of (nested) structs. +INSTANTIATE_TEST_SUITE_P( + TypedScanSpecTestSuite, + TypedScanSpecTest, + testing::Values( + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + VARBINARY(), + ROW({"int", "real"}, {INTEGER(), REAL()}), + // TODO: the test cases fail when not specifying names for + // the struct fields. This indicates bug in internal topology + // when finding children of nested scan specs. + ROW({"int", "map"}, {INTEGER(), MAP(INTEGER(), REAL())}), + ROW({"int", "array"}, {INTEGER(), ARRAY(INTEGER())}), + ROW({"int0", "array0", "row0"}, + {INTEGER(), + ARRAY(INTEGER()), + ROW({"int1", "real1", "row1"}, + {INTEGER(), + REAL(), + ROW({"int2", "real2"}, {INTEGER(), REAL()})})}))); + +TEST_P(TypedScanSpecTest, applyFilterSchemaEvolution) { + auto rowVector = makeRowVector({ + makeFlatVector(64, folly::identity), + makeConstNullVector(GetParam(), 64), + }); + ASSERT_EQ(rowVector->size(), 64); + LOG(INFO) << "Testing with type: " << rowVector->type()->toString(); + + { + ScanSpec scanSpec(""); + scanSpec.addAllChildFields(*rowVector->type()); + + ASSERT_TRUE(scanSpec.childByName("c0")); + scanSpec.childByName("c0")->setFilter( + std::make_shared(32, 64, false)); + + ASSERT_TRUE(scanSpec.childByName("c1")); + addIsNullFilterRecursive(*scanSpec.childByName("c1")); + + uint64_t result = -1ll; + scanSpec.applyFilter(*rowVector, rowVector->size(), &result); + ASSERT_EQ(result, -1ll << 32); + + // Now add a non-null filter on the missing column. + ASSERT_TRUE(scanSpec.childByName("c1")); + addIsNotNullFilterRecursive(*scanSpec.childByName("c1")); + result = -1ll; + scanSpec.applyFilter(*rowVector, rowVector->size(), &result); + ASSERT_EQ(result, 0); + } + + { + ScanSpec scanSpec(""); + scanSpec.addAllChildFields(*rowVector->type()); + + ASSERT_TRUE(scanSpec.childByName("c0")); + scanSpec.childByName("c0")->setFilter( + std::make_shared(32, 64, false)); + + // Now add a null filter only on the innermost node of the missing column. + // Should have the same result as recursive filters. + ASSERT_TRUE(scanSpec.childByName("c1")); + addIsNullFilterToLeaf(*scanSpec.childByName("c1")); + uint64_t result = -1ll; + scanSpec.applyFilter(*rowVector, rowVector->size(), &result); + ASSERT_EQ(result, -1ll << 32); + + // Now add is not null filter only on the innermost node of the missing + // column. Should have the same result as recursive filters. + ASSERT_TRUE(scanSpec.childByName("c1")); + addIsNotNullFilterToLeaf(*scanSpec.childByName("c1")); + result = -1ll; + scanSpec.applyFilter(*rowVector, rowVector->size(), &result); + ASSERT_EQ(result, 0); + } +} + +} // namespace +} // namespace facebook::velox::common diff --git a/velox/dwio/common/tests/SortingWriterTest.cpp b/velox/dwio/common/tests/SortingWriterTest.cpp new file mode 100644 index 000000000000..4beadc61e1ad --- /dev/null +++ b/velox/dwio/common/tests/SortingWriterTest.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/SortingWriter.h" +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/SortBuffer.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox::exec; +using namespace facebook::velox::memory; + +namespace facebook::velox::dwio::common::test { + +class MockWriter : public Writer { + public: + MockWriter() { + setState(State::kRunning); + } + + void write(const VectorPtr& data) override { + writtenData_.push_back(data); + totalRowsWritten_ += data->size(); + } + + bool finish() override { + return true; + } + + void flush() override {} + + void close() override { + setState(State::kClosed); + } + + void abort() override { + setState(State::kAborted); + } + + uint64_t getTotalRowsWritten() const { + return totalRowsWritten_; + } + + private: + std::vector writtenData_; + uint64_t totalRowsWritten_ = 0; +}; + +class SortingWriterTest : public testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + std::unique_ptr createSortBuffer() { + const RowTypePtr inputType = + ROW({{"c0", BIGINT()}, {"c1", INTEGER()}, {"c2", VARCHAR()}}); + + const std::vector sortColumnIndices{1}; + const std::vector sortCompareFlags{ + {true, true, false, CompareFlags::NullHandlingMode::kNullAsValue}}; + + const velox::common::PrefixSortConfig prefixSortConfig{ + std::numeric_limits::max(), + std::numeric_limits::max(), + 12}; + + return std::make_unique( + inputType, + sortColumnIndices, + sortCompareFlags, + pool_.get(), + &nonReclaimableSection_, + prefixSortConfig); + } + + RowVectorPtr createTestData(uint64_t numRows = 1000) { + return makeRowVector( + {makeFlatVector( + numRows, [](vector_size_t row) { return row; }), + makeFlatVector( + numRows, [numRows](vector_size_t row) { return numRows - row; }), + makeFlatVector(numRows, [](vector_size_t row) { + return fmt::format("row_{}", row); + })}); + } + + tsan_atomic nonReclaimableSection_{false}; +}; + +TEST_F(SortingWriterTest, largeRowSizeExceedsMaxOutputBytes) { + auto mockWriter = std::make_unique(); + auto mockWriterPtr = mockWriter.get(); + + auto sortBuffer = createSortBuffer(); + + const vector_size_t maxOutputRowsConfig = 1000; + const uint64_t maxOutputBytesConfig = 1; + const uint64_t outputTimeSliceLimitMs = 1000; + + SortingWriter sortingWriter( + std::move(mockWriter), + std::move(sortBuffer), + maxOutputRowsConfig, + maxOutputBytesConfig, + outputTimeSliceLimitMs); + + RowVectorPtr testData = createTestData(10); + sortingWriter.write(testData); + ASSERT_TRUE(sortingWriter.finish()); + ASSERT_GT(mockWriterPtr->getTotalRowsWritten(), 0); +} + +} // namespace facebook::velox::dwio::common::test diff --git a/velox/dwio/common/tests/TestBufferedInput.cpp b/velox/dwio/common/tests/TestBufferedInput.cpp index 6fa5be8da000..39461993bd24 100644 --- a/velox/dwio/common/tests/TestBufferedInput.cpp +++ b/velox/dwio/common/tests/TestBufferedInput.cpp @@ -16,6 +16,8 @@ #include #include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/hive/BufferedInputBuilder.h" #include "velox/dwio/common/BufferedInput.h" using namespace facebook::velox::dwio::common; @@ -29,13 +31,25 @@ class ReadFileMock : public ::facebook::velox::ReadFile { public: virtual ~ReadFileMock() override = default; +// On Centos9 the gtest mock header doesn't initialize the +// buffer_ member in MatcherBase correctly - the default constructor only +// initializes one: /usr/include/gtest/gtest-matchers.h:302:33 resulting in +// error: +// '.testing::Matcher::.testing::internal::MatcherBase::buffer_' is used uninitialized +// [-Werror=uninitialized] +// 302 | : vtable_(other.vtable_), buffer_(other.buffer_) { +// Fix: https://github.com/google/googletest/pull/3797 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" MOCK_METHOD( std::string_view, pread, (uint64_t offset, uint64_t length, void* buf, - facebook::velox::filesystems::File::IoStats* stats), + (const facebook::velox::FileStorageContext&)fileStorageContext), (const, override)); MOCK_METHOD(bool, shouldCoalesce, (), (const, override)); @@ -48,7 +62,7 @@ class ReadFileMock : public ::facebook::velox::ReadFile { preadv, (folly::Range regions, folly::Range iobufs, - facebook::velox::filesystems::File::IoStats* stats), + (const facebook::velox::FileStorageContext&)fileStorageContext), (const, override)); }; @@ -60,14 +74,14 @@ void expectPreads( EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); for (auto& read : reads) { ASSERT_GE(content.size(), read.offset + read.length); - EXPECT_CALL(file, pread(read.offset, read.length, _, nullptr)) + EXPECT_CALL(file, pread(read.offset, read.length, _, _)) .Times(1) .WillOnce( [content]( uint64_t offset, uint64_t length, void* buf, - facebook::velox::filesystems::File::IoStats* stats) + const facebook::velox::FileStorageContext& fileStorageContext) -> std::string_view { memcpy(buf, content.data() + offset, length); return {content.data() + offset, length}; @@ -81,13 +95,14 @@ void expectPreadvs( std::vector reads) { EXPECT_CALL(file, getName()).WillRepeatedly(Return("mock_name")); EXPECT_CALL(file, size()).WillRepeatedly(Return(content.size())); - EXPECT_CALL(file, preadv(_, _, nullptr)) + EXPECT_CALL(file, preadv(_, _, _)) .Times(1) .WillOnce( [content, reads]( folly::Range regions, folly::Range iobufs, - facebook::velox::filesystems::File::IoStats* stats) -> uint64_t { + const facebook::velox::FileStorageContext& fileStorageContext) + -> uint64_t { EXPECT_EQ(regions.size(), reads.size()); uint64_t length = 0; for (size_t i = 0; i < reads.size(); ++i) { @@ -110,6 +125,7 @@ void expectPreadvs( return length; }); } +#pragma GCC diagnostic pop std::optional getNext(SeekableInputStream& input) { const void* buf = nullptr; @@ -130,7 +146,6 @@ class TestBufferedInput : public testing::Test { const std::shared_ptr pool_ = memoryManager()->addLeafPool(); }; -} // namespace TEST_F(TestBufferedInput, ZeroLengthStream) { auto readFile = @@ -377,3 +392,45 @@ TEST_F(TestBufferedInput, VReadSortingWithLabels) { EXPECT_EQ(next.value(), r.second); } } + +class CustomBufferedInputBuilder + : public facebook::velox::connector::hive::BufferedInputBuilder { + public: + std::unique_ptr create( + const facebook::velox::FileHandle& fileHandle, + const facebook::velox::dwio::common::ReaderOptions& readerOpts, + const facebook::velox::connector::ConnectorQueryCtx* connectorQueryCtx, + std::shared_ptr ioStats, + std::shared_ptr fsStats, + folly::Executor* executor, + const folly::F14FastMap& fileReadOps = {}) + override { + VELOX_NYI("Not implemented in CustomBufferedInputBuilder"); + } +}; + +class CustomBufferedInputTest : public testing::Test { + protected: + static void SetUpTestCase() { + MemoryManager::testingSetInstance(MemoryManager::Options{}); + facebook::velox::connector::hive::BufferedInputBuilder::registerBuilder( + std::make_shared()); + } + + const std::shared_ptr pool_ = memoryManager()->addLeafPool(); +}; + +} // namespace + +TEST_F(CustomBufferedInputTest, basic) { + facebook::velox::FileHandle fileHandle; + facebook::velox::dwio::common::ReaderOptions readerOpts(pool_.get()); + auto ioStats = std::make_shared(); + auto fsStats = + std::make_shared(); + + VELOX_ASSERT_THROW( + facebook::velox::connector::hive::BufferedInputBuilder::getInstance() + ->create(fileHandle, readerOpts, nullptr, ioStats, fsStats, nullptr), + "Not implemented in CustomBufferedInputBuilder"); +} diff --git a/velox/dwio/common/tests/ThrottlerTest.cpp b/velox/dwio/common/tests/ThrottlerTest.cpp index 773e8d0dd70e..12ea0c86d75b 100644 --- a/velox/dwio/common/tests/ThrottlerTest.cpp +++ b/velox/dwio/common/tests/ThrottlerTest.cpp @@ -110,14 +110,15 @@ TEST_F(ThrottlerTest, throttle) { SCOPED_TRACE(fmt::format("signal: {}", Throttler::signalTypeName(signal))); Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - 2.0, - signal == Throttler::SignalType::kLocal ? 2 : 1'000, - signal == Throttler::SignalType::kGlobal ? 2 : 1'000, - signal == Throttler::SignalType::kNetwork ? 2 : 1'000)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + 2.0, + signal == Throttler::SignalType::kLocal ? 2 : 1'000, + signal == Throttler::SignalType::kGlobal ? 2 : 1'000, + signal == Throttler::SignalType::kNetwork ? 2 : 1'000)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -186,16 +187,17 @@ TEST_F(ThrottlerTest, expire) { for (const auto signal : kSignalTypes) { SCOPED_TRACE(fmt::format("signal: {}", Throttler::signalTypeName(signal))); Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - 2.0, - signal == Throttler::SignalType::kLocal ? 2 : 1'000, - signal == Throttler::SignalType::kGlobal ? 2 : 1'000, - signal == Throttler::SignalType::kNetwork ? 2 : 1'000, - 1'000, - 1'000)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + 2.0, + signal == Throttler::SignalType::kLocal ? 2 : 1'000, + signal == Throttler::SignalType::kGlobal ? 2 : 1'000, + signal == Throttler::SignalType::kNetwork ? 2 : 1'000, + 1'000, + 1'000)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -228,8 +230,9 @@ TEST_F(ThrottlerTest, expire) { TEST_F(ThrottlerTest, differentLocals) { const uint64_t minThrottleBackoffMs = 1'000; const uint64_t maxThrottleBackoffMs = 2'000; - Throttler::init(Throttler::Config( - true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 1'0000)); + Throttler::init( + Throttler::Config( + true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 1'0000)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -320,8 +323,9 @@ TEST_F(ThrottlerTest, differentLocals) { TEST_F(ThrottlerTest, differentGlobals) { const uint64_t minThrottleBackoffMs = 1'000; const uint64_t maxThrottleBackoffMs = 2'000; - Throttler::init(Throttler::Config( - true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 1'0000, 2)); + Throttler::init( + Throttler::Config( + true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 1'0000, 2)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -409,14 +413,15 @@ TEST_F(ThrottlerTest, differentGlobals) { TEST_F(ThrottlerTest, differentNetworks) { const uint64_t minThrottleBackoffMs = 1'000; const uint64_t maxThrottleBackoffMs = 2'000; - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - 2.0, - 1'0000, - 1'0000, - 2)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + 2.0, + 1'0000, + 1'0000, + 2)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -511,8 +516,9 @@ TEST_F(ThrottlerTest, maxOfGlobalAndLocal) { for (const bool localFirst : {false, true}) { SCOPED_TRACE(fmt::format("localFirst: {}", localFirst)); Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 2)); + Throttler::init( + Throttler::Config( + true, minThrottleBackoffMs, maxThrottleBackoffMs, 2.0, 2, 2)); auto* instance = Throttler::instance(); const auto& stats = instance->stats(); ASSERT_EQ(stats.localThrottled, 0); @@ -627,16 +633,17 @@ TEST_F(ThrottlerTest, fuzz) { const uint32_t maxCacheEntries = 64; const uint32_t cacheTTLMs = 10; Throttler::testingReset(); - Throttler::init(Throttler::Config( - true, - minThrottleBackoffMs, - maxThrottleBackoffMs, - backoffScaleFactor, - minLocalThrottledSignals, - minGlobalThrottledSignals, - minNetworkThrottledSignals, - maxCacheEntries, - cacheTTLMs)); + Throttler::init( + Throttler::Config( + true, + minThrottleBackoffMs, + maxThrottleBackoffMs, + backoffScaleFactor, + minLocalThrottledSignals, + minGlobalThrottledSignals, + minNetworkThrottledSignals, + maxCacheEntries, + cacheTTLMs)); auto* instance = Throttler::instance(); const auto seed = getCurrentTimeMs(); diff --git a/velox/dwio/common/tests/UnitLoaderBaseTest.h b/velox/dwio/common/tests/UnitLoaderBaseTest.h new file mode 100644 index 000000000000..9faf9bc91b3e --- /dev/null +++ b/velox/dwio/common/tests/UnitLoaderBaseTest.h @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/UnitLoaderTools.h" +#include "velox/dwio/common/tests/utils/UnitLoaderTestTools.h" + +using facebook::velox::dwio::common::LoadUnit; +using facebook::velox::dwio::common::test::getUnitsLoadedWithFalse; +using facebook::velox::dwio::common::test::LoadUnitMock; +using facebook::velox::dwio::common::test::ReaderMock; + +/// Base test class template that provides common test functionality for +/// different UnitLoader implementations. This template class can be inherited +/// by specific test classes to get access to common test methods. Each derived +/// class should provide a createFactory() method that returns the appropriate +/// factory instance. +template +class UnitLoaderBaseTest : public ::testing::Test { + protected: + /// Factory method to create the appropriate UnitLoaderFactory instance. + /// This method should be implemented by derived classes. + virtual UnitLoaderFactoryType createFactory() = 0; + + /// Test that UnitLoader factory handles the case where no units exist but + /// skip is requested + void testNoUnitButSkip() { + UnitLoaderFactoryType factory = createFactory(); + std::vector> units; + + EXPECT_NO_THROW(factory.create(std::move(units), 0)); + + std::vector> units2; + VELOX_ASSERT_THROW( + factory.create(std::move(units2), 1), + "Can only skip up to the past-the-end row of the file."); + } + + /// Test that UnitLoader factory handles initial skip correctly for various + /// skip values + void testInitialSkip() { + auto getFactoryWithSkip = [this](uint64_t skipToRow) { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(3)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + units.push_back(std::make_unique(20, 0, unitsLoaded, 1)); + units.push_back(std::make_unique(30, 0, unitsLoaded, 2)); + factory.create(std::move(units), skipToRow); + }; + + EXPECT_NO_THROW(getFactoryWithSkip(0)); + EXPECT_NO_THROW(getFactoryWithSkip(1)); + EXPECT_NO_THROW(getFactoryWithSkip(9)); + EXPECT_NO_THROW(getFactoryWithSkip(10)); + EXPECT_NO_THROW(getFactoryWithSkip(11)); + EXPECT_NO_THROW(getFactoryWithSkip(29)); + EXPECT_NO_THROW(getFactoryWithSkip(30)); + EXPECT_NO_THROW(getFactoryWithSkip(31)); + EXPECT_NO_THROW(getFactoryWithSkip(59)); + EXPECT_NO_THROW(getFactoryWithSkip(60)); + VELOX_ASSERT_THROW( + getFactoryWithSkip(61), + "Can only skip up to the past-the-end row of the file."); + VELOX_ASSERT_THROW( + getFactoryWithSkip(100), + "Can only skip up to the past-the-end row of the file."); + } + + /// Test that the same unit can be requested multiple times without issues + void testCanRequestUnitMultipleTimes() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + unitLoader->getLoadedUnit(0); + unitLoader->getLoadedUnit(0); + unitLoader->getLoadedUnit(0); + } + + /// Test that requesting a unit index out of range throws an exception + void testUnitOutOfRange() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + unitLoader->getLoadedUnit(0); + + VELOX_ASSERT_THROW(unitLoader->getLoadedUnit(1), "Unit out of range"); + } + + /// Test that seeking out of range throws an exception + void testSeekOutOfRange() { + auto factory = createFactory(); + std::vector unitsLoaded(getUnitsLoadedWithFalse(1)); + std::vector> units; + units.push_back(std::make_unique(10, 0, unitsLoaded, 0)); + + auto unitLoader = factory.create(std::move(units), 0); + + unitLoader->onSeek(0, 10); + + VELOX_ASSERT_THROW(unitLoader->onSeek(0, 11), "Row out of range"); + } + + /// Test that seeking out of range in ReaderMock throws appropriate exception + void testSeekOutOfRangeReaderError() { + auto factory = createFactory(); + ReaderMock readerMock{{10, 20, 30}, {0, 0, 0}, factory, 0}; + + readerMock.seek(59); + readerMock.seek(60); + + VELOX_ASSERT_THROW( + readerMock.seek(61), + "Can't seek to possition 61 in file. Must be up to 60."); + } +}; diff --git a/velox/dwio/common/tests/WriterTest.cpp b/velox/dwio/common/tests/WriterTest.cpp index 51e68de59d08..8fbb5a444b96 100644 --- a/velox/dwio/common/tests/WriterTest.cpp +++ b/velox/dwio/common/tests/WriterTest.cpp @@ -22,6 +22,32 @@ using namespace ::testing; namespace facebook::velox::dwio::common { namespace { + +class MockWriter : public Writer { + public: + MockWriter() = default; + + void setStateForTest(State state) { + setState(state); + } + + void callCheckRunning() const { + checkRunning(); + } + + void write(const VectorPtr& /*data*/) override {} + + void flush() override {} + + bool finish() override { + return true; + } + + void abort() override {} + + void close() override {} +}; + TEST(WriterTest, stateString) { ASSERT_EQ(Writer::stateString(Writer::State::kInit), "INIT"); ASSERT_EQ(Writer::stateString(Writer::State::kRunning), "RUNNING"); @@ -31,5 +57,16 @@ TEST(WriterTest, stateString) { VELOX_ASSERT_THROW( Writer::stateString(static_cast(100)), "BAD STATE: 100"); } + +TEST(WriterTest, checkRunning) { + MockWriter writer; + VELOX_ASSERT_THROW(writer.callCheckRunning(), "Writer is not running: INIT"); + writer.setStateForTest(Writer::State::kRunning); + ASSERT_NO_THROW(writer.callCheckRunning()); + writer.setStateForTest(Writer::State::kClosed); + VELOX_ASSERT_USER_THROW( + writer.callCheckRunning(), + "Writer is not running: CLOSED. Write operations are not allowed on a closed writer."); +} } // namespace } // namespace facebook::velox::dwio::common diff --git a/velox/dwio/common/tests/utils/CMakeLists.txt b/velox/dwio/common/tests/utils/CMakeLists.txt index 2765c1091d6a..bbbd5c2ebe54 100644 --- a/velox/dwio/common/tests/utils/CMakeLists.txt +++ b/velox/dwio/common/tests/utils/CMakeLists.txt @@ -19,7 +19,8 @@ add_library( DataSetBuilder.cpp FilterGenerator.cpp UnitLoaderTestTools.cpp - E2EFilterTestBase.cpp) + E2EFilterTestBase.cpp +) target_link_libraries( velox_dwio_common_test_utils @@ -36,7 +37,8 @@ target_link_libraries( velox_type velox_type_fbhive velox_vector - velox_vector_test_lib) + velox_vector_test_lib +) # older versions of GCC need it to allow std::filesystem if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9) diff --git a/velox/dwio/common/tests/utils/DataSetBuilder.cpp b/velox/dwio/common/tests/utils/DataSetBuilder.cpp index 85c508dd0a3b..27ecb8a9ffea 100644 --- a/velox/dwio/common/tests/utils/DataSetBuilder.cpp +++ b/velox/dwio/common/tests/utils/DataSetBuilder.cpp @@ -47,8 +47,9 @@ DataSetBuilder& DataSetBuilder::makeDataset( for (size_t i = 0; i < batchCount; ++i) { if (withRecursiveNulls) { - batches_->push_back(std::static_pointer_cast( - BatchMaker::createBatch(rowType, numRows, pool_, nullptr, i))); + batches_->push_back( + std::static_pointer_cast( + BatchMaker::createBatch(rowType, numRows, pool_, nullptr, i))); } else { batches_->push_back( std::static_pointer_cast(BatchMaker::createBatch( @@ -207,7 +208,7 @@ DataSetBuilder& DataSetBuilder::withUniqueStringsForField( if (strings->isNullAt(row)) { continue; } - std::string value = strings->valueAt(row); + auto value = std::string(strings->valueAt(row)); value += fmt::format("{}", row); strings->set(row, StringView(value)); } @@ -282,7 +283,7 @@ DataSetBuilder& DataSetBuilder::makeMapStringValues( continue; } if (!keys->isNullAt(i) && i % 3 == 0) { - std::string str = keys->valueAt(i); + auto str = std::string(keys->valueAt(i)); str += "----123456789"; keys->set(i, StringView(str)); } @@ -304,7 +305,7 @@ DataSetBuilder& DataSetBuilder::makeMapStringValues( continue; } if (!values->isNullAt(i) && i % 3 == 0) { - std::string str = values->valueAt(i); + auto str = std::string(values->valueAt(i)); str += "----123456789"; values->set(i, StringView(str)); } diff --git a/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp b/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp index 446aaceee766..4c6e2dd70f2c 100644 --- a/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp +++ b/velox/dwio/common/tests/utils/E2EFilterTestBase.cpp @@ -18,6 +18,7 @@ #include "velox/dwio/common/tests/utils/DataSetBuilder.h" #include "velox/expression/Expr.h" +#include "velox/expression/ExprConstants.h" #include "velox/expression/ExprToSubfieldFilter.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/parse/Expressions.h" @@ -167,7 +168,7 @@ void E2EFilterTestBase::readWithFilter( auto resultBatch = BaseVector::create(rowType_, 1, leafPool_.get()); resetReadBatchSizes(); int32_t clearCnt = 0; - auto deletedRowsIter = mutationSpec.deletedRows.begin(); + auto deletedRowsIter = mutationSpec.deletedRows.cbegin(); while (true) { { MicrosecondTimer timer(&time); @@ -181,7 +182,7 @@ void E2EFilterTestBase::readWithFilter( auto readSize = rowReader->nextReadSize(nextReadBatchSize()); std::vector isDeleted(bits::nwords(readSize)); bool haveDelete = false; - for (; deletedRowsIter != mutationSpec.deletedRows.end(); + for (; deletedRowsIter != mutationSpec.deletedRows.cend(); ++deletedRowsIter) { auto i = *deletedRowsIter; if (i < nextRowNumber) { @@ -491,15 +492,31 @@ void E2EFilterTestBase::testMetadataFilterImpl( core::ExpressionEvaluator* evaluator, const std::string& remainingFilter, std::function validationFilter) { - SCOPED_TRACE(fmt::format("remainingFilter={}", remainingFilter)); + SCOPED_TRACE(fmt::format("remainingFilter='{}'", remainingFilter)); + auto untypedExpr = parse::parseExpr(remainingFilter, {}); + auto typedExpr = core::Expressions::inferTypes( + untypedExpr, batches[0]->type(), leafPool_.get()); + testMetadataFilterImpl( + batches, + std::move(filterField), + std::move(filter), + evaluator, + std::move(typedExpr), + std::move(validationFilter)); +} + +void E2EFilterTestBase::testMetadataFilterImpl( + const std::vector& batches, + common::Subfield filterField, + std::unique_ptr filter, + core::ExpressionEvaluator* evaluator, + core::TypedExprPtr typedExpr, + std::function validationFilter) { auto spec = std::make_shared(""); if (filter) { spec->getOrCreateChild(std::move(filterField)) ->setFilter(std::move(filter)); } - auto untypedExpr = parse::parseExpr(remainingFilter, {}); - auto typedExpr = core::Expressions::inferTypes( - untypedExpr, batches[0]->type(), leafPool_.get()); auto metadataFilter = std::make_shared(*spec, *typedExpr, evaluator); auto specA = spec->getOrCreateChild(common::Subfield("a")); @@ -580,12 +597,13 @@ void E2EFilterTestBase::testMetadataFilter() { nullptr, c->size(), std::vector({c})); - batches.push_back(std::make_shared( - leafPool_.get(), - ROW({{"a", a->type()}, {"b", b->type()}}), - nullptr, - a->size(), - std::vector({a, b}))); + batches.push_back( + std::make_shared( + leafPool_.get(), + ROW({{"a", a->type()}, {"b", b->type()}}), + nullptr, + a->size(), + std::vector({a, b}))); } writeToMemory(batches[0]->type(), batches, false); @@ -621,6 +639,56 @@ void E2EFilterTestBase::testMetadataFilter() { [](int64_t a, int64_t) { return !!(a == 2 || a == 3 || a == 5 || a == 7); }); + { + SCOPED_TRACE("remainingFilter='a == 1 or a == 3 or a == 8'"); + auto typedExpr1 = core::Expressions::inferTypes( + parse::parseExpr("a == 1", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr2 = core::Expressions::inferTypes( + parse::parseExpr("a == 3", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr3 = core::Expressions::inferTypes( + parse::parseExpr("a == 8", {}), batches[0]->type(), leafPool_.get()); + + auto typedExpr = std::make_shared( + velox::BOOLEAN(), + std::vector{ + std::move(typedExpr1), + std::move(typedExpr2), + std::move(typedExpr3), + }, + expression::kOr); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + &evaluator, + std::move(typedExpr), + [](int64_t a, int64_t) { return a == 1 || a == 3 || a == 8; }); + } + { + SCOPED_TRACE("remainingFilter='a >= 1 and a <= 100 and a == 8'"); + auto typedExpr1 = core::Expressions::inferTypes( + parse::parseExpr("a >= 1", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr2 = core::Expressions::inferTypes( + parse::parseExpr("a <= 100", {}), batches[0]->type(), leafPool_.get()); + auto typedExpr3 = core::Expressions::inferTypes( + parse::parseExpr("b.c != 8", {}), batches[0]->type(), leafPool_.get()); + + auto typedExpr = std::make_shared( + velox::BOOLEAN(), + std::vector{ + std::move(typedExpr1), + std::move(typedExpr2), + std::move(typedExpr3), + }, + expression::kAnd); + testMetadataFilterImpl( + batches, + common::Subfield("a"), + nullptr, + &evaluator, + std::move(typedExpr), + [](int64_t a, int64_t c) { return a >= 1 && a <= 100 && c != 8; }); + } { SCOPED_TRACE("Values not unique in row group"); @@ -681,8 +749,18 @@ void E2EFilterTestBase::testSubfieldsPruning() { [](auto) { return 1; }, [](auto) { return 0; }, [](auto) { return "foofoofoofoofoo"_sv; }); - batches.push_back( - vectorMaker.rowVector({"a", "b", "c", "d"}, {a, b, c, d})); + auto e = vectorMaker.mapVector( + batchSize_, + [&](auto) { return kMapSize; }, + [](auto j) { return j; }, + [&](auto j) { return j % kMapSize; }); + auto f = vectorMaker.arrayVector( + batchSize_, + [&](auto j) { return kMapSize; }, + [&](auto j) { return j % kMapSize; }, + [&](auto j) { return j >= i + 1 && j % 23 == (i + 1) % 23; }); + batches.push_back(vectorMaker.rowVector( + {"a", "b", "c", "d", "e", "f"}, {a, b, c, d, e, f})); } writeToMemory(batches[0]->type(), batches, false); auto spec = std::make_shared(""); @@ -707,6 +785,12 @@ void E2EFilterTestBase::testSubfieldsPruning() { auto specD = spec->addFieldRecursively("d", *MAP(BIGINT(), VARCHAR()), 3); specD->childByName(common::ScanSpec::kMapKeysFieldName) ->setFilter(common::createBigintValues({1}, false)); + auto specE = spec->addFieldRecursively("e", *MAP(BIGINT(), BIGINT()), 4); + specE->childByName(common::ScanSpec::kMapValuesFieldName) + ->setFilter(common::createBigintValues({0, 2, 4}, false)); + auto specF = spec->addFieldRecursively("f", *ARRAY(BIGINT()), 5); + specF->childByName(common::ScanSpec::kArrayElementsFieldName) + ->setFilter(common::createBigintValues({0, 2, 4}, false)); ReaderOptions readerOpts{leafPool_.get()}; RowReaderOptions rowReaderOpts; auto input = std::make_unique( @@ -756,6 +840,31 @@ void E2EFilterTestBase::testSubfieldsPruning() { auto* dd = actual->childAt(3)->loadedVector()->asUnchecked(); ASSERT_FALSE(dd->isNullAt(ii)); ASSERT_EQ(dd->sizeAt(ii), 0); + auto* e = expected->childAt(4)->asUnchecked(); + auto* ee = actual->childAt(4)->loadedVector()->asUnchecked(); + ASSERT_FALSE(ee->isNullAt(ii)); + ASSERT_EQ(ee->sizeAt(ii), (kMapSize + 1) / 2); + for (int k = 0; k < kMapSize; k += 2) { + int k1 = ee->offsetAt(ii) + k / 2; + int k2 = e->offsetAt(j) + k; + ASSERT_TRUE(ee->mapKeys()->equalValueAt(e->mapKeys().get(), k1, k2)); + ASSERT_TRUE( + ee->mapValues()->equalValueAt(e->mapValues().get(), k1, k2)); + } + auto* f = expected->childAt(5)->asUnchecked(); + auto* ff = actual->childAt(5)->loadedVector()->asUnchecked(); + if (f->isNullAt(j)) { + ASSERT_TRUE(ff->isNullAt(ii)); + } else { + ASSERT_FALSE(ff->isNullAt(ii)); + for (int k = 0; k < kMapSize; k += 2) { + int k1 = ff->offsetAt(ii) + k / 2; + int k2 = f->offsetAt(j) + k; + + ASSERT_TRUE( + ff->elements()->equalValueAt(f->elements().get(), k1, k2)); + } + } ++ii; } } diff --git a/velox/dwio/common/tests/utils/E2EFilterTestBase.h b/velox/dwio/common/tests/utils/E2EFilterTestBase.h index 16d30e36a7b7..eada6837f60d 100644 --- a/velox/dwio/common/tests/utils/E2EFilterTestBase.h +++ b/velox/dwio/common/tests/utils/E2EFilterTestBase.h @@ -64,7 +64,7 @@ class TestingHook : public ValueHook { } } - void addValue(vector_size_t row, folly::StringPiece value) override { + void addValue(vector_size_t row, std::string_view value) override { if constexpr (std::is_same_v) { result_->set(row, StringView(value)); } else { @@ -336,6 +336,14 @@ class E2EFilterTestBase : public testing::Test { const std::string& remainingFilter, std::function validationFilter); + void testMetadataFilterImpl( + const std::vector& batches, + common::Subfield filterField, + std::unique_ptr filter, + core::ExpressionEvaluator* evaluator, + core::TypedExprPtr typedExpr, + std::function validationFilter); + protected: void testMetadataFilter(); diff --git a/velox/dwio/common/tests/utils/FilterGenerator.cpp b/velox/dwio/common/tests/utils/FilterGenerator.cpp index d532b4cec8f3..63fc29e33e56 100644 --- a/velox/dwio/common/tests/utils/FilterGenerator.cpp +++ b/velox/dwio/common/tests/utils/FilterGenerator.cpp @@ -447,7 +447,9 @@ void FilterGenerator::addToScanSpec( const SubfieldFilters& filters, ScanSpec& spec) { for (auto& pair : filters) { - spec.getOrCreateChild(pair.first)->addFilter(*pair.second); + auto* child = spec.getOrCreateChild(pair.first); + VELOX_CHECK_NULL(child->filter()); + child->setFilter(pair.second); } } @@ -661,7 +663,7 @@ void pruneRandomSubfield( break; case TypeKind::VARCHAR: case TypeKind::VARBINARY: - stringKeys.push_back( + stringKeys.emplace_back( keys->asUnchecked>()->valueAt(jj)); break; default: diff --git a/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp b/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp index e2ec87ae605c..0342d0d530e3 100644 --- a/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp +++ b/velox/dwio/common/tests/utils/UnitLoaderTestTools.cpp @@ -90,8 +90,9 @@ bool ReaderMock::loadUnit() { std::vector> ReaderMock::getUnits() { std::vector> units; for (size_t i = 0; i < rowsPerUnit_.size(); ++i) { - units.emplace_back(std::make_unique( - rowsPerUnit_[i], ioSizes_[i], unitsLoaded_, i)); + units.emplace_back( + std::make_unique( + rowsPerUnit_[i], ioSizes_[i], unitsLoaded_, i)); } return units; } diff --git a/velox/dwio/common/tests/utils/UnitLoaderTestTools.h b/velox/dwio/common/tests/utils/UnitLoaderTestTools.h index 9eae97f575c2..6c36c10c8651 100644 --- a/velox/dwio/common/tests/utils/UnitLoaderTestTools.h +++ b/velox/dwio/common/tests/utils/UnitLoaderTestTools.h @@ -32,16 +32,20 @@ class LoadUnitMock : public LoadUnit { uint64_t rowCount, uint64_t ioSize, std::vector& unitsLoaded, - size_t unitId) + size_t unitId, + std::chrono::milliseconds loadDelay = std::chrono::milliseconds(100)) : rowCount_{rowCount}, ioSize_{ioSize}, unitsLoaded_{unitsLoaded}, - unitId_{unitId} {} + unitId_{unitId}, + loadDelay_(loadDelay) {} ~LoadUnitMock() override = default; void load() override { VELOX_CHECK(!isLoaded()); + // Simulate loading time + std::this_thread::sleep_for(loadDelay_); unitsLoaded_[unitId_] = true; } @@ -67,6 +71,7 @@ class LoadUnitMock : public LoadUnit { uint64_t ioSize_; std::vector& unitsLoaded_; size_t unitId_; + std::chrono::milliseconds loadDelay_; }; class ReaderMock { @@ -82,7 +87,7 @@ class ReaderMock { void seek(uint64_t rowNumber); std::vector unitsLoaded() const { - return {unitsLoaded_.begin(), unitsLoaded_.end()}; + return {unitsLoaded_.cbegin(), unitsLoaded_.cend()}; } private: diff --git a/velox/dwio/common/wrap/CMakeLists.txt b/velox/dwio/common/wrap/CMakeLists.txt new file mode 100644 index 000000000000..a598690b32e5 --- /dev/null +++ b/velox/dwio/common/wrap/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() diff --git a/velox/dwio/dwrf/CMakeLists.txt b/velox/dwio/dwrf/CMakeLists.txt index 613f69ad2f06..65188eb58d64 100644 --- a/velox/dwio/dwrf/CMakeLists.txt +++ b/velox/dwio/dwrf/CMakeLists.txt @@ -22,3 +22,5 @@ elseif(${VELOX_BUILD_TEST_UTILS}) endif() add_subdirectory(utils) add_subdirectory(writer) + +velox_install_library_headers() diff --git a/velox/dwio/dwrf/common/CMakeLists.txt b/velox/dwio/dwrf/common/CMakeLists.txt index bb6356c94e57..3f8f38028eaa 100644 --- a/velox/dwio/dwrf/common/CMakeLists.txt +++ b/velox/dwio/dwrf/common/CMakeLists.txt @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_subdirectory(wrap) + velox_add_library( velox_dwio_dwrf_common ByteRLE.cpp @@ -26,7 +28,8 @@ velox_add_library( RLEv2.cpp Statistics.cpp wrap/orc-proto-wrapper.cpp - wrap/dwrf-proto-wrapper.cpp) + wrap/dwrf-proto-wrapper.cpp +) velox_link_libraries( velox_dwio_dwrf_common @@ -38,10 +41,14 @@ velox_link_libraries( velox_caching Snappy::snappy zstd::zstd - protobuf::libprotobuf) + protobuf::libprotobuf +) # required for the wrapped protobuf headers/sources -velox_include_directories(velox_dwio_dwrf_common PUBLIC ${PROJECT_BINARY_DIR}) +velox_include_directories( + velox_dwio_dwrf_common + PUBLIC "$" +) if(NOT VELOX_MONO_LIBRARY) # trigger generation of pb files diff --git a/velox/dwio/dwrf/common/Common.h b/velox/dwio/dwrf/common/Common.h index 783f8a0329d7..ae467654c1fa 100644 --- a/velox/dwio/dwrf/common/Common.h +++ b/velox/dwio/dwrf/common/Common.h @@ -29,11 +29,11 @@ namespace facebook::velox::dwrf { // Writer version -constexpr folly::StringPiece WRITER_NAME_KEY{"orc.writer.name"}; -constexpr folly::StringPiece WRITER_VERSION_KEY{"orc.writer.version"}; -constexpr folly::StringPiece WRITER_HOSTNAME_KEY{"orc.writer.host"}; -constexpr folly::StringPiece kDwioWriter{"dwio"}; -constexpr folly::StringPiece kPrestoWriter{"presto"}; +constexpr std::string_view kWriterNameKey{"orc.writer.name"}; +constexpr std::string_view kWriterVersionKey{"orc.writer.version"}; +constexpr std::string_view kWriterHostnameKey{"orc.writer.host"}; +constexpr std::string_view kDwioWriter{"dwio"}; +constexpr std::string_view kPrestoWriter{"presto"}; enum class DwrfFormat : uint8_t { kDwrf = 0, diff --git a/velox/dwio/dwrf/common/Config.cpp b/velox/dwio/dwrf/common/Config.cpp index 8ce1ad5f2155..b9c556d8a88c 100644 --- a/velox/dwio/dwrf/common/Config.cpp +++ b/velox/dwio/dwrf/common/Config.cpp @@ -130,7 +130,7 @@ Config::Entry> Config::MAP_FLAT_COLS( [](const std::string& /* key */, const std::string& val) { std::vector result; if (!val.empty()) { - std::vector pieces; + std::vector pieces; folly::split(',', val, pieces, true); for (const auto& p : pieces) { const auto& trimmedCol = folly::trimWhitespace(p); @@ -182,7 +182,7 @@ Config::Entry>> Config::Entry Config::MAP_FLAT_MAX_KEYS( "orc.map.flat.max.keys", - 20000); + 30000); Config::Entry Config::MAX_DICTIONARY_SIZE( "hive.exec.orc.max.dictionary.size", diff --git a/velox/dwio/dwrf/common/FileMetadata.cpp b/velox/dwio/dwrf/common/FileMetadata.cpp index ccb9f6faa7f5..2b3f2f80d499 100644 --- a/velox/dwio/dwrf/common/FileMetadata.cpp +++ b/velox/dwio/dwrf/common/FileMetadata.cpp @@ -37,6 +37,29 @@ CompressionKind orcCompressionToCompressionKind( } VELOX_FAIL("Unknown compression kind: {}", CompressionKind_Name(compression)); } + +static proto::orc::CompressionKind compressionKindToOrcCompression( + CompressionKind compressionKind) { + switch (compressionKind) { + case CompressionKind::CompressionKind_NONE: + return proto::orc::CompressionKind::NONE; + case CompressionKind::CompressionKind_ZLIB: + return proto::orc::CompressionKind::ZLIB; + case CompressionKind::CompressionKind_SNAPPY: + return proto::orc::CompressionKind::SNAPPY; + case CompressionKind::CompressionKind_LZO: + return proto::orc::CompressionKind::LZO; + case CompressionKind::CompressionKind_ZSTD: + return proto::orc::CompressionKind::ZSTD; + case CompressionKind::CompressionKind_LZ4: + return proto::orc::CompressionKind::LZ4; + case CompressionKind::CompressionKind_GZIP: + default: + VELOX_FAIL( + "Unknown compression kind: {}", + compressionKindToString(compressionKind)); + } +} } // namespace detail TypeKind TypeWrapper::kind() const { @@ -102,9 +125,10 @@ TypeKind TypeWrapper::kind() const { } case proto::orc::Type_Kind_CHAR: case proto::orc::Type_Kind_TIMESTAMP_INSTANT: - VELOX_FAIL(fmt::format( - "{} not supported yet.", - proto::orc::Type_Kind_Name(orcPtr()->kind()))); + VELOX_FAIL( + fmt::format( + "{} not supported yet.", + proto::orc::Type_Kind_Name(orcPtr()->kind()))); default: VELOX_FAIL("Unknown type kind: {}", Type_Kind_Name(orcPtr()->kind())); } @@ -116,4 +140,13 @@ common::CompressionKind PostScript::compression() const { : detail::orcCompressionToCompressionKind(orcPtr()->compression()); } +void PostScriptWriteWrapper::setCompression( + common::CompressionKind compressionKind) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_compression( + static_cast(compressionKind)) + : orcPtr()->set_compression( + detail::compressionKindToOrcCompression(compressionKind)); +} + } // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/common/FileMetadata.h b/velox/dwio/dwrf/common/FileMetadata.h index 87e8c12719fe..3665d4b91201 100644 --- a/velox/dwio/dwrf/common/FileMetadata.h +++ b/velox/dwio/dwrf/common/FileMetadata.h @@ -19,6 +19,7 @@ #include "velox/common/base/Exceptions.h" #include "velox/common/compression/Compression.h" +#include "velox/dwio/common/OutputStream.h" #include "velox/dwio/dwrf/common/Common.h" #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" #include "velox/dwio/dwrf/common/wrap/orc-proto-wrapper.h" @@ -43,6 +44,24 @@ class ProtoWrapperBase { const void* const impl_; }; +class ProtoWriteWrapperBase { + protected: + ProtoWriteWrapperBase(DwrfFormat format, void* impl) + : format_{format}, impl_{impl} {} + + DwrfFormat format_; + void* impl_; + + public: + DwrfFormat format() const { + return format_; + } + + inline void* rawProtoPtr() const { + return impl_; + } +}; + /*** * PostScript that takes the ownership of proto::PostScript / *proto::orc::PostScript and provides access to the attributes @@ -93,6 +112,10 @@ class PostScript { : orcPtr()->footerlength(); } + uint64_t metadataLength() const { + return format_ == DwrfFormat::kDwrf ? 0 : orcPtr()->metadatalength(); + } + bool hasCompression() const { return format_ == DwrfFormat::kDwrf ? dwrfPtr()->has_compression() : orcPtr()->has_compression(); @@ -244,6 +267,46 @@ class StripeInformationWrapper : public ProtoWrapperBase { } }; +class ColumnEncodingKindWrapper : public ProtoWrapperBase { + public: + explicit ColumnEncodingKindWrapper(proto::ColumnEncoding_Kind* stream) + : ProtoWrapperBase(DwrfFormat::kDwrf, stream) {} + + explicit ColumnEncodingKindWrapper(proto::orc::ColumnEncoding_Kind* stream) + : ProtoWrapperBase(DwrfFormat::kOrc, stream) {} +}; + +class ColumnEncodingWrapper : public ProtoWrapperBase { + public: + explicit ColumnEncodingWrapper(const proto::ColumnEncoding* columnEncoding) + : ProtoWrapperBase(DwrfFormat::kDwrf, columnEncoding) {} + explicit ColumnEncodingWrapper( + const proto::orc::ColumnEncoding* columnEncoding) + : ProtoWrapperBase(DwrfFormat::kOrc, columnEncoding) {} + + void Clear() {} + + proto::ColumnEncoding_Kind kind() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->kind(); + } + + uint32_t node() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->node(); + } + + private: + // private helper with no format checking + inline const proto::ColumnEncoding* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + + inline const proto::orc::ColumnEncoding* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + class TypeWrapper : public ProtoWrapperBase { public: explicit TypeWrapper(const proto::Type* t) @@ -940,6 +1003,1027 @@ class StripeFooterWrapper : public ProtoWrapperBase { std::shared_ptr orcStripeFooter_ = nullptr; }; +class StripeInformationWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StripeInformationWriteWrapper( + proto::StripeInformation* stripeInformation) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stripeInformation) {} + + explicit StripeInformationWriteWrapper( + proto::orc::StripeInformation* stripeInformation) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stripeInformation) {} + + uint64_t numberOfRows() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->numberofrows() + : orcPtr()->numberofrows(); + } + + uint64_t rawDataSize() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->rawdatasize() : 0; + } + + bool hasChecksum() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->has_checksum() : false; + } + + uint64_t checksum() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->checksum() : 0; + } + + void setNumberOfRows(uint64_t stripeRowCount) { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_numberofrows(stripeRowCount) + : orcPtr()->set_numberofrows(stripeRowCount); + } + + void setRawDataSize(uint64_t rawDataSize) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_rawdatasize(rawDataSize); + } + } + + void setChecksum(int64_t checksum) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_checksum(checksum); + } + + void setGroupSize(uint64_t groupSize) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_groupsize(groupSize); + } + + uint64_t groupSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->groupsize(); + } + + void setOffset(uint64_t offset) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_offset(offset) + : orcPtr()->set_offset(offset); + } + + uint64_t offset() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->offset() + : orcPtr()->offset(); + } + + void setIndexLength(uint64_t indexLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_indexlength(indexLength) + : orcPtr()->set_indexlength(indexLength); + } + + uint64_t indexLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->indexlength() + : orcPtr()->indexlength(); + } + + void setDataLength(uint64_t dataLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_datalength(dataLength) + : orcPtr()->set_datalength(dataLength); + } + + uint64_t dataLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->datalength() + : orcPtr()->datalength(); + } + + void setFooterLength(uint64_t footerLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_footerlength(footerLength) + : orcPtr()->set_footerlength(footerLength); + } + + uint64_t footerLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->footerlength() + : orcPtr()->footerlength(); + } + + std::string* addKeyMetadata() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->add_keymetadata(); + } + + private: + // private helper with no format checking + inline proto::StripeInformation* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::StripeInformation* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class TypeKindWrapper : public ProtoWriteWrapperBase { + public: + explicit TypeKindWrapper(proto::Type_Kind* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit TypeKindWrapper(proto::orc::Type_Kind* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} +}; + +class TypeWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit TypeWriteWrapper(proto::Type* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit TypeWriteWrapper(proto::orc::Type* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} + + const proto::Type* getDwrfPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + const proto::orc::Type* getOrcPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return reinterpret_cast(rawProtoPtr()); + } + + void setKind(TypeKindWrapper typeKindWrapper) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_kind(*reinterpret_cast( + typeKindWrapper.rawProtoPtr())) + : orcPtr()->set_kind(*reinterpret_cast( + typeKindWrapper.rawProtoPtr())); + } + + void setScale(uint32_t scale) { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + orcPtr()->set_scale(scale); + } + + void setPrecision(uint32_t precision) { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + orcPtr()->set_precision(precision); + } + + void addFieldnames(const std::string& fieldName) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_fieldnames(fieldName) + : orcPtr()->add_fieldnames(fieldName); + } + + void addSubtypes(int fieldName) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_subtypes(fieldName) + : orcPtr()->add_subtypes(fieldName); + } + + private: + // private helper with no format checking + inline proto::Type* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::Type* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class UserMetadataItemWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit UserMetadataItemWriteWrapper( + proto::UserMetadataItem* userMetadataItem) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, userMetadataItem) {} + + explicit UserMetadataItemWriteWrapper( + proto::orc::UserMetadataItem* userMetadataItem) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, userMetadataItem) {} + + void setName(const std::string& name) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_name(name) + : orcPtr()->set_name(name); + } + + void setValue(const std::string& value) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_value(value) + : orcPtr()->set_value(value); + } + + private: + // private helper with no format checking + inline proto::UserMetadataItem* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::UserMetadataItem* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class BucketStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit BucketStatisticsWriteWrapper( + proto::BucketStatistics* bucketStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, bucketStatistics) {} + + explicit BucketStatisticsWriteWrapper( + proto::orc::BucketStatistics* bucketStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, bucketStatistics) {} + + int countSize() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->count_size() + : orcPtr()->count_size(); + } + + void addCount(uint64_t count) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_count(count) + : orcPtr()->add_count(count); + } + + private: + // private helper with no format checking + inline proto::BucketStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::BucketStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class IntegerStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit IntegerStatisticsWriteWrapper( + proto::IntegerStatistics* integerStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, integerStatistics) {} + + explicit IntegerStatisticsWriteWrapper( + proto::orc::IntegerStatistics* integerStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, integerStatistics) {} + + void setSum(int64_t sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + void setMinimum(int64_t minimum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_minimum(minimum) + : orcPtr()->set_minimum(minimum); + } + + void setMaximum(int64_t maximum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_maximum(maximum) + : orcPtr()->set_maximum(maximum); + } + + private: + // private helper with no format checking + inline proto::IntegerStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::IntegerStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class DoubleStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit DoubleStatisticsWriteWrapper( + proto::DoubleStatistics* doubleStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, doubleStatistics) {} + + explicit DoubleStatisticsWriteWrapper( + proto::orc::DoubleStatistics* doubleStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, doubleStatistics) {} + + void setSum(double sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + void setMinimum(double minimum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_minimum(minimum) + : orcPtr()->set_minimum(minimum); + } + + void setMaximum(double maximum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_maximum(maximum) + : orcPtr()->set_maximum(maximum); + } + + private: + // private helper with no format checking + inline proto::DoubleStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::DoubleStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class StringStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StringStatisticsWriteWrapper( + proto::StringStatistics* stringStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stringStatistics) {} + + explicit StringStatisticsWriteWrapper( + proto::orc::StringStatistics* stringStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stringStatistics) {} + + void setSum(uint64_t sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + void setMinimum(std::string minimum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_minimum(minimum) + : orcPtr()->set_minimum(minimum); + } + + void setMaximum(std::string maximum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_maximum(maximum) + : orcPtr()->set_maximum(maximum); + } + + private: + // private helper with no format checking + inline proto::StringStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::StringStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class BinaryStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit BinaryStatisticsWriteWrapper( + proto::BinaryStatistics* binaryStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, binaryStatistics) {} + + explicit BinaryStatisticsWriteWrapper( + proto::orc::BinaryStatistics* binaryStatistics) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, binaryStatistics) {} + + void setSum(uint64_t sum) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_sum(sum) + : orcPtr()->set_sum(sum); + } + + private: + // private helper with no format checking + inline proto::BinaryStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::BinaryStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class ColumnStatisticsWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit ColumnStatisticsWriteWrapper(proto::ColumnStatistics* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit ColumnStatisticsWriteWrapper(proto::orc::ColumnStatistics* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} + + void setSize(uint64_t size) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_size(size); + } + } + + void setHasNull(bool hasNull) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_hasnull(hasNull) + : orcPtr()->set_hasnull(hasNull); + } + + void setNumberOfValues(uint64_t numberOfValues) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_numberofvalues(numberOfValues) + : orcPtr()->set_numberofvalues(numberOfValues); + } + + void setRawSize(uint64_t rawSize) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_rawsize(rawSize); + } + } + + uint64_t getRawSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->rawsize(); + } + + uint64_t getSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->size(); + } + + bool hasMapStatistics() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->has_mapstatistics(); + } + + proto::MapStatistics* mutableMapStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->mutable_mapstatistics(); + } + + BinaryStatisticsWriteWrapper mutableBinaryStatistics() { + return format_ == DwrfFormat::kDwrf + ? BinaryStatisticsWriteWrapper(dwrfPtr()->mutable_binarystatistics()) + : BinaryStatisticsWriteWrapper(orcPtr()->mutable_binarystatistics()); + } + + StringStatisticsWriteWrapper mutableStringStatistics() { + return format_ == DwrfFormat::kDwrf + ? StringStatisticsWriteWrapper(dwrfPtr()->mutable_stringstatistics()) + : StringStatisticsWriteWrapper(orcPtr()->mutable_stringstatistics()); + } + + DoubleStatisticsWriteWrapper mutableDoubleStatistics() { + return format_ == DwrfFormat::kDwrf + ? DoubleStatisticsWriteWrapper(dwrfPtr()->mutable_doublestatistics()) + : DoubleStatisticsWriteWrapper(orcPtr()->mutable_doublestatistics()); + } + + IntegerStatisticsWriteWrapper mutableIntegerStatistics() { + return format_ == DwrfFormat::kDwrf + ? IntegerStatisticsWriteWrapper(dwrfPtr()->mutable_intstatistics()) + : IntegerStatisticsWriteWrapper(orcPtr()->mutable_intstatistics()); + } + + proto::orc::DateStatistics* mutableDateStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return orcPtr()->mutable_datestatistics(); + } + + proto::orc::TimestampStatistics* mutableTimestampStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return orcPtr()->mutable_timestampstatistics(); + } + + proto::orc::DecimalStatistics* mutableDecimalStatistics() { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return orcPtr()->mutable_decimalstatistics(); + } + + BucketStatisticsWriteWrapper mutableBucketStatistics() { + return format_ == DwrfFormat::kDwrf + ? BucketStatisticsWriteWrapper(dwrfPtr()->mutable_bucketstatistics()) + : BucketStatisticsWriteWrapper(orcPtr()->mutable_bucketstatistics()); + } + + void reset(const proto::ColumnStatistics* dwrfStatistics) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + VELOX_CHECK_NOT_NULL(dwrfStatistics); + dwrfPtr()->CopyFrom(*dwrfStatistics); + } + + private: + // private helper with no format checking + inline proto::ColumnStatistics* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::ColumnStatistics* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class FooterWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit FooterWriteWrapper(proto::Footer* footer) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, footer) {} + + explicit FooterWriteWrapper(proto::orc::Footer* footer) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, footer) {} + + const proto::Footer* getDwrfPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + proto::Footer* getMutableDwrfPtr() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + const proto::orc::Footer* getOrcPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kOrc); + return reinterpret_cast(rawProtoPtr()); + } + + const StripeInformationWriteWrapper addStripes() const { + return format_ == DwrfFormat::kDwrf + ? StripeInformationWriteWrapper(dwrfPtr()->add_stripes()) + : StripeInformationWriteWrapper(orcPtr()->add_stripes()); + } + + void setHeaderLength(uint64_t headerLength) const { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_headerlength(headerLength) + : orcPtr()->set_headerlength(headerLength); + } + + void setContentLength(uint64_t contentLength) const { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_contentlength(contentLength) + : orcPtr()->set_contentlength(contentLength); + } + + void setRowIndexStride(uint32_t rowIndexStride) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_rowindexstride(rowIndexStride) + : orcPtr()->set_rowindexstride(rowIndexStride); + } + + void setNumberOfRows(uint64_t numberOfRows) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_numberofrows(numberOfRows) + : orcPtr()->set_numberofrows(numberOfRows); + } + + void setRawDataSize(uint64_t numberOfRows) const { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_rawdatasize(numberOfRows); + } + } + + void setWriter(uint32_t writer) const { + if (format_ == DwrfFormat::kOrc) { + orcPtr()->set_writer(writer); + } + } + + void setCheckSumAlgorithm(proto::ChecksumAlgorithm checksum) const { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_checksumalgorithm(checksum); + } + } + + void addStripeCacheOffsets(uint32_t stripeCacheOffsets) const { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->add_stripecacheoffsets(stripeCacheOffsets); + } else { + // + } + } + + TypeWriteWrapper addTypes() const { + return format_ == DwrfFormat::kDwrf + ? TypeWriteWrapper(dwrfPtr()->add_types()) + : TypeWriteWrapper(orcPtr()->add_types()); + } + + UserMetadataItemWriteWrapper addMetadata() const { + return format_ == DwrfFormat::kDwrf + ? UserMetadataItemWriteWrapper(dwrfPtr()->add_metadata()) + : UserMetadataItemWriteWrapper(orcPtr()->add_metadata()); + } + + ColumnStatisticsWriteWrapper addStatistics() const { + return format_ == DwrfFormat::kDwrf + ? ColumnStatisticsWriteWrapper(dwrfPtr()->add_statistics()) + : ColumnStatisticsWriteWrapper(orcPtr()->add_statistics()); + } + + const ::google::protobuf::RepeatedPtrField< + ::facebook::velox::dwrf::proto::ColumnStatistics>& + statistics() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->statistics(); + } + + int typesSize() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->types_size() + : orcPtr()->types_size(); + } + + int statisticsSize() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->statistics_size() + : orcPtr()->statistics_size(); + } + + uint64_t contentLength() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->contentlength() + : orcPtr()->contentlength(); + } + + uint64_t numberOfRows() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->numberofrows() + : orcPtr()->numberofrows(); + } + + // DWRF-specific fields + inline uint64_t rawDataSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->rawdatasize(); + } + + inline int stripesSize() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->stripes_size(); + } + + inline proto::Encryption* mutableEncryption() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->mutable_encryption(); + } + + private: + // private helper with no format checking + inline proto::Footer* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::Footer* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class RowIndexEntryWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit RowIndexEntryWriteWrapper(proto::RowIndexEntry* rowIndexEntry) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, rowIndexEntry) {} + + explicit RowIndexEntryWriteWrapper(proto::orc::RowIndexEntry* rowIndexEntry) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, rowIndexEntry) {} + + ColumnStatisticsWriteWrapper mutableStatistics() { + return format_ == DwrfFormat::kDwrf + ? ColumnStatisticsWriteWrapper(dwrfPtr()->mutable_statistics()) + : ColumnStatisticsWriteWrapper(orcPtr()->mutable_statistics()); + } + + bool hasStatistics() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->has_statistics() + : orcPtr()->has_statistics(); + } + + void mutablePositions(int start, int num) { + return format_ == DwrfFormat::kDwrf + ? dwrfPtr()->mutable_positions()->ExtractSubrange(start, num, nullptr) + : orcPtr()->mutable_positions()->ExtractSubrange(start, num, nullptr); + } + + void addPositions(uint64_t pos) { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->add_positions(pos) + : orcPtr()->add_positions(pos); + } + + uint64_t positionsSize() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->positions_size() + : orcPtr()->positions_size(); + } + + const ::google::protobuf::RepeatedField positions() const { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->positions() + : orcPtr()->positions(); + } + + void clear() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->Clear() + : orcPtr()->Clear(); + } + + private: + // private helper with no format checking + inline proto::RowIndexEntry* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::RowIndexEntry* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class RowIndexWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit RowIndexWriteWrapper(proto::RowIndex* rowIndex) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, rowIndex) {} + + explicit RowIndexWriteWrapper(proto::orc::RowIndex* rowIndex) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, rowIndex) {} + + void addEntry(std::unique_ptr& entry) { + if (format_ == DwrfFormat::kDwrf) { + auto e = reinterpret_cast(entry->rawProtoPtr()); + *dwrfPtr()->add_entry() = *e; + } else { + auto e = + reinterpret_cast(entry->rawProtoPtr()); + *orcPtr()->add_entry() = *e; + } + } + + int32_t entrySize() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->entry_size() + : orcPtr()->entry_size(); + } + + RowIndexEntryWriteWrapper mutableEntry(int32_t index) { + return format_ == DwrfFormat::kDwrf + ? RowIndexEntryWriteWrapper(dwrfPtr()->mutable_entry(index)) + : RowIndexEntryWriteWrapper(orcPtr()->mutable_entry(index)); + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* out) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(out) + : orcPtr()->SerializeToZeroCopyStream(out); + } + + void clear() { + return format_ == DwrfFormat::kDwrf ? dwrfPtr()->Clear() + : orcPtr()->Clear(); + } + + private: + // private helper with no format checking + inline proto::RowIndex* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::RowIndex* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class ColumnEncodingWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit ColumnEncodingWriteWrapper(proto::ColumnEncoding* stream) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stream) {} + + explicit ColumnEncodingWriteWrapper(proto::orc::ColumnEncoding* stream) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stream) {} + + void setKind(ColumnEncodingKindWrapper columnEncodingKindWrapper) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_kind( + *reinterpret_cast( + columnEncodingKindWrapper.rawProtoPtr())) + : orcPtr()->set_kind( + *reinterpret_cast( + columnEncodingKindWrapper.rawProtoPtr())); + } + + void setDictionarySize(uint32_t dictionarySize) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_dictionarysize(dictionarySize) + : orcPtr()->set_dictionarysize(dictionarySize); + } + + void setNode(uint32_t node) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_node(node); + } + } + + void setSequence(uint32_t sequence) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_sequence(sequence); + } + } + + proto::KeyInfo* mutableKey() { + return dwrfPtr()->mutable_key(); + } + + void Clear() { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->Clear() : orcPtr()->Clear(); + } + + void reset(const proto::ColumnEncoding* dwrfEncoding) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + VELOX_CHECK_NOT_NULL(dwrfEncoding); + dwrfPtr()->CopyFrom(*dwrfEncoding); + } + + private: + // private helper with no format checking + inline proto::ColumnEncoding* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::ColumnEncoding* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class StreamWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StreamWriteWrapper(proto::Stream* stream) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stream) {} + + explicit StreamWriteWrapper(proto::orc::Stream* stream) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stream) {} + + void setOffset(uint64_t offset) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_offset(offset); + } + + void setKind(const StreamKind& kind) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_kind(static_cast(kind)) + : orcPtr()->set_kind(static_cast(kind)); + } + + void setColumn(uint32_t column) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_column(column) + : orcPtr()->set_column(column); + } + + void setLength(uint64_t length) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_length(length) + : orcPtr()->set_length(length); + } + + void setNode(uint32_t node) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_node(node); + } + + void setSequence(uint32_t sequence) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_sequence(sequence); + } + + void setUseVints(bool useVints) { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + dwrfPtr()->set_usevints(useVints); + } + + private: + // private helper with no format checking + inline proto::Stream* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::Stream* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class StripeEncryptionGroupWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StripeEncryptionGroupWriteWrapper( + proto::StripeEncryptionGroup* stripeFooter = nullptr) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stripeFooter) {} + + // See https://orc.apache.org/specification/ORCv1/ + explicit StripeEncryptionGroupWriteWrapper( + proto::orc::StripeEncryptionVariant* stripeFooter) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stripeFooter) {} + + void encoding( + std::vector& columnEncodingWrappers) const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + for (const proto::ColumnEncoding& encoding : dwrfPtr()->encoding()) { + auto ce = ColumnEncodingWrapper(&encoding); + columnEncodingWrappers.emplace_back(ce); + } + } + + ColumnEncodingWriteWrapper addEncoding() { + return format_ == DwrfFormat::kDwrf + ? ColumnEncodingWriteWrapper(dwrfPtr()->add_encoding()) + : ColumnEncodingWriteWrapper(orcPtr()->add_encoding()); + } + + StreamWriteWrapper addStreams() { + return format_ == DwrfFormat::kDwrf + ? StreamWriteWrapper(dwrfPtr()->add_streams()) + : StreamWriteWrapper(orcPtr()->add_streams()); + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* output) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(output) + : orcPtr()->SerializeToZeroCopyStream(output); + } + + private: + // private helper with no format checking + inline proto::StripeEncryptionGroup* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::StripeEncryptionVariant* orcPtr() const { + return reinterpret_cast( + rawProtoPtr()); + } +}; + +class StripeFooterWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit StripeFooterWriteWrapper(proto::StripeFooter* stripeFooter) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, stripeFooter) {} + + explicit StripeFooterWriteWrapper(proto::orc::StripeFooter* stripeFooter) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, stripeFooter) {} + + void encoding( + std::vector& columnEncodingWrappers) const { + if (format_ == DwrfFormat::kDwrf) { + for (const proto::ColumnEncoding& encoding : dwrfPtr()->encoding()) { + auto ce = ColumnEncodingWrapper(&encoding); + columnEncodingWrappers.emplace_back(ce); + } + } else { + for (const proto::orc::ColumnEncoding& encoding : orcPtr()->columns()) { + auto ce = ColumnEncodingWrapper(&encoding); + columnEncodingWrappers.emplace_back(ce); + } + } + } + + void setWriterTimezone() const { + if (format_ == DwrfFormat::kOrc) { + // orcPtr()->set_writertimezone("Asia/Shanghai"); + } + } + + StreamWriteWrapper addStreams() { + return format_ == DwrfFormat::kDwrf + ? StreamWriteWrapper(dwrfPtr()->add_streams()) + : StreamWriteWrapper(orcPtr()->add_streams()); + } + + std::string* addEncryptionGroups() { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return dwrfPtr()->add_encryptiongroups(); + } + + ColumnEncodingWriteWrapper addEncoding() { + return format_ == DwrfFormat::kDwrf + ? ColumnEncodingWriteWrapper(dwrfPtr()->add_encoding()) + : ColumnEncodingWriteWrapper(orcPtr()->add_columns()); + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* output) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(output) + : orcPtr()->SerializeToZeroCopyStream(output); + } + + inline proto::StripeFooter* dwrfPtr() const { + VELOX_CHECK_EQ(format_, DwrfFormat::kDwrf); + return reinterpret_cast(rawProtoPtr()); + } + + private: + // private helper with no format checking + inline proto::orc::StripeFooter* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + +class PostScriptWriteWrapper : public ProtoWriteWrapperBase { + public: + explicit PostScriptWriteWrapper(proto::PostScript* postScript) + : ProtoWriteWrapperBase(DwrfFormat::kDwrf, postScript) {} + + explicit PostScriptWriteWrapper(proto::orc::PostScript* postScript) + : ProtoWriteWrapperBase(DwrfFormat::kOrc, postScript) {} + + void setWriterVersion(uint32_t writerVersion) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_writerversion(writerVersion) + : orcPtr()->set_writerversion(6); + } + + void addVersion(uint32_t version) { + if (format_ == DwrfFormat::kOrc) { + orcPtr()->add_version(version); + } + } + + void setFooterLength(uint64_t footerLength) { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->set_footerlength(footerLength) + : orcPtr()->set_footerlength(footerLength); + } + + void setCompression(common::CompressionKind compressionKind); + + void setCompressionBlockSize(uint64_t compressionBlockSize) { + format_ == DwrfFormat::kDwrf + ? dwrfPtr()->set_compressionblocksize(compressionBlockSize) + : orcPtr()->set_compressionblocksize(compressionBlockSize); + } + + void setCacheMode(StripeCacheMode cacheMode) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_cachemode(static_cast(cacheMode)); + } + } + + void setCacheSize(uint32_t cacheSize) { + if (format_ == DwrfFormat::kDwrf) { + dwrfPtr()->set_cachesize(cacheSize); + } + } + + void setMetaDataLength(uint64_t metaDataLength) { + if (format_ == DwrfFormat::kOrc) { + orcPtr()->set_metadatalength(metaDataLength); + } + } + + void SerializeToZeroCopyStream( + dwio::common::BufferedOutputStream* out) const { + format_ == DwrfFormat::kDwrf ? dwrfPtr()->SerializeToZeroCopyStream(out) + : orcPtr()->SerializeToZeroCopyStream(out); + } + + private: + // private helper with no format checking + inline proto::PostScript* dwrfPtr() const { + return reinterpret_cast(rawProtoPtr()); + } + inline proto::orc::PostScript* orcPtr() const { + return reinterpret_cast(rawProtoPtr()); + } +}; + } // namespace facebook::velox::dwrf template <> diff --git a/velox/dwio/dwrf/common/IntEncoder.h b/velox/dwio/dwrf/common/IntEncoder.h index 1051967bb39b..bba9060a1839 100644 --- a/velox/dwio/dwrf/common/IntEncoder.h +++ b/velox/dwio/dwrf/common/IntEncoder.h @@ -341,8 +341,9 @@ template case 57 ... 63: return writeVarint<1>(value, buffer); } - DWIO_RAISE(folly::sformat( - "Unexpected leading zeros {} for value {}", leadingZeros, value)); + DWIO_RAISE( + folly::sformat( + "Unexpected leading zeros {} for value {}", leadingZeros, value)); } template diff --git a/velox/dwio/dwrf/common/RLEv1.h b/velox/dwio/dwrf/common/RLEv1.h index 082b57c10e08..62ab5d444f1a 100644 --- a/velox/dwio/dwrf/common/RLEv1.h +++ b/velox/dwio/dwrf/common/RLEv1.h @@ -417,7 +417,7 @@ class RleDecoderV1 : public dwio::common::IntDecoder { rows + rowIndex, std::min(remainingValues_, numRows - rowIndex)); const auto endOfRun = currentRow + remainingValues_; - const auto bound = std::lower_bound(range.begin(), range.end(), endOfRun); + const auto bound = std::lower_bound(range.cbegin(), range.cend(), endOfRun); return std::make_pair(bound - range.begin(), bound[-1] - currentRow + 1); } diff --git a/velox/dwio/dwrf/common/wrap/CMakeLists.txt b/velox/dwio/dwrf/common/wrap/CMakeLists.txt new file mode 100644 index 000000000000..a598690b32e5 --- /dev/null +++ b/velox/dwio/dwrf/common/wrap/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +velox_install_library_headers() diff --git a/velox/dwio/dwrf/proto/CMakeLists.txt b/velox/dwio/dwrf/proto/CMakeLists.txt index 2f0de359d08a..fd2ade2dac98 100644 --- a/velox/dwio/dwrf/proto/CMakeLists.txt +++ b/velox/dwio/dwrf/proto/CMakeLists.txt @@ -13,19 +13,13 @@ # limitations under the License. # Set up Proto -file( - GLOB PROTO_FILES - RELATIVE ${PROJECT_SOURCE_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}/*.proto) +file(GLOB PROTO_FILES RELATIVE ${PROJECT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/*.proto) foreach(PROTO ${PROTO_FILES}) get_filename_component(PROTO_DIR ${PROTO} DIRECTORY) get_filename_component(PROTO_NAME ${PROTO} NAME_WE) - list(APPEND PROTO_SRCS - "${PROJECT_BINARY_DIR}/${PROTO_DIR}/${PROTO_NAME}.pb.cc") - list(APPEND PROTO_HDRS - "${PROJECT_BINARY_DIR}/${PROTO_DIR}/${PROTO_NAME}.pb.h") - list(APPEND PROTO_FILES_FULL - "${PROJECT_SOURCE_DIR}/${PROTO_DIR}/${PROTO_NAME}.proto") + list(APPEND PROTO_SRCS "${PROJECT_BINARY_DIR}/${PROTO_DIR}/${PROTO_NAME}.pb.cc") + list(APPEND PROTO_HDRS "${PROJECT_BINARY_DIR}/${PROTO_DIR}/${PROTO_NAME}.pb.h") + list(APPEND PROTO_FILES_FULL "${PROJECT_SOURCE_DIR}/${PROTO_DIR}/${PROTO_NAME}.proto") endforeach() set(PROTO_OUTPUT_FILES ${PROTO_HDRS} ${PROTO_SRCS}) set_source_files_properties(${PROTO_OUTPUT_FILES} PROPERTIES GENERATED TRUE) @@ -37,11 +31,11 @@ endforeach() add_custom_command( OUTPUT ${PROTO_OUTPUT_FILES} - COMMAND protobuf::protoc ${PROTO_PATH_ARGS} --cpp_out ${PROJECT_BINARY_DIR} - ${PROTO_FILES_FULL} + COMMAND protobuf::protoc ${PROTO_PATH_ARGS} --cpp_out ${PROJECT_BINARY_DIR} ${PROTO_FILES_FULL} DEPENDS protobuf::protoc COMMENT "Running PROTO compiler" - VERBATIM) + VERBATIM +) add_custom_target(dwio_proto ALL DEPENDS ${PROTO_OUTPUT_FILES}) if(VELOX_MONO_LIBRARY) diff --git a/velox/dwio/dwrf/proto/dwrf_proto.proto b/velox/dwio/dwrf/proto/dwrf_proto.proto index aaf69dcf1555..ba0f3350a26d 100644 --- a/velox/dwio/dwrf/proto/dwrf_proto.proto +++ b/velox/dwio/dwrf/proto/dwrf_proto.proto @@ -284,6 +284,7 @@ enum CompressionKind { LZO = 3; ZSTD = 4; LZ4 = 5; + GZIP = 6; } enum StripeCacheMode { diff --git a/velox/dwio/dwrf/proto/fb_protobuf.sh b/velox/dwio/dwrf/proto/fb_protobuf.sh index c7367f75d8e7..0212d8c890af 100755 --- a/velox/dwio/dwrf/proto/fb_protobuf.sh +++ b/velox/dwio/dwrf/proto/fb_protobuf.sh @@ -16,6 +16,6 @@ set -e ( - "${PROTOC:-protoc}" dwrf_proto.proto --cpp_out="$INSTALL_DIR" - "${PROTOC:-protoc}" orc_proto.proto --cpp_out="$INSTALL_DIR" + "${PROTOC:-protoc}" dwrf_proto.proto --cpp_out="$INSTALL_DIR" + "${PROTOC:-protoc}" orc_proto.proto --cpp_out="$INSTALL_DIR" ) diff --git a/velox/dwio/dwrf/reader/CMakeLists.txt b/velox/dwio/dwrf/reader/CMakeLists.txt index cae8fcc24bbe..8b6f04ec9625 100644 --- a/velox/dwio/dwrf/reader/CMakeLists.txt +++ b/velox/dwio/dwrf/reader/CMakeLists.txt @@ -33,7 +33,8 @@ velox_add_library( StreamLabels.cpp StripeDictionaryCache.cpp StripeReaderBase.cpp - StripeStream.cpp) + StripeStream.cpp +) velox_link_libraries( velox_dwio_dwrf_reader @@ -42,4 +43,5 @@ velox_link_libraries( velox_caching velox_dwio_dwrf_utils velox_test_util - fmt::fmt) + fmt::fmt +) diff --git a/velox/dwio/dwrf/reader/ColumnReader.cpp b/velox/dwio/dwrf/reader/ColumnReader.cpp index 919ebf61af4e..0e20372722cc 100644 --- a/velox/dwio/dwrf/reader/ColumnReader.cpp +++ b/velox/dwio/dwrf/reader/ColumnReader.cpp @@ -302,7 +302,7 @@ void ByteRleColumnReader::next( BufferPtr values; if (flatVector) { - values = flatVector->mutableValues(numValues); + values = flatVector->mutableValues(); } if (flatVector) { @@ -475,7 +475,7 @@ void DecimalColumnReader::next( } BufferPtr values; if (flatVector) { - values = flatVector->mutableValues(numValues); + values = flatVector->mutableValues(); } BufferPtr nulls = readNulls(numValues, result, incomingNulls); @@ -585,8 +585,8 @@ void IntegerDirectColumnReader::next( auto flatVector = resetIfWrongFlatVectorType(result); BufferPtr values; if (flatVector) { - values = flatVector->mutableValues(numValues); result->resize(numValues, false); + values = flatVector->mutableValues(); } BufferPtr nulls = readNulls(numValues, result, incomingNulls); @@ -743,7 +743,7 @@ void IntegerDictionaryColumnReader::next( BufferPtr values; if (result) { result->resize(numValues, false); - values = flatVector->mutableValues(numValues); + values = flatVector->mutableValues(); } BufferPtr nulls = readNulls(numValues, result, incomingNulls); @@ -869,7 +869,7 @@ void TimestampColumnReader::next( BufferPtr values; if (flatVector) { result->resize(numValues, false); - values = flatVector->mutableValues(numValues); + values = flatVector->mutableValues(); } BufferPtr nulls = readNulls(numValues, result, incomingNulls); @@ -1023,7 +1023,7 @@ void FloatingPointColumnReader::next( } BufferPtr values; if (flatVector) { - values = flatVector->mutableValues(numValues); + values = flatVector->mutableValues(); } BufferPtr nulls = readNulls(numValues, result, incomingNulls); @@ -1311,7 +1311,7 @@ void StringDictionaryColumnReader::loadStrideDictionary() { if (strideDictCount_ > 0) { // seek stride dictionary related streams std::vector pos( - positions.begin() + positionOffset_, positions.end()); + positions.cbegin() + positionOffset_, positions.cend()); dwio::common::PositionProvider pp(pos); strideDictStream_->seekToPosition(pp); strideDictLengthDecoder_->seekToRowGroup(pp); @@ -1544,7 +1544,7 @@ void StringDictionaryColumnReader::readFlatVector( BufferPtr data; if (flatVector) { - data = flatVector->mutableValues(numValues); + data = flatVector->mutableValues(); } BufferPtr nulls = readNulls(numValues, result, incomingNulls); @@ -1773,7 +1773,7 @@ void StringDirectColumnReader::next( BufferPtr values; if (flatVector) { flatVector->resize(numValues, false); - values = flatVector->mutableValues(numValues); + values = flatVector->mutableValues(); } BufferPtr nulls = readNulls(numValues, result, incomingNulls); @@ -2461,8 +2461,9 @@ std::unique_ptr buildByteRleColumnReader( RleDecoderFactory::get(), std::move(flatMapContext)); default: - DWIO_RAISE(fmt::format( - "Unsupported upcast to typekind: {}", requestedType->toString())); + DWIO_RAISE( + fmt::format( + "Unsupported upcast to typekind: {}", requestedType->toString())); } } @@ -2502,9 +2503,10 @@ std::unique_ptr buildTypedIntegerColumnReader( numBytes, std::move(flatMapContext)); default: - DWIO_RAISE(fmt::format( - "Unsupported requested integral type: {}", - requestedType->toString())); + DWIO_RAISE( + fmt::format( + "Unsupported requested integral type: {}", + requestedType->toString())); } } diff --git a/velox/dwio/dwrf/reader/DwrfData.cpp b/velox/dwio/dwrf/reader/DwrfData.cpp index 509791b44d12..afbf5b4ee3bb 100644 --- a/velox/dwio/dwrf/reader/DwrfData.cpp +++ b/velox/dwio/dwrf/reader/DwrfData.cpp @@ -101,6 +101,7 @@ void DwrfData::ensureRowGroupIndex() { dwio::common::PositionProvider DwrfData::seekToRowGroup(int64_t index) { ensureRowGroupIndex(); + VELOX_CHECK_LT(index, index_->entry_size(), "RowGroup index is corrupted"); positionsHolder_ = toPositionsInner(index_->entry(index)); dwio::common::PositionProvider positionProvider(positionsHolder_); diff --git a/velox/dwio/dwrf/reader/DwrfData.h b/velox/dwio/dwrf/reader/DwrfData.h index 5c17cd77aafc..04fa978a96e8 100644 --- a/velox/dwio/dwrf/reader/DwrfData.h +++ b/velox/dwio/dwrf/reader/DwrfData.h @@ -72,6 +72,10 @@ class DwrfData : public dwio::common::FormatData { return flatMapContext_.inMapDecoder ? inMap_->as() : nullptr; } + const velox::BufferPtr& inMapBuffer() { + return inMap_; + } + /// Seeks possible flat map in map streams and nulls to the row group /// and returns a PositionsProvider for the other streams. dwio::common::PositionProvider seekToRowGroup(int64_t index) override; @@ -96,7 +100,7 @@ class DwrfData : public dwio::common::FormatData { static std::vector toPositionsInner( const proto::RowIndexEntry& entry) { return std::vector( - entry.positions().begin(), entry.positions().end()); + entry.positions().cbegin(), entry.positions().cend()); } memory::MemoryPool& memoryPool_; diff --git a/velox/dwio/dwrf/reader/DwrfReader.cpp b/velox/dwio/dwrf/reader/DwrfReader.cpp index 637247e1a685..f311d3953a6f 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.cpp +++ b/velox/dwio/dwrf/reader/DwrfReader.cpp @@ -19,6 +19,7 @@ #include #include "velox/dwio/common/OnDemandUnitLoader.h" +#include "velox/dwio/common/ParallelUnitLoader.h" #include "velox/dwio/common/TypeUtils.h" #include "velox/dwio/common/exception/Exception.h" #include "velox/dwio/dwrf/reader/ColumnReader.h" @@ -38,24 +39,26 @@ using dwio::common::UnitLoaderFactory; class DwrfUnit : public LoadUnit { public: DwrfUnit( - const StripeReaderBase& stripeReaderBase, + std::shared_ptr readerBase, const StrideIndexProvider& strideIndexProvider, - dwio::common::ColumnReaderStatistics& columnReaderStatistics, + std::shared_ptr + columnReaderStatistics, uint32_t stripeIndex, std::shared_ptr columnSelector, - const std::shared_ptr& projectedNodes, + std::shared_ptr projectedNodes, RowReaderOptions options, - const dwio::common::ColumnReaderOptions& columnReaderOptions) - : stripeReaderBase_{stripeReaderBase}, + dwio::common::ColumnReaderOptions columnReaderOptions) + : stripeReaderBase_{readerBase}, + memoryPool_(readerBase->memoryPool().shared_from_this()), strideIndexProvider_{strideIndexProvider}, - columnReaderStatistics_{&columnReaderStatistics}, + columnReaderStatistics_{std::move(columnReaderStatistics)}, stripeIndex_{stripeIndex}, columnSelector_{std::move(columnSelector)}, - projectedNodes_{projectedNodes}, + projectedNodes_{std::move(projectedNodes)}, options_{std::move(options)}, - columnReaderOptions_{columnReaderOptions}, + columnReaderOptions_{std::move(columnReaderOptions)}, stripeInfo_{ - stripeReaderBase.getReader().footer().stripes(stripeIndex_)} {} + stripeReaderBase_.getReader().footer().stripes(stripeIndex_)} {} ~DwrfUnit() override = default; @@ -85,14 +88,25 @@ class DwrfUnit : public LoadUnit { void loadDecoders(); // Immutables - const StripeReaderBase& stripeReaderBase_; + const StripeReaderBase stripeReaderBase_; + // Not used in DwrfUnit directly, it is to keep memory pool alive for + // readerBase + const std::shared_ptr memoryPool_; + + // SAFETY: This reference is safe despite DwrfUnit potentially outliving + // DwrfRowReader during async operations. The reference is only STORED (not + // dereferenced) during load() path. Actual dereferencing via + // getStrideIndex() only happens during synchronous data reading in + // ColumnReader::next(), where DwrfRowReader is guaranteed to be alive. const StrideIndexProvider& strideIndexProvider_; - dwio::common::ColumnReaderStatistics* const columnReaderStatistics_; + + const std::shared_ptr + columnReaderStatistics_; const uint32_t stripeIndex_; const std::shared_ptr columnSelector_; const std::shared_ptr projectedNodes_; const RowReaderOptions options_; - const dwio::common::ColumnReaderOptions& columnReaderOptions_; + const dwio::common::ColumnReaderOptions columnReaderOptions_; const StripeInformationWrapper stripeInfo_; // Mutables @@ -242,6 +256,8 @@ DwrfRowReader::DwrfRowReader( reader->schema()))}, decodingTimeCallback_{options_.decodingTimeCallback()}, strideIndex_{0}, + columnReaderStatistics_{ + std::make_shared()}, currentUnit_{nullptr} { const auto& fileFooter = getReader().footer(); const uint32_t numberOfStripes = fileFooter.stripesSize(); @@ -328,22 +344,30 @@ std::unique_ptr DwrfRowReader::getUnitLoader() { std::vector> loadUnits; loadUnits.reserve(stripeCeiling_ - firstStripe_); for (auto stripe = firstStripe_; stripe < stripeCeiling_; ++stripe) { - loadUnits.emplace_back(std::make_unique( - /*stripeReaderBase=*/*this, - /*strideIndexProvider=*/*this, - columnReaderStatistics_, - stripe, - columnSelector_, - projectedNodes_, - options_, - columnReaderOptions_)); + loadUnits.emplace_back( + std::make_unique( + /*readerBase=*/readerBaseShared(), + /*strideIndexProvider=*/*this, + columnReaderStatistics_, + stripe, + columnSelector_, + projectedNodes_, + options_, + columnReaderOptions_)); } std::shared_ptr unitLoaderFactory = options_.unitLoaderFactory(); if (!unitLoaderFactory) { - unitLoaderFactory = - std::make_shared( - options_.blockedOnIoCallback()); + if (loadUnits.size() > 1 && options_.parallelUnitLoadCount() > 1 && + options_.ioExecutor() != nullptr) { + unitLoaderFactory = + std::make_shared( + options_.ioExecutor(), options_.parallelUnitLoadCount()); + } else { + unitLoaderFactory = + std::make_shared( + options_.blockedOnIoCallback()); + } } return unitLoaderFactory->create(std::move(loadUnits), 0); } @@ -632,6 +656,8 @@ uint64_t DwrfRowReader::next( } else { previousRow_ = 0; } + // Collect unit loader stats at the end. + unitLoadStats_ = unitLoader_->stats(); return 0; } @@ -752,8 +778,10 @@ std::optional DwrfRowReader::estimatedRowSizeHelper( case TypeKind::ARRAY: case TypeKind::MAP: case TypeKind::ROW: { - // start the estimate with the offsets and hasNulls vectors sizes - size_t totalEstimate = valueCount * (sizeof(uint8_t) + sizeof(uint64_t)); + // Start the estimate with the offsets and sizes buffers. + size_t totalEstimate = nodeType.kind() == TypeKind::ROW + ? 0 + : 2 * valueCount * sizeof(vector_size_t); for (int32_t i = 0; i < nodeType.subtypesSize(); ++i) { if (!shouldReadNode(nodeType.subtypes(i))) { continue; @@ -774,15 +802,23 @@ std::optional DwrfRowReader::estimatedRowSizeHelper( } std::optional DwrfRowReader::estimatedRowSize() const { + if (hasRowEstimate_) { + return estimatedRowSize_; + } + const auto& reader = getReader(); const auto& fileFooter = reader.footer(); + hasRowEstimate_ = true; + if (!fileFooter.hasNumberOfRows()) { - return std::nullopt; + estimatedRowSize_ = std::nullopt; + return estimatedRowSize_; } if (fileFooter.numberOfRows() < 1) { - return 0; + estimatedRowSize_ = 0; + return estimatedRowSize_; } // Estimate with projections. @@ -791,9 +827,12 @@ std::optional DwrfRowReader::estimatedRowSize() const { const auto projectedSize = estimatedRowSizeHelper(fileFooter, *stats, ROOT_NODE_ID); if (projectedSize.has_value()) { - return projectedSize.value() / fileFooter.numberOfRows(); + estimatedRowSize_ = projectedSize.value() / fileFooter.numberOfRows(); + return estimatedRowSize_; } - return std::nullopt; + + estimatedRowSize_ = std::nullopt; + return estimatedRowSize_; } DwrfReader::DwrfReader( @@ -815,8 +854,9 @@ DwrfReader::DwrfReader( void DwrfReader::updateColumnNamesFromTableSchema() { const auto& tableSchema = readerBase_->readerOptions().fileSchema(); const auto& fileSchema = readerBase_->schema(); - readerBase_->setSchema(std::dynamic_pointer_cast( - updateColumnNames(fileSchema, tableSchema))); + readerBase_->setSchema( + std::dynamic_pointer_cast( + updateColumnNames(fileSchema, tableSchema))); } std::unique_ptr DwrfReader::getStripe( diff --git a/velox/dwio/dwrf/reader/DwrfReader.h b/velox/dwio/dwrf/reader/DwrfReader.h index dcb38dbb5a1d..e5bb7380f8c3 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.h +++ b/velox/dwio/dwrf/reader/DwrfReader.h @@ -112,7 +112,8 @@ class DwrfRowReader : public StrideIndexProvider, stats.footerBufferOverread += getReader().footerBufferOverread(); stats.numStripes += stripeCeiling_ - firstStripe_; stats.columnReaderStatistics.flattenStringDictionaryValues += - columnReaderStatistics_.flattenStringDictionaryValues; + columnReaderStatistics_->flattenStringDictionaryValues; + stats.unitLoaderStats.merge(unitLoadStats_); } void resetFilterCaches() override; @@ -210,17 +211,22 @@ class DwrfRowReader : public StrideIndexProvider, // Number of processed strides. int64_t processedStrides_{0}; + dwio::common::UnitLoaderStats unitLoadStats_; + // Set to true after clearing filter caches, i.e. adding a dynamic filter. // Causes filters to be re-evaluated against stride stats on next stride // instead of next stripe. bool recomputeStridesToSkip_{false}; - dwio::common::ColumnReaderStatistics columnReaderStatistics_; + std::shared_ptr columnReaderStatistics_; std::optional nextRowNumber_; std::unique_ptr unitLoader_; DwrfUnit* currentUnit_; + + mutable std::optional estimatedRowSize_; + mutable bool hasRowEstimate_{false}; }; class DwrfReader : public dwio::common::Reader { diff --git a/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp b/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp index 22f7220e8533..d1a611755bf4 100644 --- a/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp +++ b/velox/dwio/dwrf/reader/FlatMapColumnReader.cpp @@ -163,12 +163,13 @@ std::vector>> getKeyNodesFiltered( .inMapDecoder = inMapDecoder.get(), .keySelectionCallback = nullptr}); - keyNodes.push_back(std::make_unique>( - std::move(valueReader), - std::move(inMapDecoder), - key, - sequence, - memoryPool)); + keyNodes.push_back( + std::make_unique>( + std::move(valueReader), + std::move(inMapDecoder), + key, + sequence, + memoryPool)); }); keySelectionStats.selectedKeys = keyNodes.size(); diff --git a/velox/dwio/dwrf/reader/ReaderBase.cpp b/velox/dwio/dwrf/reader/ReaderBase.cpp index 71dbc2e41510..e227a219e3c9 100644 --- a/velox/dwio/dwrf/reader/ReaderBase.cpp +++ b/velox/dwio/dwrf/reader/ReaderBase.cpp @@ -19,11 +19,14 @@ #include #include "velox/common/process/TraceContext.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/Mutation.h" #include "velox/dwio/common/exception/Exception.h" +#include "velox/functions/lib/string/StringImpl.h" namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; using dwio::common::ColumnStatistics; using dwio::common::FileFormat; using dwio::common::LogType; @@ -95,7 +98,7 @@ template std::unique_ptr parseFooter( dwio::common::SeekableInputStream* input, google::protobuf::Arena* arena) { - auto* impl = google::protobuf::Arena::CreateMessage(arena); + auto* impl = ArenaCreate(arena); VELOX_CHECK(impl->ParseFromZeroCopyStream(input)); return std::make_unique(impl); } @@ -329,7 +332,7 @@ std::shared_ptr ReaderBase::convertType( const FooterWrapper& footer, uint32_t index, bool fileColumnNamesReadAsLowerCase) { - VELOX_CHECK_LT( + VELOX_USER_CHECK_LT( index, folly::to(footer.typesSize()), "Corrupted file, invalid types"); @@ -385,7 +388,7 @@ std::shared_ptr ReaderBase::convertType( footer, type.subtypes(i), fileColumnNamesReadAsLowerCase); auto childName = type.fieldNames(i); if (fileColumnNamesReadAsLowerCase) { - folly::toLowerAscii(childName); + childName = functions::stringImpl::utf8StrToLowerCopy(childName); } names.push_back(std::move(childName)); types.push_back(std::move(childType)); diff --git a/velox/dwio/dwrf/reader/ReaderBase.h b/velox/dwio/dwrf/reader/ReaderBase.h index 561d88e56841..ba560165068d 100644 --- a/velox/dwio/dwrf/reader/ReaderBase.h +++ b/velox/dwio/dwrf/reader/ReaderBase.h @@ -199,7 +199,7 @@ class ReaderBase { const std::string& writerName() const { for (int32_t index = 0; index < footer_->metadataSize(); ++index) { auto entry = footer_->metadata(index); - if (entry.name() == WRITER_NAME_KEY) { + if (entry.name() == kWriterNameKey) { return entry.value(); } } diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp index ec570ae05b7f..095a93ca7a9b 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp @@ -75,16 +75,17 @@ void SelectiveDecimalColumnReader::seekToRowGroup(int64_t index) { template template -void SelectiveDecimalColumnReader::readHelper(RowSet rows) { - vector_size_t numRows = rows.back() + 1; +void SelectiveDecimalColumnReader::readHelper( + const common::Filter* filter, + RowSet rows) { ExtractToReader extractValues(this); - common::AlwaysTrue filter; + common::AlwaysTrue alwaysTrue; DirectRleColumnVisitor< int64_t, common::AlwaysTrue, decltype(extractValues), kDense> - visitor(filter, this, rows, extractValues); + visitor(alwaysTrue, this, rows, extractValues); // decode scale stream if (version_ == velox::dwrf::RleVersion_1) { @@ -104,14 +105,135 @@ void SelectiveDecimalColumnReader::readHelper(RowSet rows) { // reset numValues_ before reading values numValues_ = 0; valueSize_ = sizeof(DataT); + vector_size_t numRows = rows.back() + 1; ensureValuesCapacity(numRows); // decode value stream facebook::velox::dwio::common:: ColumnVisitor - valueVisitor(filter, this, rows, extractValues); + valueVisitor(alwaysTrue, this, rows, extractValues); decodeWithVisitor>(valueDecoder_.get(), valueVisitor); readOffset_ += numRows; + + // Fill decimals before applying filter. + fillDecimals(); + + // 'nullsInReadRange_' is the nulls for the entire read range, and if the row + // set is not dense, result nulls should be allocated, which represents the + // nulls for the selected rows before filtering. + const auto rawNulls = nullsInReadRange_ + ? (kDense ? nullsInReadRange_->as() : rawResultNulls_) + : nullptr; + // Process filter. + process(filter, rows, rawNulls); +} + +template +void SelectiveDecimalColumnReader::processNulls( + bool isNull, + const RowSet& rows, + const uint64_t* rawNulls) { + if (!rawNulls) { + return; + } + returnReaderNulls_ = false; + anyNulls_ = !isNull; + allNull_ = isNull; + + auto rawDecimal = values_->asMutable(); + auto rawScale = scaleBuffer_->asMutable(); + + vector_size_t idx = 0; + if (isNull) { + for (vector_size_t i = 0; i < numValues_; i++) { + if (bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + idx++; + } + } + } else { + for (vector_size_t i = 0; i < numValues_; i++) { + if (!bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx, false); + rawDecimal[idx] = rawDecimal[i]; + rawScale[idx] = rawScale[i]; + addOutputRow(rows[i]); + idx++; + } + } + } +} + +template +void SelectiveDecimalColumnReader::processFilter( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls) { + VELOX_CHECK_NOT_NULL(filter, "Filter must not be null."); + returnReaderNulls_ = false; + anyNulls_ = false; + allNull_ = true; + + vector_size_t idx = 0; + auto rawDecimal = values_->asMutable(); + for (vector_size_t i = 0; i < numValues_; i++) { + if (rawNulls && bits::isBitNull(rawNulls, i)) { + if (filter->testNull()) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + anyNulls_ = true; + idx++; + } + } else { + bool tested; + if constexpr (std::is_same_v) { + tested = filter->testInt64(rawDecimal[i]); + } else { + tested = filter->testInt128(rawDecimal[i]); + } + + if (tested) { + if (rawNulls) { + bits::setNull(rawResultNulls_, idx, false); + } + rawDecimal[idx] = rawDecimal[i]; + addOutputRow(rows[i]); + allNull_ = false; + idx++; + } + } + } +} + +template +void SelectiveDecimalColumnReader::process( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls) { + if (!filter) { + // No filter and "hasDeletion" is false so input rows will be + // reused. + return; + } + + switch (filter->kind()) { + case common::FilterKind::kIsNull: + processNulls(true, rows, rawNulls); + break; + case common::FilterKind::kIsNotNull: { + if (rawNulls) { + processNulls(false, rows, rawNulls); + } else { + for (vector_size_t i = 0; i < numValues_; i++) { + addOutputRow(rows[i]); + } + } + break; + } + default: + processFilter(filter, rows, rawNulls); + } } template @@ -119,14 +241,23 @@ void SelectiveDecimalColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - VELOX_CHECK(!scanSpec_->filter()); VELOX_CHECK(!scanSpec_->valueHook()); prepareRead(offset, rows, incomingNulls); + if (!scanSpec_->keepValues() && scanSpec_->filter() && + (!resultNulls_ || !resultNulls_->unique() || + resultNulls_->capacity() * 8 < rows.size())) { + // Make sure a dedicated resultNulls_ is allocated with enough capacity as + // RleDecoder always assumes it is available and 'prepareRead' skips + // allocation when the column is not projected. + resultNulls_ = AlignedBuffer::allocate(rows.size(), memoryPool_); + rawResultNulls_ = resultNulls_->asMutable(); + } + rawValues_ = values_->asMutable(); bool isDense = rows.back() == rows.size() - 1; if (isDense) { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } else { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } } @@ -134,16 +265,17 @@ template void SelectiveDecimalColumnReader::getValues( const RowSet& rows, VectorPtr* result) { + getIntValues(rows, requestedType_, result); +} + +template +void SelectiveDecimalColumnReader::fillDecimals() { auto nullsPtr = resultNulls() ? resultNulls()->template as() : nullptr; auto scales = scaleBuffer_->as(); auto values = values_->asMutable(); - DecimalUtil::fillDecimals( values, nullsPtr, values, scales, numValues_, scale_); - - rawValues_ = values_->asMutable(); - getIntValues(rows, requestedType_, result); } template class SelectiveDecimalColumnReader; diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h index 67a82b051e36..338d8ac4756f 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h @@ -49,7 +49,24 @@ class SelectiveDecimalColumnReader : public SelectiveColumnReader { private: template - void readHelper(RowSet rows); + void readHelper(const common::Filter* filter, RowSet rows); + + // Process IsNull and IsNotNull filters. + void processNulls(bool isNull, const RowSet& rows, const uint64_t* rawNulls); + + // Process filters on decimal values. + void processFilter( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls); + + // Dispatch to the respective filter processing based on the filter type. + void process( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls); + + void fillDecimals(); std::unique_ptr> valueDecoder_; std::unique_ptr> scaleDecoder_; diff --git a/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp b/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp index 05c16970ad9a..da57c358ebd5 100644 --- a/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveDwrfReader.cpp @@ -99,6 +99,10 @@ std::unique_ptr SelectiveDwrfReader::build( return createSelectiveFlatMapColumnReader( columnReaderOptions, requestedType, fileType, params, scanSpec); } + if (scanSpec.isFlatMapAsStruct()) { + return std::make_unique( + columnReaderOptions, requestedType, fileType, params, scanSpec); + } return std::make_unique( columnReaderOptions, requestedType, fileType, params, scanSpec); case TypeKind::REAL: @@ -152,7 +156,7 @@ std::unique_ptr SelectiveDwrfReader::build( default: VELOX_FAIL( "buildReader unhandled type: " + - mapTypeKindToName(fileType->type()->kind())); + std::string(TypeKindName::toName(fileType->type()->kind()))); } } diff --git a/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp index 3f7b30c0d6a9..670fe3bea89b 100644 --- a/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.cpp @@ -17,11 +17,12 @@ #include "velox/dwio/dwrf/reader/SelectiveFlatMapColumnReader.h" #include "velox/dwio/common/FlatMapHelper.h" +#include "velox/dwio/common/SelectiveFlatMapColumnReader.h" #include "velox/dwio/dwrf/reader/SelectiveDwrfReader.h" #include "velox/dwio/dwrf/reader/SelectiveStructColumnReader.h" +#include "velox/vector/FlatMapVector.h" namespace facebook::velox::dwrf { - namespace { template @@ -39,7 +40,7 @@ inline dwio::common::flatmap::KeyValue extractKey( template std::string toString(const T& x) { if constexpr (std::is_same_v) { - return x; + return std::string(x); } else { return std::to_string(x); } @@ -72,7 +73,7 @@ std::vector> getKeyNodes( const std::shared_ptr& fileType, DwrfParams& params, common::ScanSpec& scanSpec, - bool asStruct) { + dwio::common::flatmap::FlatMapOutput outputType) { using namespace dwio::common::flatmap; std::vector> keyNodes; @@ -86,23 +87,43 @@ std::vector> getKeyNodes( common::ScanSpec* valuesSpec = nullptr; std::unordered_map, common::ScanSpec*, KeyValueHash> childSpecs; - if (!asStruct) { - keysSpec = scanSpec.getOrCreateChild(common::ScanSpec::kMapKeysFieldName); - valuesSpec = - scanSpec.getOrCreateChild(common::ScanSpec::kMapValuesFieldName); - VELOX_CHECK(!valuesSpec->hasFilter()); - keysSpec->setProjectOut(true); - valuesSpec->setProjectOut(true); - } else { - for (auto& c : scanSpec.children()) { - T key; - if constexpr (std::is_same_v) { - key = StringView(c->fieldName()); - } else { - key = folly::to(c->fieldName()); + + // Adjust the scan spec according to the output type. + switch (outputType) { + // For a kMap output, just need a scan spec for map keys and one for map + // values. + case FlatMapOutput::kMap: { + keysSpec = scanSpec.getOrCreateChild(common::ScanSpec::kMapKeysFieldName); + valuesSpec = + scanSpec.getOrCreateChild(common::ScanSpec::kMapValuesFieldName); + VELOX_CHECK(!valuesSpec->hasFilter()); + keysSpec->setProjectOut(true); + valuesSpec->setProjectOut(true); + break; + } + // For a kStruct output, the streams to be read are part of the scan spec + // already. + case FlatMapOutput::kStruct: { + for (auto& c : scanSpec.children()) { + T key; + if constexpr (std::is_same_v) { + key = StringView(c->fieldName()); + } else { + key = folly::to(c->fieldName()); + } + childSpecs[KeyValue(key)] = c.get(); } - childSpecs[KeyValue(key)] = c.get(); + break; } + case FlatMapOutput::kFlatMap: + // Remove on filters on keys stream since it doesn't exist (it's common to + // filter out nulls). + keysSpec = scanSpec.getOrCreateChild(common::ScanSpec::kMapKeysFieldName); + valuesSpec = + scanSpec.getOrCreateChild(common::ScanSpec::kMapValuesFieldName); + keysSpec->setFilter(nullptr); + VELOX_CHECK(!valuesSpec->hasFilter()); + break; } // Load all sub streams. @@ -119,10 +140,14 @@ std::vector> getKeyNodes( const auto& keyInfo = stripe.getEncoding(seqEk).key(); auto key = extractKey(keyInfo); common::ScanSpec* childSpec; - if (auto it = childSpecs.find(key); - it != childSpecs.end() && !it->second->isConstant()) { + if (outputType == FlatMapOutput::kFlatMap) { + childSpec = scanSpec.getOrCreateChild(toString(key.get())); + childSpec->setProjectOut(true); + childSpec->setChannel(sequence - 1); + } else if (auto it = childSpecs.find(key); + it != childSpecs.end() && !it->second->isConstant()) { childSpec = it->second; - } else if (asStruct) { + } else if (outputType == FlatMapOutput::kStruct) { // Column not selected in 'scanSpec', skipping it. return; } else { @@ -159,7 +184,7 @@ std::vector> getKeyNodes( << fileType->id() << ", keys=" << keyNodes.size() << ", streams=" << streams; - if (!asStruct) { + if (outputType != FlatMapOutput::kStruct) { std::sort(keyNodes.begin(), keyNodes.end(), [](auto& x, auto& y) { return x.sequence < y.sequence; }); @@ -181,13 +206,14 @@ class SelectiveFlatMapAsStructReader : public SelectiveStructColumnReaderBase { fileType, params, scanSpec), - keyNodes_(getKeyNodes( - columnReaderOptions, - requestedType, - fileType, - params, - scanSpec, - true)) { + keyNodes_( + getKeyNodes( + columnReaderOptions, + requestedType, + fileType, + params, + scanSpec, + dwio::common::flatmap::FlatMapOutput::kStruct)) { VELOX_CHECK( !keyNodes_.empty(), "For struct encoding, keys to project must be configured"); @@ -206,9 +232,9 @@ class SelectiveFlatMapAsStructReader : public SelectiveStructColumnReaderBase { }; template -class SelectiveFlatMapReader : public SelectiveStructColumnReaderBase { +class SelectiveFlatMapAsMapReader : public SelectiveStructColumnReaderBase { public: - SelectiveFlatMapReader( + SelectiveFlatMapAsMapReader( const dwio::common::ColumnReaderOptions& columnReaderOptions, const TypePtr& requestedType, const std::shared_ptr& fileType, @@ -227,7 +253,12 @@ class SelectiveFlatMapReader : public SelectiveStructColumnReaderBase { fileType, params, scanSpec, - false)) {} + dwio::common::flatmap::FlatMapOutput::kMap)) {} + + void setIsTopLevel() override { + // Children are not considered top level since this is materialized as MAP. + SelectiveColumnReader::setIsTopLevel(); + } void read(int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) override { @@ -243,6 +274,68 @@ class SelectiveFlatMapReader : public SelectiveStructColumnReaderBase { flatMap_; }; +template +class SelectiveFlatMapReader + : public dwio::common::SelectiveFlatMapColumnReader { + public: + SelectiveFlatMapReader( + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const TypePtr& requestedType, + const std::shared_ptr& fileType, + DwrfParams& params, + common::ScanSpec& scanSpec) + : dwio::common::SelectiveFlatMapColumnReader( + requestedType, + fileType, + params, + scanSpec), + keyNodes_( + getKeyNodes( + columnReaderOptions, + requestedType, + fileType, + params, + scanSpec, + dwio::common::flatmap::FlatMapOutput::kFlatMap)), + rowsPerRowGroup_(formatData_->rowsPerRowGroup().value()) { + // Instantiate and populate distinct keys vector. + keysVector_ = BaseVector::create( + CppToType::create(), + (vector_size_t)keyNodes_.size(), + ¶ms.pool()); + auto rawKeys = keysVector_->values()->asMutable(); + children_.resize(keyNodes_.size()); + + for (int i = 0; i < keyNodes_.size(); ++i) { + keyNodes_[i].reader->scanSpec()->setSubscript(i); + children_[i] = keyNodes_[i].reader.get(); + + rawKeys[i] = keyNodes_[i].key.get(); + } + } + + const BufferPtr& inMapBuffer(column_index_t childIndex) const override { + return children_[childIndex] + ->formatData() + .template as() + .inMapBuffer(); + } + + void seekToRowGroup(int64_t index) override { + seekToRowGroupFixedRowsPerRowGroup(index, rowsPerRowGroup_); + } + + void advanceFieldReader( + dwio::common::SelectiveColumnReader* reader, + int64_t offset) override { + advanceFieldReaderFixedRowsPerRowGroup(reader, offset, rowsPerRowGroup_); + } + + private: + std::vector> keyNodes_; + const int32_t rowsPerRowGroup_; +}; + template std::unique_ptr createReader( const dwio::common::ColumnReaderOptions& columnReaderOptions, @@ -253,9 +346,14 @@ std::unique_ptr createReader( if (scanSpec.isFlatMapAsStruct()) { return std::make_unique>( columnReaderOptions, requestedType, fileType, params, scanSpec); - } else { + } else if (params.stripeStreams() + .rowReaderOptions() + .preserveFlatMapsInMemory()) { return std::make_unique>( columnReaderOptions, requestedType, fileType, params, scanSpec); + } else { + return std::make_unique>( + columnReaderOptions, requestedType, fileType, params, scanSpec); } } @@ -268,7 +366,8 @@ createSelectiveFlatMapColumnReader( const std::shared_ptr& fileType, DwrfParams& params, common::ScanSpec& scanSpec) { - auto kind = fileType->childAt(0)->type()->kind(); + VELOX_DCHECK(requestedType->isMap()); + auto kind = requestedType->childAt(0)->kind(); switch (kind) { case TypeKind::TINYINT: return createReader( diff --git a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp index 6cbfd654a4a2..e7ba8206754b 100644 --- a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp @@ -60,7 +60,6 @@ SelectiveListColumnReader::SelectiveListColumnReader( params, scanSpec), length_(makeLengthDecoder(*fileType_, params, *memoryPool_)) { - VELOX_CHECK_EQ(fileType_->id(), fileType->id(), "working on the same node"); EncodingKey encodingKey{fileType_->id(), params.flatMapContext().sequence}; auto& stripe = params.stripeStreams(); // count the number of selected sub-columns @@ -84,54 +83,88 @@ SelectiveListColumnReader::SelectiveListColumnReader( children_ = {child_.get()}; } -SelectiveMapColumnReader::SelectiveMapColumnReader( - const dwio::common::ColumnReaderOptions& columnReaderOptions, - const TypePtr& requestedType, - const std::shared_ptr& fileType, +namespace { + +void makeMapChildrenReaders( + const dwio::common::TypeWithId& fileType, + const Type& requestedType, DwrfParams& params, - common::ScanSpec& scanSpec) - : dwio::common::SelectiveMapColumnReader( - requestedType, - fileType, - params, - scanSpec), - length_(makeLengthDecoder(*fileType_, params, *memoryPool_)) { - VELOX_CHECK_EQ(fileType_->id(), fileType->id(), "working on the same node"); + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const common::ScanSpec& scanSpec, + std::unique_ptr& keyReader, + std::unique_ptr& elementReader) { const EncodingKey encodingKey{ - fileType_->id(), params.flatMapContext().sequence}; + fileType.id(), params.flatMapContext().sequence}; auto& stripe = params.stripeStreams(); - if (scanSpec_->children().empty()) { - scanSpec_->getOrCreateChild(common::ScanSpec::kMapKeysFieldName); - scanSpec_->getOrCreateChild(common::ScanSpec::kMapValuesFieldName); - } - scanSpec_->children()[0]->setProjectOut(true); - scanSpec_->children()[1]->setProjectOut(true); - - auto& keyType = requestedType_->childAt(0); - auto keyParams = DwrfParams( + DwrfParams keyParams( stripe, params.streamLabels(), params.runtimeStatistics(), flatMapContextFromEncodingKey(encodingKey)); - keyReader_ = SelectiveDwrfReader::build( + keyReader = SelectiveDwrfReader::build( columnReaderOptions, - keyType, - fileType_->childAt(0), + requestedType.childAt(0), + fileType.childAt(0), keyParams, - *scanSpec_->children()[0].get()); - - auto& valueType = requestedType_->childAt(1); - auto elementParams = DwrfParams( + *scanSpec.children()[0]); + DwrfParams elementParams = DwrfParams( stripe, params.streamLabels(), params.runtimeStatistics(), flatMapContextFromEncodingKey(encodingKey)); - elementReader_ = SelectiveDwrfReader::build( + elementReader = SelectiveDwrfReader::build( columnReaderOptions, - valueType, - fileType_->childAt(1), + requestedType.childAt(1), + fileType.childAt(1), elementParams, - *scanSpec_->children()[1]); + *scanSpec.children()[1]); +} + +} // namespace + +SelectiveMapColumnReader::SelectiveMapColumnReader( + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const TypePtr& requestedType, + const std::shared_ptr& fileType, + DwrfParams& params, + common::ScanSpec& scanSpec) + : dwio::common::SelectiveMapColumnReader( + requestedType, + fileType, + params, + scanSpec), + length_(makeLengthDecoder(*fileType_, params, *memoryPool_)) { + makeMapChildrenReaders( + *fileType_, + *requestedType_, + params, + columnReaderOptions, + *scanSpec_, + keyReader_, + elementReader_); + children_ = {keyReader_.get(), elementReader_.get()}; +} + +SelectiveMapAsStructColumnReader::SelectiveMapAsStructColumnReader( + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const TypePtr& requestedType, + const std::shared_ptr& fileType, + DwrfParams& params, + common::ScanSpec& scanSpec) + : dwio::common::SelectiveMapAsStructColumnReader( + requestedType, + fileType, + params, + scanSpec), + length_(makeLengthDecoder(*fileType_, params, *memoryPool_)) { + makeMapChildrenReaders( + *fileType_, + *requestedType_, + params, + columnReaderOptions, + mapScanSpec_, + keyReader_, + elementReader_); children_ = {keyReader_.get(), elementReader_.get()}; } diff --git a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h index ad6e8575bdd9..05e836244f0b 100644 --- a/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.h @@ -34,10 +34,6 @@ class SelectiveListColumnReader DwrfParams& params, common::ScanSpec& scanSpec); - void resetFilterCaches() override { - child_->resetFilterCaches(); - } - void seekToRowGroup(int64_t index) override { dwio::common::SelectiveListColumnReader::seekToRowGroup(index); auto positionsProvider = formatData_->seekToRowGroup(index); @@ -68,24 +64,37 @@ class SelectiveMapColumnReader : public dwio::common::SelectiveMapColumnReader { DwrfParams& params, common::ScanSpec& scanSpec); - void resetFilterCaches() override { - keyReader_->resetFilterCaches(); - elementReader_->resetFilterCaches(); - } - void seekToRowGroup(int64_t index) override { dwio::common::SelectiveMapColumnReader::seekToRowGroup(index); auto positionsProvider = formatData_->seekToRowGroup(index); - length_->seekToRowGroup(positionsProvider); - VELOX_CHECK(!positionsProvider.hasNext()); + } - keyReader_->seekToRowGroup(index); - keyReader_->setReadOffsetRecursive(0); - elementReader_->seekToRowGroup(index); - elementReader_->setReadOffsetRecursive(0); - childTargetReadOffset_ = 0; + void readLengths(int32_t* lengths, int32_t numLengths, const uint64_t* nulls) + override { + length_->next(lengths, numLengths, nulls); + } + + private: + std::unique_ptr> length_; +}; + +class SelectiveMapAsStructColumnReader + : public dwio::common::SelectiveMapAsStructColumnReader { + public: + SelectiveMapAsStructColumnReader( + const dwio::common::ColumnReaderOptions& columnReaderOptions, + const TypePtr& requestedType, + const std::shared_ptr& fileType, + DwrfParams& params, + common::ScanSpec& scanSpec); + + void seekToRowGroup(int64_t index) override { + dwio::common::SelectiveMapAsStructColumnReader::seekToRowGroup(index); + auto positionsProvider = formatData_->seekToRowGroup(index); + length_->seekToRowGroup(positionsProvider); + VELOX_CHECK(!positionsProvider.hasNext()); } void readLengths(int32_t* lengths, int32_t numLengths, const uint64_t* nulls) diff --git a/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp index d88b8c24e404..d6914a47b94e 100644 --- a/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveStringDictionaryColumnReader.cpp @@ -156,7 +156,7 @@ void SelectiveStringDictionaryColumnReader::loadStrideDictionary() { if (scanState_.dictionary2.numValues > 0) { // seek stride dictionary related streams std::vector pos( - positions.begin() + positionOffset_, positions.end()); + positions.cbegin() + positionOffset_, positions.cend()); PositionProvider pp(pos); strideDictStream_->seekToPosition(pp); strideDictLengthDecoder_->seekToRowGroup(pp); diff --git a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp index 657470e92676..c4777ea43be5 100644 --- a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.cpp @@ -362,20 +362,19 @@ void SelectiveStringDirectColumnReader::skipInDecode( lengthIndex_ += numValues; } -folly::StringPiece SelectiveStringDirectColumnReader::readValue( - int32_t length) { +std::string_view SelectiveStringDirectColumnReader::readValue(int32_t length) { skipBytes(bytesToSkip_, blobStream_.get(), bufferStart_, bufferEnd_); bytesToSkip_ = 0; // bufferStart_ may be null if length is 0 and this is the first string // we're reading. if (bufferEnd_ - bufferStart_ >= length) { bytesToSkip_ = length; - return folly::StringPiece(bufferStart_, length); + return std::string_view(bufferStart_, length); } tempString_.resize(length); readBytes( length, blobStream_.get(), tempString_.data(), bufferStart_, bufferEnd_); - return folly::StringPiece(tempString_); + return std::string_view(tempString_); } template @@ -474,7 +473,7 @@ void SelectiveStringDirectColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - prepareRead(offset, rows, incomingNulls); + prepareRead(offset, rows, incomingNulls); auto numRows = rows.back() + 1; auto numNulls = nullsInReadRange_ ? BaseVector::countNulls(nullsInReadRange_, 0, numRows) diff --git a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h index 8da1e77401d2..e0a2369ceb57 100644 --- a/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveStringDirectColumnReader.h @@ -58,7 +58,7 @@ class SelectiveStringDirectColumnReader template void skipInDecode(int32_t numValues, int32_t current, const uint64_t* nulls); - folly::StringPiece readValue(int32_t length); + std::string_view readValue(int32_t length); template void decode(const uint64_t* nulls, Visitor visitor); diff --git a/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp index b1f47289505e..364270e53911 100644 --- a/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveStructColumnReader.cpp @@ -86,12 +86,13 @@ SelectiveStructColumnReader::SelectiveStructColumnReader( .sequence = encodingKey.sequence(), .inMapDecoder = nullptr, .keySelectionCallback = nullptr}); - addChild(SelectiveDwrfReader::build( - columnReaderOptions, - childRequestedType, - childFileType, - childParams, - *childSpec)); + addChild( + SelectiveDwrfReader::build( + columnReaderOptions, + childRequestedType, + childFileType, + childParams, + *childSpec)); childSpec->setSubscript(children_.size() - 1); } } diff --git a/velox/dwio/dwrf/reader/SelectiveStructColumnReader.h b/velox/dwio/dwrf/reader/SelectiveStructColumnReader.h index a334f5b7ca9e..6e323b67fe39 100644 --- a/velox/dwio/dwrf/reader/SelectiveStructColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveStructColumnReader.h @@ -29,13 +29,15 @@ class SelectiveStructColumnReaderBase const std::shared_ptr& fileType, DwrfParams& params, common::ScanSpec& scanSpec, - bool isRoot = false) + bool isRoot = false, + bool generateLazyChildren = true) : dwio::common::SelectiveStructColumnReaderBase( requestedType, fileType, params, scanSpec, - isRoot), + isRoot, + generateLazyChildren), rowsPerRowGroup_(formatData_->rowsPerRowGroup().value()) { VELOX_CHECK_EQ(fileType_->id(), fileType->id(), "working on the same node"); } @@ -43,35 +45,14 @@ class SelectiveStructColumnReaderBase void seekTo(int64_t offset, bool readsNullsOnly) override; void seekToRowGroup(int64_t index) override { - dwio::common::SelectiveStructColumnReaderBase::seekToRowGroup(index); - if (isTopLevel_ && !formatData_->hasNulls()) { - readOffset_ = index * rowsPerRowGroup_; - return; - } - - // There may be a nulls stream but no other streams for the struct. - formatData_->seekToRowGroup(index); - // Set the read offset recursively. Do this before seeking the children - // because list/map children will reset the offsets for their children. - setReadOffsetRecursive(index * rowsPerRowGroup_); - for (auto& child : children_) { - child->seekToRowGroup(index); - } + seekToRowGroupFixedRowsPerRowGroup(index, rowsPerRowGroup_); } /// Advance field reader to the row group closest to specified offset by /// calling seekToRowGroup. void advanceFieldReader(SelectiveColumnReader* reader, int64_t offset) override { - if (!reader->isTopLevel()) { - return; - } - const auto rowGroup = reader->readOffset() / rowsPerRowGroup_; - const auto nextRowGroup = offset / rowsPerRowGroup_; - if (nextRowGroup > rowGroup) { - reader->seekToRowGroup(nextRowGroup); - reader->setReadOffset(nextRowGroup * rowsPerRowGroup_); - } + advanceFieldReaderFixedRowsPerRowGroup(reader, offset, rowsPerRowGroup_); } private: diff --git a/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp index 99f790851e97..10f9be96a57b 100644 --- a/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.cpp @@ -100,7 +100,7 @@ void SelectiveTimestampColumnReader::read( template void SelectiveTimestampColumnReader::readHelper( - common::Filter* filter, + const common::Filter* filter, const RowSet& rows) { ExtractToReader extractValues(this); common::AlwaysTrue alwaysTrue; diff --git a/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.h b/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.h index 44ba1feb113d..b9817b71f3aa 100644 --- a/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveTimestampColumnReader.h @@ -42,7 +42,7 @@ class SelectiveTimestampColumnReader private: template - void readHelper(common::Filter* filter, const RowSet& rows); + void readHelper(const common::Filter* filter, const RowSet& rows); void processNulls(const bool isNull, const RowSet& rows, const uint64_t* rawNulls); diff --git a/velox/dwio/dwrf/reader/StripeMetadataCache.h b/velox/dwio/dwrf/reader/StripeMetadataCache.h index 4a7a1125610d..9cd5507932dd 100644 --- a/velox/dwio/dwrf/reader/StripeMetadataCache.h +++ b/velox/dwio/dwrf/reader/StripeMetadataCache.h @@ -95,7 +95,7 @@ class StripeMetadataCache { std::vector offsets; offsets.reserve(footer.stripeCacheOffsetsSize()); const auto& from = footer.stripeCacheOffsets(); - offsets.assign(from.begin(), from.end()); + offsets.assign(from.cbegin(), from.cend()); return offsets; } diff --git a/velox/dwio/dwrf/reader/StripeReaderBase.cpp b/velox/dwio/dwrf/reader/StripeReaderBase.cpp index 44ba6aa81e8b..9bc5549d9237 100644 --- a/velox/dwio/dwrf/reader/StripeReaderBase.cpp +++ b/velox/dwio/dwrf/reader/StripeReaderBase.cpp @@ -16,8 +16,11 @@ #include "velox/dwio/dwrf/reader/StripeReaderBase.h" +#include "velox/dwio/common/Arena.h" + namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; using dwio::common::LogType; // preload is not considered or mutated if stripe has already been fetched. e.g. @@ -95,9 +98,7 @@ std::unique_ptr StripeReaderBase::fetchStripe( }; if (fileFooter.format() == DwrfFormat::kDwrf) { - auto* rawFooter = - google::protobuf::Arena::CreateMessage( - arena.get()); + auto* rawFooter = ArenaCreate(arena.get()); ProtoUtils::readProtoInto( reader_->createDecompressedStream( std::move(footerStream), streamDebugInfo), @@ -110,9 +111,7 @@ std::unique_ptr StripeReaderBase::fetchStripe( return createStripeMetadata(std::move(stripeFooter)); } else { - auto* rawFooter = - google::protobuf::Arena::CreateMessage( - arena.get()); + auto* rawFooter = ArenaCreate(arena.get()); ProtoUtils::readProtoInto( reader_->createDecompressedStream( std::move(footerStream), streamDebugInfo), diff --git a/velox/dwio/dwrf/test/CMakeLists.txt b/velox/dwio/dwrf/test/CMakeLists.txt index e8da27a642e0..cbd29a9ac2f8 100644 --- a/velox/dwio/dwrf/test/CMakeLists.txt +++ b/velox/dwio/dwrf/test/CMakeLists.txt @@ -14,42 +14,47 @@ add_subdirectory(utils) -set(TEST_LINK_LIBS - velox_dwio_common_test_utils - velox_dwio_dwrf_reader - velox_dwio_dwrf_writer - velox_row_fast - GTest::gtest - GTest::gtest_main - GTest::gmock - glog::glog) +set( + TEST_LINK_LIBS + velox_dwio_common_test_utils + velox_dwio_dwrf_reader + velox_dwio_dwrf_writer + velox_row_fast + GTest::gtest + GTest::gtest_main + GTest::gmock + glog::glog +) -add_executable(velox_dwio_dwrf_buffered_output_stream_test - TestBufferedOutputStream.cpp) -add_test(velox_dwio_dwrf_buffered_output_stream_test - velox_dwio_dwrf_buffered_output_stream_test) +add_executable(velox_dwio_dwrf_buffered_output_stream_test TestBufferedOutputStream.cpp) +add_test(velox_dwio_dwrf_buffered_output_stream_test velox_dwio_dwrf_buffered_output_stream_test) target_link_libraries( - velox_dwio_dwrf_buffered_output_stream_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_buffered_output_stream_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) -add_executable(velox_dwio_dwrf_column_statistics_test - TestDwrfColumnStatistics.cpp) -add_test(velox_dwio_dwrf_column_statistics_test - velox_dwio_dwrf_column_statistics_test) +add_executable(velox_dwio_dwrf_column_statistics_test TestDwrfColumnStatistics.cpp) +add_test(velox_dwio_dwrf_column_statistics_test velox_dwio_dwrf_column_statistics_test) target_link_libraries( - velox_dwio_dwrf_column_statistics_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_column_statistics_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) -add_executable(velox_dwio_orc_column_statistics_test - TestOrcColumnStatistics.cpp) -add_test(velox_dwio_orc_column_statistics_test - velox_dwio_orc_column_statistics_test) +add_executable(velox_dwio_orc_column_statistics_test TestOrcColumnStatistics.cpp) +add_test(velox_dwio_orc_column_statistics_test velox_dwio_orc_column_statistics_test) target_link_libraries( - velox_dwio_orc_column_statistics_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_orc_column_statistics_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_compression_test CompressionTest.cpp) add_test(velox_dwio_dwrf_compression_test velox_dwio_dwrf_compression_test) @@ -60,26 +65,27 @@ target_link_libraries( Folly::folly glog::glog lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_decompression_test TestDecompression.cpp) add_test( NAME velox_dwio_dwrf_decompression_test COMMAND velox_dwio_dwrf_decompression_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_dwio_dwrf_decompression_test velox_link_libs Folly::folly lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_stripe_stream_test TestStripeStream.cpp) add_test(velox_dwio_dwrf_stripe_stream_test velox_dwio_dwrf_stripe_stream_test) @@ -90,10 +96,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_stream_labels_test StreamLabelsTests.cpp) add_test(velox_dwio_dwrf_stream_labels_test velox_dwio_dwrf_stream_labels_test) @@ -104,15 +110,13 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) -add_executable(velox_dwio_dwrf_stripe_dictionary_cache_test - TestStripeDictionaryCache.cpp) -add_test(velox_dwio_dwrf_stripe_dictionary_cache_test - velox_dwio_dwrf_stripe_dictionary_cache_test) +add_executable(velox_dwio_dwrf_stripe_dictionary_cache_test TestStripeDictionaryCache.cpp) +add_test(velox_dwio_dwrf_stripe_dictionary_cache_test velox_dwio_dwrf_stripe_dictionary_cache_test) target_link_libraries( velox_dwio_dwrf_stripe_dictionary_cache_test @@ -120,46 +124,58 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) -add_executable(velox_dwio_dwrf_dictionary_encoder_test - TestIntegerDictionaryEncoder.cpp TestStringDictionaryEncoder.cpp) -add_test(velox_dwio_dwrf_dictionary_encoder_test - velox_dwio_dwrf_dictionary_encoder_test) +add_executable( + velox_dwio_dwrf_dictionary_encoder_test + TestIntegerDictionaryEncoder.cpp + TestStringDictionaryEncoder.cpp +) +add_test(velox_dwio_dwrf_dictionary_encoder_test velox_dwio_dwrf_dictionary_encoder_test) target_link_libraries( velox_dwio_dwrf_dictionary_encoder_test velox_link_libs Folly::folly gflags::gflags - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_encoding_selector_test TestEncodingSelector.cpp) -add_test(velox_dwio_dwrf_encoding_selector_test - velox_dwio_dwrf_encoding_selector_test) +add_test(velox_dwio_dwrf_encoding_selector_test velox_dwio_dwrf_encoding_selector_test) target_link_libraries( - velox_dwio_dwrf_encoding_selector_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_encoding_selector_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_index_builder_test IndexBuilderTests.cpp) add_test(velox_dwio_dwrf_index_builder_test velox_dwio_dwrf_index_builder_test) target_link_libraries( - velox_dwio_dwrf_index_builder_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_index_builder_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) -add_executable(velox_dwio_dwrf_dictionary_encoding_utils_test - TestDictionaryEncodingUtils.cpp) -add_test(velox_dwio_dwrf_dictionary_encoding_utils_test - velox_dwio_dwrf_dictionary_encoding_utils_test) +add_executable(velox_dwio_dwrf_dictionary_encoding_utils_test TestDictionaryEncodingUtils.cpp) +add_test( + velox_dwio_dwrf_dictionary_encoding_utils_test + velox_dwio_dwrf_dictionary_encoding_utils_test +) target_link_libraries( - velox_dwio_dwrf_dictionary_encoding_utils_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_dictionary_encoding_utils_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_checksum_test ChecksumTests.cpp) add_test(velox_dwio_dwrf_checksum_test velox_dwio_dwrf_checksum_test) @@ -169,7 +185,8 @@ target_link_libraries( velox_link_libs Folly::folly Boost::headers - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_writer_test WriterTest.cpp) add_test(velox_dwio_dwrf_writer_test velox_dwio_dwrf_writer_test) @@ -179,110 +196,112 @@ target_link_libraries( velox_link_libs Folly::folly lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_writer_context_test WriterContextTest.cpp) -add_test(velox_dwio_dwrf_writer_context_test - velox_dwio_dwrf_writer_context_test) +add_test(velox_dwio_dwrf_writer_context_test velox_dwio_dwrf_writer_context_test) target_link_libraries( velox_dwio_dwrf_writer_context_test velox_link_libs Folly::folly lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) -add_executable(velox_dwio_dwrf_writer_encoding_manager_test - EncodingManagerTests.cpp) -add_test(velox_dwio_dwrf_writer_encoding_manager_test - velox_dwio_dwrf_writer_encoding_manager_test) +add_executable(velox_dwio_dwrf_writer_encoding_manager_test EncodingManagerTests.cpp) +add_test(velox_dwio_dwrf_writer_encoding_manager_test velox_dwio_dwrf_writer_encoding_manager_test) target_link_libraries( velox_dwio_dwrf_writer_encoding_manager_test velox_link_libs Folly::folly lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_writer_sink_test WriterSinkTest.cpp) add_test(velox_dwio_dwrf_writer_sink_test velox_dwio_dwrf_writer_sink_test) target_link_libraries( - velox_dwio_dwrf_writer_sink_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_writer_sink_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) -add_executable(velox_dwio_dwrf_data_buffer_holder_test - DataBufferHolderTests.cpp) -add_test(velox_dwio_dwrf_data_buffer_holder_test - velox_dwio_dwrf_data_buffer_holder_test) +add_executable(velox_dwio_dwrf_data_buffer_holder_test DataBufferHolderTests.cpp) +add_test(velox_dwio_dwrf_data_buffer_holder_test velox_dwio_dwrf_data_buffer_holder_test) target_link_libraries( - velox_dwio_dwrf_data_buffer_holder_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_data_buffer_holder_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_layout_planner_test LayoutPlannerTests.cpp) -add_test(velox_dwio_dwrf_layout_planner_test - velox_dwio_dwrf_layout_planner_test) +add_test(velox_dwio_dwrf_layout_planner_test velox_dwio_dwrf_layout_planner_test) target_link_libraries( velox_dwio_dwrf_layout_planner_test velox_link_libs Folly::folly lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_decryption_test DecryptionTests.cpp) add_test(velox_dwio_dwrf_decryption_test velox_dwio_dwrf_decryption_test) target_link_libraries( - velox_dwio_dwrf_decryption_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_decryption_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_encryption_test EncryptionTests.cpp) add_test(velox_dwio_dwrf_encryption_test velox_dwio_dwrf_encryption_test) target_link_libraries( - velox_dwio_dwrf_encryption_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_encryption_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_filemetadata_test FileMetadataTest.cpp) add_test(velox_filemetadata_test velox_filemetadata_test) -target_link_libraries( - velox_filemetadata_test velox_link_libs ${TEST_LINK_LIBS}) +target_link_libraries(velox_filemetadata_test velox_link_libs ${TEST_LINK_LIBS}) add_executable(velox_common_test CommonTests.cpp) add_test(velox_common_test velox_common_test) -target_link_libraries( - velox_common_test velox_link_libs ${TEST_LINK_LIBS}) +target_link_libraries(velox_common_test velox_link_libs ${TEST_LINK_LIBS}) -add_executable(velox_dwio_dwrf_stripe_reader_base_test - StripeReaderBaseTests.cpp) -add_test(velox_dwio_dwrf_stripe_reader_base_test - velox_dwio_dwrf_stripe_reader_base_test) +add_executable(velox_dwio_dwrf_stripe_reader_base_test StripeReaderBaseTests.cpp) +add_test(velox_dwio_dwrf_stripe_reader_base_test velox_dwio_dwrf_stripe_reader_base_test) target_link_libraries( velox_dwio_dwrf_stripe_reader_base_test velox_link_libs Folly::folly lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_reader_base_test ReaderBaseTests.cpp) add_test(velox_dwio_dwrf_reader_base_test velox_dwio_dwrf_reader_base_test) @@ -293,10 +312,10 @@ target_link_libraries( Folly::folly glog::glog lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_config_test ConfigTests.cpp) add_test(velox_dwio_dwrf_config_test velox_dwio_dwrf_config_test) @@ -306,21 +325,28 @@ target_link_libraries( velox_link_libs velox_row_fast Folly::folly - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_ratio_checker_test RatioTrackerTest.cpp) add_test(velox_dwio_dwrf_ratio_checker_test velox_dwio_dwrf_ratio_checker_test) target_link_libraries( - velox_dwio_dwrf_ratio_checker_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_ratio_checker_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_flush_policy_test FlushPolicyTest.cpp) add_test(velox_dwio_dwrf_flush_policy_test velox_dwio_dwrf_flush_policy_test) target_link_libraries( - velox_dwio_dwrf_flush_policy_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_flush_policy_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_byte_rle_test TestByteRle.cpp) add_test(velox_dwio_dwrf_byte_rle_test velox_dwio_dwrf_byte_rle_test) @@ -330,31 +356,34 @@ target_link_libraries( velox_link_libs fmt::fmt Folly::folly - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_byte_rle_encoder_test TestByteRLEEncoder.cpp) -add_test(velox_dwio_dwrf_byte_rle_encoder_test - velox_dwio_dwrf_byte_rle_encoder_test) +add_test(velox_dwio_dwrf_byte_rle_encoder_test velox_dwio_dwrf_byte_rle_encoder_test) target_link_libraries( velox_dwio_dwrf_byte_rle_encoder_test velox_link_libs fmt::fmt Folly::folly - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_int_encoder_test TestIntEncoder.cpp) add_test(velox_dwio_dwrf_int_encoder_test velox_dwio_dwrf_int_encoder_test) target_link_libraries( - velox_dwio_dwrf_int_encoder_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_int_encoder_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_rle_test TestRle.cpp) add_test(velox_dwio_dwrf_rle_test velox_dwio_dwrf_rle_test) -target_link_libraries( - velox_dwio_dwrf_rle_test velox_link_libs Folly::folly ${TEST_LINK_LIBS}) +target_link_libraries(velox_dwio_dwrf_rle_test velox_link_libs Folly::folly ${TEST_LINK_LIBS}) add_executable(velox_dwio_dwrf_int_direct_test TestIntDirect.cpp) add_test(velox_dwio_dwrf_int_direct_test velox_dwio_dwrf_int_direct_test) @@ -364,14 +393,18 @@ target_link_libraries( velox_link_libs fmt::fmt Folly::folly - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_rlev1_encoder_test TestRLEv1Encoder.cpp) add_test(velox_dwio_dwrf_rlev1_encoder_test velox_dwio_dwrf_rlev1_encoder_test) target_link_libraries( - velox_dwio_dwrf_rlev1_encoder_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwio_dwrf_rlev1_encoder_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_column_reader_test TestColumnReader.cpp) add_test(velox_dwio_dwrf_column_reader_test velox_dwio_dwrf_column_reader_test) @@ -381,13 +414,15 @@ target_link_libraries( velox_link_libs Folly::folly fmt::fmt - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwio_dwrf_reader_test ReaderTest.cpp) add_test( NAME velox_dwio_dwrf_reader_test COMMAND velox_dwio_dwrf_reader_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_dwio_dwrf_reader_test @@ -397,10 +432,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_e2e_filter_test E2EFilterTest.cpp) add_test(velox_dwrf_e2e_filter_test velox_dwrf_e2e_filter_test) @@ -413,19 +448,20 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) -add_executable(velox_dwrf_statistics_builder_utils_test - TestStatisticsBuilderUtils.cpp) -add_test(velox_dwrf_statistics_builder_utils_test - velox_dwrf_statistics_builder_utils_test) +add_executable(velox_dwrf_statistics_builder_utils_test TestStatisticsBuilderUtils.cpp) +add_test(velox_dwrf_statistics_builder_utils_test velox_dwrf_statistics_builder_utils_test) target_link_libraries( - velox_dwrf_statistics_builder_utils_test velox_link_libs Folly::folly - ${TEST_LINK_LIBS}) + velox_dwrf_statistics_builder_utils_test + velox_link_libs + Folly::folly + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_column_writer_test ColumnWriterTest.cpp) add_test(velox_dwrf_column_writer_test velox_dwrf_column_writer_test) @@ -438,14 +474,13 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_column_writer_index_test ColumnWriterIndexTest.cpp) -add_test(velox_dwrf_column_writer_index_test - velox_dwrf_column_writer_index_test) +add_test(velox_dwrf_column_writer_index_test velox_dwrf_column_writer_index_test) target_link_libraries( velox_dwrf_column_writer_index_test @@ -454,10 +489,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_writer_flush_test WriterFlushTest.cpp) add_test(velox_dwrf_writer_flush_test velox_dwrf_writer_flush_test) @@ -469,10 +504,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_e2e_writer_test E2EWriterTest.cpp) add_test(velox_dwrf_e2e_writer_test velox_dwrf_e2e_writer_test) @@ -486,10 +521,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_e2e_reader_test E2EReaderTest.cpp) add_test(velox_dwrf_e2e_reader_test velox_dwrf_e2e_reader_test) @@ -503,10 +538,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_writer_extended_test WriterExtendedTests.cpp) @@ -517,14 +552,13 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(velox_dwrf_column_writer_stats_test ColumnWriterStatsTests.cpp) -add_test(velox_dwrf_column_writer_stats_test - velox_dwrf_column_writer_stats_test) +add_test(velox_dwrf_column_writer_stats_test velox_dwrf_column_writer_stats_test) target_link_libraries( velox_dwrf_column_writer_stats_test @@ -533,10 +567,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_executable(physical_size_aggregator_test PhysicalSizeAggregatorTest.cpp) add_test(physical_size_aggregator_test physical_size_aggregator_test) @@ -548,10 +582,10 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) if(VELOX_ENABLE_BENCHMARKS) add_executable(velox_dwrf_int_encoder_benchmark IntEncoderBenchmark.cpp) @@ -561,10 +595,10 @@ if(VELOX_ENABLE_BENCHMARKS) velox_memory velox_dwio_common_exception Folly::folly - Folly::follybenchmark) + Folly::follybenchmark + ) - add_executable(velox_dwrf_float_column_writer_benchmark - FloatColumnWriterBenchmark.cpp) + add_executable(velox_dwrf_float_column_writer_benchmark FloatColumnWriterBenchmark.cpp) target_link_libraries( velox_dwrf_float_column_writer_benchmark velox_vector @@ -572,11 +606,11 @@ if(VELOX_ENABLE_BENCHMARKS) velox_dwio_dwrf_writer Folly::folly Folly::follybenchmark - fmt::fmt) + fmt::fmt + ) endif() -add_executable(velox_dwio_cache_test CacheInputTest.cpp - DirectBufferedInputTest.cpp) +add_executable(velox_dwio_cache_test CacheInputTest.cpp DirectBufferedInputTest.cpp) add_test(velox_dwio_cache_test velox_dwio_cache_test) @@ -588,7 +622,7 @@ target_link_libraries( Folly::folly fmt::fmt lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) diff --git a/velox/dwio/dwrf/test/CacheInputTest.cpp b/velox/dwio/dwrf/test/CacheInputTest.cpp index 39ac5ff92aa1..c43d6dc1f884 100644 --- a/velox/dwio/dwrf/test/CacheInputTest.cpp +++ b/velox/dwio/dwrf/test/CacheInputTest.cpp @@ -121,8 +121,9 @@ class CacheTest : public ::testing::Test { std::make_unique(cache_.get()); cache_->setVerifyHook(checkEntry); for (auto i = 0; i < kMaxStreams; ++i) { - streamIds_.push_back(std::make_unique( - i, i, 0, dwrf::StreamKind_DATA)); + streamIds_.push_back( + std::make_unique( + i, i, 0, dwrf::StreamKind_DATA)); } streamStarts_.resize(kMaxStreams + 1); streamStarts_[0] = 0; @@ -187,25 +188,23 @@ class CacheTest : public ::testing::Test { return lease.id(); } - std::shared_ptr - inputByPath(const std::string& path, uint64_t& fileId, uint64_t& groupId) { + std::shared_ptr inputByPath( + const std::string& path, + StringIdLease& fileId, + StringIdLease& groupId) { std::lock_guard l(mutex_); - StringIdLease fileLease(fileIds(), path); - fileId = fileLease.id(); - StringIdLease groupLease(fileIds(), fmt::format("group{}", fileId / 2)); - groupId = groupLease.id(); - auto it = pathToInput_.find(fileId); + fileId = StringIdLease{fileIds(), path}; + groupId = StringIdLease{fileIds(), fmt::format("group{}", fileId.id() / 2)}; + auto it = pathToInput_.find(fileId.id()); if (it != pathToInput_.end()) { return it->second; } - fileIds_.push_back(fileLease); - fileIds_.push_back(groupLease); + fileIds_.push_back(fileId); + fileIds_.push_back(groupId); // Creates an extremely large read file for test. auto stream = std::make_shared( - fileLease.id(), - 1UL << 63, - std::make_shared()); - pathToInput_[fileLease.id()] = stream; + fileId.id(), 1UL << 63, std::make_shared()); + pathToInput_[fileId.id()] = stream; return stream; } @@ -215,8 +214,8 @@ class CacheTest : public ::testing::Test { std::shared_ptr readFile, int32_t numColumns, std::shared_ptr tracker, - uint64_t fileId, - uint64_t groupId, + const StringIdLease& fileId, + const StringIdLease& groupId, int64_t offset, bool noCacheRetention, const IoStatisticsPtr& ioStats, @@ -356,11 +355,11 @@ class CacheTest : public ::testing::Test { io::ReaderOptions::kDefaultLoadQuantum, groupStats_); std::vector> stripes; - uint64_t fileId; - uint64_t groupId; + StringIdLease fileId; + StringIdLease groupId; auto readFile = inputByPath(filename, fileId, groupId); if (groupStats_) { - groupStats_->recordFile(fileId, groupId, numStripes); + groupStats_->recordFile(fileId.id(), groupId.id(), numStripes); } for (auto stripeIndex = 0; stripeIndex < numStripes; ++stripeIndex) { const auto firstPrefetchStripe = stripeIndex + stripes.size(); @@ -473,8 +472,8 @@ TEST_F(CacheTest, window) { nullptr, io::ReaderOptions::kDefaultLoadQuantum, groupStats_); - uint64_t fileId; - uint64_t groupId; + StringIdLease fileId; + StringIdLease groupId; auto file = inputByPath("test_for_window", fileId, groupId); auto input = std::make_unique( file, @@ -685,21 +684,22 @@ TEST_F(CacheTest, ssdThreads) { for (int i = 0; i < kNumThreads; ++i) { stats.push_back(std::make_shared()); fsStats.push_back(std::make_shared()); - threads.push_back(std::thread( - [i, this, threadStats = stats.back(), fsStat = fsStats.back()]() { - for (auto counter = 0; counter < 4; ++counter) { - readLoop( - fmt::format("testfile{}", i / 2), - 10, - 70, - 10, - 20, - 2, - /*noCacheRetention=*/false, - threadStats, - fsStat); - } - })); + threads.push_back( + std::thread( + [i, this, threadStats = stats.back(), fsStat = fsStats.back()]() { + for (auto counter = 0; counter < 4; ++counter) { + readLoop( + fmt::format("testfile{}", i / 2), + 10, + 70, + 10, + 20, + 2, + /*noCacheRetention=*/false, + threadStats, + fsStat); + } + })); } for (int i = 0; i < kNumThreads; ++i) { threads[i].join(); @@ -737,18 +737,19 @@ class FileWithReadAhead { bufferedInput_ = std::make_unique( file_, MetricsLog::voidLog(), - fileId_->id(), + *fileId_, cache, nullptr, - 0, + StringIdLease{}, stats, fsStats, executor, options_); auto sequential = StreamIdentifier::sequentialFile(); stream_ = bufferedInput_->enqueue(Region{0, file_->size()}, &sequential); - VELOX_CHECK(reinterpret_cast(stream_.get()) - ->testingNoCacheRetention()); + VELOX_CHECK( + reinterpret_cast(stream_.get()) + ->testingNoCacheRetention()); // Trigger load of next 4MB after reading the first 2MB of the previous 4MB // quantum. reinterpret_cast(stream_.get())->setPrefetchPct(50); @@ -790,58 +791,67 @@ TEST_F(CacheTest, readAhead) { for (int threadIndex = 0; threadIndex < kNumThreads; ++threadIndex) { stats.push_back(std::make_shared()); fsStats.push_back(std::make_shared()); - threads.push_back(std::thread([threadIndex, - this, - threadStats = stats.back(), - fsStat = fsStats.back()]() { - std::vector> files; - auto firstFileNumber = threadIndex * kFilesPerThread; - for (auto i = 0; i < kFilesPerThread; ++i) { - auto name = fmt::format("prefetch_{}", i + firstFileNumber); - files.push_back(std::make_unique( - name, cache_.get(), threadStats, fsStat, *pool_, executor_.get())); - } - std::vector totalRead(kFilesPerThread); - std::vector bytesLeft(kFilesPerThread); - for (auto counter = 0; counter < 100; ++counter) { - for (auto i = 0; i < kFilesPerThread; ++i) { - if (!files[i]) { - continue; // This set of files is finished. + threads.push_back( + std::thread([threadIndex, + this, + threadStats = stats.back(), + fsStat = fsStats.back()]() { + std::vector> files; + auto firstFileNumber = threadIndex * kFilesPerThread; + for (auto i = 0; i < kFilesPerThread; ++i) { + auto name = fmt::format("prefetch_{}", i + firstFileNumber); + files.push_back( + std::make_unique( + name, + cache_.get(), + threadStats, + fsStat, + *pool_, + executor_.get())); } - // Read from the next file. Different files advance at slightly - // different rates. - auto bytesNeeded = kMinRead + i * 1000; - while (bytesLeft[i] < bytesNeeded) { - const void* buffer; - int32_t size; - if (!files[i]->next(buffer, size)) { - // End of file. Check that a multiple of file size has been read. - EXPECT_EQ(0, totalRead[i] % FileWithReadAhead::kFileSize); - if (totalRead[i] >= 3 * FileWithReadAhead::kFileSize) { - files[i] = nullptr; - break; + std::vector totalRead(kFilesPerThread); + std::vector bytesLeft(kFilesPerThread); + for (auto counter = 0; counter < 100; ++counter) { + for (auto i = 0; i < kFilesPerThread; ++i) { + if (!files[i]) { + continue; // This set of files is finished. + } + // Read from the next file. Different files advance at slightly + // different rates. + auto bytesNeeded = kMinRead + i * 1000; + while (bytesLeft[i] < bytesNeeded) { + const void* buffer; + int32_t size; + if (!files[i]->next(buffer, size)) { + // End of file. Check that a multiple of file size has been + // read. + EXPECT_EQ(0, totalRead[i] % FileWithReadAhead::kFileSize); + if (totalRead[i] >= 3 * FileWithReadAhead::kFileSize) { + files[i] = nullptr; + break; + } + // Open a new file with a different unique name. + auto newName = fmt::format( + "prefetch_{}", + (static_cast(firstFileNumber) + i + i) * + 1000000000 + + totalRead[i]); + files[i] = std::make_unique( + newName, + cache_.get(), + threadStats, + fsStat, + *pool_, + executor_.get()); + continue; + } + totalRead[i] += size; + bytesLeft[i] += size; } - // Open a new file with a different unique name. - auto newName = fmt::format( - "prefetch_{}", - (static_cast(firstFileNumber) + i + i) * 1000000000 + - totalRead[i]); - files[i] = std::make_unique( - newName, - cache_.get(), - threadStats, - fsStat, - *pool_, - executor_.get()); - continue; + bytesLeft[i] -= bytesNeeded; } - totalRead[i] += size; - bytesLeft[i] += size; } - bytesLeft[i] -= bytesNeeded; - } - } - })); + })); } int64_t bytes = 0; int32_t count = 0; @@ -935,19 +945,19 @@ TEST_F(CacheTest, noCacheRetention) { TEST_F(CacheTest, loadQuotumTooLarge) { initializeCache(64 << 20, 256 << 20); - auto fileId = std::make_unique(fileIds(), "foo"); + StringIdLease fileId{fileIds(), "foo"}; auto readFile = - std::make_shared(fileId->id(), 10 << 20, nullptr); + std::make_shared(fileId.id(), 10 << 20, nullptr); auto readOptions = io::ReaderOptions(pool_.get()); readOptions.setLoadQuantum(9 << 20 /*9MB*/); VELOX_ASSERT_THROW( std::make_unique( readFile, MetricsLog::voidLog(), - fileId->id(), + fileId, cache_.get(), nullptr, - 0, + StringIdLease{}, nullptr, nullptr, executor_.get(), @@ -961,8 +971,8 @@ TEST_F(CacheTest, ssdReadVerification) { // 32 RAM, 256MB SSD, with checksumWrite/checksumReadVerification enabled. initializeCache(kMemoryBytes, kSsdBytes, true); - uint64_t fileId; - uint64_t groupId; + StringIdLease fileId; + StringIdLease groupId; auto file = inputByPath("test_file", fileId, groupId); auto tracker = std::make_shared( "testTracker", nullptr, io::ReaderOptions::kDefaultLoadQuantum); diff --git a/velox/dwio/dwrf/test/ColumnStatisticsBase.h b/velox/dwio/dwrf/test/ColumnStatisticsBase.h index 32b09aacfcf0..e1264cd9f27e 100644 --- a/velox/dwio/dwrf/test/ColumnStatisticsBase.h +++ b/velox/dwio/dwrf/test/ColumnStatisticsBase.h @@ -18,10 +18,14 @@ #include +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/Statistics.h" #include "velox/dwio/dwrf/writer/StatisticsBuilder.h" namespace facebook::velox::dwrf { + +using dwio::common::ArenaCreate; + class ColumnStatisticsBase { public: ColumnStatisticsBase() @@ -752,8 +756,7 @@ class ColumnStatisticsBase { if (format == DwrfFormat::kDwrf) { auto columnStatistics = - google::protobuf::Arena::CreateMessage( - arena_.get()); + ArenaCreate(arena_.get()); if (from == State::kFalse) { columnStatistics->set_hasnull(false); } else if (from == State::kTrue) { @@ -762,8 +765,8 @@ class ColumnStatisticsBase { target.merge(*buildColumnStatisticsFromProto( ColumnStatisticsWrapper(columnStatistics), context())); } else { - auto columnStatistics = google::protobuf::Arena::CreateMessage< - proto::orc::ColumnStatistics>(arena_.get()); + auto columnStatistics = + ArenaCreate(arena_.get()); if (from == State::kFalse) { columnStatistics->set_hasnull(false); } else if (from == State::kTrue) { diff --git a/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp b/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp index a8abecaed2fb..abf23b57b67b 100644 --- a/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp +++ b/velox/dwio/dwrf/test/ColumnWriterIndexTest.cpp @@ -371,9 +371,10 @@ class WriterEncodingIndexTest2 { } } proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Simulate continue writing to next stripe, so internally buffered data @@ -821,9 +822,10 @@ class IntegerColumnWriterDirectEncodingIndexTest : public testing::Test { // *all* streams EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } else { for (size_t i = 0; i != pageCount; ++i) { @@ -847,9 +849,10 @@ class IntegerColumnWriterDirectEncodingIndexTest : public testing::Test { EXPECT_CALL(*mockIndexBuilderPtr, flush()); EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } @@ -972,9 +975,10 @@ class StringColumnWriterDictionaryEncodingIndexTest : public testing::Test { // Recording PRESENT stream starting positions for the new stripe. EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(4); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Simulate continue writing to next stripe, so internally buffered data @@ -1128,9 +1132,10 @@ class StringColumnWriterDirectEncodingIndexTest : public testing::Test { // *all* streams EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } else { for (size_t i = 0; i != pageCount; ++i) { @@ -1154,9 +1159,10 @@ class StringColumnWriterDirectEncodingIndexTest : public testing::Test { EXPECT_CALL(*mockIndexBuilderPtr, flush()); EXPECT_CALL(*mockIndexBuilderPtr, add(0, -1)).Times(positionCount); proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); } @@ -1208,7 +1214,7 @@ class ListColumnWriterEncodingIndexTest : public testing::Test, public WriterEncodingIndexTest2 { public: ListColumnWriterEncodingIndexTest() - : WriterEncodingIndexTest2(ARRAY(REAL())){}; + : WriterEncodingIndexTest2(ARRAY(REAL())) {}; protected: static void SetUpTestCase() { diff --git a/velox/dwio/dwrf/test/ColumnWriterTest.cpp b/velox/dwio/dwrf/test/ColumnWriterTest.cpp index 2cf83e888804..88bb5d65c686 100644 --- a/velox/dwio/dwrf/test/ColumnWriterTest.cpp +++ b/velox/dwio/dwrf/test/ColumnWriterTest.cpp @@ -31,7 +31,6 @@ #include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/writer/Writer.h" #include "velox/type/Type.h" -#include "velox/vector/DictionaryVector.h" #include "velox/vector/tests/utils/VectorMaker.h" using namespace ::testing; @@ -214,7 +213,7 @@ VectorPtr populateBatch( auto valuesPtr = values->asMutableRange(); const size_t nulloptCount = - std::count(data.begin(), data.end(), std::nullopt); + std::count(data.cbegin(), data.cend(), std::nullopt); if (nulloptCount == 0) { size_t index = 0; for (auto val : data) { @@ -355,12 +354,13 @@ void testDataTypeWriter( for (auto stripeI = 0; stripeI < stripeCount; ++stripeI) { proto::StripeFooter sf; + auto sfw = StripeFooterWriteWrapper(&sf); for (auto strideI = 0; strideI < strideCount; ++strideI) { writer->write(batch, common::Ranges::of(0, size)); writer->createIndexEntry(); } - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); TestStripeStreams streams(context, sf, rowType, pool.get()); @@ -804,6 +804,30 @@ void printRow(const std::string& title, const VectorPtr& batch) { } } +void printFlatMap(const std::string& title, const VectorPtr& batch) { + auto mv = std::dynamic_pointer_cast(batch); + if (!mv) { + VLOG(3) << "To be implemented for encoded vector"; + return; + } + + VLOG(3) << "*******" << title << "*******"; + VLOG(3) << "Size: " << mv->size() << ", Null count: " << getNullCountStr(*mv); + + VLOG(3) << "Keys:"; + for (int i = 0; i < mv->numDistinctKeys(); i++) { + VLOG(3) << " Key:[" << i << "] " << mv->distinctKeys()->toString(i); + const auto& inMaps = mv->inMaps(); + for (int j = 0; j < mv->size(); j++) { + if (inMaps[i] == nullptr || + bits::isBitSet(inMaps[i]->as(), j)) { + VLOG(3) << " Value:[" << j << "] " + << mv->mapValues()[i]->toString(j); + } + } + } +} + VectorPtr wrapInDictionary(const VectorPtr& batch, size_t stride, MemoryPool& pool) { VectorPtr ret = batch; @@ -846,10 +870,8 @@ wrapInDictionary(const VectorPtr& batch, size_t stride, MemoryPool& pool) { return ret; } -VectorPtr wrapInDictionaryRow(const VectorPtr& batch, MemoryPool& pool) { - auto row = batch->as(); - - auto size = row->size(); +VectorPtr wrapInDictionary(const VectorPtr& batch, MemoryPool& pool) { + auto size = batch->size(); auto indices = AlignedBuffer::allocate(size, &pool); auto rawIndices = indices->asMutable(); for (auto i = 0; i < size; i++) { @@ -937,48 +959,72 @@ void mapToStruct( } } -template +enum class MapWriterInputType { kMap, kStruct, kFlatMap }; + +template void testMapWriter( MemoryPool& pool, const std::vector& batches, bool useFlatMap, + MapWriterInputType inputType, bool disableDictionaryEncoding, bool testEncoded, bool printMaps = true) { const auto rowType = CppToType>>::create(); const auto dataType = rowType->childAt(0); - const auto rowTypeWithId = TypeWithId::create(rowType); + const std::shared_ptr rowTypeWithId = + TypeWithId::create(rowType); const auto dataTypeWithId = rowTypeWithId->childAt(0); const auto writerSchema = TypeWithId::create(rowType); const auto writerDataTypeWithId = writerSchema->childAt(0); VLOG(2) << "Testing map writer " << dataType->toString() << " using " << (useFlatMap ? "Flat Map" : "Regular Map") - << (useFlatMap && useStruct ? " - Struct" : ""); + << (useFlatMap && inputType == MapWriterInputType::kStruct + ? " - Struct" + : "") + << (useFlatMap && inputType == MapWriterInputType::kFlatMap + ? " - FlatMap" + : ""); const auto config = std::make_shared(); auto* pBatches = &batches; - std::vector structs; + std::vector transformedInput; std::unordered_map> structReaderContext; if (useFlatMap) { - if constexpr (useStruct) { - structs = batches; - pBatches = &structs; - std::vector uniqueKeys; - ASSERT_NO_FATAL_FAILURE(getUniqueKeys(uniqueKeys, batches)); - ASSERT_NO_FATAL_FAILURE( - (mapToStruct(pool, structs, uniqueKeys))); - - std::vector uniqueKeysString; - uniqueKeysString.reserve(uniqueKeys.size()); - std::transform( - uniqueKeys.cbegin(), - uniqueKeys.cend(), - std::back_inserter(uniqueKeysString), - [](const auto& e) { return folly::to(e); }); - ASSERT_EQ(writerDataTypeWithId->column(), 0); - config->set(Config::MAP_FLAT_COLS_STRUCT_KEYS, {uniqueKeysString}); - structReaderContext[writerDataTypeWithId->id()] = uniqueKeysString; + if (inputType == MapWriterInputType::kStruct) { + if constexpr (!CppToType::isPrimitiveType) { + } else { + transformedInput = batches; + pBatches = &transformedInput; + std::vector uniqueKeys; + ASSERT_NO_FATAL_FAILURE(getUniqueKeys(uniqueKeys, batches)); + ASSERT_NO_FATAL_FAILURE( + (mapToStruct(pool, transformedInput, uniqueKeys))); + + std::vector uniqueKeysString; + uniqueKeysString.reserve(uniqueKeys.size()); + std::transform( + uniqueKeys.cbegin(), + uniqueKeys.cend(), + std::back_inserter(uniqueKeysString), + [](const auto& e) { return folly::to(e); }); + ASSERT_EQ(writerDataTypeWithId->column(), 0); + config->set(Config::MAP_FLAT_COLS_STRUCT_KEYS, {uniqueKeysString}); + structReaderContext[writerDataTypeWithId->id()] = uniqueKeysString; + } + } else if (inputType == MapWriterInputType::kFlatMap) { + transformedInput = batches; + pBatches = &transformedInput; + VectorMaker maker(&pool); + + for (size_t i = 0; i < transformedInput.size(); i++) { + auto mapBatch = + std::dynamic_pointer_cast(transformedInput[i]); + ASSERT_TRUE(mapBatch); + + transformedInput[i] = maker.flatMapVector(mapBatch); + } } config->set(Config::FLATTEN_MAP, true); @@ -998,24 +1044,27 @@ void testMapWriter( // Each batch represents an input for a separate stripe for (auto batch : *pBatches) { - auto isStruct = useFlatMap && useStruct; + auto isStruct = useFlatMap && inputType == MapWriterInputType::kStruct; if (printMaps) { if (isStruct) { printRow("Input", batch); + } else if (inputType == MapWriterInputType::kFlatMap) { + printFlatMap("Input", batch); } else { printMap("Input", batch); } } proto::StripeFooter sf; + auto sfw = StripeFooterWriteWrapper(&sf); std::vector writtenBatches; // Write map/row for (auto strideI = 0; strideI < strideCount; ++strideI) { auto toWrite = batch; if (testEncoded) { - if (isStruct) { - toWrite = wrapInDictionaryRow(toWrite, pool); + if (inputType != MapWriterInputType::kMap) { + toWrite = wrapInDictionary(toWrite, pool); } else { toWrite = wrapInDictionary(toWrite, strideI, pool); } @@ -1025,8 +1074,8 @@ void testMapWriter( writtenBatches.push_back(toWrite); } - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto validate = [&](bool returnFlatVector = false) { @@ -1047,6 +1096,8 @@ void testMapWriter( if (printMaps) { if (isStruct) { printRow("Result", batch); + } else if (inputType == MapWriterInputType::kFlatMap) { + printFlatMap("Result", out); } else { printMap("Result", out); } @@ -1150,19 +1201,20 @@ void testMapWriterRow( } proto::StripeFooter sf; + auto sfw = StripeFooterWriteWrapper(&sf); std::vector writtenBatches; // Write map/row auto toWrite = batch; if (testEncoded) { - toWrite = wrapInDictionaryRow(toWrite, pool); + toWrite = wrapInDictionary(toWrite, pool); } writer->write(toWrite, common::Ranges::of(0, toWrite->size())); writer->createIndexEntry(); writtenBatches.push_back(toWrite); - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto validate = [&](bool returnFlatVector = false) { @@ -1253,25 +1305,38 @@ TEST_F(ColumnWriterTest, TestMapWriterNestedRow) { testMapWriterRowImpl>(); } -template +template void testMapWriter( MemoryPool& pool, const VectorPtr& batch, bool useFlatMap, + MapWriterInputType inputType = MapWriterInputType::kMap, bool printMaps = true) { - std::vector batches{batch, batch}; - testMapWriter( - pool, batches, useFlatMap, true, false, printMaps); + testMapWriter( + pool, {batch, batch}, useFlatMap, inputType, printMaps); +} + +template +void testMapWriter( + MemoryPool& pool, + const std::vector& batches, + bool useFlatMap, + MapWriterInputType inputType = MapWriterInputType::kMap, + bool printMaps = true) { + testMapWriter( + pool, batches, useFlatMap, inputType, true, false, printMaps); if (useFlatMap) { - testMapWriter( - pool, batches, useFlatMap, false, false, printMaps); - testMapWriter( - pool, batches, useFlatMap, true, true, printMaps); + testMapWriter( + pool, batches, useFlatMap, inputType, false, false, printMaps); + testMapWriter( + pool, batches, useFlatMap, inputType, true, true, printMaps); } } -template -void testMapWriterNumericKey(bool useFlatMap) { +template +void testMapWriterNumericKey( + bool useFlatMap, + MapWriterInputType inputType = MapWriterInputType::kMap) { using b = MapBuilder; auto pool = memory::memoryManager()->addLeafPool(); @@ -1285,14 +1350,19 @@ void testMapWriterNumericKey(bool useFlatMap) { typename b::pair{ std::numeric_limits::min(), std::numeric_limits::min()}}}); - testMapWriter(*pool, batch, useFlatMap, true); + testMapWriter(*pool, batch, useFlatMap, inputType, true); } -// Workaround to avoid issues with two template arguments when wrapped in gtest -// EXPECT macros. +// Workaround to avoid issues with two template arguments when wrapped in +// gtest EXPECT macros. template void testMapWriterNumericKeyUseStruct(bool useFlatMap) { - testMapWriterNumericKey(useFlatMap); + testMapWriterNumericKey(useFlatMap, MapWriterInputType::kStruct); +} + +template +void testMapWriterNumericKeyUseFlatMap(bool useFlatMap) { + testMapWriterNumericKey(useFlatMap, MapWriterInputType::kFlatMap); } TEST_F(ColumnWriterTest, TestMapWriterFloatKey) { @@ -1308,12 +1378,20 @@ TEST_F(ColumnWriterTest, TestMapWriterFloatKey) { /* useFlatMap */ true); }, exception::LoggedException); + + EXPECT_THROW( + { + testMapWriterNumericKeyUseFlatMap( + /* useFlatMap */ true); + }, + exception::LoggedException); } TEST_F(ColumnWriterTest, TestMapWriterInt64Key) { testMapWriterNumericKey(/* useFlatMap */ false); testMapWriterNumericKey(/* useFlatMap */ true); testMapWriterNumericKeyUseStruct(/* useFlatMap */ true); + testMapWriterNumericKeyUseFlatMap(/* useFlatMap */ true); } TEST_F(ColumnWriterTest, TestMapWriterDuplicatedInt64Key) { @@ -1333,18 +1411,21 @@ TEST_F(ColumnWriterTest, TestMapWriterInt32Key) { testMapWriterNumericKey(/* useFlatMap */ false); testMapWriterNumericKey(/* useFlatMap */ true); testMapWriterNumericKeyUseStruct(/* useFlatMap */ true); + testMapWriterNumericKeyUseFlatMap(/* useFlatMap */ true); } TEST_F(ColumnWriterTest, TestMapWriterInt16Key) { testMapWriterNumericKey(/* useFlatMap */ false); testMapWriterNumericKey(/* useFlatMap */ true); testMapWriterNumericKeyUseStruct(/* useFlatMap */ true); + testMapWriterNumericKeyUseFlatMap(/* useFlatMap */ true); } TEST_F(ColumnWriterTest, TestMapWriterInt8Key) { testMapWriterNumericKey(/* useFlatMap */ false); testMapWriterNumericKey(/* useFlatMap */ true); testMapWriterNumericKeyUseStruct(/* useFlatMap */ true); + testMapWriterNumericKeyUseFlatMap(/* useFlatMap */ true); } TEST_F(ColumnWriterTest, TestMapWriterStringKey) { @@ -1360,8 +1441,104 @@ TEST_F(ColumnWriterTest, TestMapWriterStringKey) { testMapWriter(*pool_, batch, /* useFlatMap */ false); testMapWriter(*pool_, batch, /* useFlatMap */ true); - testMapWriter( - *pool_, batch, /* useFlatMap */ true, true); + testMapWriter( + *pool_, + batch, + /* useFlatMap */ true, + MapWriterInputType::kStruct, + true); + testMapWriter( + *pool_, + batch, + /* useFlatMap */ true, + MapWriterInputType::kFlatMap, + true); +} + +void testFlatMapWriter( + const std::vector& batches, + MemoryPool* pool) { + const auto rowType = + std::dynamic_pointer_cast(batches[0]->type()); + const auto dataType = rowType->childAt(0); + const std::shared_ptr rowTypeWithId = + TypeWithId::create(rowType); + const auto dataTypeWithId = rowTypeWithId->childAt(0); + const auto writerSchema = TypeWithId::create(rowType); + const auto writerDataTypeWithId = writerSchema->childAt(0); + + const auto config = std::make_shared(); + config->set(Config::FLATTEN_MAP, true); + config->set(Config::MAP_FLAT_COLS, {writerDataTypeWithId->column()}); + config->set(Config::MAP_FLAT_DISABLE_DICT_ENCODING, true); + + WriterContext context{config, memory::memoryManager()->addRootPool()}; + context.initBuffer(); + const auto writer = BaseColumnWriter::create(context, *writerDataTypeWithId); + + for (auto batch : batches) { + writer->write(batch->childAt(0), common::Ranges::of(0, batch->size())); + } + writer->createIndexEntry(); + + proto::StripeFooter sf; + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /*unused*/) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); + }); + + // Reading the vector out + TestStripeStreams streams(context, sf, rowType, pool); + EXPECT_CALL(streams.getMockStrideIndexProvider(), getStrideIndex()) + .WillRepeatedly(Return(0)); + auto reqType = rowTypeWithId->childAt(0); + memory::AllocationPool allocPool(pool); + StreamLabels labels(allocPool); + auto reader = + ColumnReader::build(reqType, reqType, streams, labels, nullptr, 0); + VectorPtr out; + + for (const auto& batch : batches) { + reader->next(batch->size(), out); + for (int32_t i = 0; i < batch->size(); ++i) { + ASSERT_TRUE(batch->childAt(0)->equalValueAt(out.get(), i, i)) + << "Row mismatch at index " << i; + } + } +} + +TEST_F(ColumnWriterTest, TestFlatMapKeyNotInAllBatches) { + VectorMaker maker(pool_.get()); + // Test the case where not all keys appear in all batches. + const std::vector batches{ + maker.rowVector({maker.flatMapVector( + {{{"1", "3"}, {"2", "2"}}, {{"2", "5"}, {"3", "8"}}})}), + maker.rowVector({maker.flatMapVector( + {{{"4", "3"}, {"5", "2"}}, {{"6", "5"}, {"7", "8"}}})}), + maker.rowVector({maker.flatMapVector( + {{{"1", "3"}, {"2", "2"}}, {{"2", "5"}, {"3", "8"}}})})}; + + testFlatMapWriter(batches, pool_.get()); +} + +TEST_F(ColumnWriterTest, TesFlatMapDuplicatedKey) { + const size_t size = 3; + const BufferPtr inMaps = AlignedBuffer::allocate(size, pool_.get()); + bits::fillBits(inMaps->asMutable(), 1, size, pool_.get()); + + VectorMaker maker(pool_.get()); + RowVectorPtr batch = maker.rowVector({std::make_shared( + pool_.get(), + MAP(VARCHAR(), VARCHAR()), + nullptr, + 3, + maker.flatVector({"2", "2"}), + std::vector{ + maker.flatVector({"1", "2", "3"}), maker.flatVector({"4", "5", "6"})}, + std::vector{inMaps, inMaps})}); + + VELOX_ASSERT_THROW( + testFlatMapWriter({batch}, pool_.get()), "Duplicated key in map: 2"); } TEST_F(ColumnWriterTest, TestMapWriterDuplicatedStringKey) { @@ -1434,7 +1611,7 @@ TEST_F(ColumnWriterTest, TestMapWriterMixedBatchTypeHandling) { *pool_, batches, /* useFlatMap */ true, - + MapWriterInputType::kMap, true, false)), ""); @@ -1453,8 +1630,18 @@ TEST_F(ColumnWriterTest, TestMapWriterBinaryKey) { testMapWriter(*pool_, batch, /* useFlatMap */ false); testMapWriter(*pool_, batch, /* useFlatMap */ true); - testMapWriter( - *pool_, batch, /* useFlatMap */ true, true); + testMapWriter( + *pool_, + batch, + /* useFlatMap */ true, + MapWriterInputType::kStruct, + true); + testMapWriter( + *pool_, + batch, + /* useFlatMap */ true, + MapWriterInputType::kFlatMap, + true); } template @@ -1504,12 +1691,14 @@ TEST_F(ColumnWriterTest, TestMapWriterDifferentStripeBatches) { *pool_, batches, /* useFlatMap */ false, + MapWriterInputType::kMap, false, false); testMapWriter( *pool_, batches, /* useFlatMap */ true, + MapWriterInputType::kMap, false, false); } @@ -2179,6 +2368,7 @@ struct IntegerColumnWriterTypedTestCase { for (size_t i = 0; i != flushCount; ++i) { proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); for (size_t j = 0; j != repetitionCount; ++j) { columnWriter->write(batch, common::Ranges::of(0, batch->size())); postProcess(*columnWriter, i, j); @@ -2186,8 +2376,8 @@ struct IntegerColumnWriterTypedTestCase { } // We only flush once per stripe. columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Read and verify. @@ -3339,7 +3529,7 @@ struct StringColumnWriterTestCase { postProcess{postProcess}, repetitionCount{repetitionCount}, flushCount{flushCount}, - type{CppToType::create()} {} + type{CppToType::create()} {} virtual ~StringColumnWriterTestCase() = default; @@ -3413,6 +3603,7 @@ struct StringColumnWriterTestCase { for (size_t i = 0; i != flushCount; ++i) { proto::StripeFooter stripeFooter; + auto sfw = StripeFooterWriteWrapper(&stripeFooter); // Write Stride for (size_t j = 0; j != repetitionCount; ++j) { // TODO: break the batch into multiple strides. @@ -3423,8 +3614,8 @@ struct StringColumnWriterTestCase { // Flush when all strides are written (once per stripe). columnWriter->flush( - [&stripeFooter](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *stripeFooter.add_encoding(); + [&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Read and verify. @@ -4253,8 +4444,9 @@ TEST_F(ColumnWriterTest, IntDictWriterDirectValueOverflow) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](auto /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto& enc = sf.encoding(0); ASSERT_EQ(enc.kind(), proto::ColumnEncoding_Kind_DICTIONARY); @@ -4298,8 +4490,9 @@ TEST_F(ColumnWriterTest, ShortDictWriterDictValueOverflow) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); auto& enc = sf.encoding(0); ASSERT_EQ(enc.kind(), proto::ColumnEncoding_Kind_DICTIONARY); @@ -4339,8 +4532,9 @@ TEST_F(ColumnWriterTest, RemovePresentStream) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // get data stream @@ -4377,8 +4571,9 @@ TEST_F(ColumnWriterTest, ColumnIdInStream) { writer->write(vector, common::Ranges::of(0, size)); writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](auto /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // get data stream @@ -4506,8 +4701,9 @@ struct DictColumnWriterTestCase { writer->createIndexEntry(); proto::StripeFooter sf; - writer->flush([&sf](uint32_t /* unused */) -> proto::ColumnEncoding& { - return *sf.add_encoding(); + auto sfw = StripeFooterWriteWrapper(&sf); + writer->flush([&sfw](uint32_t /* unused */) -> ColumnEncodingWriteWrapper { + return sfw.addEncoding(); }); // Reading the vector out diff --git a/velox/dwio/dwrf/test/CompressionTest.cpp b/velox/dwio/dwrf/test/CompressionTest.cpp index 33ba45cbf20a..9343af57e6dc 100644 --- a/velox/dwio/dwrf/test/CompressionTest.cpp +++ b/velox/dwio/dwrf/test/CompressionTest.cpp @@ -291,6 +291,8 @@ VELOX_INSTANTIATE_TEST_SUITE_P( Values( std::make_tuple(CompressionKind_ZLIB, nullptr, nullptr), std::make_tuple(CompressionKind_ZLIB, &testEncrypter, &testDecrypter), + std::make_tuple(CompressionKind_GZIP, nullptr, nullptr), + std::make_tuple(CompressionKind_GZIP, &testEncrypter, &testDecrypter), std::make_tuple(CompressionKind_ZSTD, nullptr, nullptr), std::make_tuple(CompressionKind_ZSTD, &testEncrypter, &testDecrypter), std::make_tuple(CompressionKind_NONE, nullptr, nullptr), @@ -399,10 +401,11 @@ TEST_P(CompressionTest, getCompressionBufferOOM) { {true, true}, {true, false}, {false, true}, {false, false}}; for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{} compression {}", - testData.debugString(), - compressionKindToString(kind_))); + SCOPED_TRACE( + fmt::format( + "{} compression {}", + testData.debugString(), + compressionKindToString(kind_))); auto config = std::make_shared(); config->set(Config::COMPRESSION, kind_); @@ -456,6 +459,8 @@ VELOX_INSTANTIATE_TEST_SUITE_P( Values( std::make_tuple(CompressionKind_ZLIB, nullptr), std::make_tuple(CompressionKind_ZLIB, &testEncrypter), + std::make_tuple(CompressionKind_GZIP, nullptr), + std::make_tuple(CompressionKind_GZIP, &testEncrypter), std::make_tuple(CompressionKind_ZSTD, nullptr), std::make_tuple(CompressionKind_ZSTD, &testEncrypter), std::make_tuple(CompressionKind_NONE, &testEncrypter))); diff --git a/velox/dwio/dwrf/test/DataBufferHolderTests.cpp b/velox/dwio/dwrf/test/DataBufferHolderTests.cpp index 6b53cbfab22e..d4b5f6e1238a 100644 --- a/velox/dwio/dwrf/test/DataBufferHolderTests.cpp +++ b/velox/dwio/dwrf/test/DataBufferHolderTests.cpp @@ -36,9 +36,15 @@ TEST_F(DataBufferHolderTest, InputCheck) { VELOX_ASSERT_THROW((DataBufferHolder{*pool_, 1024, 2048}), ""); VELOX_ASSERT_THROW((DataBufferHolder{*pool_, 1024, 1024, 1.1f}), ""); - { DataBufferHolder holder{*pool_, 1024}; } - { DataBufferHolder holder{*pool_, 1024, 512}; } - { DataBufferHolder holder{*pool_, 1024, 512, 3.0f}; } + { + DataBufferHolder holder{*pool_, 1024}; + } + { + DataBufferHolder holder{*pool_, 1024, 512}; + } + { + DataBufferHolder holder{*pool_, 1024, 512, 3.0f}; + } } TEST_F(DataBufferHolderTest, TakeAndGetBuffer) { diff --git a/velox/dwio/dwrf/test/DecryptionTests.cpp b/velox/dwio/dwrf/test/DecryptionTests.cpp index 6a19a41c4846..7cff7a4acb26 100644 --- a/velox/dwio/dwrf/test/DecryptionTests.cpp +++ b/velox/dwio/dwrf/test/DecryptionTests.cpp @@ -32,7 +32,8 @@ TEST(Decryption, NotEncrypted) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); TestDecrypterFactory factory; auto handler = DecryptionHandler::create(footer, &factory); ASSERT_FALSE(handler->isEncrypted()); @@ -42,7 +43,8 @@ TEST(Decryption, NoKeyProvider) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); footer.mutable_encryption(); TestDecrypterFactory factory; ASSERT_THROW( @@ -53,7 +55,8 @@ TEST(Decryption, EmptyGroup) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); TestDecrypterFactory factory; @@ -65,7 +68,8 @@ TEST(Decryption, EmptyNodes) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -79,7 +83,8 @@ TEST(Decryption, StatsMismatch) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -96,7 +101,8 @@ TEST(Decryption, KeyExistenceMismatch) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); for (size_t i = 0; i < 2; ++i) { @@ -116,7 +122,8 @@ TEST(Decryption, ReuseStripeKey) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -135,7 +142,8 @@ TEST(Decryption, StripeKeyMismatch) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -153,7 +161,8 @@ TEST(Decryption, Basic) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); for (auto i = 0; i < 5; ++i) { @@ -183,7 +192,8 @@ TEST(Decryption, NestedType) { auto type = parser.parse( "struct>,c:struct,d:array>"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); @@ -222,7 +232,8 @@ TEST(Decryption, RootNode) { HiveTypeParser parser; auto type = parser.parse("struct"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -238,7 +249,8 @@ TEST(Decryption, GroupOverlap) { HiveTypeParser parser; auto type = parser.parse("struct>"); proto::Footer footer; - ProtoUtils::writeType(*type, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer.mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); diff --git a/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp b/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp index 076e871ce88f..e3ddd5b457d9 100644 --- a/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp +++ b/velox/dwio/dwrf/test/DirectBufferedInputTest.cpp @@ -71,9 +71,9 @@ class DirectBufferedInputTest : public testing::Test { return std::make_unique( file_, dwio::common::MetricsLog::voidLog(), - 1, + StringIdLease{}, tracker_, - 2, + StringIdLease{}, ioStats_, fsStats, executor_.get(), diff --git a/velox/dwio/dwrf/test/E2EFilterTest.cpp b/velox/dwio/dwrf/test/E2EFilterTest.cpp index 43b67e91e550..d9ebb9f52e62 100644 --- a/velox/dwio/dwrf/test/E2EFilterTest.cpp +++ b/velox/dwio/dwrf/test/E2EFilterTest.cpp @@ -241,6 +241,62 @@ TEST_F(E2EFilterTest, floatAndDouble) { false); } +TEST_F(E2EFilterTest, DISABLED_shortDecimal) { + // ORC write functionality is not yet supported. Enable this test once it + // becomes available and set the file format to ORC at that time. + // options.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"shortdecimal_val:decimal(8, 5)", DECIMAL(8, 5)}, + {"shortdecimal_val:decimal(10, 5)", DECIMAL(10, 5)}, + {"shortdecimal_val:decimal(17, 5)", DECIMAL(17, 5)}}; + + for (const auto& pair : types) { + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "shortdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + false, + {"shortdecimal_val"}, + 20); + } +} + +TEST_F(E2EFilterTest, DISABLED_longDecimal) { + // ORC write functionality is not yet supported. Enable this test once it + // becomes available and set the file format to ORC at that time. + // options.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"longdecimal_val:decimal(30, 10)", DECIMAL(30, 10)}, + {"longdecimal_val:decimal(37, 15)", DECIMAL(37, 15)}}; + for (const auto& pair : types) { + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "longdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + false, + {"longdecimal_val"}, + 20); + } +} + TEST_F(E2EFilterTest, stringDirect) { testutil::TestValue::enable(); bool coverage[2][2]{}; diff --git a/velox/dwio/dwrf/test/E2EReaderTest.cpp b/velox/dwio/dwrf/test/E2EReaderTest.cpp index 6b696392910e..9ba378ea9da6 100644 --- a/velox/dwio/dwrf/test/E2EReaderTest.cpp +++ b/velox/dwio/dwrf/test/E2EReaderTest.cpp @@ -73,11 +73,11 @@ class ValueTypes { } auto begin() const { - return values_.begin(); + return values_.cbegin(); } auto end() const { - return values_.end(); + return values_.cend(); } const std::shared_ptr& decodingExecutor() const { @@ -244,127 +244,159 @@ TEST_P(E2EReaderTest, SharedDictionaryFlatmapReadAsStruct) { INSTANTIATE_TEST_SUITE_P( SingleTypesSerialMap, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"tinyint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"smallint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"integer"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"bigint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"string"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"tinyint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"smallint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"integer"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"bigint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"string"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::MAP, {"array"})})); INSTANTIATE_TEST_SUITE_P( SingleTypesSerialStruct, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"tinyint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"smallint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"integer"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"bigint"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"string"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"tinyint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"smallint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"integer"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"bigint"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"string"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), + ValueTypes( + Decoding::SERIAL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), + ValueTypes(Decoding::SERIAL, FlatMapAs::STRUCT, {"array"}), + ValueTypes( + Decoding::SERIAL, + FlatMapAs::STRUCT, + {"array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesSerialMap, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::SERIAL, - FlatMapAs::MAP, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::SERIAL, + FlatMapAs::MAP, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesSerialStruct, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::SERIAL, - FlatMapAs::STRUCT, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::SERIAL, + FlatMapAs::STRUCT, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); INSTANTIATE_TEST_SUITE_P( SingleTypesParallelMap, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"tinyint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"smallint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"integer"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"bigint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"string"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"tinyint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"smallint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"integer"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"bigint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"string"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::MAP, {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::MAP, + {"array"})})); INSTANTIATE_TEST_SUITE_P( SingleTypesParallelStruct, E2EReaderTest, - ValuesIn(std::vector{ - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"tinyint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"smallint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"integer"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"bigint"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"string"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"}), - ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"array"})})); + ValuesIn( + std::vector{ + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"tinyint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"smallint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"integer"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"bigint"}), + ValueTypes(Decoding::PARALLEL, FlatMapAs::STRUCT, {"string"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"}), + ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesParallelMap, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::PARALLEL, - FlatMapAs::MAP, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::PARALLEL, + FlatMapAs::MAP, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); INSTANTIATE_TEST_SUITE_P( AllTypesParallelStruct, E2EReaderTest, - ValuesIn(std::vector{ValueTypes( - Decoding::PARALLEL, - FlatMapAs::STRUCT, - {"tinyint", - "smallint", - "integer", - "bigint", - "string", - "array", - "array", - "array", - "array", - "array"})})); + ValuesIn( + std::vector{ValueTypes( + Decoding::PARALLEL, + FlatMapAs::STRUCT, + {"tinyint", + "smallint", + "integer", + "bigint", + "string", + "array", + "array", + "array", + "array", + "array"})})); diff --git a/velox/dwio/dwrf/test/E2EWriterTest.cpp b/velox/dwio/dwrf/test/E2EWriterTest.cpp index e069d56c8644..771a8078f707 100644 --- a/velox/dwio/dwrf/test/E2EWriterTest.cpp +++ b/velox/dwio/dwrf/test/E2EWriterTest.cpp @@ -261,7 +261,8 @@ class E2EWriterTest : public testing::Test { 0, 0, writerFlushThresholdSize, - "none"); + "none", + 0); } std::shared_ptr rootPool_; @@ -980,7 +981,7 @@ TEST_F(E2EWriterTest, OversizeRows) { config, /*flushPolicyFactory=*/nullptr, /*layoutPlannerFactory=*/nullptr, - /*memoryBudget=*/std::numeric_limits::max(), + /*writerMemoryCap=*/std::numeric_limits::max(), false); } @@ -1012,7 +1013,7 @@ TEST_F(E2EWriterTest, OversizeBatches) { config, /*flushPolicyFactory=*/nullptr, /*layoutPlannerFactory=*/nullptr, - /*memoryBudget=*/std::numeric_limits::max(), + /*writerMemoryCap=*/std::numeric_limits::max(), false); // Test splitting multiple huge batches. @@ -1028,7 +1029,7 @@ TEST_F(E2EWriterTest, OversizeBatches) { config, /*flushPolicyFactory=*/nullptr, /*layoutPlannerFactory=*/nullptr, - /*memoryBudget=*/std::numeric_limits::max(), + /*writerMemoryCap=*/std::numeric_limits::max(), false); } @@ -1088,7 +1089,7 @@ TEST_F(E2EWriterTest, OverflowLengthIncrements) { config, /*flushPolicyFactory=*/nullptr, /*layoutPlannerFactory=*/nullptr, - /*memoryBudget=*/std::numeric_limits::max(), + /*writerMemoryCap=*/std::numeric_limits::max(), false); } @@ -1634,6 +1635,76 @@ TEST_F(E2EWriterTest, fuzzFlatmap) { } } +TEST_F(E2EWriterTest, fuzzFlatmapWithFlatmapInput) { + auto pool = memory::memoryManager()->addLeafPool(); + auto type = ROW({ + {"flatmap1", MAP(INTEGER(), REAL())}, + {"flatmap2", MAP(VARCHAR(), ARRAY(REAL()))}, + {"flatmap3", MAP(INTEGER(), MAP(INTEGER(), REAL()))}, + }); + auto config = std::make_shared(); + config->set(dwrf::Config::FLATTEN_MAP, true); + config->set(dwrf::Config::MAP_FLAT_COLS, {0, 1, 2}); + auto seed = folly::Random::rand32(); + LOG(INFO) << "seed: " << seed; + std::mt19937 rng{seed}; + + // Small batches creates more edge cases. + size_t batchSize = 10; + VectorFuzzer fuzzer( + { + .vectorSize = batchSize, + .nullRatio = 0, + .stringLength = 20, + .stringVariableLength = true, + .containerLength = 5, + .containerVariableLength = true, + }, + pool.get(), + seed); + + auto genFlatMap = [&](auto type) { + auto& mapType = type->asMap(); + + std::vector inMaps; + std::vector values; + for (size_t key = 0; key < batchSize; ++key) { + inMaps.push_back(AlignedBuffer::allocate(batchSize, pool.get(), 0)); + auto* rawInMaps = inMaps.back()->asMutable(); + for (size_t row = 0; row < batchSize; ++row) { + bits::setBit(rawInMaps, row, folly::Random::oneIn(2, rng)); + } + + values.push_back(fuzzer.fuzz(mapType.valueType())); + } + + return std::make_shared( + pool.get(), + type, + nullptr, + batchSize, + createKeys(mapType.keyType(), *pool, rng, batchSize, batchSize), + std::move(values), + std::move(inMaps)); + }; + + auto gen = [&]() { + std::vector children(type->size()); + for (auto i = 0; i < type->size(); ++i) { + children[i] = genFlatMap(type->childAt(i)); + } + + return std::make_shared( + pool.get(), type, nullptr, batchSize, std::move(children)); + }; + + auto iterations = 20; + auto batches = 20; + for (auto i = 0; i < iterations; ++i) { + testWriter(*pool, type, batches, gen, config); + } +} + TEST_F(E2EWriterTest, memoryConfigError) { const auto type = ROW( {{"int_val", INTEGER()}, @@ -2070,9 +2141,10 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimThreshold) { } const std::vector writerFlushThresholdSizes = {0, 1L << 30}; for (uint64_t writerFlushThresholdSize : writerFlushThresholdSizes) { - SCOPED_TRACE(fmt::format( - "writerFlushThresholdSize {}", - succinctBytes(writerFlushThresholdSize))); + SCOPED_TRACE( + fmt::format( + "writerFlushThresholdSize {}", + succinctBytes(writerFlushThresholdSize))); const common::SpillConfig spillConfig = getSpillConfig(10, 20, writerFlushThresholdSize); diff --git a/velox/dwio/dwrf/test/EncodingManagerTests.cpp b/velox/dwio/dwrf/test/EncodingManagerTests.cpp index 681e0908157b..e53c276e29c5 100644 --- a/velox/dwio/dwrf/test/EncodingManagerTests.cpp +++ b/velox/dwio/dwrf/test/EncodingManagerTests.cpp @@ -34,10 +34,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -45,10 +45,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().end(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cend(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } footer.add_encoding(); // footer [e] @@ -58,10 +58,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -69,10 +69,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().end(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cend(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } proto::StripeEncryptionGroup group1; proto::StripeEncryptionGroup group2; @@ -85,10 +85,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -96,10 +96,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 1, - encryptionGroups.at(1).encoding().end(), - encryptionGroups.at(1).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(1).encoding().cend(), + encryptionGroups.at(1).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted end iterator. @@ -107,10 +107,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } encryptionGroups[1].add_encoding(); // footer [e] @@ -120,10 +120,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(footer.encoding().begin(), iter.current_); - EXPECT_EQ(footer.encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(footer.encoding().cbegin(), iter.current_); + EXPECT_EQ(footer.encoding().cend(), iter.currentEnd_); } { // An adjusted iterator. @@ -131,10 +131,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted iterator. @@ -142,10 +142,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 0, - encryptionGroups.at(0).encoding().begin(), - encryptionGroups.at(0).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(0).encoding().cbegin(), + encryptionGroups.at(0).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // A valid end iterator. @@ -153,10 +153,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 1, - encryptionGroups.at(1).encoding().end(), - encryptionGroups.at(1).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(1).encoding().cend(), + encryptionGroups.at(1).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } footer.Clear(); // footer [] @@ -167,10 +167,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cbegin(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted iterator further back. @@ -178,10 +178,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().begin(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cbegin(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } encryptionGroups.at(1).Clear(); // footer [] @@ -192,10 +192,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, 0, - encryptionGroups.at(0).encoding().end(), - encryptionGroups.at(0).encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + encryptionGroups.at(0).encoding().cend(), + encryptionGroups.at(0).encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } { // An adjusted end iterator further back. @@ -203,10 +203,10 @@ TEST(TestEncodingIter, Ctor) { footer, encryptionGroups, -1, - footer.encoding().end(), - footer.encoding().end()}; - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.current_); - EXPECT_EQ(encryptionGroups.at(1).encoding().end(), iter.currentEnd_); + footer.encoding().cend(), + footer.encoding().cend()}; + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.current_); + EXPECT_EQ(encryptionGroups.at(1).encoding().cend(), iter.currentEnd_); } } @@ -227,15 +227,15 @@ TEST(TestEncodingIter, EncodingIterBeginAndEnd) { footer, encryptionGroups, -1, - footer.encoding().begin(), - footer.encoding().end()}; + footer.encoding().cbegin(), + footer.encoding().cend()}; EXPECT_EQ(begin, EncodingIter::begin(footer, encryptionGroups)); EncodingIter end{ footer, encryptionGroups, 1, - encryptionGroups.at(1).encoding().end(), - encryptionGroups.at(1).encoding().end()}; + encryptionGroups.at(1).encoding().cend(), + encryptionGroups.at(1).encoding().cend()}; EXPECT_EQ(end, EncodingIter::end(footer, encryptionGroups)); } diff --git a/velox/dwio/dwrf/test/FlushPolicyTest.cpp b/velox/dwio/dwrf/test/FlushPolicyTest.cpp index ed10f458aa2b..b045326664e9 100644 --- a/velox/dwio/dwrf/test/FlushPolicyTest.cpp +++ b/velox/dwio/dwrf/test/FlushPolicyTest.cpp @@ -51,8 +51,9 @@ TEST_F(DefaultFlushPolicyTest, StripeProgressTest) { testCase.stripeSizeThreshold, /*dictionarySizeThreshold=*/0}; EXPECT_EQ( testCase.shouldFlush, - policy.shouldFlush(dwio::common::StripeProgress{ - .stripeSizeEstimate = testCase.stripeSize})); + policy.shouldFlush( + dwio::common::StripeProgress{ + .stripeSizeEstimate = testCase.stripeSize})); } } @@ -115,8 +116,8 @@ TEST_F(DefaultFlushPolicyTest, AdditionalCriteriaTest) { .dictionarySize = 42, .decision = FlushDecision::SKIP}}; for (const auto& testCase : testCases) { - DefaultFlushPolicy policy{ - /*stripeSizeThreshold=*/1000, testCase.dictionarySizeThreshold}; + DefaultFlushPolicy policy{/*stripeSizeThreshold=*/1000, + testCase.dictionarySizeThreshold}; EXPECT_EQ( testCase.decision, policy.shouldFlushDictionary( diff --git a/velox/dwio/dwrf/test/IndexBuilderTests.cpp b/velox/dwio/dwrf/test/IndexBuilderTests.cpp index 66969d007905..ee5914e236af 100644 --- a/velox/dwio/dwrf/test/IndexBuilderTests.cpp +++ b/velox/dwio/dwrf/test/IndexBuilderTests.cpp @@ -25,17 +25,17 @@ namespace facebook::velox::dwrf { class IndexBuilderTest : public testing::Test { protected: - static const proto::RowIndexEntry& getEntry( + static const RowIndexEntryWriteWrapper getEntry( IndexBuilder& builder, size_t index) { - return *builder.getEntry(index); + return builder.getEntry(index); } static std::vector getPositions( IndexBuilder& builder, size_t index) { - auto& positions = builder.getEntry(index)->positions(); - return std::vector{positions.begin(), positions.end()}; + auto& positions = builder.getEntry(index).positions(); + return std::vector{positions.cbegin(), positions.cend()}; } StatisticsBuilderOptions options_{16}; @@ -45,7 +45,7 @@ TEST_F(IndexBuilderTest, Constructor) { IndexBuilder builder{nullptr}; EXPECT_EQ(1, builder.getEntrySize()); // Ensure a clean start. - EXPECT_EQ(0, getEntry(builder, 0).positions_size()); + EXPECT_EQ(0, getEntry(builder, 0).positionsSize()); } TEST_F(IndexBuilderTest, AddEntry) { @@ -65,8 +65,8 @@ TEST_F(IndexBuilderTest, AddEntry) { ASSERT_EQ(51, builder.getEntrySize()); for (size_t i = 0; i != 50; ++i) { // The newly added entries should be empty. - EXPECT_EQ(0, getEntry(builder, i + 1).positions_size()); - EXPECT_TRUE(getEntry(builder, i).has_statistics()); + EXPECT_EQ(0, getEntry(builder, i + 1).positionsSize()); + EXPECT_TRUE(getEntry(builder, i).hasStatistics()); } } @@ -94,9 +94,9 @@ TEST_F(IndexBuilderTest, Backfill) { IndexBuilder builder{nullptr}; StatisticsBuilder sb{options_}; builder.addEntry(sb); - ASSERT_EQ(0, getEntry(builder, 0).positions_size()); + ASSERT_EQ(0, getEntry(builder, 0).positionsSize()); builder.add(0uL); - ASSERT_EQ(0, getEntry(builder, 0).positions_size()); + ASSERT_EQ(0, getEntry(builder, 0).positionsSize()); ASSERT_THAT(getPositions(builder, 1), ElementsAreArray({0uL})); builder.add(42uL, 0); @@ -112,16 +112,16 @@ TEST_F(IndexBuilderTest, Backfill) { builder.addEntry(sb); } for (size_t i = 2; i != 7; ++i) { - ASSERT_EQ(0, getEntry(builder, i).positions_size()); + ASSERT_EQ(0, getEntry(builder, i).positionsSize()); } builder.add(144uL, 4); EXPECT_THAT(getPositions(builder, 0), ElementsAreArray({42uL})); EXPECT_THAT(getPositions(builder, 1), ElementsAreArray({0uL, 0uL, 7uL})); - ASSERT_EQ(0, getEntry(builder, 2).positions_size()); - ASSERT_EQ(0, getEntry(builder, 3).positions_size()); + ASSERT_EQ(0, getEntry(builder, 2).positionsSize()); + ASSERT_EQ(0, getEntry(builder, 3).positionsSize()); EXPECT_THAT(getPositions(builder, 4), ElementsAreArray({144uL})); - ASSERT_EQ(0, getEntry(builder, 5).positions_size()); - ASSERT_EQ(0, getEntry(builder, 6).positions_size()); + ASSERT_EQ(0, getEntry(builder, 5).positionsSize()); + ASSERT_EQ(0, getEntry(builder, 6).positionsSize()); } TEST_F(IndexBuilderTest, RemovePresentStreamPositions) { diff --git a/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp b/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp index 644ddab0c587..e79fa02bf759 100644 --- a/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp +++ b/velox/dwio/dwrf/test/IntEncoderBenchmark.cpp @@ -117,8 +117,9 @@ FOLLY_ALWAYS_INLINE static int32_t findSetBitsNew(uint64_t value) { case 57 ... 63: return 1; } - DWIO_RAISE(folly::sformat( - "Unexpected leading zeros {} for value {}", leadingZeros, value)); + DWIO_RAISE( + folly::sformat( + "Unexpected leading zeros {} for value {}", leadingZeros, value)); } size_t iters = 2000; diff --git a/velox/dwio/dwrf/test/LayoutPlannerTests.cpp b/velox/dwio/dwrf/test/LayoutPlannerTests.cpp index a2de14e8e91f..32cf1a5062bd 100644 --- a/velox/dwio/dwrf/test/LayoutPlannerTests.cpp +++ b/velox/dwio/dwrf/test/LayoutPlannerTests.cpp @@ -128,12 +128,12 @@ TEST_F(LayoutPlannerTest, Basic) { uint32_t seq, proto::ColumnEncoding_Kind kind, std::optional key = std::nullopt) { - auto& encoding = encodingManager.addEncodingToFooter(node); - encoding.set_node(node); - encoding.set_sequence(seq); - encoding.set_kind(kind); + auto encoding = encodingManager.addEncodingToFooter(node); + encoding.setNode(node); + encoding.setSequence(seq); + encoding.setKind(ColumnEncodingKindWrapper(&kind)); if (key.has_value()) { - encoding.mutable_key()->set_intkey(*key); + encoding.mutableKey()->set_intkey(*key); } }; diff --git a/velox/dwio/dwrf/test/OrcTest.h b/velox/dwio/dwrf/test/OrcTest.h index 3bd87c45f075..6b0b3987ac2c 100644 --- a/velox/dwio/dwrf/test/OrcTest.h +++ b/velox/dwio/dwrf/test/OrcTest.h @@ -28,8 +28,6 @@ namespace facebook::velox::dwrf { -#define VELOX_ARRAY_SIZE(array) (sizeof(array) / sizeof(*array)) - using MemoryPool = memory::MemoryPool; inline std::string getExampleFilePath(const std::string& fileName) { diff --git a/velox/dwio/dwrf/test/ReaderBaseTests.cpp b/velox/dwio/dwrf/test/ReaderBaseTests.cpp index 04f0685c2521..8afd8b5e643e 100644 --- a/velox/dwio/dwrf/test/ReaderBaseTests.cpp +++ b/velox/dwio/dwrf/test/ReaderBaseTests.cpp @@ -16,7 +16,9 @@ #include #include + #include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/InputStream.h" #include "velox/dwio/common/encryption/TestProvider.h" #include "velox/dwio/common/exception/Exception.h" @@ -72,14 +74,14 @@ class EncryptedStatsTest : public Test { TestEncrypter encrypter; HiveTypeParser parser; auto type = parser.parse("struct,c:int,d:int>"); - auto footer = - google::protobuf::Arena::CreateMessage(&arena_); + auto footer = ArenaCreate(&arena_); + auto footerWrapper = FooterWriteWrapper(footer); // add empty stats to the file for (size_t i = 0; i < 7; ++i) { - footer->add_statistics()->set_numberofvalues(i); + footerWrapper.addStatistics().setNumberOfValues(i); } - ProtoUtils::writeType(*type, *footer); - auto enc = footer->mutable_encryption(); + ProtoUtils::writeType(*type, footerWrapper); + auto enc = footerWrapper.mutableEncryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); group->add_nodes(1); @@ -107,7 +109,7 @@ class EncryptedStatsTest : public Test { *readerPool_, std::make_unique(readFile, *readerPool_), std::make_unique(std::move(ps)), - footer, + footerWrapper.getDwrfPtr(), nullptr, std::move(handler)); } diff --git a/velox/dwio/dwrf/test/ReaderTest.cpp b/velox/dwio/dwrf/test/ReaderTest.cpp index b132757ee54f..6dcb1896e8b2 100644 --- a/velox/dwio/dwrf/test/ReaderTest.cpp +++ b/velox/dwio/dwrf/test/ReaderTest.cpp @@ -18,8 +18,11 @@ #include #include "folly/Random.h" #include "folly/executors/CPUThreadPoolExecutor.h" +#include "folly/executors/IOThreadPoolExecutor.h" #include "folly/lang/Assume.h" +#include "folly/synchronization/Baton.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/ExecutorBarrier.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" @@ -27,6 +30,7 @@ #include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/test/OrcTest.h" #include "velox/dwio/dwrf/test/utils/E2EWriterTestUtil.h" +#include "velox/dwio/dwrf/writer/Writer.h" #include "velox/type/fbhive/HiveTypeParser.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/FlatVector.h" @@ -38,14 +42,14 @@ #include #include +namespace facebook::velox::dwrf { +namespace { + using namespace ::testing; using namespace facebook::velox::dwio::common; using namespace facebook::velox::type::fbhive; -using namespace facebook::velox; -using namespace facebook::velox::dwrf; using namespace facebook::velox::test; -namespace { const std::string& getStructFile() { static const std::string structFile_ = getExampleFilePath("struct.orc"); return structFile_; @@ -80,6 +84,7 @@ class TestReaderP protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + facebook::velox::common::testutil::TestValue::enable(); } folly::Executor* executor() { @@ -102,6 +107,7 @@ class TestReader : public testing::Test, public VectorTestBase { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + facebook::velox::common::testutil::TestValue::enable(); } std::vector createBatches( @@ -116,8 +122,6 @@ class TestReader : public testing::Test, public VectorTestBase { } }; -} // namespace - TEST_F(TestReader, testWriterVersions) { EXPECT_EQ("original", writerVersionToString(ORIGINAL)); EXPECT_EQ("dwrf-4.9", writerVersionToString(DWRF_4_9)); @@ -823,7 +827,7 @@ TEST_F(TestReader, testEstimatedSize) { rowReaderOpts.select(cs); auto rowReader = reader->createRowReader(rowReaderOpts); - ASSERT_EQ(rowReader->estimatedRowSize(), 79); + ASSERT_EQ(rowReader->estimatedRowSize(), 67); } { @@ -835,7 +839,7 @@ TEST_F(TestReader, testEstimatedSize) { RowReaderOptions rowReaderOpts; rowReaderOpts.select(cs); auto rowReader = reader->createRowReader(rowReaderOpts); - ASSERT_EQ(rowReader->estimatedRowSize(), 13); + ASSERT_EQ(rowReader->estimatedRowSize(), 4); } } @@ -1053,8 +1057,9 @@ TEST_F(TestReader, testMismatchSchemaMoreFields) { std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:float,d:string>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{1, 2, 3})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{1, 2, 3})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1098,8 +1103,9 @@ TEST_F(TestReader, testMismatchSchemaFewerFields) { std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{1})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{1})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1140,8 +1146,9 @@ TEST_F(TestReader, testMismatchSchemaNestedMoreFields) { std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:float>")); LOG(INFO) << requestedType->toString(); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{"b.b", "b.c", "b.d", "c"})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{"b.b", "b.c", "b.d", "c"})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1205,8 +1212,9 @@ TEST_F(TestReader, testMismatchSchemaNestedFewerFields) { std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:float>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{"b.b", "c"})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{"b.b", "c"})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -1262,8 +1270,9 @@ TEST_F(TestReader, testMismatchSchemaIncompatibleNotSelected) { std::shared_ptr requestedType = std::dynamic_pointer_cast(HiveTypeParser().parse( "struct,c:int>")); - rowReaderOpts.select(std::make_shared( - requestedType, std::vector{"b.b"})); + rowReaderOpts.select( + std::make_shared( + requestedType, std::vector{"b.b"})); auto reader = DwrfReader::create( createFileBufferedInput(getStructFile(), readerOpts.memoryPool()), readerOpts); @@ -2269,24 +2278,25 @@ TEST_F(TestReader, failToReuseReaderNulls) { TEST_F(TestReader, readFlatMapsSomeEmpty) { // Test reading a flat map where the key filter means that some maps are // empty. - auto keys = makeFlatVector(std::vector{ - 1, - 2, - 3, - 4, - 5, - 6, // map 1 has more than just the selected keys. - 1, - 2, - 3, // map 2 has only selected keys. - 4, - 5, - 6, // map 3 has no selected keys. - 1, - 2, - 5, - 6 // map 4 has some selected keys. - }); + auto keys = makeFlatVector( + std::vector{ + 1, + 2, + 3, + 4, + 5, + 6, // map 1 has more than just the selected keys. + 1, + 2, + 3, // map 2 has only selected keys. + 4, + 5, + 6, // map 3 has no selected keys. + 1, + 2, + 5, + 6 // map 4 has some selected keys. + }); auto values = makeFlatVector(16, folly::identity); auto maps = makeMapVector(std::vector{0, 6, 9, 12, 16}, keys, values); @@ -2412,6 +2422,66 @@ TEST_F(TestReader, readFlatMapsWithNullMaps) { } } +TEST_F(TestReader, readFlatMapsAsFlatMaps) { + auto testRoundTrip = [&](const FlatMapVectorPtr& flatMap) { + auto input = makeRowVector({flatMap->toMapVector()}); + + std::shared_ptr config = std::make_shared(); + config->set(dwrf::Config::FLATTEN_MAP, true); + config->set(dwrf::Config::MAP_FLAT_COLS, {0}); + + auto [writer, reader] = createWriterReader({input}, pool(), config); + + auto schema = asRowType(input->type()); + auto spec = std::make_shared(""); + spec->addAllChildFields(*schema); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + rowReaderOpts.setPreserveFlatMapsInMemory(true); + + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(schema, 0, pool()); + + rowReader->next(flatMap->size(), batch); + auto rowVector = batch->as(); + auto resultMaps = rowVector->childAt(0); + + assertEqualVectors(flatMap, resultMaps); + }; + + testRoundTrip( + makeFlatMapVector({ + {}, + {{1, 1.9}, {2, 2.1}, {0, 3.12}}, + {{127, 0.12}}, + })); + + testRoundTrip( + makeFlatMapVector({ + {{"a", "a1"}}, + {{"b", "b1"}}, + {{"c", "c1"}}, + {{"d", "d1"}}, + })); + + testRoundTrip( + makeNullableFlatMapVector({ + {{{101, 1}, {102, 2}, {103, 3}}}, + {{{105, 0}, {106, 0}}}, + {std::nullopt}, + {{{101, 11}, {103, 13}, {105, std::nullopt}}}, + {{{101, 1}, {102, 2}, {103, 3}}}, + })); + + testRoundTrip( + makeFlatMapVector( + {{{0, 0}, {1, 1}, {2, 2}, {3, 3}}, + {{0, 4}, {1, 5}, {2, 6}, {3, 7}}, + {{0, 8}, {1, 9}, {2, 10}, {3, 11}}, + {{0, 12}, {1, 13}, {2, 14}, {3, 15}}})); +} + TEST_F(TestReader, readStructWithWholeBatchFiltered) { // Test reading a struct with a pushdown filter that filters out all rows // for a certain batch. @@ -2511,8 +2581,9 @@ TEST_F(TestReader, readStringDictionaryAsFlat) { dwio::common::RuntimeStatistics stats; rowReader->updateRuntimeStats(stats); ASSERT_EQ(stats.columnReaderStatistics.flattenStringDictionaryValues, 0); - spec->childByName("c0")->setFilter(std::make_unique( - std::vector{"aaaaaaaaaaaaaaaaaaaa"}, false)); + spec->childByName("c0")->setFilter( + std::make_unique( + std::vector{"aaaaaaaaaaaaaaaaaaaa"}, false)); spec->resetCachedValues(true); rowReader = reader->createRowReader(rowReaderOpts); ASSERT_EQ(rowReader->next(20, actual), 20); @@ -2665,3 +2736,188 @@ TEST_F(TestReader, skipLongString) { validate(batch); } } + +TEST_F(TestReader, mapAsStruct) { + auto row = makeRowVector({ + makeMapVector({{{1, 4}, {2, 5}}, {{1, 6}, {3, 7}}}), + }); + auto [writer, reader] = createWriterReader({row}, pool()); + auto outType = ROW({"c0"}, {ROW({"3", "1"}, BIGINT())}); + auto spec = std::make_shared(""); + spec->addAllChildFields(*outType); + spec->childByName("c0")->setFlatMapAsStruct(true); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(outType, 0, pool()); + ASSERT_EQ(rowReader->next(10, batch), 2); + auto expected = makeRowVector({ + makeRowVector( + {"3", "1"}, + { + makeNullableFlatVector({std::nullopt, 7}), + makeFlatVector({4, 6}), + }), + }); + assertEqualVectors(expected, batch); +} + +TEST_F(TestReader, mapAsStructFilterAfterRead) { + auto row = makeRowVector({ + makeMapVector({{{1, 4}, {2, 5}}, {}, {{1, 6}, {3, 7}}}), + makeRowVector( + {makeConstant(0, 3)}, [](auto i) { return i == 0; }), + }); + auto [writer, reader] = createWriterReader({row}, pool()); + auto outType = + ROW({"c0", "c1"}, {ROW({"3", "1"}, BIGINT()), ROW({"c0"}, BIGINT())}); + auto spec = std::make_shared(""); + spec->addAllChildFields(*outType); + auto* c0Spec = spec->childByName("c0"); + c0Spec->setFlatMapAsStruct(true); + c0Spec->setFilter(std::make_shared()); + spec->childByName("c1")->setFilter(std::make_shared()); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(outType, 0, pool()); + ASSERT_EQ(rowReader->next(10, batch), 3); + auto expected = makeRowVector({ + makeRowVector( + {"3", "1"}, + { + makeNullableFlatVector({std::nullopt, 7}), + makeNullableFlatVector({std::nullopt, 6}), + }), + makeRowVector({makeConstant(0, 2)}), + }); + assertEqualVectors(expected, batch); +} + +TEST_F(TestReader, mapAsStructAllEmpty) { + auto row = makeRowVector({makeMapVector({{}, {}})}); + auto [writer, reader] = createWriterReader({row}, pool()); + auto outType = ROW({"c0"}, {ROW({"1"}, BIGINT())}); + auto spec = std::make_shared(""); + spec->addAllChildFields(*outType); + spec->childByName("c0")->setFlatMapAsStruct(true); + RowReaderOptions rowReaderOpts; + rowReaderOpts.setScanSpec(spec); + auto rowReader = reader->createRowReader(rowReaderOpts); + VectorPtr batch = BaseVector::create(outType, 0, pool()); + ASSERT_EQ(rowReader->next(10, batch), 2); + auto expected = makeRowVector({ + makeRowVector({"1"}, {makeNullConstant(TypeKind::BIGINT, 2)}), + }); + assertEqualVectors(expected, batch); +} + +// Verify DwrfRowReader can be destroyed while ParallelUnitLoader async load() +// are in progress. This regression test ensures that: +// 1. ParallelUnitLoader destructor doesn't wait for async load() operations +// 2. Async load() from DwrfUnit can still function after ParallelUnitLoader +// destruction and DwrfRowReader destruction, which means all dependencies in +// DwrfUnit remain valid (eg ReaderBase) +// +// If a future change adds an unsafe raw pointer to DwrfUnit's dependencies that +// would be freed by ParallelUnitLoader or DwrfRowReader's destruction, this +// test may crash due to use-after-free. +DEBUG_ONLY_TEST_F(TestReader, asyncLoadSurvivesReaderDestruction) { + const int kNumStripes = 2; + const int kRowsPerStripe = 100; + std::vector batches; + batches.reserve(kNumStripes); + for (int stripe = 0; stripe < kNumStripes; ++stripe) { + batches.push_back(makeRowVector({ + makeFlatVector( + kRowsPerStripe, + [stripe](auto row) { return stripe * kRowsPerStripe + row; }), + })); + } + + // Write the DWRF file - force each batch into its own stripe + auto config = std::make_shared(); + + auto sink = + std::make_unique(1 << 20, FileSink::Options{.pool = pool()}); + auto* sinkPtr = sink.get(); + auto writer = E2EWriterTestUtil::writeData( + std::move(sink), + asRowType(batches[0]->type()), + batches, + config, + // Force flush after each batch to create separate stripes + E2EWriterTestUtil::simpleFlushPolicyFactory(true)); + + std::string data(sinkPtr->data(), sinkPtr->size()); + auto input = std::make_unique( + std::make_shared(std::move(data)), *pool()); + + std::atomic asyncLoadsStarted{0}; + std::atomic asyncLoadsCompleted{0}; + folly::Baton<> readerDestroyed; + + SCOPED_TESTVALUE_SET( + "facebook::velox::dwio::common::ParallelUnitLoader::load", + std::function([&](void*) { + // Only block the second stripe (index 1) - let the first stripe load + // normally so rowReader->next() can complete + // fetch_add returns the value before increment: 0 for first, 1 for + // second, etc. + if (asyncLoadsStarted.fetch_add(1) == 1) { + // Block here until reader is destroyed + readerDestroyed.wait(); + } + asyncLoadsCompleted.fetch_add(1); + })); + + auto ioExecutor = std::make_shared(2); + + // Make sure ReaderOptions and DwrfRowReader are freed after {} scope + { + dwio::common::ReaderOptions readerOpts(pool()); + readerOpts.setFileFormat(FileFormat::DWRF); + auto reader = DwrfReader::create(std::move(input), readerOpts); + + // Enable parallel unit load + RowReaderOptions rowReaderOpts; + rowReaderOpts.setParallelUnitLoadCount(2); + rowReaderOpts.setIOExecutor(ioExecutor.get()); + auto rowReader = reader->createRowReader(rowReaderOpts); + + VectorPtr batch; + rowReader->next(50, batch); // Read first stripe + + auto start = std::chrono::steady_clock::now(); + rowReader.reset(); + auto duration = std::chrono::steady_clock::now() - start; + // Verify destruction was fast (didn't wait for async operations) + EXPECT_LT(duration, std::chrono::seconds(1)) + << "Destruction should not wait for async loads"; + } + + // Now signal that reader is destroyed + readerDestroyed.post(); + + // Wait for async loads to complete + int maxWaitMs = 2000; + int waitedMs = 0; + while (asyncLoadsCompleted.load() < 2 && waitedMs < maxWaitMs) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waitedMs += 100; + } + + // Verify that both async loads completed successfully after reader + // destruction This proves the fix works: async operations can complete even + // after DwrfRowReader is destroyed because LoadUnit is captured as shared_ptr + EXPECT_EQ(asyncLoadsCompleted.load(), 2) + << "Both async loads should complete even after reader destruction. " + << "If this fails, it means async operations are being cancelled or " + << "crashing after DwrfRowReader destruction, indicating unsafe pointers."; + + // Clean up + ioExecutor->join(); +} + +} // namespace +} // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp b/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp index 4091204ba960..acce6cbf5cdd 100644 --- a/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp +++ b/velox/dwio/dwrf/test/StripeReaderBaseTests.cpp @@ -42,7 +42,8 @@ class StripeLoadKeysTest : public Test { HiveTypeParser parser; auto type = parser.parse("struct"); footer_ = std::make_unique(); - ProtoUtils::writeType(*type, *footer_); + auto footerWrapper = FooterWriteWrapper(footer_.get()); + ProtoUtils::writeType(*type, footerWrapper); auto enc = footer_->mutable_encryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); auto group = enc->add_encryptiongroups(); @@ -93,8 +94,8 @@ class StripeLoadKeysTest : public Test { handler_ = std::move(handler); - enc_ = const_cast( - std::addressof(dynamic_cast( + enc_ = const_cast(std::addressof( + dynamic_cast( handler_->getEncryptionProviderByIndex(0)))); } diff --git a/velox/dwio/dwrf/test/TestByteRle.cpp b/velox/dwio/dwrf/test/TestByteRle.cpp index ad19d341a50a..94f8ecf9da68 100644 --- a/velox/dwio/dwrf/test/TestByteRle.cpp +++ b/velox/dwio/dwrf/test/TestByteRle.cpp @@ -39,9 +39,9 @@ std::unique_ptr createBooleanDecoder( TEST(ByteRle, simpleTest) { const unsigned char buffer[] = {0x61, 0x00, 0xfd, 0x44, 0x45, 0x46}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(103); rle->next(data.data(), data.size(), nullptr); @@ -54,9 +54,9 @@ TEST(ByteRle, simpleTest) { } TEST(ByteRle, nullTest) { - char buffer[258]; - uint64_t nulls[5]; - char result[266]; + char buffer[258] = {'\0'}; + uint64_t nulls[5] = {'\0'}; + char result[266] = {'\0'}; buffer[0] = -128; buffer[129] = -128; for (int32_t i = 0; i < 128; ++i) { @@ -66,8 +66,8 @@ TEST(ByteRle, nullTest) { for (int32_t i = 0; i < 266; ++i) { bits::setNull(nulls, i, i < 10); } - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( new SeekableArrayInputStream(buffer, sizeof(buffer)))); rle->next(result, sizeof(result), nulls); for (size_t i = 0; i < sizeof(result); ++i) { @@ -93,9 +93,9 @@ TEST(ByteRle, literalCrossBuffer) { 0x09, 0x07, 0x10}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 6))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 6))); std::vector data(20); rle->next(data.data(), data.size(), nullptr); @@ -109,9 +109,9 @@ TEST(ByteRle, literalCrossBuffer) { TEST(ByteRle, skipLiteralBufferUnderflowTest) { const unsigned char buffer[] = {0xf8, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 4))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 4))); std::vector data(8); rle->next(data.data(), 3, nullptr); EXPECT_EQ(0x0, data[0]); @@ -127,9 +127,9 @@ TEST(ByteRle, skipLiteralBufferUnderflowTest) { TEST(ByteRle, simpleRuns) { const unsigned char buffer[] = {0x0d, 0xff, 0x0d, 0xfe, 0x0d, 0xfd}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(16); for (size_t i = 0; i < 3; ++i) { rle->next(data.data(), data.size(), nullptr); @@ -145,9 +145,9 @@ TEST(ByteRle, splitHeader) { 0x00, 0x01, 0xe0, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 1))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 1))); std::vector data(35); rle->next(data.data(), data.size(), nullptr); for (size_t i = 0; i < 3; ++i) { @@ -179,9 +179,9 @@ TEST(ByteRle, splitRuns) { 0x0e, 0x0f, 0x10}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(5); for (size_t i = 0; i < 3; ++i) { rle->next(data.data(), data.size(), nullptr); @@ -227,9 +227,9 @@ TEST(ByteRle, testNulls) { 0x0f, 0x3d, 0xdc}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 3))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer), 3))); std::vector data(16, -1); std::vector nulls(1); for (size_t i = 0; i < data.size(); ++i) { @@ -276,9 +276,9 @@ TEST(ByteRle, testAllNulls) { 0x0f, 0x3d, 0xdc}; - std::unique_ptr rle = - createByteDecoder(std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + std::unique_ptr rle = createByteDecoder( + std::unique_ptr( + new SeekableArrayInputStream(buffer, std::size(buffer)))); std::vector data(16, -1); std::vector allNull(1, bits::kNull64); std::vector noNull(1, bits::kNotNull64); @@ -413,7 +413,7 @@ TEST(ByteRle, testSkip) { 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, }; SeekableInputStream* const stream = - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)); + new SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr rle = createByteDecoder(std::unique_ptr(stream)); std::vector data(1); @@ -570,7 +570,7 @@ TEST(ByteRle, testSeek) { 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, }; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); const uint64_t fileLocs[] = { 0, 0, 0, 0, 0, 2, 2, 2, 2, 4, 4, 4, 4, 6, 6, 6, 6, 8, 8, 8, 8, 10, 10, 10, @@ -907,7 +907,7 @@ TEST(ByteRle, testSeek) { // Seek to end std::vector position; - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(0); PositionProvider pp{position}; rle->seekToRowGroup(pp); @@ -916,7 +916,7 @@ TEST(ByteRle, testSeek) { // Seek to end + 1 position.clear(); - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(1); PositionProvider pp2{position}; rle->seekToRowGroup(pp2); @@ -926,7 +926,7 @@ TEST(ByteRle, testSeek) { TEST(BooleanRle, simpleTest) { const unsigned char buffer[] = {0x61, 0xf0, 0xfd, 0x55, 0xAA, 0x55}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(50); for (size_t i = 0; i < 16; ++i) { @@ -951,7 +951,7 @@ TEST(BooleanRle, runsTest) { const unsigned char buffer[] = { 0xf7, 0xff, 0x80, 0x3f, 0xe0, 0x0f, 0xf8, 0x03, 0xfe, 0x00}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(72); rle->next(data.data(), data.size(), nullptr); @@ -973,7 +973,7 @@ TEST(BooleanRle, runsTestWithNull) { const unsigned char buffer[] = { 0xf7, 0xff, 0x80, 0x3f, 0xe0, 0x0f, 0xf8, 0x03, 0xfe, 0x00}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(72); std::vector nulls(bits::nwords(data.size()), bits::kNotNull64); @@ -1089,7 +1089,7 @@ TEST(BooleanRle, skipTest) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(1); for (size_t i = 0; i < 16384; i += 5) { @@ -1200,7 +1200,7 @@ TEST(BooleanRle, skipTestWithNulls) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); raw_vector data; data.resize(3); @@ -1365,7 +1365,7 @@ TEST(BooleanRle, seekTest) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); // Read all 16384 values and validate them. @@ -1501,7 +1501,7 @@ TEST(BooleanRle, seekTestWithNulls) { 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71, 0xc7, 0x1c, 0x71}; - auto* stream = new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)); + auto* stream = new SeekableArrayInputStream(buffer, std::size(buffer)); auto rle = createBooleanDecoder(std::unique_ptr(stream)); ASSERT_EQ(stream->totalRead(), 0); auto lastTotalReadBytes = stream->totalRead(); @@ -1519,7 +1519,7 @@ TEST(BooleanRle, seekTestWithNulls) { EXPECT_EQ(0, bits::isBitSet(data.data(), i)) << "Output wrong at " << i; } rle->next(data.data(), data.size(), noNull.data()); - ASSERT_EQ(getNumReadBytes(), VELOX_ARRAY_SIZE(buffer)); + ASSERT_EQ(getNumReadBytes(), std::size(buffer)); for (size_t i = 0; i < data.size(); ++i) { EXPECT_EQ(i < 8192 ? i & 1 : (i / 3) & 1, bits::isBitSet(data.data(), i)) << "Output wrong at " << i; @@ -1573,7 +1573,7 @@ TEST(BooleanRle, seekBoolAndByteRLE) { 0xf9, 0xf0, 0xf0, 0xf7, 0x1c, 0x71, 0xc1, 0x80}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); std::vector data(sizeof(num) / sizeof(char)); rle->next(data.data(), data.size(), nullptr); @@ -1596,7 +1596,7 @@ TEST(BooleanRle, seekBoolAndByteRLE) { TEST(BooleanRle, skipToEnd) { const unsigned char buffer[] = {0xfe, 0xff, 0xff}; std::unique_ptr stream( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))); + new SeekableArrayInputStream(buffer, std::size(buffer))); std::unique_ptr rle = createBooleanDecoder(std::move(stream)); char value[1]; rle->next(value, 1, nullptr); diff --git a/velox/dwio/dwrf/test/TestColumnReader.cpp b/velox/dwio/dwrf/test/TestColumnReader.cpp index da7b9f743f61..def2f28ea979 100644 --- a/velox/dwio/dwrf/test/TestColumnReader.cpp +++ b/velox/dwio/dwrf/test/TestColumnReader.cpp @@ -517,14 +517,14 @@ TEST_P(TestColumnReader, testBooleanWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [0x0f for x in range(256 / 8)] const unsigned char buffer2[] = {0x1d, 0x0f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -628,13 +628,13 @@ TEST_P(TestColumnReader, testBooleanSkipsWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [0x0f for x in range(128 / 8)] const unsigned char buffer2[] = {0x1d, 0x0f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -683,8 +683,8 @@ TEST_P(TestColumnReader, testByteWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(256) char buffer[258]; @@ -697,8 +697,8 @@ TEST_P(TestColumnReader, testByteWithNulls) { buffer[i + 2] = static_cast(i); } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -778,8 +778,8 @@ TEST_P(TestColumnReader, testByteSkipsWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(512 / 8)] const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(256) char buffer[258]; @@ -792,8 +792,8 @@ TEST_P(TestColumnReader, testByteSkipsWithNulls) { buffer[i + 2] = static_cast(i); } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -856,7 +856,7 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { int32_t expects_col0[] = {2110, 2120, 2130, 2140}; int32_t expects_col1[] = {11, 12, 13, 14}; int32_t expects_col2[] = {32, 34, 36, 38}; - int32_t size = VELOX_ARRAY_SIZE(col0); + int32_t size = std::size(col0); // set format streams_.setFormat(DwrfFormat::kOrc); @@ -884,8 +884,8 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { // col_0's DATA stream EXPECT_CALL( streams_, getStreamOrcProxy(1, proto::orc::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer0, VELOX_ARRAY_SIZE(buffer0)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer0, std::size(buffer0)))); // col_1's DATA stream std::array data; std::vector v; @@ -903,8 +903,8 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { // col_2's DATA stream EXPECT_CALL( streams_, getStreamOrcProxy(3, proto::orc::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = @@ -929,7 +929,7 @@ TEST_P(TestColumnReader, testIntegerRLEv2) { auto colBatch = getChild>(batch, 0); auto colBatch2 = getChild>(batch, 1); auto colBatch3 = getChild>(batch, 2); - ASSERT_EQ(VELOX_ARRAY_SIZE(expects_col0), colBatch->size()); + ASSERT_EQ(std::size(expects_col0), colBatch->size()); ASSERT_EQ(colBatch->size(), colBatch2->size()); ASSERT_EQ(colBatch2->size(), colBatch3->size()); for (size_t i = 0; i < batch->size(); ++i) { @@ -972,8 +972,8 @@ TEST_P(TestColumnReader, testIntegerWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char buffer1[] = {0x16, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); char buffer2[1024]; size_t size = writeRange(buffer2, 0, 100); @@ -1091,8 +1091,8 @@ TEST_P(TestColumnReader, testIntDictSkipWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char buffer1[] = {0x16, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // even row points to dictionary. char buffer2[1024]; @@ -1108,8 +1108,8 @@ TEST_P(TestColumnReader, testIntDictSkipWithNulls) { const unsigned char buffer3[] = {0x0a, 0xaa}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_IN_DICTIONARY, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); EXPECT_CALL(streams_, genMockDictDataSetter(1, 0)) .WillRepeatedly(Return([&](BufferPtr& buffer, MemoryPool* pool) { @@ -1249,21 +1249,21 @@ TEST_P(StringReaderTests, testDictionaryWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char buffer1[] = {0x19, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = {0x2f, 0x00, 0x00, 0x2f, 0x00, 0x01}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); const unsigned char buffer3[] = {0x4f, 0x52, 0x43, 0x4f, 0x77, 0x65, 0x6e}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_DICTIONARY_DATA, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); const unsigned char buffer4[] = {0x02, 0x01, 0x03}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer4, VELOX_ARRAY_SIZE(buffer4)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer4, std::size(buffer4)))); TestStrideIndexProvider provider(10000); EXPECT_CALL(streams_, getStrideIndexProviderProxy()) @@ -1395,8 +1395,8 @@ TEST_P(StringReaderTests, testStringDictSkipNoNulls) { const unsigned char inDict[] = {0x0a, 0xaa}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_IN_DICTIONARY, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(inDict, VELOX_ARRAY_SIZE(inDict)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(inDict, std::size(inDict)))); auto indexData = index.SerializePartialAsString(); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_ROW_INDEX, _)) @@ -1484,8 +1484,8 @@ TEST_P(StringReaderTests, testStringDictSkipWithNulls) { .WillRepeatedly(Return(nullptr)); const unsigned char present[] = {0x16, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(present, VELOX_ARRAY_SIZE(present)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(present, std::size(present)))); char data[1024]; data[0] = 0x9c; @@ -1565,8 +1565,8 @@ TEST_P(StringReaderTests, testStringDictSkipWithNulls) { const unsigned char inDict[] = {0x0a, 0xaa}; EXPECT_CALL( streams_, getStreamProxy(1, proto::Stream_Kind_IN_DICTIONARY, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(inDict, VELOX_ARRAY_SIZE(inDict)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(inDict, std::size(inDict)))); auto indexData = index.SerializePartialAsString(); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_ROW_INDEX, _)) @@ -1637,18 +1637,18 @@ TEST_P(TestNonSelectiveColumnReader, testSubstructsWithNulls) { const unsigned char buffer1[] = {0x16, 0x0f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = {0x0a, 0x55}; EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); const unsigned char buffer3[] = {0x04, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(3, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); char buffer4[256]; size_t size = writeRange(buffer4, 0, 26); @@ -1723,13 +1723,13 @@ TEST_P(TestColumnReader, testSkipWithNulls) { const unsigned char buffer1[] = { 0x03, 0x00, 0xff, 0x3f, 0x08, 0xff, 0xff, 0xfc, 0x03, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, _, _)) .WillRepeatedly(Return(nullptr)); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); TestStrideIndexProvider provider(10000); EXPECT_CALL(streams_, getStrideIndexProviderProxy()) @@ -1741,8 +1741,8 @@ TEST_P(TestColumnReader, testSkipWithNulls) { .WillRepeatedly(Return(new SeekableArrayInputStream(buffer2, size))); const unsigned char buffer3[] = {0x61, 0x01, 0x00}; EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); // fill the dictionary with '00' to '99' char digits[200]; @@ -1754,12 +1754,12 @@ TEST_P(TestColumnReader, testSkipWithNulls) { } EXPECT_CALL( streams_, getStreamProxy(2, proto::Stream_Kind_DICTIONARY_DATA, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(digits, VELOX_ARRAY_SIZE(digits)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(digits, std::size(digits)))); const unsigned char buffer4[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_LENGTH, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer4, VELOX_ARRAY_SIZE(buffer4)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer4, std::size(buffer4)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -1838,12 +1838,12 @@ TEST_P(StringReaderTests, testBinaryDirect) { } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly( - Return(new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob)))); + Return(new SeekableArrayInputStream(blob, std::size(blob)))); const unsigned char buffer[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -1892,8 +1892,8 @@ TEST_P(StringReaderTests, testBinaryDirectWithNulls) { const unsigned char buffer1[] = {0x1d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); char blob[256]; for (size_t i = 0; i < 8; ++i) { @@ -1904,12 +1904,12 @@ TEST_P(StringReaderTests, testBinaryDirectWithNulls) { } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly( - Return(new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob)))); + Return(new SeekableArrayInputStream(blob, std::size(blob)))); const unsigned char buffer2[] = {0x7d, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -1969,12 +1969,12 @@ TEST_P(TestColumnReader, testShortBlobError) { char blob[100]; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly( - Return(new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob)))); + Return(new SeekableArrayInputStream(blob, std::size(blob)))); const unsigned char buffer1[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2020,13 +2020,13 @@ TEST_P(StringReaderTests, testStringDirectShortBuffer) { } } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob), 3))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(blob, std::size(blob), 3))); const unsigned char buffer1[] = {0x61, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2075,8 +2075,8 @@ TEST_P(StringReaderTests, testStringDirectShortBufferWithNulls) { const unsigned char buffer1[] = {0x3d, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); char blob[512]; for (size_t i = 0; i < 16; ++i) { @@ -2086,13 +2086,13 @@ TEST_P(StringReaderTests, testStringDirectShortBufferWithNulls) { } } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob), 30))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(blob, std::size(blob), 30))); const unsigned char buffer2[] = {0x7d, 0x00, 0x02, 0x7d, 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2154,19 +2154,19 @@ TEST_P(StringReaderTests, testStringDirectNullAcrossWindow) { const unsigned char isNull[2] = {0xff, 0x7f}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(isNull, VELOX_ARRAY_SIZE(isNull)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(isNull, std::size(isNull)))); const char blob[] = "abcdefg"; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(blob, VELOX_ARRAY_SIZE(blob), 4))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(blob, std::size(blob), 4))); // [1] * 7 const unsigned char lenData[] = {0x04, 0x00, 0x01}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(lenData, VELOX_ARRAY_SIZE(lenData)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(lenData, std::size(lenData)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2250,8 +2250,8 @@ TEST_P(StringReaderTests, testStringDirectSkip) { 0x01, 0x8a, 0x05, 0x7f, 0x01, 0x8c, 0x06, 0x7f, 0x01, 0x8e, 0x07, 0x7f, 0x01, 0x90, 0x08, 0x1b, 0x01, 0x92, 0x09}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2298,8 +2298,8 @@ TEST_P(StringReaderTests, testStringDirectSkipWithNulls) { // alternate 4 non-null and 4 null via [0xf0 for x in range(2400 / 8)] const unsigned char buffer1[] = {0x7f, 0xf0, 0x7f, 0xf0, 0x25, 0xf0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // sum(range(1200)) const size_t BLOB_SIZE = 719400; @@ -2323,8 +2323,8 @@ TEST_P(StringReaderTests, testStringDirectSkipWithNulls) { 0x01, 0x8a, 0x05, 0x7f, 0x01, 0x8c, 0x06, 0x7f, 0x01, 0x8e, 0x07, 0x7f, 0x01, 0x90, 0x08, 0x1b, 0x01, 0x92, 0x09}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -2386,8 +2386,8 @@ TEST_P(TestColumnReader, testList) { 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(1200) char buffer2[8192]; @@ -2435,8 +2435,8 @@ TEST_P(TestNonSelectiveColumnReader, testListPropagateNulls) { // set getStream const unsigned char buffer[] = {0xff, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_LENGTH, true)) .WillRepeatedly(Return(new SeekableArrayInputStream(buffer, 0))); @@ -2477,8 +2477,8 @@ TEST_P(TestNonSelectiveColumnReader, testListWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2493,8 +2493,8 @@ TEST_P(TestNonSelectiveColumnReader, testListWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // range(2048) char buffer3[8192]; @@ -2636,8 +2636,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2652,8 +2652,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // range(2048) char buffer3[8192]; @@ -2737,8 +2737,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNullsNoData) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2753,8 +2753,8 @@ TEST_P(TestNonSelectiveColumnReader, testListSkipWithNullsNoData) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_DATA, true)) .WillRepeatedly(Return(nullptr)); @@ -2815,8 +2815,8 @@ TEST_P(TestNonSelectiveColumnReader, testListWithAllNulls) { // set getStream const unsigned char buffer[] = {0xff, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer, std::size(buffer)))); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) .WillRepeatedly(Return(new SeekableArrayInputStream(buffer, 0))); @@ -2869,8 +2869,8 @@ TEST_P(TestColumnReader, testMap) { 0x00, 0x02}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // range(1200) char buffer2[8192]; @@ -2926,8 +2926,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(2, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(nullptr)); @@ -2935,8 +2935,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithNulls) { // [0x55 for x in range(2048/8)] const unsigned char buffer2[] = {0x7f, 0x55, 0x7b, 0x55}; EXPECT_CALL(streams_, getStreamProxy(3, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // [1 for x in range(260)] + // [4 for x in range(260)] + @@ -2948,8 +2948,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); // range(2048) char buffer4[8192]; @@ -3131,8 +3131,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNulls) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [1 for x in range(260)] + // [4 for x in range(260)] + @@ -3144,8 +3144,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNulls) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // range(2048) char buffer3[8192]; @@ -3254,8 +3254,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNullsNoData) { // [0xaa for x in range(2048/8)] const unsigned char buffer1[] = {0x7f, 0xaa, 0x7b, 0xaa}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // [1 for x in range(260)] + // [4 for x in range(260)] + @@ -3267,8 +3267,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapSkipWithNullsNoData) { 0x00, 0x7f, 0x00, 0x00, 0x7f, 0x00, 0x03, 0x6e, 0x00, 0x03, 0xff, 0x13}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct>"); @@ -3322,8 +3322,8 @@ TEST_P(TestNonSelectiveColumnReader, testMapWithAllNulls) { const unsigned char buffer1[] = {0xff, 0x00}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) .WillRepeatedly(Return(new SeekableArrayInputStream(buffer1, 0))); @@ -3365,9 +3365,7 @@ TEST_P(TestColumnReader, testFloatBatchNotAligned) { EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) .WillRepeatedly(Return(new SeekableArrayInputStream( - byteValues, - VELOX_ARRAY_SIZE(byteValues), - VELOX_ARRAY_SIZE(byteValues) / 2))); + byteValues, std::size(byteValues), std::size(byteValues) / 2))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3403,8 +3401,8 @@ TEST_P(TestColumnReader, testFloatWithNulls) { // 13 non-nulls followed by 19 nulls const unsigned char buffer1[] = {0xfc, 0xff, 0xf8, 0x0, 0x0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const float test_vals[] = { 1.0f, @@ -3427,8 +3425,8 @@ TEST_P(TestColumnReader, testFloatWithNulls) { 0x0, 0x80, 0xff, 0xff, 0xff, 0x7f, 0x7f, 0xff, 0xff, 0x7f, 0xff, 0x1, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x80}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3472,8 +3470,8 @@ TEST_P(TestColumnReader, testFloatSkipWithNulls) { // 2 non-nulls, 2 nulls, 2 non-nulls, 2 nulls const unsigned char buffer1[] = {0xff, 0xcc}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // 1, 2.5, -100.125, 10000 const unsigned char buffer2[] = { @@ -3494,8 +3492,8 @@ TEST_P(TestColumnReader, testFloatSkipWithNulls) { 0x1c, 0x46}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3563,8 +3561,8 @@ TEST_P(TestColumnReader, testDoubleWithNulls) { // 13 non-nulls followed by 19 nulls const unsigned char buffer1[] = {0xfc, 0xff, 0xf8, 0x0, 0x0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const double test_vals[] = { 1.0, @@ -3591,8 +3589,8 @@ TEST_P(TestColumnReader, testDoubleWithNulls) { 0xff, 0xff, 0xef, 0xff, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3637,16 +3635,16 @@ TEST_P(TestColumnReader, testDoubleSkipWithNulls) { // 1 non-null, 5 nulls, 2 non-nulls const unsigned char buffer1[] = {0xff, 0x83}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // 1, 2, -2 const unsigned char buffer2[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc0}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3713,8 +3711,8 @@ TEST_P(TestColumnReader, testTimestampSkipWithNulls) { // 2 non-nulls, 2 nulls, 2 non-nulls, 2 nulls const unsigned char buffer1[] = {0xff, 0xcc}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = { 0xfc, @@ -3735,13 +3733,13 @@ TEST_P(TestColumnReader, testTimestampSkipWithNulls) { 0xd4, 0x30}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); const unsigned char buffer3[] = {0x1, 0x8, 0x5e}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer3, VELOX_ARRAY_SIZE(buffer3)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer3, std::size(buffer3)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3831,15 +3829,15 @@ TEST_P(TestColumnReader, testTimestamp) { 0xba, 0xa0, 0x1a, 0x9d, 0x88, 0xa6, 0x82, 0x1a, 0x9d, 0xba, 0x9c, 0xe4, 0x19, 0x9d, 0xee, 0xe1, 0xcd, 0x18}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); const unsigned char buffer2[] = { 0xf6, 0x00, 0xa8, 0xd1, 0xf9, 0xd6, 0x03, 0x00, 0x9e, 0x01, 0xec, 0x76, 0xf4, 0x76, 0xfc, 0x76, 0x84, 0x77, 0x8c, 0x77, 0xfd, 0x0b}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3909,13 +3907,13 @@ TEST_P(TestColumnReader, testDecimal64) { } } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return(new SeekableArrayInputStream( - numBuffer, VELOX_ARRAY_SIZE(numBuffer), 3))); + .WillRepeatedly(Return( + new SeekableArrayInputStream(numBuffer, std::size(numBuffer), 3))); // col_0's Secondary Stream const unsigned char buffer2[] = {0x3e, 0x00, 0x04}; // [0x02] * 65 EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer2, VELOX_ARRAY_SIZE(buffer2)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer2, std::size(buffer2)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -3948,19 +3946,19 @@ TEST_P(TestColumnReader, testDecimal64WithSkip) { const unsigned char presentBuffer[] = {0xfe, 0xff, 0x80}; // [0xff] EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(new SeekableArrayInputStream( - presentBuffer, VELOX_ARRAY_SIZE(presentBuffer)))); + presentBuffer, std::size(presentBuffer)))); const unsigned char numBuffer[] = { 0xf8, 0xe8, 0xe2, 0xcf, 0xf4, 0xcb, 0xb6, 0xda, 0x0d, 0x86, 0xc1, 0xcc, 0xcd, 0x9e, 0xd5, 0xc5, 0x11, 0xb4, 0xf6, 0xfc, 0xf3, 0xb9, 0xba, 0x16, 0xca, 0xe7, 0xa3, 0xa6, 0xdf, 0x1c, 0xea, 0xad, 0xc0, 0xe5, 0x24, 0xf8, 0x94, 0x8c, 0x2f, 0x86, 0xa4, 0x3c, 0x94, 0x4d, 0x62}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return(new SeekableArrayInputStream( - numBuffer, VELOX_ARRAY_SIZE(numBuffer)))); + .WillRepeatedly(Return( + new SeekableArrayInputStream(numBuffer, std::size(numBuffer)))); const unsigned char buffer1[] = {0x06, 0x00, 0x14}; // [0x0a] * 9 EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -4004,7 +4002,7 @@ TEST_P(TestColumnReader, testDecimal128WithSkip) { const unsigned char presentBuffer[] = {0xfe, 0xff, 0xf8}; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_PRESENT, false)) .WillRepeatedly(Return(new SeekableArrayInputStream( - presentBuffer, VELOX_ARRAY_SIZE(presentBuffer)))); + presentBuffer, std::size(presentBuffer)))); const unsigned char numBuffer[] = { 0xf8, 0xe8, 0xe2, 0xcf, 0xf4, 0xcb, 0xb6, 0xda, 0x0d, 0x86, 0xc1, 0xcc, 0xcd, 0x9e, 0xd5, 0xc5, 0x11, 0xb4, 0xf6, 0xfc, 0xf3, 0xb9, 0xba, 0x16, @@ -4018,12 +4016,12 @@ TEST_P(TestColumnReader, testDecimal128WithSkip) { 0x93, 0xe8, 0xa3, 0xec, 0xd0, 0x96, 0xd4, 0xcc, 0xf6, 0xac, 0x02, }; EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_DATA, true)) - .WillRepeatedly(Return(new SeekableArrayInputStream( - numBuffer, VELOX_ARRAY_SIZE(numBuffer)))); + .WillRepeatedly(Return( + new SeekableArrayInputStream(numBuffer, std::size(numBuffer)))); const unsigned char buffer1[] = {0x0a, 0x00, 0x4a}; // [0x02] * 13 EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_NANO_DATA, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(buffer1, VELOX_ARRAY_SIZE(buffer1)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(buffer1, std::size(buffer1)))); // create the row type auto rowType = HiveTypeParser().parse("struct"); @@ -4087,8 +4085,8 @@ TEST_P(TestColumnReader, testLargeSkip) { length[pos + 2] = 0x01; } EXPECT_CALL(streams_, getStreamProxy(1, proto::Stream_Kind_LENGTH, true)) - .WillRepeatedly(Return( - new SeekableArrayInputStream(length, VELOX_ARRAY_SIZE(length)))); + .WillRepeatedly( + Return(new SeekableArrayInputStream(length, std::size(length)))); char data[1024 * 1024]; size_t size = writeRange(data, 0, 73200); diff --git a/velox/dwio/dwrf/test/TestDecompression.cpp b/velox/dwio/dwrf/test/TestDecompression.cpp index 3dcf5f68e46b..7a4d6b283773 100644 --- a/velox/dwio/dwrf/test/TestDecompression.cpp +++ b/velox/dwio/dwrf/test/TestDecompression.cpp @@ -336,7 +336,7 @@ TEST_F(DecompressionTest, testLzoSmall) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZO, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -353,7 +353,7 @@ TEST_F(DecompressionTest, testLzoSmall) { TEST_F(DecompressionTest, testLzoLong) { // set up a framed lzo buffer with 100,000 'a' unsigned char buffer[482]; - bzero(buffer, VELOX_ARRAY_SIZE(buffer)); + bzero(buffer, std::size(buffer)); // header buffer[0] = 190; buffer[1] = 3; @@ -378,7 +378,7 @@ TEST_F(DecompressionTest, testLzoLong) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZO, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -413,7 +413,7 @@ TEST_F(DecompressionTest, testLz4Small) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZ4, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -430,7 +430,7 @@ TEST_F(DecompressionTest, testLz4Small) { TEST_F(DecompressionTest, testLz4Long) { // set up a framed lzo buffer with 100,000 'a' unsigned char buffer[406]; - memset(buffer, 255, VELOX_ARRAY_SIZE(buffer)); + memset(buffer, 255, std::size(buffer)); // header buffer[0] = 38; buffer[1] = 3; @@ -448,7 +448,7 @@ TEST_F(DecompressionTest, testLz4Long) { std::unique_ptr result = createTestDecompressor( CompressionKind_LZ4, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 128 * 1024); const void* ptr; int32_t length; @@ -465,7 +465,7 @@ TEST_F(DecompressionTest, testCreateZlib) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 32768); EXPECT_EQ( "PagedInputStream StreamInfo (Test Decompression) input stream (SeekableArrayInputStream 0 of 8) State (0) remaining length (0)", @@ -497,7 +497,7 @@ TEST_F(DecompressionTest, testLiteralBlocks) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 5)), + new SeekableArrayInputStream(buffer, std::size(buffer), 5)), 5); EXPECT_EQ( "PagedInputStream StreamInfo (Test Decompression) input stream (SeekableArrayInputStream 0 of 23) State (0) remaining length (0)", @@ -539,7 +539,7 @@ TEST_F(DecompressionTest, testInflate) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer))), + new SeekableArrayInputStream(buffer, std::size(buffer))), 1000); const void* ptr; int32_t length; @@ -552,6 +552,25 @@ TEST_F(DecompressionTest, testInflate) { } } +TEST_F(DecompressionTest, testSmallBufferInflate) { + const unsigned char buffer[] = { + 0xe, 0x0, 0x0, 0x63, 0x60, 0x64, 0x62, 0xc0, 0x8d, 0x0}; + const std::unique_ptr result = createTestDecompressor( + CompressionKind_ZLIB, + std::make_unique(buffer, std::size(buffer)), + 1 // blockSize 1 to test multiple inflate calls during decompression. + ); + const void* ptr; + int32_t length; + ASSERT_EQ(true, result->Next(&ptr, &length)); + ASSERT_EQ(30, length); + for (int32_t i = 0; i < 10; ++i) { + for (int32_t j = 0; j < 3; ++j) { + EXPECT_EQ(j, static_cast(ptr)[i * 3 + j]); + } + } +} + TEST_F(DecompressionTest, testInflateSequence) { const unsigned char buffer[] = {0xe, 0x0, 0x0, 0x63, 0x60, 0x64, 0x62, 0xc0, 0x8d, 0x0, 0xe, 0x0, 0x0, 0x63, @@ -559,7 +578,7 @@ TEST_F(DecompressionTest, testInflateSequence) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 3)), + new SeekableArrayInputStream(buffer, std::size(buffer), 3)), 1000); const void* ptr; int32_t length; @@ -594,7 +613,7 @@ TEST_F(DecompressionTest, testSkipZlib) { std::unique_ptr result = createTestDecompressor( CompressionKind_ZLIB, std::unique_ptr( - new SeekableArrayInputStream(buffer, VELOX_ARRAY_SIZE(buffer), 5)), + new SeekableArrayInputStream(buffer, std::size(buffer), 5)), 5); const void* ptr; int32_t length; @@ -906,7 +925,7 @@ class TestSeek : public ::testing::Test { size_t arr[][2]{{0, 0}, {0, seekPos}, {offset1, seekPos}, {offset1, 0}}; char* input[]{input1, input1, input2, input2}; - for (size_t i = 0; i < VELOX_ARRAY_SIZE(arr); ++i) { + for (size_t i = 0; i < std::size(arr); ++i) { auto pos = arr[i]; std::vector list{pos[0], pos[1]}; PositionProvider pp(list); @@ -1046,8 +1065,8 @@ class TestingSeekableInputStream : public SeekableInputStream { return true; } - google::protobuf::int64 ByteCount() const override { - return position_; + int64_t ByteCount() const override { + return static_cast(position_); } void seekToPosition(PositionProvider& position) override { @@ -1097,7 +1116,7 @@ TEST_F(TestSeek, uncompressedLarge) { entry.getCompressed()[i] = static_cast(i); } written += runSize + kHeaderSize; - data.insert(data.end(), entry.data().begin(), entry.data().end()); + data.insert(data.end(), entry.data().cbegin(), entry.data().cend()); } auto stream = createTestDecompressor( CompressionKind_SNAPPY, diff --git a/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp b/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp index dc42b6941b15..0fb09b43fa43 100644 --- a/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp +++ b/velox/dwio/dwrf/test/TestDictionaryEncodingUtils.cpp @@ -22,7 +22,9 @@ using namespace testing; using namespace facebook::velox::memory; -namespace facebook::velox::dwrf { +namespace facebook::velox::dwrf::test { +namespace { + class DictionaryEncodingUtilsTest : public testing::Test { protected: static void SetUpTestCase() { @@ -36,7 +38,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringGetSortedIndexLookupTable) { bool sort, std::function ordering, - const std::vector& addKeySequence, + const std::vector& addKeySequence, const std::vector& lookupTable) : sort{sort}, ordering{ordering}, @@ -45,7 +47,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringGetSortedIndexLookupTable) { bool sort; std::function ordering; - std::vector addKeySequence; + std::vector addKeySequence; std::vector lookupTable; }; @@ -151,7 +153,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { bool sort, std::function ordering, - const std::vector& addKeySequence, + const std::vector& addKeySequence, const std::vector& lookupTable, const std::vector& inDict, size_t finalDictSize, @@ -166,7 +168,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { bool sort; std::function ordering; - std::vector addKeySequence; + std::vector addKeySequence; std::vector lookupTable; std::vector inDict; size_t finalDictSize; @@ -315,7 +317,7 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { dwio::common::DataBuffer strideDictSizes{ *pool, rowCount / kStrideSize + 1}; - std::vector expected{testCase.finalDictSize}; + std::vector expected{testCase.finalDictSize}; std::vector expectedSize(testCase.finalDictSize); for (size_t i = 0; i < testCase.lookupTable.size(); ++i) { if (testCase.inDict[i]) { @@ -368,4 +370,5 @@ TEST_F(DictionaryEncodingUtilsTest, StringStrideDictOptimization) { } } -} // namespace facebook::velox::dwrf +} // namespace +} // namespace facebook::velox::dwrf::test diff --git a/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp b/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp index cf13c6df826b..a71ddc8a4bc4 100644 --- a/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp +++ b/velox/dwio/dwrf/test/TestDwrfColumnStatistics.cpp @@ -195,12 +195,12 @@ void checkEntries( for (const auto& entry : entries) { EXPECT_NE( std::find_if( - expectedEntries.begin(), - expectedEntries.end(), + expectedEntries.cbegin(), + expectedEntries.cend(), [&](const ColumnStatistics& expectedStats) { return expectedStats == entry; }), - expectedEntries.end()); + expectedEntries.cend()); } } @@ -490,11 +490,11 @@ TEST(MapStatisticsBuilderTest, mergeKeyStats) { statsBuilder.increaseRawSize(8); mapStatsBuilder.addValues(createKeyInfo(1), statsBuilder); - keyStats = dynamic_cast( + auto& keyStats1 = dynamic_cast( *mapStatsBuilder.getEntryStatistics().at(KeyInfo{1})); - ASSERT_EQ(2, keyStats.getNumberOfValues()); - ASSERT_TRUE(keyStats.getRawSize().has_value()); - ASSERT_EQ(8, keyStats.getRawSize().value()); - EXPECT_TRUE(keyStats.getSize().has_value()); - EXPECT_EQ(42, keyStats.getSize().value()); + ASSERT_EQ(2, keyStats1.getNumberOfValues()); + ASSERT_TRUE(keyStats1.getRawSize().has_value()); + ASSERT_EQ(8, keyStats1.getRawSize().value()); + EXPECT_TRUE(keyStats1.getSize().has_value()); + EXPECT_EQ(42, keyStats1.getSize().value()); } diff --git a/velox/dwio/dwrf/test/TestReadFile.h b/velox/dwio/dwrf/test/TestReadFile.h index 8501b231ed84..e47d83224182 100644 --- a/velox/dwio/dwrf/test/TestReadFile.h +++ b/velox/dwio/dwrf/test/TestReadFile.h @@ -42,15 +42,15 @@ class TestReadFile : public velox::ReadFile { uint64_t offset, uint64_t length, void* buffer, - filesystems::File::IoStats* stats = nullptr) const override { + const FileStorageContext& fileStorageContext = {}) const override { const uint64_t content = offset + seed_; const uint64_t available = std::min(length_ - offset, length); int fill; for (fill = 0; fill < available; ++fill) { reinterpret_cast(buffer)[fill] = content + fill; } - if (stats) { - stats->addCounter( + if (fileStorageContext.ioStats) { + fileStorageContext.ioStats->addCounter( "read", RuntimeCounter(fill, RuntimeCounter::Unit::kBytes)); } return std::string_view(static_cast(buffer), fill); @@ -59,10 +59,10 @@ class TestReadFile : public velox::ReadFile { uint64_t preadv( uint64_t offset, const std::vector>& buffers, - filesystems::File::IoStats* stats = nullptr) const override { - auto res = ReadFile::preadv(offset, buffers, stats); - if (stats) { - stats->addCounter( + const FileStorageContext& fileStorageContext = {}) const override { + auto res = ReadFile::preadv(offset, buffers, fileStorageContext); + if (fileStorageContext.ioStats) { + fileStorageContext.ioStats->addCounter( "read", RuntimeCounter( static_cast(res), RuntimeCounter::Unit::kBytes)); diff --git a/velox/dwio/dwrf/test/TestRle.cpp b/velox/dwio/dwrf/test/TestRle.cpp index b779dedff8e8..d1fe6f2c813d 100644 --- a/velox/dwio/dwrf/test/TestRle.cpp +++ b/velox/dwio/dwrf/test/TestRle.cpp @@ -88,7 +88,7 @@ TEST_F(RLEv2Test, basicDelta0) { } const unsigned char bytes[] = {0xc0, 0x13, 0x00, 0x02}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -106,7 +106,7 @@ TEST_F(RLEv2Test, basicDelta1) { const unsigned char bytes[] = { 0xce, 0x04, 0xe7, 0x07, 0xc8, 0x01, 0x32, 0x19, 0x0f}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -127,7 +127,7 @@ TEST_F(RLEv2Test, basicDelta2) { const unsigned char bytes[] = { 0xce, 0x04, 0xe7, 0x07, 0xc7, 0x01, 0x32, 0x19, 0x23}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -148,7 +148,7 @@ TEST_F(RLEv2Test, basicDelta3) { const unsigned char bytes[] = { 0xce, 0x04, 0xe8, 0x07, 0xc7, 0x01, 0x32, 0x19, 0x0f}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -169,7 +169,7 @@ TEST_F(RLEv2Test, basicDelta4) { const unsigned char bytes[] = { 0xce, 0x04, 0xe8, 0x07, 0xc8, 0x01, 0x32, 0x19, 0x23}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -188,7 +188,7 @@ TEST_F(RLEv2Test, delta0Width) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_2, *pool, true /* doesn't matter */, @@ -219,7 +219,7 @@ TEST_F(RLEv2Test, basicDelta0WithNulls) { } const unsigned char bytes[] = {0xc0, 0x13, 0x00, 0x02}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); const size_t count = values.size(); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count, nulls), 1, nulls); @@ -243,7 +243,7 @@ TEST_F(RLEv2Test, shortRepeats) { const unsigned char bytes[] = {0x04, 0x00, 0x04, 0x02, 0x04, 0x04, 0x04, 0x06, 0x04, 0x08, 0x04, 0x0a, 0x04, 0x0c, 0x04, 0x0e, 0x04, 0x10, 0x04, 0x12}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -266,7 +266,7 @@ TEST_F(RLEv2Test, multiByteShortRepeats) { 0x00, 0x00, 0x3c, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x3c, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -280,7 +280,7 @@ TEST_F(RLEv2Test, 0to2Repeat1Direct) { std::unique_ptr> rle = createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_2, *pool, true /* doesn't matter */, @@ -302,7 +302,7 @@ TEST_F(RLEv2Test, bitSize2Direct) { } const unsigned char bytes[] = {0x42, 0x13, 0x22, 0x22, 0x22, 0x22, 0x22}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); checkResults(values, decodeRLEv2(bytes, l, 3, count), 3); @@ -320,7 +320,7 @@ TEST_F(RLEv2Test, bitSize4Direct) { const unsigned char bytes[] = { 0x46, 0x13, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, count), 1); @@ -360,7 +360,7 @@ TEST_F(RLEv2Test, multipleRunsDirect) { 0x04, 0x04, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); @@ -382,7 +382,7 @@ TEST_F(RLEv2Test, largeNegativesDirect) { std::unique_ptr> rle = createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_2, *pool, true /* doesn't matter */, @@ -408,7 +408,7 @@ TEST_F(RLEv2Test, overflowDirect) { 0x7e, 0x03, 0x7d, 0x45, 0x3c, 0x12, 0x41, 0x48, 0xf4, 0xbe, 0x7d, 0x45, 0x3c, 0x12, 0x41, 0x48, 0xf4, 0xae, 0x50, 0xce, 0xad, 0x2a, 0x30, 0x0e, 0xd2, 0x96, 0xfe, 0xd8, 0xd2, 0x38, 0x54, 0x6e, 0x3d, 0x81}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -445,7 +445,7 @@ TEST_F(RLEv2Test, basicPatched0) { 0x5a, 0xfc, 0xe8}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -484,7 +484,7 @@ TEST_F(RLEv2Test, basicPatched1) { 0xe0, 0x78, 0x00, 0x1c, 0x0f, 0x08, 0x06, 0x81, 0xc6, 0x90, 0x80, 0x68, 0x24, 0x1b, 0x0b, 0x26, 0x83, 0x21, 0x30, 0xe0, 0x98, 0x3c, 0x6f, 0x06, 0xb7, 0x03, 0x70}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -544,7 +544,7 @@ TEST_F(RLEv2Test, mixedPatchedAndShortRepeats) { 0x00, 0x0c, 0x02, 0x08, 0x18, 0x00, 0x40, 0x00, 0x01, 0x00, 0x00, 0x08, 0x30, 0x33, 0x80, 0x00, 0x02, 0x0c, 0x10, 0x20, 0x20, 0x47, 0x80, 0x13, 0x4c}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); // Read 1 at a time, then 3 at a time, etc. checkResults(values, decodeRLEv2(bytes, l, 1, values.size()), 1); checkResults(values, decodeRLEv2(bytes, l, 3, values.size()), 3); @@ -579,7 +579,7 @@ TEST_F(RLEv2Test, basicDirectSeek) { 0x04, 0x04, 0x04}; - unsigned long l = sizeof(bytes) / sizeof(char); + unsigned long l = sizeof(bytes); std::unique_ptr> rle = createRleDecoder( std::unique_ptr( @@ -688,7 +688,7 @@ TEST_F(RLEv1Test, simpleTest) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_1, *pool, true, @@ -712,7 +712,7 @@ TEST_F(RLEv1Test, signedNullLiteralTest) { std::unique_ptr> rle = createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_1, *pool, true, @@ -733,7 +733,7 @@ TEST_F(RLEv1Test, splitHeader) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer), 4)), + buffer, std::size(buffer), 4)), RleVersion_1, *pool, true, @@ -751,8 +751,7 @@ TEST_F(RLEv1Test, splitRuns) { const unsigned char buffer[] = { 0x7d, 0x01, 0xff, 0x01, 0xfb, 0x01, 0x02, 0x03, 0x04, 0x05}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), @@ -784,8 +783,7 @@ TEST_F(RLEv1Test, testSigned) { auto pool = memory::memoryManager()->addLeafPool(); const unsigned char buffer[] = {0x7f, 0xff, 0x20}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), RleVersion_1, @@ -808,8 +806,7 @@ TEST_F(RLEv1Test, testNull) { auto pool = memory::memoryManager()->addLeafPool(); const unsigned char buffer[] = {0x75, 0x02, 0x00}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), RleVersion_1, @@ -842,8 +839,7 @@ TEST_F(RLEv1Test, testAllNulls) { 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x3d, 0x00, 0x12}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), @@ -1095,8 +1091,7 @@ TEST_F(RLEv1Test, skipTest) { 128, 228, 63, 128, 232, 63, 128, 236, 63, 128, 240, 63, 128, 244, 63, 128, 248, 63, 128, 252, 63}; dwio::common::SeekableInputStream* const stream = - new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); std::unique_ptr> rle = createRleDecoder( std::unique_ptr(stream), RleVersion_1, @@ -1872,8 +1867,8 @@ TEST_F(RLEv1Test, seekTest) { 151, 12, 193, 190, 224, 143, 9, 129, 245, 133, 204, 8, 182, 209, 250, 178, 8, 148, 139, 144, 193, 11, 230, 182, 245, 164, 7, 149, 204, 161, 226, 14, 175, 229, 148, 166, 13, 148, 140, 189, 216, 3}; - auto* stream = new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer)); + auto* stream = + new dwio::common::SeekableArrayInputStream(buffer, std::size(buffer)); const long junk[] = { -1192035722, 1672896916, 1491444859, -1244121273, -791680696, 1681943525, -571055948, -1744759283, -998345856, 240559198, @@ -2991,7 +2986,7 @@ TEST_F(RLEv1Test, seekTest) { }; std::vector data(2048); rle->next(data.data(), data.size(), nullptr); - ASSERT_EQ(getNumReadBytes(), VELOX_ARRAY_SIZE(buffer)); + ASSERT_EQ(getNumReadBytes(), std::size(buffer)); for (size_t i = 0; i < data.size(); ++i) { if (i < 1024) { EXPECT_EQ(i / 4, data[i]) << "Wrong output at " << i; @@ -3023,7 +3018,7 @@ TEST_F(RLEv1Test, seekTest) { // Seek to end std::vector position; - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(0); dwio::common::PositionProvider pp{position}; rle->seekToRowGroup(pp); @@ -3033,7 +3028,7 @@ TEST_F(RLEv1Test, seekTest) { // Seek to end + 1 position.clear(); - position.push_back(VELOX_ARRAY_SIZE(buffer)); + position.push_back(std::size(buffer)); position.push_back(1); dwio::common::PositionProvider pp2{position}; // Seek is fine (because it's lazy), but read should fail @@ -3049,7 +3044,7 @@ TEST_F(RLEv1Test, testLeadingNulls) { createRleDecoder( std::unique_ptr( new dwio::common::SeekableArrayInputStream( - buffer, VELOX_ARRAY_SIZE(buffer))), + buffer, std::size(buffer))), RleVersion_1, *pool, true, diff --git a/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp b/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp index 2dfd93797ec0..f8c33d91ff6d 100644 --- a/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp +++ b/velox/dwio/dwrf/test/TestStringDictionaryEncoder.cpp @@ -20,8 +20,6 @@ DECLARE_bool(velox_enable_memory_usage_track_in_default_memory_pool); -using namespace facebook::velox::memory; - namespace facebook::velox::dwrf { class TestStringDictionaryEncoder : public ::testing::Test { @@ -35,10 +33,10 @@ class TestStringDictionaryEncoder : public ::testing::Test { TEST_F(TestStringDictionaryEncoder, AddKey) { struct TestCase { explicit TestCase( - const std::vector& addKeySequence, + const std::vector& addKeySequence, const std::vector& encodedSequence) : addKeySequence{addKeySequence}, encodedSequence{encodedSequence} {} - std::vector addKeySequence; + std::vector addKeySequence; std::vector encodedSequence; }; @@ -50,7 +48,7 @@ TEST_F(TestStringDictionaryEncoder, AddKey) { TestCase{{"doe", "sow", "sow", "doe", "sow"}, {0, 1, 1, 0, 1}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; std::vector actualEncodedSequence{}; for (const auto& key : testCase.addKeySequence) { @@ -63,14 +61,14 @@ TEST_F(TestStringDictionaryEncoder, AddKey) { TEST_F(TestStringDictionaryEncoder, GetIndex) { struct TestCase { explicit TestCase( - const std::vector& addKeySequence, - const std::vector& getIndexSequence, + const std::vector& addKeySequence, + const std::vector& getIndexSequence, const std::vector& encodedSequence) : addKeySequence{addKeySequence}, getIndexSequence{getIndexSequence}, encodedSequence{encodedSequence} {} - std::vector addKeySequence; - std::vector getIndexSequence; + std::vector addKeySequence; + std::vector getIndexSequence; std::vector encodedSequence; }; @@ -94,7 +92,7 @@ TEST_F(TestStringDictionaryEncoder, GetIndex) { {0, 3, 4, 2, 1, 3, 2, 4, 2, 0, 1, 0, 3}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; for (const auto& key : testCase.addKeySequence) { stringDictEncoder.addKey(key, 0); @@ -111,14 +109,14 @@ TEST_F(TestStringDictionaryEncoder, GetIndex) { TEST_F(TestStringDictionaryEncoder, GetCount) { struct TestCase { explicit TestCase( - const std::vector& addKeySequence, - const std::vector& getCountSequence, + const std::vector& addKeySequence, + const std::vector& getCountSequence, const std::vector& countSequence) : addKeySequence{addKeySequence}, getCountSequence{getCountSequence}, countSequence{countSequence} {} - std::vector addKeySequence; - std::vector getCountSequence; + std::vector addKeySequence; + std::vector getCountSequence; std::vector countSequence; }; @@ -143,7 +141,7 @@ TEST_F(TestStringDictionaryEncoder, GetCount) { {3, 2, 3, 3, 2}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; for (const auto& key : testCase.addKeySequence) { stringDictEncoder.addKey(key, 0); @@ -161,15 +159,14 @@ TEST_F(TestStringDictionaryEncoder, GetCount) { TEST_F(TestStringDictionaryEncoder, GetStride) { struct TestCase { explicit TestCase( - const std::vector>& - addKeySequence, - const std::vector& getStrideSequence, + const std::vector>& addKeySequence, + const std::vector& getStrideSequence, const std::vector& strideSequence) : addKeySequence{addKeySequence}, getStrideSequence{getStrideSequence}, strideSequence{strideSequence} {} - std::vector> addKeySequence; - std::vector getStrideSequence; + std::vector> addKeySequence; + std::vector getStrideSequence; std::vector strideSequence; }; @@ -197,7 +194,7 @@ TEST_F(TestStringDictionaryEncoder, GetStride) { {1, 1, 6, 3, 4}}}; for (const auto& testCase : testCases) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; for (const auto& kv : testCase.addKeySequence) { stringDictEncoder.addKey(kv.first, kv.second); @@ -220,7 +217,7 @@ std::string genPaddedIntegerString(size_t integer, size_t length) { } TEST_F(TestStringDictionaryEncoder, Clear) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; std::string baseString{"jjkkll"}; for (size_t i = 0; i != 2500; ++i) { @@ -242,7 +239,7 @@ TEST_F(TestStringDictionaryEncoder, Clear) { } TEST_F(TestStringDictionaryEncoder, MemBenchmark) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder stringDictEncoder{*pool, *pool}; std::string baseString{"jjkkll"}; for (size_t i = 0; i != 10000; ++i) { @@ -253,15 +250,15 @@ TEST_F(TestStringDictionaryEncoder, MemBenchmark) { } TEST_F(TestStringDictionaryEncoder, Limit) { - auto pool = memoryManager()->addLeafPool(); + auto pool = memory::memoryManager()->addLeafPool(); StringDictionaryEncoder encoder{*pool, *pool}; - encoder.addKey(folly::StringPiece{"abc"}, 0); + encoder.addKey(std::string_view{"abc"}, 0); dwio::common::DataBuffer buf{*pool}; buf.resize(std::numeric_limits::max()); ASSERT_THROW( encoder.addKey( - folly::StringPiece{ + std::string_view{ buf.data(), std::numeric_limits::max() - 3}, 0), dwio::common::exception::LoggedException); diff --git a/velox/dwio/dwrf/test/TestStripeStream.cpp b/velox/dwio/dwrf/test/TestStripeStream.cpp index 2e9132aeb5fb..9dff9261c386 100644 --- a/velox/dwio/dwrf/test/TestStripeStream.cpp +++ b/velox/dwio/dwrf/test/TestStripeStream.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/encryption/TestProvider.h" #include "velox/dwio/dwrf/reader/StripeStream.h" #include "velox/dwio/dwrf/test/OrcTest.h" @@ -42,8 +43,8 @@ class RecordingInputStream : public facebook::velox::InMemoryReadFile { uint64_t offset, uint64_t length, void* buf, - facebook::velox::filesystems::File::IoStats* stats = - nullptr) const override { + const facebook::velox::FileStorageContext& fileStorageContext = {}) + const override { reads_.push_back({offset, length}); return {static_cast(buf), length}; } @@ -158,10 +159,11 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(StripeStreamFormatTypeTest, planReads) { google::protobuf::Arena arena; - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); auto is = std::make_unique(); auto isPtr = is.get(); auto readerBase = std::make_shared( @@ -247,10 +249,11 @@ TEST_P(StripeStreamFormatTypeTest, planReads) { TEST_F(StripeStreamTest, filterSequences) { google::protobuf::Arena arena; - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); auto type = HiveTypeParser().parse("struct>"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); auto is = std::make_unique(); auto isPtr = is.get(); auto readerBase = std::make_shared( @@ -311,10 +314,11 @@ TEST_F(StripeStreamTest, filterSequences) { TEST_P(StripeStreamFormatTypeTest, zeroLength) { google::protobuf::Arena arena; - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); proto::PostScript ps; ps.set_compressionblocksize(1024); ps.set_compression(proto::CompressionKind::ZSTD); @@ -437,12 +441,13 @@ TEST_P(StripeStreamFormatTypeTest, planReadsIndex) { index.SerializeToOstream(&buffer); // build footer - auto footer = google::protobuf::Arena::CreateMessage(&arena); - footer->set_rowindexstride(100); - footer->add_stripecacheoffsets(0); - footer->add_stripecacheoffsets(buffer.tellp()); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); + footerWrapper.setRowIndexStride(100); + footerWrapper.addStripeCacheOffsets(0); + footerWrapper.addStripeCacheOffsets(buffer.tellp()); auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); // build cache std::string str(buffer.str()); @@ -596,23 +601,24 @@ TEST_F(StripeStreamTest, readEncryptedStreams) { proto::PostScript ps; ps.set_compression(proto::CompressionKind::ZSTD); ps.set_compressionblocksize(256 * 1024); - auto footer = google::protobuf::Arena::CreateMessage(&arena); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); // a: not encrypted, projected // encryption group 1: b, c. projected b. // group 2: d. projected d. // group 3: e. not projected auto type = HiveTypeParser().parse("struct"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); - auto enc = footer->mutable_encryption(); + auto enc = footerWrapper.mutableEncryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); addEncryptionGroup(*enc, {2, 3}); addEncryptionGroup(*enc, {4}); addEncryptionGroup(*enc, {5}); - auto stripe = footer->add_stripes(); + auto stripe = footerWrapper.addStripes(); for (auto i = 0; i < 3; ++i) { - *stripe->add_keymetadata() = folly::to("key", i); + *stripe.addKeyMetadata() = folly::to("key", i); } TestDecrypterFactory factory; auto handler = DecryptionHandler::create(FooterWrapper(footer), &factory); @@ -689,19 +695,20 @@ TEST_F(StripeStreamTest, schemaMismatch) { proto::PostScript ps; ps.set_compression(proto::CompressionKind::ZSTD); ps.set_compressionblocksize(256 * 1024); - auto footer = google::protobuf::Arena::CreateMessage(&arena); + auto footer = ArenaCreate(&arena); + auto footerWrapper = FooterWriteWrapper(footer); // a: not encrypted, has schema change // b: encrypted // c: not encrypted auto type = HiveTypeParser().parse("struct,b:int,c:int>"); - ProtoUtils::writeType(*type, *footer); + ProtoUtils::writeType(*type, footerWrapper); - auto enc = footer->mutable_encryption(); + auto enc = footerWrapper.mutableEncryption(); enc->set_keyprovider(proto::Encryption_KeyProvider_UNKNOWN); addEncryptionGroup(*enc, {3}); - auto stripe = footer->add_stripes(); - *stripe->add_keymetadata() = "key"; + auto stripe = footerWrapper.addStripes(); + *stripe.addKeyMetadata() = "key"; TestDecrypterFactory factory; auto handler = DecryptionHandler::create(FooterWrapper(footer), &factory); TestEncrypter encrypter; diff --git a/velox/dwio/dwrf/test/WriterFlushTest.cpp b/velox/dwio/dwrf/test/WriterFlushTest.cpp index b8dfcac394f1..f55d3bc44495 100644 --- a/velox/dwio/dwrf/test/WriterFlushTest.cpp +++ b/velox/dwio/dwrf/test/WriterFlushTest.cpp @@ -146,8 +146,9 @@ class MockMemoryPool : public velox::memory::MemoryPool { VELOX_UNSUPPORTED("allocateContiguous unsupported"); } - void freeContiguous(velox::memory::ContiguousAllocation& - /*unused*/) override { + void freeContiguous( + velox::memory::ContiguousAllocation& + /*unused*/) override { VELOX_UNSUPPORTED("freeContiguous unsupported"); } diff --git a/velox/dwio/dwrf/test/WriterTest.cpp b/velox/dwio/dwrf/test/WriterTest.cpp index 7126ef0b4ecd..82488fb39f28 100644 --- a/velox/dwio/dwrf/test/WriterTest.cpp +++ b/velox/dwio/dwrf/test/WriterTest.cpp @@ -75,7 +75,7 @@ class WriterTest : public Test { return writer_->getFooter(); } - auto& addStripeInfo() { + StripeInformationWriteWrapper addStripeInfo() { return writer_->addStripeInfo(); } @@ -168,13 +168,12 @@ TEST_P(AllWriterCompressionTest, compression) { folly::to(i), folly::to(i + 1)); } for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } if (compressionKind_ == CompressionKind::CompressionKind_SNAPPY || compressionKind_ == CompressionKind::CompressionKind_LZO || compressionKind_ == CompressionKind::CompressionKind_LZ4 || - compressionKind_ == CompressionKind::CompressionKind_GZIP || compressionKind_ == CompressionKind::CompressionKind_MAX) { VELOX_ASSERT_THROW( writeFooter(*schema), @@ -231,7 +230,7 @@ TEST_P(SupportedCompressionTest, WriteFooter) { folly::to(i), folly::to(i + 1)); } for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } writeFooter(*schema); writer.close(); @@ -263,11 +262,11 @@ TEST_P(SupportedCompressionTest, WriteFooter) { ASSERT_EQ(footer.metadataSize(), 5); for (size_t i = 0; i < 4; ++i) { auto item = footer.metadata(i); - if (item.name() == WRITER_NAME_KEY) { + if (item.name() == kWriterNameKey) { ASSERT_EQ(item.value(), kDwioWriter); - } else if (item.name() == WRITER_VERSION_KEY) { + } else if (item.name() == kWriterVersionKey) { ASSERT_EQ(item.value(), folly::to(reader->writerVersion())); - } else if (item.name() == WRITER_HOSTNAME_KEY) { + } else if (item.name() == kWriterHostnameKey) { ASSERT_EQ(item.value(), process::getHostName()); } else { ASSERT_EQ( @@ -307,9 +306,9 @@ TEST_P(SupportedCompressionTest, AddStripeInfo) { writerSink.addBuffer(*pool_, data.data(), data.size()); writerSink.setMode(WriterSink::Mode::None); - auto& ret = addStripeInfo(); - ASSERT_EQ(ret.numberofrows(), 101); - ASSERT_EQ(ret.rawdatasize(), 202); + auto ret = addStripeInfo(); + ASSERT_EQ(ret.numberOfRows(), 101); + ASSERT_EQ(ret.rawDataSize(), 202); ASSERT_EQ(ret.checksum(), 8963334039576633799); writer.close(); } @@ -327,14 +326,14 @@ TEST_P(SupportedCompressionTest, NoChecksum) { writerSink.addBuffer(*pool_, data.data(), data.size()); writerSink.setMode(WriterSink::Mode::None); - auto& ret = addStripeInfo(); - ASSERT_FALSE(ret.has_checksum()); + auto ret = addStripeInfo(); + ASSERT_FALSE(ret.hasChecksum()); std::string typeStr{"struct"}; HiveTypeParser parser; auto schema = parser.parse(typeStr); for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } writeFooter(*schema); writer.close(); @@ -369,7 +368,7 @@ TEST_P(SupportedCompressionTest, NoCache) { HiveTypeParser parser; auto schema = parser.parse(typeStr); for (size_t i = 0; i < 4; ++i) { - getFooter().add_statistics(); + getFooter()->addStatistics(); } writeFooter(*schema); writer.close(); @@ -457,7 +456,20 @@ class MockFileSink : public dwio::common::FileSink { MOCK_METHOD(uint64_t, size, (), (const override)); MOCK_METHOD(bool, isBuffered, (), (const override)); +// On Centos9 the gtest mock header doesn't initialize the +// buffer_ member in MatcherBase correctly - the default constructor only +// initializes one: /usr/include/gtest/gtest-matchers.h:302:33 resulting in +// error: +// '.testing::Matcher::.testing::internal::MatcherBase::buffer_' is used uninitialized +// [-Werror=uninitialized] +// 302 | : vtable_(other.vtable_), buffer_(other.buffer_) { +// Fix: https://github.com/google/googletest/pull/3797 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" MOCK_METHOD(void, write, (std::vector>&)); +#pragma GCC diagnostic pop }; TEST_F(WriterTest, FlushWriterSinkUponClose) { diff --git a/velox/dwio/dwrf/test/utils/CMakeLists.txt b/velox/dwio/dwrf/test/utils/CMakeLists.txt index 16f035f7670a..617b90490771 100644 --- a/velox/dwio/dwrf/test/utils/CMakeLists.txt +++ b/velox/dwio/dwrf/test/utils/CMakeLists.txt @@ -27,4 +27,5 @@ target_link_libraries( velox_type velox_vector GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) diff --git a/velox/dwio/dwrf/utils/CMakeLists.txt b/velox/dwio/dwrf/utils/CMakeLists.txt index 9fff6d39972b..74655a9b0b1d 100644 --- a/velox/dwio/dwrf/utils/CMakeLists.txt +++ b/velox/dwio/dwrf/utils/CMakeLists.txt @@ -17,5 +17,4 @@ if(${VELOX_BUILD_TESTING}) endif() velox_add_library(velox_dwio_dwrf_utils ProtoUtils.cpp BitIterator.h) -velox_link_libraries(velox_dwio_dwrf_utils velox_dwio_dwrf_common velox_type - velox_memory) +velox_link_libraries(velox_dwio_dwrf_utils velox_dwio_dwrf_common velox_type velox_memory) diff --git a/velox/dwio/dwrf/utils/ProtoUtils.cpp b/velox/dwio/dwrf/utils/ProtoUtils.cpp index 405d2e79ddfb..8cfbf4594405 100644 --- a/velox/dwio/dwrf/utils/ProtoUtils.cpp +++ b/velox/dwio/dwrf/utils/ProtoUtils.cpp @@ -51,31 +51,34 @@ CREATE_TYPE_TRAIT(ROW, STRUCT) void ProtoUtils::writeType( const Type& type, - proto::Footer& footer, - proto::Type* parent) { - auto self = footer.add_types(); + FooterWriteWrapper& footer, + TypeWriteWrapper* parent) { + auto self = footer.addTypes(); if (parent) { - parent->add_subtypes(footer.types_size() - 1); + parent->addSubtypes(footer.typesSize() - 1); } + auto kind = VELOX_STATIC_FIELD_DYNAMIC_DISPATCH(SchemaType, kind, type.kind()); - self->set_kind(kind); + auto typeKindWrapper = TypeKindWrapper(&kind); + self.setKind(typeKindWrapper); + switch (type.kind()) { case TypeKind::ROW: { auto& row = type.asRow(); for (size_t i = 0; i < row.size(); ++i) { - self->add_fieldnames(row.nameOf(i)); - writeType(*row.childAt(i), footer, self); + self.addFieldnames(row.nameOf(i)); + writeType(*row.childAt(i), footer, &self); } break; } case TypeKind::ARRAY: - writeType(*type.asArray().elementType(), footer, self); + writeType(*type.asArray().elementType(), footer, &self); break; case TypeKind::MAP: { auto& map = type.asMap(); - writeType(*map.keyType(), footer, self); - writeType(*map.valueType(), footer, self); + writeType(*map.keyType(), footer, &self); + writeType(*map.valueType(), footer, &self); break; } default: diff --git a/velox/dwio/dwrf/utils/ProtoUtils.h b/velox/dwio/dwrf/utils/ProtoUtils.h index bea1d1003303..310aaf7eb18b 100644 --- a/velox/dwio/dwrf/utils/ProtoUtils.h +++ b/velox/dwio/dwrf/utils/ProtoUtils.h @@ -17,6 +17,7 @@ #pragma once #include "velox/dwio/common/SeekableInputStream.h" +#include "velox/dwio/dwrf/common/FileMetadata.h" #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" #include "velox/type/Type.h" @@ -26,8 +27,8 @@ class ProtoUtils final { public: static void writeType( const Type& type, - proto::Footer& footer, - proto::Type* parent = nullptr); + FooterWriteWrapper&, + TypeWriteWrapper* parent = nullptr); static std::shared_ptr fromFooter( const proto::Footer& footer, diff --git a/velox/dwio/dwrf/utils/test/CMakeLists.txt b/velox/dwio/dwrf/utils/test/CMakeLists.txt index 9da234fd235e..982d47bb2ea2 100644 --- a/velox/dwio/dwrf/utils/test/CMakeLists.txt +++ b/velox/dwio/dwrf/utils/test/CMakeLists.txt @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_dwio_dwrf_utils_test - ProtoUtilsTests.cpp BitIteratorTests.cpp BufferedWriterTest.cpp) +add_executable( + velox_dwio_dwrf_utils_test + ProtoUtilsTests.cpp + BitIteratorTests.cpp + BufferedWriterTest.cpp +) add_test(velox_dwio_dwrf_utils_test velox_dwio_dwrf_utils_test) @@ -25,4 +29,5 @@ target_link_libraries( velox_type_fbhive GTest::gtest GTest::gtest_main - glog::glog) + glog::glog +) diff --git a/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp b/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp index 92c6d01b00a4..f32700513d24 100644 --- a/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp +++ b/velox/dwio/dwrf/utils/test/ProtoUtilsTests.cpp @@ -31,7 +31,8 @@ TEST(ProtoUtilsTests, AllTypes) { HiveTypeParser parser; auto schema = parser.parse(type); proto::Footer footer; - ProtoUtils::writeType(*schema, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*schema, footerWrapper); auto out = ProtoUtils::fromFooter(footer); auto str = HiveTypeSerializer::serialize(out); @@ -45,7 +46,8 @@ TEST(ProtoUtilsTests, Projection) { auto schema = parser.parse( "struct>"); proto::Footer footer; - ProtoUtils::writeType(*schema, footer); + auto footerWrapper = FooterWriteWrapper(&footer); + ProtoUtils::writeType(*schema, footerWrapper); auto type = ProtoUtils::fromFooter( footer, [](auto id) { return id != 2 && id != 5; }); diff --git a/velox/dwio/dwrf/writer/CMakeLists.txt b/velox/dwio/dwrf/writer/CMakeLists.txt index 3150c30e55b8..ec121ef6aaed 100644 --- a/velox/dwio/dwrf/writer/CMakeLists.txt +++ b/velox/dwio/dwrf/writer/CMakeLists.txt @@ -24,7 +24,8 @@ velox_add_library( Writer.cpp WriterBase.cpp WriterContext.cpp - WriterSink.cpp) + WriterSink.cpp +) velox_link_libraries( velox_dwio_dwrf_writer @@ -38,6 +39,6 @@ velox_link_libraries( velox_vector Boost::headers lz4::lz4 - lzo2::lzo2 zstd::zstd - ZLIB::ZLIB) + ZLIB::ZLIB +) diff --git a/velox/dwio/dwrf/writer/ColumnWriter.cpp b/velox/dwio/dwrf/writer/ColumnWriter.cpp index 2a4cf2077961..2100d18fd8a4 100644 --- a/velox/dwio/dwrf/writer/ColumnWriter.cpp +++ b/velox/dwio/dwrf/writer/ColumnWriter.cpp @@ -68,8 +68,9 @@ class ByteRleColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); data_->flush(); } @@ -248,8 +249,9 @@ class IntegerColumnWriter : public BaseColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { tryAbandonDictionaries(false); initStreamWriters(useDictionaryEncoding_); @@ -287,12 +289,14 @@ class IntegerColumnWriter : public BaseColumnWriter { // FIXME: call base class set encoding first to deal with sequence and // whatnot. - void setEncoding(proto::ColumnEncoding& encoding) const override { + void setEncoding(ColumnEncodingWriteWrapper& encoding) const override { BaseColumnWriter::setEncoding(encoding); if (useDictionaryEncoding_) { - encoding.set_kind( - proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY); - encoding.set_dictionarysize(finalDictionarySize_); + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY; + encoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); + + encoding.setDictionarySize(finalDictionarySize_); } } @@ -679,8 +683,9 @@ class TimestampColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); seconds_->flush(); nanos_->flush(); @@ -788,8 +793,9 @@ class DecimalColumnWriter : public BaseColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); unscaledValues_->flush(); scales_->flush(); @@ -954,8 +960,9 @@ class StringColumnWriter : public BaseColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { tryAbandonDictionaries(false); initStreamWriters(useDictionaryEncoding_); @@ -1000,12 +1007,13 @@ class StringColumnWriter : public BaseColumnWriter { // FIXME: call base class set encoding first to deal with sequence and // whatnot. - void setEncoding(proto::ColumnEncoding& encoding) const override { + void setEncoding(ColumnEncodingWriteWrapper& encoding) const override { BaseColumnWriter::setEncoding(encoding); if (useDictionaryEncoding_) { - encoding.set_kind( - proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY); - encoding.set_dictionarysize(finalDictionarySize_); + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DICTIONARY; + encoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); + encoding.setDictionarySize(finalDictionarySize_); } } @@ -1230,10 +1238,12 @@ uint64_t StringColumnWriter::writeDict( size_t strideIndex = strideOffsets_.size() - 1; uint64_t rawSize = 0; auto processRow = [&](size_t pos) { - auto sp = decodedVector.valueAt(pos); - rows_.unsafeAppend(dictEncoder_.addKey(sp, strideIndex)); - statsBuilder.addValues(sp); - rawSize += sp.size(); + auto sv = decodedVector.valueAt(pos); + // TODO: Remove explicit std::string_view cast. + rows_.unsafeAppend(dictEncoder_.addKey(std::string_view(sv), strideIndex)); + // TODO: Remove explicit std::string_view cast. + statsBuilder.addValues(std::string_view(sv)); + rawSize += sv.size(); }; uint64_t nullCount = 0; @@ -1274,10 +1284,11 @@ uint64_t StringColumnWriter::writeDirect( uint64_t rawSize = 0; auto processRow = [&](size_t pos) { - auto sp = decodedVector.valueAt(pos); - auto size = sp.size(); - dataDirect_->write(sp.data(), size); - statsBuilder.addValues(sp); + auto sv = decodedVector.valueAt(pos); + auto size = sv.size(); + dataDirect_->write(sv.data(), size); + // TODO: Remove explicit std::string_view cast. + statsBuilder.addValues(std::string_view(sv)); rawSize += size; lengths.unsafeAppend(size); }; @@ -1481,8 +1492,9 @@ class FloatColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); data_.flush(); } @@ -1612,8 +1624,9 @@ class BinaryColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); data_.flush(); lengths_->flush(); @@ -1726,8 +1739,9 @@ class StructColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); for (auto& c : children_) { c->flush(encodingFactory); @@ -1855,8 +1869,9 @@ class ListColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); lengths_->flush(); children_.at(0)->flush(encodingFactory); @@ -1982,8 +1997,9 @@ class MapColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override { + std::function encodingFactory, + std::function encodingOverride) + override { BaseColumnWriter::flush(encodingFactory, encodingOverride); lengths_->flush(); children_.at(0)->flush(encodingFactory); @@ -2119,7 +2135,7 @@ std::unique_ptr BaseColumnWriter::create( "MAP_FLAT_COLS contains column {}, but the root type of this column is {}." " Column root types must be of type MAP", type.column(), - mapTypeKindToName(type.type()->kind())); + TypeKindName::toName(type.type()->kind())); } const auto structColumnKeys = context.getConfig(Config::MAP_FLAT_COLS_STRUCT_KEYS); @@ -2212,7 +2228,7 @@ std::unique_ptr BaseColumnWriter::create( } default: VELOX_FAIL( - "not supported yet: {}", mapTypeKindToName(type.type()->kind())); + "not supported yet: {}", TypeKindName::toName(type.type()->kind())); } } } // namespace facebook::velox::dwrf diff --git a/velox/dwio/dwrf/writer/ColumnWriter.h b/velox/dwio/dwrf/writer/ColumnWriter.h index 98c5691babb9..ec7c8f09ab19 100644 --- a/velox/dwio/dwrf/writer/ColumnWriter.h +++ b/velox/dwio/dwrf/writer/ColumnWriter.h @@ -45,12 +45,31 @@ class ColumnWriter { virtual void reset() = 0; virtual void flush( - std::function encodingFactory, - std::function encodingOverride = - [](auto& /* e */) {}) = 0; + std::function encodingFactory, + std::function encodingOverride = + [](auto /* e */) {}) { + VELOX_NYI(); + } + + virtual void flush( + std::function + encodingFactory, + std::function + encodingOverride = [](auto& /* e */) {}) { + VELOX_NYI(); + } virtual uint64_t writeFileStats( - std::function statsFactory) const = 0; + std::function statsFactory) + const { + VELOX_NYI(); + } + + virtual uint64_t writeFileStats( + std::function + statsFactory) const { + VELOX_NYI(); + } virtual bool tryAbandonDictionaries(bool force) = 0; @@ -61,11 +80,13 @@ class ColumnWriter { const uint32_t sequence) : id_{id}, sequence_{sequence}, context_{context} {} - virtual void setEncoding(proto::ColumnEncoding& encoding) const { - encoding.set_kind(proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DIRECT); - encoding.set_dictionarysize(0); - encoding.set_node(id_); - encoding.set_sequence(sequence_); + virtual void setEncoding(ColumnEncodingWriteWrapper& columnEncoding) const { + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_DIRECT; + columnEncoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); + columnEncoding.setDictionarySize(0); + columnEncoding.setNode(id_); + columnEncoding.setSequence(sequence_); } const uint32_t id_; @@ -99,9 +120,9 @@ class BaseColumnWriter : public ColumnWriter { } void flush( - std::function encodingFactory, - std::function encodingOverride = - [](auto& /* e */) {}) override { + std::function encodingFactory, + std::function encodingOverride = + [](auto /* e */) {}) override { if (!isRoot()) { present_->flush(); @@ -113,21 +134,22 @@ class BaseColumnWriter : public ColumnWriter { } } - auto& encoding = encodingFactory(id_); + auto encoding = encodingFactory(id_); setEncoding(encoding); encodingOverride(encoding); indexBuilder_->flush(); } - uint64_t writeFileStats(std::function - statsFactory) const override { - auto& stats = statsFactory(id_); + uint64_t writeFileStats( + std::function statsFactory) + const override { + auto stats = statsFactory(id_); fileStatsBuilder_->toProto(stats); const uint64_t size = context_.getPhysicalSizeAggregator(id_).getResult(); for (auto& child : children_) { child->writeFileStats(statsFactory); } - stats.set_size(size); + stats.setSize(size); return size; } diff --git a/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp b/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp index 09297b38be61..65bea8d379fb 100644 --- a/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp +++ b/velox/dwio/dwrf/writer/FlatMapColumnWriter.cpp @@ -18,10 +18,9 @@ #include #include "velox/type/Type.h" #include "velox/vector/ComplexVector.h" -#include "velox/vector/FlatVector.h" +#include "velox/vector/FlatMapVector.h" namespace facebook::velox::dwrf { - namespace { template @@ -84,15 +83,17 @@ FlatMapColumnWriter::FlatMapColumnWriter( template void FlatMapColumnWriter::setEncoding( - proto::ColumnEncoding& encoding) const { + ColumnEncodingWriteWrapper& encoding) const { BaseColumnWriter::setEncoding(encoding); - encoding.set_kind(proto::ColumnEncoding_Kind::ColumnEncoding_Kind_MAP_FLAT); + auto columnEncodingKind = + proto::ColumnEncoding_Kind::ColumnEncoding_Kind_MAP_FLAT; + encoding.setKind(ColumnEncodingKindWrapper(&columnEncodingKind)); } template void FlatMapColumnWriter::flush( - std::function encodingFactory, - std::function encodingOverride) { + std::function encodingFactory, + std::function encodingOverride) { BaseColumnWriter::flush(encodingFactory, encodingOverride); for (auto& pair : valueWriters_) { @@ -136,18 +137,18 @@ void FlatMapColumnWriter::createIndexEntry() { template uint64_t FlatMapColumnWriter::writeFileStats( - std::function statsFactory) const { - auto& stats = statsFactory(id_); + std::function statsFactory) const { + auto stats = statsFactory(id_); fileStatsBuilder_->toProto(stats); uint64_t size = context_.getPhysicalSizeAggregator(id_).getResult(); - auto& keyStats = statsFactory(keyType_.id()); + auto keyStats = statsFactory(keyType_.id()); keyFileStatsBuilder_->toProto(keyStats); auto keySize = context_.getPhysicalSizeAggregator(keyType_.id()).getResult(); - keyStats.set_size(keySize); + keyStats.setSize(keySize); valueFileStatsBuilder_->writeFileStats(statsFactory); - stats.set_size(size); + stats.setSize(size); return size; } @@ -175,6 +176,7 @@ void FlatMapColumnWriter::reset() { valueWriters_.clear(); rowsInStrides_.clear(); rowsInCurrentStride_ = 0; + totalRows_ = 0; } KeyInfo getKeyInfo(int64_t key) { @@ -199,11 +201,12 @@ ValueWriter& FlatMapColumnWriter::getValueWriter( } if (valueWriters_.size() >= maxKeyCount_) { - DWIO_RAISE(fmt::format( - "Too many map keys requested in (node {}, column {}). Allowed: {}", - id_, - type_.column(), - maxKeyCount_)); + DWIO_RAISE( + fmt::format( + "Too many map keys requested in (node {}, column {}). Allowed: {}", + id_, + type_.column(), + maxKeyCount_)); } auto keyInfo = getKeyInfo(key); @@ -239,27 +242,30 @@ ValueWriter& FlatMapColumnWriter::getValueWriter( template uint32_t updateKeyStatistics( typename TypeInfo::StatisticsBuilder& keyStatsBuilder, - typename TypeTraits::NativeType value) { - keyStatsBuilder.addValues(value); - return sizeof(typename TypeTraits::NativeType); + typename TypeTraits::NativeType value, + uint64_t count = 1) { + keyStatsBuilder.addValues(value, count); + return sizeof(typename TypeTraits::NativeType) * count; } template <> uint32_t updateKeyStatistics( StringStatisticsBuilder& keyStatsBuilder, - StringView value) { + StringView value, + uint64_t count) { auto size = value.size(); - keyStatsBuilder.addValues(folly::StringPiece{value.data(), size}); - return size; + keyStatsBuilder.addValues(std::string_view{value.data(), size}, count); + return size * count; } template <> uint32_t updateKeyStatistics( BinaryStatisticsBuilder& keyStatsBuilder, - StringView value) { + StringView value, + uint64_t count) { auto size = value.size(); - keyStatsBuilder.addValues(size); - return size; + keyStatsBuilder.addValues(size, count); + return size * count; } namespace { @@ -351,6 +357,11 @@ uint64_t FlatMapColumnWriter::write( const common::Ranges& ranges) { switch (slice->typeKind()) { case TypeKind::MAP: + if (slice->wrappedVector()->encoding() == + VectorEncoding::Simple::FLAT_MAP) { + return writeFlatMap(slice, ranges); + } + return writeMap(slice, ranges); case TypeKind::ROW: if (!structKeys_.empty()) { @@ -466,6 +477,7 @@ uint64_t FlatMapColumnWriter::writeMap( rawSize += nullCount * NULL_SIZE; } rowsInCurrentStride_ += mapCount; + totalRows_ += mapCount; indexStatsBuilder_->increaseValueCount(mapCount); indexStatsBuilder_->increaseRawSize(rawSize); return rawSize; @@ -486,6 +498,88 @@ common::Ranges getNonNullRanges(const common::Ranges& ranges, const Row& row) { return nonNullRanges; } +template +uint64_t FlatMapColumnWriter::writeFlatMap( + const VectorPtr& vector, + const common::Ranges& ranges) { + uint64_t rawSize = 0; + common::Ranges nonNullRanges; + + const FlatMapVector* flatMap = vector->as(); + if (flatMap) { + // FlatMap has no additional encodings. + writeNulls(vector, ranges); + nonNullRanges = getNonNullRanges(ranges, Flat{vector}); + } else { + // FlatMap has additional encodings, we need to decode. + auto localDecodedFlatMap = decode(vector, ranges); + auto& decodedFlatMap = localDecodedFlatMap.get(); + writeNulls(decodedFlatMap, ranges); + flatMap = decodedFlatMap.base()->template as(); + nonNullRanges = getNonNullRanges(ranges, Decoded{decodedFlatMap}); + } + + auto processFlatMap = [&](const auto& keys) { + const auto& inMaps = flatMap->inMaps(); + const auto& values = flatMap->mapValues(); + + for (size_t i = 0; i < flatMap->numDistinctKeys(); i++) { + const uint64_t* inMapsForKey = nullptr; + auto nonNullRangesForKey = nonNullRanges; + + if (i < inMaps.size() && inMaps[i] != nullptr) { + inMapsForKey = inMaps[i]->as(); + + // Filter out only values where the key is present in the map. + nonNullRangesForKey = nonNullRanges.filter([&inMapsForKey](auto i) { + return bits::isBitSet(inMapsForKey, i); + }); + } + const auto& valuesForKey = values[i]; + const auto& key = keys.valueAt(i); + + auto keySize = updateKeyStatistics( + *keyFileStatsBuilder_, key, nonNullRangesForKey.size()); + keyFileStatsBuilder_->increaseRawSize(keySize); + rawSize += keySize; + + ValueWriter& valueWriter = getValueWriter(key, nonNullRanges.size()); + valueWriter.writeBuffers( + nonNullRangesForKey, valuesForKey, nonNullRanges, inMapsForKey); + } + }; + + if (flatMap->distinctKeys()->isFlatEncoding()) { + processFlatMap(Flat{flatMap->distinctKeys()}); + } else { + auto localDecodedKeys = decode( + flatMap->distinctKeys(), + common::Ranges::of(0, flatMap->distinctKeys()->size())); + processFlatMap(Decoded{localDecodedKeys.get()}); + } + + totalRows_ += nonNullRanges.size(); + + for (auto& pair : valueWriters_) { + if (totalRows_ != pair.second.writtenValues()) { + if (pair.second.writtenValues() > totalRows_) { + DWIO_RAISE("Duplicated key in map: ", pair.first); + } + pair.second.backfill(nonNullRanges.size()); + } + } + + size_t numNullRows = ranges.size() - nonNullRanges.size(); + if (numNullRows > 0) { + indexStatsBuilder_->setHasNull(); + rawSize += numNullRows * NULL_SIZE; + } + rowsInCurrentStride_ += nonNullRanges.size(); + indexStatsBuilder_->increaseValueCount(nonNullRanges.size()); + indexStatsBuilder_->increaseRawSize(rawSize); + return rawSize; +} + template uint64_t FlatMapColumnWriter::writeRow( const VectorPtr& slice, @@ -536,6 +630,7 @@ uint64_t FlatMapColumnWriter::writeRow( rawSize += numNullRows * NULL_SIZE; } rowsInCurrentStride_ += nonNullRanges.size(); + totalRows_ += nonNullRanges.size(); indexStatsBuilder_->increaseValueCount(nonNullRanges.size()); indexStatsBuilder_->increaseRawSize(rawSize); return rawSize; diff --git a/velox/dwio/dwrf/writer/FlatMapColumnWriter.h b/velox/dwio/dwrf/writer/FlatMapColumnWriter.h index 6e50a0e435ac..ab92e520b76d 100644 --- a/velox/dwio/dwrf/writer/FlatMapColumnWriter.h +++ b/velox/dwio/dwrf/writer/FlatMapColumnWriter.h @@ -53,14 +53,15 @@ class ValueStatisticsBuilder { } uint64_t writeFileStats( - std::function statsFactory) const { - auto& stats = statsFactory(id_); + std::function statsFactory) + const { + auto stats = statsFactory(id_); statisticsBuilder_->toProto(stats); uint64_t size = context_.getPhysicalSizeAggregator(id_).getResult(); for (int32_t i = 0; i < children_.size(); ++i) { children_[i]->writeFileStats(statsFactory); } - stats.set_size(size); + stats.setSize(size); return size; } @@ -137,6 +138,7 @@ class ValueWriter { if (mapCount) { inMap_->add( inMapBuffer_.data(), common::Ranges::of(0, mapCount), nullptr); + writtenValues_ += mapCount; } if (values) { @@ -146,7 +148,7 @@ class ValueWriter { } // used for struct encoding writer - uint64_t writeBuffers( + void writeBuffers( const VectorPtr& values, const common::Ranges& nonNullRanges, const BufferPtr& inMapBuffer /* all 1 */) { @@ -155,12 +157,28 @@ class ValueWriter { inMapBuffer->as(), common::Ranges::of(0, nonNullRanges.size()), nullptr); + writtenValues_ += nonNullRanges.size(); } if (values) { - return columnWriter_->write(values, nonNullRanges); + columnWriter_->write(values, nonNullRanges); + } + } + + // used for flat map encoding writer + void writeBuffers( + const common::Ranges& valuesRanges, + const VectorPtr& values, + const common::Ranges& inMapRanges, + const uint64_t* inMapBuffer) { + if (inMapRanges.size()) { + inMap_->addBits(inMapBuffer, inMapRanges, nullptr, false); + writtenValues_ += inMapRanges.size(); + } + + if (valuesRanges.size()) { + columnWriter_->write(values, valuesRanges); } - return 0; } void backfill(uint32_t count) { @@ -171,6 +189,7 @@ class ValueWriter { inMapBuffer_.reserve(count); std::memset(inMapBuffer_.data(), 0, count); inMap_->add(inMapBuffer_.data(), common::Ranges::of(0, count), nullptr); + writtenValues_ += count; } uint32_t getSequence() const { @@ -191,15 +210,17 @@ class ValueWriter { columnWriter_->createIndexEntry(); } - void flush(std::function encodingFactory) { + void flush( + std::function encodingFactory) { inMap_->flush(); - columnWriter_->flush(encodingFactory, [&](auto& encoding) { - *encoding.mutable_key() = keyInfo_; + columnWriter_->flush(encodingFactory, [&](auto encoding) { + *encoding.mutableKey() = keyInfo_; }); } void reset() { columnWriter_->reset(); + writtenValues_ = 0; } void resizeBuffers(size_t inMap) { @@ -208,6 +229,10 @@ class ValueWriter { ranges_.clear(); } + size_t writtenValues() const { + return writtenValues_; + } + private: uint32_t sequence_; const proto::KeyInfo keyInfo_; @@ -216,6 +241,7 @@ class ValueWriter { dwio::common::DataBuffer inMapBuffer_; common::Ranges ranges_; const bool collectMapStats_; + size_t writtenValues_{0}; }; namespace { @@ -266,25 +292,29 @@ class FlatMapColumnWriter : public BaseColumnWriter { uint64_t write(const VectorPtr& slice, const common::Ranges& ranges) override; void flush( - std::function encodingFactory, - std::function encodingOverride) override; + std::function encodingFactory, + std::function encodingOverride) + override; void createIndexEntry() override; void reset() override; - uint64_t writeFileStats(std::function - statsFactory) const override; + uint64_t writeFileStats( + std::function statsFactory) + const override; private: using KeyType = typename TypeTraits::NativeType; - void setEncoding(proto::ColumnEncoding& encoding) const override; + void setEncoding(ColumnEncodingWriteWrapper& encoding) const override; ValueWriter& getValueWriter(KeyType key, uint32_t inMapSize); - // write() calls writeMap() or writeRow() depending on input type + // write() calls writeMap(), writeFlatMap(), or writeRow() depending on input + // type and encoding uint64_t writeMap(const VectorPtr& slice, const common::Ranges& ranges); + uint64_t writeFlatMap(const VectorPtr& slice, const common::Ranges& ranges); uint64_t writeRow(const VectorPtr& slice, const common::Ranges& ranges); void clearNodes(); @@ -300,6 +330,10 @@ class FlatMapColumnWriter : public BaseColumnWriter { // Captures current row count for current (incomplete) stride size_t rowsInCurrentStride_{0}; + // Captures current row count for current stripe (sum of rowsInStrides_ + + // rowsInCurrentStride_) + size_t totalRows_{0}; + // Remember key and value types. Needed for constructing value writers const dwio::common::TypeWithId& keyType_; const dwio::common::TypeWithId& valueType_; diff --git a/velox/dwio/dwrf/writer/IndexBuilder.h b/velox/dwio/dwrf/writer/IndexBuilder.h index 57582199fc56..88cbade7dc35 100644 --- a/velox/dwio/dwrf/writer/IndexBuilder.h +++ b/velox/dwio/dwrf/writer/IndexBuilder.h @@ -16,12 +16,14 @@ #pragma once +#include "velox/dwio/common/Arena.h" #include "velox/dwio/common/OutputStream.h" #include "velox/dwio/dwrf/common/wrap/dwrf-proto-wrapper.h" #include "velox/dwio/dwrf/writer/StatisticsBuilder.h" namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; using dwio::common::BufferedOutputStream; using dwio::common::PositionRecorder; @@ -35,41 +37,50 @@ constexpr int32_t PRESENT_STREAM_INDEX_ENTRIES_PAGED = class IndexBuilder : public PositionRecorder { public: - IndexBuilder(std::unique_ptr out) - : out_{std::move(out)} {} + IndexBuilder( + std::unique_ptr out, + dwio::common::FileFormat fileFormat = dwio::common::FileFormat::DWRF) + : out_{std::move(out)}, + arena_(std::make_unique()) { + auto rowIndex = ArenaCreate(arena_.get()); + auto rowIndexEntry = ArenaCreate(arena_.get()); + + index_ = std::make_unique(rowIndex); + entry_ = std::make_unique(rowIndexEntry); + } virtual ~IndexBuilder() = default; void add(uint64_t pos, int32_t index = -1) override { - getEntry(index)->add_positions(pos); + getEntry(index).addPositions(pos); } virtual void addEntry(const StatisticsBuilder& writer) { - auto* stats = entry_.mutable_statistics(); - writer.toProto(*stats); - *index_.add_entry() = entry_; - entry_.Clear(); + auto stats = entry_->mutableStatistics(); + writer.toProto(stats); + index_->addEntry(entry_); + entry_->clear(); } virtual size_t getEntrySize() const { - const int32_t size = index_.entry_size() + 1; + const int32_t size = index_->entrySize() + 1; VELOX_CHECK_GT(size, 0, "Invalid entry size or missing current entry."); return size; } virtual void flush() { // remove isPresent positions if none is null - index_.SerializeToZeroCopyStream(out_.get()); + index_->SerializeToZeroCopyStream(out_.get()); out_->flush(); - index_.Clear(); - entry_.Clear(); + index_->clear(); + entry_->clear(); } void capturePresentStreamOffset() { if (!presentStreamOffset_.has_value()) { - presentStreamOffset_ = entry_.positions_size(); + presentStreamOffset_ = entry_->positionsSize(); } else { - DWIO_ENSURE_EQ(presentStreamOffset_.value(), entry_.positions_size()); + DWIO_ENSURE_EQ(presentStreamOffset_.value(), entry_->positionsSize()); } } @@ -79,27 +90,28 @@ class IndexBuilder : public PositionRecorder { : PRESENT_STREAM_INDEX_ENTRIES_UNPAGED; // Only need to process entries that have been added to the row index - for (uint32_t i = 0; i < index_.entry_size(); ++i) { - index_.mutable_entry(i)->mutable_positions()->ExtractSubrange( - presentStreamOffset_.value(), streamCount, nullptr); + for (uint32_t i = 0; i < index_->entrySize(); ++i) { + index_->mutableEntry(i).mutablePositions( + presentStreamOffset_.value(), streamCount); } } private: - proto::RowIndexEntry* getEntry(int32_t index) { + RowIndexEntryWriteWrapper getEntry(int32_t index) { if (index < 0) { - return &entry_; - } else if (index < index_.entry_size()) { - return index_.mutable_entry(index); + return *entry_; + } else if (index < index_->entrySize()) { + return index_->mutableEntry(index); } else { - VELOX_CHECK_EQ(index, index_.entry_size()); - return &entry_; + VELOX_CHECK_EQ(index, index_->entrySize()); + return *entry_; } } const std::unique_ptr out_; - proto::RowIndex index_; - proto::RowIndexEntry entry_; + std::unique_ptr index_; + std::unique_ptr entry_; + std::unique_ptr arena_; std::optional presentStreamOffset_; friend class IndexBuilderTest; diff --git a/velox/dwio/dwrf/writer/LayoutPlanner.cpp b/velox/dwio/dwrf/writer/LayoutPlanner.cpp index 9ef2fb6223c9..334097a435f4 100644 --- a/velox/dwio/dwrf/writer/LayoutPlanner.cpp +++ b/velox/dwio/dwrf/writer/LayoutPlanner.cpp @@ -16,14 +16,19 @@ #include "velox/dwio/dwrf/writer/LayoutPlanner.h" +#include "velox/dwio/common/Arena.h" + namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; + StreamList getStreamList(WriterContext& context) { StreamList streams; streams.reserve(context.getStreamCount()); context.iterateUnSuppressedStreams([&](auto& pair) { - streams.push_back(std::make_pair( - std::addressof(pair.first), std::addressof(pair.second))); + streams.push_back( + std::make_pair( + std::addressof(pair.first), std::addressof(pair.second))); }); return streams; } @@ -114,10 +119,6 @@ bool EncodingIter::operator==(const EncodingIter& other) const { return current_ == other.current_; } -bool EncodingIter::operator!=(const EncodingIter& other) const { - return current_ != other.current_; -} - EncodingIter::reference EncodingIter::operator*() const { return *current_; } @@ -128,49 +129,54 @@ EncodingIter::pointer EncodingIter::operator->() const { EncodingManager::EncodingManager( const encryption::EncryptionHandler& encryptionHandler) - : encryptionHandler_{encryptionHandler} { + : encryptionHandler_{encryptionHandler}, + arena_{std::make_unique()} { initEncryptionGroups(); + auto dwrfStripeFooter = ArenaCreate(arena_.get()); + footer_ = std::make_unique(dwrfStripeFooter); } -proto::ColumnEncoding& EncodingManager::addEncodingToFooter(uint32_t nodeId) { +ColumnEncodingWriteWrapper EncodingManager::addEncodingToFooter( + uint32_t nodeId) { if (encryptionHandler_.isEncrypted(nodeId)) { auto index = encryptionHandler_.getEncryptionGroupIndex(nodeId); - return *encryptionGroups_.at(index).add_encoding(); + return ColumnEncodingWriteWrapper( + encryptionGroups_.at(index).add_encoding()); } else { - return *footer_.add_encoding(); + return footer_->addEncoding(); } } -proto::Stream* EncodingManager::addStreamToFooter( +StreamWriteWrapper EncodingManager::addStreamToFooter( uint32_t nodeId, uint32_t& currentIndex) { if (encryptionHandler_.isEncrypted(nodeId)) { currentIndex = encryptionHandler_.getEncryptionGroupIndex(nodeId); - return encryptionGroups_.at(currentIndex).add_streams(); + return StreamWriteWrapper(encryptionGroups_.at(currentIndex).add_streams()); } else { currentIndex = std::numeric_limits::max(); - return footer_.add_streams(); + return footer_->addStreams(); } } std::string* EncodingManager::addEncryptionGroupToFooter() { - return footer_.add_encryptiongroups(); + return footer_->addEncryptionGroups(); } proto::StripeEncryptionGroup EncodingManager::getEncryptionGroup(uint32_t i) { return encryptionGroups_.at(i); } -const proto::StripeFooter& EncodingManager::getFooter() const { - return footer_; +const StripeFooterWriteWrapper& EncodingManager::getFooter() const { + return *footer_; } EncodingIter EncodingManager::begin() const { - return EncodingIter::begin(footer_, encryptionGroups_); + return EncodingIter::begin(*footer_->dwrfPtr(), encryptionGroups_); } EncodingIter EncodingManager::end() const { - return EncodingIter::end(footer_, encryptionGroups_); + return EncodingIter::end(*footer_->dwrfPtr(), encryptionGroups_); } void EncodingManager::initEncryptionGroups() { diff --git a/velox/dwio/dwrf/writer/LayoutPlanner.h b/velox/dwio/dwrf/writer/LayoutPlanner.h index 558c1467bccb..ac22f9d736f2 100644 --- a/velox/dwio/dwrf/writer/LayoutPlanner.h +++ b/velox/dwio/dwrf/writer/LayoutPlanner.h @@ -48,7 +48,6 @@ class EncodingIter { EncodingIter& operator++(); EncodingIter operator++(int); bool operator==(const EncodingIter& other) const; - bool operator!=(const EncodingIter& other) const; reference operator*() const; pointer operator->() const; @@ -90,11 +89,11 @@ class EncodingManager : public EncodingContainer { const encryption::EncryptionHandler& encryptionHandler); virtual ~EncodingManager() override = default; - proto::ColumnEncoding& addEncodingToFooter(uint32_t nodeId); - proto::Stream* addStreamToFooter(uint32_t nodeId, uint32_t& currentIndex); + ColumnEncodingWriteWrapper addEncodingToFooter(uint32_t nodeId); + StreamWriteWrapper addStreamToFooter(uint32_t nodeId, uint32_t& currentIndex); std::string* addEncryptionGroupToFooter(); proto::StripeEncryptionGroup getEncryptionGroup(uint32_t i); - const proto::StripeFooter& getFooter() const; + const StripeFooterWriteWrapper& getFooter() const; EncodingIter begin() const override; EncodingIter end() const override; @@ -103,7 +102,8 @@ class EncodingManager : public EncodingContainer { void initEncryptionGroups(); const encryption::EncryptionHandler& encryptionHandler_; - proto::StripeFooter footer_; + std::unique_ptr footer_; + std::unique_ptr arena_; std::vector encryptionGroups_; }; diff --git a/velox/dwio/dwrf/writer/StatisticsBuilder.cpp b/velox/dwio/dwrf/writer/StatisticsBuilder.cpp index 3000230c1156..b716d1aa71aa 100644 --- a/velox/dwio/dwrf/writer/StatisticsBuilder.cpp +++ b/velox/dwio/dwrf/writer/StatisticsBuilder.cpp @@ -16,8 +16,12 @@ #include "velox/dwio/dwrf/writer/StatisticsBuilder.h" +#include "velox/dwio/common/Arena.h" + namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; + namespace { static bool isValidLength(const std::optional& length) { @@ -101,28 +105,30 @@ void StatisticsBuilder::merge( } } -void StatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void StatisticsBuilder::toProto(ColumnStatisticsWriteWrapper& stats) const { if (hasNull_.has_value()) { - stats.set_hasnull(hasNull_.value()); + stats.setHasNull(hasNull_.value()); } if (valueCount_.has_value()) { - stats.set_numberofvalues(valueCount_.value()); + stats.setNumberOfValues(valueCount_.value()); } if (rawSize_.has_value()) { - stats.set_rawsize(rawSize_.value()); + stats.setRawSize(rawSize_.value()); } if (size_.has_value()) { - stats.set_size(size_.value()); + stats.setSize(size_.value()); } } std::unique_ptr StatisticsBuilder::build() const { - proto::ColumnStatistics stats; + auto columnStatistics = ArenaCreate(arena_.get()); + auto stats = ColumnStatisticsWriteWrapper(columnStatistics); toProto(stats); + StatsContext context{WriterVersion_CURRENT}; - auto result = - buildColumnStatisticsFromProto(ColumnStatisticsWrapper(&stats), context); + auto result = buildColumnStatisticsFromProto( + ColumnStatisticsWrapper(columnStatistics), context); // We do not alter the proto since this is part of the file format // and the file format. The distinct count does not exist in the // file format but is added here for use in on demand sampling. @@ -230,13 +236,14 @@ void BooleanStatisticsBuilder::merge( mergeCount(trueCount_, stats->getTrueCount()); } -void BooleanStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void BooleanStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { StatisticsBuilder::toProto(stats); // Serialize type specific stats only if there is non-null values if (!isEmpty(*this) && trueCount_.has_value()) { - auto bStats = stats.mutable_bucketstatistics(); - DWIO_ENSURE_EQ(bStats->count_size(), 0); - bStats->add_count(trueCount_.value()); + auto bStats = stats.mutableBucketStatistics(); + DWIO_ENSURE_EQ(bStats.countSize(), 0); + bStats.addCount(trueCount_.value()); } } @@ -263,20 +270,21 @@ void IntegerStatisticsBuilder::merge( mergeWithOverflowCheck(sum_, stats->getSum()); } -void IntegerStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void IntegerStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { StatisticsBuilder::toProto(stats); // Serialize type specific stats only if there is non-null values if (!isEmpty(*this) && (min_.has_value() || max_.has_value() || sum_.has_value())) { - auto iStats = stats.mutable_intstatistics(); + auto iStats = stats.mutableIntegerStatistics(); if (min_.has_value()) { - iStats->set_minimum(min_.value()); + iStats.setMinimum(min_.value()); } if (max_.has_value()) { - iStats->set_maximum(max_.value()); + iStats.setMaximum(max_.value()); } if (sum_.has_value()) { - iStats->set_sum(sum_.value()); + iStats.setSum(sum_.value()); } } } @@ -305,20 +313,21 @@ void DoubleStatisticsBuilder::merge( } } -void DoubleStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void DoubleStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { StatisticsBuilder::toProto(stats); // Serialize type specific stats only if there is non-null values if (!isEmpty(*this) && (min_.has_value() || max_.has_value() || sum_.has_value())) { - auto dStats = stats.mutable_doublestatistics(); + auto dStats = stats.mutableDoubleStatistics(); if (min_.has_value()) { - dStats->set_minimum(min_.value()); + dStats.setMinimum(min_.value()); } if (max_.has_value()) { - dStats->set_maximum(max_.value()); + dStats.setMaximum(max_.value()); } if (sum_.has_value()) { - dStats->set_sum(sum_.value()); + dStats.setSum(sum_.value()); } } } @@ -361,22 +370,23 @@ void StringStatisticsBuilder::merge( mergeWithOverflowCheck(length_, stats->getTotalLength()); } -void StringStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void StringStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { StatisticsBuilder::toProto(stats); // If string value is too long, drop it and fall back to basic stats if (!isEmpty(*this) && (shouldKeep(min_) || shouldKeep(max_) || isValidLength(length_))) { - auto dStats = stats.mutable_stringstatistics(); + auto dStats = stats.mutableStringStatistics(); if (isValidLength(length_)) { - dStats->set_sum(length_.value()); + dStats.setSum(length_.value()); } if (shouldKeep(min_)) { - dStats->set_minimum(min_.value()); + dStats.setMinimum(min_.value()); } if (shouldKeep(max_)) { - dStats->set_maximum(max_.value()); + dStats.setMaximum(max_.value()); } } } @@ -399,12 +409,13 @@ void BinaryStatisticsBuilder::merge( mergeWithOverflowCheck(length_, stats->getTotalLength()); } -void BinaryStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void BinaryStatisticsBuilder::toProto( + ColumnStatisticsWriteWrapper& stats) const { StatisticsBuilder::toProto(stats); // Serialize type specific stats only if there is non-null values if (!isEmpty(*this) && isValidLength(length_)) { - auto bStats = stats.mutable_binarystatistics(); - bStats->set_sum(length_.value()); + auto bStats = stats.mutableBinaryStatistics(); + bStats.setSum(length_.value()); } } @@ -427,10 +438,10 @@ void MapStatisticsBuilder::merge( } } -void MapStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { +void MapStatisticsBuilder::toProto(ColumnStatisticsWriteWrapper& stats) const { StatisticsBuilder::toProto(stats); if (!isEmpty(*this) && !entryStatistics_.empty()) { - auto mapStats = stats.mutable_mapstatistics(); + auto mapStats = stats.mutableMapStatistics(); for (const auto& entry : entryStatistics_) { auto entryStatistics = mapStats->add_stats(); const auto& key = entry.first; @@ -440,8 +451,8 @@ void MapStatisticsBuilder::toProto(proto::ColumnStatistics& stats) const { } else if (key.bytesKey.has_value()) { entryStatistics->mutable_key()->set_byteskey(key.bytesKey.value()); } - dynamic_cast(*entry.second) - .toProto(*entryStatistics->mutable_stats()); + auto c = ColumnStatisticsWriteWrapper(entryStatistics->mutable_stats()); + dynamic_cast(*entry.second).toProto(c); } } } diff --git a/velox/dwio/dwrf/writer/StatisticsBuilder.h b/velox/dwio/dwrf/writer/StatisticsBuilder.h index 3441a5ddf9a1..497d78377468 100644 --- a/velox/dwio/dwrf/writer/StatisticsBuilder.h +++ b/velox/dwio/dwrf/writer/StatisticsBuilder.h @@ -109,7 +109,7 @@ class StatisticsBuilder : public virtual dwio::common::ColumnStatistics { public: /// Constructs with 'options'. explicit StatisticsBuilder(const StatisticsBuilderOptions& options) - : options_{options} { + : options_{options}, arena_(std::make_unique()) { init(); } @@ -177,7 +177,7 @@ class StatisticsBuilder : public virtual dwio::common::ColumnStatistics { /* * Write stats to proto */ - virtual void toProto(proto::ColumnStatistics& stats) const; + virtual void toProto(ColumnStatisticsWriteWrapper& stats) const; std::unique_ptr build() const; @@ -198,13 +198,14 @@ class StatisticsBuilder : public virtual dwio::common::ColumnStatistics { rawSize_ = 0; size_ = options_.initialSize; if (options_.countDistincts) { - hll_ = std::make_shared(options_.allocator); + hll_ = std::make_shared>(options_.allocator); } } protected: StatisticsBuilderOptions options_; - std::shared_ptr hll_; + std::shared_ptr> hll_; + std::unique_ptr arena_; }; class BooleanStatisticsBuilder : public StatisticsBuilder, @@ -233,7 +234,7 @@ class BooleanStatisticsBuilder : public StatisticsBuilder, init(); } - void toProto(proto::ColumnStatistics& stats) const override; + void toProto(ColumnStatisticsWriteWrapper& stats) const override; private: void init() { @@ -272,7 +273,7 @@ class IntegerStatisticsBuilder : public StatisticsBuilder, init(); } - void toProto(proto::ColumnStatistics& stats) const override; + void toProto(ColumnStatisticsWriteWrapper& stats) const override; private: void init() { @@ -332,7 +333,7 @@ class DoubleStatisticsBuilder : public StatisticsBuilder, init(); } - void toProto(proto::ColumnStatistics& stats) const override; + void toProto(ColumnStatisticsWriteWrapper& stats) const override; private: void init() { @@ -358,7 +359,7 @@ class StringStatisticsBuilder : public StatisticsBuilder, ~StringStatisticsBuilder() override = default; - void addValues(folly::StringPiece value, uint64_t count = 1) { + void addValues(std::string_view value, uint64_t count = 1) { // min_/max_ is not initialized with default that can be compared against // easily. So we need to capture whether self is empty and handle // differently. @@ -368,10 +369,10 @@ class StringStatisticsBuilder : public StatisticsBuilder, min_ = value; max_ = value; } else { - if (min_.has_value() && value < folly::StringPiece{min_.value()}) { + if (min_.has_value() && value < std::string_view{min_.value()}) { min_ = value; } - if (max_.has_value() && value > folly::StringPiece{max_.value()}) { + if (max_.has_value() && value > std::string_view{max_.value()}) { max_ = value; } } @@ -389,7 +390,7 @@ class StringStatisticsBuilder : public StatisticsBuilder, init(); } - void toProto(proto::ColumnStatistics& stats) const override; + void toProto(ColumnStatisticsWriteWrapper& stats) const override; private: uint32_t lengthLimit_; @@ -429,7 +430,7 @@ class BinaryStatisticsBuilder : public StatisticsBuilder, init(); } - void toProto(proto::ColumnStatistics& stats) const override; + void toProto(ColumnStatisticsWriteWrapper& stats) const override; private: void init() { @@ -477,7 +478,7 @@ class MapStatisticsBuilder : public StatisticsBuilder, init(); } - void toProto(proto::ColumnStatistics& stats) const override; + void toProto(ColumnStatisticsWriteWrapper& stats) const override; private: void init() { diff --git a/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp b/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp index 3babe373db2a..89dbcd1cfefb 100644 --- a/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp +++ b/velox/dwio/dwrf/writer/StatisticsBuilderUtils.cpp @@ -81,18 +81,18 @@ void StatisticsBuilderUtils::addValues( const VectorPtr& vector, const common::Ranges& ranges) { auto nulls = vector->rawNulls(); - auto data = vector->asFlatVector()->rawValues(); + auto* data = vector->asFlatVector()->rawValues(); if (vector->mayHaveNulls()) { for (auto& pos : ranges) { if (bits::isBitNull(nulls, pos)) { builder.setHasNull(); } else { - builder.addValues(folly::StringPiece{data[pos]}); + builder.addValues(std::string_view{data[pos]}); } } } else { for (auto& pos : ranges) { - builder.addValues(folly::StringPiece{data[pos]}); + builder.addValues(std::string_view{data[pos]}); } } } diff --git a/velox/dwio/dwrf/writer/StringDictionaryEncoder.h b/velox/dwio/dwrf/writer/StringDictionaryEncoder.h index 8ecb635f8655..20abfaac8582 100644 --- a/velox/dwio/dwrf/writer/StringDictionaryEncoder.h +++ b/velox/dwio/dwrf/writer/StringDictionaryEncoder.h @@ -30,22 +30,22 @@ namespace detail { // Each new string inserted into dictionary is assigned an incrementing id. // A set is maintained with all of the DictStringId created. Using -// Heterogeneous lookup techniques, incoming StringPiece is first looked for +// Heterogeneous lookup techniques, incoming string_view is first looked for // a match in the set. If no match exists a new id is generated and inserted // into the set. What is Heterogeneous lookup ? // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/p0919r1.html // Heterogeneous lookup is not available in standard CPP and proposed for CPP20. // Follys:F14* variant supports it, so leveraging folly for now. struct StringLookupKey { - StringLookupKey(folly::StringPiece sp, uint32_t index) - : sp{sp}, + StringLookupKey(std::string_view sv, uint32_t index) + : sv{sv}, index{index}, hash{folly::crc32c( - reinterpret_cast(sp.data()), - sp.size(), + reinterpret_cast(sv.data()), + sv.size(), 0 /* seed */)} {} - const folly::StringPiece sp; + const std::string_view sv; const uint32_t index; const uint32_t hash; }; @@ -117,9 +117,9 @@ class StringDictionaryEncoder { } uint32_t - addKey(folly::StringPiece sp, uint32_t strideIndex, uint32_t count = 1) { + addKey(std::string_view sv, uint32_t strideIndex, uint32_t count = 1) { auto newIndex = size(); - detail::StringLookupKey key{sp, newIndex}; + detail::StringLookupKey key{sv, newIndex}; auto result = keyIndex_.insert(key); if (!result.second) { auto index = result.first->getIndex(); @@ -130,12 +130,12 @@ class StringDictionaryEncoder { auto bytesCount = keyBytes_.size(); if (UNLIKELY( newIndex == std::numeric_limits::max() || - (std::numeric_limits::max() - bytesCount <= sp.size()))) { + (std::numeric_limits::max() - bytesCount <= sv.size()))) { DWIO_RAISE("exceeds dictionary size limit"); } // append keys - keyBytes_.extendAppend(bytesCount, sp.data(), sp.size()); + keyBytes_.extendAppend(bytesCount, sv.data(), sv.size()); keyOffsets_.append(keyBytes_.size()); hash_.append(key.hash); counts_.append(count); @@ -153,11 +153,11 @@ class StringDictionaryEncoder { return firstSeenStrideIndex_[index]; } - folly::StringPiece getKey(uint32_t index) const { + std::string_view getKey(uint32_t index) const { DCHECK(index < keyOffsets_.size() - 1); auto startOffset = keyOffsets_[index]; auto endOffset = keyOffsets_[index + 1]; - return folly::StringPiece{ + return std::string_view{ keyBytes_.data() + startOffset, endOffset - startOffset}; } @@ -178,8 +178,8 @@ class StringDictionaryEncoder { VELOX_FRIEND_TEST(TestStringDictionaryEncoder, Clear); // Intended for testing only. - uint32_t getIndex(folly::StringPiece sp) { - detail::StringLookupKey key{sp, 0}; + uint32_t getIndex(std::string_view sv) { + detail::StringLookupKey key{sv, 0}; auto result = keyIndex_.find(key); if (result != keyIndex_.end()) { return result->getIndex(); @@ -221,7 +221,7 @@ FOLLY_ALWAYS_INLINE bool DictStringIdEquality::operator()( FOLLY_ALWAYS_INLINE bool DictStringIdEquality::operator()( detail::StringLookupKey key, detail::DictStringId lhs) const { - return encoder_.getKey(lhs.getIndex()) == key.sp; + return encoder_.getKey(lhs.getIndex()) == key.sv; } FOLLY_ALWAYS_INLINE uint32_t diff --git a/velox/dwio/dwrf/writer/Writer.cpp b/velox/dwio/dwrf/writer/Writer.cpp index d6011c38f8de..26edce0d9105 100644 --- a/velox/dwio/dwrf/writer/Writer.cpp +++ b/velox/dwio/dwrf/writer/Writer.cpp @@ -213,10 +213,11 @@ Writer::Writer( : Writer{ std::move(sink), options, - options.memoryPool->addAggregateChild(fmt::format( - "{}.dwrf.{}", - options.memoryPool->name(), - folly::to(folly::Random::rand64())))} {} + options.memoryPool->addAggregateChild( + fmt::format( + "{}.dwrf.{}", + options.memoryPool->name(), + folly::to(folly::Random::rand64())))} {} void Writer::setMemoryReclaimers( const std::shared_ptr& pool) { @@ -521,7 +522,7 @@ void Writer::flushStripe(bool close) { const auto& handler = context.getEncryptionHandler(); EncodingManager encodingManager{handler}; - writer_->flush([&](uint32_t nodeId) -> proto::ColumnEncoding& { + writer_->flush([&](uint32_t nodeId) -> ColumnEncodingWriteWrapper { return encodingManager.addEncodingToFooter(nodeId); }); @@ -546,7 +547,8 @@ void Writer::flushStripe(bool close) { const DataBufferHolder& out) { uint32_t currentIndex = 0; const auto nodeId = stream.encodingKey().node(); - proto::Stream* s = encodingManager.addStreamToFooter(nodeId, currentIndex); + StreamWriteWrapper s = + encodingManager.addStreamToFooter(nodeId, currentIndex); // set offset only when needed, ie. when offset of current stream cannot be // calculated based on offset and length of previous stream. In that case, @@ -554,19 +556,19 @@ void Writer::flushStripe(bool close) { // encryption group or neither are encrypted. So the logic is simplified to // check if group index are the same for current and previous stream if (offset > 0 && lastIndex != currentIndex) { - s->set_offset(offset); + s.setOffset(offset); } lastIndex = currentIndex; // Jolly/Presto readers can't read streams bigger than 2GB. writerBase_->validateStreamSize(stream, out.size()); - s->set_kind(static_cast(stream.kind())); - s->set_node(nodeId); - s->set_column(stream.column()); - s->set_sequence(stream.encodingKey().sequence()); - s->set_length(out.size()); - s->set_usevints(context.getConfig(Config::USE_VINTS)); + s.setKind(stream.kind()); + s.setNode(nodeId); + s.setColumn(stream.column()); + s.setSequence(stream.encodingKey().sequence()); + s.setLength(out.size()); + s.setUseVints(context.getConfig(Config::USE_VINTS)); offset += out.size(); context.recordPhysicalSize(stream, out.size()); @@ -619,19 +621,20 @@ void Writer::flushStripe(bool close) { VELOX_CHECK_EQ(footerOffset, stripeOffset + dataLength + indexLength); sink.setMode(WriterSink::Mode::Footer); - writerBase_->writeProto(encodingManager.getFooter()); + encodingManager.getFooter().setWriterTimezone(); + writerBase_->writeProto(&encodingManager.getFooter()); sink.setMode(WriterSink::Mode::None); - auto& stripe = writerBase_->addStripeInfo(); - stripe.set_offset(stripeOffset); - stripe.set_indexlength(indexLength); - stripe.set_datalength(dataLength); - stripe.set_footerlength(sink.size() - footerOffset); + auto stripe = writerBase_->addStripeInfo(); + stripe.setOffset(stripeOffset); + stripe.setIndexLength(indexLength); + stripe.setDataLength(dataLength); + stripe.setFooterLength(sink.size() - footerOffset); // set encryption key metadata if (handler.isEncrypted() && context.stripeIndex() == 0) { for (uint32_t i = 0; i < handler.getEncryptionGroupCount(); ++i) { - *stripe.add_keymetadata() = + *stripe.addKeyMetadata() = handler.getEncryptionProviderByIndex(i).getKey(); } } @@ -694,10 +697,10 @@ void Writer::flushInternal(bool close) { proto::Encryption* encryption = nullptr; // initialize encryption related metadata only when there is data written - if (handler.isEncrypted() && footer.stripes_size() > 0) { + if (handler.isEncrypted() && footer->stripesSize() > 0) { const auto count = handler.getEncryptionGroupCount(); stats.resize(count); - encryption = footer.mutable_encryption(); + encryption = footer->mutableEncryption(); encryption->set_keyprovider( encryption::toProto(handler.getKeyProviderType())); for (uint32_t i = 0; i < count; ++i) { @@ -708,25 +711,28 @@ void Writer::flushInternal(bool close) { std::optional lastRoot; std::unordered_map statsMap; - writer_->writeFileStats([&](uint32_t nodeId) -> proto::ColumnStatistics& { - auto entry = footer.add_statistics(); - if (!encryption || !handler.isEncrypted(nodeId)) { - return *entry; - } + writer_->writeFileStats( + [&](uint32_t nodeId) -> ColumnStatisticsWriteWrapper { + auto entry = footer->addStatistics(); + if (!encryption || !handler.isEncrypted(nodeId)) { + return entry; + } - auto root = handler.getEncryptionRoot(nodeId); - auto groupIndex = handler.getEncryptionGroupIndex(nodeId); - auto& group = stats.at(groupIndex); - if (!lastRoot || root != lastRoot.value()) { - // this is a new root, add to the footer, and use a new slot - group.emplace_back(); - encryption->mutable_encryptiongroups(groupIndex)->add_nodes(root); - } - lastRoot = root; - auto encryptedStats = group.back().add_statistics(); - statsMap[entry] = encryptedStats; - return *encryptedStats; - }); + auto root = handler.getEncryptionRoot(nodeId); + auto groupIndex = handler.getEncryptionGroupIndex(nodeId); + auto& group = stats.at(groupIndex); + if (!lastRoot || root != lastRoot.value()) { + // this is a new root, add to the footer, and use a new slot + group.emplace_back(); + encryption->mutable_encryptiongroups(groupIndex)->add_nodes(root); + } + lastRoot = root; + auto encryptedStats = group.back().add_statistics(); + auto cs = + reinterpret_cast(entry.rawProtoPtr()); + statsMap[cs] = encryptedStats; + return ColumnStatisticsWriteWrapper(encryptedStats); + }); #define COPY_STAT(from, to, stat) \ if (from->has_##stat()) { \ @@ -770,7 +776,7 @@ void Writer::flushInternal(bool close) { dwio::common::MetricsLog::FileCloseMetrics{ .writerVersion = writerVersionToString( context.getConfig(Config::WRITER_VERSION)), - .footerLength = footer.contentlength(), + .footerLength = footer->contentLength(), .fileSize = sink.size(), .cacheSize = sink.getCacheSize(), .numCacheBlocks = sink.getCacheOffsets().size() - 1, diff --git a/velox/dwio/dwrf/writer/Writer.h b/velox/dwio/dwrf/writer/Writer.h index 3864b57790d1..4a7d1ab540b5 100644 --- a/velox/dwio/dwrf/writer/Writer.h +++ b/velox/dwio/dwrf/writer/Writer.h @@ -60,10 +60,11 @@ class Writer : public dwio::common::Writer { : Writer{ std::move(sink), options, - parentPool.addAggregateChild(fmt::format( - "{}.dwrf_{}", - parentPool.name(), - folly::to(folly::Random::rand64())))} {} + parentPool.addAggregateChild( + fmt::format( + "{}.dwrf_{}", + parentPool.name(), + folly::to(folly::Random::rand64())))} {} Writer( std::unique_ptr sink, @@ -134,7 +135,7 @@ class Writer : public dwio::common::Writer { return writerBase_->getContext(); } - const proto::Footer& getFooter() const { + const std::unique_ptr& getFooter() const { return writerBase_->getFooter(); } diff --git a/velox/dwio/dwrf/writer/WriterBase.cpp b/velox/dwio/dwrf/writer/WriterBase.cpp index 6fd51477d048..dd182917661d 100644 --- a/velox/dwio/dwrf/writer/WriterBase.cpp +++ b/velox/dwio/dwrf/writer/WriterBase.cpp @@ -22,8 +22,8 @@ namespace facebook::velox::dwrf { void WriterBase::writeFooter(const Type& type) { auto pos = writerSink_->size(); - footer_.set_headerlength(ORC_MAGIC_LEN); - footer_.set_contentlength(pos - ORC_MAGIC_LEN); + footer_->setHeaderLength(ORC_MAGIC_LEN); + footer_->setContentLength(pos - ORC_MAGIC_LEN); writerSink_->setMode(WriterSink::Mode::None); // write cache when available @@ -31,45 +31,46 @@ void WriterBase::writeFooter(const Type& type) { if (cacheSize > 0) { writerSink_->writeCache(); for (auto& i : writerSink_->getCacheOffsets()) { - footer_.add_stripecacheoffsets(i); + footer_->addStripeCacheOffsets(i); } pos = writerSink_->size(); } - ProtoUtils::writeType(type, footer_); - DWIO_ENSURE_EQ(footer_.types_size(), footer_.statistics_size()); + ProtoUtils::writeType(type, *footer_); + DWIO_ENSURE_EQ(footer_->typesSize(), footer_->statisticsSize()); auto writerVersion = static_cast(context_->getConfig(Config::WRITER_VERSION)); writeUserMetadata(writerVersion); - footer_.set_numberofrows(context_->fileRowCount()); - footer_.set_rowindexstride(context_->indexStride()); + footer_->setNumberOfRows(context_->fileRowCount()); + footer_->setRowIndexStride(context_->indexStride()); if (context_->fileRawSize() > 0 || context_->fileRowCount() == 0) { // ColumnTransformWriter, when rewriting presto written file does not have // rawSize. - footer_.set_rawdatasize(context_->fileRawSize()); + footer_->setRawDataSize(context_->fileRawSize()); } auto* checksum = writerSink_->getChecksum(); - footer_.set_checksumalgorithm( + footer_->setCheckSumAlgorithm( (checksum != nullptr) ? checksum->getType() : proto::ChecksumAlgorithm::NULL_); - writeProto(footer_); + writeProto(footer_->getDwrfPtr()); const auto footerLength = writerSink_->size() - pos; // write postscript pos = writerSink_->size(); - proto::PostScript ps; - ps.set_writerversion(writerVersion); - ps.set_footerlength(footerLength); - ps.set_compression( - static_cast(context_->compression())); + auto dwrfPostScript = ArenaCreate(arena_.get()); + std::unique_ptr ps = + std::make_unique(dwrfPostScript); + ps->setWriterVersion(writerVersion); + ps->setFooterLength(footerLength); + ps->setCompression(context_->compression()); if (context_->compression() != common::CompressionKind::CompressionKind_NONE) { - ps.set_compressionblocksize(context_->compressionBlockSize()); + ps->setCompressionBlockSize(context_->compressionBlockSize()); } - ps.set_cachemode( - static_cast(writerSink_->getCacheMode())); - ps.set_cachesize(cacheSize); + + ps->setCacheMode(writerSink_->getCacheMode()); + ps->setCacheSize(cacheSize); writeProto(ps, common::CompressionKind::CompressionKind_NONE); auto psLength = writerSink_->size() - pos; DWIO_ENSURE_LE(psLength, 0xff, "PostScript is too large: ", psLength); @@ -80,14 +81,14 @@ void WriterBase::writeFooter(const Type& type) { void WriterBase::writeUserMetadata(uint32_t writerVersion) { // add writer version - userMetadata_[std::string{WRITER_NAME_KEY}] = kDwioWriter; - userMetadata_[std::string{WRITER_VERSION_KEY}] = + userMetadata_[std::string{kWriterNameKey}] = kDwioWriter; + userMetadata_[std::string{kWriterVersionKey}] = folly::to(writerVersion); - userMetadata_[std::string{WRITER_HOSTNAME_KEY}] = process::getHostName(); + userMetadata_[std::string{kWriterHostnameKey}] = process::getHostName(); std::for_each(userMetadata_.begin(), userMetadata_.end(), [&](auto& pair) { - auto item = footer_.add_metadata(); - item->set_name(pair.first); - item->set_value(pair.second); + auto item = footer_->addMetadata(); + item.setName(pair.first); + item.setValue(pair.second); }); } diff --git a/velox/dwio/dwrf/writer/WriterBase.h b/velox/dwio/dwrf/writer/WriterBase.h index e79b9293db9b..abb9d4f16c78 100644 --- a/velox/dwio/dwrf/writer/WriterBase.h +++ b/velox/dwio/dwrf/writer/WriterBase.h @@ -17,15 +17,19 @@ #pragma once #include "velox/common/base/GTestMacros.h" +#include "velox/dwio/common/Arena.h" #include "velox/dwio/dwrf/writer/WriterContext.h" #include "velox/dwio/dwrf/writer/WriterSink.h" namespace facebook::velox::dwrf { +using dwio::common::ArenaCreate; + class WriterBase { public: explicit WriterBase(std::unique_ptr sink) - : sink_{std::move(sink)} { + : sink_{std::move(sink)}, + arena_(std::make_unique()) { VELOX_CHECK_NOT_NULL(sink_); } @@ -88,6 +92,8 @@ class WriterBase { *sink_, context_->getMemoryPool(MemoryUsageCategory::OUTPUT_STREAM), context_->getConfigs()); + auto dwrfFooter_ = ArenaCreate(arena_.get()); + footer_ = std::make_unique(dwrfFooter_); } void initBuffers(); @@ -107,7 +113,7 @@ class WriterBase { auto holder = context_->newDataBufferHolder(); auto stream = context_->newStream(kind, *holder); - t.SerializeToZeroCopyStream(stream.get()); + t->SerializeToZeroCopyStream(stream.get()); stream->flush(); writerSink_->addBuffers(*holder); @@ -131,23 +137,23 @@ class WriterBase { } } - proto::StripeInformation& addStripeInfo() { - auto stripe = footer_.add_stripes(); - stripe->set_numberofrows(context_->stripeRowCount()); + StripeInformationWriteWrapper addStripeInfo() { + auto stripe = footer_->addStripes(); + stripe.setNumberOfRows(context_->stripeRowCount()); if (context_->stripeRawSize() > 0 || context_->stripeRowCount() == 0) { // ColumnTransformWriter, when rewriting presto written // file does not have rawSize. - stripe->set_rawdatasize(context_->stripeRawSize()); + stripe.setRawDataSize(context_->stripeRawSize()); } auto* checksum = writerSink_->getChecksum(); if (checksum != nullptr) { - stripe->set_checksum(checksum->getDigest()); + stripe.setChecksum(checksum->getDigest()); } - return *stripe; + return stripe; } - proto::Footer& getFooter() { + std::unique_ptr& getFooter() { return footer_; } @@ -170,8 +176,10 @@ class WriterBase { std::unique_ptr context_; std::unique_ptr sink_; std::unique_ptr writerSink_; - proto::Footer footer_; + std::unique_ptr footer_; + proto::orc::Metadata metadata_; std::unordered_map userMetadata_; + std::unique_ptr arena_; friend class WriterTest; VELOX_FRIEND_TEST(WriterBaseTest, FlushWriterSinkUponClose); diff --git a/velox/dwio/dwrf/writer/WriterContext.h b/velox/dwio/dwrf/writer/WriterContext.h index 9ba444d53175..d31d1813e3e7 100644 --- a/velox/dwio/dwrf/writer/WriterContext.h +++ b/velox/dwio/dwrf/writer/WriterContext.h @@ -30,6 +30,7 @@ #include "velox/vector/DecodedVector.h" namespace facebook::velox::dwrf { + using dwio::common::BufferedOutputStream; using dwio::common::DataBufferHolder; using dwio::common::compression::CompressionBufferPool; @@ -50,16 +51,14 @@ class WriterContext : public CompressionBufferPool { ~WriterContext() override; bool hasStream(const DwrfStreamIdentifier& stream) const { - return streams_.find(stream) != streams_.end(); + return streams_.find(stream) != streams_.cend(); } const DataBufferHolder& getStream(const DwrfStreamIdentifier& stream) const { return streams_.at(stream); } - void addBuffer( - const DwrfStreamIdentifier& stream, - folly::StringPiece buffer) { + void addBuffer(const DwrfStreamIdentifier& stream, std::string_view buffer) { streams_.at(stream).take(buffer); } @@ -115,7 +114,7 @@ class WriterContext : public CompressionBufferPool { velox::memory::MemoryPool& dictionaryPool, velox::memory::MemoryPool& generalPool) { auto result = dictEncoders_.find(encodingKey); - if (result == dictEncoders_.end()) { + if (result == dictEncoders_.cend()) { auto emplaceResult = dictEncoders_.emplace( encodingKey, std::make_unique>( @@ -233,7 +232,7 @@ class WriterContext : public CompressionBufferPool { void removeAllIntDictionaryEncodersOnNode( std::function predicate) { auto iter = dictEncoders_.begin(); - while (iter != dictEncoders_.end()) { + while (iter != dictEncoders_.cend()) { if (predicate(iter->first.node())) { iter = dictEncoders_.erase(iter); } else { diff --git a/velox/dwio/orc/test/CMakeLists.txt b/velox/dwio/orc/test/CMakeLists.txt index 3d93874d12aa..dca4b326a525 100644 --- a/velox/dwio/orc/test/CMakeLists.txt +++ b/velox/dwio/orc/test/CMakeLists.txt @@ -16,7 +16,8 @@ add_executable(velox_dwio_orc_reader_test ReaderTest.cpp) add_test( NAME velox_dwio_orc_reader_test COMMAND velox_dwio_orc_reader_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_dwio_orc_reader_test @@ -24,13 +25,15 @@ target_link_libraries( velox_dwio_common_test_utils GTest::gtest GTest::gtest_main - GTest::gmock) + GTest::gmock +) add_executable(velox_dwio_orc_reader_filter_test ReaderFilterTest.cpp) add_test( NAME velox_dwio_orc_reader_filter_test COMMAND velox_dwio_orc_reader_filter_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_dwio_orc_reader_filter_test @@ -38,7 +41,7 @@ target_link_libraries( velox_dwio_common_test_utils GTest::gtest GTest::gtest_main - GTest::gmock) + GTest::gmock +) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/velox/dwio/orc/test/ReaderTest.cpp b/velox/dwio/orc/test/ReaderTest.cpp index cad607b5761d..a357acd0798e 100644 --- a/velox/dwio/orc/test/ReaderTest.cpp +++ b/velox/dwio/orc/test/ReaderTest.cpp @@ -239,9 +239,7 @@ TEST_F(OrcReaderTest, testOrcReadAllType) { auto mapValues = mapCol->mapValues()->as>(); EXPECT_EQ(mapKeys->size(), 2); EXPECT_EQ(mapKeys->size(), mapValues->size()); - EXPECT_EQ( - mapCol->toString(0, 2, ",", false), - "2 elements starting at 0 {foo => 1, bar => 2}"); + EXPECT_EQ(mapCol->toString(0, 2, ",", false), "{foo => 1, bar => 2}"); EXPECT_EQ(structCol->size(), 1); EXPECT_EQ(structCol->type()->toString(), "ROW"); @@ -392,8 +390,8 @@ TEST_P( auto rowReader = reader->createRowReader(rowReaderOptions); for (std::map::const_iterator itr = - GetParam().userMeta.begin(); - itr != GetParam().userMeta.end(); + GetParam().userMeta.cbegin(); + itr != GetParam().userMeta.cend(); ++itr) { ASSERT_EQ(true, reader->hasMetadataValue(itr->first)); std::string val = reader->getMetadataValue(itr->first); diff --git a/velox/dwio/parquet/CMakeLists.txt b/velox/dwio/parquet/CMakeLists.txt index 2baaaeee11f5..c40f352d0bbe 100644 --- a/velox/dwio/parquet/CMakeLists.txt +++ b/velox/dwio/parquet/CMakeLists.txt @@ -27,8 +27,6 @@ velox_add_library(velox_dwio_parquet_reader RegisterParquetReader.cpp) velox_add_library(velox_dwio_parquet_writer RegisterParquetWriter.cpp) if(VELOX_ENABLE_PARQUET) - velox_link_libraries(velox_dwio_parquet_reader - velox_dwio_native_parquet_reader xsimd) - velox_link_libraries(velox_dwio_parquet_writer - velox_dwio_arrow_parquet_writer) + velox_link_libraries(velox_dwio_parquet_reader velox_dwio_native_parquet_reader xsimd) + velox_link_libraries(velox_dwio_parquet_writer velox_dwio_arrow_parquet_writer) endif() diff --git a/velox/dwio/parquet/common/BloomFilter.cpp b/velox/dwio/parquet/common/BloomFilter.cpp index 6a59a1962327..211b842cf9ee 100644 --- a/velox/dwio/parquet/common/BloomFilter.cpp +++ b/velox/dwio/parquet/common/BloomFilter.cpp @@ -30,8 +30,6 @@ namespace facebook::velox::parquet { -constexpr uint32_t BlockSplitBloomFilter::SALT[kBitsSetPerBlock]; - BlockSplitBloomFilter::BlockSplitBloomFilter(memory::MemoryPool* pool) : pool_(pool), hashStrategy_(HashStrategy::XXHASH), diff --git a/velox/dwio/parquet/common/CMakeLists.txt b/velox/dwio/parquet/common/CMakeLists.txt index 4e3edf6687ce..159067602e9e 100644 --- a/velox/dwio/parquet/common/CMakeLists.txt +++ b/velox/dwio/parquet/common/CMakeLists.txt @@ -17,7 +17,8 @@ velox_add_library( BloomFilter.cpp XxHasher.cpp LevelComparison.cpp - LevelConversion.cpp) + LevelConversion.cpp +) velox_link_libraries( velox_dwio_parquet_common @@ -30,4 +31,5 @@ velox_link_libraries( Folly::folly Snappy::snappy thrift - zstd::zstd) + zstd::zstd +) diff --git a/velox/dwio/parquet/reader/CMakeLists.txt b/velox/dwio/parquet/reader/CMakeLists.txt index 2bda9d61d06b..292285da9b53 100644 --- a/velox/dwio/parquet/reader/CMakeLists.txt +++ b/velox/dwio/parquet/reader/CMakeLists.txt @@ -25,7 +25,8 @@ velox_add_library( RleBpDecoder.cpp StructColumnReader.cpp StringColumnReader.cpp - SemanticVersion.cpp) + SemanticVersion.cpp +) velox_link_libraries( velox_dwio_native_parquet_reader @@ -38,4 +39,5 @@ velox_link_libraries( arrow Snappy::snappy thrift - zstd::zstd) + zstd::zstd +) diff --git a/velox/dwio/parquet/reader/FloatingPointColumnReader.h b/velox/dwio/parquet/reader/FloatingPointColumnReader.h index cac475c0ee94..1a3fc9c4c58c 100644 --- a/velox/dwio/parquet/reader/FloatingPointColumnReader.h +++ b/velox/dwio/parquet/reader/FloatingPointColumnReader.h @@ -36,6 +36,13 @@ class FloatingPointColumnReader ParquetParams& params, common::ScanSpec& scanSpec); + // Parquet floating point reader always supports a bulk path + static constexpr bool kHasBulkPath = true; + + bool hasBulkPath() const override { + return kHasBulkPath; + } + void seekToRowGroup(int64_t index) override { base::seekToRowGroup(index); this->scanState().clear(); @@ -66,7 +73,16 @@ FloatingPointColumnReader::FloatingPointColumnReader( requestedType, std::move(fileType), params, - scanSpec) {} + scanSpec) { + VELOX_DCHECK( + (this->requestedType_->kind() == TypeKind::REAL && + std::is_same_v) || + (this->requestedType_->kind() == TypeKind::DOUBLE && + std::is_same_v), + "TRequested type mismatch: template parameter is {}, but requestedType is {}", + folly::demangle(typeid(TRequested)), + this->requestedType_->toString()); +} template uint64_t FloatingPointColumnReader::skip( diff --git a/velox/dwio/parquet/reader/Metadata.cpp b/velox/dwio/parquet/reader/Metadata.cpp index 8920b0ea400f..7c5c1c1a6893 100644 --- a/velox/dwio/parquet/reader/Metadata.cpp +++ b/velox/dwio/parquet/reader/Metadata.cpp @@ -156,22 +156,16 @@ common::CompressionKind thriftCodecToCompressionKind( switch (codec) { case thrift::CompressionCodec::UNCOMPRESSED: return common::CompressionKind::CompressionKind_NONE; - break; case thrift::CompressionCodec::SNAPPY: return common::CompressionKind::CompressionKind_SNAPPY; - break; case thrift::CompressionCodec::GZIP: return common::CompressionKind::CompressionKind_GZIP; - break; case thrift::CompressionCodec::LZO: return common::CompressionKind::CompressionKind_LZO; - break; case thrift::CompressionCodec::LZ4: return common::CompressionKind::CompressionKind_LZ4; - break; case thrift::CompressionCodec::ZSTD: return common::CompressionKind::CompressionKind_ZSTD; - break; case thrift::CompressionCodec::LZ4_RAW: return common::CompressionKind::CompressionKind_LZ4; default: @@ -321,8 +315,9 @@ FileMetaDataPtr::FileMetaDataPtr(const void* metadata) : ptr_(metadata) {} FileMetaDataPtr::~FileMetaDataPtr() = default; RowGroupMetaDataPtr FileMetaDataPtr::rowGroup(int i) const { - return RowGroupMetaDataPtr(reinterpret_cast( - &thriftFileMetaDataPtr(ptr_)->row_groups[i])); + return RowGroupMetaDataPtr( + reinterpret_cast( + &thriftFileMetaDataPtr(ptr_)->row_groups[i])); } int64_t FileMetaDataPtr::numRows() const { @@ -359,4 +354,8 @@ std::string FileMetaDataPtr::keyValueMetadataValue( VELOX_FAIL(fmt::format("Input key {} is not in the key value metadata", key)); } +std::string FileMetaDataPtr::createdBy() const { + return thriftFileMetaDataPtr(ptr_)->created_by; +} + } // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/reader/Metadata.h b/velox/dwio/parquet/reader/Metadata.h index a0241828b546..69fee1e5bdc4 100644 --- a/velox/dwio/parquet/reader/Metadata.h +++ b/velox/dwio/parquet/reader/Metadata.h @@ -153,6 +153,9 @@ class FileMetaDataPtr { /// Returns the value inside the key/value metadata if the key is present. std::string keyValueMetadataValue(const std::string_view key) const; + /// Return the Parquet writer created_by string. + std::string createdBy() const; + private: const void* ptr_; }; diff --git a/velox/dwio/parquet/reader/PageReader.cpp b/velox/dwio/parquet/reader/PageReader.cpp index 115a77b7bf36..d0c6a0fcf7d5 100644 --- a/velox/dwio/parquet/reader/PageReader.cpp +++ b/velox/dwio/parquet/reader/PageReader.cpp @@ -17,11 +17,11 @@ #include "velox/dwio/parquet/reader/PageReader.h" #include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" #include "velox/dwio/common/BufferUtil.h" #include "velox/dwio/common/ColumnVisitors.h" #include "velox/dwio/parquet/common/LevelConversion.h" #include "velox/dwio/parquet/thrift/ThriftTransport.h" - #include "velox/vector/FlatVector.h" #include // @manual @@ -87,7 +87,12 @@ PageHeader PageReader::readPageHeader() { if (bufferEnd_ == bufferStart_) { const void* buffer; int32_t size; - inputStream_->Next(&buffer, &size); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + inputStream_->Next(&buffer, &size); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); bufferStart_ = reinterpret_cast(buffer); bufferEnd_ = bufferStart_ + size; } @@ -106,26 +111,31 @@ PageHeader PageReader::readPageHeader() { } const char* PageReader::readBytes(int32_t size, BufferPtr& copy) { - if (bufferEnd_ == bufferStart_) { - const void* buffer = nullptr; - int32_t bufferSize = 0; - if (!inputStream_->Next(&buffer, &bufferSize)) { - VELOX_FAIL("Read past end"); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + if (bufferEnd_ == bufferStart_) { + const void* buffer = nullptr; + int32_t bufferSize = 0; + if (!inputStream_->Next(&buffer, &bufferSize)) { + VELOX_FAIL("Read past end"); + } + bufferStart_ = reinterpret_cast(buffer); + bufferEnd_ = bufferStart_ + bufferSize; } - bufferStart_ = reinterpret_cast(buffer); - bufferEnd_ = bufferStart_ + bufferSize; - } - if (bufferEnd_ - bufferStart_ >= size) { - bufferStart_ += size; - return bufferStart_ - size; - } - dwio::common::ensureCapacity(copy, size, &pool_); - dwio::common::readBytes( - size, - inputStream_.get(), - copy->asMutable(), - bufferStart_, - bufferEnd_); + if (bufferEnd_ - bufferStart_ >= size) { + bufferStart_ += size; + return bufferStart_ - size; + } + dwio::common::ensureCapacity(copy, size, &pool_); + dwio::common::readBytes( + size, + inputStream_.get(), + copy->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); return copy->as(); } @@ -288,9 +298,15 @@ void PageReader::prepareDataPageV2(const PageHeader& pageHeader, int64_t row) { } if (maxDefine_ > 0) { - defineDecoder_ = std::make_unique( - pageData_ + repeatLength, - pageData_ + repeatLength + defineLength, + if (maxDefine_ == 1) { + defineDecoder_ = std::make_unique( + pageData_ + repeatLength, + pageData_ + repeatLength + defineLength, + ::arrow::bit_util::NumRequiredBits(maxDefine_)); + } + wideDefineDecoder_ = std::make_unique( + reinterpret_cast(pageData_ + repeatLength), + defineLength, ::arrow::bit_util::NumRequiredBits(maxDefine_)); } auto levelsSize = repeatLength + defineLength; @@ -362,12 +378,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(dictionary_.values->asMutable(), pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, - inputStream_.get(), - dictionary_.values->asMutable(), - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, + inputStream_.get(), + dictionary_.values->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } if (type_->type()->isShortDecimal() && parquetType == thrift::Type::INT32) { @@ -397,12 +418,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(dictionary_.values->asMutable(), pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, - inputStream_.get(), - dictionary_.values->asMutable(), - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, + inputStream_.get(), + dictionary_.values->asMutable(), + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } // Expand the Parquet type length values to Velox type length. // We start from the end to allow in-place expansion. @@ -429,8 +455,13 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(strings, pageData_, numBytes); } else { - dwio::common::readBytes( - numBytes, inputStream_.get(), strings, bufferStart_, bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numBytes, inputStream_.get(), strings, bufferStart_, bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } auto header = strings; for (auto i = 0; i < dictionary_.numValues; ++i) { @@ -452,12 +483,17 @@ void PageReader::prepareDictionary(const PageHeader& pageHeader) { if (pageData_) { memcpy(data, pageData_, numParquetBytes); } else { - dwio::common::readBytes( - numParquetBytes, - inputStream_.get(), - data, - bufferStart_, - bufferEnd_); + uint64_t readUs{0}; + { + MicrosecondTimer timer(&readUs); + dwio::common::readBytes( + numParquetBytes, + inputStream_.get(), + data, + bufferStart_, + bufferEnd_); + } + stats_.pageLoadTimeNs.increment(readUs * 1'000); } if (type_->type()->isShortDecimal()) { // Parquet decimal values have a fixed typeLength_ and are in big-endian @@ -745,7 +781,7 @@ void PageReader::makeDecoder() { std::make_unique(pageData_); break; } - FMT_FALLTHROUGH; + [[fallthrough]]; default: VELOX_UNSUPPORTED("Encoding not supported yet: {}", encoding_); } diff --git a/velox/dwio/parquet/reader/PageReader.h b/velox/dwio/parquet/reader/PageReader.h index c377100428a7..423added8930 100644 --- a/velox/dwio/parquet/reader/PageReader.h +++ b/velox/dwio/parquet/reader/PageReader.h @@ -42,6 +42,7 @@ class PageReader { ParquetTypeWithIdPtr fileType, common::CompressionKind codec, int64_t chunkSize, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone) : pool_(pool), inputStream_(std::move(stream)), @@ -52,6 +53,7 @@ class PageReader { codec_(codec), chunkSize_(chunkSize), nullConcatenation_(pool_), + stats_(stats), sessionTimezone_(sessionTimezone) { type_->makeLevelInfo(leafInfo_); } @@ -62,6 +64,7 @@ class PageReader { memory::MemoryPool& pool, common::CompressionKind codec, int64_t chunkSize, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone = nullptr) : pool_(pool), inputStream_(std::move(stream)), @@ -71,6 +74,7 @@ class PageReader { codec_(codec), chunkSize_(chunkSize), nullConcatenation_(pool_), + stats_(stats), sessionTimezone_(sessionTimezone) {} /// Advances 'numRows' top level rows. @@ -264,7 +268,7 @@ class PageReader { template < typename Visitor, typename std::enable_if< - !std::is_same_v && + !std::is_same_v && !std::is_same_v, int>::type = 0> void @@ -300,7 +304,7 @@ class PageReader { template < typename Visitor, typename std::enable_if< - std::is_same_v, + std::is_same_v, int>::type = 0> void callDecoder(const uint64_t* nulls, bool& nullsFromFastPath, Visitor visitor) { @@ -359,14 +363,17 @@ class PageReader { // Returns the number of passed rows/values gathered by // 'reader'. Only numRows() is set for a filter-only case, only // numValues() is set for a non-filtered case. - template - static int32_t numRowsInReader( - const dwio::common::SelectiveColumnReader& reader) { + template + static int32_t numValuesRead( + const dwio::common::SelectiveColumnReader& reader, + const int32_t numPageRowsRead) { + if (hasHook) { + return numPageRowsRead; + } if (hasFilter) { return reader.numRows(); - } else { - return reader.numValues(); } + return reader.numValues(); } memory::MemoryPool& pool_; @@ -502,6 +509,8 @@ class PageReader { // Base values of dictionary when reading a string dictionary. VectorPtr dictionaryValues_; + dwio::common::ColumnReaderStatistics& stats_; + const tz::TimeZone* sessionTimezone_{nullptr}; // Decoders. Only one will be set at a time. @@ -537,6 +546,10 @@ void PageReader::readWithVisitor(Visitor& visitor) { !std::is_same_v; constexpr bool filterOnly = std::is_same_v; + constexpr bool hasHook = Visitor::kHasHook; + static_assert( + !(hasFilter && hasHook), "hasFilter and hasHook cannot both be true"); + bool mayProduceNulls = !filterOnly && visitor.allowNulls(); auto rows = visitor.rows(); auto numRows = visitor.numRows(); @@ -544,11 +557,13 @@ void PageReader::readWithVisitor(Visitor& visitor) { startVisit(folly::Range(rows, numRows)); rowsCopy_ = &visitor.rowsCopy(); folly::Range pageRows; + int32_t numPageRowsRead = 0; const uint64_t* nulls = nullptr; bool isMultiPage = false; while (rowsForPage(reader, hasFilter, mayProduceNulls, pageRows, nulls)) { bool nullsFromFastPath = false; - int32_t numValuesBeforePage = numRowsInReader(reader); + const int32_t numValuesBeforePage = + numValuesRead(reader, numPageRowsRead); visitor.setNumValuesBias(numValuesBeforePage); visitor.setRows(pageRows); callDecoder(nulls, nullsFromFastPath, visitor); @@ -571,20 +586,23 @@ void PageReader::readWithVisitor(Visitor& visitor) { } if (!nulls) { nullConcatenation_.appendOnes( - numRowsInReader(reader) - numValuesBeforePage); + numValuesRead(reader, numPageRowsRead) - + numValuesBeforePage); } else if (reader.returnReaderNulls()) { // Nulls from decoding go directly to result. nullConcatenation_.append( reader.nullsInReadRange()->template as(), 0, - numRowsInReader(reader) - numValuesBeforePage); + numValuesRead(reader, numPageRowsRead) - + numValuesBeforePage); } else { // Add the nulls produced from the decoder to the result. auto firstNullIndex = nullsFromFastPath ? 0 : numValuesBeforePage; nullConcatenation_.append( reader.mutableNulls(0), firstNullIndex, - firstNullIndex + numRowsInReader(reader) - + firstNullIndex + + numValuesRead(reader, numPageRowsRead) - numValuesBeforePage); } } @@ -598,6 +616,7 @@ void PageReader::readWithVisitor(Visitor& visitor) { if (hasFilter && rowNumberBias_) { reader.offsetOutputRows(numValuesBeforePage, rowNumberBias_); } + numPageRowsRead += pageRows.size(); } if (isMultiPage) { reader.setNulls(mayProduceNulls ? nullConcatenation_.buffer() : nullptr); diff --git a/velox/dwio/parquet/reader/ParquetColumnReader.cpp b/velox/dwio/parquet/reader/ParquetColumnReader.cpp index 0b69f446280d..ef3ff4c9d467 100644 --- a/velox/dwio/parquet/reader/ParquetColumnReader.cpp +++ b/velox/dwio/parquet/reader/ParquetColumnReader.cpp @@ -51,8 +51,13 @@ std::unique_ptr ParquetColumnReader::build( requestedType, fileType, params, scanSpec); case TypeKind::REAL: - return std::make_unique>( - requestedType, fileType, params, scanSpec); + if (requestedType->kind() == TypeKind::REAL) { + return std::make_unique>( + requestedType, fileType, params, scanSpec); + } else { + return std::make_unique>( + requestedType, fileType, params, scanSpec); + } case TypeKind::DOUBLE: return std::make_unique>( requestedType, fileType, params, scanSpec); @@ -65,9 +70,11 @@ std::unique_ptr ParquetColumnReader::build( case TypeKind::VARCHAR: return std::make_unique(fileType, params, scanSpec); - case TypeKind::ARRAY: + case TypeKind::ARRAY: { + VELOX_CHECK(requestedType->isArray(), "Requested type must be array"); return std::make_unique( columnReaderOptions, requestedType, fileType, params, scanSpec); + } case TypeKind::MAP: return std::make_unique( @@ -97,7 +104,7 @@ std::unique_ptr ParquetColumnReader::build( default: VELOX_FAIL( "buildReader unhandled type: " + - mapTypeKindToName(fileType->type()->kind())); + std::string(TypeKindName::toName(fileType->type()->kind()))); } } diff --git a/velox/dwio/parquet/reader/ParquetData.cpp b/velox/dwio/parquet/reader/ParquetData.cpp index 29a593da414c..572c53acc2db 100644 --- a/velox/dwio/parquet/reader/ParquetData.cpp +++ b/velox/dwio/parquet/reader/ParquetData.cpp @@ -25,7 +25,7 @@ std::unique_ptr ParquetParams::toFormatData( const std::shared_ptr& type, const common::ScanSpec& /*scanSpec*/) { return std::make_unique( - type, metaData_, pool(), sessionTimezone_); + type, metaData_, pool(), runtimeStatistics(), sessionTimezone_); } void ParquetData::filterRowGroups( @@ -70,7 +70,9 @@ void ParquetData::filterRowGroups( } } -bool ParquetData::rowGroupMatches(uint32_t rowGroupId, common::Filter* filter) { +bool ParquetData::rowGroupMatches( + uint32_t rowGroupId, + const common::Filter* filter) { auto column = type_->column(); auto type = type_->type(); auto rowGroup = fileMetaDataPtr_.rowGroup(rowGroupId); @@ -126,6 +128,7 @@ dwio::common::PositionProvider ParquetData::seekToRowGroup(int64_t index) { type_, metadata.compression(), metadata.totalCompressedSize(), + stats_, sessionTimezone_); return dwio::common::PositionProvider(empty); } diff --git a/velox/dwio/parquet/reader/ParquetData.h b/velox/dwio/parquet/reader/ParquetData.h index fe8020f57c65..9926202491d6 100644 --- a/velox/dwio/parquet/reader/ParquetData.h +++ b/velox/dwio/parquet/reader/ParquetData.h @@ -63,6 +63,7 @@ class ParquetData : public dwio::common::FormatData { const std::shared_ptr& type, const FileMetaDataPtr fileMetadataPtr, memory::MemoryPool& pool, + dwio::common::ColumnReaderStatistics& stats, const tz::TimeZone* sessionTimezone) : pool_(pool), type_(std::static_pointer_cast(type)), @@ -70,6 +71,7 @@ class ParquetData : public dwio::common::FormatData { maxDefine_(type_->maxDefine_), maxRepeat_(type_->maxRepeat_), rowsInRowGroup_(-1), + stats_(stats), sessionTimezone_(sessionTimezone) {} /// Prepares to read data for 'index'th row group. @@ -90,8 +92,9 @@ class ParquetData : public dwio::common::FormatData { return reader_.get(); } - // Reads null flags for 'numValues' next top level rows. The first 'numValues' - // bits of 'nulls' are set and the reader is advanced by numValues'. + // Reads null flags for 'numValues' next top level rows. The first + // 'numValues' bits of 'nulls' are set and the reader is advanced by + // numValues'. void readNullsOnly(int32_t numValues, BufferPtr& nulls) { reader_->readNullsOnly(numValues, nulls); } @@ -100,8 +103,9 @@ class ParquetData : public dwio::common::FormatData { return maxDefine_ > 0; } - /// Sets nulls to be returned by readNulls(). Nulls for non-leaf readers come - /// from leaf repdefs which are gathered before descending the reader tree. + /// Sets nulls to be returned by readNulls(). Nulls for non-leaf readers + /// come from leaf repdefs which are gathered before descending the reader + /// tree. void setNulls(BufferPtr& nulls, int32_t numValues) { if (nulls || numValues) { VELOX_CHECK_EQ(presetNullsConsumed_, presetNullsSize_); @@ -120,8 +124,8 @@ class ParquetData : public dwio::common::FormatData { const uint64_t* incomingNulls, BufferPtr& nulls, bool nullsOnly = false) override { - // If the query accesses only nulls, read the nulls from the pages in range. - // If nulls are preread, return those minus any skipped. + // If the query accesses only nulls, read the nulls from the pages in + // range. If nulls are preread, return those minus any skipped. if (presetNulls_) { VELOX_CHECK_LE(numValues, presetNullsSize_ - presetNullsConsumed_); if (!presetNullsConsumed_ && numValues == presetNullsSize_) { @@ -144,8 +148,8 @@ class ParquetData : public dwio::common::FormatData { readNullsOnly(numValues, nulls); return; } - // There are no column-level nulls in Parquet, only page-level ones, so this - // is always non-null. + // There are no column-level nulls in Parquet, only page-level ones, so + // this is always non-null. nulls = nullptr; } @@ -206,7 +210,7 @@ class ParquetData : public dwio::common::FormatData { private: /// True if 'filter' may have hits for the column of 'this' according to the /// stats in 'rowGroup'. - bool rowGroupMatches(uint32_t rowGroupId, common::Filter* filter); + bool rowGroupMatches(uint32_t rowGroupId, const common::Filter* filter); protected: memory::MemoryPool& pool_; @@ -219,6 +223,7 @@ class ParquetData : public dwio::common::FormatData { const uint32_t maxDefine_; const uint32_t maxRepeat_; int64_t rowsInRowGroup_; + dwio::common::ColumnReaderStatistics& stats_; const tz::TimeZone* sessionTimezone_; std::unique_ptr reader_; diff --git a/velox/dwio/parquet/reader/ParquetReader.cpp b/velox/dwio/parquet/reader/ParquetReader.cpp index 325713272bb5..63700fa6bfe3 100644 --- a/velox/dwio/parquet/reader/ParquetReader.cpp +++ b/velox/dwio/parquet/reader/ParquetReader.cpp @@ -21,6 +21,7 @@ #include "velox/dwio/parquet/reader/ParquetColumnReader.h" #include "velox/dwio/parquet/reader/StructColumnReader.h" #include "velox/dwio/parquet/thrift/ThriftTransport.h" +#include "velox/functions/lib/string/StringImpl.h" namespace facebook::velox::parquet { @@ -30,12 +31,30 @@ bool isParquetReservedKeyword( std::string name, uint32_t parentSchemaIdx, uint32_t curSchemaIdx) { - return ((parentSchemaIdx == 0 && curSchemaIdx == 0) || name == "key_value" || - name == "key" || name == "value" || name == "list" || - name == "element" || name == "bag" || name == "array_element") + // We skip this for the top-level nodes. + return ((parentSchemaIdx == 0 && curSchemaIdx == 0) || + (parentSchemaIdx != 0 && + (name == "key_value" || name == "key" || name == "value" || + name == "list" || name == "element" || name == "bag" || + name == "array_element"))) ? true : false; } + +// An unannotated array in Parquet is a repeated field that is not explicitly +// marked as a LIST logical type. If current schema element is a repeated field +// and the requested type is an array, we treat the current schema element as an +// unannotated array, and returns true if the element type is compatible with +// the physical type. +bool isCompatible( + const TypePtr& requestedType, + bool isRepeated, + const std::function& isCompatibleFunc) { + return isCompatibleFunc(requestedType) || + (requestedType->isArray() && isRepeated && + isCompatibleFunc(requestedType->asArray().elementType())); +} + } // namespace /// Metadata and options for reading Parquet. @@ -304,11 +323,10 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( auto name = schemaElement.name; if (isFileColumnNamesReadAsLowerCase()) { - folly::toLowerAscii(name); + name = functions::stringImpl::utf8StrToLowerCopy(name); } - if ((!options_.useColumnNamesForColumnMapping()) && - (options_.fileSchema() != nullptr)) { + if (!options_.useColumnNamesForColumnMapping() && options_.fileSchema()) { if (isParquetReservedKeyword(name, parentSchemaIdx, curSchemaIdx)) { columnNames.push_back(name); } @@ -331,26 +349,42 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( ++schemaIdx; auto childName = schema[schemaIdx].name; if (isFileColumnNamesReadAsLowerCase()) { - folly::toLowerAscii(childName); + childName = functions::stringImpl::utf8StrToLowerCopy(childName); } TypePtr childRequestedType = nullptr; bool followChild = true; - if (requestedType && requestedType->isRow()) { - auto requestedRowType = - std::dynamic_pointer_cast(requestedType); - if (options_.useColumnNamesForColumnMapping()) { - auto fileTypeIdx = requestedRowType->getChildIdxIfExists(childName); - if (fileTypeIdx.has_value()) { - childRequestedType = requestedRowType->childAt(*fileTypeIdx); + + { + RowTypePtr requestedRowType = nullptr; + if (requestedType) { + if (requestedType->isRow()) { + requestedRowType = + std::dynamic_pointer_cast(requestedType); + } else if ( + requestedType->isArray() && isRepeated && + requestedType->asArray().elementType()->isRow()) { + // Handle the case of unannotated array of structs (repeated group + // without LIST annotation). + requestedRowType = std::dynamic_pointer_cast( + requestedType->asArray().elementType()); } - } else { - // Handle schema evolution. - if (i < requestedRowType->size()) { - columnNames.push_back(requestedRowType->nameOf(i)); - childRequestedType = requestedRowType->childAt(i); + } + + if (requestedRowType) { + if (options_.useColumnNamesForColumnMapping()) { + auto fileTypeIdx = requestedRowType->getChildIdxIfExists(childName); + if (fileTypeIdx.has_value()) { + childRequestedType = requestedRowType->childAt(*fileTypeIdx); + } } else { - followChild = false; + // Handle schema evolution. + if (i < requestedRowType->size()) { + columnNames.push_back(requestedRowType->nameOf(i)); + childRequestedType = requestedRowType->childAt(i); + } else { + followChild = false; + } } } } @@ -533,20 +567,21 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( // In this legacy case, there is no middle layer between "array" // node and the children nodes. Below creates this dummy middle // layer to mimic the non-legacy case and fill the gap. - rowChildren.emplace_back(std::make_unique( - childrenRowType, - std::move(children), - curSchemaIdx, - maxSchemaElementIdx, - ParquetTypeWithId::kNonLeaf, - "dummy", - std::nullopt, - std::nullopt, - std::nullopt, - maxRepeat, - maxDefine, - isOptional, - isRepeated)); + rowChildren.emplace_back( + std::make_unique( + childrenRowType, + std::move(children), + curSchemaIdx, + maxSchemaElementIdx, + ParquetTypeWithId::kNonLeaf, + "dummy", + std::nullopt, + std::nullopt, + std::nullopt, + maxRepeat, + maxDefine, + isOptional, + isRepeated)); auto res = std::make_unique( TypeFactory::create(childrenRowType), std::move(rowChildren), @@ -597,20 +632,21 @@ std::unique_ptr ReaderBase::getParquetColumnInfo( // In this legacy case, there is no middle layer between "array" // node and the children nodes. Below creates this dummy middle // layer to mimic the non-legacy case and fill the gap. - rowChildren.emplace_back(std::make_unique( - childrenRowType, - std::move(children), - curSchemaIdx, - maxSchemaElementIdx, - ParquetTypeWithId::kNonLeaf, - "dummy", - std::nullopt, - std::nullopt, - std::nullopt, - maxRepeat, - maxDefine, - isOptional, - isRepeated)); + rowChildren.emplace_back( + std::make_unique( + childrenRowType, + std::move(children), + curSchemaIdx, + maxSchemaElementIdx, + ParquetTypeWithId::kNonLeaf, + "dummy", + std::nullopt, + std::nullopt, + std::nullopt, + maxRepeat, + maxDefine, + isOptional, + isRepeated)); return std::make_unique( TypeFactory::create(childrenRowType), std::move(rowChildren), @@ -721,6 +757,8 @@ TypePtr ReaderBase::convertType( static std::string_view kTypeMappingErrorFmtStr = "Converted type {} is not allowed for requested type {}"; + const bool isRepeated = schemaElement.__isset.repetition_type && + schemaElement.repetition_type == thrift::FieldRepetitionType::REPEATED; if (schemaElement.__isset.converted_type) { switch (schemaElement.converted_type) { case thrift::ConvertedType::INT_8: @@ -731,10 +769,16 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TINYINT || - requestedType->kind() == TypeKind::SMALLINT || - requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TINYINT || + type->kind() == TypeKind::SMALLINT || + type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "TINYINT", requestedType->toString()); @@ -748,9 +792,15 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::SMALLINT || - requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::SMALLINT || + type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "SMALLINT", requestedType->toString()); @@ -764,8 +814,14 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "INTEGER", requestedType->toString()); @@ -779,7 +835,13 @@ TypePtr ReaderBase::convertType( "{} converted type can only be set for value of thrift::Type::INT32", schemaElement.converted_type); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "BIGINT", requestedType->toString()); @@ -791,7 +853,11 @@ TypePtr ReaderBase::convertType( thrift::Type::INT32, "DATE converted type can only be set for value of thrift::Type::INT32"); VELOX_CHECK( - !requestedType || requestedType->isDate(), + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isDate(); }), kTypeMappingErrorFmtStr, "DATE", requestedType->toString()); @@ -804,7 +870,13 @@ TypePtr ReaderBase::convertType( thrift::Type::INT64, "TIMESTAMP_MICROS or TIMESTAMP_MILLIS converted type can only be set for value of thrift::Type::INT64"); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", requestedType->toString()); @@ -822,7 +894,10 @@ TypePtr ReaderBase::convertType( auto type = DECIMAL(schemaElementPrecision, schemaElementScale); if (requestedType) { VELOX_CHECK( - requestedType->isDecimal(), + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isDecimal(); }), kTypeMappingErrorFmtStr, "DECIMAL", requestedType->toString()); @@ -831,20 +906,30 @@ TypePtr ReaderBase::convertType( // the scale of the file type and requested type must match while // precision may be larger. if (requestedType->isShortDecimal()) { - const auto& shortDecimalType = requestedType->asShortDecimal(); VELOX_CHECK( - type->isShortDecimal() && - shortDecimalType.precision() >= schemaElementPrecision && - shortDecimalType.scale() == schemaElementScale, + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return type->isShortDecimal() && + type->asShortDecimal().precision() >= + schemaElementPrecision && + type->asShortDecimal().scale() == schemaElementScale; + }), kTypeMappingErrorFmtStr, type->toString(), requestedType->toString()); } else { - const auto& longDecimalType = requestedType->asLongDecimal(); VELOX_CHECK( - type->isLongDecimal() && - longDecimalType.precision() >= schemaElementPrecision && - longDecimalType.scale() == schemaElementScale, + isCompatible( + requestedType, + isRepeated, + [&](const TypePtr& type) { + return type->isLongDecimal() && + type->asLongDecimal().precision() >= + schemaElementPrecision && + type->asLongDecimal().scale() == schemaElementScale; + }), kTypeMappingErrorFmtStr, type->toString(), requestedType->toString()); @@ -858,7 +943,13 @@ TypePtr ReaderBase::convertType( case thrift::Type::BYTE_ARRAY: case thrift::Type::FIXED_LEN_BYTE_ARRAY: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::VARCHAR, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::VARCHAR; + }), kTypeMappingErrorFmtStr, "VARCHAR", requestedType->toString()); @@ -873,7 +964,13 @@ TypePtr ReaderBase::convertType( thrift::Type::BYTE_ARRAY, "ENUM converted type can only be set for value of thrift::Type::BYTE_ARRAY"); VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::VARCHAR, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::VARCHAR; + }), kTypeMappingErrorFmtStr, "VARCHAR", requestedType->toString()); @@ -896,15 +993,27 @@ TypePtr ReaderBase::convertType( switch (schemaElement.type) { case thrift::Type::type::BOOLEAN: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BOOLEAN, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::BOOLEAN; + }), kTypeMappingErrorFmtStr, "BOOLEAN", requestedType->toString()); return BOOLEAN(); case thrift::Type::type::INT32: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::INTEGER || - requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "INTEGER", requestedType->toString()); @@ -914,47 +1023,84 @@ TypePtr ReaderBase::convertType( if (schemaElement.__isset.logicalType && schemaElement.logicalType.__isset.TIMESTAMP) { VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", requestedType->toString()); return TIMESTAMP(); } VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::BIGINT, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::BIGINT; + }), kTypeMappingErrorFmtStr, "BIGINT", requestedType->toString()); return BIGINT(); case thrift::Type::type::INT96: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::TIMESTAMP, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::TIMESTAMP; + }), kTypeMappingErrorFmtStr, "TIMESTAMP", requestedType->toString()); return TIMESTAMP(); // INT96 only maps to a timestamp case thrift::Type::type::FLOAT: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::REAL || - requestedType->kind() == TypeKind::DOUBLE, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::REAL || + type->kind() == TypeKind::DOUBLE; + }), kTypeMappingErrorFmtStr, "REAL", requestedType->toString()); return REAL(); case thrift::Type::type::DOUBLE: VELOX_CHECK( - !requestedType || requestedType->kind() == TypeKind::DOUBLE, + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { + return type->kind() == TypeKind::DOUBLE; + }), kTypeMappingErrorFmtStr, "DOUBLE", requestedType->toString()); return DOUBLE(); case thrift::Type::type::BYTE_ARRAY: case thrift::Type::type::FIXED_LEN_BYTE_ARRAY: - if (requestedType && requestedType->isVarchar()) { + if (requestedType && + isCompatible(requestedType, isRepeated, [](const TypePtr& type) { + return type->isVarchar(); + })) { return VARCHAR(); } else { VELOX_CHECK( - !requestedType || requestedType->isVarbinary(), + !requestedType || + isCompatible( + requestedType, + isRepeated, + [](const TypePtr& type) { return type->isVarbinary(); }), kTypeMappingErrorFmtStr, "VARBINARY", requestedType->toString()); @@ -977,7 +1123,7 @@ std::shared_ptr ReaderBase::createRowType( for (auto& child : children) { auto childName = static_cast(*child).name_; if (fileColumnNamesReadAsLowerCase) { - folly::toLowerAscii(childName); + childName = functions::stringImpl::utf8StrToLowerCopy(childName); } childNames.push_back(std::move(childName)); childTypes.push_back(child->type()); @@ -1177,14 +1323,21 @@ class ParquetRowReader::Impl { std::optional estimatedRowSize() const { auto index = nextRowGroupIdsIdx_ < 1 ? 0 : rowGroupIds_[nextRowGroupIdsIdx_ - 1]; - return readerBase_->rowGroupUncompressedSize( - index, *readerBase_->schemaWithId()) / + if (index == lastRowGroupWithRowEstimate_) { + return estimatedRowSize_; + } + estimatedRowSize_ = readerBase_->rowGroupUncompressedSize( + index, *readerBase_->schemaWithId()) / rowGroups_[index].num_rows; + lastRowGroupWithRowEstimate_ = index; + return estimatedRowSize_; } void updateRuntimeStats(dwio::common::RuntimeStatistics& stats) const { stats.skippedStrides += skippedStrides_; stats.processedStrides += rowGroupIds_.size(); + stats.columnReaderStatistics.pageLoadTimeNs.merge( + columnReaderStats_.pageLoadTimeNs); } void resetFilterCaches() { @@ -1236,6 +1389,9 @@ class ParquetRowReader::Impl { ParquetStatsContext parquetStatsContext_; dwio::common::ColumnReaderStatistics columnReaderStats_; + + mutable std::optional estimatedRowSize_; + mutable int32_t lastRowGroupWithRowEstimate_{-1}; }; ParquetRowReader::ParquetRowReader( diff --git a/velox/dwio/parquet/reader/ParquetTypeWithId.cpp b/velox/dwio/parquet/reader/ParquetTypeWithId.cpp index 1581fa639348..aab2e71fba7e 100644 --- a/velox/dwio/parquet/reader/ParquetTypeWithId.cpp +++ b/velox/dwio/parquet/reader/ParquetTypeWithId.cpp @@ -53,23 +53,24 @@ ParquetTypeWithId::moveChildren() const&& { auto precision = parquetChild->precision_; auto scale = parquetChild->scale_; auto typeLength = parquetChild->typeLength_; - children.push_back(std::make_unique( - std::move(type), - std::move(*parquetChild).moveChildren(), - id, - maxId, - column, - std::move(name), - parquetType, - std::move(logicalType), - std::move(convertedType), - maxRepeat, - maxDefine, - isOptional, - isRepeated, - precision, - scale, - typeLength)); + children.push_back( + std::make_unique( + std::move(type), + std::move(*parquetChild).moveChildren(), + id, + maxId, + column, + std::move(name), + parquetType, + std::move(logicalType), + std::move(convertedType), + maxRepeat, + maxDefine, + isOptional, + isRepeated, + precision, + scale, + typeLength)); } return children; } diff --git a/velox/dwio/parquet/reader/RepeatedColumnReader.cpp b/velox/dwio/parquet/reader/RepeatedColumnReader.cpp index 8cd75156747a..e10dd4182a8f 100644 --- a/velox/dwio/parquet/reader/RepeatedColumnReader.cpp +++ b/velox/dwio/parquet/reader/RepeatedColumnReader.cpp @@ -153,8 +153,6 @@ void MapColumnReader::seekToRowGroup(int64_t index) { BufferPtr noBuffer; formatData_->as().setNulls(noBuffer, 0); lengths_.setLengths(nullptr); - keyReader_->seekToRowGroup(index); - elementReader_->seekToRowGroup(index); } void MapColumnReader::skipUnreadLengths() { diff --git a/velox/dwio/parquet/reader/RleBpDecoder.h b/velox/dwio/parquet/reader/RleBpDecoder.h index ac07ed76ad0c..5856de775e2b 100644 --- a/velox/dwio/parquet/reader/RleBpDecoder.h +++ b/velox/dwio/parquet/reader/RleBpDecoder.h @@ -37,7 +37,7 @@ class RleBpDecoder { /// Decode @param numValues number of values and copy the decoded values into /// @param outputBuffer template - void next(T* FOLLY_NONNULL& outputBuffer, uint64_t numValues) { + void next(T * FOLLY_NONNULL & outputBuffer, uint64_t numValues) { while (numValues > 0) { if (numRemainingUnpackedValues_ > 0) { auto numValuesToRead = @@ -103,7 +103,7 @@ class RleBpDecoder { template inline void copyRemainingUnpackedValues( - T* FOLLY_NONNULL& outputBuffer, + T * FOLLY_NONNULL & outputBuffer, int8_t numValues) { VELOX_CHECK_LE(numValues, numRemainingUnpackedValues_); diff --git a/velox/dwio/parquet/reader/StringColumnReader.cpp b/velox/dwio/parquet/reader/StringColumnReader.cpp index ac678b7f0a39..7ce4f68d41ae 100644 --- a/velox/dwio/parquet/reader/StringColumnReader.cpp +++ b/velox/dwio/parquet/reader/StringColumnReader.cpp @@ -35,7 +35,7 @@ void StringColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - prepareRead(offset, rows, incomingNulls); + prepareRead(offset, rows, incomingNulls); dwio::common::StringColumnReadWithVisitorHelper( *this, rows)([&](auto visitor) { formatData_->as().readWithVisitor(visitor); @@ -83,7 +83,7 @@ void StringColumnReader::dedictionarize() { } auto& view = dict->valueAt(indices[i]); numValues_ = i; - addStringValue(folly::StringPiece(view.data(), view.size())); + addStringValue(std::string_view(view.data(), view.size())); } numValues_ = numValues; } diff --git a/velox/dwio/parquet/reader/StringDecoder.h b/velox/dwio/parquet/reader/StringDecoder.h index 2bff5285e3c4..03805b5659fb 100644 --- a/velox/dwio/parquet/reader/StringDecoder.h +++ b/velox/dwio/parquet/reader/StringDecoder.h @@ -90,15 +90,15 @@ class StringDecoder { return *reinterpret_cast(buffer); } - folly::StringPiece readString() { + std::string_view readString() { auto length = lengthAt(bufferStart_); bufferStart_ += length + sizeof(int32_t); - return folly::StringPiece(bufferStart_ - length, length); + return std::string_view(bufferStart_ - length, length); } - folly::StringPiece readFixedString() { + std::string_view readFixedString() { bufferStart_ += fixedLength_; - return folly::StringPiece(bufferStart_ - fixedLength_, fixedLength_); + return std::string_view(bufferStart_ - fixedLength_, fixedLength_); } const char* bufferStart_; diff --git a/velox/dwio/parquet/reader/StructColumnReader.cpp b/velox/dwio/parquet/reader/StructColumnReader.cpp index 694f334c51a2..64740bf04108 100644 --- a/velox/dwio/parquet/reader/StructColumnReader.cpp +++ b/velox/dwio/parquet/reader/StructColumnReader.cpp @@ -46,12 +46,13 @@ StructColumnReader::StructColumnReader( auto childFileType = fileType_->childByName(childSpec->fieldName()); auto childRequestedType = requestedType_->asRow().findChild(childSpec->fieldName()); - addChild(ParquetColumnReader::build( - columnReaderOptions, - childRequestedType, - childFileType, - params, - *childSpec)); + addChild( + ParquetColumnReader::build( + columnReaderOptions, + childRequestedType, + childFileType, + params, + *childSpec)); childSpecs[i]->setSubscript(children_.size() - 1); } diff --git a/velox/dwio/parquet/reader/TimestampColumnReader.h b/velox/dwio/parquet/reader/TimestampColumnReader.h index df832d3daeac..6ee0cf7fc5ef 100644 --- a/velox/dwio/parquet/reader/TimestampColumnReader.h +++ b/velox/dwio/parquet/reader/TimestampColumnReader.h @@ -163,10 +163,10 @@ class TimestampColumnReader : public IntegerColumnReader { bool isDense, typename ExtractValues> void readHelper( - velox::common::Filter* filter, + const velox::common::Filter* filter, const RowSet& rows, ExtractValues extractValues) { - if (auto* range = dynamic_cast(filter)) { + if (auto* range = dynamic_cast(filter)) { ParquetTimestampRange newRange{ range->lower(), range->upper(), range->nullAllowed(), filePrecision_}; this->readWithVisitor( @@ -176,12 +176,37 @@ class TimestampColumnReader : public IntegerColumnReader { common::TimestampRange, ExtractValues, isDense>(newRange, this, rows, extractValues)); + } else if ( + auto* multiRange = dynamic_cast(filter)) { + std::vector> filters; + filters.reserve(multiRange->filters().size()); + for (const auto& filter : multiRange->filters()) { + if (auto* range = dynamic_cast(filter.get())) { + filters.emplace_back( + std::make_unique>( + range->lower(), + range->upper(), + range->nullAllowed(), + filePrecision_)); + } else { + filters.emplace_back(filter->clone(range->nullAllowed())); + } + } + auto newMultiRange = + common::MultiRange(std::move(filters), multiRange->nullAllowed()); + this->readWithVisitor( + rows, + dwio::common::ColumnVisitor< + int128_t, + common::MultiRange, + ExtractValues, + isDense>(newMultiRange, this, rows, extractValues)); } else { this->readWithVisitor( rows, dwio::common:: ColumnVisitor( - *reinterpret_cast(filter), + *static_cast(filter), this, rows, extractValues)); diff --git a/velox/dwio/parquet/tests/CMakeLists.txt b/velox/dwio/parquet/tests/CMakeLists.txt index feeff62cb789..dc09a85c4fdc 100644 --- a/velox/dwio/parquet/tests/CMakeLists.txt +++ b/velox/dwio/parquet/tests/CMakeLists.txt @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -set(TEST_LINK_LIBS - velox_dwio_common_test_utils - velox_vector_test_lib - velox_exec_test_lib - velox_dwio_parquet_reader - velox_dwio_parquet_writer - velox_temp_path - GTest::gtest - GTest::gtest_main - GTest::gmock - gflags::gflags - glog::glog) +set( + TEST_LINK_LIBS + velox_dwio_common_test_utils + velox_vector_test_lib + velox_exec_test_lib + velox_dwio_parquet_reader + velox_dwio_parquet_writer + velox_temp_path + GTest::gtest + GTest::gtest_main + GTest::gmock + gflags::gflags + glog::glog +) add_subdirectory(common) add_subdirectory(reader) @@ -34,13 +36,14 @@ add_executable(velox_dwio_parquet_tpch_test ParquetTpchTest.cpp) add_test( NAME velox_dwio_parquet_tpch_test COMMAND velox_dwio_parquet_tpch_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_dwio_parquet_tpch_test velox_tpch_connector velox_aggregates velox_tpch_gen - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) -file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/velox/dwio/parquet/tests/ParquetTestBase.h b/velox/dwio/parquet/tests/ParquetTestBase.h index c867978a0a5c..77662eb625da 100644 --- a/velox/dwio/parquet/tests/ParquetTestBase.h +++ b/velox/dwio/parquet/tests/ParquetTestBase.h @@ -19,6 +19,7 @@ #include #include #include "velox/common/base/Fs.h" +#include "velox/common/file/File.h" #include "velox/dwio/common/FileSink.h" #include "velox/dwio/common/Reader.h" #include "velox/dwio/common/tests/utils/DataFiles.h" @@ -191,10 +192,51 @@ class ParquetTestBase : public testing::Test, "velox/dwio/parquet/tests/reader", "../examples/" + fileName); } + dwio::common::MemorySink* write( + const RowVectorPtr& data, + const WriterOptions& writerOptions) { + auto sink = std::make_unique( + 200 * 1024 * 1024, + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto* sinkPtr = sink.get(); + auto writer = std::make_unique( + std::move(sink), writerOptions, data->rowType()); + writer->write(data); + writer->close(); + writers_.push_back(std::move(writer)); + return sinkPtr; + } + + dwio::common::MemorySink* write( + const RowVectorPtr& data, + std::unordered_map configFromFile = {}, + std::unordered_map sessionProperties = {}) { + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + auto connectorConfig = config::ConfigBase(std::move(configFromFile)); + auto connectorSessionProperties = + config::ConfigBase(std::move(sessionProperties)); + writerOptions.processConfigs(connectorConfig, connectorSessionProperties); + return write(data, writerOptions); + } + + std::unique_ptr createReaderInMemory( + const dwio::common::MemorySink& sink, + const dwio::common::ReaderOptions& opts) { + std::string data(sink.data(), sink.size()); + return std::make_unique( + std::make_unique( + std::make_shared(std::move(data)), + opts.memoryPool()), + opts); + } + static constexpr uint64_t kRowsInRowGroup = 10'000; static constexpr uint64_t kBytesInRowGroup = 128 * 1'024 * 1'024; std::shared_ptr rootPool_; std::shared_ptr leafPool_; std::shared_ptr tempPath_; + // Stores writers created by write() helper to keep sinks alive for reading. + std::vector> writers_; }; } // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/tests/ParquetTpchTest.cpp b/velox/dwio/parquet/tests/ParquetTpchTest.cpp index 1ef7a9775b65..0a728abe8d4f 100644 --- a/velox/dwio/parquet/tests/ParquetTpchTest.cpp +++ b/velox/dwio/parquet/tests/ParquetTpchTest.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/file/FileSystems.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/tpch/TpchConnector.h" #include "velox/dwio/parquet/RegisterParquetReader.h" #include "velox/dwio/parquet/RegisterParquetWriter.h" @@ -54,26 +55,18 @@ class ParquetTpchTest : public testing::Test { parquet::registerParquetReaderFactory(); parquet::registerParquetWriterFactory(); - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared( - std::unordered_map())); + connector::hive::HiveConnectorFactory hiveFactory; + auto hiveConnector = hiveFactory.newConnector( + kHiveConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(hiveConnector); - connector::registerConnectorFactory( - std::make_shared()); - auto tpchConnector = - connector::getConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName) - ->newConnector( - kTpchConnectorId, - std::make_shared( - std::unordered_map())); + connector::tpch::TpchConnectorFactory tpchFactory; + auto tpchConnector = tpchFactory.newConnector( + kTpchConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(tpchConnector); saveTpchTablesAsParquet(); @@ -81,10 +74,6 @@ class ParquetTpchTest : public testing::Test { } static void TearDownTestSuite() { - connector::unregisterConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName); - connector::unregisterConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName); connector::unregisterConnector(kHiveConnectorId); connector::unregisterConnector(kTpchConnectorId); parquet::unregisterParquetReaderFactory(); @@ -105,8 +94,8 @@ class ParquetTpchTest : public testing::Test { auto plan = PlanBuilder() .tpchTableScan(table, std::move(columnNames), 0.01) .planNode(); - auto split = - exec::Split(std::make_shared( + auto split = exec::Split( + std::make_shared( kTpchConnectorId, /*cacheable=*/true, 1, 0)); auto rows = diff --git a/velox/dwio/parquet/tests/common/CMakeLists.txt b/velox/dwio/parquet/tests/common/CMakeLists.txt index c6d2593e9219..19cf296e68fa 100644 --- a/velox/dwio/parquet/tests/common/CMakeLists.txt +++ b/velox/dwio/parquet/tests/common/CMakeLists.txt @@ -21,6 +21,8 @@ target_link_libraries( arrow thrift velox_link_libs + velox_exec GTest::gtest GTest::gmock - GTest::gtest_main) + GTest::gtest_main +) diff --git a/velox/dwio/parquet/tests/common/LevelConversionTest.cpp b/velox/dwio/parquet/tests/common/LevelConversionTest.cpp index e3ff3a7e6e9c..f7274b7c9744 100644 --- a/velox/dwio/parquet/tests/common/LevelConversionTest.cpp +++ b/velox/dwio/parquet/tests/common/LevelConversionTest.cpp @@ -160,38 +160,37 @@ MultiLevelTestData TriplyNestedList() { // [[[]], [[], [1, 2]], null, [[3]]], // null, // [] - return MultiLevelTestData{ - /*defLevels=*/std::vector{ - 2, - 7, - 6, - 7, - 5, - 3, // first row - 5, - 5, - 7, - 7, - 2, - 7, // second row - 0, // third row - 1}, - /*repLevels=*/ - std::vector{ - 0, - 1, - 3, - 3, - 2, - 1, // first row - 0, - 1, - 2, - 3, - 1, - 1, // second row - 0, - 0}}; + return MultiLevelTestData{/*defLevels=*/std::vector{ + 2, + 7, + 6, + 7, + 5, + 3, // first row + 5, + 5, + 7, + 7, + 2, + 7, // second row + 0, // third row + 1}, + /*repLevels=*/ + std::vector{ + 0, + 1, + 3, + 3, + 2, + 1, // first row + 0, + 1, + 2, + 3, + 1, + 1, // second row + 0, + 0}}; } template diff --git a/velox/dwio/parquet/tests/examples/complex_type_v2_page.parquet b/velox/dwio/parquet/tests/examples/complex_type_v2_page.parquet new file mode 100644 index 000000000000..5b2af2256e8a Binary files /dev/null and b/velox/dwio/parquet/tests/examples/complex_type_v2_page.parquet differ diff --git a/velox/dwio/parquet/tests/examples/nested_array_struct.parquet b/velox/dwio/parquet/tests/examples/nested_array_struct.parquet new file mode 100644 index 000000000000..41a43fa35d39 Binary files /dev/null and b/velox/dwio/parquet/tests/examples/nested_array_struct.parquet differ diff --git a/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet b/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet new file mode 100644 index 000000000000..8a7eea601d01 Binary files /dev/null and b/velox/dwio/parquet/tests/examples/proto_repeated_string.parquet differ diff --git a/velox/dwio/parquet/tests/reader/CMakeLists.txt b/velox/dwio/parquet/tests/reader/CMakeLists.txt index 0402842a4ceb..ef72bf26b99f 100644 --- a/velox/dwio/parquet/tests/reader/CMakeLists.txt +++ b/velox/dwio/parquet/tests/reader/CMakeLists.txt @@ -16,10 +16,14 @@ add_executable(velox_dwio_parquet_page_reader_test ParquetPageReaderTest.cpp) add_test( NAME velox_dwio_parquet_page_reader_test COMMAND velox_dwio_parquet_page_reader_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( - velox_dwio_parquet_page_reader_test velox_dwio_native_parquet_reader - velox_link_libs ${TEST_LINK_LIBS}) + velox_dwio_parquet_page_reader_test + velox_dwio_native_parquet_reader + velox_link_libs + ${TEST_LINK_LIBS} +) add_executable(velox_parquet_e2e_filter_test E2EFilterTest.cpp) add_test(velox_parquet_e2e_filter_test velox_parquet_e2e_filter_test) @@ -29,10 +33,10 @@ target_link_libraries( velox_dwio_parquet_writer velox_dwio_native_parquet_reader lz4::lz4 - lzo2::lzo2 zstd::zstd ZLIB::ZLIB - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) add_library(velox_dwio_parquet_reader_benchmark_lib ParquetReaderBenchmark.cpp) target_link_libraries( @@ -44,49 +48,62 @@ target_link_libraries( velox_hive_connector ${TEST_LINK_LIBS} Folly::follybenchmark - Folly::folly) + Folly::folly +) if(VELOX_ENABLE_BENCHMARKS) - add_executable(velox_dwio_parquet_reader_benchmark - ParquetReaderBenchmarkMain.cpp) - target_link_libraries(velox_dwio_parquet_reader_benchmark - velox_dwio_parquet_reader_benchmark_lib) + add_executable(velox_dwio_parquet_reader_benchmark ParquetReaderBenchmarkMain.cpp) + target_link_libraries(velox_dwio_parquet_reader_benchmark velox_dwio_parquet_reader_benchmark_lib) endif() add_executable( velox_dwio_parquet_reader_test - ParquetReaderTest.cpp ParquetReaderBenchmarkTest.cpp BloomFilterTest.cpp) + ParquetReaderTest.cpp + ParquetReaderBenchmarkTest.cpp + BloomFilterTest.cpp +) add_test( NAME velox_dwio_parquet_reader_test COMMAND velox_dwio_parquet_reader_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( - velox_dwio_parquet_reader_test velox_dwio_parquet_common - velox_dwio_parquet_reader_benchmark_lib velox_link_libs) + velox_dwio_parquet_reader_test + velox_dwio_parquet_common + velox_dwio_parquet_reader_benchmark_lib + velox_hive_connector + velox_link_libs +) -add_executable(velox_dwio_parquet_structure_decoder_test - NestedStructureDecoderTest.cpp) +add_executable(velox_dwio_parquet_structure_decoder_test NestedStructureDecoderTest.cpp) add_test( NAME velox_dwio_parquet_structure_decoder_test COMMAND velox_dwio_parquet_structure_decoder_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( - velox_dwio_parquet_structure_decoder_test velox_dwio_native_parquet_reader - velox_link_libs ${TEST_LINK_LIBS}) + velox_dwio_parquet_structure_decoder_test + velox_dwio_native_parquet_reader + velox_link_libs + ${TEST_LINK_LIBS} +) if(VELOX_ENABLE_BENCHMARKS) - add_executable(velox_dwio_parquet_structure_decoder_benchmark - NestedStructureDecoderBenchmark.cpp) + add_executable(velox_dwio_parquet_structure_decoder_benchmark NestedStructureDecoderBenchmark.cpp) target_link_libraries( velox_dwio_parquet_structure_decoder_benchmark - velox_dwio_native_parquet_reader Folly::folly Folly::follybenchmark) + velox_dwio_native_parquet_reader + Folly::folly + Folly::follybenchmark + ) endif() add_executable(velox_dwio_parquet_table_scan_test ParquetTableScanTest.cpp) add_test( NAME velox_dwio_parquet_table_scan_test COMMAND velox_dwio_parquet_table_scan_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_dwio_parquet_table_scan_test velox_dwio_parquet_reader @@ -96,20 +113,21 @@ target_link_libraries( velox_hive_connector velox_link_libs velox_type_tz - ${TEST_LINK_LIBS}) + ${TEST_LINK_LIBS} +) if(${VELOX_ENABLE_ARROW}) - add_executable(velox_dwio_parquet_rlebp_decoder_test RleBpDecoderTest.cpp) add_test( NAME velox_dwio_parquet_rlebp_decoder_test COMMAND velox_dwio_parquet_rlebp_decoder_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + ) target_link_libraries( velox_dwio_parquet_rlebp_decoder_test velox_dwio_native_parquet_reader arrow velox_link_libs - ${TEST_LINK_LIBS}) - + ${TEST_LINK_LIBS} + ) endif() diff --git a/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp b/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp index f5f608d812aa..b8d4e324b784 100644 --- a/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp +++ b/velox/dwio/parquet/tests/reader/E2EFilterTest.cpp @@ -99,8 +99,9 @@ class E2EFilterTest : public E2EFilterTestBase, TEST_F(E2EFilterTest, writerMagic) { rowType_ = ROW({"c0"}, {INTEGER()}); std::vector batches; - batches.push_back(std::static_pointer_cast( - test::BatchMaker::createBatch(rowType_, 20000, *leafPool_, nullptr, 0))); + batches.push_back( + std::static_pointer_cast(test::BatchMaker::createBatch( + rowType_, 20000, *leafPool_, nullptr, 0))); writeToMemory(rowType_, batches, false); auto data = sinkData_.data(); auto size = sinkData_.size(); @@ -657,8 +658,9 @@ TEST_F(E2EFilterTest, largeMetadata) { rowType_ = ROW({"c0"}, {INTEGER()}); std::vector batches; - batches.push_back(std::static_pointer_cast( - test::BatchMaker::createBatch(rowType_, 1000, *leafPool_, nullptr, 0))); + batches.push_back( + std::static_pointer_cast(test::BatchMaker::createBatch( + rowType_, 1000, *leafPool_, nullptr, 0))); writeToMemory(rowType_, batches, false); dwio::common::ReaderOptions readerOpts{leafPool_.get()}; readerOpts.setFooterEstimatedSize(1024); @@ -694,8 +696,9 @@ TEST_F(E2EFilterTest, combineRowGroup) { rowType_ = ROW({"c0"}, {INTEGER()}); std::vector batches; for (int i = 0; i < 5; i++) { - batches.push_back(std::static_pointer_cast( - test::BatchMaker::createBatch(rowType_, 1, *leafPool_, nullptr, 0))); + batches.push_back( + std::static_pointer_cast(test::BatchMaker::createBatch( + rowType_, 1, *leafPool_, nullptr, 0))); } writeToMemory(rowType_, batches, false); dwio::common::ReaderOptions readerOpts{leafPool_.get()}; @@ -711,7 +714,7 @@ TEST_F(E2EFilterTest, writeDecimalAsInteger) { auto rowVector = makeRowVector( {makeFlatVector({1, 2}, DECIMAL(8, 2)), makeFlatVector({1, 2}, DECIMAL(10, 2)), - makeFlatVector({1, 2}, DECIMAL(19, 2))}); + makeFlatVector({1, 2}, DECIMAL(19, 2))}); writeToMemory(rowVector->type(), {rowVector}, false); dwio::common::ReaderOptions readerOpts{leafPool_.get()}; auto input = std::make_unique( diff --git a/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp b/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp index 5145dcfdc8ca..d037cf44b91a 100644 --- a/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetPageReaderTest.cpp @@ -31,11 +31,13 @@ TEST_F(ParquetPageReaderTest, smallPage) { auto headerSize = file->getLength(); auto inputStream = std::make_unique( std::move(file), 0, headerSize, *leafPool_, LogType::TEST); + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); auto header = pageReader->readPageHeader(); EXPECT_EQ(header.type, thrift::PageType::type::DATA_PAGE); EXPECT_EQ(header.uncompressed_page_size, 16950); @@ -50,6 +52,7 @@ TEST_F(ParquetPageReaderTest, smallPage) { auto maxValue = header.data_page_header.statistics.max_value; EXPECT_EQ(minValue, expectedMinValue); EXPECT_EQ(maxValue, expectedMaxValue); + EXPECT_GT(stats.pageLoadTimeNs.sum(), 0); } TEST_F(ParquetPageReaderTest, largePage) { @@ -59,11 +62,13 @@ TEST_F(ParquetPageReaderTest, largePage) { auto headerSize = file->getLength(); auto inputStream = std::make_unique( std::move(file), 0, headerSize, *leafPool_, LogType::TEST); + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); auto header = pageReader->readPageHeader(); EXPECT_EQ(header.type, thrift::PageType::type::DATA_PAGE); @@ -79,6 +84,7 @@ TEST_F(ParquetPageReaderTest, largePage) { auto maxValue = header.data_page_header.statistics.max_value; EXPECT_EQ(minValue, expectedMinValue); EXPECT_EQ(maxValue, expectedMaxValue); + EXPECT_GT(stats.pageLoadTimeNs.sum(), 0); } TEST_F(ParquetPageReaderTest, corruptedPageHeader) { @@ -92,11 +98,13 @@ TEST_F(ParquetPageReaderTest, corruptedPageHeader) { // In the corrupted_page_header, the min_value length is set incorrectly on // purpose. This is to simulate the situation where the Parquet Page Header is // corrupted. And an error is expected to be thrown. + dwio::common::ColumnReaderStatistics stats; auto pageReader = std::make_unique( std::move(inputStream), *leafPool_, common::CompressionKind::CompressionKind_GZIP, - headerSize); + headerSize, + stats); EXPECT_THROW(pageReader->readPageHeader(), VeloxException); } diff --git a/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp b/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp index f1ee25766dbf..adb6ffc7ab55 100644 --- a/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetReaderTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/common/Mutation.h" #include "velox/dwio/parquet/tests/ParquetTestBase.h" #include "velox/expression/ExprToSubfieldFilter.h" #include "velox/vector/tests/utils/VectorMaker.h" @@ -595,8 +596,9 @@ TEST_F(ParquetReaderTest, parseUnsignedInt1) { {TINYINT(), SMALLINT(), INTEGER(), BIGINT()}); RowReaderOptions rowReaderOpts; - rowReaderOpts.select(std::make_shared( - rowType, rowType->names())); + rowReaderOpts.select( + std::make_shared( + rowType, rowType->names())); rowReaderOpts.setScanSpec(makeScanSpec(rowType)); auto rowReader = reader->createRowReader(rowReaderOpts); @@ -1456,7 +1458,6 @@ TEST_F(ParquetReaderTest, readFixedLenBinaryAsStringFromUuid) { } TEST_F(ParquetReaderTest, testV2PageWithZeroMaxDefRep) { - // enum_type.parquet contains 1 column (ENUM) with 3 rows. const std::string sample(getExampleFilePath("v2_page.parquet")); dwio::common::ReaderOptions readerOptions{leafPool_.get()}; @@ -1480,6 +1481,31 @@ TEST_F(ParquetReaderTest, testV2PageWithZeroMaxDefRep) { outputRowType, *rowReader, expected, *leafPool_); } +TEST_F(ParquetReaderTest, readComplexTypeWithV2Page) { + const std::string sample(getExampleFilePath("complex_type_v2_page.parquet")); + + dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + auto reader = createReader(sample, readerOptions); + EXPECT_EQ(reader->numberOfRows(), 1ULL); + + auto rowType = reader->typeWithId(); + EXPECT_EQ(rowType->type()->kind(), TypeKind::ROW); + EXPECT_EQ(rowType->size(), 2ULL); + + auto outputRowType = + ROW({"nums", "props"}, {ARRAY(INTEGER()), MAP(VARCHAR(), INTEGER())}); + auto rowReaderOpts = getReaderOpts(outputRowType); + rowReaderOpts.setScanSpec(makeScanSpec(outputRowType)); + auto rowReader = reader->createRowReader(rowReaderOpts); + + auto expected = makeRowVector( + {makeArrayVectorFromJson({"[4 ,5]"}), + makeMapVectorFromJson( + {"{\"x\": 99, \"y\": 100}"})}); + assertReadWithReaderAndExpected( + outputRowType, *rowReader, expected, *leafPool_); +} + TEST_F(ParquetReaderTest, arrayOfMapOfIntKeyArrayValue) { // The Schema is of type // message hive_schema { @@ -1733,3 +1759,355 @@ TEST_F(ParquetReaderTest, fileColumnVarcharToMetadataColumnMismatchTest) { runVarcharColTest(type); } } + +TEST_F(ParquetReaderTest, readerWithSchema) { + // Create an in-memory writer. + auto sink = std::make_unique( + 1024 * 1024, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto sinkPtr = sink.get(); + const auto data = makeRowVector( + {makeFlatVector({1}), + makeArrayVectorFromJson({"[4 ,5]"})}); + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = leafPool_.get(); + + // key, element are Parquet reserved keywords. + // Ensure we handle them properly during the schema inference. + auto schema = ROW({"key", "element"}, {BIGINT(), ARRAY(INTEGER())}); + + auto writer = std::make_unique( + std::move(sink), writerOptions, rootPool_, schema); + writer->write(data); + writer->close(); + + // Create the reader. + dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + readerOptions.setFileSchema(schema); + std::string dataBuf(sinkPtr->data(), sinkPtr->size()); + auto file = std::make_shared(std::move(dataBuf)); + auto buffer = std::make_unique( + file, readerOptions.memoryPool()); + ParquetReader reader(std::move(buffer), readerOptions); + + EXPECT_EQ(reader.rowType()->toString(), schema->toString()); +} + +// Comprehensive test matrix covering all combinations: +// - Nulls: No nulls, With nulls +// - Dictionary: Enabled, Disabled +// - Filter: None, IsNull, IsNotNull, Value filter +// - Density: Dense (no deletions), Non-dense (with deletions/mutations) + +enum class FloatToDoubleFilter { + kNone, + kIsNull, + kIsNotNull, + kGreaterThanOrEqual, // Value filter: greater than or equal to a threshold + kMultiRange, // MultiRange filter: a < X OR a > Y +}; + +struct FloatToDoubleSpec { + std::vector> values; + std::vector ids; + bool enableDictionary{true}; + FloatToDoubleFilter filter{FloatToDoubleFilter::kNone}; + std::optional filterValue; // Value for value-based filters + std::optional filterLowerValue; // Lower bound for MultiRange filter + std::optional filterUpperValue; // Upper bound for MultiRange filter + std::vector deletedRows; +}; + +struct FloatToDoubleTestParam { + bool hasNulls; + bool enableDictionary; + FloatToDoubleFilter filter; + bool isDense; + + std::string toString() const { + return fmt::format( + "Nulls_{}_Dict_{}_Filter_{}_Dense_{}", + hasNulls ? "Yes" : "No", + enableDictionary ? "Yes" : "No", + filterName(filter), + isDense ? "Yes" : "No"); + } + + static std::string filterName(FloatToDoubleFilter filter) { + switch (filter) { + case FloatToDoubleFilter::kNone: + return "None"; + case FloatToDoubleFilter::kIsNull: + return "IsNull"; + case FloatToDoubleFilter::kIsNotNull: + return "IsNotNull"; + case FloatToDoubleFilter::kGreaterThanOrEqual: + return "GreaterThanOrEqual"; + case FloatToDoubleFilter::kMultiRange: + return "MultiRange"; + default: + return "Unknown"; + } + } +}; + +class FloatToDoubleEvolutionTest + : public ParquetReaderTest, + public testing::WithParamInterface { + public: + static std::vector getTestParams() { + std::vector params; + for (bool hasNulls : {false, true}) { + for (bool enableDictionary : {false, true}) { + // When hasNulls is false, only test kNone, kGreaterThanOrEqual, and + // kMultiRange filter (kIsNull would match nothing, kIsNotNull is + // equivalent to kNone) + std::vector filters; + if (hasNulls) { + filters = { + FloatToDoubleFilter::kNone, + FloatToDoubleFilter::kIsNull, + FloatToDoubleFilter::kIsNotNull, + FloatToDoubleFilter::kGreaterThanOrEqual, + FloatToDoubleFilter::kMultiRange}; + } else { + filters = { + FloatToDoubleFilter::kNone, + FloatToDoubleFilter::kGreaterThanOrEqual, + FloatToDoubleFilter::kMultiRange}; + } + + for (auto filter : filters) { + for (bool isDense : {true, false}) { + params.push_back({hasNulls, enableDictionary, filter, isDense}); + } + } + } + } + return params; + } + + void runFloatToDoubleScenario(const FloatToDoubleSpec& spec); +}; + +void FloatToDoubleEvolutionTest::runFloatToDoubleScenario( + const FloatToDoubleSpec& spec) { + ASSERT_EQ(spec.values.size(), spec.ids.size()); + const vector_size_t numRows = spec.ids.size(); + + auto floatVector = makeNullableFlatVector(spec.values); + auto idVector = + makeFlatVector(numRows, [&](auto row) { return spec.ids[row]; }); + + RowVectorPtr writeData = makeRowVector({floatVector, idVector}); + RowTypePtr writeSchema = ROW({"float_col", "id"}, {REAL(), BIGINT()}); + + auto sink = std::make_unique( + 1024 * 1024, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto sinkPtr = sink.get(); + + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = leafPool_.get(); + writerOptions.enableDictionary = spec.enableDictionary; + + auto writer = std::make_unique( + std::move(sink), writerOptions, rootPool_, writeSchema); + writer->write(writeData); + writer->close(); + + RowTypePtr readSchema = ROW({"float_col", "id"}, {DOUBLE(), BIGINT()}); + + dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + readerOptions.setFileSchema(readSchema); + + std::string dataBuf(sinkPtr->data(), sinkPtr->size()); + auto file = std::make_shared(std::move(dataBuf)); + auto buffer = std::make_unique( + file, readerOptions.memoryPool()); + auto reader = + std::make_unique(std::move(buffer), readerOptions); + + RowReaderOptions rowReaderOpts; + rowReaderOpts.select( + std::make_shared( + readSchema, readSchema->names())); + auto scanSpec = makeScanSpec(readSchema); + + // Apply IsNull or IsNotNull filter if specified + switch (spec.filter) { + case FloatToDoubleFilter::kNone: + break; + case FloatToDoubleFilter::kIsNull: { + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + floatChild->setFilter(exec::isNull()); + break; + } + case FloatToDoubleFilter::kIsNotNull: { + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + floatChild->setFilter(exec::isNotNull()); + break; + } + case FloatToDoubleFilter::kGreaterThanOrEqual: { + ASSERT_TRUE(spec.filterValue.has_value()); + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + floatChild->setFilter( + exec::greaterThanOrEqualDouble(spec.filterValue.value())); + break; + } + case FloatToDoubleFilter::kMultiRange: { + ASSERT_TRUE(spec.filterLowerValue.has_value()); + ASSERT_TRUE(spec.filterUpperValue.has_value()); + auto* floatChild = + scanSpec->getOrCreateChild(common::Subfield("float_col")); + // Create a MultiRange filter: a < lower OR a > upper + floatChild->setFilter( + exec::orFilter( + exec::lessThanDouble(spec.filterLowerValue.value()), + exec::greaterThanDouble(spec.filterUpperValue.value()))); + break; + } + } + + rowReaderOpts.setScanSpec(scanSpec); + auto rowReader = reader->createRowReader(rowReaderOpts); + + std::vector deletedFlags(numRows, false); + for (auto index : spec.deletedRows) { + ASSERT_LT(index, numRows); + deletedFlags[index] = true; + } + + std::vector expectedIndices; + expectedIndices.reserve(numRows); + for (vector_size_t row = 0; row < numRows; ++row) { + if (deletedFlags[row]) { + continue; + } + + bool passes = false; + switch (spec.filter) { + case FloatToDoubleFilter::kNone: + passes = true; + break; + case FloatToDoubleFilter::kIsNull: + passes = !spec.values[row].has_value(); + break; + case FloatToDoubleFilter::kIsNotNull: + passes = spec.values[row].has_value(); + break; + case FloatToDoubleFilter::kGreaterThanOrEqual: + passes = spec.values[row].has_value() && + static_cast(*spec.values[row]) >= spec.filterValue.value(); + break; + case FloatToDoubleFilter::kMultiRange: + passes = spec.values[row].has_value() && + (static_cast(*spec.values[row]) < + spec.filterLowerValue.value() || + static_cast(*spec.values[row]) > + spec.filterUpperValue.value()); + break; + } + + if (passes) { + expectedIndices.push_back(row); + } + } + + std::vector> expectedDoubles(expectedIndices.size()); + for (size_t i = 0; i < expectedIndices.size(); ++i) { + const auto originalIndex = expectedIndices[i]; + if (!spec.values[originalIndex].has_value()) { + expectedDoubles[i] = std::nullopt; + } else { + expectedDoubles[i] = static_cast(*spec.values[originalIndex]); + } + } + + auto expectedFloat = makeNullableFlatVector(expectedDoubles); + auto expectedId = makeFlatVector( + expectedIndices.size(), + [&](auto row) { return spec.ids[expectedIndices[row]]; }); + RowVectorPtr expected = makeRowVector({expectedFloat, expectedId}); + + if (spec.deletedRows.empty() && spec.filter != FloatToDoubleFilter::kIsNull && + spec.filter != FloatToDoubleFilter::kIsNotNull && + spec.filter != FloatToDoubleFilter::kGreaterThanOrEqual && + spec.filter != FloatToDoubleFilter::kMultiRange) { + assertReadWithReaderAndExpected( + readSchema, *rowReader, expected, *leafPool_); + return; + } + + VectorPtr result = BaseVector::create(readSchema, 0, leafPool_.get()); + vector_size_t scanned = 0; + std::vector deleted(bits::nwords(numRows), 0); + if (spec.deletedRows.empty()) { + scanned = rowReader->next(numRows, result); + } else { + for (auto index : spec.deletedRows) { + bits::setBit(deleted.data(), index); + } + dwio::common::Mutation mutation; + mutation.deletedRows = deleted.data(); + scanned = rowReader->next(numRows, result, &mutation); + } + + EXPECT_GT(scanned, 0); + EXPECT_GE(scanned, expected->size()); + ASSERT_TRUE(result != nullptr); + auto rowVector = result->as(); + ASSERT_TRUE(rowVector != nullptr); + ASSERT_EQ(rowVector->size(), expected->size()); + assertEqualVectorPart(expected, result, 0); +} + +TEST_P(FloatToDoubleEvolutionTest, readFloatToDouble) { + const auto& param = GetParam(); + FloatToDoubleSpec spec; + constexpr vector_size_t kSize = 200; + spec.enableDictionary = param.enableDictionary; + spec.values.resize(kSize); + spec.ids.resize(kSize); + + for (vector_size_t row = 0; row < kSize; ++row) { + if (param.hasNulls && row % 5 == 0) { + spec.values[row] = std::nullopt; + } else { + // Use a value pattern that works for both dictionary and direct encoding + float val = + static_cast(row % 10) * 1.1f + static_cast(row) * 0.01f; + spec.values[row] = val; + } + spec.ids[row] = row; + } + + spec.filter = param.filter; + + // Set filter value for value-based filters + if (param.filter == FloatToDoubleFilter::kGreaterThanOrEqual) { + // Filter values greater than or equal to 5.0 (this should match + // approximately half the rows) + spec.filterValue = 5.0; + } else if (param.filter == FloatToDoubleFilter::kMultiRange) { + // Filter values < 3.0 OR > 7.0 + spec.filterLowerValue = 3.0; + spec.filterUpperValue = 7.0; + } + + if (!param.isDense) { + // Add some deleted rows scattered throughout + spec.deletedRows = {5, 20, 55, 99, 150, 199}; + } + + runFloatToDoubleScenario(spec); +} + +INSTANTIATE_TEST_SUITE_P( + FloatToDoubleEvolution, + FloatToDoubleEvolutionTest, + testing::ValuesIn(FloatToDoubleEvolutionTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return info.param.toString(); + }); diff --git a/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp b/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp index c272fa43a8bc..9a2690e79a63 100644 --- a/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp +++ b/velox/dwio/parquet/tests/reader/ParquetTableScanTest.cpp @@ -69,8 +69,7 @@ class ParquetTableScanTest : public HiveConnectorTestBase { void assertSelectWithAssignments( std::vector&& outputColumnNames, - std::unordered_map>& - assignments, + const connector::ColumnHandleMap& assignments, const std::string& sql) { auto rowType = getRowType(std::move(outputColumnNames)); auto plan = PlanBuilder() @@ -84,9 +83,7 @@ class ParquetTableScanTest : public HiveConnectorTestBase { const std::vector& subfieldFilters, const std::string& remainingFilter, const std::string& sql, - const std::unordered_map< - std::string, - std::shared_ptr>& assignments = {}) { + const connector::ColumnHandleMap& assignments = {}) { auto rowType = getRowType(std::move(outputColumnNames)); parse::ParseOptions options; options.parseDecimalAsDouble = false; @@ -293,6 +290,11 @@ class ParquetTableScanTest : public HiveConnectorTestBase { {}, "t == TIMESTAMP '2022-12-23 03:56:01'", "SELECT t from tmp where t == TIMESTAMP '2022-12-23 03:56:01'"); + assertSelectWithFilter( + {"t"}, + {}, + "not(eq(t, TIMESTAMP '2000-09-12 22:36:29'))", + "SELECT t from tmp where t != TIMESTAMP '2000-09-12 22:36:29'"); } private: @@ -426,6 +428,39 @@ TEST_F(ParquetTableScanTest, aggregatePushdown) { assertEqualVectors(rows->childAt(1), valuesVector); } +TEST_F(ParquetTableScanTest, aggregatePushdownToSmallPages) { + const std::vector columnNames = {"a", "b", "c"}; + const auto expectedRowVector = makeRowVector( + {makeFlatVector({1, 2, 4}), + makeFlatVector({7, 9, 13})}); + const auto outputType = ROW(columnNames, {SMALLINT(), SMALLINT(), VARCHAR()}); + std::vector data; + for (auto row = 0; row < 10; ++row) { + data.emplace_back(makeRowVector( + columnNames, + { + makeFlatVector({static_cast(row % 5)}), + makeFlatVector({static_cast(row)}), + makeFlatVector({std::to_string(row)}), + })); + } + const auto filePath = TempFilePath::create(); + WriterOptions options; + options.dataPageSize = 1; + writeToParquetFile(filePath->getPath(), data, options); + const auto plan = + PlanBuilder(pool()) + .tableScan( + outputType, + {}, + "c <> '' AND a in (1::smallint, 2::smallint, 4::smallint)") + .singleAggregation({"a"}, {"sum(b) as s"}) + .planNode(); + AssertQueryBuilder(plan) + .split(makeSplit(filePath->getPath())) + .assertResults(expectedRowVector); +} + TEST_F(ParquetTableScanTest, countStar) { // sample.parquet holds two columns (a: BIGINT, b: DOUBLE) and // 20 rows. @@ -537,8 +572,77 @@ TEST_F(ParquetTableScanTest, array) { vector, })); - assertSelectWithFilter( - {"repeatedInt"}, {}, "", "SELECT UNNEST(array[array[1,2,3]])"); + assertSelectWithFilter({"repeatedInt"}, {}, "", "SELECT [1,2,3]"); + + // Set the requested type for unannotated array. + auto rowType = ROW({"repeatedInt"}, {ARRAY(INTEGER())}); + auto plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(getExampleFilePath("old_repeated_int.parquet"))}) + .assertResults("SELECT [1,2,3]"); + + // Throws when reading repeated values as scalar type. + rowType = ROW({"repeatedInt"}, {INTEGER()}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(getExampleFilePath("old_repeated_int.parquet"))}) + .assertResults(""), + "Requested type must be array"); + + rowType = ROW({"mystring"}, {ARRAY(VARCHAR())}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits({makeSplit(getExampleFilePath("proto_repeated_string.parquet"))}) + .assertResults( + "SELECT UNNEST(array[array['hello', 'world'], array['good','bye'], array['one', 'two', 'three']])"); + + rowType = + ROW({"primitive", "myComplex"}, + {INTEGER(), + ARRAY( + ROW({"id", "repeatedMessage"}, + {INTEGER(), ARRAY(ROW({"someId"}, {INTEGER()}))}))}); + plan = PlanBuilder(pool_.get()) + .tableScan(rowType, {}, "", rowType, {}) + .planNode(); + + // Construct the expected vector. + auto someIdVector = makeArrayOfRowVector( + ROW({"someId"}, {INTEGER()}), + { + {variant::row({3})}, + {variant::row({6})}, + {variant::row({9})}, + }); + auto rowVector = makeRowVector( + {"id", "repeatedMessage"}, + { + makeFlatVector({1, 4, 7}), + someIdVector, + }); + auto expected = makeRowVector( + {"primitive", "myComplex"}, + { + makeFlatVector({2, 5, 8}), + makeArrayVector({0, 1, 2}, rowVector), + }); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kParquetUseColumnNamesSession, + "true") + .splits({makeSplit(getExampleFilePath("nested_array_struct.parquet"))}) + .assertResults(expected); } // Optional array with required elements. @@ -655,44 +759,55 @@ TEST_F(ParquetTableScanTest, filterOnNestedArray) { } TEST_F(ParquetTableScanTest, readAsLowerCase) { - auto plan = PlanBuilder(pool_.get()) - .tableScan(ROW({"a"}, {BIGINT()}), {}, "") - .planNode(); - CursorParameters params; - std::shared_ptr executor = - std::make_shared( - std::thread::hardware_concurrency()); - std::shared_ptr queryCtx = - core::QueryCtx::create(executor.get()); - std::unordered_map session = { - {std::string( - connector::hive::HiveConfig::kFileColumnNamesReadAsLowerCaseSession), - "true"}}; - queryCtx->setConnectorSessionOverridesUnsafe( - kHiveConnectorId, std::move(session)); - params.queryCtx = queryCtx; - params.planNode = plan; - const int numSplitsPerFile = 1; - - auto addSplits = [&](exec::TaskCursor* taskCursor) { - if (taskCursor->noMoreSplits()) { - return; - } - auto& task = taskCursor->task(); - auto const splits = HiveConnectorTestBase::makeHiveConnectorSplits( - {getExampleFilePath("upper.parquet")}, - numSplitsPerFile, - dwio::common::FileFormat::PARQUET); - for (const auto& split : splits) { - task->addSplit("0", exec::Split(split)); - } - task->noMoreSplits("0"); - taskCursor->setNoMoreSplits(); - }; - auto result = readCursor(params, addSplits); - ASSERT_TRUE(waitForTaskCompletion(result.first->task().get())); - assertEqualResults( - result.second, {makeRowVector({"a"}, {makeFlatVector({0, 1})})}); + auto vectors = {makeRowVector( + {"A", "b"}, + { + makeFlatVector(20, [](auto row) { return row + 1; }), + makeFlatVector(20, [](auto row) { return row + 1; }), + })}; + auto filePath = TempFilePath::create(); + WriterOptions options; + writeToParquetFile(filePath->getPath(), vectors, options); + createDuckDbTable(vectors); + + auto plan = PlanBuilder().tableScan(ROW({"a"}, {BIGINT()})).planNode(); + auto split = makeSplit(filePath->getPath()); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kFileColumnNamesReadAsLowerCaseSession, + "true") + .split(split) + .assertResults("SELECT A FROM tmp"); + + // Test reading table with non-ascii names. + auto vectorsNonAsciiNames = {makeRowVector( + {"Товары", "国Ⅵ", "\uFF21", "\uFF22"}, + { + makeFlatVector(20, [](auto row) { return row + 1; }), + makeFlatVector(20, [](auto row) { return row + 1; }), + makeFlatVector(20, [](auto row) { return row + 1; }), + makeFlatVector(20, [](auto row) { return row + 1; }), + })}; + filePath = TempFilePath::create(); + writeToParquetFile(filePath->getPath(), vectorsNonAsciiNames, options); + createDuckDbTable(vectorsNonAsciiNames); + + plan = PlanBuilder() + .tableScan( + ROW({"товары", "国ⅵ", "\uFF41", "\uFF42"}, + {BIGINT(), DOUBLE(), REAL(), INTEGER()})) + .planNode(); + split = makeSplit(filePath->getPath()); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kFileColumnNamesReadAsLowerCaseSession, + "true") + .split(split) + .assertResults("SELECT * FROM tmp"); } TEST_F(ParquetTableScanTest, rowIndex) { @@ -714,8 +829,7 @@ TEST_F(ParquetTableScanTest, rowIndex) { }), std::nullopt, std::unordered_map{{kPath, filePath}}); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["a"] = std::make_shared( "a", connector::hive::HiveColumnHandle::ColumnType::kRegular, @@ -812,16 +926,14 @@ TEST_F(ParquetTableScanTest, filterNullIcebergPartition) { {"c1 IS NOT NULL"}, "", "SELECT c0, c1 FROM tmp WHERE c1 IS NOT NULL", - std::unordered_map>{ - {"c0", c0}, {"c1", c1}}); + connector::ColumnHandleMap{{"c0", c0}, {"c1", c1}}); assertSelectWithFilter( {"c0", "c1"}, {"c1 IS NULL"}, "", "SELECT c0, c1 FROM tmp WHERE c1 IS NULL", - std::unordered_map>{ - {"c0", c0}, {"c1", c1}}); + connector::ColumnHandleMap{{"c0", c0}, {"c1", c1}}); } TEST_F(ParquetTableScanTest, sessionTimezone) { @@ -1334,6 +1446,43 @@ TEST_F(ParquetTableScanTest, shortAndLongDecimalReadWithLargerPrecision) { assertEqualVectors(expectedDecimalVectors->childAt(1), rows->childAt(1)); } +TEST_F(ParquetTableScanTest, inFilter) { + auto vectors = {makeRowVector( + {"name"}, + { + makeNullableFlatVector( + {"mary", "martin", "lucy", "alex", std::nullopt, "mary", "dan"}), + })}; + auto filePath = TempFilePath::create(); + WriterOptions options; + writeToParquetFile(filePath->getPath(), vectors, options); + createDuckDbTable(vectors); + + // Test in. + auto plan = PlanBuilder(pool_.get()) + .tableScan( + ROW({"name"}, {VARCHAR()}), + {"name in ('alex', 'leo', 'mary', null, 'victor')"}, + "") + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .split(makeSplit(filePath->getPath())) + .assertResults( + "SELECT name FROM tmp where name in ('alex', 'leo', 'mary', null, 'victor')"); + + // Test not in. + plan = PlanBuilder(pool_.get()) + .tableScan( + ROW({"name"}, {VARCHAR()}), + {"name not in ('alex', 'leo', 'mary', null, 'victor')"}, + "") + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .split(makeSplit(filePath->getPath())) + .assertResults( + "SELECT name FROM tmp where name not in ('alex', 'leo', 'mary', null, 'victor')"); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); folly::Init init{&argc, &argv, false}; diff --git a/velox/dwio/parquet/tests/thrift/CMakeLists.txt b/velox/dwio/parquet/tests/thrift/CMakeLists.txt index 3c3af4113645..65f89cc69357 100644 --- a/velox/dwio/parquet/tests/thrift/CMakeLists.txt +++ b/velox/dwio/parquet/tests/thrift/CMakeLists.txt @@ -21,4 +21,5 @@ target_link_libraries( thrift velox_link_libs GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) diff --git a/velox/dwio/parquet/tests/writer/CMakeLists.txt b/velox/dwio/parquet/tests/writer/CMakeLists.txt index 93a0538650c7..6948dbfcf70e 100644 --- a/velox/dwio/parquet/tests/writer/CMakeLists.txt +++ b/velox/dwio/parquet/tests/writer/CMakeLists.txt @@ -17,35 +17,40 @@ add_executable(velox_parquet_writer_sink_test SinkTests.cpp) add_test( NAME velox_parquet_writer_sink_test COMMAND velox_parquet_writer_sink_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_parquet_writer_sink_test velox_dwio_parquet_writer velox_dwio_common_test_utils velox_vector_fuzzer - Boost::regex + velox_caching velox_link_libs Folly::folly ${TEST_LINK_LIBS} GTest::gtest - fmt::fmt) + fmt::fmt +) -add_executable(velox_parquet_writer_test ParquetWriterTest.cpp) +add_executable(velox_parquet_writer_test ParquetWriterFieldIdTest.cpp ParquetWriterTest.cpp) add_test( NAME velox_parquet_writer_test COMMAND velox_parquet_writer_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_parquet_writer_test + velox_dwio_arrow_parquet_writer_test_lib velox_dwio_parquet_writer velox_dwio_parquet_reader velox_dwio_common_test_utils + velox_caching velox_link_libs - Boost::regex Folly::folly ${TEST_LINK_LIBS} GTest::gtest - fmt::fmt) + fmt::fmt +) diff --git a/velox/dwio/parquet/tests/writer/ParquetWriterFieldIdTest.cpp b/velox/dwio/parquet/tests/writer/ParquetWriterFieldIdTest.cpp new file mode 100644 index 000000000000..d81aac342939 --- /dev/null +++ b/velox/dwio/parquet/tests/writer/ParquetWriterFieldIdTest.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/parquet/writer/arrow/tests/TestUtil.h" + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/dwio/parquet/tests/ParquetTestBase.h" +#include "velox/dwio/parquet/writer/arrow/Schema.h" +#include "velox/dwio/parquet/writer/arrow/tests/FileReader.h" + +namespace { + +using namespace facebook::velox; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::parquet; + +class ParquetWriterFieldIdTest : public ParquetTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(ParquetWriterFieldIdTest, fieldIds) { + auto schema = + ROW({"p", "s", "a", "m"}, + {BIGINT(), + ROW({"x", "y"}, {INTEGER(), VARCHAR()}), + ARRAY(INTEGER()), + MAP(VARCHAR(), INTEGER())}); + constexpr int32_t kRows = 10; + auto data = makeRowVector( + {"p", "s", "a", "m"}, + {makeFlatVector(kRows, [](auto row) { return row; }), + makeRowVector( + {"x", "y"}, + {makeFlatVector(kRows, [](auto row) { return row; }), + makeFlatVector(kRows, [](auto) { return "z"; })}), + makeArrayVectorFromJson(std::vector(kRows, "[3]")), + makeMapVectorFromJson( + std::vector(kRows, R"({"k": 4})"))}); + + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + if (GetParam()) { + // Provide Parquet field IDs aligned with the Velox schema tree. + // p -> 10. + // s -> 20, children: x -> 21, y -> 22. + // a -> 30, list element -> 31. + // m -> 40, children: key -> 41, value -> 42. + writerOptions.parquetFieldIds = { + ParquetFieldId{10, {}}, + ParquetFieldId{20, {ParquetFieldId{21, {}}, ParquetFieldId{22, {}}}}, + ParquetFieldId{30, {ParquetFieldId{31, {}}}}, + ParquetFieldId{40, {ParquetFieldId{41, {}}, ParquetFieldId{42, {}}}}, + }; + } + + auto* sinkPtr = write(data, writerOptions); + + dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + auto parquetReader = createReaderInMemory(*sinkPtr, readerOptions); + EXPECT_EQ(parquetReader->numberOfRows(), kRows); + auto veloxRowType = parquetReader->rowType(); + EXPECT_EQ(*veloxRowType, *schema); + + std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); + auto arrowBufferReader = std::make_shared<::arrow::io::BufferReader>( + std::make_shared<::arrow::Buffer>( + reinterpret_cast(sinkData.data()), sinkData.size())); + + auto fileReader = parquet::arrow::ParquetFileReader::Open(arrowBufferReader); + auto metadata = fileReader->metadata(); + auto* descr = metadata->schema(); + auto* root = descr->group_node(); + + ASSERT_EQ(root->field_count(), 4); + + auto exp = [&](int32_t expectedFieldId) { + return GetParam() ? expectedFieldId : -1; + }; + + // Top-level field IDs. + EXPECT_EQ(root->field(0)->field_id(), exp(10)); + EXPECT_EQ(root->field(1)->field_id(), exp(20)); + EXPECT_EQ(root->field(2)->field_id(), exp(30)); + EXPECT_EQ(root->field(3)->field_id(), exp(40)); + + using GroupNode = parquet::arrow::schema::GroupNode; + auto* s = static_cast(root->field(1).get()); + EXPECT_EQ(s->field(0)->field_id(), exp(21)); + EXPECT_EQ(s->field(1)->field_id(), exp(22)); + + auto* a = static_cast(root->field(2).get()); + // LIST logical group has one repeated child (the array entries); dive once + // more to the element. + auto* listEntries = a->field(0).get(); + auto* listGroup = static_cast(listEntries); + auto* element = listGroup->field(0).get(); + EXPECT_EQ(element->field_id(), exp(31)); + + auto* m = static_cast(root->field(3).get()); + auto* keyValue = m->field(0).get(); + auto* keyValueGroup = static_cast(keyValue); + EXPECT_EQ(keyValueGroup->field(0)->field_id(), exp(41)); + EXPECT_EQ(keyValueGroup->field(1)->field_id(), exp(42)); +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + ParquetWriterFieldIdTest, + ParquetWriterFieldIdTest, + ::testing::Values(false, true)); + +} // namespace diff --git a/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp b/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp index 3dc9f93990eb..8f666f6aa350 100644 --- a/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp +++ b/velox/dwio/parquet/tests/writer/ParquetWriterTest.cpp @@ -22,6 +22,7 @@ #include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConnector.h" // @manual #include "velox/core/QueryCtx.h" +#include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/dwio/parquet/RegisterParquetWriter.h" // @manual #include "velox/dwio/parquet/reader/PageReader.h" #include "velox/dwio/parquet/tests/ParquetTestBase.h" @@ -45,19 +46,19 @@ class ParquetWriterTest : public ParquetTestBase { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); testutil::TestValue::enable(); filesystems::registerLocalFileSystem(); - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared( - std::unordered_map())); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + kHiveConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(hiveConnector); parquet::registerParquetWriterFactory(); } + void TearDown() override { + writers_.clear(); + } + std::unique_ptr createRowReaderWithSchema( const std::unique_ptr reader, const RowTypePtr& rowType) { @@ -66,20 +67,109 @@ class ParquetWriterTest : public ParquetTestBase { rowReaderOpts.setScanSpec(scanSpec); auto rowReader = reader->createRowReader(rowReaderOpts); return rowReader; - }; + } - std::unique_ptr createReaderInMemory( - const dwio::common::MemorySink& sink, - const dwio::common::ReaderOptions& opts) { - std::string data(sink.data(), sink.size()); - return std::make_unique( - std::make_unique( - std::make_shared(std::move(data)), - opts.memoryPool()), - opts); - }; + RowVectorPtr makeSmallintTestData(int64_t rows) { + auto data = makeRowVector({ + makeFlatVector(rows, [](auto row) { return row + 1; }), + }); + return data; + } + + RowVectorPtr makeTimestampTestData(int64_t rows) { + auto data = makeRowVector({makeFlatVector( + rows, [](auto row) { return Timestamp(row, row); })}); + return data; + } + + thrift::PageHeader readPageHeader( + MemorySink* sinkPtr, + int64_t offsetFromDataPage) { + dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + auto reader = createReaderInMemory(*sinkPtr, readerOptions); + + auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); + std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); + + auto readFile = std::make_shared(sinkData); + auto file = std::make_shared(std::move(readFile)); + + auto inputStream = std::make_unique( + std::move(file), + colChunkPtr.dataPageOffset() + offsetFromDataPage, + 150, + *leafPool_, + LogType::TEST); + auto pageReader = std::make_unique( + std::move(inputStream), + *leafPool_, + colChunkPtr.compression(), + colChunkPtr.totalCompressedSize(), + stats); + return pageReader->readPageHeader(); + } inline static const std::string kHiveConnectorId = "test-hive"; + dwio::common::ColumnReaderStatistics stats; +}; + +class ArrowMemoryPool final : public ::arrow::MemoryPool { + public: + explicit ArrowMemoryPool() : allocated_(0) {} + + ~ArrowMemoryPool() = default; + + ::arrow::Status Allocate(int64_t size, int64_t alignment, uint8_t** out) + override { + *out = reinterpret_cast(malloc(size)); + VELOX_CHECK_NOT_NULL(*out, "Failed to allocate memory in ArrowMemoryPool."); + + allocated_ += size; + return ::arrow::Status::OK(); + } + + ::arrow::Status Reallocate( + int64_t oldSize, + int64_t newSize, + int64_t alignment, + uint8_t** ptr) override { + uint8_t* newBuffer = reinterpret_cast(realloc(*ptr, newSize)); + VELOX_CHECK_NOT_NULL( + newBuffer, "Failed to reallocate memory in ArrowMemoryPool."); + + *ptr = newBuffer; + allocated_ = allocated_ - oldSize + newSize; + return ::arrow::Status::OK(); + } + + void Free(uint8_t* buffer, int64_t size, int64_t alignment) override { + free(buffer); + allocated_ -= size; + } + + int64_t bytes_allocated() const override { + return allocated_; + ; + } + + int64_t max_memory() const override { + VELOX_UNSUPPORTED("ArrowMemoryPool#max_memory() unsupported"); + } + + int64_t total_bytes_allocated() const override { + VELOX_UNSUPPORTED("ArrowMemoryPool#total_bytes_allocated() unsupported"); + } + + int64_t num_allocations() const override { + VELOX_UNSUPPORTED("ArrowMemoryPool#num_allocations() unsupported"); + } + + std::string backend_name() const override { + return "arrow memory pool"; + } + + private: + int64_t allocated_; }; std::vector params = { @@ -91,76 +181,24 @@ std::vector params = { }; TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { - const auto schema = ROW({"c0"}, {SMALLINT()}); constexpr int64_t kRows = 10'000; - const auto data = makeRowVector({ - makeFlatVector(kRows, [](auto row) { return row + 1; }), - }); + const auto data = makeSmallintTestData(kRows); // Write Parquet test data, then read and return the DataPage // (thrift::PageType::type) used. const auto testEnableDictionaryAndDictionaryPageSizeToGetPageHeader = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties, - bool isFirstPageOrSecondPage) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - if (isFirstPageOrSecondPage) { - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + bool isFirstPage) { + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + if (isFirstPage) { + return readPageHeader(sinkPtr, 0); } constexpr int64_t kFirstDataPageCompressedSize = 1291; constexpr int64_t kFirstDataPageHeaderSize = 48; - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset() + kFirstDataPageCompressedSize + - kFirstDataPageHeaderSize, - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + return readPageHeader( + sinkPtr, kFirstDataPageCompressedSize + kFirstDataPageHeaderSize); }; // Test default config (i.e., no explicit config) @@ -261,58 +299,17 @@ TEST_F(ParquetWriterTest, dictionaryEncodingWithDictionaryPageSize) { } TEST_F(ParquetWriterTest, dictionaryEncodingOff) { - const auto schema = ROW({"c0"}, {SMALLINT()}); constexpr int64_t kRows = 10'000; - const auto data = makeRowVector({ - makeFlatVector(kRows, [](auto row) { return row + 1; }), - }); + const auto data = makeSmallintTestData(kRows); // Write Parquet test data, then read and return the DataPage // (thrift::PageType::type) used. const auto testEnableDictionaryAndDictionaryPageSizeToGetPageHeader = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + return readPageHeader(sinkPtr, 0); }; // Test only dictionary off without dictionary page size configured @@ -390,25 +387,16 @@ TEST_F(ParquetWriterTest, compression) { makeFlatVector(kRows, [](auto row) { return row - 25; }), }); - // Create an in-memory writer - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - facebook::velox::parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); writerOptions.compressionKind = CompressionKind::CompressionKind_SNAPPY; const auto& fieldNames = schema->names(); - for (int i = 0; i < params.size(); i++) { writerOptions.columnCompressionsMap[fieldNames[i]] = params[i]; } - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); + auto* sinkPtr = write(data, writerOptions); dwio::common::ReaderOptions readerOptions{leafPool_.get()}; auto reader = createReaderInMemory(*sinkPtr, readerOptions); @@ -425,61 +413,20 @@ TEST_F(ParquetWriterTest, compression) { auto rowReader = createRowReaderWithSchema(std::move(reader), schema); assertReadWithReaderAndExpected(schema, *rowReader, data, *leafPool_); -}; +} TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { - const auto schema = ROW({"c0"}, {SMALLINT()}); constexpr int64_t kRows = 10'000; - const auto data = makeRowVector({ - makeFlatVector(kRows, [](auto row) { return row + 1; }), - }); + const auto data = makeSmallintTestData(kRows); // Write Parquet test data, then read and return the DataPage // (thrift::PageType::type) used. const auto testPageSizeAndBatchSizeToGetPageHeader = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - return pageReader->readPageHeader(); + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + return readPageHeader(sinkPtr, 0); }; // Test default config (i.e., no explicit config) @@ -573,7 +520,6 @@ TEST_F(ParquetWriterTest, testPageSizeAndBatchSizeConfiguration) { } TEST_F(ParquetWriterTest, toggleDataPageVersion) { - auto schema = ROW({"c0"}, {INTEGER()}); const int64_t kRows = 1; const auto data = makeRowVector({ makeFlatVector(kRows, [](auto row) { return 987; }), @@ -584,50 +530,9 @@ TEST_F(ParquetWriterTest, toggleDataPageVersion) { const auto testDataPageVersion = [&](std::unordered_map configFromFile, std::unordered_map sessionProperties) { - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto sinkPtr = sink.get(); - parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); - - // Simulate setting of Hive config & connector session properties, then - // write test data. - auto connectorConfig = config::ConfigBase(std::move(configFromFile)); - auto connectorSessionProperties = - config::ConfigBase(std::move(sessionProperties)); - - writerOptions.processConfigs( - connectorConfig, connectorSessionProperties); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, schema); - writer->write(data); - writer->close(); - - // Read to identify DataPage used. - dwio::common::ReaderOptions readerOptions{leafPool_.get()}; - auto reader = createReaderInMemory(*sinkPtr, readerOptions); - - auto colChunkPtr = reader->fileMetaData().rowGroup(0).columnChunk(0); - std::string_view sinkData(sinkPtr->data(), sinkPtr->size()); - - auto readFile = std::make_shared(sinkData); - auto file = std::make_shared(std::move(readFile)); - - auto inputStream = std::make_unique( - std::move(file), - colChunkPtr.dataPageOffset(), - 150, - *leafPool_, - LogType::TEST); - auto pageReader = std::make_unique( - std::move(inputStream), - *leafPool_, - colChunkPtr.compression(), - colChunkPtr.totalCompressedSize()); - - return pageReader->readPageHeader().type; + auto* sinkPtr = write( + data, std::move(configFromFile), std::move(sessionProperties)); + return readPageHeader(sinkPtr, 0).type; }; // Test default behavior - DataPage should be V1. @@ -702,24 +607,16 @@ DEBUG_ONLY_TEST_F(ParquetWriterTest, unitFromWriterOptions) { ASSERT_EQ(tsType->timezone(), "America/Los_Angeles"); }))); - const auto data = makeRowVector({makeFlatVector( - 10'000, [](auto row) { return Timestamp(row, row); })}); + const auto data = makeTimestampTestData(10'000); parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); + writerOptions.memoryPool = rootPool_.get(); writerOptions.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; writerOptions.parquetWriteTimestampTimeZone = "America/Los_Angeles"; - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, ROW({"c0"}, {TIMESTAMP()})); - writer->write(data); - writer->close(); -}; + write(data, writerOptions); +} -TEST_F(ParquetWriterTest, parquetWriteTimestampTimeZoneWithDefault) { +DEBUG_ONLY_TEST_F(ParquetWriterTest, parquetWriteTimestampTimeZoneWithDefault) { SCOPED_TESTVALUE_SET( "facebook::velox::parquet::Writer::write", std::function( @@ -731,25 +628,26 @@ TEST_F(ParquetWriterTest, parquetWriteTimestampTimeZoneWithDefault) { ASSERT_EQ(tsType->timezone(), ""); }))); - const auto data = makeRowVector({makeFlatVector( - 10'000, [](auto row) { return Timestamp(row, row); })}); + const auto data = makeTimestampTestData(10'000); parquet::WriterOptions writerOptions; - writerOptions.memoryPool = leafPool_.get(); + writerOptions.memoryPool = rootPool_.get(); writerOptions.parquetWriteTimestampUnit = TimestampPrecision::kMicroseconds; - // Create an in-memory writer. - auto sink = std::make_unique( - 200 * 1024 * 1024, - dwio::common::FileSink::Options{.pool = leafPool_.get()}); - auto writer = std::make_unique( - std::move(sink), writerOptions, rootPool_, ROW({"c0"}, {TIMESTAMP()})); - writer->write(data); - writer->close(); -}; + write(data, writerOptions); +} + +TEST_F(ParquetWriterTest, parquetWriteWithArrowMemoryPool) { + const auto data = makeTimestampTestData(10'000); + parquet::WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + writerOptions.arrowMemoryPool = std::make_shared(); + + write(data, writerOptions); +} TEST_F(ParquetWriterTest, updateWriterOptionsFromHiveConfig) { std::unordered_map configFromFile = { - {parquet::WriterOptions::kParquetSessionWriteTimestampUnit, "3"}}; + {parquet::WriterOptions::kParquetHiveConnectorWriteTimestampUnit, "3"}}; const config::ConfigBase connectorConfig(std::move(configFromFile)); const config::ConfigBase connectorSessionProperties({}); @@ -804,6 +702,91 @@ DEBUG_ONLY_TEST_F(ParquetWriterTest, timestampUnitAndTimeZone) { } #endif +TEST_F(ParquetWriterTest, dictionaryEncodedVector) { + const auto randomIndices = [this](vector_size_t size) { + BufferPtr indices = + AlignedBuffer::allocate(size, leafPool_.get()); + auto rawIndices = indices->asMutable(); + for (int32_t i = 0; i < size; i++) { + rawIndices[i] = folly::Random::rand32(size); + } + return indices; + }; + + const auto wrapDictionaryVectors = + [&](const std::vector& vectors) { + std::vector wrappedVectors; + wrappedVectors.reserve(vectors.size()); + + for (const auto& vector : vectors) { + auto wrappedVector = BaseVector::wrapInDictionary( + BufferPtr(nullptr), + randomIndices(vector->size()), + vector->size(), + vector); + EXPECT_EQ( + wrappedVector->encoding(), VectorEncoding::Simple::DICTIONARY); + wrappedVectors.emplace_back(wrappedVector); + } + return wrappedVectors; + }; + + // Dictionary encoded vectors with complex type. + const auto size = 10'000; + auto wrappedVectors = wrapDictionaryVectors({ + facebook::velox::test::BatchMaker::createVector( + MAP(VARCHAR(), INTEGER()), size, *leafPool_), + facebook::velox::test::BatchMaker::createVector( + ARRAY(VARCHAR()), size, *leafPool_), + facebook::velox::test::BatchMaker::createVector( + ROW({"c0", "c1"}, + {BIGINT(), ROW({"id", "name"}, {INTEGER(), VARCHAR()})}), + size, + *leafPool_), + }); + + auto data = makeRowVector(wrappedVectors); + write(data); + + // Dictionary encoded constant vector of scalar type. + const auto constantVector = makeConstant(static_cast(123'456), size); + const auto wrappedVector = std::make_shared>( + leafPool_.get(), nullptr, size, constantVector, randomIndices(size)); + EXPECT_EQ(wrappedVector->encoding(), VectorEncoding::Simple::DICTIONARY); + VELOX_CHECK_NOT_NULL(wrappedVector->valueVector()); + EXPECT_FALSE(wrappedVector->wrappedVector()->isFlatEncoding()); + + data = makeRowVector({wrappedVector}); + write(data); +} + +TEST_F(ParquetWriterTest, allNulls) { + auto schema = ROW({"c0"}, {INTEGER()}); + const int64_t kRows = 4096; + // Create a column with all elements being null. + auto nulls = makeNulls(kRows, [](auto /*row*/) { return true; }); + auto flatVector = std::make_shared>( + pool_.get(), + schema->childAt(0), + nulls, + kRows, + /*values=*/nullptr, + std::vector()); + auto data = std::make_shared( + pool_.get(), schema, nullptr, kRows, std::vector{flatVector}); + + auto* sinkPtr = write(data); + + dwio::common::ReaderOptions readerOptions{leafPool_.get()}; + auto reader = createReaderInMemory(*sinkPtr, readerOptions); + + ASSERT_EQ(reader->numberOfRows(), kRows); + ASSERT_EQ(*reader->rowType(), *schema); + + auto rowReader = createRowReaderWithSchema(std::move(reader), schema); + assertReadWithReaderAndExpected(schema, *rowReader, data, *leafPool_); +} + } // namespace int main(int argc, char** argv) { diff --git a/velox/dwio/parquet/thrift/CMakeLists.txt b/velox/dwio/parquet/thrift/CMakeLists.txt index cb0b6ea91881..ee6fe26e8c80 100644 --- a/velox/dwio/parquet/thrift/CMakeLists.txt +++ b/velox/dwio/parquet/thrift/CMakeLists.txt @@ -13,9 +13,4 @@ # limitations under the License. velox_add_library(velox_dwio_parquet_thrift ParquetThriftTypes.cpp) -velox_link_libraries( - velox_dwio_parquet_thrift - arrow - thrift - Boost::headers - fmt::fmt) +velox_link_libraries(velox_dwio_parquet_thrift arrow thrift Boost::headers fmt::fmt) diff --git a/velox/dwio/parquet/writer/CMakeLists.txt b/velox/dwio/parquet/writer/CMakeLists.txt index 6d2e42434c4b..b2bc9c96703a 100644 --- a/velox/dwio/parquet/writer/CMakeLists.txt +++ b/velox/dwio/parquet/writer/CMakeLists.txt @@ -22,5 +22,7 @@ velox_link_libraries( velox_dwio_arrow_parquet_writer_util_lib velox_dwio_common velox_arrow_bridge + velox_exec arrow - fmt::fmt) + fmt::fmt +) diff --git a/velox/dwio/parquet/writer/Writer.cpp b/velox/dwio/parquet/writer/Writer.cpp index 895f88a57ae9..053ee290c39a 100644 --- a/velox/dwio/parquet/writer/Writer.cpp +++ b/velox/dwio/parquet/writer/Writer.cpp @@ -22,6 +22,7 @@ #include "velox/common/config/Config.h" #include "velox/common/testutil/TestValue.h" #include "velox/core/QueryConfig.h" +#include "velox/dwio/parquet/writer/arrow/ArrowSchema.h" #include "velox/dwio/parquet/writer/arrow/Properties.h" #include "velox/dwio/parquet/writer/arrow/Writer.h" #include "velox/exec/MemoryReclaimer.h" @@ -164,12 +165,17 @@ std::shared_ptr getArrowParquetWriterOptions( properties = properties->data_page_version(arrow::ParquetDataPageVersion::V1); } + if (options.createdBy.has_value()) { + properties = properties->created_by(options.createdBy.value()); + } return properties->build(); } -void validateSchemaRecursive(const RowTypePtr& schema) { - // Check the schema's field names is not empty and unique. - VELOX_USER_CHECK_NOT_NULL(schema, "Field schema must not be empty."); +void validateSchemaRecursive( + const RowTypePtr& schema, + const std::vector& parquetFieldIds) { + // Check the schema's field names are not empty and unique. + VELOX_USER_CHECK_NOT_NULL(schema, "Schema must not be empty."); const auto& fieldNames = schema->names(); folly::F14FastSet uniqueNames; @@ -182,55 +188,100 @@ void validateSchemaRecursive(const RowTypePtr& schema) { name); } + if (!parquetFieldIds.empty()) { + VELOX_USER_CHECK_EQ(parquetFieldIds.size(), schema->size()); + } + for (auto i = 0; i < schema->size(); ++i) { - if (auto childSchema = - std::dynamic_pointer_cast(schema->childAt(i))) { - validateSchemaRecursive(childSchema); + const auto& childType = schema->childAt(i); + const auto& childFieldIds = + parquetFieldIds.empty() ? parquetFieldIds : parquetFieldIds[i].children; + + if (childType->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(childType), childFieldIds); + } else if (childType->isArray()) { + if (!parquetFieldIds.empty()) { + VELOX_USER_CHECK_EQ(parquetFieldIds[i].children.size(), 1); + } + const auto& elementType = childType->asArray().elementType(); + if (elementType->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(elementType), + childFieldIds.empty() ? childFieldIds : childFieldIds[0].children); + } + } else if (childType->isMap()) { + if (!parquetFieldIds.empty()) { + VELOX_USER_CHECK_EQ(parquetFieldIds[i].children.size(), 2); + } + const auto& mapType = childType->asMap(); + if (mapType.keyType()->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(mapType.keyType()), + childFieldIds.empty() ? childFieldIds : childFieldIds[0].children); + } + if (mapType.valueType()->isRow()) { + validateSchemaRecursive( + std::dynamic_pointer_cast(mapType.valueType()), + childFieldIds.empty() ? childFieldIds : childFieldIds[1].children); + } } } } -std::shared_ptr<::arrow::Field> updateFieldNameRecursive( +std::shared_ptr<::arrow::Field> updateFieldNameAndIdRecursive( const std::shared_ptr<::arrow::Field>& field, const Type& type, + const ParquetFieldId* fieldId, const std::string& name = "") { + auto newField = name.empty() ? field : field->WithName(name); + + if (fieldId) { + newField = + newField->WithMetadata(arrow::arrow::fieldIdMetadata(fieldId->fieldId)); + } + if (type.isRow()) { auto& rowType = type.asRow(); - auto newField = field->WithName(name); auto structType = std::dynamic_pointer_cast<::arrow::StructType>(newField->type()); auto childrenSize = rowType.size(); + VELOX_CHECK(!fieldId || childrenSize <= fieldId->children.size()); std::vector> newFields; newFields.reserve(childrenSize); - for (auto i = 0; i < childrenSize; i++) { - newFields.push_back(updateFieldNameRecursive( - structType->fields()[i], *rowType.childAt(i), rowType.nameOf(i))); + for (auto i = 0; i < childrenSize; ++i) { + const auto* childSetting = fieldId ? &fieldId->children.at(i) : nullptr; + newFields.push_back(updateFieldNameAndIdRecursive( + structType->fields()[i], + *rowType.childAt(i), + childSetting, + rowType.nameOf(i))); } - return newField->WithType(::arrow::struct_(newFields)); + newField = newField->WithType(::arrow::struct_(newFields)); } else if (type.isArray()) { - auto newField = field->WithName(name); auto listType = std::dynamic_pointer_cast<::arrow::BaseListType>(newField->type()); auto elementType = type.asArray().elementType(); auto elementField = listType->value_field(); - return newField->WithType( - ::arrow::list(updateFieldNameRecursive(elementField, *elementType))); + const auto* childSetting = fieldId ? &fieldId->children.at(0) : nullptr; + auto updatedElementField = + updateFieldNameAndIdRecursive(elementField, *elementType, childSetting); + newField = newField->WithType(::arrow::list(updatedElementField)); } else if (type.isMap()) { auto mapType = type.asMap(); - auto newField = field->WithName(name); auto arrowMapType = std::dynamic_pointer_cast<::arrow::MapType>(newField->type()); - auto newKeyField = - updateFieldNameRecursive(arrowMapType->key_field(), *mapType.keyType()); - auto newValueField = updateFieldNameRecursive( - arrowMapType->item_field(), *mapType.valueType()); - return newField->WithType( - ::arrow::map(newKeyField->type(), newValueField->type())); - } else if (name != "") { - return field->WithName(name); - } else { - return field; + const auto* keySetting = fieldId ? &fieldId->children.at(0) : nullptr; + const auto* valueSetting = fieldId ? &fieldId->children.at(1) : nullptr; + auto newKeyField = updateFieldNameAndIdRecursive( + arrowMapType->key_field(), *mapType.keyType(), keySetting); + auto newValueField = updateFieldNameAndIdRecursive( + arrowMapType->item_field(), *mapType.valueType(), valueSetting); + newField = newField->WithType( + std::make_shared<::arrow::MapType>(newKeyField, newValueField)); } + + return newField; } std::optional getTimestampUnit( @@ -246,6 +297,17 @@ std::optional getTimestampUnit( return std::nullopt; } +// Converts a string to TimestampPrecision. Accepts numeric values "3" (milli), +// "6" (micro), or "9" (nano). +TimestampPrecision stringToTimestampPrecision(const std::string& value) { + auto unit = folly::to(value); + VELOX_CHECK( + unit == 3 /*milli*/ || unit == 6 /*micro*/ || unit == 9 /*nano*/, + "Invalid timestamp unit: {}", + unit); + return static_cast(unit); +} + std::optional getTimestampTimeZone( const config::ConfigBase& config, const char* configKey) { @@ -306,6 +368,15 @@ std::optional getParquetBatchSize( return std::nullopt; } +std::optional getParquetCreatedBy( + const config::ConfigBase& config, + const char* configKey) { + if (config.get(configKey).has_value()) { + return config.get(configKey).value(); + } + return std::nullopt; +} + } // namespace Writer::Writer( @@ -315,13 +386,14 @@ Writer::Writer( RowTypePtr schema) : pool_(std::move(pool)), generalPool_{pool_->addLeafChild(".general")}, - stream_(std::make_shared( - std::move(sink), - *generalPool_, - options.bufferGrowRatio)), + stream_( + std::make_shared( + std::move(sink), + *generalPool_, + options.bufferGrowRatio)), arrowContext_(std::make_shared()), schema_(std::move(schema)) { - validateSchemaRecursive(schema_); + validateSchemaRecursive(schema_, options.parquetFieldIds); if (options.flushPolicyFactory) { castUniquePointer(options.flushPolicyFactory(), flushPolicy_); @@ -338,6 +410,8 @@ Writer::Writer( getArrowParquetWriterOptions(options, flushPolicy_); setMemoryReclaimers(); writeInt96AsTimestamp_ = options.writeInt96AsTimestamp; + arrowMemoryPool_ = options.arrowMemoryPool; + parquetFieldIds_ = std::move(options.parquetFieldIds); } Writer::Writer( @@ -347,9 +421,10 @@ Writer::Writer( : Writer{ std::move(sink), options, - options.memoryPool->addAggregateChild(fmt::format( - "writer_node_{}", - folly::to(folly::Random::rand64()))), + options.memoryPool->addAggregateChild( + fmt::format( + "writer_node_{}", + folly::to(folly::Random::rand64()))), std::move(schema)} {} void Writer::flush() { @@ -364,7 +439,7 @@ void Writer::flush() { arrowContext_->writer, FileWriter::Open( *arrowContext_->schema.get(), - ::arrow::default_memory_pool(), + arrowMemoryPool_.get(), stream_, arrowContext_->properties, arrowProperties)); @@ -415,10 +490,15 @@ void Writer::write(const VectorPtr& data) { data->type()->equivalent(*schema_), "The file schema type should be equal with the input rowvector type."); + VectorPtr exportData = data; + if (needFlatten(exportData)) { + BaseVector::flattenVector(exportData); + } + ArrowArray array; ArrowSchema schema; - exportToArrow(data, array, generalPool_.get(), options_); - exportToArrow(data, schema, options_); + exportToArrow(exportData, array, generalPool_.get(), options_); + exportToArrow(exportData, schema, options_); // Convert the arrow schema to Schema and then update the column names based // on schema_. @@ -427,9 +507,15 @@ void Writer::write(const VectorPtr& data) { "facebook::velox::parquet::Writer::write", arrowSchema.get()); std::vector> newFields; auto childSize = schema_->size(); + if (!parquetFieldIds_.empty()) { + VELOX_CHECK(childSize == parquetFieldIds_.size()); + } for (auto i = 0; i < childSize; i++) { - newFields.push_back(updateFieldNameRecursive( - arrowSchema->fields()[i], *schema_->childAt(i), schema_->nameOf(i))); + newFields.push_back(updateFieldNameAndIdRecursive( + arrowSchema->fields()[i], + *schema_->childAt(i), + !parquetFieldIds_.empty() ? &parquetFieldIds_.at(i) : nullptr, + schema_->nameOf(i))); } PARQUET_ASSIGN_OR_THROW( @@ -502,6 +588,22 @@ void Writer::setMemoryReclaimers() { generalPool_->setReclaimer(exec::MemoryReclaimer::create()); } +bool Writer::needFlatten(const VectorPtr& data) const { + auto rowVector = std::dynamic_pointer_cast(data); + VELOX_CHECK_NOT_NULL( + rowVector, "Arrow export expects a RowVector as input data."); + + const auto& children = rowVector->children(); + return std::any_of(children.begin(), children.end(), [](const auto& child) { + bool isNestedWrapped = + (child->encoding() == VectorEncoding::Simple::DICTIONARY || + child->encoding() == VectorEncoding::Simple::CONSTANT) && + child->valueVector() && !child->wrappedVector()->isFlatEncoding(); + bool isComplex = !child->isScalar(); + return isNestedWrapped || isComplex; + }); +} + std::unique_ptr ParquetWriterFactory::createWriter( std::unique_ptr sink, const std::shared_ptr& options) { @@ -526,11 +628,30 @@ void WriterOptions::processConfigs( VELOX_CHECK_NOT_NULL( parquetWriterOptions, "Expected a Parquet WriterOptions object."); + // Check serdeParameters for timestamp settings first (highest priority). + auto serdeTimestampUnitIt = serdeParameters.find(kParquetSerdeTimestampUnit); + if (serdeTimestampUnitIt != serdeParameters.end()) { + parquetWriteTimestampUnit = + stringToTimestampPrecision(serdeTimestampUnitIt->second); + } + + auto serdeTimestampTimezoneIt = + serdeParameters.find(kParquetSerdeTimestampTimezone); + if (serdeTimestampTimezoneIt != serdeParameters.end()) { + // Empty string means no timezone conversion (nullopt). + if (serdeTimestampTimezoneIt->second.empty()) { + parquetWriteTimestampTimeZone = std::nullopt; + } else { + parquetWriteTimestampTimeZone = serdeTimestampTimezoneIt->second; + } + } + if (!parquetWriteTimestampUnit) { parquetWriteTimestampUnit = getTimestampUnit(session, kParquetSessionWriteTimestampUnit).has_value() ? getTimestampUnit(session, kParquetSessionWriteTimestampUnit) - : getTimestampUnit(connectorConfig, kParquetSessionWriteTimestampUnit); + : getTimestampUnit( + connectorConfig, kParquetHiveConnectorWriteTimestampUnit); } if (!parquetWriteTimestampTimeZone) { parquetWriteTimestampTimeZone = parquetWriterOptions->sessionTimezoneName; @@ -578,6 +699,11 @@ void WriterOptions::processConfigs( : getParquetBatchSize( connectorConfig, kParquetHiveConnectorWriteBatchSize); } + + if (!createdBy) { + createdBy = + getParquetCreatedBy(connectorConfig, kParquetHiveConnectorCreatedBy); + } } } // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/writer/Writer.h b/velox/dwio/parquet/writer/Writer.h index 460f01a7f240..922e56110820 100644 --- a/velox/dwio/parquet/writer/Writer.h +++ b/velox/dwio/parquet/writer/Writer.h @@ -16,6 +16,7 @@ #pragma once +#include "arrow/memory_pool.h" #include "velox/common/compression/Compression.h" #include "velox/common/config/Config.h" #include "velox/dwio/common/DataBuffer.h" @@ -87,6 +88,14 @@ class LambdaFlushPolicy : public DefaultFlushPolicy { std::function lambda_; }; +/// Parquet field IDs during write operations. Each ID must be unique positive +/// number, do not need to be sequential. +/// Used to explicitly control field ID assignment in the Parquet schema. +struct ParquetFieldId { + int32_t fieldId; + std::vector children; +}; + struct WriterOptions : public dwio::common::WriterOptions { // Growth ratio passed to ArrowDataBufferSink. The default value is a // heuristic borrowed from @@ -111,6 +120,15 @@ struct WriterOptions : public dwio::common::WriterOptions { std::optional dictionaryPageSizeLimit; std::optional enableDictionary; std::optional useParquetDataPageV2; + std::optional createdBy; + + std::shared_ptr arrowMemoryPool; + + /// Optional field IDs to assign to columns in the Parquet schema. + /// If provided, the writer will use these IDs for the schema fields. + /// If not provided, the field_id will be -1. + /// The structure should match the schema hierarchy with nested children. + std::vector parquetFieldIds; // Parsing session and hive configs. @@ -140,6 +158,17 @@ struct WriterOptions : public dwio::common::WriterOptions { "hive.parquet.writer.batch_size"; static constexpr const char* kParquetHiveConnectorWriteBatchSize = "hive.parquet.writer.batch-size"; + static constexpr const char* kParquetHiveConnectorCreatedBy = + "hive.parquet.writer.created-by"; + + // Serde parameter keys for timestamp settings. These can be set via + // serdeParameters map to override the default timestamp behavior. + // The timezone key accepts a timezone string or empty string to disable + // timezone conversion. + static constexpr const char* kParquetSerdeTimestampUnit = + "parquet.writer.timestamp.unit"; + static constexpr const char* kParquetSerdeTimestampTimezone = + "parquet.writer.timestamp.timezone"; // Process hive connector and session configs. void processConfigs( @@ -193,15 +222,23 @@ class Writer : public dwio::common::Writer { // Sets the memory reclaimers for all the memory pools used by this writer. void setMemoryReclaimers(); + // Checks if the input data contains a nested wrapped vector or complex + // vector. If so, flatten the input to make it compatible with + // 'exportFlattenedVector' in Arrow export. + bool needFlatten(const VectorPtr& data) const; + // Pool for 'stream_'. std::shared_ptr pool_; std::shared_ptr generalPool_; + std::shared_ptr arrowMemoryPool_; // Temporary Arrow stream for capturing the output. std::shared_ptr stream_; std::shared_ptr arrowContext_; + std::vector parquetFieldIds_; + std::unique_ptr flushPolicy_; const RowTypePtr schema_; diff --git a/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp b/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp index de5a2382198a..d4a7193a6646 100644 --- a/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp +++ b/velox/dwio/parquet/writer/arrow/ArrowSchema.cpp @@ -346,14 +346,6 @@ static Status GetTimestampMetadata( static constexpr char FIELD_ID_KEY[] = "PARQUET:field_id"; -std::shared_ptr<::arrow::KeyValueMetadata> FieldIdMetadata(int field_id) { - if (field_id >= 0) { - return ::arrow::key_value_metadata({FIELD_ID_KEY}, {ToChars(field_id)}); - } else { - return nullptr; - } -} - int FieldIdFromMetadata( const std::shared_ptr& metadata) { if (!metadata) { @@ -676,7 +668,7 @@ Status GroupToStruct( node.name(), struct_type, node.is_optional(), - FieldIdMetadata(node.field_id())); + fieldIdMetadata(node.field_id())); out->level_info = current_levels; return Status::OK(); } @@ -761,14 +753,14 @@ Status MapToSchemaField( group.name(), ::arrow::struct_({key_field->field, value_field->field}), /*nullable=*/false, - FieldIdMetadata(key_value.field_id())); + fieldIdMetadata(key_value.field_id())); key_value_field->level_info = current_levels; out->field = ::arrow::field( group.name(), std::make_shared<::arrow::MapType>(key_value_field->field), group.is_optional(), - FieldIdMetadata(group.field_id())); + fieldIdMetadata(group.field_id())); out->level_info = current_levels; // At this point current levels contains the def level for this list, // we need to reset to the prior parent. @@ -854,7 +846,7 @@ Status ListToSchemaField( list_node.name(), type, /*nullable=*/false, - FieldIdMetadata(list_node.field_id())); + fieldIdMetadata(list_node.field_id())); RETURN_NOT_OK(PopulateLeaf( column_index, item_field, current_levels, ctx, out, child_field)); } @@ -862,7 +854,7 @@ Status ListToSchemaField( group.name(), ::arrow::list(child_field->field), group.is_optional(), - FieldIdMetadata(group.field_id())); + fieldIdMetadata(group.field_id())); out->level_info = current_levels; // At this point current levels contains the def level for this list, // we need to reset to the prior parent. @@ -898,7 +890,7 @@ Status GroupToSchemaField( node.name(), ::arrow::list(out->children[0].field), /*nullable=*/false, - FieldIdMetadata(node.field_id())); + fieldIdMetadata(node.field_id())); ctx->LinkParent(&out->children[0], out); out->level_info = current_levels; @@ -961,7 +953,7 @@ Status NodeToSchemaField( node.name(), ::arrow::list(child_field), /*nullable=*/false, - FieldIdMetadata(node.field_id())); + fieldIdMetadata(node.field_id())); out->level_info = current_levels; // At this point current_levels has consider this list the ancestor so // restore the actual ancestor. @@ -976,7 +968,7 @@ Status NodeToSchemaField( node.name(), type, node.is_optional(), - FieldIdMetadata(node.field_id())), + fieldIdMetadata(node.field_id())), current_levels, ctx, parent, @@ -1155,8 +1147,9 @@ Result ApplyOriginalStorageMetadata( // so no need to recurse on value types. const auto& dict_origin_type = checked_cast(*origin_type); - inferred->field = inferred->field->WithType(::arrow::dictionary( - ::arrow::int32(), inferred_type, dict_origin_type.ordered())); + inferred->field = inferred->field->WithType( + ::arrow::dictionary( + ::arrow::int32(), inferred_type, dict_origin_type.ordered())); modified = true; } @@ -1222,6 +1215,14 @@ Result ApplyOriginalMetadata( } // namespace +std::shared_ptr<::arrow::KeyValueMetadata> fieldIdMetadata(int field_id) { + if (field_id >= 0) { + return ::arrow::key_value_metadata({FIELD_ID_KEY}, {ToChars(field_id)}); + } else { + return nullptr; + } +} + Status FieldToNode( const std::shared_ptr& field, const WriterProperties& properties, diff --git a/velox/dwio/parquet/writer/arrow/ArrowSchema.h b/velox/dwio/parquet/writer/arrow/ArrowSchema.h index 8302bc1cdb19..394e51d339fe 100644 --- a/velox/dwio/parquet/writer/arrow/ArrowSchema.h +++ b/velox/dwio/parquet/writer/arrow/ArrowSchema.h @@ -195,5 +195,7 @@ struct PARQUET_EXPORT SchemaManifest { } }; +std::shared_ptr<::arrow::KeyValueMetadata> fieldIdMetadata(int32_t field_id); + } // namespace arrow } // namespace facebook::velox::parquet::arrow diff --git a/velox/dwio/parquet/writer/arrow/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/CMakeLists.txt index fb0e4a4bb7df..b70c5f25e9b6 100644 --- a/velox/dwio/parquet/writer/arrow/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/CMakeLists.txt @@ -38,7 +38,8 @@ velox_add_library( Schema.cpp Statistics.cpp Types.cpp - Writer.cpp) + Writer.cpp +) velox_link_libraries( velox_dwio_arrow_parquet_writer_lib @@ -48,4 +49,5 @@ velox_link_libraries( velox_dwio_common velox_arrow_bridge arrow - fmt::fmt) + fmt::fmt +) diff --git a/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp b/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp index 55825c5957ce..320b1bc03786 100644 --- a/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp +++ b/velox/dwio/parquet/writer/arrow/ColumnWriter.cpp @@ -628,12 +628,13 @@ class SerializedPageWriter : public PageWriter { void UpdateEncryption(int8_t module_type) { switch (module_type) { case encryption::kColumnMetaData: { - meta_encryptor_->UpdateAad(encryption::CreateModuleAad( - meta_encryptor_->file_aad(), - module_type, - row_group_ordinal_, - column_ordinal_, - kNonPageOrdinal)); + meta_encryptor_->UpdateAad( + encryption::CreateModuleAad( + meta_encryptor_->file_aad(), + module_type, + row_group_ordinal_, + column_ordinal_, + kNonPageOrdinal)); break; } case encryption::kDataPage: { @@ -647,21 +648,23 @@ class SerializedPageWriter : public PageWriter { break; } case encryption::kDictionaryPageHeader: { - meta_encryptor_->UpdateAad(encryption::CreateModuleAad( - meta_encryptor_->file_aad(), - module_type, - row_group_ordinal_, - column_ordinal_, - kNonPageOrdinal)); + meta_encryptor_->UpdateAad( + encryption::CreateModuleAad( + meta_encryptor_->file_aad(), + module_type, + row_group_ordinal_, + column_ordinal_, + kNonPageOrdinal)); break; } case encryption::kDictionaryPage: { - data_encryptor_->UpdateAad(encryption::CreateModuleAad( - data_encryptor_->file_aad(), - module_type, - row_group_ordinal_, - column_ordinal_, - kNonPageOrdinal)); + data_encryptor_->UpdateAad( + encryption::CreateModuleAad( + data_encryptor_->file_aad(), + module_type, + row_group_ordinal_, + column_ordinal_, + kNonPageOrdinal)); break; } default: @@ -1575,7 +1578,7 @@ class TypedColumnWriterImpl : public ColumnWriterImpl, batch_num_values, batch_num_spaced_values, bits_buffer_->data(), - /*offset=*/0, + /*valid_bits_offset=*/0, /*num_levels=*/batch_size, null_count); } else { @@ -1867,8 +1870,12 @@ class TypedColumnWriterImpl : public ColumnWriterImpl, if (array->data()->offset > 0) { RETURN_NOT_OK(util::VisitArrayInline(*array, &slicer, &buffers[1])); } - return ::arrow::MakeArray(std::make_shared( - array->type(), array->length(), std::move(buffers), new_null_count)); + return ::arrow::MakeArray( + std::make_shared( + array->type(), + array->length(), + std::move(buffers), + new_null_count)); } void WriteLevelsSpaced( diff --git a/velox/dwio/parquet/writer/arrow/Encoding.cpp b/velox/dwio/parquet/writer/arrow/Encoding.cpp index 14125875e389..f4a5990e5f4b 100644 --- a/velox/dwio/parquet/writer/arrow/Encoding.cpp +++ b/velox/dwio/parquet/writer/arrow/Encoding.cpp @@ -1315,7 +1315,7 @@ int PlainBooleanDecoder::DecodeArrow( null_count, [&]() { bool value; - ARROW_IGNORE_EXPR(bit_reader_->GetValue(1, &value)); + ((void)(bit_reader_->GetValue(1, &value))); builder->UnsafeAppend(value); }, [&]() { builder->UnsafeAppendNull(); }); @@ -1378,8 +1378,9 @@ struct ArrowBinaryHelper { Status Prepare(std::optional estimated_data_length = {}) { RETURN_NOT_OK(acc_->builder->Reserve(entries_remaining_)); if (estimated_data_length.has_value()) { - RETURN_NOT_OK(acc_->builder->ReserveData(std::min( - *estimated_data_length, ::arrow::kBinaryMemoryLimit))); + RETURN_NOT_OK(acc_->builder->ReserveData( + std::min( + *estimated_data_length, ::arrow::kBinaryMemoryLimit))); } return Status::OK(); } @@ -1392,8 +1393,9 @@ struct ArrowBinaryHelper { RETURN_NOT_OK(PushChunk()); RETURN_NOT_OK(acc_->builder->Reserve(entries_remaining_)); if (estimated_remaining_data_length.has_value()) { - RETURN_NOT_OK(acc_->builder->ReserveData(std::min( - *estimated_remaining_data_length, chunk_space_remaining_))); + RETURN_NOT_OK(acc_->builder->ReserveData( + std::min( + *estimated_remaining_data_length, chunk_space_remaining_))); } } return Status::OK(); @@ -1737,7 +1739,7 @@ class DictDecoderImpl : public DecoderImpl, virtual public DictDecoder { num_values_ = num_values; if (len == 0) { // Initialize dummy decoder to avoid crashes later on - idx_decoder_ = RleDecoder(data, len, /*bit_width=*/1); + idx_decoder_ = RleDecoder(data, len, /*bitWidth=*/1); return; } uint8_t bit_width = *data; @@ -3330,13 +3332,14 @@ class RleBooleanEncoder final : public EncoderImpl, buffered_append_values_.push_back(boolean_array.Value(i)); } } else { - PARQUET_THROW_NOT_OK(::arrow::VisitArraySpanInline<::arrow::BooleanType>( - *boolean_array.data(), - [&](bool value) { - buffered_append_values_.push_back(value); - return Status::OK(); - }, - []() { return Status::OK(); })); + PARQUET_THROW_NOT_OK( + ::arrow::VisitArraySpanInline<::arrow::BooleanType>( + *boolean_array.data(), + [&](bool value) { + buffered_append_values_.push_back(value); + return Status::OK(); + }, + []() { return Status::OK(); })); } } diff --git a/velox/dwio/parquet/writer/arrow/Encryption.h b/velox/dwio/parquet/writer/arrow/Encryption.h index df310589a433..82fd6f104083 100644 --- a/velox/dwio/parquet/writer/arrow/Encryption.h +++ b/velox/dwio/parquet/writer/arrow/Encryption.h @@ -300,8 +300,9 @@ class PARQUET_EXPORT FileDecryptionProperties { /// invocation of the retriever callback. /// If an explicit key is available for a footer or a column, /// its key metadata will be ignored. - Builder* column_keys(const ColumnPathToDecryptionPropertiesMap& - column_decryption_properties); + Builder* column_keys( + const ColumnPathToDecryptionPropertiesMap& + column_decryption_properties); /// Set a key retriever callback. Its also possible to /// set explicit footer or column keys on this file property object. diff --git a/velox/dwio/parquet/writer/arrow/Exception.h b/velox/dwio/parquet/writer/arrow/Exception.h index 927df3407464..b71cfc926aba 100644 --- a/velox/dwio/parquet/writer/arrow/Exception.h +++ b/velox/dwio/parquet/writer/arrow/Exception.h @@ -154,9 +154,10 @@ class ParquetInvalidOrCorruptedFileException : public ParquetStatusException { int>::type = 0, typename... Args> explicit ParquetInvalidOrCorruptedFileException(Arg arg, Args&&... args) - : ParquetStatusException(::arrow::Status::Invalid( - std::forward(arg), - std::forward(args)...)) {} + : ParquetStatusException( + ::arrow::Status::Invalid( + std::forward(arg), + std::forward(args)...)) {} }; template diff --git a/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp b/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp index 6192156c7b27..cf02fa99c679 100644 --- a/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp +++ b/velox/dwio/parquet/writer/arrow/FileEncryptorInternal.cpp @@ -172,8 +172,9 @@ encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor( int key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (meta_encryptor_[index] == nullptr) { - meta_encryptor_[index].reset(encryption::AesEncryptor::Make( - algorithm, key_len, true, &all_encryptors_)); + meta_encryptor_[index].reset( + encryption::AesEncryptor::Make( + algorithm, key_len, true, &all_encryptors_)); } return meta_encryptor_[index].get(); } @@ -184,8 +185,9 @@ encryption::AesEncryptor* InternalFileEncryptor::GetDataAesEncryptor( int key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (data_encryptor_[index] == nullptr) { - data_encryptor_[index].reset(encryption::AesEncryptor::Make( - algorithm, key_len, false, &all_encryptors_)); + data_encryptor_[index].reset( + encryption::AesEncryptor::Make( + algorithm, key_len, false, &all_encryptors_)); } return data_encryptor_[index].get(); } diff --git a/velox/dwio/parquet/writer/arrow/FileWriter.cpp b/velox/dwio/parquet/writer/arrow/FileWriter.cpp index ea89c4838eba..1123db370a09 100644 --- a/velox/dwio/parquet/writer/arrow/FileWriter.cpp +++ b/velox/dwio/parquet/writer/arrow/FileWriter.cpp @@ -492,8 +492,9 @@ class FileSerializer : public ParquetFileWriter::Contents { return AppendRowGroup(true); } - void AddKeyValueMetadata(const std::shared_ptr& - key_value_metadata) override { + void AddKeyValueMetadata( + const std::shared_ptr& key_value_metadata) + override { if (key_value_metadata_ == nullptr) { key_value_metadata_ = key_value_metadata; } else if (key_value_metadata != nullptr) { diff --git a/velox/dwio/parquet/writer/arrow/Metadata.cpp b/velox/dwio/parquet/writer/arrow/Metadata.cpp index ef978a01a747..a56c40e5fdf9 100644 --- a/velox/dwio/parquet/writer/arrow/Metadata.cpp +++ b/velox/dwio/parquet/writer/arrow/Metadata.cpp @@ -182,10 +182,10 @@ std::unique_ptr ColumnCryptoMetaData::Make( } ColumnCryptoMetaData::ColumnCryptoMetaData(const uint8_t* metadata) - : impl_(std::make_unique( - reinterpret_cast< - const facebook::velox::parquet::thrift::ColumnCryptoMetaData*>( - metadata))) {} + : impl_( + std::make_unique( + reinterpret_cast(metadata))) {} ColumnCryptoMetaData::~ColumnCryptoMetaData() = default; @@ -389,6 +389,10 @@ class ColumnChunkMetaData::ColumnChunkMetaDataImpl { return std::nullopt; } + inline int32_t field_id() const { + return descr_->schema_node()->field_id(); + } + private: mutable std::shared_ptr possible_stats_; std::vector encodings_; @@ -535,6 +539,10 @@ int64_t ColumnChunkMetaData::total_compressed_size() const { return impl_->total_compressed_size(); } +int32_t ColumnChunkMetaData::field_id() const { + return impl_->field_id(); +} + std::unique_ptr ColumnChunkMetaData::crypto_metadata() const { return impl_->crypto_metadata(); @@ -1021,8 +1029,9 @@ class FileMetaData::FileMetaDataImpl { if (metadata_->schema.empty()) { throw ParquetException("Empty file schema (no root)"); } - schema_.Init(schema::Unflatten( - &metadata_->schema[0], static_cast(metadata_->schema.size()))); + schema_.Init( + schema::Unflatten( + &metadata_->schema[0], static_cast(metadata_->schema.size()))); } void InitColumnOrders() { @@ -1745,8 +1754,8 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { [&thrift_encodings]( facebook::velox::parquet::thrift::Encoding::type value) { auto it = std::find( - thrift_encodings.begin(), thrift_encodings.end(), value); - if (it == thrift_encodings.end()) { + thrift_encodings.cbegin(), thrift_encodings.cend(), value); + if (it == thrift_encodings.cend()) { thrift_encodings.push_back(value); } }; diff --git a/velox/dwio/parquet/writer/arrow/Metadata.h b/velox/dwio/parquet/writer/arrow/Metadata.h index c69ee5a03d41..7cb7670a0387 100644 --- a/velox/dwio/parquet/writer/arrow/Metadata.h +++ b/velox/dwio/parquet/writer/arrow/Metadata.h @@ -187,6 +187,7 @@ class PARQUET_EXPORT ColumnChunkMetaData { int64_t index_page_offset() const; int64_t total_compressed_size() const; int64_t total_uncompressed_size() const; + int32_t field_id() const; std::unique_ptr crypto_metadata() const; std::optional GetColumnIndexLocation() const; std::optional GetOffsetIndexLocation() const; diff --git a/velox/dwio/parquet/writer/arrow/PageIndex.cpp b/velox/dwio/parquet/writer/arrow/PageIndex.cpp index 465d6e1afecd..2af2a6535a43 100644 --- a/velox/dwio/parquet/writer/arrow/PageIndex.cpp +++ b/velox/dwio/parquet/writer/arrow/PageIndex.cpp @@ -201,10 +201,11 @@ class OffsetIndexImpl : public OffsetIndex { const facebook::velox::parquet::thrift::OffsetIndex& offset_index) { page_locations_.reserve(offset_index.page_locations.size()); for (const auto& page_location : offset_index.page_locations) { - page_locations_.emplace_back(PageLocation{ - page_location.offset, - page_location.compressed_page_size, - page_location.first_row_index}); + page_locations_.emplace_back( + PageLocation{ + page_location.offset, + page_location.compressed_page_size, + page_location.first_row_index}); } } diff --git a/velox/dwio/parquet/writer/arrow/PathInternal.cpp b/velox/dwio/parquet/writer/arrow/PathInternal.cpp index 885a8254edd4..f43736067334 100644 --- a/velox/dwio/parquet/writer/arrow/PathInternal.cpp +++ b/velox/dwio/parquet/writer/arrow/PathInternal.cpp @@ -888,10 +888,10 @@ class PathBuilder { return VisitInline(*array.storage()); } -#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \ - Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \ - return Status::NotImplemented("Level generation for " #ArrowTypePrefix \ - " not supported yet"); \ +#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \ + Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \ + return Status::NotImplemented( \ + "Level generation for " #ArrowTypePrefix " not supported yet"); \ } // Types not yet supported in Parquet. diff --git a/velox/dwio/parquet/writer/arrow/Properties.h b/velox/dwio/parquet/writer/arrow/Properties.h index 356d88066079..815e2c7d6f5f 100644 --- a/velox/dwio/parquet/writer/arrow/Properties.h +++ b/velox/dwio/parquet/writer/arrow/Properties.h @@ -36,6 +36,8 @@ // Define the parquet created by version. #define CREATED_BY_VERSION "parquet-cpp-velox" +// Velox has no versioning yet. Set default 0.0.0. +#define VELOX_VERSION "0.0.0" namespace facebook::velox::parquet::arrow { @@ -326,8 +328,7 @@ class PARQUET_EXPORT WriterProperties { version_(ParquetVersion::PARQUET_2_6), data_page_version_(ParquetDataPageVersion::V1), created_by_( - std::string("parquet-mr version 2.6.0 (build ") + - DEFAULT_CREATED_BY + ")"), + DEFAULT_CREATED_BY + std::string(" version ") + VELOX_VERSION), store_decimal_as_integer_(false), page_checksum_enabled_(false) {} virtual ~Builder() {} diff --git a/velox/dwio/parquet/writer/arrow/Statistics.cpp b/velox/dwio/parquet/writer/arrow/Statistics.cpp index 757ac36b62e8..4a3ef281befd 100644 --- a/velox/dwio/parquet/writer/arrow/Statistics.cpp +++ b/velox/dwio/parquet/writer/arrow/Statistics.cpp @@ -94,10 +94,16 @@ struct CompareHelper { "T is an unsigned numeric"); constexpr static T DefaultMin() { + if constexpr (std::is_floating_point_v) { + return std::numeric_limits::infinity(); + } return std::numeric_limits::max(); } constexpr static T DefaultMax() { - return std::numeric_limits::lowest(); + if constexpr (std::is_floating_point_v) { + return -std::numeric_limits::infinity(); + } + return std::numeric_limits::min(); } // MSVC17 fix, isnan is not overloaded for IntegralType as per C++11 @@ -361,6 +367,7 @@ CleanStatistic(std::pair min_max) { // In case of floating point types, the following rules are applied (as per // upstream parquet-mr): // - If any of min/max is NaN, return nothing. +// - If min is infinity and max is -infinity, return nothing. // - If min is 0.0f, replace with -0.0f // - If max is -0.0f, replace with 0.0f template @@ -375,8 +382,8 @@ ::arrow:: return ::std::nullopt; } - if (min == std::numeric_limits::max() && - max == std::numeric_limits::lowest()) { + if (min == std::numeric_limits::infinity() && + max == -std::numeric_limits::infinity()) { return ::std::nullopt; } diff --git a/velox/dwio/parquet/writer/arrow/Types.cpp b/velox/dwio/parquet/writer/arrow/Types.cpp index 929f664f3097..9f1baec35125 100644 --- a/velox/dwio/parquet/writer/arrow/Types.cpp +++ b/velox/dwio/parquet/writer/arrow/Types.cpp @@ -771,8 +771,10 @@ class LogicalType::Impl::Compatible : public virtual LogicalType::Impl { } \ } -#define reset_decimal_metadata(m___) \ - { set_decimal_metadata(m___, false, -1, -1); } +#define reset_decimal_metadata(m___) \ + { \ + set_decimal_metadata(m___, false, -1, -1); \ + } // For logical types that always translate to the same converted type class LogicalType::Impl::SimpleCompatible diff --git a/velox/dwio/parquet/writer/arrow/Types.h b/velox/dwio/parquet/writer/arrow/Types.h index 2645a883a4df..24727b1c05cf 100644 --- a/velox/dwio/parquet/writer/arrow/Types.h +++ b/velox/dwio/parquet/writer/arrow/Types.h @@ -603,10 +603,6 @@ inline bool operator==(const SortingColumn& left, const SortingColumn& right) { left.column_idx == right.column_idx; } -inline bool operator!=(const SortingColumn& left, const SortingColumn& right) { - return !(left == right); -} - // ---------------------------------------------------------------------- struct ByteArray { @@ -631,10 +627,6 @@ inline bool operator==(const ByteArray& left, const ByteArray& right) { (left.len == 0 || std::memcmp(left.ptr, right.ptr, left.len) == 0); } -inline bool operator!=(const ByteArray& left, const ByteArray& right) { - return !(left == right); -} - struct FixedLenByteArray { FixedLenByteArray() : ptr(NULLPTR) {} explicit FixedLenByteArray(const uint8_t* ptr) : ptr(ptr) {} @@ -665,10 +657,6 @@ inline bool operator==(const Int96& left, const Int96& right) { return std::equal(left.value, left.value + 3, right.value); } -inline bool operator!=(const Int96& left, const Int96& right) { - return !(left == right); -} - static inline std::string ByteArrayToString(const ByteArray& a) { return std::string(reinterpret_cast(a.ptr), a.len); } diff --git a/velox/dwio/parquet/writer/arrow/Writer.cpp b/velox/dwio/parquet/writer/arrow/Writer.cpp index e6572bb764dc..f3a101a02810 100644 --- a/velox/dwio/parquet/writer/arrow/Writer.cpp +++ b/velox/dwio/parquet/writer/arrow/Writer.cpp @@ -480,12 +480,13 @@ class FileWriterImpl : public FileWriter { if (arrow_properties_->use_threads()) { VELOX_DCHECK_EQ(parallel_column_write_contexts_.size(), writers.size()); - RETURN_NOT_OK(::arrow::internal::ParallelFor( - static_cast(writers.size()), - [&](int i) { - return writers[i]->Write(¶llel_column_write_contexts_[i]); - }, - arrow_properties_->executor())); + RETURN_NOT_OK( + ::arrow::internal::ParallelFor( + static_cast(writers.size()), + [&](int i) { + return writers[i]->Write(¶llel_column_write_contexts_[i]); + }, + arrow_properties_->executor())); } return Status::OK(); @@ -658,16 +659,18 @@ Result> FileWriter::Open( Status WriteFileMetaData( const FileMetaData& file_metadata, ::arrow::io::OutputStream* sink) { - PARQUET_CATCH_NOT_OK(::facebook::velox::parquet::arrow::WriteFileMetaData( - file_metadata, sink)); + PARQUET_CATCH_NOT_OK( + ::facebook::velox::parquet::arrow::WriteFileMetaData( + file_metadata, sink)); return Status::OK(); } Status WriteMetaDataFile( const FileMetaData& file_metadata, ::arrow::io::OutputStream* sink) { - PARQUET_CATCH_NOT_OK(::facebook::velox::parquet::arrow::WriteMetaDataFile( - file_metadata, sink)); + PARQUET_CATCH_NOT_OK( + ::facebook::velox::parquet::arrow::WriteMetaDataFile( + file_metadata, sink)); return Status::OK(); } diff --git a/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp b/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp index eeae731f9d0f..866aa8ada61f 100644 --- a/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/BloomFilter.cpp @@ -31,8 +31,6 @@ namespace facebook::velox::parquet::arrow { -constexpr uint32_t BlockSplitBloomFilter::SALT[kBitsSetPerBlock]; - BlockSplitBloomFilter::BlockSplitBloomFilter(::arrow::MemoryPool* pool) : pool_(pool), hash_strategy_(HashStrategy::XXHASH), diff --git a/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp b/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp index 644fda5a678f..03b56d5c9bc9 100644 --- a/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/BloomFilterTest.cpp @@ -425,10 +425,12 @@ TYPED_TEST(TestBatchBloomFilter, Basic) { // Bloom filter fpp parameter const double fpp = 0.05; - filter.Init(BlockSplitBloomFilter::OptimalNumOfBytes( - TestFixture::kTestDataSize, fpp)); - batch_insert_filter.Init(BlockSplitBloomFilter::OptimalNumOfBytes( - TestFixture::kTestDataSize, fpp)); + filter.Init( + BlockSplitBloomFilter::OptimalNumOfBytes( + TestFixture::kTestDataSize, fpp)); + batch_insert_filter.Init( + BlockSplitBloomFilter::OptimalNumOfBytes( + TestFixture::kTestDataSize, fpp)); std::vector hashes; for (const Type& value : test_data) { diff --git a/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt index c8c6e24c33fa..5906a75516d4 100644 --- a/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/tests/CMakeLists.txt @@ -25,10 +25,10 @@ add_executable( PropertiesTest.cpp SchemaTest.cpp StatisticsTest.cpp - TypesTest.cpp) + TypesTest.cpp +) -add_test(velox_dwio_arrow_parquet_writer_test - velox_dwio_arrow_parquet_writer_test) +add_test(velox_dwio_arrow_parquet_writer_test velox_dwio_arrow_parquet_writer_test) target_link_libraries( velox_dwio_arrow_parquet_writer_test @@ -39,7 +39,8 @@ target_link_libraries( arrow arrow_testing velox_dwio_native_parquet_reader - velox_temp_path) + velox_temp_path +) add_library( velox_dwio_arrow_parquet_writer_test_lib @@ -49,8 +50,12 @@ add_library( ColumnScanner.cpp FileReader.cpp TestUtil.cpp - XxHasher.cpp) + XxHasher.cpp +) target_link_libraries( - velox_dwio_arrow_parquet_writer_test_lib arrow - velox_dwio_arrow_parquet_writer_lib GTest::gtest) + velox_dwio_arrow_parquet_writer_test_lib + arrow + velox_dwio_arrow_parquet_writer_lib + GTest::gtest +) diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp b/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp index e22c8fcad530..0deb1540244c 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnReaderTest.cpp @@ -506,7 +506,7 @@ TEST_F(TestPrimitiveReader, TestReadValuesMissing) { std::shared_ptr data_page = MakeDataPage( &descr, values, - /*num_values=*/2, + /*num_vals=*/2, Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -600,7 +600,7 @@ TEST_F(TestPrimitiveReader, TestReadValuesMissingWithDictionary) { std::shared_ptr data_page = MakeDataPage( &descr, {}, - /*num_values=*/2, + /*num_vals=*/2, Encoding::RLE_DICTIONARY, /*indices=*/{}, /*indices_size=*/0, @@ -915,7 +915,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, ReadRequired) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(def_levels.size()), + /*num_vals=*/static_cast(def_levels.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -985,7 +985,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, ReadOptional) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(def_levels.size()), + /*num_vals=*/static_cast(def_levels.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1113,7 +1113,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, ReadRequiredRepeated) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(def_levels.size()), + /*num_vals=*/static_cast(def_levels.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1191,7 +1191,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, ReadNullableRepeated) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(def_levels.size()), + /*num_vals=*/static_cast(def_levels.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1320,7 +1320,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, SkipRequiredTopLevel) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(values.size()), + /*num_vals=*/static_cast(values.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1371,7 +1371,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, SkipOptional) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(values.size()), + /*num_vals=*/static_cast(values.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1485,7 +1485,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, SkipRepeated) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(values.size()), + /*num_vals=*/static_cast(values.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1590,7 +1590,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, SkipRepeatedConsumeBufferFirst) { std::shared_ptr page = MakeDataPage( descr_, values, - /*num_values=*/static_cast(values.size()), + /*num_vals=*/static_cast(values.size()), Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1647,7 +1647,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, ReadPartialRecord) { std::shared_ptr page = MakeDataPage( descr_, /*values=*/{10, 20, 20, 20}, - /*num_values=*/4, + /*num_vals=*/4, Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1663,7 +1663,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, ReadPartialRecord) { std::shared_ptr page = MakeDataPage( descr_, /*values=*/{20, 20}, - /*num_values=*/2, + /*num_vals=*/2, Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1679,7 +1679,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, ReadPartialRecord) { std::shared_ptr page = MakeDataPage( descr_, /*values=*/{20, 30}, - /*num_values=*/2, + /*num_vals=*/2, Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1751,7 +1751,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, SkipPartialRecord) { std::shared_ptr page = MakeDataPage( descr_, /*values=*/{10, 20, 20, 20}, - /*num_values=*/4, + /*num_vals=*/4, Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1767,7 +1767,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, SkipPartialRecord) { std::shared_ptr page = MakeDataPage( descr_, /*values=*/{20, 20}, - /*num_values=*/2, + /*num_vals=*/2, Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, @@ -1783,7 +1783,7 @@ TEST_P(RecordReaderPrimitiveTypeTest, SkipPartialRecord) { std::shared_ptr page = MakeDataPage( descr_, /*values=*/{20, 30}, - /*num_values=*/2, + /*num_vals=*/2, Encoding::PLAIN, /*indices=*/{}, /*indices_size=*/0, diff --git a/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp b/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp index 4d9b10f35686..cb9f85196f5d 100644 --- a/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/ColumnWriterTest.cpp @@ -226,7 +226,7 @@ class TestPrimitiveWriter : public PrimitiveTypedTest { ASSERT_EQ(this->values_, this->values_out_); std::vector encodings_vector = this->metadata_encodings(); std::set encodings( - encodings_vector.begin(), encodings_vector.end()); + encodings_vector.cbegin(), encodings_vector.cend()); if (this->type_num() == Type::BOOLEAN) { // Dictionary encoding is not allowed for boolean type diff --git a/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp b/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp index b18c9bb7dc45..47d501aef905 100644 --- a/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/EncodingTest.cpp @@ -1706,7 +1706,7 @@ class DictDecoderImpl : public DecoderImpl, virtual public DictDecoder { num_values_ = num_values; if (len == 0) { // Initialize dummy decoder to avoid crashes later on - idx_decoder_ = RleDecoder(data, len, /*bit_width=*/1); + idx_decoder_ = RleDecoder(data, len, /*bitWidth=*/1); return; } uint8_t bit_width = *data; @@ -3295,13 +3295,14 @@ class RleBooleanEncoder final : public EncoderImpl, buffered_append_values_.push_back(boolean_array.Value(i)); } } else { - PARQUET_THROW_NOT_OK(::arrow::VisitArraySpanInline<::arrow::BooleanType>( - *boolean_array.data(), - [&](bool value) { - buffered_append_values_.push_back(value); - return Status::OK(); - }, - []() { return Status::OK(); })); + PARQUET_THROW_NOT_OK( + ::arrow::VisitArraySpanInline<::arrow::BooleanType>( + *boolean_array.data(), + [&](bool value) { + buffered_append_values_.push_back(value); + return Status::OK(); + }, + []() { return Status::OK(); })); } } @@ -3411,7 +3412,7 @@ class RleBooleanDecoder : public DecoderImpl, virtual public BooleanDecoder { num_bytes, /*bit_width=*/1); } else { - decoder_->Reset(decoder_data, num_bytes, /*bit_width=*/1); + decoder_->Reset(decoder_data, num_bytes, /*bitWidth=*/1); } } diff --git a/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp b/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp index 9c1447391584..e4beb4ef077b 100644 --- a/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/FileDeserializeTest.cpp @@ -57,8 +57,9 @@ AddDummyStats(int stat_size, H& header, bool fill_all_stats = false) { std::string(reinterpret_cast(stat_bytes.data()), stat_size)); if (fill_all_stats) { - header.statistics.__set_min(std::string( - reinterpret_cast(stat_bytes.data()), stat_size)); + header.statistics.__set_min( + std::string( + reinterpret_cast(stat_bytes.data()), stat_size)); header.statistics.__set_null_count(42); header.statistics.__set_distinct_count(1); } diff --git a/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp b/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp index 32aac3599582..a0d0d2e5cb55 100644 --- a/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/MetadataTest.cpp @@ -428,6 +428,9 @@ TEST(Metadata, TestAddKeyValueMetadata) { } // Verify keys that were added after file writer was closed are not present. EXPECT_FALSE(reader->fileMetaData().keyValueMetadataContains("test_key_4")); + ASSERT_EQ( + CREATED_BY_VERSION + std::string(" version ") + VELOX_VERSION, + reader->fileMetaData().createdBy()); } // TODO: disabled as they require Arrow parquet data dir. @@ -502,10 +505,12 @@ TEST(Metadata, TestSortingColumns) { sortingColumns.push_back(sortingColumn); } + auto createdBy = CREATED_BY_VERSION + std::string(" version 1.0"); auto sink = CreateOutputStream(); auto writerProps = WriterProperties::Builder() .disable_dictionary() ->set_sorting_columns(sortingColumns) + ->created_by(createdBy) ->build(); EXPECT_EQ(sortingColumns, writerProps->sorting_columns()); @@ -537,6 +542,7 @@ TEST(Metadata, TestSortingColumns) { EXPECT_EQ(sortingColumns[0].column_idx, rowGroup.sortingColumnIdx(0)); EXPECT_EQ(sortingColumns[0].descending, rowGroup.sortingColumnDescending(0)); EXPECT_EQ(sortingColumns[0].nulls_first, rowGroup.sortingColumnNullsFirst(0)); + ASSERT_EQ(createdBy, reader->fileMetaData().createdBy()); } TEST(ApplicationVersion, Basics) { diff --git a/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp b/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp index fd175e2a8213..28f13a40f588 100644 --- a/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/PageIndexTest.cpp @@ -270,10 +270,11 @@ TEST(PageIndex, WriteOffsetIndex) { auto sink = CreateOutputStream(); builder->WriteTo(sink.get()); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); - offset_indexes.emplace_back(OffsetIndex::Make( - buffer->data(), - static_cast(buffer->size()), - default_reader_properties())); + offset_indexes.emplace_back( + OffsetIndex::Make( + buffer->data(), + static_cast(buffer->size()), + default_reader_properties())); /// Verify the data of the offset index. for (const auto& offset_index : offset_indexes) { @@ -309,11 +310,12 @@ void TestWriteTypedColumnIndex( auto sink = CreateOutputStream(); builder->WriteTo(sink.get()); PARQUET_ASSIGN_OR_THROW(auto buffer, sink->Finish()); - column_indexes.emplace_back(ColumnIndex::Make( - *descr, - buffer->data(), - static_cast(buffer->size()), - default_reader_properties())); + column_indexes.emplace_back( + ColumnIndex::Make( + *descr, + buffer->data(), + static_cast(buffer->size()), + default_reader_properties())); /// Verify the data of the column index. for (const auto& column_index : column_indexes) { @@ -666,10 +668,9 @@ TEST_F(PageIndexBuilderTest, SingleRowGroup) { EncodedStatistics().set_null_count(0).set_min("A").set_max("B")}}; const std::vector> page_locations = { /*row_group_id=0*/ - {/*column_id=0*/ { - /*offset=*/128, - /*compressed_page_size=*/512, - /*first_row_index=*/0}, + {/*column_id=0*/ {/*offset=*/128, + /*compressed_page_size=*/512, + /*first_row_index=*/0}, /*column_id=1*/ {/*offset=*/1024, /*compressed_page_size=*/512, @@ -713,19 +714,17 @@ TEST_F(PageIndexBuilderTest, TwoRowGroups) { EncodedStatistics().set_null_count(0).set_min("bar").set_max("foo")}}; const std::vector> page_locations = { /*row_group_id=0*/ - {/*column_id=0*/ { - /*offset=*/128, - /*compressed_page_size=*/512, - /*first_row_index=*/0}, + {/*column_id=0*/ {/*offset=*/128, + /*compressed_page_size=*/512, + /*first_row_index=*/0}, /*column_id=1*/ {/*offset=*/1024, /*compressed_page_size=*/512, /*first_row_index=*/0}}, /*row_group_id=0*/ - {/*column_id=0*/ { - /*offset=*/128, - /*compressed_page_size=*/512, - /*first_row_index=*/0}, + {/*column_id=0*/ {/*offset=*/128, + /*compressed_page_size=*/512, + /*first_row_index=*/0}, /*column_id=1*/ {/*offset=*/1024, /*compressed_page_size=*/512, diff --git a/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp b/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp index f312ec7e7b47..2c8980898220 100644 --- a/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/SchemaTest.cpp @@ -310,19 +310,22 @@ TEST_F(TestPrimitiveNode, Equals) { } TEST_F(TestPrimitiveNode, PhysicalLogicalMapping) { - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32)); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::JSON)); + ASSERT_NO_THROW( + PrimitiveNode::Make( + "foo", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32)); + ASSERT_NO_THROW( + PrimitiveNode::Make( + "foo", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::JSON)); ASSERT_THROW( PrimitiveNode::Make( "foo", Repetition::REQUIRED, Type::INT32, ConvertedType::JSON), ParquetException); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", - Repetition::REQUIRED, - Type::INT64, - ConvertedType::TIMESTAMP_MILLIS)); + ASSERT_NO_THROW( + PrimitiveNode::Make( + "foo", + Repetition::REQUIRED, + Type::INT64, + ConvertedType::TIMESTAMP_MILLIS)); ASSERT_THROW( PrimitiveNode::Make( "foo", Repetition::REQUIRED, Type::INT32, ConvertedType::INT_64), @@ -345,8 +348,9 @@ TEST_F(TestPrimitiveNode, PhysicalLogicalMapping) { Type::FIXED_LEN_BYTE_ARRAY, ConvertedType::ENUM), ParquetException); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::ENUM)); + ASSERT_NO_THROW( + PrimitiveNode::Make( + "foo", Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::ENUM)); ASSERT_THROW( PrimitiveNode::Make( "foo", @@ -407,20 +411,22 @@ TEST_F(TestPrimitiveNode, PhysicalLogicalMapping) { 2, 4), ParquetException); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, - 10, - 6, - 4)); - ASSERT_NO_THROW(PrimitiveNode::Make( - "foo", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, - 12)); + ASSERT_NO_THROW( + PrimitiveNode::Make( + "foo", + Repetition::REQUIRED, + Type::FIXED_LEN_BYTE_ARRAY, + ConvertedType::DECIMAL, + 10, + 6, + 4)); + ASSERT_NO_THROW( + PrimitiveNode::Make( + "foo", + Repetition::REQUIRED, + Type::FIXED_LEN_BYTE_ARRAY, + ConvertedType::INTERVAL, + 12)); ASSERT_THROW( PrimitiveNode::Make( "foo", @@ -602,10 +608,12 @@ TEST_F(TestSchemaConverter, NestedExample) { // 3-level list encoding NodePtr item = Int64("item", Repetition::OPTIONAL, 4); - NodePtr list(GroupNode::Make( - "b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3)); - NodePtr bag(GroupNode::Make( - "bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2)); + NodePtr list( + GroupNode::Make( + "b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3)); + NodePtr bag( + GroupNode::Make( + "bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2)); fields.push_back(bag); NodePtr schema = GroupNode::Make( @@ -751,14 +759,16 @@ TEST_F(TestSchemaFlatten, NestedExample) { // 3-level list encoding NodePtr item = Int64("item", Repetition::OPTIONAL, 4); - NodePtr list(GroupNode::Make( - "b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3)); - NodePtr bag(GroupNode::Make( - "bag", - Repetition::OPTIONAL, - {list}, - /*logical_type=*/nullptr, - 2)); + NodePtr list( + GroupNode::Make( + "b", Repetition::REPEATED, {item}, ConvertedType::LIST, 3)); + NodePtr bag( + GroupNode::Make( + "bag", + Repetition::OPTIONAL, + {list}, + /*logical_type=*/nullptr, + 2)); fields.push_back(bag); NodePtr schema = GroupNode::Make( @@ -853,11 +863,12 @@ TEST_F(TestSchemaDescriptor, Equals) { NodePtr item1 = Int64("item1", Repetition::REQUIRED); NodePtr item2 = Boolean("item2", Repetition::OPTIONAL); NodePtr item3 = Int32("item3", Repetition::REPEATED); - NodePtr list(GroupNode::Make( - "records", - Repetition::REPEATED, - {item1, item2, item3}, - ConvertedType::LIST)); + NodePtr list( + GroupNode::Make( + "records", + Repetition::REPEATED, + {item1, item2, item3}, + ConvertedType::LIST)); NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); NodePtr bag2(GroupNode::Make("bag", Repetition::REQUIRED, {list})); @@ -869,13 +880,15 @@ TEST_F(TestSchemaDescriptor, Equals) { ASSERT_TRUE(descr1.Equals(descr1)); SchemaDescriptor descr2; - descr2.Init(GroupNode::Make( - "schema", Repetition::REPEATED, {inta, intb, intc, bag2})); + descr2.Init( + GroupNode::Make( + "schema", Repetition::REPEATED, {inta, intb, intc, bag2})); ASSERT_FALSE(descr1.Equals(descr2)); SchemaDescriptor descr3; - descr3.Init(GroupNode::Make( - "schema", Repetition::REPEATED, {inta, intb2, intc, bag})); + descr3.Init( + GroupNode::Make( + "schema", Repetition::REPEATED, {inta, intb2, intc, bag})); ASSERT_FALSE(descr1.Equals(descr3)); // Robust to name of parent node @@ -885,8 +898,9 @@ TEST_F(TestSchemaDescriptor, Equals) { ASSERT_TRUE(descr1.Equals(descr4)); SchemaDescriptor descr5; - descr5.Init(GroupNode::Make( - "schema", Repetition::REPEATED, {inta, intb, intc, bag, intb2})); + descr5.Init( + GroupNode::Make( + "schema", Repetition::REPEATED, {inta, intb, intc, bag, intb2})); ASSERT_FALSE(descr1.Equals(descr5)); // Different max repetition / definition levels @@ -912,11 +926,12 @@ TEST_F(TestSchemaDescriptor, BuildTree) { NodePtr item1 = Int64("item1", Repetition::REQUIRED); NodePtr item2 = Boolean("item2", Repetition::OPTIONAL); NodePtr item3 = Int32("item3", Repetition::REPEATED); - NodePtr list(GroupNode::Make( - "records", - Repetition::REPEATED, - {item1, item2, item3}, - ConvertedType::LIST)); + NodePtr list( + GroupNode::Make( + "records", + Repetition::REPEATED, + {item1, item2, item3}, + ConvertedType::LIST)); NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); fields.push_back(bag); @@ -994,11 +1009,12 @@ TEST_F(TestSchemaDescriptor, HasRepeatedFields) { NodePtr item1 = Int64("item1", Repetition::REQUIRED); NodePtr item2 = Boolean("item2", Repetition::OPTIONAL); NodePtr item3 = Int32("item3", Repetition::REPEATED); - NodePtr list(GroupNode::Make( - "records", - Repetition::REPEATED, - {item1, item2, item3}, - ConvertedType::LIST)); + NodePtr list( + GroupNode::Make( + "records", + Repetition::REPEATED, + {item1, item2, item3}, + ConvertedType::LIST)); NodePtr bag(GroupNode::Make("bag", Repetition::OPTIONAL, {list})); fields.push_back(bag); @@ -1009,8 +1025,12 @@ TEST_F(TestSchemaDescriptor, HasRepeatedFields) { // 3-level list encoding NodePtr item_key = Int64("key", Repetition::REQUIRED); NodePtr item_value = Boolean("value", Repetition::OPTIONAL); - NodePtr map(GroupNode::Make( - "map", Repetition::REPEATED, {item_key, item_value}, ConvertedType::MAP)); + NodePtr map( + GroupNode::Make( + "map", + Repetition::REPEATED, + {item_key, item_value}, + ConvertedType::MAP)); NodePtr my_map(GroupNode::Make("my_map", Repetition::OPTIONAL, {map})); fields.push_back(my_map); @@ -1034,29 +1054,33 @@ TEST(TestSchemaPrinter, Examples) { // 3-level list encoding NodePtr item1 = Int64("item1", Repetition::OPTIONAL, 4); NodePtr item2 = Boolean("item2", Repetition::REQUIRED, 5); - NodePtr list(GroupNode::Make( - "b", Repetition::REPEATED, {item1, item2}, ConvertedType::LIST, 3)); - NodePtr bag(GroupNode::Make( - "bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2)); + NodePtr list( + GroupNode::Make( + "b", Repetition::REPEATED, {item1, item2}, ConvertedType::LIST, 3)); + NodePtr bag( + GroupNode::Make( + "bag", Repetition::OPTIONAL, {list}, /*logical_type=*/nullptr, 2)); fields.push_back(bag); - fields.push_back(PrimitiveNode::Make( - "c", - Repetition::REQUIRED, - Type::INT32, - ConvertedType::DECIMAL, - -1, - 3, - 2, - 6)); + fields.push_back( + PrimitiveNode::Make( + "c", + Repetition::REQUIRED, + Type::INT32, + ConvertedType::DECIMAL, + -1, + 3, + 2, + 6)); - fields.push_back(PrimitiveNode::Make( - "d", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 5), - Type::INT64, - /*length=*/-1, - 7)); + fields.push_back( + PrimitiveNode::Make( + "d", + Repetition::REQUIRED, + DecimalLogicalType::Make(10, 5), + Type::INT64, + /*length=*/-1, + 7)); NodePtr schema = GroupNode::Make( "schema", @@ -2059,86 +2083,100 @@ TEST(TestSchemaNodeCreation, FactoryExceptions) { // create an object if compatibility conditions are not met // Nested logical type on non-group node ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "map", Repetition::REQUIRED, MapLogicalType::Make(), Type::INT64)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "map", Repetition::REQUIRED, MapLogicalType::Make(), Type::INT64)); // Incompatible primitive type ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "string", - Repetition::REQUIRED, - StringLogicalType::Make(), - Type::BOOLEAN)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "string", + Repetition::REQUIRED, + StringLogicalType::Make(), + Type::BOOLEAN)); // Incompatible primitive length ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "interval", - Repetition::REQUIRED, - IntervalLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - 11)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "interval", + Repetition::REQUIRED, + IntervalLogicalType::Make(), + Type::FIXED_LEN_BYTE_ARRAY, + 11)); // Scale is greater than precision. - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 11), - Type::INT64)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(17, 18), - Type::INT64)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "decimal", + Repetition::REQUIRED, + DecimalLogicalType::Make(10, 11), + Type::INT64)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "decimal", + Repetition::REQUIRED, + DecimalLogicalType::Make(17, 18), + Type::INT64)); // Primitive too small for given precision ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(16, 6), - Type::INT32)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 9), - Type::INT32)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(19, 17), - Type::INT64)); - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(308, 6), - Type::FIXED_LEN_BYTE_ARRAY, - 128)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "decimal", + Repetition::REQUIRED, + DecimalLogicalType::Make(16, 6), + Type::INT32)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "decimal", + Repetition::REQUIRED, + DecimalLogicalType::Make(10, 9), + Type::INT32)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "decimal", + Repetition::REQUIRED, + DecimalLogicalType::Make(19, 17), + Type::INT64)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "decimal", + Repetition::REQUIRED, + DecimalLogicalType::Make(308, 6), + Type::FIXED_LEN_BYTE_ARRAY, + 128)); // Length is too long - ASSERT_ANY_THROW(PrimitiveNode::Make( - "decimal", - Repetition::REQUIRED, - DecimalLogicalType::Make(10, 6), - Type::FIXED_LEN_BYTE_ARRAY, - 891723283)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "decimal", + Repetition::REQUIRED, + DecimalLogicalType::Make(10, 6), + Type::FIXED_LEN_BYTE_ARRAY, + 891723283)); // Incompatible primitive length ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "uuid", - Repetition::REQUIRED, - UUIDLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - 64)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "uuid", + Repetition::REQUIRED, + UUIDLogicalType::Make(), + Type::FIXED_LEN_BYTE_ARRAY, + 64)); // Non-positive length argument for fixed length binary ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "negative_length", - Repetition::REQUIRED, - NoLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - -16)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "negative_length", + Repetition::REQUIRED, + NoLogicalType::Make(), + Type::FIXED_LEN_BYTE_ARRAY, + -16)); // Non-positive length argument for fixed length binary ... - ASSERT_ANY_THROW(PrimitiveNode::Make( - "zero_length", - Repetition::REQUIRED, - NoLogicalType::Make(), - Type::FIXED_LEN_BYTE_ARRAY, - 0)); + ASSERT_ANY_THROW( + PrimitiveNode::Make( + "zero_length", + Repetition::REQUIRED, + NoLogicalType::Make(), + Type::FIXED_LEN_BYTE_ARRAY, + 0)); // Non-nested logical type on group node ... - ASSERT_ANY_THROW(GroupNode::Make( - "list", Repetition::REPEATED, {}, JSONLogicalType::Make())); + ASSERT_ANY_THROW( + GroupNode::Make( + "list", Repetition::REPEATED, {}, JSONLogicalType::Make())); // nullptr logical type arguments convert to NoLogicalType/ConvertedType::NONE std::shared_ptr empty; diff --git a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp index 85603e04b15f..14814ba42c39 100644 --- a/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp +++ b/velox/dwio/parquet/writer/arrow/tests/StatisticsTest.cpp @@ -464,7 +464,7 @@ class TestStatistics : public PrimitiveTypedTest { std::vector definitionLevels(batchNullCount, 0); definitionLevels.insert( definitionLevels.end(), batchNumValues - batchNullCount, 1); - auto beg = this->values_.begin() + i * numValues / 2; + auto beg = this->values_.cbegin() + i * numValues / 2; auto end = beg + batchNumValues; std::vector batch = GetDeepCopy(std::vector(beg, end)); c_type* batchValuesPtr = GetValuesPointer(batch); @@ -892,23 +892,29 @@ TEST(CorruptStatistics, Basics) { schema::NodePtr node; std::vector fields; // Test Physical Types - fields.push_back(schema::PrimitiveNode::Make( - "col1", Repetition::OPTIONAL, Type::INT32, ConvertedType::NONE)); - fields.push_back(schema::PrimitiveNode::Make( - "col2", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::NONE)); + fields.push_back( + schema::PrimitiveNode::Make( + "col1", Repetition::OPTIONAL, Type::INT32, ConvertedType::NONE)); + fields.push_back( + schema::PrimitiveNode::Make( + "col2", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::NONE)); // Test Logical Types - fields.push_back(schema::PrimitiveNode::Make( - "col3", Repetition::OPTIONAL, Type::INT32, ConvertedType::DATE)); - fields.push_back(schema::PrimitiveNode::Make( - "col4", Repetition::OPTIONAL, Type::INT32, ConvertedType::UINT_32)); - fields.push_back(schema::PrimitiveNode::Make( - "col5", - Repetition::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, - 12)); - fields.push_back(schema::PrimitiveNode::Make( - "col6", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8)); + fields.push_back( + schema::PrimitiveNode::Make( + "col3", Repetition::OPTIONAL, Type::INT32, ConvertedType::DATE)); + fields.push_back( + schema::PrimitiveNode::Make( + "col4", Repetition::OPTIONAL, Type::INT32, ConvertedType::UINT_32)); + fields.push_back( + schema::PrimitiveNode::Make( + "col5", + Repetition::OPTIONAL, + Type::FIXED_LEN_BYTE_ARRAY, + ConvertedType::INTERVAL, + 12)); + fields.push_back( + schema::PrimitiveNode::Make( + "col6", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8)); node = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields); schema.Init(node); @@ -932,23 +938,29 @@ TEST(CorrectStatistics, Basics) { schema::NodePtr node; std::vector fields; // Test Physical Types - fields.push_back(schema::PrimitiveNode::Make( - "col1", Repetition::OPTIONAL, Type::INT32, ConvertedType::NONE)); - fields.push_back(schema::PrimitiveNode::Make( - "col2", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::NONE)); + fields.push_back( + schema::PrimitiveNode::Make( + "col1", Repetition::OPTIONAL, Type::INT32, ConvertedType::NONE)); + fields.push_back( + schema::PrimitiveNode::Make( + "col2", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::NONE)); // Test Logical Types - fields.push_back(schema::PrimitiveNode::Make( - "col3", Repetition::OPTIONAL, Type::INT32, ConvertedType::DATE)); - fields.push_back(schema::PrimitiveNode::Make( - "col4", Repetition::OPTIONAL, Type::INT32, ConvertedType::UINT_32)); - fields.push_back(schema::PrimitiveNode::Make( - "col5", - Repetition::OPTIONAL, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::INTERVAL, - 12)); - fields.push_back(schema::PrimitiveNode::Make( - "col6", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8)); + fields.push_back( + schema::PrimitiveNode::Make( + "col3", Repetition::OPTIONAL, Type::INT32, ConvertedType::DATE)); + fields.push_back( + schema::PrimitiveNode::Make( + "col4", Repetition::OPTIONAL, Type::INT32, ConvertedType::UINT_32)); + fields.push_back( + schema::PrimitiveNode::Make( + "col5", + Repetition::OPTIONAL, + Type::FIXED_LEN_BYTE_ARRAY, + ConvertedType::INTERVAL, + 12)); + fields.push_back( + schema::PrimitiveNode::Make( + "col6", Repetition::OPTIONAL, Type::BYTE_ARRAY, ConvertedType::UTF8)); node = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields); schema.Init(node); @@ -973,8 +985,12 @@ class TestStatisticsSortOrder : public ::testing::Test { using c_type = typename TestType::c_type; void AddNodes(std::string name) { - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, TestType::type_num, ConvertedType::NONE)); + fields_.push_back( + schema::PrimitiveNode::Make( + name, + Repetition::REQUIRED, + TestType::type_num, + ConvertedType::NONE)); } void SetUpSchema() { @@ -1053,11 +1069,13 @@ using CompareTestTypes = ::testing:: template <> void TestStatisticsSortOrder::AddNodes(std::string name) { // UINT_32 logical type to set Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT32, ConvertedType::UINT_32)); + fields_.push_back( + schema::PrimitiveNode::Make( + name, Repetition::REQUIRED, Type::INT32, ConvertedType::UINT_32)); // INT_32 logical type to set Signed Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32)); + fields_.push_back( + schema::PrimitiveNode::Make( + name, Repetition::REQUIRED, Type::INT32, ConvertedType::INT_32)); } template <> @@ -1068,28 +1086,34 @@ void TestStatisticsSortOrder::SetValues() { // Write UINT32 min/max values stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[5]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[4]), sizeof(c_type))); + .set_min( + std::string( + reinterpret_cast(&values_[5]), sizeof(c_type))) + .set_max( + std::string( + reinterpret_cast(&values_[4]), sizeof(c_type))); // Write INT32 min/max values stats_[1] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .set_min( + std::string( + reinterpret_cast(&values_[0]), sizeof(c_type))) + .set_max( + std::string( + reinterpret_cast(&values_[9]), sizeof(c_type))); } // TYPE::INT64 template <> void TestStatisticsSortOrder::AddNodes(std::string name) { // UINT_64 logical type to set Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT64, ConvertedType::UINT_64)); + fields_.push_back( + schema::PrimitiveNode::Make( + name, Repetition::REQUIRED, Type::INT64, ConvertedType::UINT_64)); // INT_64 logical type to set Signed Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::INT64, ConvertedType::INT_64)); + fields_.push_back( + schema::PrimitiveNode::Make( + name, Repetition::REQUIRED, Type::INT64, ConvertedType::INT_64)); } template <> @@ -1100,17 +1124,21 @@ void TestStatisticsSortOrder::SetValues() { // Write UINT64 min/max values stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[5]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[4]), sizeof(c_type))); + .set_min( + std::string( + reinterpret_cast(&values_[5]), sizeof(c_type))) + .set_max( + std::string( + reinterpret_cast(&values_[4]), sizeof(c_type))); // Write INT64 min/max values stats_[1] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .set_min( + std::string( + reinterpret_cast(&values_[0]), sizeof(c_type))) + .set_max( + std::string( + reinterpret_cast(&values_[9]), sizeof(c_type))); } // TYPE::FLOAT @@ -1123,10 +1151,12 @@ void TestStatisticsSortOrder::SetValues() { // Write Float min/max values stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .set_min( + std::string( + reinterpret_cast(&values_[0]), sizeof(c_type))) + .set_max( + std::string( + reinterpret_cast(&values_[9]), sizeof(c_type))); } // TYPE::DOUBLE @@ -1139,18 +1169,21 @@ void TestStatisticsSortOrder::SetValues() { // Write Double min/max values stats_[0] - .set_min(std::string( - reinterpret_cast(&values_[0]), sizeof(c_type))) - .set_max(std::string( - reinterpret_cast(&values_[9]), sizeof(c_type))); + .set_min( + std::string( + reinterpret_cast(&values_[0]), sizeof(c_type))) + .set_max( + std::string( + reinterpret_cast(&values_[9]), sizeof(c_type))); } // TYPE::ByteArray template <> void TestStatisticsSortOrder::AddNodes(std::string name) { // UTF8 logical type to set Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::UTF8)); + fields_.push_back( + schema::PrimitiveNode::Make( + name, Repetition::REQUIRED, Type::BYTE_ARRAY, ConvertedType::UTF8)); } template <> @@ -1180,22 +1213,26 @@ void TestStatisticsSortOrder::SetValues() { // Write String min/max values stats_[0] - .set_min(std::string( - reinterpret_cast(vals[2].c_str()), vals[2].length())) - .set_max(std::string( - reinterpret_cast(vals[9].c_str()), vals[9].length())); + .set_min( + std::string( + reinterpret_cast(vals[2].c_str()), vals[2].length())) + .set_max( + std::string( + reinterpret_cast(vals[9].c_str()), + vals[9].length())); } // TYPE::FLBAArray template <> void TestStatisticsSortOrder::AddNodes(std::string name) { // FLBA has only Unsigned Statistics - fields_.push_back(schema::PrimitiveNode::Make( - name, - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::NONE, - FLBA_LENGTH)); + fields_.push_back( + schema::PrimitiveNode::Make( + name, + Repetition::REQUIRED, + Type::FIXED_LEN_BYTE_ARRAY, + ConvertedType::NONE, + FLBA_LENGTH)); } template <> @@ -1274,14 +1311,15 @@ TEST(TestByteArrayStatisticsFromArrow, LargeStringType) { using TestStatisticsSortOrderFLBA = TestStatisticsSortOrder; TEST_F(TestStatisticsSortOrderFLBA, decimalSortOrder) { - this->fields_.push_back(schema::PrimitiveNode::Make( - "Column 0", - Repetition::REQUIRED, - Type::FIXED_LEN_BYTE_ARRAY, - ConvertedType::DECIMAL, - FLBA_LENGTH, - 12, - 2)); + this->fields_.push_back( + schema::PrimitiveNode::Make( + "Column 0", + Repetition::REQUIRED, + Type::FIXED_LEN_BYTE_ARRAY, + ConvertedType::DECIMAL, + FLBA_LENGTH, + 12, + 2)); this->SetUpSchema(); this->WriteParquet(); @@ -1598,6 +1636,172 @@ TEST(TestStatistics, DoubleNegativeZero) { CheckNegativeZeroStats(); } +// Test infinity handling in statistics. +template +void CheckInfinityStats() { + using T = typename ParquetType::c_type; + + constexpr int32_t kNumValues = 8; + NodePtr node = PrimitiveNode::Make( + "infinity_test", Repetition::OPTIONAL, ParquetType::type_num); + ColumnDescriptor descr(node, 1, 1); + + constexpr T posInf = std::numeric_limits::infinity(); + constexpr T negInf = -std::numeric_limits::infinity(); + constexpr T min = -1.0f; + constexpr T max = 1.0f; + + { + std::array allPosInf{ + posInf, posInf, posInf, posInf, posInf, posInf, posInf, posInf}; + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, allPosInf, posInf, posInf); + } + + { + std::array allNegInf{ + negInf, negInf, negInf, negInf, negInf, negInf, negInf, negInf}; + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, allNegInf, negInf, negInf); + } + + { + std::array mixedInf{ + posInf, negInf, posInf, negInf, posInf, negInf, posInf, negInf}; + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, mixedInf, negInf, posInf); + } + + { + std::array mixedValues{ + posInf, max, min, min, negInf, max, min, posInf}; + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, mixedValues, negInf, posInf); + } + + { + constexpr T nan = std::numeric_limits::quiet_NaN(); + std::array mixedWithNan{ + posInf, nan, max, negInf, nan, min, posInf, nan}; + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, mixedWithNan, negInf, posInf); + } +} + +TEST(TestStatistics, FloatInfinityValues) { + CheckInfinityStats(); +} + +TEST(TestStatistics, DoubleInfinityValues) { + CheckInfinityStats(); +} + +// Test infinity values with validity bitmap. +TEST(TestStatistics, InfinityWithNullBitmap) { + constexpr int kNumValues = 8; + NodePtr node = PrimitiveNode::Make( + "infinity_null_test", Repetition::OPTIONAL, Type::FLOAT); + ColumnDescriptor descr(node, 1, 1); + + constexpr float posInf = std::numeric_limits::infinity(); + constexpr float negInf = -std::numeric_limits::infinity(); + + // Test with some infinity values marked as null. + std::array valuesWithNulls{ + posInf, negInf, 1.0f, 2.0f, posInf, -1.0f, 3.0f, negInf}; + + // Bitmap: exclude first posInf and last negInf (01111110 = 0x7E). + uint8_t validBitmap = 0x7E; + + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, valuesWithNulls, &validBitmap, negInf, posInf); + valuesWithNulls = {posInf, 0.0f, 1.0f, 2.0f, -2.0f, -1.0f, 3.0f, negInf}; + + stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, valuesWithNulls, &validBitmap, -2.0f, 3.0f); +} + +// Test merging statistics with infinity values. +TEST(TestStatistics, MergeInfinityStatistics) { + NodePtr node = + PrimitiveNode::Make("merge_infinity", Repetition::OPTIONAL, Type::DOUBLE); + ColumnDescriptor descr(node, 1, 1); + + constexpr double posInf = std::numeric_limits::infinity(); + constexpr double negInf = -std::numeric_limits::infinity(); + + auto stats1 = MakeStatistics(&descr); + std::array normalValues{-1.0f, 0.0f, 1.0f}; + AssertMinMaxAre(stats1, normalValues, -1.0f, 1.0f); + + auto stats2 = MakeStatistics(&descr); + std::array infinityValues{negInf, posInf}; + AssertMinMaxAre(stats2, infinityValues, negInf, posInf); + + auto mergedStats = MakeStatistics(&descr); + mergedStats->Merge(*stats1); + mergedStats->Merge(*stats2); + + // Result should have infinity bounds. + ASSERT_TRUE(mergedStats->HasMinMax()); + ASSERT_EQ(negInf, mergedStats->min()); + ASSERT_EQ(posInf, mergedStats->max()); +} + +TEST(TestStatistics, CleanInfinityStatistics) { + constexpr int kNumValues = 4; + NodePtr node = PrimitiveNode::Make( + "clean_stat_nullopt", Repetition::OPTIONAL, Type::FLOAT); + ColumnDescriptor descr(node, 1, 1); + + constexpr float nan = std::numeric_limits::quiet_NaN(); + + { + std::array allNans{nan, nan, nan, nan}; + auto stats = MakeStatistics(&descr); + AssertUnsetMinMax(stats, allNans); + } + + { + std::array values{1.0f, 2.0f, 3.0f, 4.0f}; + uint8_t allNullBitmap = 0x00; + + auto stats = MakeStatistics(&descr); + AssertUnsetMinMax(stats, values, &allNullBitmap); + } + + { + std::array mixedNans{nan, 1.0f, nan, 2.0f}; + uint8_t partialNullBitmap = 0x05; + + auto stats = MakeStatistics(&descr); + AssertUnsetMinMax(stats, mixedNans, &partialNullBitmap); + } +} + +TEST(TestStatistics, InfinityCleanStatisticValid) { + constexpr int kNumValues = 4; + NodePtr node = PrimitiveNode::Make( + "clean_stat_valid", Repetition::OPTIONAL, Type::DOUBLE); + ColumnDescriptor descr(node, 1, 1); + + constexpr double posInf = std::numeric_limits::infinity(); + constexpr double negInf = -std::numeric_limits::infinity(); + constexpr double nan = std::numeric_limits::quiet_NaN(); + + { + std::array mixedValues{posInf, nan, negInf, nan}; + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, mixedValues, negInf, posInf); + } + + { + std::array singleInf{negInf}; + auto stats = MakeStatistics(&descr); + AssertMinMaxAre(stats, singleInf, negInf, negInf); + } +} + // TODO: disabled as it requires Arrow parquet data dir. // Test statistics for binary column with UNSIGNED sort order /* diff --git a/velox/dwio/parquet/writer/arrow/tests/TestUtil.h b/velox/dwio/parquet/writer/arrow/tests/TestUtil.h index 79f3dc3992cc..4b4c11002aff 100644 --- a/velox/dwio/parquet/writer/arrow/tests/TestUtil.h +++ b/velox/dwio/parquet/writer/arrow/tests/TestUtil.h @@ -366,8 +366,9 @@ class DataPageBuilder { ParquetException::NYI("only rle encoding currently implemented"); } - std::vector encode_buffer(LevelEncoder::MaxBufferSize( - Encoding::RLE, max_level, static_cast(levels.size()))); + std::vector encode_buffer( + LevelEncoder::MaxBufferSize( + Encoding::RLE, max_level, static_cast(levels.size()))); // We encode into separate memory from the output stream because the // RLE-encoded bytes have to be preceded in the stream by their absolute @@ -807,12 +808,13 @@ class PrimitiveTypedTest : public ::testing::Test { for (int i = 0; i < num_columns; ++i) { std::string name = TestColumnName(i); - fields.push_back(schema::PrimitiveNode::Make( - name, - repetition, - TestType::type_num, - ConvertedType::NONE, - FLBA_LENGTH)); + fields.push_back( + schema::PrimitiveNode::Make( + name, + repetition, + TestType::type_num, + ConvertedType::NONE, + FLBA_LENGTH)); } node_ = schema::GroupNode::Make("schema", Repetition::REQUIRED, fields); schema_.Init(node_); diff --git a/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h b/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h index 45e5b025ff1e..bf30c6dd5d4f 100644 --- a/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h +++ b/velox/dwio/parquet/writer/arrow/util/ByteStreamSplitInternal.h @@ -78,8 +78,9 @@ void ByteStreamSplitDecodeSse2( for (int64_t i = 0; i < num_blocks; ++i) { for (size_t j = 0; j < kNumStreams; ++j) { - stage[0][j] = _mm_loadu_si128(reinterpret_cast( - &data[i * sizeof(__m128i) + j * stride])); + stage[0][j] = _mm_loadu_si128( + reinterpret_cast( + &data[i * sizeof(__m128i) + j * stride])); } for (size_t step = 0; step < kNumStreamsLog2; ++step) { for (size_t j = 0; j < kNumStreamsHalf; ++j) { @@ -227,8 +228,9 @@ void ByteStreamSplitDecodeAvx2( for (int64_t i = 0; i < num_blocks; ++i) { for (size_t j = 0; j < kNumStreams; ++j) { - stage[0][j] = _mm256_loadu_si256(reinterpret_cast( - &data[i * sizeof(__m256i) + j * stride])); + stage[0][j] = _mm256_loadu_si256( + reinterpret_cast( + &data[i * sizeof(__m256i) + j * stride])); } for (size_t step = 0; step < kNumStreamsLog2; ++step) { @@ -405,8 +407,9 @@ void ByteStreamSplitDecodeAvx512( for (int64_t i = 0; i < num_blocks; ++i) { for (size_t j = 0; j < kNumStreams; ++j) { - stage[0][j] = _mm512_loadu_si512(reinterpret_cast( - &data[i * sizeof(__m512i) + j * stride])); + stage[0][j] = _mm512_loadu_si512( + reinterpret_cast( + &data[i * sizeof(__m512i) + j * stride])); } for (size_t step = 0; step < kNumStreamsLog2; ++step) { diff --git a/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt b/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt index b971d33c25c4..2aff416fcdb3 100644 --- a/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt +++ b/velox/dwio/parquet/writer/arrow/util/CMakeLists.txt @@ -20,7 +20,8 @@ velox_add_library( CompressionZlib.cpp CompressionLZ4.cpp Hashing.cpp - Crc32.cpp) + Crc32.cpp +) velox_link_libraries( velox_dwio_arrow_parquet_writer_util_lib @@ -29,4 +30,5 @@ velox_link_libraries( Snappy::snappy zstd::zstd ZLIB::ZLIB - lz4::lz4) + lz4::lz4 +) diff --git a/velox/dwio/parquet/writer/arrow/util/Compression.cpp b/velox/dwio/parquet/writer/arrow/util/Compression.cpp index f5f77313d101..d61f557a2a1f 100644 --- a/velox/dwio/parquet/writer/arrow/util/Compression.cpp +++ b/velox/dwio/parquet/writer/arrow/util/Compression.cpp @@ -24,7 +24,6 @@ #include "arrow/result.h" #include "arrow/status.h" -#include "arrow/util/logging.h" #include "velox/dwio/parquet/writer/arrow/util/CompressionInternal.h" namespace facebook::velox::parquet::arrow::util { diff --git a/velox/dwio/parquet/writer/arrow/util/Hashing.h b/velox/dwio/parquet/writer/arrow/util/Hashing.h index cc8c0ff07701..cadca7c68900 100644 --- a/velox/dwio/parquet/writer/arrow/util/Hashing.h +++ b/velox/dwio/parquet/writer/arrow/util/Hashing.h @@ -41,7 +41,6 @@ #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_builders.h" #include "arrow/util/endian.h" -#include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/ubsan.h" diff --git a/velox/dwio/parquet/writer/arrow/util/safe-math.h b/velox/dwio/parquet/writer/arrow/util/safe-math.h index 661b62887bc7..70506eefec60 100644 --- a/velox/dwio/parquet/writer/arrow/util/safe-math.h +++ b/velox/dwio/parquet/writer/arrow/util/safe-math.h @@ -156,13 +156,13 @@ #define PSNIP_SAFE_DEFINE_LARGER_BINARY_OP(T, name, op_name, op) \ PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \ - psnip_safe_larger_##name##_##op_name(T a, T b) { \ + psnip_safe_larger_##name##_##op_name(T a, T b) { \ return ((psnip_safe_##name##_larger)a)op((psnip_safe_##name##_larger)b); \ } #define PSNIP_SAFE_DEFINE_LARGER_UNARY_OP(T, name, op_name, op) \ PSNIP_SAFE__FUNCTION psnip_safe_##name##_larger \ - psnip_safe_larger_##name##_##op_name(T value) { \ + psnip_safe_larger_##name##_##op_name(T value) { \ return (op((psnip_safe_##name##_larger)value)); \ } diff --git a/velox/dwio/text/CMakeLists.txt b/velox/dwio/text/CMakeLists.txt index 844a12ffd601..d11825c4e30d 100644 --- a/velox/dwio/text/CMakeLists.txt +++ b/velox/dwio/text/CMakeLists.txt @@ -21,3 +21,9 @@ add_subdirectory(writer) velox_add_library(velox_dwio_text_writer_register RegisterTextWriter.cpp) velox_link_libraries(velox_dwio_text_writer_register velox_dwio_text_writer) + +add_subdirectory(reader) + +velox_add_library(velox_dwio_text_reader_register RegisterTextReader.cpp) + +velox_link_libraries(velox_dwio_text_reader_register velox_dwio_text_reader) diff --git a/velox/dwio/text/RegisterTextReader.cpp b/velox/dwio/text/RegisterTextReader.cpp new file mode 100644 index 000000000000..6f631880809e --- /dev/null +++ b/velox/dwio/text/RegisterTextReader.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/text/RegisterTextReader.h" +#include "velox/dwio/text/reader/TextReader.h" + +namespace facebook::velox::text { + +std::unique_ptr TextReaderFactory::createReader( + std::unique_ptr input, + const ReaderOptions& options) { + return std::make_unique(options, std::move(input)); +} + +void registerTextReaderFactory() { + registerReaderFactory(std::make_shared()); +} + +void unregisterTextReaderFactory() { + unregisterReaderFactory(dwio::common::FileFormat::TEXT); +} + +} // namespace facebook::velox::text diff --git a/velox/dwio/text/RegisterTextReader.h b/velox/dwio/text/RegisterTextReader.h new file mode 100644 index 000000000000..65b9aeb8ae1b --- /dev/null +++ b/velox/dwio/text/RegisterTextReader.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/dwio/common/ReaderFactory.h" + +namespace facebook::velox::text { + +class TextReaderFactory : public dwio::common::ReaderFactory { + public: + TextReaderFactory() : ReaderFactory(dwio::common::FileFormat::TEXT) {} + + std::unique_ptr createReader( + std::unique_ptr, + const dwio::common::ReaderOptions&) override; +}; + +void registerTextReaderFactory(); + +void unregisterTextReaderFactory(); + +} // namespace facebook::velox::text diff --git a/velox/dwio/text/reader/CMakeLists.txt b/velox/dwio/text/reader/CMakeLists.txt new file mode 100644 index 000000000000..d65f1ffbc227 --- /dev/null +++ b/velox/dwio/text/reader/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +velox_add_library(velox_dwio_text_reader TextReader.cpp) + +velox_link_libraries( + velox_dwio_text_reader + velox_type_fbhive + velox_dwio_common_compression + velox_encode + fmt::fmt +) diff --git a/velox/dwio/text/reader/TextReader.cpp b/velox/dwio/text/reader/TextReader.cpp new file mode 100644 index 000000000000..d382dbae491b --- /dev/null +++ b/velox/dwio/text/reader/TextReader.cpp @@ -0,0 +1,1645 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/text/reader/TextReader.h" + +#include +#include + +#include "velox/common/encode/Base64.h" +#include "velox/dwio/common/exception/Exceptions.h" +#include "velox/type/fbhive/HiveTypeParser.h" + +namespace facebook::velox::text { +namespace { + +using common::CompressionKind; + +using dwio::common::EOFError; +using dwio::common::RowReader; +using dwio::common::verify; + +static constexpr std::string_view kTextfileCompressionExtensionGzip{".gz"}; +static constexpr std::string_view kTextfileCompressionExtensionDeflate{ + ".deflate"}; +static constexpr std::string_view kTextfileCompressionExtensionZst{".zst"}; +static constexpr std::string_view kTextfileCompressionExtensionLz4{".lz4"}; +static constexpr std::string_view kTextfileCompressionExtensionLzo{".lzo"}; +static constexpr std::string_view kTextfileCompressionExtensionSnappy{ + ".snappy"}; + +static std::string emptyString = std::string(); + +constexpr const int32_t kDecompressionBufferFactor = 3; + +void resizeVector( + BaseVector* FOLLY_NULLABLE data, + const vector_size_t insertionIdx) { + if (data == nullptr) { + return; + } + + auto dataSize = data->size(); + if (dataSize == 0) { + data->resize(10); + } else if (dataSize <= insertionIdx) { + if (data->type()->kind() == TypeKind::ARRAY) { + auto oldSize = dataSize; + auto newSize = dataSize * 2; + data->resize(newSize); + + auto arrayVector = data->asChecked(); + auto rawOffsets = arrayVector->offsets()->asMutable(); + auto rawSizes = arrayVector->sizes()->asMutable(); + + auto lastOffset = oldSize > 0 ? rawOffsets[oldSize - 1] : 0; + auto lastSize = oldSize > 0 ? rawSizes[oldSize - 1] : 0; + auto newOffset = oldSize > 0 ? lastOffset + lastSize : 0; + + for (auto i = oldSize; i < newSize; ++i) { + rawSizes[i] = 0; + rawOffsets[i] = newOffset; + } + } else if (data->type()->kind() == TypeKind::MAP) { + auto oldSize = dataSize; + auto newSize = dataSize * 2; + data->resize(newSize); + + auto mapVector = data->asChecked(); + auto rawOffsets = mapVector->offsets()->asMutable(); + auto rawSizes = mapVector->sizes()->asMutable(); + + auto lastOffset = oldSize > 0 ? rawOffsets[oldSize - 1] : 0; + auto lastSize = oldSize > 0 ? rawSizes[oldSize - 1] : 0; + auto newOffset = oldSize > 0 ? lastOffset + lastSize : 0; + + for (auto i = oldSize; i < newSize; ++i) { + rawSizes[i] = 0; + rawOffsets[i] = newOffset; + } + } else { + data->resize(dataSize * 2); + } + } +} + +void setCompressionSettings( + const std::string& filename, + CompressionKind& kind, + dwio::common::compression::CompressionOptions& compressionOptions) { + if (filename.ends_with(kTextfileCompressionExtensionLz4) || + filename.ends_with(kTextfileCompressionExtensionLzo) || + filename.ends_with(kTextfileCompressionExtensionSnappy)) { + VELOX_FAIL("Unsupported compression extension for file: {}", filename); + } + if (filename.ends_with(kTextfileCompressionExtensionGzip)) { + kind = CompressionKind::CompressionKind_GZIP; + compressionOptions.format.zlib.windowBits = + 15; // 2^15-byte deflate window size + } else if (filename.ends_with(kTextfileCompressionExtensionDeflate)) { + kind = CompressionKind::CompressionKind_ZLIB; + compressionOptions.format.zlib.windowBits = + -15; // raw deflate, 2^15-byte window size + } else if (filename.ends_with(kTextfileCompressionExtensionZst)) { + kind = CompressionKind::CompressionKind_ZSTD; + } else { + kind = CompressionKind::CompressionKind_NONE; + } +} + +} // namespace + +FileContents::FileContents( + MemoryPool& pool, + const std::shared_ptr& t) + : schema{t}, + input{nullptr}, + pool{pool}, + fileLength{0}, + compression{CompressionKind::CompressionKind_NONE}, + compressionOptions{}, + needsEscape{} { + needsEscape.fill(false); + needsEscape.at(0) = true; +} + +TextRowReader::TextRowReader( + std::shared_ptr fileContents, + const RowReaderOptions& opts) + : RowReader(), + contents_{fileContents}, + schemaWithId_{TypeWithId::create(fileContents->schema)}, + scanSpec_{opts.scanSpec()}, + selectedSchema_{nullptr}, + options_{opts}, + columnSelector_{ + ColumnSelector::apply(opts.selector(), contents_->schema)}, + currentRow_{0}, + pos_{opts.offset()}, + atEOL_{false}, + atEOF_{false}, + atSOL_{false}, + atPhysicalEOF_{false}, + depth_{0}, + unreadIdx_{0}, + limit_{opts.limit()}, + fileLength_{getStreamLength()}, + varBinBuf_{ + std::make_shared>(contents_->pool)} { + // Seek to first line at or after the specified region. + if (contents_->compression == CompressionKind::CompressionKind_NONE) { + // TODO: Inconsistent row skipping behavior (kept for Presto compatibility) + // Issue: When reading from byte offset > 0, we skip rows inclusively at the + // start position, but when reading from byte 0, no rows are skipped. This + // creates inconsistent behavior where a row at the boundary may be skipped + // when it should be included. + // + // Example: If pos_ = 10 is the first byte of row 2, that entire row gets + // skipped, even though it should be read. + // + // Proposed fix: streamPosition_ = (pos_ == 0) ? 0 : --pos_; + // This would skip rows exclusively of pos_, ensuring consistent behavior. + const auto streamPosition_ = pos_; + + contents_->inputStream = contents_->input->read( + streamPosition_, + contents_->fileLength - streamPosition_, + dwio::common::LogType::STREAM); + + if (pos_ != 0) { + unreadData_.clear(); + (void)skipLine(); + } + if (opts.skipRows() > 0) { + (void)seekToRow(opts.skipRows()); + } + } else { + // compressed text files, the first split reads the whole file, rest read 0 + if (pos_ != 0) { + atEOF_ = true; + } + limit_ = std::numeric_limits::max(); + + contents_->inputStream = contents_->input->loadCompleteFile(); + auto name = contents_->inputStream->getName(); + contents_->decompressedInputStream = createDecompressor( + contents_->compression, + std::move(contents_->inputStream), + // An estimated value used as the output buffer size for the zlib + // decompressor, and as the fallback value of the decompressed length + // for other decompressors. + kDecompressionBufferFactor * contents_->fileLength, + contents_->pool, + contents_->compressionOptions, + fmt::format("Text Reader: Stream {}", name), + nullptr, + true, + contents_->fileLength); + + if (opts.skipRows() > 0) { + (void)seekToRow(opts.skipRows()); + } + } +} + +uint64_t TextRowReader::next( + uint64_t rows, + VectorPtr& result, + const Mutation* mutation) { + if (atEOF_) { + return 0; + } + + auto& t = schemaWithId_; + verify( + t->type()->isRow(), + "Top-level TypeKind of schema is not Row for file %s", + getStreamNameData()); + + auto projectSelectedType = options_.projectSelectedType(); + auto reqT = + (projectSelectedType ? getSelectedType() : TypeWithId::create(getType())); + verify( + reqT->type()->isRow(), + "Top-level TypeKind of schema is not Row for file %s", + getStreamNameData()); + + auto childCount = t->size(); + auto reqChildCount = reqT->size(); + + // create top level RowVector + auto rowVecPtr = BaseVector::create( + reqT->type(), (vector_size_t)rows, &contents_->pool); + + vector_size_t rowsRead = 0; + const auto initialPos = pos_; + while (!atEOF_ && rowsRead < rows) { + resetLine(); + uint64_t colIndex = 0; + for (vector_size_t i = 0; i < childCount; i++) { + if (colIndex >= reqT->size()) { + break; + } + + DelimType delim = DelimTypeNone; + const auto& ct = t->childAt(i); + const auto& rct = reqT->childAt(i); + auto childVector = rowVecPtr->childAt(i).get(); + + if (isSelectedField(ct)) { + ++colIndex; + } else if (colIndex < reqChildCount && !projectSelectedType) { + // not selected and not projecting: set to null + if (childVector != nullptr) { + rowVecPtr->setNull(i, true); + childVector = nullptr; + } + ++colIndex; + } else { + // not selected and projecting: just discard the field + childVector = nullptr; + } + + resizeVector(childVector, rowsRead); + readElement(ct->type(), rct->type(), childVector, rowsRead, delim); + } + + // set null property + for (uint64_t i = colIndex; i < reqChildCount; i++) { + auto childVector = rowVecPtr->childAt(i).get(); + + if (childVector != nullptr) { + rowVecPtr->setNull(static_cast(i), true); + } + } + (void)skipLine(); + ++currentRow_; + ++rowsRead; + + bool eof = false; + if (contents_->compression == CompressionKind::CompressionKind_NONE) { + eof = pos_ >= getLength(); + } else if (atPhysicalEOF_) { + eof = pos_ >= contents_->decompressedInputStream->ByteCount(); + } + + if (eof) { + setEOF(); + } + + // handle empty file + if (initialPos == pos_ && atEOF_) { + currentRow_ = 0; + rowsRead = 0; + } + } + + // Resize the row vector to the actual number of rows read. + // Handled here for both cases: pos_ > fileLength_ and pos_ > limit_ + rowVecPtr->resize(rowsRead); + result = projectColumns(rowVecPtr, *scanSpec_, mutation); + + return rowsRead; +} + +int64_t TextRowReader::nextRowNumber() { + return atEOF_ ? -1 : static_cast(currentRow_) + 1; +} + +int64_t TextRowReader::nextReadSize(uint64_t size) { + return static_cast(std::min(fileLength_ - currentRow_, size)); +} + +void TextRowReader::updateRuntimeStats( + dwio::common::RuntimeStatistics& /*stats*/) const { + // No-op for non-selective reader. +} + +void TextRowReader::resetFilterCaches() { + // No-op for non-selective reader. +} + +std::optional TextRowReader::estimatedRowSize() const { + return std::nullopt; +} + +const ColumnSelector& TextRowReader::getColumnSelector() const { + return columnSelector_; +} + +std::shared_ptr TextRowReader::getSelectedType() const { + if (!selectedSchema_) { + selectedSchema_ = columnSelector_.buildSelected(); + } + return selectedSchema_; +} + +uint64_t TextRowReader::getRowNumber() const { + return currentRow_; +} + +uint64_t TextRowReader::seekToRow(uint64_t rowNumber) { + VELOX_CHECK_GT( + rowNumber, currentRow_, "Text file cannot seek to earlier row"); + + while (currentRow_ < rowNumber && !skipLine()) { + currentRow_++; + resetLine(); + } + + return currentRow_; +} + +const RowReaderOptions& TextRowReader::getDefaultOpts() { + static RowReaderOptions defaultOpts; + return defaultOpts; +} + +bool TextRowReader::isSelectedField( + const std::shared_ptr& type) { + auto ci = type->id(); + return columnSelector_.shouldReadNode(ci); +} + +const char* TextRowReader::getStreamNameData() const { + return contents_->input->getName().data(); +} + +uint64_t TextRowReader::getLength() { + if (fileLength_ == std::numeric_limits::max()) { + fileLength_ = getStreamLength(); + } + return fileLength_; +} + +uint64_t TextRowReader::getStreamLength() const { + return contents_->input->getInputStream()->getLength(); +} + +void TextRowReader::setEOF() { + atEOF_ = true; + atEOL_ = true; +} + +/// TODO: Update maximum depth after fixing issue with deeply nested complex +/// types +void TextRowReader::incrementDepth() { + if (depth_ > 4) { + dwio::common::parse_error("Schema nesting too deep"); + } + depth_++; +} + +void TextRowReader::decrementDepth(DelimType& delim) { + if (depth_ == 0) { + dwio::common::logic_error("Attempt to decrement nesting depth of 0"); + } + depth_--; + auto d = depth_ + DelimTypeEOR; + if (delim > d) { + setNone(delim); + } +} + +void TextRowReader::setEOE(DelimType& delim) { + // Set delim if it is currently None or a more deeply + // delimiter, to simply the code where aggregates + // parse nested aggregates. + auto d = depth_ + DelimTypeEOE; + if (isNone(delim) || d < delim) { + delim = d; + } +} + +void TextRowReader::resetEOE(DelimType& delim) { + // Reset delim it is EOE or above. + auto d = depth_ + DelimTypeEOE; + if (delim >= d) { + setNone(delim); + } +} + +bool TextRowReader::isEOE(DelimType delim) { + // Test if delim is the EOE at the current depth. + return (delim == (depth_ + DelimTypeEOE)); +} + +void TextRowReader::setEOR(DelimType& delim) { + // Set delim if it is currently None or a more + // deeply nested delimiter. + auto d = depth_ + DelimTypeEOR; + if (isNone(delim) || delim > d) { + delim = d; + } +} + +bool TextRowReader::isEOR(DelimType delim) { + // Return true if delim is the EOR for the current depth + // or a less deeply nested depth. + return (delim != DelimTypeNone && delim <= (depth_ + DelimTypeEOR)); +} + +bool TextRowReader::isOuterEOR(DelimType delim) { + // Return true if delim is the EOR for the enclosing object. + // For example, when parsing ARRAY elements, which leave delim + // set to the EOR for their depth on return, isOuterEOR will + // return true if we have reached the ARRAY EOR delimiter at + // the end of the latest element. + return (delim != DelimTypeNone && delim < (depth_ + DelimTypeEOR)); +} + +bool TextRowReader::isEOEorEOR(DelimType delim) { + return (!isNone(delim) && delim <= (depth_ + DelimTypeEOE)); +} + +void TextRowReader::setNone(DelimType& delim) { + delim = DelimTypeNone; +} + +bool TextRowReader::isNone(DelimType delim) { + return (delim == DelimTypeNone); +} + +std::string& +TextRowReader::getString(TextRowReader& th, bool& isNull, DelimType& delim) { + if (th.atEOL_) { + delim = DelimTypeEOR; // top-level EOR + } + + if (th.isEOEorEOR(delim)) { + isNull = true; + return emptyString; + } + + bool wasEscaped = false; + th.ownedString_.clear(); + + // Processing has to be done character by characater instad of chunk by chunk. + // This is to avoid edge case handling if escape character(s) are cut off at + // the end of the chunk. + while (true) { + auto v = th.getByteOptimized(delim); + if (!th.isNone(delim)) { + break; + } + + if (th.contents_->serDeOptions.isEscaped && + v == th.contents_->serDeOptions.escapeChar) { + wasEscaped = true; + th.ownedString_.append(1, static_cast(v)); + v = th.getByteUncheckedOptimized(delim); + if (!th.isNone(delim)) { + break; + } + } + th.ownedString_.append(1, static_cast(v)); + } + + if (th.ownedString_ == th.contents_->serDeOptions.nullString) { + isNull = true; + return emptyString; + } + + if (wasEscaped) { + // We need to copy the data byte by byte only if there is at least one + // escaped byte. + uint64_t j = 0; + for (uint64_t i = 0; i < th.ownedString_.length(); i++) { + if (th.ownedString_[i] == th.contents_->serDeOptions.escapeChar && + i < th.ownedString_.length() - 1) { + // Check if it's '\r' or '\n'. + i++; + if (th.ownedString_[i] == 'r') { + th.ownedString_[j++] = '\r'; + } else if (th.ownedString_[i] == 'n') { + th.ownedString_[j++] = '\n'; + } else { + // Keep the next byte. + th.ownedString_[j++] = th.ownedString_[i]; + } + } else { + th.ownedString_[j++] = th.ownedString_[i]; + } + } + th.ownedString_.resize(j); + } + + return th.ownedString_; +} + +template +void TextRowReader::setValueFromString( + const std::string& str, + BaseVector* data, + vector_size_t insertionRow, + std::function(const std::string&)> convert) { + if ((atEOF_ && atSOL_) || data == nullptr) { + return; + } + auto flatVector = data->asChecked>(); + auto result = str.empty() ? std::nullopt : convert(str); + if (result) { + flatVector->set(insertionRow, *result); + } else { + flatVector->setNull(insertionRow, true); + } +} + +uint8_t TextRowReader::getByteOptimized(DelimType& delim) { + setNone(delim); + auto v = getByteUncheckedOptimized(delim); + if (isNone(delim)) { + if (v == '\r') { + v = getByteUncheckedOptimized( + delim); // always returns '\n' in this case + } + delim = getDelimType(v); + } + return v; +} + +DelimType TextRowReader::getDelimType(uint8_t v) { + DelimType delim = DelimTypeNone; + + if (v == '\n') { + atEOL_ = true; + delim = DelimTypeEOR; // top level EOR + + /// TODO: Logically should be >=, kept as it is to align with presto reader. + if (pos_ > limit_) { + atEOF_ = true; + delim = DelimTypeEOR; + } + } else if (v == contents_->serDeOptions.separators.at(depth_)) { + setEOE(delim); + } else { + setNone(delim); + uint64_t i = depth_; + while (i > 0) { + i--; + if (v == contents_->serDeOptions.separators.at(i)) { + delim = i + DelimTypeEOR; // level-based EOR + break; + } + } + } + return delim; +} + +template +char TextRowReader::getByteUncheckedOptimized(DelimType& delim) { + if (atEOL_) { + if (!skipLF) { + delim = DelimTypeEOR; // top level EOR + } + return '\n'; + } + + try { + char v; + if (contents_->compression != CompressionKind::CompressionKind_NONE && + preLoadedUnreadData_.empty()) { + int length = 0; + const void* buffer = nullptr; + atPhysicalEOF_ = + !contents_->decompressedInputStream->Next(&buffer, &length); + if (!atPhysicalEOF_) { + preLoadedUnreadData_ = + std::string_view(reinterpret_cast(buffer), length); + } + } + + if (unreadData_.empty() || unreadIdx_ >= unreadData_.size()) { + bool updated = false; + if (contents_->compression != CompressionKind::CompressionKind_NONE) { + unreadData_.assign( + preLoadedUnreadData_.data(), preLoadedUnreadData_.size()); + preLoadedUnreadData_ = {}; + updated = !unreadData_.empty(); + } else { + int length = 0; + const void* buffer = nullptr; + if (contents_->inputStream->Next(&buffer, &length) && length > 0) { + VELOX_CHECK_NOT_NULL(buffer); + unreadData_.assign(reinterpret_cast(buffer), length); + updated = true; + } + } + + if (!updated) { + setEOF(); + delim = DelimTypeEOR; + return '\0'; + } + unreadIdx_ = 0; + } + + v = unreadData_[unreadIdx_++]; + pos_++; + + // only when previous char == '\r' + if (skipLF) { + if (v != '\n') { + pos_--; + return '\n'; + } + } else { + atSOL_ = false; + } + return v; + } catch (EOFError&) { + } catch (std::runtime_error& e) { + if (std::string(e.what()).find("Short read of") != 0 && !skipLF) { + throw; + } + } + if (!skipLF) { + setEOF(); + delim = DelimTypeEOR; + } + return '\n'; +} + +bool TextRowReader::getEOR(DelimType& delim, bool& isNull) { + if (isEOR(delim)) { + isNull = true; + return true; + } + if (atEOL_) { + delim = DelimTypeEOR; // top-level EOR + isNull = true; + return true; + } + bool wasAtSOL = atSOL_; + setNone(delim); + ownedString_.clear(); + const auto& ns = contents_->serDeOptions.nullString; + uint8_t v = 0; + while (true) { + v = getByteUncheckedOptimized(delim); + if (isNone(delim)) { + if (v == '\r') { + // always returns '\n' in this case + v = getByteUncheckedOptimized(delim); + } + delim = getDelimType(v); + } + + if (isEOR(delim) || atEOL_) { + if (ownedString_ == ns) { + isNull = true; + } else if (!ownedString_.empty()) { + break; + } + setEOR(delim); + return true; + } + if (ownedString_.size() >= ns.size() || + static_cast(v) != ns[ownedString_.size()]) { + break; + } + ownedString_.push_back(static_cast(v)); + } + + unreadData_.insert(0, 1, static_cast(v)); + pos_--; + if (!ownedString_.empty()) { + unreadData_.insert(0, ownedString_); + pos_ -= ownedString_.size(); + } + atEOL_ = false; + atSOL_ = wasAtSOL; + setNone(delim); + return false; +} + +bool TextRowReader::skipLine() { + DelimType delim = DelimTypeNone; + while (!atEOL_) { + (void)getByteOptimized(delim); + } + /// TODO: Logically should be >=, kept as it is to align with presto reader + if (pos_ > limit_) { + setEOF(); + delim = DelimTypeEOR; + } + return atEOF_; +} + +void TextRowReader::resetLine() { + if (!atEOF_) { + atEOL_ = false; + VELOX_CHECK_EQ(depth_, 0); + } + atSOL_ = true; +} + +template +T TextRowReader::getInteger(TextRowReader& th, bool& isNull, DelimType& delim) { + const std::string& str = getString(th, isNull, delim); + + if (str.empty()) { + isNull = true; + } + if (isNull) { + return 0; + } + + // Test if s is not acceptable integer format for + // the warehouse, for cases accepted by stol(). + char c = str[0]; + if (c != '-' && !std::isdigit(static_cast(c))) { + isNull = true; + return 0; + } + + int64_t v = 0; + unsigned long long scanPos = 0; + errno = 0; + auto scanCount = sscanf(str.c_str(), "%" SCNd64 "%lln", &v, &scanPos); + if (scanCount != 1 || errno == ERANGE) { + isNull = true; + return 0; + } + if (scanPos < str.size()) { + // Check if the string is a valid decimal. + for (uint64_t i = scanPos; i < str.size(); i++) { + if (i == scanPos && str[i] == '.') { + continue; + } + if (str[i] >= '0' && str[i] <= '9') { + continue; + } + isNull = true; + return 0; + } + } + + if (!std::is_same::value) { + if (static_cast(static_cast(v)) != v) { + isNull = true; + return 0; + } + } + return static_cast(v); +} + +namespace { + +static constexpr std::string_view kTrueStringView{"TRUE"}; +static constexpr std::string_view kFalseStringView{"FALSE"}; + +} // namespace + +bool TextRowReader::getBoolean( + TextRowReader& th, + bool& isNull, + DelimType& delim) { + const std::string& str = getString(th, isNull, delim); + if (str.empty()) { + isNull = true; + } + if (isNull) { + return false; + } + if (str.compare(kTrueStringView) == 0) { + return true; + } + if (str.compare(kFalseStringView) == 0) { + return false; + } + + switch (str.size()) { + case 4: + if (boost::algorithm::iequals(str, kTrueStringView)) { + return true; + } + break; + case 5: + if (boost::algorithm::iequals(str, kFalseStringView)) { + return false; + } + break; + default: + break; + } + + isNull = true; + return false; +} + +namespace { + +static constexpr std::string_view kNaNStringView{"NaN"}; +static constexpr std::string_view kInfinityStringView{"Infinity"}; +static constexpr std::string_view kShortInfinityStringView{"Inf"}; +static constexpr std::string_view kNegInfinityStringView{"-Infinity"}; +static constexpr std::string_view kShortNegInfinityStringView{"-Inf"}; + +bool unacceptableFloatingPoint(std::string& s) { + for (int i = 0; i < s.size(); ++i) { + char c = s.data()[i]; + if (!(std::isalpha(c) || c == '-')) { + return false; + } + } + + bool isNaN = boost::algorithm::iequals(s, kNaNStringView); + + bool isInf = boost::algorithm::iequals(s, kInfinityStringView); + bool isShortInf = boost::algorithm::iequals(s, kShortInfinityStringView); + + bool isNegInf = boost::algorithm::iequals(s, kNegInfinityStringView); + bool isShortNegInf = + boost::algorithm::iequals(s, kShortNegInfinityStringView); + + return (!isNaN && !isInf && !isShortInf && !isNegInf && !isShortNegInf); +} + +void trimStringInPlace(std::string& s) { + const auto isNotSpace = [](unsigned char ch) { return ch > 0x20; }; + size_t start = 0; + size_t end = s.size(); + + // Find first non-whitespace character + while (start < end && !isNotSpace(s[start])) { + ++start; + } + + // If the string is all whitespace + if (start == end) { + s.clear(); + return; + } + + // Find last non-whitespace character + size_t last = end - 1; + while (last > start && !isNotSpace(s[last])) { + --last; + } + + // Erase leading and trailing whitespace + s = s.substr(start, last - start + 1); +} + +} // namespace + +float TextRowReader::getFloat( + TextRowReader& th, + bool& isNull, + DelimType& delim) { + std::string& str = getString(th, isNull, delim); + if (str.empty()) { + isNull = true; + } + if (isNull) { + return 0; + } + + trimStringInPlace(str); + + if (str.data()[0] == '.') { + th.ownedString_.insert(th.ownedString_.begin(), '0'); + str = th.ownedString_; + } + + if (unacceptableFloatingPoint(str)) { + isNull = true; + return 0.0; + } + + float v = 0.0; + unsigned long long scanPos = 0; + // We ignore ERANGE, since denormalized floats and + // infinities are acceptable. + auto scanCount = sscanf(str.c_str(), "%f%lln", &v, &scanPos); + if (scanCount != 1 || scanPos < str.size()) { + isNull = true; + return 0.0; + } + return v; +} + +double +TextRowReader::getDouble(TextRowReader& th, bool& isNull, DelimType& delim) { + std::string& str = getString(th, isNull, delim); + if (str.empty()) { + isNull = true; + } + + if (isNull) { + return 0.0; + } + + trimStringInPlace(str); + + if (str.data()[0] == '.') { + th.ownedString_.insert(th.ownedString_.begin(), '0'); + str = th.ownedString_; + } + + // Filter out values from non-warehouse sources which + // other readers translate to null. Warehouse + // readers require upper-case values. + if (unacceptableFloatingPoint(str)) { + isNull = true; + return 0.0; + } + + double v = 0.0; + unsigned long long scanPos = 0; + // We ignore ERANGE, since denormalized doubles and + // infinities are acceptable. + auto scanCount = sscanf(str.c_str(), "%lf%lln", &v, &scanPos); + if (scanCount != 1 || scanPos < str.size()) { + isNull = true; + return 0.0; + } + return v; +} + +/// TODO: Reconsider error handling strategy for malformed data +/// Currently, all read functions convert invalid/malformed data to NULL values. +/// This approach may produce incorrect query results, particularly for +/// aggregate operations where a high volume of NULLs can significantly skew +/// calculations (e.g., COUNT, AVG, SUM). Consider alternative strategies such +/// as throwing exceptions, logging warnings, or providing configurable error +/// handling modes. +void TextRowReader::readElement( + const std::shared_ptr& t, + const std::shared_ptr& reqT, + BaseVector* FOLLY_NULLABLE data, + vector_size_t insertionRow, + DelimType& delim) { + bool isNull = false; + switch (t->kind()) { + case TypeKind::INTEGER: + switch (reqT->kind()) { + case TypeKind::BIGINT: + putValue( + getInteger, data, insertionRow, delim); + break; + case TypeKind::INTEGER: + if (reqT->isDate()) { + const std::string& str = getString(*this, isNull, delim); + setValueFromString( + str, + data, + insertionRow, + [](const std::string& s) -> std::optional { + return DATE()->toDays(s); + }); + } else { + putValue( + getInteger, data, insertionRow, delim); + } + break; + default: + VELOX_FAIL( + "Requested type {} is not supported to be read as type {}", + reqT->toString(), + t->toString()); + break; + } + break; + + case TypeKind::BIGINT: + if (reqT->isShortDecimal()) { + const std::string& str = getString(*this, isNull, delim); + auto decimalParams = getDecimalPrecisionScale(*reqT); + const auto precision = decimalParams.first; + const auto scale = decimalParams.second; + setValueFromString( + str, + data, + insertionRow, + [precision, scale](const std::string& s) -> std::optional { + int64_t v = 0; + const auto status = DecimalUtil::castFromString( + StringView(s.data(), static_cast(s.size())), + precision, + scale, + v); + return status.ok() ? std::optional(v) : std::nullopt; + }); + } else { + putValue( + getInteger, data, insertionRow, delim); + } + break; + + case TypeKind::HUGEINT: { + const std::string& str = getString(*this, isNull, delim); + if (reqT->isLongDecimal()) { + auto decimalParams = getDecimalPrecisionScale(*reqT); + const auto precision = decimalParams.first; + const auto scale = decimalParams.second; + setValueFromString( + str, + data, + insertionRow, + [precision, + scale](const std::string& s) -> std::optional { + int128_t v = 0; + const auto status = DecimalUtil::castFromString( + StringView(s.data(), static_cast(s.size())), + precision, + scale, + v); + return status.ok() ? std::optional(v) : std::nullopt; + }); + } else { + setValueFromString( + str, + data, + insertionRow, + [](const std::string& s) -> std::optional { + return HugeInt::parse(s); + }); + } + break; + } + case TypeKind::SMALLINT: + switch (reqT->kind()) { + case TypeKind::BIGINT: + putValue( + getInteger, data, insertionRow, delim); + break; + case TypeKind::INTEGER: + putValue( + getInteger, data, insertionRow, delim); + break; + case TypeKind::SMALLINT: + putValue( + getInteger, data, insertionRow, delim); + break; + default: + VELOX_FAIL( + "Requested type {} is not supported to be read as type {}", + reqT->toString(), + t->toString()); + break; + } + break; + + case TypeKind::VARBINARY: { + const std::string& str = getString(*this, isNull, delim); + + // Early return if no data vector or at EOF + if ((atEOF_ && atSOL_) || (data == nullptr)) { + return; + } + + const auto& flatVector = data->asChecked>(); + if (!flatVector) { + VELOX_FAIL( + "Vector for column type does not match: expected FlatVector, got {}", + data ? data->type()->toString() : "null"); + return; + } + + // Allocate a blob buffer + size_t len = str.size(); + const auto blen = encoding::Base64::calculateDecodedSize(str.data(), len); + varBinBuf_->resize(blen.value_or(0)); + + // decode from base64 to the blob buffer. + Status status = encoding::Base64::decode( + str.data(), str.size(), varBinBuf_->data(), blen.value_or(0)); + + if (status.code() == StatusCode::kOK) { + flatVector->set( + insertionRow, + StringView(varBinBuf_->data(), static_cast(blen.value()))); + } else { + // Not valid base64: just copy as-is for compatibility. + // + // Note that some warehouse file have simply binary data + // in what should be a base64-encoded field, and which + // may result in extra rows. Other readers behave as + // below, so this provides compatibility, even if all + // readers should really reject these files. + varBinBuf_->resize(str.size()); + + VELOX_CHECK_NOT_NULL(str.data()); + + len = str.size(); + memcpy(varBinBuf_->data(), str.data(), str.size()); + + // Use StringView, set(vector_size_t idx, T value) fails because + // strlen(varBinBuf_->data()) is undefined due to lack of null + // terminator + flatVector->set( + insertionRow, + StringView(varBinBuf_->data(), static_cast(str.size()))); + } + + if (isNull) { + flatVector->setNull(insertionRow, true); + } + + break; + } + case TypeKind::VARCHAR: { + const std::string& str = getString(*this, isNull, delim); + + // Early return if no data vector or at EOF + if ((atEOF_ && atSOL_) || (data == nullptr)) { + return; + } + + const auto& flatVector = data->asChecked>(); + if (!flatVector) { + VELOX_FAIL( + "Vector for column type does not match: expected FlatVector, got {}", + data ? data->type()->toString() : "null"); + return; + } + + flatVector->set( + insertionRow, + StringView(str.data(), static_cast(str.size()))); + + if (isNull) { + flatVector->setNull(insertionRow, true); + } + + break; + } + + case TypeKind::BOOLEAN: + switch (reqT->kind()) { + case TypeKind::BIGINT: + putValue(getBoolean, data, insertionRow, delim); + break; + case TypeKind::INTEGER: + putValue(getBoolean, data, insertionRow, delim); + break; + case TypeKind::SMALLINT: + putValue(getBoolean, data, insertionRow, delim); + break; + case TypeKind::TINYINT: + putValue(getBoolean, data, insertionRow, delim); + break; + case TypeKind::BOOLEAN: + putValue(getBoolean, data, insertionRow, delim); + break; + default: + VELOX_FAIL( + "Requested type {} is not supported to be read as type {}", + reqT->toString(), + t->toString()); + break; + } + break; + + case TypeKind::TINYINT: + switch (reqT->kind()) { + case TypeKind::BIGINT: + putValue( + getInteger, data, insertionRow, delim); + break; + case TypeKind::INTEGER: + putValue( + getInteger, data, insertionRow, delim); + break; + case TypeKind::SMALLINT: + putValue( + getInteger, data, insertionRow, delim); + break; + case TypeKind::TINYINT: + putValue( + getInteger, data, insertionRow, delim); + break; + default: + VELOX_FAIL( + "Requested type {} is not supported to be read as type {}", + reqT->toString(), + t->toString()); + break; + } + break; + + case TypeKind::ARRAY: { + const auto& ct = t->childAt(0); + const auto& arrayVector = data ? data->asChecked() : nullptr; + + incrementDepth(); + (void)getEOR(delim, isNull); + + if (arrayVector != nullptr) { + auto rawSizes = arrayVector->sizes()->asMutable(); + auto rawOffsets = arrayVector->offsets()->asMutable(); + + rawOffsets[insertionRow] = insertionRow > 0 + ? rawOffsets[insertionRow - 1] + rawSizes[insertionRow - 1] + : 0; + const int startElementIdx = rawOffsets[insertionRow]; + + vector_size_t elementCount = 0; + if (isNull) { + arrayVector->setNull(insertionRow, isNull); + rawSizes[insertionRow] = 0; + } else { + // Read elements until we reach the end of the array. + while (!isOuterEOR(delim)) { + setNone(delim); + auto elementsVector = arrayVector->elements().get(); + resizeVector(elementsVector, startElementIdx + elementCount); + + readElement( + ct, + reqT->childAt(0), + elementsVector, + startElementIdx + elementCount, + delim); + + // Update size on every iteration to allow the right size + // inheritance in resizeVector. + rawSizes[insertionRow] = ++elementCount; + + if (atEOF_ && atSOL_) { + decrementDepth(delim); + return; + } + } + } + + } else { + // Skip over array data to maintain correct stream position. + while (!isOuterEOR(delim)) { + setNone(delim); + readElement(ct, reqT->childAt(0), nullptr, 0, delim); + } + } + decrementDepth(delim); + break; + } + + case TypeKind::ROW: { + const auto& childCount = t->size(); + const auto& rowVector = data ? data->asChecked() : nullptr; + incrementDepth(); + + if (rowVector != nullptr) { + if (isNull) { + rowVector->setNull(insertionRow, isNull); + } else { + for (uint64_t j = 0; j < childCount; j++) { + if (!isOuterEOR(delim)) { + setNone(delim); + } + + // Get the child vector for this field. + BaseVector* childVector = nullptr; + if (j < reqT->size()) { + childVector = rowVector->childAt(j).get(); + } + resizeVector(childVector, insertionRow); + readElement( + t->childAt(j), + j < reqT->size() ? reqT->childAt(j) : t->childAt(j), + childVector, + insertionRow, + delim); + + if (atEOF_ && atSOL_) { + decrementDepth(delim); + return; + } + } + } + } else { + // Skip over row data to maintain correct stream position. + for (uint64_t j = 0; j < childCount; j++) { + if (!isOuterEOR(delim)) { + setNone(delim); + } + readElement(t->childAt(j), reqT->childAt(j), nullptr, 0, delim); + } + } + + decrementDepth(delim); + setEOE(delim); + break; + } + + case TypeKind::MAP: { + const auto& mapt = t->asMap(); + const auto& key = mapt.keyType(); + const auto& value = mapt.valueType(); + const auto& mapVector = data ? data->asChecked() : nullptr; + incrementDepth(); + (void)getEOR(delim, isNull); + + if (mapVector != nullptr) { + auto rawOffsets = mapVector->offsets()->asMutable(); + auto rawSizes = mapVector->sizes()->asMutable(); + + rawOffsets[insertionRow] = insertionRow > 0 + ? rawOffsets[insertionRow - 1] + rawSizes[insertionRow - 1] + : 0; + const int startElementIdx = rawOffsets[insertionRow]; + + vector_size_t elementCount = 0; + if (isNull) { + mapVector->setNull(insertionRow, isNull); + rawSizes[insertionRow] = 0; + } else { + while (!isOuterEOR(delim)) { + // Decode another element. + setNone(delim); + incrementDepth(); + + // insert key + auto keysVector = mapVector->mapKeys().get(); + resizeVector(keysVector, startElementIdx + elementCount); + + readElement( + key, + reqT->childAt(0), + keysVector, + startElementIdx + elementCount, + delim); + + // Case for no value key. + if (atEOF_ && atSOL_) { + rawSizes[insertionRow] = elementCount; + rawOffsets[insertionRow + 1] = startElementIdx + elementCount; + decrementDepth(delim); + decrementDepth(delim); + return; + } + resetEOE(delim); + + // insert value + auto valsVector = mapVector->mapValues().get(); + resizeVector(valsVector, startElementIdx + elementCount); + + readElement( + value, + reqT->childAt(1), + valsVector, + startElementIdx + elementCount, + delim); + + rawSizes[insertionRow] = ++elementCount; + + decrementDepth(delim); + } + } + + } else { + // Skip over map data to maintain correct stream position. + while (!isOuterEOR(delim)) { + setNone(delim); + incrementDepth(); + readElement(key, reqT->childAt(0), nullptr, 0, delim); + resetEOE(delim); + readElement(value, reqT->childAt(1), nullptr, 0, delim); + decrementDepth(delim); + } + } + decrementDepth(delim); + break; + } + + case TypeKind::REAL: + switch (reqT->kind()) { + case TypeKind::REAL: + putValue(getFloat, data, insertionRow, delim); + break; + case TypeKind::DOUBLE: + putValue(getDouble, data, insertionRow, delim); + break; + default: + VELOX_FAIL( + "Requested type {} is not supported to be read as type {}", + reqT->toString(), + t->toString()); + break; + } + break; + + case TypeKind::DOUBLE: + putValue(getDouble, data, insertionRow, delim); + break; + + case TypeKind::TIMESTAMP: { + const std::string& s = getString(*this, isNull, delim); + + // Early return if no data vector or at EOF + if ((atEOF_ && atSOL_) || (data == nullptr)) { + return; + } + + auto flatVector = data->asChecked>(); + if (!flatVector) { + VELOX_FAIL( + "Vector for column type does not match: expected FlatVector, got {}", + data ? data->type()->toString() : "null"); + return; + } + + if (s.empty()) { + isNull = true; + flatVector->setNull(insertionRow, true); + } else { + auto ts = util::Converter::tryCast(s).thenOrThrow( + folly::identity, + [&](const Status& status) { VELOX_USER_FAIL(status.message()); }); + ts.toGMT(Timestamp::defaultTimezone()); + flatVector->set( + insertionRow, Timestamp{ts.getSeconds(), ts.getNanos()}); + } + + break; + } + + default: + VELOX_NYI("readElement unhandled type (kind code {})", t->kind()); + } + + ownedString_.clear(); +} + +uint64_t maxStreamsForType(const std::shared_ptr& type) { + switch (type->kind()) { + case TypeKind::ROW: + case TypeKind::REAL: + case TypeKind::DOUBLE: + case TypeKind::BOOLEAN: + case TypeKind::TINYINT: + case TypeKind::ARRAY: + case TypeKind::MAP: + case TypeKind::VARBINARY: + case TypeKind::TIMESTAMP: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + case TypeKind::SMALLINT: + case TypeKind::VARCHAR: + return 1; + default: + return 0; + } +} + +template +void TextRowReader::putValue( + const std::function& + f, + BaseVector* FOLLY_NULLABLE data, + vector_size_t insertionRow, + DelimType& delim) { + bool isNull = false; + T v; + if (isEOR(delim)) { + isNull = true; + v = 0; + } else { + v = f(*this, isNull, delim); + } + + // Early return if no data vector or at EOF + if ((atEOF_ && atSOL_) || (data == nullptr)) { + return; + } + + // Cast to FlatVector + auto flatVector = data ? data->asChecked>() : nullptr; + if (!flatVector) { + VELOX_FAIL("Vector for column type does not match"); + return; + } + + // Handle null property. + if (isNull) { + flatVector->setNull(insertionRow, isNull); + return; + } + + flatVector->set(insertionRow, v); +} + +const std::shared_ptr& TextRowReader::getType() const { + return contents_->schema; +} + +TextReader::TextReader( + const ReaderOptions& options, + std::unique_ptr input) + : options_{options} { + auto schema = options_.fileSchema(); + VELOX_USER_CHECK_NOT_NULL(schema, "File schema for TEXT must be set."); + + if (!schema) { + // Create dummy for testing. + internalSchema_ = std::dynamic_pointer_cast( + type::fbhive::HiveTypeParser().parse("struct")); + DWIO_ENSURE_NOT_NULL(internalSchema_.get()); + schema = internalSchema_; + } + schemaWithId_ = TypeWithId::create(schema); + contents_ = std::make_shared(options_.memoryPool(), schema); + + if (!contents_->schema->isRow()) { + throw std::invalid_argument("file schema must be a ROW type"); + } + + contents_->input = std::move(input); + + // Find the size of the file using the option or filesystem. + contents_->fileLength = std::min( + options_.tailLocation(), + static_cast(contents_->input->getInputStream()->getLength())); + + /** + * We are now allowing delimiters/separators and escape characters to be the + * same. This could be error prone because we are checking for delimiters + * before escape characters. + * + * Example: + * delim = ','; escapeChar = ',' + * dataToParse = "1,,2" + * Schema = ROW(ARRAY(VARCHAR())) + * + * Scenario 1: Check delimiter before escape (current implementation) + * Output: ["1", NULL, "2"] + * + * Scenario 2: Check escape before delim + * Output: ["1,2"] + * + * TODO: This is not a bug but would be good to be able to handle this + * ambiguity + */ + + // Set the SerDe options. + contents_->serDeOptions = options_.serDeOptions(); + if (contents_->serDeOptions.isEscaped) { + for (auto delim : contents_->serDeOptions.separators) { + contents_->needsEscape.at(delim) = true; + } + contents_->needsEscape.at(contents_->serDeOptions.escapeChar) = true; + } + + // Validate SerDe options. + VELOX_CHECK( + contents_->serDeOptions.nullString.compare("\r") != 0, + "\'\\r\' is not allowed to be nullString"); + VELOX_CHECK( + contents_->serDeOptions.nullString.compare("\n") != 0, + "\'\\n\n is not allowed to be nullString"); + + setCompressionSettings( + contents_->input->getName(), + contents_->compression, + contents_->compressionOptions); +} + +std::optional TextReader::numberOfRows() const { + return std::nullopt; +} + +std::unique_ptr TextReader::columnStatistics( + uint32_t /*index*/) const { + return nullptr; +} + +const std::shared_ptr& TextReader::rowType() const { + return contents_->schema; +} + +CompressionKind TextReader::getCompression() const { + return contents_->compression; +} + +const std::shared_ptr& TextReader::typeWithId() const { + if (!typeWithId_) { + typeWithId_ = TypeWithId::create(rowType()); + } + return typeWithId_; +} + +std::unique_ptr TextReader::createRowReader( + const RowReaderOptions& opts) const { + return std::make_unique(contents_, opts); +} + +uint64_t TextReader::getFileLength() const { + return contents_->fileLength; +} + +} // namespace facebook::velox::text diff --git a/velox/dwio/text/reader/TextReader.h b/velox/dwio/text/reader/TextReader.h new file mode 100644 index 000000000000..435de81c35a7 --- /dev/null +++ b/velox/dwio/text/reader/TextReader.h @@ -0,0 +1,238 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "folly/CppAttributes.h" +#include "velox/dwio/common/BufferedInput.h" +#include "velox/dwio/common/Reader.h" +#include "velox/dwio/common/TypeWithId.h" +#include "velox/dwio/common/compression/Compression.h" + +namespace facebook::velox::text { + +using common::CompressionKind; +using common::ScanSpec; +using dwio::common::BufferedInput; +using dwio::common::ColumnSelector; +using dwio::common::ColumnStatistics; +using dwio::common::Mutation; +using dwio::common::ReaderOptions; +using dwio::common::RowReaderOptions; +using dwio::common::SerDeOptions; +using dwio::common::TypeWithId; +using memory::MemoryPool; + +// Shared state for a file between TextReader and TextRowReader +struct FileContents { + FileContents(MemoryPool& pool, const std::shared_ptr& t); + + const size_t COLUMN_POSITION_INVALID = std::numeric_limits::max(); + const std::shared_ptr schema; + + std::unique_ptr input; + std::unique_ptr inputStream; + std::unique_ptr decompressedInputStream; + MemoryPool& pool; + uint64_t fileLength; + CompressionKind compression; + dwio::common::compression::CompressionOptions compressionOptions; + SerDeOptions serDeOptions; + std::array needsEscape; +}; + +using DelimType = uint8_t; +constexpr DelimType DelimTypeNone = 0; +constexpr DelimType DelimTypeEOR = 1; +constexpr DelimType DelimTypeEOE = 2; + +class TextReader : public dwio::common::Reader { + public: + TextReader( + const ReaderOptions& options, + std::unique_ptr input); + + std::optional numberOfRows() const override; + + std::unique_ptr columnStatistics( + uint32_t index) const override; + + const RowTypePtr& rowType() const override; + + CompressionKind getCompression() const; + + const std::shared_ptr& typeWithId() const override; + + std::unique_ptr createRowReader( + const RowReaderOptions& options) const override; + + uint64_t getFileLength() const; + + private: + ReaderOptions options_; + mutable std::shared_ptr typeWithId_; + std::shared_ptr contents_; + std::shared_ptr schemaWithId_; + std::shared_ptr internalSchema_; +}; + +class TextRowReader : public dwio::common::RowReader { + public: + TextRowReader( + std::shared_ptr fileContents, + const RowReaderOptions& options); + + uint64_t next( + uint64_t size, + VectorPtr& result, + const Mutation* mutation = nullptr) override; + + int64_t nextRowNumber() override; + + int64_t nextReadSize(uint64_t size) override; + + void updateRuntimeStats( + dwio::common::RuntimeStatistics& stats) const override; + + void resetFilterCaches() override; + + std::optional estimatedRowSize() const override; + + const ColumnSelector& getColumnSelector() const; + + std::shared_ptr getSelectedType() const; + + uint64_t getRowNumber() const; + + uint64_t seekToRow(uint64_t rowNumber); + + private: + const RowReaderOptions& getDefaultOpts(); + + const std::shared_ptr& getType() const; + + bool isSelectedField(const std::shared_ptr& t); + + const char* getStreamNameData() const; + + uint64_t getLength(); + + uint64_t getStreamLength() const; + + void setEOF(); + + void incrementDepth(); + + void decrementDepth(DelimType& delim); + + void setEOE(DelimType& delim); + + void resetEOE(DelimType& delim); + + bool isEOE(DelimType delim); + + void setEOR(DelimType& delim); + + bool isEOR(DelimType delim); + + bool isOuterEOR(DelimType delim); + + bool isEOEorEOR(DelimType delim); + + void setNone(DelimType& delim); + + bool isNone(DelimType delim); + + DelimType getDelimType(uint8_t v); + + template + char getByteUnchecked(DelimType& delim); + + template + char getByteUncheckedOptimized(DelimType& delim); + + uint8_t getByte(DelimType& delim); + uint8_t getByteOptimized(DelimType& delim); + + bool getEOR(DelimType& delim, bool& isNull); + + bool skipLine(); + + void resetLine(); + + static std::string& + getString(TextRowReader& th, bool& isNull, DelimType& delim); + + template + static T getInteger(TextRowReader& th, bool& isNull, DelimType& delim); + + static bool getBoolean(TextRowReader& th, bool& isNull, DelimType& delim); + + static float getFloat(TextRowReader& th, bool& isNull, DelimType& delim); + + static double getDouble(TextRowReader& th, bool& isNull, DelimType& delim); + + void readElement( + const std::shared_ptr& t, + const std::shared_ptr& reqT, + BaseVector* FOLLY_NULLABLE data, + vector_size_t insertionRow, + DelimType& delim); + + template + void putValue( + const std::function& + f, + BaseVector* FOLLY_NULLABLE data, + vector_size_t insertionRow, + DelimType& delim); + + template + void setValueFromString( + const std::string& str, + BaseVector* FOLLY_NULLABLE data, + vector_size_t insertionRow, + std::function(const std::string&)> convert); + + const std::shared_ptr contents_; + const std::shared_ptr schemaWithId_; + const std::shared_ptr& scanSpec_; + + mutable std::shared_ptr selectedSchema_; + + RowReaderOptions options_; + ColumnSelector columnSelector_; + uint64_t currentRow_; + uint64_t pos_; + bool atEOL_; + bool atEOF_; + bool atSOL_; + bool atPhysicalEOF_; + uint8_t depth_; + std::string unreadData_; + std::string_view preLoadedUnreadData_; + int unreadIdx_; + uint64_t limit_; // lowest offset not in the range + uint64_t fileLength_; + std::string ownedString_; + std::shared_ptr> varBinBuf_; +}; + +} // namespace facebook::velox::text diff --git a/velox/dwio/text/tests/CMakeLists.txt b/velox/dwio/text/tests/CMakeLists.txt index 34a05424d366..032c209e75a7 100644 --- a/velox/dwio/text/tests/CMakeLists.txt +++ b/velox/dwio/text/tests/CMakeLists.txt @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -set(TEST_LINK_LIBS - velox_dwio_common_test_utils - velox_vector_test_lib - velox_exec_test_lib - velox_temp_path - GTest::gtest - GTest::gtest_main - GTest::gmock - gflags::gflags - glog::glog) +set( + TEST_LINK_LIBS + velox_dwio_common_test_utils + velox_vector_test_lib + velox_exec_test_lib + velox_temp_path + GTest::gtest + GTest::gtest_main + GTest::gmock + gflags::gflags + glog::glog +) + +add_subdirectory(reader) add_subdirectory(writer) diff --git a/velox/dwio/text/tests/reader/CMakeLists.txt b/velox/dwio/text/tests/reader/CMakeLists.txt new file mode 100644 index 000000000000..55ec4ff35ed1 --- /dev/null +++ b/velox/dwio/text/tests/reader/CMakeLists.txt @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(velox_text_reader_test TextReaderTest.cpp) + +add_test( + NAME velox_text_reader_test + COMMAND velox_text_reader_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(velox_text_reader_test velox_dwio_text_reader_register ${TEST_LINK_LIBS}) + +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/velox/dwio/text/tests/reader/TextReaderTest.cpp b/velox/dwio/text/tests/reader/TextReaderTest.cpp new file mode 100644 index 000000000000..15e4e595af8a --- /dev/null +++ b/velox/dwio/text/tests/reader/TextReaderTest.cpp @@ -0,0 +1,1940 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/dwio/common/tests/utils/DataFiles.h" +#include "velox/dwio/text/RegisterTextReader.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +extern int daylight; +extern long timezone; + +using namespace facebook::velox; +using namespace facebook::velox::test; + +namespace facebook::velox::text { + +namespace { + +int32_t parseDate(const std::string& text) { + return DATE()->toDays(text); +} + +class TextReaderTest : public testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestSuite() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void SetUp() override { + registerTextReaderFactory(); + } + + void TearDown() override { + unregisterTextReaderFactory(); + } + + memory::MemoryPool& poolRef() { + return *pool(); + } + + void setScanSpec(const Type& type, dwio::common::RowReaderOptions& options) { + auto spec = std::make_shared("root"); + spec->addAllChildFields(type); + options.setScanSpec(spec); + } + + private: + std::shared_ptr readFile_; +}; + +struct TestCompressionParam { + std::string filepath; + std::string compression; +}; + +class TextReaderDecompressionTest + : public TextReaderTest, + public testing::WithParamInterface {}; + +TEST_F(TextReaderTest, basic) { + auto expected = makeRowVector({ + makeFlatVector( + {"FOO", + "FOO", + "FOO", + "FOO", + "BAR", + "BAR", + "BAR", + "BAR", + "BAZ", + "BAZ", + "BAZ", + "BAZ"}), + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), + makeFlatVector( + {1.123, + 2.333, + -6.1, + 4.2, + 47.2, + 79.5, + 3.1415926, + -221.145, + 93.12, + -4123.11, + 950.2, + 43.66}), + makeFlatVector( + {true, + true, + false, + true, + false, + false, + true, + true, + false, + false, + false, + true}), + }); + + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", + "examples/simple_types_compressed_file.gz"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Try reading 10 rows each time. + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 2); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); + + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + rowReader = reader->createRowReader(rowReaderOptions); + // Try reading 2, 3, 4, 5 rows at a time. + ASSERT_EQ(rowReader->next(2, result), 2); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 2 + i)); + } + ASSERT_EQ(rowReader->next(4, result), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 5 + i)); + } + ASSERT_EQ(rowReader->next(5, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 9 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, headerAndCustomNullString) { + tzset(); + const auto tzOffsetPST = 28'800; + const auto tzOffsetPDT = 25'200; + auto expected = makeRowVector({ + makeFlatVector( + {"FOO", + "FOO", + "FOO", + "FOO", + "BAR", + "BAR", + "BAR", + "BAR", + "BAZ", + "BAZ", + "BAZ", + "BAZ", + ""}), + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}), + makeNullableFlatVector( + {1.123, + 2.333, + -6.1, + 4.2, + 47.2, + 79.5, + 3.1415926, + -221.145, + 93.12, + -4123.11, + 950.2, + 43.66, + std::nullopt}), + makeNullableFlatVector({ + Timestamp{1'695'378'095 + tzOffsetPDT, 148'000'000}, + Timestamp{1'695'690'351 + tzOffsetPDT, 0}, + Timestamp{1'695'686'400 + tzOffsetPDT, 0}, + Timestamp{1'695'083'400 + tzOffsetPDT, 0}, + std::nullopt, + Timestamp{1'695'657'091 + tzOffsetPDT, 209'000'000}, + Timestamp{1'695'690'437 + tzOffsetPDT, 469'123'000}, + Timestamp{1'696'540'679 + tzOffsetPDT, 976'000'000}, + Timestamp{1'695'657'171 + tzOffsetPDT, 637'000'000}, + Timestamp{1'695'693'225 + tzOffsetPDT, 745'123'000}, + std::nullopt, + Timestamp{1'695'406'246 + tzOffsetPDT, 0}, + Timestamp{1'699'392'124 + tzOffsetPST, 736'000'000}, + }), + }); + + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_ts", TIMESTAMP()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/simple_types_with_header"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto rowReaderOptions = dwio::common::RowReaderOptions(); + setScanSpec(*type, rowReaderOptions); + rowReaderOptions.setSkipRows(1); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Try reading 10 rows each time. + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); + + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + rowReader = reader->createRowReader(rowReaderOptions); + // Try reading 2, 3, 4, 5 rows at a time. + ASSERT_EQ(rowReader->next(2, result), 2); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 2 + i)); + } + ASSERT_EQ(rowReader->next(4, result), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 5 + i)); + } + ASSERT_EQ(rowReader->next(5, result), 4); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 9 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); + + // Try reading with an empty NULL string. + auto serDeOptions = readerOptions.serDeOptions(); + serDeOptions.nullString = ""; + readerOptions.setSerDeOptions(serDeOptions); + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + rowReader = reader->createRowReader(rowReaderOptions); + ASSERT_EQ(rowReader->next(15, result), 13); + expected->childAt(0)->setNull(12, true); + for (int i = 0; i < 13; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + + // Try reading with a custom NULL string. + serDeOptions.nullString = "BAR"; + readerOptions.setSerDeOptions(serDeOptions); + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + rowReader = reader->createRowReader(rowReaderOptions); + ASSERT_EQ(rowReader->next(15, result), 13); + expected->childAt(0)->setNull(12, false); + expected->childAt(0)->setNull(4, true); + expected->childAt(0)->setNull(5, true); + expected->childAt(0)->setNull(6, true); + expected->childAt(0)->setNull(7, true); + for (int i = 0; i < 13; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } +} + +TEST_F(TextReaderTest, complexTypesWithCustomDelimiters) { + const vector_size_t length = 13; + const auto keyVector = makeFlatVector( + {1, 111, 22, 22222, 333, 33, 44, 5, 555, 66, 7777, 7, + 777, 8888, 88, 9, 99, 10, 10000, 111, 1, 11, 11122222, 123142}); + const auto valueVector = makeFlatVector( + {false, true, true, false, false, true, false, true, + false, true, false, true, false, false, true, true, + false, true, false, true, false, true, true, true}); + BufferPtr sizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr offsets = facebook::velox::allocateOffsets(length, pool()); + auto rawSizes = sizes->asMutable(); + auto rawOffsets = offsets->asMutable(); + rawSizes[0] = 2; + rawSizes[1] = 2; + rawSizes[2] = 2; + rawSizes[3] = 1; + rawSizes[4] = 2; + rawSizes[5] = 1; + rawSizes[6] = 3; + rawSizes[7] = 2; + rawSizes[8] = 2; + rawSizes[9] = 2; + rawSizes[10] = 3; + rawSizes[11] = 1; + rawSizes[12] = 1; + for (int i = 1; i < length; i++) { + rawOffsets[i] = rawOffsets[i - 1] + rawSizes[i - 1]; + } + + const auto expected = makeRowVector({ + makeFlatVector( + {"FOO", + "FOO", + "FOO", + "FOO", + "BAR", + "BAR", + "BAR", + "BAR", + "BAZ", + "BAZ", + "BAZ", + "FOO\\nBAZ", + "FOO\n\nBAR\nBAZ"}), + makeArrayVector( + {{1, 11, 111}, + {22, 22222}, + {333, 33}, + {4444, 44}, + {5, 555}, + {666, 66, 66}, + {7777, 7, 777}, + {8888, 88}, + {9, 99}, + {10, 10000}, + {111, 1, 111}, + {12, 11122222, 222}, + {13, 11133333, 333}}), + makeArrayVector( + {{1.123, 1.3123}, + {2.333, -5512, 1.23}, + {-6.1, 65.777}, + {4.2, 24, 324.11}, + {47.2, 213.23}, + {79.5, -44.11}, + {3.1415926, 441.124}, + {-221.145, 878.43, -11}, + {93.12, 632}, + {-4123.11, -177.1}, + {950.2, -4412}, + {43.66, 33121.43}, + {-42.11, -123.43}}), + std::make_shared( + pool(), + MAP(keyVector->type(), valueVector->type()), + nullptr, + length, + offsets, + sizes, + keyVector, + valueVector), + }); + + const auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_bigint_arr", ARRAY(BIGINT())}, + {"col_double_arr", ARRAY(DOUBLE())}, + {"col_map", MAP(BIGINT(), BOOLEAN())}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/custom_delimiters_file"); + auto readFile = std::make_shared(path); + + auto serDeOptions = dwio::common::SerDeOptions('\t', '|', '#', '\\', true); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReaderOptions = dwio::common::RowReaderOptions(); + setScanSpec(*type, rowReaderOptions); + rowReaderOptions.range(0, 544); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Try reading 10 rows each time. + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); + + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + auto rowReaderOptions2 = dwio::common::RowReaderOptions(); + setScanSpec(*type, rowReaderOptions2); + rowReader = reader->createRowReader(rowReaderOptions2); + // Try reading 2, 3, 4, 5 rows at a time. + ASSERT_EQ(rowReader->next(2, result), 2); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 2 + i)); + } + ASSERT_EQ(rowReader->next(4, result), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 5 + i)); + } + ASSERT_EQ(rowReader->next(5, result), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 9 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, projectComplexTypesWithCustomDelimiters) { + const auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_bigint_arr", ARRAY(BIGINT())}, + {"col_double_arr", ARRAY(DOUBLE())}, + {"col_map", MAP(BIGINT(), BOOLEAN())}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/custom_delimiters_file"); + auto readFile = std::make_shared(path); + + auto serDeOptions = dwio::common::SerDeOptions('\t', '|', '#', '\\', true); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + auto spec = std::make_shared(""); + spec->addField("ds", 0)->setConstantValue( + BaseVector::createConstant(VARCHAR(), "2023-07-18", 1, pool())); + spec->addField("col_string", 1); + spec->addField("col_map", 2); + + dwio::common::RowReaderOptions rowOptions; + rowOptions.setScanSpec(spec); + rowOptions.select( + std::make_shared( + type, std::vector({"col_string", "col_map"}))); + auto rowReader = reader->createRowReader(rowOptions); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(13, result), 13); + ASSERT_EQ( + *result->type(), + *ROW( + {"ds", "col_string", "col_map"}, + {VARCHAR(), VARCHAR(), MAP(BIGINT(), BOOLEAN())})); + + const vector_size_t length = 13; + const auto keyVector = makeFlatVector( + {1, 111, 22, 22222, 333, 33, 44, 5, 555, 66, 7777, 7, + 777, 8888, 88, 9, 99, 10, 10000, 111, 1, 11, 11122222, 123142}); + const auto valueVector = makeFlatVector( + {false, true, true, false, false, true, false, true, + false, true, false, true, false, false, true, true, + false, true, false, true, false, true, true, true}); + BufferPtr sizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr offsets = facebook::velox::allocateOffsets(length, pool()); + auto rawSizes = sizes->asMutable(); + auto rawOffsets = offsets->asMutable(); + rawSizes[0] = 2; + rawSizes[1] = 2; + rawSizes[2] = 2; + rawSizes[3] = 1; + rawSizes[4] = 2; + rawSizes[5] = 1; + rawSizes[6] = 3; + rawSizes[7] = 2; + rawSizes[8] = 2; + rawSizes[9] = 2; + rawSizes[10] = 3; + rawSizes[11] = 1; + rawSizes[12] = 1; + for (int i = 1; i < length; i++) { + rawOffsets[i] = rawOffsets[i - 1] + rawSizes[i - 1]; + } + + auto expected = makeRowVector({ + std::make_shared>( + pool(), 13, false, VARCHAR(), "2023-07-18"), + makeFlatVector( + {"FOO", + "FOO", + "FOO", + "FOO", + "BAR", + "BAR", + "BAR", + "BAR", + "BAZ", + "BAZ", + "BAZ", + "FOO\\nBAZ", + "FOO\n\nBAR\nBAZ"}), + std::make_shared( + pool(), + MAP(keyVector->type(), valueVector->type()), + nullptr, + length, + offsets, + sizes, + keyVector, + valueVector), + }); + + ASSERT_EQ(result->size(), expected->size()); + for (int i = 0; i < 13; ++i) { + ASSERT_TRUE(result->equalValueAt(expected.get(), i, i)); + } +} + +TEST_F(TextReaderTest, projectPrimitiveTypes) { + auto type = ROW( + {{"col_int", INTEGER()}, + {"col_huge", BIGINT()}, + {"col_tiny", TINYINT()}, + {"col_double", DOUBLE()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/simple_types"); + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + // Test projection of multiple primitive types: VARCHAR, INTEGER, and BOOLEAN + auto spec = std::make_shared(""); + spec->addField("col_tiny", 0); + spec->addField("col_int", 1); + spec->addField("col_double", 2); + + dwio::common::RowReaderOptions rowOptions; + rowOptions.setScanSpec(spec); + rowOptions.select( + std::make_shared( + type, + std::vector({"col_tiny", "col_int", "col_double"}))); + auto rowReader = reader->createRowReader(rowOptions); + + VectorPtr result; + ASSERT_EQ(rowReader->next(20, result), 16); + ASSERT_EQ( + *result->type(), + *ROW( + {"col_tiny", "col_int", "col_double"}, + {TINYINT(), INTEGER(), DOUBLE()})); + + auto expected = makeRowVector({ + makeFlatVector({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), + makeFlatVector( + {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}), + makeNullableFlatVector( + {1.123, + 2.333, + -6.1, + 4.2, + 0.0000513, + 79.5, + std::nullopt, + 3.1415926, + 93.12, + -221.145, + std::nullopt, + 950.2, + -4123.11, + 43.66, + std::nullopt, + std::nullopt}), + }); + + ASSERT_EQ(result->size(), expected->size()); + for (int i = 0; i < 16; ++i) { + ASSERT_TRUE(result->equalValueAt(expected.get(), i, i)); + } +} + +TEST_F(TextReaderTest, projectColumns) { + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", + "examples/simple_types_compressed_file.gz"); + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto spec = std::make_shared(""); + spec->addField("ds", 0)->setConstantValue( + BaseVector::createConstant(VARCHAR(), "2023-07-18", 1, pool())); + spec->addField("col_float", 1); + dwio::common::RowReaderOptions rowOptions; + rowOptions.setScanSpec(spec); + rowOptions.select( + std::make_shared( + type, std::vector({"col_float"}))); + auto rowReader = reader->createRowReader(rowOptions); + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 10); + ASSERT_EQ(*result->type(), *ROW({"ds", "col_float"}, {VARCHAR(), DOUBLE()})); + auto expected = makeRowVector({ + std::make_shared>( + pool(), 10, false, VARCHAR(), "2023-07-18"), + makeFlatVector( + {1.123, + 2.333, + -6.1, + 4.2, + 47.2, + 79.5, + 3.1415926, + -221.145, + 93.12, + -4123.11, + 950.2, + 43.66}), + }); + ASSERT_EQ(result->size(), expected->size()); + for (int i = 0; i < 10; ++i) { + ASSERT_TRUE(result->equalValueAt(expected.get(), i, i)); + } +} + +TEST_F(TextReaderTest, projectNone) { + // Tests the case where none of the columns are projected, e.g. a basic + // count(*) query. + auto type = ROW( + {{"col_int", INTEGER()}, + {"col_big_int", BIGINT()}, + {"col_tiny_int", TINYINT()}, + {"col_double", DOUBLE()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/simple_types"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + dwio::common::RowReaderOptions rowReaderOptions; + // Project none of the columns. + setScanSpec(*ROW({}, {}), rowReaderOptions); + rowReaderOptions.select( + std::make_shared(ROW({}, {}))); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + VectorPtr result; + // We expect to get 16 rows. + ASSERT_EQ(rowReader->next(16, result), 16); + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, compressedProjectNone) { + // Tests the case where none of the columns are projected, e.g. a basic + // count(*) query. + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", + "examples/simple_types_compressed_file.gz"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + dwio::common::RowReaderOptions rowReaderOptions; + // Project none of the columns. + setScanSpec(*ROW({}, {}), rowReaderOptions); + rowReaderOptions.select( + std::make_shared(ROW({}, {}))); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + VectorPtr result; + // We expect to get 12 rows. + ASSERT_EQ(rowReader->next(12, result), 12); + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, compressedFilter) { + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", + "examples/simple_types_compressed_file.gz"); + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto spec = std::make_shared(""); + spec->addField("ds", 0)->setConstantValue( + BaseVector::createConstant(VARCHAR(), "2023-07-18", 1, pool())); + spec->addField("col_int", 1); + spec->getOrCreateChild(common::Subfield("col_string")) + ->setFilter( + std::make_unique( + std::vector({"BAR"}), false)); + dwio::common::RowReaderOptions rowOptions; + rowOptions.setScanSpec(spec); + rowOptions.select( + std::make_shared(type, type->names())); + auto rowReader = reader->createRowReader(rowOptions); + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 10); + ASSERT_EQ(*result->type(), *ROW({"ds", "col_int"}, {VARCHAR(), INTEGER()})); + auto expected = makeRowVector({ + std::make_shared>( + pool(), 4, false, VARCHAR(), "2023-07-18"), + makeFlatVector({5, 6, 7, 8}), + }); + ASSERT_EQ(result->size(), expected->size()); + for (int i = 0; i < expected->size(); ++i) { + ASSERT_TRUE(result->equalValueAt(expected.get(), i, i)); + } +} + +TEST_F(TextReaderTest, filter) { + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_big_int", BIGINT()}, + {"col_bool", BOOLEAN()}, + {"col_timestamp", TIMESTAMP()}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/more_simple_types"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + auto spec = std::make_shared(""); + spec->addField("ds", 0)->setConstantValue( + BaseVector::createConstant(VARCHAR(), "2023-07-18", 1, pool())); + spec->addField("col_big_int", 1); + spec->getOrCreateChild(common::Subfield("col_string")) + ->setFilter( + std::make_unique( + std::vector({"BAR", "BAZ"}), false)); + + dwio::common::RowReaderOptions rowOptions; + rowOptions.setScanSpec(spec); + rowOptions.select( + std::make_shared(type, type->names())); + + auto rowReader = reader->createRowReader(rowOptions); + VectorPtr result; + + ASSERT_EQ(rowReader->next(15, result), 13); + + ASSERT_EQ( + *result->type(), *ROW({"ds", "col_big_int"}, {VARCHAR(), BIGINT()})); + auto expected = makeRowVector({ + std::make_shared>( + pool(), 7, false, VARCHAR(), "2023-07-18"), + makeFlatVector({ + 4192, + 4193, + 4192, + 4192, + 4194, + 4192, + 4195, + }), + }); + + ASSERT_EQ(result->size(), expected->size()); + for (int i = 0; i < expected->size(); ++i) { + ASSERT_TRUE(result->equalValueAt(expected.get(), i, i)); + } +} + +TEST_F(TextReaderTest, shrinkBatch) { + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/simple_types"); + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto spec = std::make_shared(""); + dwio::common::RowReaderOptions rowOptions; + rowOptions.setScanSpec(spec); + rowOptions.select( + std::make_shared(ROW({}, {}))); + auto rowReader = reader->createRowReader(rowOptions); + VectorPtr result; + + ASSERT_EQ(rowReader->next(6, result), 6); + ASSERT_EQ(result->size(), 6); + ASSERT_EQ(rowReader->next(4, result), 4); + ASSERT_EQ(result->size(), 4); + ASSERT_EQ(rowReader->next(6, result), 6); + ASSERT_EQ(result->size(), 6); + ASSERT_EQ(rowReader->next(4, result), 0); +} + +TEST_F(TextReaderTest, compressedShrinkBatch) { + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", + "examples/simple_types_compressed_file.gz"); + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto spec = std::make_shared(""); + dwio::common::RowReaderOptions rowOptions; + rowOptions.setScanSpec(spec); + rowOptions.select( + std::make_shared(ROW({}, {}))); + auto rowReader = reader->createRowReader(rowOptions); + VectorPtr result; + ASSERT_EQ(rowReader->next(6, result), 6); + ASSERT_EQ(result->size(), 6); + ASSERT_EQ(rowReader->next(4, result), 4); + ASSERT_EQ(result->size(), 4); + ASSERT_EQ(rowReader->next(4, result), 2); + ASSERT_EQ(result->size(), 2); + ASSERT_EQ(rowReader->next(4, result), 0); +} + +TEST_F(TextReaderTest, emptyFile) { + auto type = ROW({ + {"transaction_id", VARCHAR()}, + {"serial_number", VARCHAR()}, + }); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/empty.gz"); + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto rowReaderOptions = dwio::common::RowReaderOptions(); + setScanSpec(*type, rowReaderOptions); + rowReaderOptions.setSkipRows(1); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + EXPECT_EQ(*reader->rowType(), *type); + VectorPtr result; + // Try reading 10 rows each time. + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, readRanges) { + auto expected = makeRowVector( + {makeFlatVector( + {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}), + makeFlatVector({ + 4191, + 4192, + 4192, + 4196, + 4192, + 4193, + 4192, + 4192, + 4194, + 4192, + 4195, + 4192, + 4192, + 4192, + 4192, + }), + makeFlatVector({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})}); + + auto type = + ROW({{"id", BIGINT()}, {"org_id", BIGINT()}, {"deleted", TINYINT()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", + "examples/simple_types_10_bytes_per_row"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + const int bytesPerRows = 10; + + dwio::common::RowReaderOptions rowReaderOptions; + VectorPtr result; + + // read from 1st row to 6th row + rowReaderOptions.range(0, 5 * bytesPerRows); + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + uint64_t scanned = rowReader->next(1024, result); + EXPECT_EQ(scanned, 6); + for (int i = 0; i < 6; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + + // read from 6th row to 10th row + rowReaderOptions.range(5 * bytesPerRows, 5 * bytesPerRows); + setScanSpec(*type, rowReaderOptions); + rowReader = reader->createRowReader(rowReaderOptions); + scanned = rowReader->next(1024, result); + EXPECT_EQ(scanned, 5); + for (int i = 0; i < 5; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 6 + i)); + } + + // read from 11th row to 15th row + rowReaderOptions.range(10 * bytesPerRows, 5 * bytesPerRows); + setScanSpec(*type, rowReaderOptions); + rowReader = reader->createRowReader(rowReaderOptions); + scanned = rowReader->next(1024, result); + EXPECT_EQ(scanned, 4); + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 11 + i)); + } +} + +TEST_F(TextReaderTest, readFloatAsInt) { + auto expected = makeRowVector({ + makeFlatVector( + {11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}), + makeFlatVector( + {4191, + 4192, + 4192, + 4196, + 4192, + 4193, + 4192, + 4192, + 4194, + 4192, + 4195, + 4192, + 4192, + 4192, + 4192, + 4192, + 4192}), + makeFlatVector({0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), + makeNullableFlatVector( + {1, + 2, + -6, + 4, + std::nullopt, + 79, + std::nullopt, + 3, + 93, + -221, + std::nullopt, + 950, + -4123, + 43, + std::nullopt, + std::nullopt}), + }); + + auto type = ROW( + {{"id", INTEGER()}, + {"org_id", BIGINT()}, + {"deleted", TINYINT()}, + {"ratio", INTEGER()}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/simple_types"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(6, result), 6); + for (int i = 0; i < 6; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } +} + +TEST_F(TextReaderTest, simpleTypes) { + const auto tzOffsetPST = 28'800; + const auto tzOffsetPDT = 25'200; + + auto expected = makeRowVector({ + makeFlatVector( + {"FOO", + "FOO", + "FOO", + "FOO", + "BAR", + "BAR", + "BAR", + "BAR", + "BAZ", + "BAZ", + "BAZ", + "FOO", + "FOOBARBAZ"}), + makeFlatVector( + {4191, + 4192, + 4192, + 4196, + 4192, + 4193, + 4192, + 4192, + 4194, + 4192, + 4195, + 4192, + 4192}), + makeFlatVector( + {true, + true, + true, + true, + true, + true, + false, + false, + false, + false, + true, + false, + true}), + makeNullableFlatVector( + {Timestamp{1'695'378'095 + tzOffsetPDT, 148'000'000}, + Timestamp{1'695'690'351 + tzOffsetPDT, 0}, + Timestamp{1'695'686'400 + tzOffsetPDT, 0}, + Timestamp{1'695'083'400 + tzOffsetPDT, 0}, + std::nullopt, + Timestamp{1'695'657'091 + tzOffsetPDT, 209'000'000}, + Timestamp{1'695'690'437 + tzOffsetPDT, 469'123'000}, + Timestamp{1'696'540'679 + tzOffsetPDT, 976'000'000}, + Timestamp{1'695'657'171 + tzOffsetPDT, 637'000'000}, + Timestamp{1'695'693'225 + tzOffsetPDT, 745'123'000}, + std::nullopt, + Timestamp{1'695'406'246 + tzOffsetPDT, 0}, + Timestamp{1'699'392'124 + tzOffsetPST, 736'000'000}}), + }); + + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_big_int", BIGINT()}, + {"col_bool", BOOLEAN()}, + {"col_timestamp", TIMESTAMP()}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/more_simple_types"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(11, result), 11); + for (int i = 0; i < 11; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(6, result), 2); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 11 + i)); + } +} + +TEST_F(TextReaderTest, primitiveLimitsStressTest) { + // Create expected vectors with 100 min values followed by 100 max values + std::vector tinyintValues; + std::vector integerValues; + std::vector bigintValues; + + // Add 100 minimum values + for (int i = 0; i < 100; ++i) { + tinyintValues.push_back(std::numeric_limits::min()); + integerValues.push_back(std::numeric_limits::min()); + bigintValues.push_back(std::numeric_limits::min()); + } + + // Add 100 maximum values + for (int i = 0; i < 100; ++i) { + tinyintValues.push_back(std::numeric_limits::max()); + integerValues.push_back(std::numeric_limits::max()); + bigintValues.push_back(std::numeric_limits::max()); + } + + auto expected = makeRowVector( + {makeFlatVector(tinyintValues), + makeFlatVector(integerValues), + makeFlatVector(bigintValues)}); + + auto type = ROW( + {{"col_tinyint", TINYINT()}, + {"col_integer", INTEGER()}, + {"col_bigint", BIGINT()}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/primitive_limits"); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + auto serDeOptions = dwio::common::SerDeOptions('\t', '=', '|', '\\', true); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Test reading all 200 rows at once + ASSERT_EQ(rowReader->next(250, result), 200); + for (int i = 0; i < 200; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); + + // Test reading in smaller batches + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + rowReader = reader->createRowReader(rowReaderOptions); + + ASSERT_EQ(rowReader->next(50, result), 50); + for (int i = 0; i < 50; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(75, result), 75); + for (int i = 0; i < 75; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 50 + i)); + } + ASSERT_EQ(rowReader->next(100, result), 75); + for (int i = 0; i < 75; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 125 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, DISABLED_nestedComplexTypesWithCustomDelimiters) { + // Inner maps for the arrays + const auto innerMapKeys1 = makeFlatVector({1, 11, 22}); + const auto innerMapValues1 = makeFlatVector({true, false, true}); + const auto innerMapKeys2 = makeFlatVector({33, 44}); + const auto innerMapValues2 = makeFlatVector({false, true}); + const auto innerMapKeys3 = makeFlatVector({55, 66, 77, 88}); + const auto innerMapValues3 = makeFlatVector({true, true, false, true}); + const auto innerMapKeys4 = makeFlatVector(std::vector{99}); + const auto innerMapValues4 = makeFlatVector(std::vector{false}); + const auto innerMapKeys5 = makeFlatVector({100, 200}); + const auto innerMapValues5 = makeFlatVector({true, false}); + const auto innerMapKeys6 = makeFlatVector({300, 400, 500}); + const auto innerMapValues6 = makeFlatVector({false, false, true}); + + // Combine all inner map keys and values + const auto allInnerMapKeys = makeFlatVector( + {1, 11, 22, 33, 44, 55, 66, 77, 88, 99, 100, 200, 300, 400, 500}); + const auto allInnerMapValues = makeFlatVector( + {true, + false, + true, + false, + true, + true, + true, + false, + true, + false, + true, + false, + false, + false, + true}); + + // Create inner maps with proper offsets and sizes + BufferPtr innerMapSizes = allocateOffsets(6, pool()); + BufferPtr innerMapOffsets = allocateOffsets(6, pool()); + auto rawInnerMapSizes = innerMapSizes->asMutable(); + auto rawInnerMapOffsets = innerMapOffsets->asMutable(); + + rawInnerMapSizes[0] = 3; // {1:true, 11:false, 22:true} + rawInnerMapSizes[1] = 2; // {33:false, 44:true} + rawInnerMapSizes[2] = 4; // {55:true, 66:true, 77:false, 88:true} + rawInnerMapSizes[3] = 1; // {99:false} + rawInnerMapSizes[4] = 2; // {100:true, 200:false} + rawInnerMapSizes[5] = 3; // {300:false, 400:false, 500:true} + + rawInnerMapOffsets[0] = 0; + for (int i = 1; i < 6; i++) { + rawInnerMapOffsets[i] = rawInnerMapOffsets[i - 1] + rawInnerMapSizes[i - 1]; + } + + auto innerMapsVector = std::make_shared( + pool(), + MAP(BIGINT(), BOOLEAN()), + nullptr, + 6, + innerMapOffsets, + innerMapSizes, + allInnerMapKeys, + allInnerMapValues); + + // Create arrays containing the inner maps + // Array 1: [innerMap0, innerMap1] (maps at indices 0, 1) + // Array 2: [innerMap2, innerMap3, innerMap4] (maps at indices 2, 3, 4) + // Array 3: [innerMap5] (map at index 5) + auto arrayVector = makeArrayVector({0, 2, 5, 6}, innerMapsVector); + + // Create the outer map keys + const auto outerMapKeys = makeFlatVector({10, 20, 30}); + + // Create the final data structure + auto outerMapOffsets = allocateOffsets(3, pool()); + auto outerMapSizes = allocateOffsets(3, pool()); + auto rawOuterMapOffsets = outerMapOffsets->asMutable(); + auto rawOuterMapSizes = outerMapSizes->asMutable(); + + // Set offsets: [0, 1, 2] + rawOuterMapOffsets[0] = 0; + rawOuterMapOffsets[1] = 1; + rawOuterMapOffsets[2] = 2; + + // Set sizes: [1, 1, 1] (each outer map entry contains 1 array) + rawOuterMapSizes[0] = 1; + rawOuterMapSizes[1] = 1; + rawOuterMapSizes[2] = 1; + + const auto expected = makeRowVector({std::make_shared( + pool(), + MAP(BIGINT(), ARRAY(MAP(BIGINT(), BOOLEAN()))), + nullptr, + 3, + outerMapOffsets, + outerMapSizes, + outerMapKeys, + arrayVector)}); + + const auto type = + ROW({{"col_nested_map", MAP(BIGINT(), ARRAY(MAP(BIGINT(), BOOLEAN())))}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", + "examples/custom_delimiter_nested_complex_types"); + auto readFile = std::make_shared(path); + + // Set up custom delimiters: + // - Tab ('\t') for field separation (depth 0) + // - Pipe ('|') for array element separation (depth 1) + // - Hash ('#') for map key-value separation (depth 2) + // - ',' for inner map key-value pair separation (depth 3) + auto serDeOptions = dwio::common::SerDeOptions('\t', '=', '|', '\\', true); + serDeOptions.separators[3] = ','; + serDeOptions.separators[4] = ':'; + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReaderOptions = dwio::common::RowReaderOptions(); + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Read all 3 rows + ASSERT_EQ(rowReader->next(10, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, nestedArraysWithCustomDelimiters) { + // Create expected nested array structure + // Row 1: [[1,2,3], [4,5], [6,7,8,9]] + // Row 2: [[10,20], [30,40,50], [60]] + // Row 3: [[100], [200,300,400], [500,600]] + + // Create inner arrays for each row + auto innerArraysRow1 = + makeArrayVector({{1, 2, 3}, {4, 5}, {6, 7, 8, 9}}); + auto innerArraysRow2 = + makeArrayVector({{10, 20}, {30, 40, 50}, {60}}); + auto innerArraysRow3 = + makeArrayVector({{100}, {200, 300, 400}, {500, 600}}); + + // Combine all inner arrays into a single vector + auto allInnerArrays = makeArrayVector({ + {1, 2, 3}, + {4, 5}, + {6, 7, 8, 9}, // Row 1 inner arrays + {10, 20}, + {30, 40, 50}, + {60}, // Row 2 inner arrays + {100}, + {200, 300, 400}, + {500, 600} // Row 3 inner arrays + }); + + // Create the outer array structure with proper offsets + // Row 1: uses inner arrays 0, 1, 2 + // Row 2: uses inner arrays 3, 4, 5 + // Row 3: uses inner arrays 6, 7, 8 + auto outerArray = makeArrayVector({0, 3, 6, 9}, allInnerArrays); + + const auto expected = makeRowVector({outerArray}); + + const auto type = ROW({{"col_nested_array", ARRAY(ARRAY(BIGINT()))}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/nested_arrays_file"); + auto readFile = std::make_shared(path); + + // Set up custom delimiters for nested arrays: + // - Tab ('\t') for field separation (depth 0) - not used in single-column + // case + // - Pipe ('|') for outer array element separation (depth 1) + // - Comma (',') for inner array element separation (depth 2) + auto serDeOptions = dwio::common::SerDeOptions('\t', '|', ',', '\\', true); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReaderOptions = dwio::common::RowReaderOptions(); + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Read all 3 rows + ASSERT_EQ(rowReader->next(10, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, tripleNestedArraysWithCustomDelimiters) { + // Create expected triple nested array structure + // Row 1: [[[1,2], [3,4]], [[5,6,7], [8,9]], [[10,11,12,13], [14,15]]] + // Row 2: [[[20,21], [22,23,24]], [[25,26]], [[27,28,29], [30,31,32]]] + // Row 3: [[[100,101,102]], [[200,201], [300,301,302,303]], [[400,401,402], + // [500,501]]] + + // Create the innermost arrays (level 3) + auto innermostArrays = makeArrayVector({ + {1, 2}, + {3, 4}, // Row 1, outer array 0 + {5, 6, 7}, + {8, 9}, // Row 1, outer array 1 + {10, 11, 12, 13}, + {14, 15}, // Row 1, outer array 2 + {20, 21}, + {22, 23, 24}, // Row 2, outer array 0 + {25, 26}, // Row 2, outer array 1 + {27, 28, 29}, + {30, 31, 32}, // Row 2, outer array 2 + {100, 101, 102}, // Row 3, outer array 0 + {200, 201}, + {300, 301, 302, 303}, // Row 3, outer array 1 + {400, 401, 402}, + {500, 501} // Row 3, outer array 2 + }); + + // Create middle level arrays (level 2) - each contains innermost arrays + auto middleArrays = makeArrayVector( + { + 0, // Row 1, outer array 0: contains innermost arrays [0,1] + 2, // Row 1, outer array 1: contains innermost arrays [2,3] + 4, // Row 1, outer array 2: contains innermost arrays [4,5] + 6, // Row 2, outer array 0: contains innermost arrays [6,7] + 8, // Row 2, outer array 1: contains innermost arrays [8] + 9, // Row 2, outer array 2: contains innermost arrays [9,10] + 11, // Row 3, outer array 0: contains innermost arrays [11] + 12, // Row 3, outer array 1: contains innermost arrays [12,13] + 14, // Row 3, outer array 2: contains innermost arrays [14,15] + 16 // End marker + }, + innermostArrays); + + // Create outermost arrays (level 1) - each row contains middle arrays + auto outerArray = makeArrayVector( + { + 0, 3, 6, 9 // Row boundaries: Row 1 [0-2], Row 2 [3-5], Row 3 [6-8] + }, + middleArrays); + + const auto expected = makeRowVector({outerArray}); + + const auto type = + ROW({{"col_triple_nested_array", ARRAY(ARRAY(ARRAY(BIGINT())))}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/triple_nested_arrays_file"); + auto readFile = std::make_shared(path); + + // Set up custom delimiters for triple nested arrays: + // - Tab ('\t') for field separation (depth 0) - not used in single-column + // case + // - Pipe ('|') for outermost array element separation (depth 1) + // - Comma (',') for middle array element separation (depth 2) + // - Hash ('#') for innermost array element separation (depth 3) + auto serDeOptions = dwio::common::SerDeOptions('\t', '|', ',', '\\', true); + serDeOptions.separators[3] = '#'; + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + auto rowReaderOptions = dwio::common::RowReaderOptions(); + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Read all 3 rows + ASSERT_EQ(rowReader->next(10, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextReaderTest, varbinarySuccessfulDecoding) { + // Test successful Base64 decoding for VARBINARY type + + auto type = ROW({{"col_binary", VARBINARY()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/varbinary"); + + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 2); + + auto rowVector = std::static_pointer_cast(result); + auto binaryVector = rowVector->childAt(0)->as>(); + + // Verify the decoded binary data + EXPECT_EQ(binaryVector->valueAt(0), StringView("Hello World")); + EXPECT_EQ(binaryVector->valueAt(1), StringView("TestData")); +} + +TEST_F(TextReaderTest, varbinaryUnsuccessfulDecoding) { + // Test unsuccessful Base64 decoding for VARBINARY type + + auto type = ROW({{"col_binary", VARBINARY()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/varbinary_unsuccessful"); + + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 2); + + auto rowVector = std::static_pointer_cast(result); + auto binaryVector = rowVector->childAt(0)->as>(); + + // Verify the data was copied as-is (fallback behavior) + // When Base64 decoding fails, the original string is copied directly + EXPECT_EQ(binaryVector->valueAt(0), StringView("InvalidBase64!")); + EXPECT_EQ(binaryVector->valueAt(1), StringView("Another@Invalid#String")); +} + +TEST_F(TextReaderTest, logicalTypes) { + auto expected = makeRowVector( + {makeNullableFlatVector( + {0, + 123, + -1234567, + 999999999999999, + std::nullopt, + 4242, + -1, + std::nullopt, + 314159265358979, + 77777, + 100000000000000, + -5432199, + std::nullopt, + 1234, + -999999999999999, + 999999999999999}, + DECIMAL(15, 2)), + makeNullableFlatVector( + {0, + HugeInt::parse("999999999999999999999"), + HugeInt::parse("123456789012345678901234567890"), + HugeInt::parse("-99999999999999999999999999"), + HugeInt::parse("88888888888888888888"), + std::nullopt, + 1, + std::nullopt, + HugeInt::parse("27182818284590452353612"), + HugeInt::parse("-123456789012345678999"), + HugeInt::parse("12345678901234567890123456789012345678"), + 987654321012, + std::nullopt, + 5678, + -123, + HugeInt::parse("99999999999999999999999999999999")}, + DECIMAL(38, 2)), + makeNullableFlatVector( + { + parseDate("1970-01-01"), + parseDate("2024-02-29"), + parseDate("1900-01-01"), + parseDate("2099-12-31"), + parseDate("2001-09-11"), + parseDate("2025-09-10"), + std::nullopt, + std::nullopt, + parseDate("1999-12-31"), + parseDate("2012-12-21"), + parseDate("2200-01-01"), + parseDate("1988-08-08"), + parseDate("1969-07-20"), + parseDate("2000-01-01"), + parseDate("1800-06-15"), + parseDate("2500-12-31"), + }, + DATE())}); + + auto type = + ROW({{"c0", DECIMAL(15, 2)}, {"c1", DECIMAL(38, 2)}, {"c2", DATE()}}); + + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/logical_types.gz"); + + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 6); + for (int i = 0; i < 6; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } +} + +TEST_F(TextReaderTest, nestedRows) { + auto nestedRowChildren = std::vector{ + makeFlatVector({42, 100, -5, 0, 999}), + makeFlatVector({true, false, true, false, true}), + makeArrayVector( + {{3.14159, 2.71828}, + {2.71828, 1.41421, 0.0}, + {1.41421, -123.456}, + {0.0, 999.999}, + {-123.456, 42.0, 3.14159}})}; + auto nestedRowVector = makeRowVector( + {"nested_int", "nested_bool", "nested_arr_double"}, nestedRowChildren); + + const auto expected = makeRowVector( + {makeFlatVector( + {"hello", "world", "test", "sample", "data"}), + nestedRowVector, + makeFlatVector({false, true, false, true, false})}); + + auto type = ROW( + {{"col_varchar", VARCHAR()}, + {"col_nested_row", + ROW( + {{"nested_int", INTEGER()}, + {"nested_bool", BOOLEAN()}, + {"nested_arr_double", ARRAY(DOUBLE())}})}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = velox::test::getDataFilePath( + "velox/dwio/text/tests/reader/", "examples/nested_row"); + + auto readFile = std::make_shared(path); + auto serDeOptions = dwio::common::SerDeOptions('&', ',', '#', '\\', true); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + ASSERT_EQ(rowReader->next(10, result), 5); + + for (int i = 0; i < 5; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_P(TextReaderDecompressionTest, tests) { + auto [filepath, format] = GetParam(); + auto expected = makeRowVector({ + makeFlatVector( + {"FOO", + "FOO", + "FOO", + "FOO", + "BAR", + "BAR", + "BAR", + "BAR", + "BAZ", + "BAZ", + "BAZ", + "BAZ"}), + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), + makeFlatVector( + {1.123, + 2.333, + -6.1, + 4.2, + 47.2, + 79.5, + 3.1415926, + -221.145, + 93.12, + -4123.11, + 950.2, + 43.66}), + makeFlatVector( + {true, + true, + false, + true, + false, + false, + true, + true, + false, + false, + false, + true}), + }); + + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + + auto path = + velox::test::getDataFilePath("velox/dwio/text/tests/reader/", filepath); + auto readFile = std::make_shared(path); + + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = factory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*type, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *type); + + VectorPtr result; + + // Try reading 10 rows each time. + ASSERT_EQ(rowReader->next(10, result), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 2); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 10 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); + + input = std::make_unique(readFile, poolRef()); + reader = factory->createReader(std::move(input), readerOptions); + rowReader = reader->createRowReader(rowReaderOptions); + // Try reading 2, 3, 4, 5 rows at a time. + ASSERT_EQ(rowReader->next(2, result), 2); + for (int i = 0; i < 2; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 2 + i)); + } + ASSERT_EQ(rowReader->next(4, result), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 5 + i)); + } + ASSERT_EQ(rowReader->next(5, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, 9 + i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +std::vector params = { + {"examples/simple_types_compressed_file.deflate", "deflate"}, + {"examples/simple_types_compressed_file.zst", "zstd"}}; + +INSTANTIATE_TEST_SUITE_P( + TextReaderDecompressionTests, + TextReaderDecompressionTest, + testing::ValuesIn(params), + [](const auto& paramInfo) { return paramInfo.param.compression; }); + +TEST_F(TextReaderTest, unsupportedCompressedKind) { + auto type = ROW( + {{"col_string", VARCHAR()}, + {"col_int", INTEGER()}, + {"col_float", DOUBLE()}, + {"col_bool", BOOLEAN()}}); + auto factory = dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + const std::string kBaseDir = "velox/dwio/text/tests/reader/"; + std::vector paths = { + getDataFilePath(kBaseDir, "examples/simple_types_compressed_file.lz4"), + getDataFilePath(kBaseDir, "examples/simple_types_compressed_file.lzo"), + getDataFilePath( + kBaseDir, "examples/simple_types_compressed_file.snappy")}; + for (const auto& path : paths) { + auto readFile = std::make_shared(path); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(type); + auto input = + std::make_unique(readFile, poolRef()); + EXPECT_THROW( + factory->createReader(std::move(input), readerOptions), + VeloxRuntimeError); + } +} + +} // namespace + +} // namespace facebook::velox::text diff --git a/velox/dwio/text/tests/reader/examples/custom_delimiter_nested_complex_types b/velox/dwio/text/tests/reader/examples/custom_delimiter_nested_complex_types new file mode 100644 index 000000000000..3e8ef4177cc7 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/custom_delimiter_nested_complex_types @@ -0,0 +1,3 @@ +10=1:true,11:false,22:true|33:false,44:true +20=55:true,66:true,77:false,88:true|99:false|100:true,200:false +30=300:false,400:false,500:true diff --git a/velox/dwio/text/tests/reader/examples/custom_delimiters_file b/velox/dwio/text/tests/reader/examples/custom_delimiters_file new file mode 100644 index 000000000000..6913c764383f --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/custom_delimiters_file @@ -0,0 +1,13 @@ +FOO 1|11|111 1.123|1.3123 1#false|111#true +FOO 22|22222 2.333|-5512|1.23 22#true|22222#false +FOO 333|33 -6.1|65.777 333#false|33#true +FOO 4444|44 4.2|24|324.11 44#false +BAR 5|555 47.2|213.23 5#true|555#false +BAR 666|66|66 79.5|-44.11 66#true +BAR 7777|7|777 3.1415926|441.124 7777#false|7#true|777#false +BAR 8888|88 -221.145|878.43|-11 8888#false|88#true +BAZ 9|99 93.12|632 9#true|99#false +BAZ 10|10000 -4123.11|-177.1 10#true|10000#false +BAZ 111|1|111 950.2|-4412 111#true|1#false|11#true +FOO\\nBAZ 12|11122222|222 43.66|33121.43 11122222#true +FOO\n\nBAR\nBAZ 13|11133333|333 -42.11|-123.43 123142#true diff --git a/velox/dwio/text/tests/reader/examples/empty.gz b/velox/dwio/text/tests/reader/examples/empty.gz new file mode 100644 index 000000000000..048927fa6500 Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/empty.gz differ diff --git a/velox/dwio/text/tests/reader/examples/logical_types.gz b/velox/dwio/text/tests/reader/examples/logical_types.gz new file mode 100644 index 000000000000..3a20dc8ceeef Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/logical_types.gz differ diff --git a/velox/dwio/text/tests/reader/examples/more_simple_types b/velox/dwio/text/tests/reader/examples/more_simple_types new file mode 100644 index 000000000000..674f64b8725a --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/more_simple_types @@ -0,0 +1,13 @@ +FOO4191true2023-09-22 10:21:35.148 +FOO4192true2023-09-26 01:05:51.000 +FOO4192true2023-09-26 +FOO4196true2023-09-19 00:30:00 +BAR4192true +BAR4193true2023-09-25 15:51:31.209 +BAR4192false2023-09-26 01:07:17.469123 +BAR4192false2023-10-05 21:17:59.976 +BAZ4194false2023-09-25 15:52:51.637 +BAZ4192false2023-09-26 01:53:45.745123456 +BAZ4195TRUE +FOO4192FALSE2023-09-22 18:10:46.000000001 +FOOBARBAZ4192TRUE2023-11-07 21:22:04.736 diff --git a/velox/dwio/text/tests/reader/examples/nested_arrays_file b/velox/dwio/text/tests/reader/examples/nested_arrays_file new file mode 100644 index 000000000000..af9a7919a099 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/nested_arrays_file @@ -0,0 +1,3 @@ +1,2,3|4,5|6,7,8,9 +10,20|30,40,50|60 +100|200,300,400|500,600 diff --git a/velox/dwio/text/tests/reader/examples/nested_row b/velox/dwio/text/tests/reader/examples/nested_row new file mode 100644 index 000000000000..98c65a885aa7 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/nested_row @@ -0,0 +1,5 @@ +hello&42,true,3.14159#2.71828&false +world&100,false,2.71828#1.41421#.0&true +test&-5,true,1.41421#-123.456&false +sample&0,false,.0#999.999&true +data&999,true,-123.456#42.0#3.14159&false diff --git a/velox/dwio/text/tests/reader/examples/primitive_limits b/velox/dwio/text/tests/reader/examples/primitive_limits new file mode 100644 index 000000000000..064cd7799ba5 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/primitive_limits @@ -0,0 +1,200 @@ +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +-128 -2147483648 -9223372036854775808 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 +127 2147483647 9223372036854775807 diff --git a/velox/dwio/text/tests/reader/examples/simple_types b/velox/dwio/text/tests/reader/examples/simple_types new file mode 100644 index 000000000000..6a102f6446bb --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/simple_types @@ -0,0 +1,16 @@ +11419101.123 +12419202.333 +1341920-6.1 +14419604.2 +15419205.13e-05 +164193079.5 +1741920 +18419203.1415926 +194194093.12 +2041920-221.145 +2141950 +2241920950.2 +2341920-4123.11 +244192043.66 +2541920 +264192054.aa diff --git a/velox/dwio/text/tests/reader/examples/simple_types_10_bytes_per_row b/velox/dwio/text/tests/reader/examples/simple_types_10_bytes_per_row new file mode 100644 index 000000000000..bef0c0eafb8e --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/simple_types_10_bytes_per_row @@ -0,0 +1,15 @@ +1141910 +1241920 +1341920 +1441960 +1541920 +1641930 +1741920 +1841920 +1941940 +2041920 +2141950 +2241920 +2341920 +2441920 +2541920 diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.deflate b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.deflate new file mode 100644 index 000000000000..5f34cf5e6e46 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.deflate @@ -0,0 +1,2 @@ +U�I�0 P�m����l%�B�Tʦ�?F5,�<|�m��Eq���B�(R&F[�#��u��u� �T� +��ث�k�$��?;QQc�u�g$���yTF������o \ No newline at end of file diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.gz b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.gz new file mode 100644 index 000000000000..fd1eb227accf Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.gz differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lz4 b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lz4 new file mode 100644 index 000000000000..fa1adc4234d1 Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lz4 differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lzo b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lzo new file mode 100644 index 000000000000..951ed722424a Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.lzo differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.snappy b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.snappy new file mode 100644 index 000000000000..3343cdfc53ea Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.snappy differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.zst b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.zst new file mode 100644 index 000000000000..fbeda9616019 Binary files /dev/null and b/velox/dwio/text/tests/reader/examples/simple_types_compressed_file.zst differ diff --git a/velox/dwio/text/tests/reader/examples/simple_types_with_header b/velox/dwio/text/tests/reader/examples/simple_types_with_header new file mode 100644 index 000000000000..6403c7678819 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/simple_types_with_header @@ -0,0 +1,14 @@ +nameidpricedate +FOO1 1.123 2023-09-22 10:21:35.148 +FOO2 2.3332023-09-26 01:05:51.000 +FOO3-6.12023-09-26 +FOO44.22023-09-19 00:30:00 +BAR5 47.2  +BAR679.52023-09-25 15:51:31.209 +BAR73.1415926 2023-09-26 01:07:17.469123 +BAR8-221.1452023-10-05 21:17:59.976 +BAZ993.122023-09-25 15:52:51.637 +BAZ10 -4123.11 2023-09-26 01:53:45.745123456 +BAZ11950.2 +BAZ1243.662023-09-22 18:10:46.000000001 +132023-11-07 21:22:04.736 diff --git a/velox/dwio/text/tests/reader/examples/triple_nested_arrays_file b/velox/dwio/text/tests/reader/examples/triple_nested_arrays_file new file mode 100644 index 000000000000..6f0e3e5b9a86 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/triple_nested_arrays_file @@ -0,0 +1,3 @@ +1#2,3#4|5#6#7,8#9|10#11#12#13,14#15 +20#21,22#23#24|25#26|27#28#29,30#31#32 +100#101#102|200#201,300#301#302#303|400#401#402,500#501 diff --git a/velox/dwio/text/tests/reader/examples/varbinary b/velox/dwio/text/tests/reader/examples/varbinary new file mode 100644 index 000000000000..5888f62049e9 --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/varbinary @@ -0,0 +1,2 @@ +SGVsbG8gV29ybGQ= +VGVzdERhdGE= diff --git a/velox/dwio/text/tests/reader/examples/varbinary_unsuccessful b/velox/dwio/text/tests/reader/examples/varbinary_unsuccessful new file mode 100644 index 000000000000..3bf6b863db2b --- /dev/null +++ b/velox/dwio/text/tests/reader/examples/varbinary_unsuccessful @@ -0,0 +1,2 @@ +InvalidBase64! +Another@Invalid#String diff --git a/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp b/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp index fb86463f36ea..a7372b7b6a74 100644 --- a/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp +++ b/velox/dwio/text/tests/writer/BufferedWriterSinkTest.cpp @@ -15,10 +15,8 @@ */ #include -#include "velox/common/base/Fs.h" #include "velox/common/file/FileSystems.h" #include "velox/dwio/text/tests/writer/FileReaderUtil.h" -#include "velox/dwio/text/writer/TextWriter.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/tests/utils/VectorTestBase.h" @@ -46,30 +44,40 @@ class BufferedWriterSinkTest : public testing::Test, }; TEST_F(BufferedWriterSinkTest, write) { - auto filePath = fs::path( - fmt::format("{}/test_buffered_writer.txt", tempPath_->getPath())); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_buffered_writer.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); auto sink = std::make_unique( filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto bufferedWriterSink = std::make_unique( std::move(sink), rootPool_->addLeafChild("bufferedWriterSinkTest"), 15); + bufferedWriterSink->write("hello world", 10); bufferedWriterSink->write("this is writer", 10); bufferedWriterSink->close(); - std::string result = readFile(filePath); - EXPECT_EQ(result.size(), 20); + + uint64_t result = readFile(tempPath, filename); + EXPECT_EQ(result, 20); } TEST_F(BufferedWriterSinkTest, abort) { - auto filePath = - fs::path(fmt::format("{}/test_buffered_abort.txt", tempPath_->getPath())); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_buffered_abort.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); auto sink = std::make_unique( filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto bufferedWriterSink = std::make_unique( std::move(sink), rootPool_->addLeafChild("bufferedWriterSinkTest"), 15); + bufferedWriterSink->write("hello world", 10); bufferedWriterSink->write("this is writer", 10); bufferedWriterSink->abort(); - std::string result = readFile(filePath); - EXPECT_EQ(result.size(), 10); + + uint64_t result = readFile(tempPath_->getPath(), filename); + EXPECT_EQ(result, 10); } } // namespace facebook::velox::text diff --git a/velox/dwio/text/tests/writer/CMakeLists.txt b/velox/dwio/text/tests/writer/CMakeLists.txt index ec3b2127b132..2d42a84a7f91 100644 --- a/velox/dwio/text/tests/writer/CMakeLists.txt +++ b/velox/dwio/text/tests/writer/CMakeLists.txt @@ -12,21 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_text_writer_test - TextWriterTest.cpp BufferedWriterSinkTest.cpp FileReaderUtil.cpp) +add_executable( + velox_text_writer_test + TextWriterTest.cpp + BufferedWriterSinkTest.cpp + FileReaderUtil.cpp +) add_test( NAME velox_text_writer_test COMMAND velox_text_writer_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_text_writer_test velox_dwio_text_writer + velox_dwio_text_reader velox_dwio_common_test_utils + velox_dwio_text_reader_register + velox_dwio_text_writer_register velox_link_libs - Boost::regex Folly::folly ${TEST_LINK_LIBS} GTest::gtest - fmt::fmt) + fmt::fmt +) diff --git a/velox/dwio/text/tests/writer/FileReaderUtil.cpp b/velox/dwio/text/tests/writer/FileReaderUtil.cpp index db5a62f96498..a7d20e661f90 100644 --- a/velox/dwio/text/tests/writer/FileReaderUtil.cpp +++ b/velox/dwio/text/tests/writer/FileReaderUtil.cpp @@ -15,29 +15,46 @@ */ #include "velox/dwio/text/tests/writer/FileReaderUtil.h" +#include "velox/common/file/File.h" +#include "velox/common/file/FileSystems.h" namespace facebook::velox::text { -std::string readFile(const std::string& name) { - std::ifstream file(name); - std::string line; +using dwio::common::SerDeOptions; - std::stringstream ss; - while (std::getline(file, line)) { - ss << line; - } - return ss.str(); +uint64_t readFile(const std::string& path, const std::string& name) { + const auto fs = filesystems::getFileSystem(path, nullptr); + auto filepath = fs::path(fmt::format("{}/{}", path, name)); + const auto& file = fs->openFileForRead(filepath.string()); + + return file->size(); } -std::vector> parseTextFile(const std::string& name) { - std::ifstream file(name); +std::vector> parseTextFile( + const std::string& path, + const std::string& name, + void* buffer, + SerDeOptions serDeOptions) { + const auto fs = filesystems::getFileSystem(path, nullptr); + auto filepath = fs::path(fmt::format("{}/{}", path, name)); + const auto& file = fs->openFileForRead(filepath.string()); + std::string line; std::vector> table; - while (std::getline(file, line)) { - std::vector row = splitTextLine(line, TextFileTraits::kSOH); - table.push_back(row); + auto fileSize = file->size(); + if (fileSize > 0) { + file->pread(0, fileSize, buffer); + std::string content(static_cast(buffer), fileSize); + + std::istringstream stream(content); + while (std::getline(stream, line)) { + std::vector row = + splitTextLine(line, serDeOptions.separators[0]); + table.push_back(row); + } } + return table; } diff --git a/velox/dwio/text/tests/writer/FileReaderUtil.h b/velox/dwio/text/tests/writer/FileReaderUtil.h index 42113c1a4bcd..65de295ef4d7 100644 --- a/velox/dwio/text/tests/writer/FileReaderUtil.h +++ b/velox/dwio/text/tests/writer/FileReaderUtil.h @@ -19,7 +19,16 @@ #include "velox/dwio/text/writer/TextWriter.h" namespace facebook::velox::text { -std::string readFile(const std::string& name); -std::vector> parseTextFile(const std::string& name); + +using dwio::common::SerDeOptions; + +uint64_t readFile(const std::string& filepath, const std::string& name); + +std::vector> parseTextFile( + const std::string& path, + const std::string& name, + void* buffer, + SerDeOptions serDeOptions = SerDeOptions()); + std::vector splitTextLine(const std::string& str, char delimiter); } // namespace facebook::velox::text diff --git a/velox/dwio/text/tests/writer/TextWriterTest.cpp b/velox/dwio/text/tests/writer/TextWriterTest.cpp index b1bc16bb243a..d007c9323624 100644 --- a/velox/dwio/text/tests/writer/TextWriterTest.cpp +++ b/velox/dwio/text/tests/writer/TextWriterTest.cpp @@ -15,31 +15,51 @@ */ #include "velox/dwio/text/writer/TextWriter.h" -#include -#include "velox/common/base/Fs.h" +#include "velox/buffer/Buffer.h" #include "velox/common/file/FileSystems.h" +#include "velox/dwio/text/RegisterTextReader.h" +#include "velox/dwio/text/RegisterTextWriter.h" #include "velox/dwio/text/tests/writer/FileReaderUtil.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/vector/tests/utils/VectorTestBase.h" +#include + +/// TODO: Add fuzzer test. + namespace facebook::velox::text { -// TODO: add fuzzer test once text reader is move in OSS class TextWriterTest : public testing::Test, public velox::test::VectorTestBase { public: void SetUp() override { velox::filesystems::registerLocalFileSystem(); - dwio::common::LocalFileSink::registerFactory(); + registerTextWriterFactory(); + registerTextReaderFactory(); rootPool_ = memory::memoryManager()->addRootPool("TextWriterTests"); leafPool_ = rootPool_->addLeafChild("TextWriterTests"); tempPath_ = exec::test::TempDirectoryPath::create(); } + void TearDown() override { + unregisterTextWriterFactory(); + unregisterTextReaderFactory(); + } + protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + void setScanSpec(const Type& type, dwio::common::RowReaderOptions& options) { + auto spec = std::make_shared("root"); + spec->addAllChildFields(type); + options.setScanSpec(spec); + } + + memory::MemoryPool& poolRef() { + return *pool(); + } + constexpr static float kInf = std::numeric_limits::infinity(); constexpr static double kNaN = std::numeric_limits::quiet_NaN(); std::shared_ptr rootPool_; @@ -78,8 +98,10 @@ TEST_F(TextWriterTest, write) { WriterOptions writerOptions; writerOptions.memoryPool = rootPool_.get(); - auto filePath = - fs::path(fmt::format("{}/test_text_writer.txt", tempPath_->getPath())); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_text_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); auto sink = std::make_unique( filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); auto writer = std::make_unique( @@ -89,9 +111,18 @@ TEST_F(TextWriterTest, write) { writer->write(data); writer->close(); - std::vector> result = parseTextFile(filePath); + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf); + EXPECT_EQ(result.size(), 3); EXPECT_EQ(result[0].size(), 10); + // bool type EXPECT_EQ(result[0][0], "true"); EXPECT_EQ(result[1][0], "true"); @@ -128,9 +159,9 @@ TEST_F(TextWriterTest, write) { EXPECT_EQ(result[2][6], "3.100000"); // timestamp - EXPECT_EQ(result[0][7], "1970-01-01 00:00:00.000"); - EXPECT_EQ(result[1][7], "1970-01-01 00:00:01.001"); - EXPECT_EQ(result[2][7], "1970-01-01 00:00:02.002"); + EXPECT_EQ(result[0][7], "1969-12-31 16:00:00.000"); + EXPECT_EQ(result[1][7], "1969-12-31 16:00:01.001"); + EXPECT_EQ(result[2][7], "1969-12-31 16:00:02.002"); // varchar EXPECT_EQ(result[0][8], "hello"); @@ -143,6 +174,1088 @@ TEST_F(TextWriterTest, write) { EXPECT_EQ(result[2][9], "Y3Bw"); } +TEST_F(TextWriterTest, verifyWriteWithTextReader) { + auto schema = + ROW({"c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9"}, + { + TIMESTAMP(), + BOOLEAN(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + VARBINARY(), + }); + auto data = makeRowVector( + {"c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9"}, + { + makeFlatVector( + 3, [](auto i) { return Timestamp(i, i * 1'000'000); }), + makeConstant(true, 3), + makeFlatVector({1, 2, 3}), + makeFlatVector({1, 2, 3}), // TODO null + makeFlatVector({1, 2, 3}), + makeFlatVector({1, 2, 3}), + makeFlatVector({1.1, kInf, 3.1}), + makeFlatVector({1.1, kNaN, 3.1}), + makeFlatVector({"hello", "world", "cpp"}, VARCHAR()), + makeFlatVector({"hello", "world", "cpp"}, VARBINARY()), + }); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_text_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath.string(), + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions)); + writer->write(data); + writer->close(); + + // Set up reader. + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } +} + +TEST_F(TextWriterTest, mapAndArrayComplexTypes) { + const vector_size_t length = 13; + const auto keyVector = makeFlatVector( + {1, 111, 22, 22222, 333, 33, 44, 5, 555, 66, 7777, 7, + 777, 8888, 88, 9, 99, 10, 10000, 111, 1, 11, 11122222, 123142}); + const auto valueVector = makeFlatVector( + {false, true, true, false, false, true, false, true, + false, true, false, true, false, false, true, true, + false, true, false, true, false, true, true, true}); + BufferPtr sizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr offsets = facebook::velox::allocateOffsets(length, pool()); + auto rawSizes = sizes->asMutable(); + auto rawOffsets = offsets->asMutable(); + rawSizes[0] = 2; + rawSizes[1] = 2; + rawSizes[2] = 2; + rawSizes[3] = 1; + rawSizes[4] = 2; + rawSizes[5] = 1; + rawSizes[6] = 3; + rawSizes[7] = 2; + rawSizes[8] = 2; + rawSizes[9] = 2; + rawSizes[10] = 3; + rawSizes[11] = 1; + rawSizes[12] = 1; + for (int i = 1; i < length; i++) { + rawOffsets[i] = rawOffsets[i - 1] + rawSizes[i - 1]; + } + + const auto data = makeRowVector( + {makeArrayVector( + {{1, 11, 111}, + {22, 22222}, + {333, 33}, + {4444, 44}, + {5, 555}, + {666, 66, 66}, + {7777, 7, 777}, + {8888, 88}, + {9, 99}, + {10, 10000}, + {111, 1, 111}, + {12, 11122222, 222}, + {13, 11133333, 333}}), + makeArrayVector( + {{1.123, 1.3123}, + {2.333, -5512, 1.23}, + {-6.1, 65.777}, + {4.2, 24, 324.11}, + {47.2, 213.23}, + {79.5, -44.11}, + {3.1415926, 441.124}, + {-221.145, 878.43, -11}, + {93.12, 632}, + {-4123.11, -177.1}, + {950.2, -4412}, + {43.66, 33121.43}, + {-42.11, -123.43}}), + std::make_shared( + pool(), + MAP(keyVector->type(), valueVector->type()), + nullptr, + length, + offsets, + sizes, + keyVector, + valueVector)}); + + const auto schema = ROW( + {{"col_bigint_arr", ARRAY(BIGINT())}, + {"col_double_arr", ARRAY(DOUBLE())}, + {"col_map", MAP(BIGINT(), BOOLEAN())}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_text_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath.string(), + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // use traits to specify delimiters when it is not nested + const auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf, serDeOptions); + + EXPECT_EQ(result.size(), 13); + EXPECT_EQ(result[0].size(), 3); + + // col_bigint_arr + EXPECT_EQ(result[0][0], "1|11|111"); + EXPECT_EQ(result[1][0], "22|22222"); + EXPECT_EQ(result[2][0], "333|33"); + EXPECT_EQ(result[3][0], "4444|44"); + EXPECT_EQ(result[4][0], "5|555"); + EXPECT_EQ(result[5][0], "666|66|66"); + EXPECT_EQ(result[6][0], "7777|7|777"); + EXPECT_EQ(result[7][0], "8888|88"); + EXPECT_EQ(result[8][0], "9|99"); + EXPECT_EQ(result[9][0], "10|10000"); + EXPECT_EQ(result[10][0], "111|1|111"); + EXPECT_EQ(result[11][0], "12|11122222|222"); + EXPECT_EQ(result[12][0], "13|11133333|333"); + + // col_double_arr + EXPECT_EQ(result[0][1], "1.123000|1.312300"); + EXPECT_EQ(result[1][1], "2.333000|-5512.000000|1.230000"); + EXPECT_EQ(result[2][1], "-6.100000|65.777000"); + EXPECT_EQ(result[3][1], "4.200000|24.000000|324.110000"); + EXPECT_EQ(result[4][1], "47.200000|213.230000"); + EXPECT_EQ(result[5][1], "79.500000|-44.110000"); + EXPECT_EQ(result[6][1], "3.141593|441.124000"); + EXPECT_EQ(result[7][1], "-221.145000|878.430000|-11.000000"); + EXPECT_EQ(result[8][1], "93.120000|632.000000"); + EXPECT_EQ(result[9][1], "-4123.110000|-177.100000"); + EXPECT_EQ(result[10][1], "950.200000|-4412.000000"); + EXPECT_EQ(result[11][1], "43.660000|33121.430000"); + EXPECT_EQ(result[12][1], "-42.110000|-123.430000"); + + // col_map + EXPECT_EQ(result[0][2], "1#false|111#true"); + EXPECT_EQ(result[1][2], "22#true|22222#false"); + EXPECT_EQ(result[2][2], "333#false|33#true"); + EXPECT_EQ(result[3][2], "44#false"); + EXPECT_EQ(result[4][2], "5#true|555#false"); + EXPECT_EQ(result[5][2], "66#true"); + EXPECT_EQ(result[6][2], "7777#false|7#true|777#false"); + EXPECT_EQ(result[7][2], "8888#false|88#true"); + EXPECT_EQ(result[8][2], "9#true|99#false"); + EXPECT_EQ(result[9][2], "10#true|10000#false"); + EXPECT_EQ(result[10][2], "111#true|1#false|11#true"); + EXPECT_EQ(result[11][2], "11122222#true"); + EXPECT_EQ(result[12][2], "123142#true"); +} + +TEST_F(TextWriterTest, verifyMapAndArrayComplexTypesWithTextReader) { + const vector_size_t length = 13; + const auto keyVector = makeFlatVector( + {1, 111, 22, 22222, 333, 33, 44, 5, 555, 66, 7777, 7, + 777, 8888, 88, 9, 99, 10, 10000, 111, 1, 11, 11122222, 123142}); + const auto valueVector = makeFlatVector( + {false, true, true, false, false, true, false, true, + false, true, false, true, false, false, true, true, + false, true, false, true, false, true, true, true}); + BufferPtr sizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr offsets = facebook::velox::allocateOffsets(length, pool()); + auto rawSizes = sizes->asMutable(); + auto rawOffsets = offsets->asMutable(); + rawSizes[0] = 2; + rawSizes[1] = 2; + rawSizes[2] = 2; + rawSizes[3] = 1; + rawSizes[4] = 2; + rawSizes[5] = 1; + rawSizes[6] = 3; + rawSizes[7] = 2; + rawSizes[8] = 2; + rawSizes[9] = 2; + rawSizes[10] = 3; + rawSizes[11] = 1; + rawSizes[12] = 1; + for (int i = 1; i < length; i++) { + rawOffsets[i] = rawOffsets[i - 1] + rawSizes[i - 1]; + } + + const auto data = makeRowVector( + {makeArrayVector( + {{1, 11, 111}, + {22, 22222}, + {333, 33}, + {4444, 44}, + {5, 555}, + {666, 66, 66}, + {7777, 7, 777}, + {8888, 88}, + {9, 99}, + {10, 10000}, + {111, 1, 111}, + {12, 11122222, 222}, + {13, 11133333, 333}}), + makeArrayVector( + {{1.123, 1.3123}, + {2.333, -5512, 1.23}, + {-6.1, 65.777}, + {4.2, 24, 324.11}, + {47.2, 213.23}, + {79.5, -44.11}, + {3.1415926, 441.124}, + {-221.145, 878.43, -11}, + {93.12, 632}, + {-4123.11, -177.1}, + {950.2, -4412}, + {43.66, 33121.43}, + {-42.11, -123.43}}), + std::make_shared( + pool(), + MAP(keyVector->type(), valueVector->type()), + nullptr, + length, + offsets, + sizes, + keyVector, + valueVector)}); + + const auto schema = ROW( + {{"col_bigint_arr", ARRAY(BIGINT())}, + {"col_double_arr", ARRAY(DOUBLE())}, + {"col_map", MAP(BIGINT(), BOOLEAN())}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + auto filePath = + fs::path(fmt::format("{}/test_text_writer.txt", tempPath_->getPath())); + + auto sink = std::make_unique( + filePath.string(), + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // use traits to specify delimiters when it is not nested + const auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + // Set up reader. + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + // Change the expected value + const auto expected = makeRowVector( + {makeArrayVector( + {{1, 11, 111}, + {22, 22222}, + {333, 33}, + {4444, 44}, + {5, 555}, + {666, 66, 66}, + {7777, 7, 777}, + {8888, 88}, + {9, 99}, + {10, 10000}, + {111, 1, 111}, + {12, 11122222, 222}, + {13, 11133333, 333}}), + makeArrayVector( + {{1.123, 1.3123}, + {2.333, -5512, 1.23}, + {-6.1, 65.777}, + {4.2, 24, 324.11}, + {47.2, 213.23}, + {79.5, -44.11}, + {3.141593, 441.124}, + {-221.145, 878.43, -11}, + {93.12, 632}, + {-4123.11, -177.1}, + {950.2, -4412}, + {43.66, 33121.43}, + {-42.11, -123.43}}), + std::make_shared( + pool(), + MAP(keyVector->type(), valueVector->type()), + nullptr, + length, + offsets, + sizes, + keyVector, + valueVector)}); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + ASSERT_EQ(rowReader->next(13, result), 13); + for (int i = 0; i < 13; ++i) { + EXPECT_TRUE(result->equalValueAt(expected.get(), i, i)); + } +} + +TEST_F(TextWriterTest, arrayTypes) { + // Test specifically for ARRAY types with various element types + const auto data = makeRowVector({ + makeArrayVector( + {{1, 2, 3}, + {10, 20}, + {100, 200, 300, 400}, + {}, // empty array + {42}}), + makeArrayVector( + {{"hello", "world"}, + {"foo", "bar", "baz"}, + {"single"}, + {}, // empty array + {"test", "array", "string"}}), + makeArrayVector( + {{1.1, 2.2, 3.3}, + {10.5, 20.7}, + {}, // empty array + {99.99}, + {1.0, 2.0, 3.0, 4.0, 5.0}}), + }); + + const auto schema = ROW( + {{"int_array", ARRAY(INTEGER())}, + {"string_array", ARRAY(VARCHAR())}, + {"double_array", ARRAY(DOUBLE())}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_array_writer.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + const auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf, serDeOptions); + + EXPECT_EQ(result.size(), 5); + EXPECT_EQ(result[0].size(), 3); + + // int array type + EXPECT_EQ(result[0][0], "1|2|3"); + EXPECT_EQ(result[1][0], "10|20"); + EXPECT_EQ(result[2][0], "100|200|300|400"); + EXPECT_EQ(result[3][0], ""); + EXPECT_EQ(result[4][0], "42"); + + // varchar array type + EXPECT_EQ(result[0][1], "hello|world"); + EXPECT_EQ(result[1][1], "foo|bar|baz"); + EXPECT_EQ(result[2][1], "single"); + EXPECT_EQ(result[3][1], ""); + EXPECT_EQ(result[4][1], "test|array|string"); + + // double array type + EXPECT_EQ(result[0][2], "1.100000|2.200000|3.300000"); + EXPECT_EQ(result[1][2], "10.500000|20.700000"); + EXPECT_EQ(result[2][2], ""); + EXPECT_EQ(result[3][2], "99.990000"); + EXPECT_EQ(result[4][2], "1.000000|2.000000|3.000000|4.000000|5.000000"); +} + +TEST_F(TextWriterTest, verifyArrayTypesWithTextReader) { + // Test specifically for ARRAY types with various element types + const auto data = makeRowVector({ + makeArrayVector( + {{1, 2, 3}, + {10, 20}, + {100, 200, 300, 400}, + {}, // empty array + {42}}), + makeArrayVector( + {{"hello", "world"}, + {"foo", "bar", "baz"}, + {"single"}, + {}, // empty array + {"test", "array", "string"}}), + makeArrayVector( + {{1.1, 2.2, 3.3}, + {10.5, 20.7}, + {}, // empty array + {99.99}, + {1.0, 2.0, 3.0, 4.0, 5.0}}), + }); + + const auto schema = ROW( + {{"int_array", ARRAY(INTEGER())}, + {"string_array", ARRAY(VARCHAR())}, + {"double_array", ARRAY(DOUBLE())}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_array_writer.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + const auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + // Set up reader. + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(5, result), 5); + for (int i = 0; i < 5; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextWriterTest, mapTypes) { + // Test specifically for MAP types with various key-value combinations + const vector_size_t length = 5; + + // Create key and value vectors for the maps + const auto keyVector = + makeFlatVector({1, 2, 3, 10, 20, 100, 42, 99, 123, 456}); + const auto valueVector = makeFlatVector( + {true, false, true, false, true, true, false, true, false, true}); + + // Set up offsets and sizes for each row + BufferPtr sizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr offsets = facebook::velox::allocateOffsets(length, pool()); + auto rawSizes = sizes->asMutable(); + auto rawOffsets = offsets->asMutable(); + + rawSizes[0] = 3; // Row 0: 3 key-value pairs + rawSizes[1] = 2; // Row 1: 2 key-value pairs + rawSizes[2] = 1; // Row 2: 1 key-value pair + rawSizes[3] = 0; // Row 3: 0 key-value pairs (empty map) + rawSizes[4] = 4; // Row 4: 4 key-value pairs + + rawOffsets[0] = 0; + for (int i = 1; i < length; i++) { + rawOffsets[i] = rawOffsets[i - 1] + rawSizes[i - 1]; + } + + // Create string keys for a second map column + const auto stringKeyVector = makeFlatVector( + {"key1", "key2", "foo", "bar", "baz", "qux", "test"}); + const auto intValueVector = + makeFlatVector({10, 20, 100, 1, 2, 3, 999}); + + // Set up offsets and sizes for string key map + BufferPtr stringSizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr stringOffsets = facebook::velox::allocateOffsets(length, pool()); + auto rawStringSizes = stringSizes->asMutable(); + auto rawStringOffsets = stringOffsets->asMutable(); + + rawStringSizes[0] = 2; // Row 0: 2 key-value pairs + rawStringSizes[1] = 1; // Row 1: 1 key-value pair + rawStringSizes[2] = 3; // Row 2: 3 key-value pairs + rawStringSizes[3] = 0; // Row 3: 0 key-value pairs (empty map) + rawStringSizes[4] = 1; // Row 4: 1 key-value pair + + rawStringOffsets[0] = 0; + for (int i = 1; i < length; i++) { + rawStringOffsets[i] = rawStringOffsets[i - 1] + rawStringSizes[i - 1]; + } + + const auto data = makeRowVector( + {std::make_shared( + pool(), + MAP(BIGINT(), BOOLEAN()), + nullptr, + length, + offsets, + sizes, + keyVector, + valueVector), + std::make_shared( + pool(), + MAP(VARCHAR(), INTEGER()), + nullptr, + length, + stringOffsets, + stringSizes, + stringKeyVector, + intValueVector)}); + + const auto schema = ROW( + {{"int_bool_map", MAP(BIGINT(), BOOLEAN())}, + {"string_int_map", MAP(VARCHAR(), INTEGER())}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_map_writer.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + const auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf, serDeOptions); + + EXPECT_EQ(result.size(), 5); + EXPECT_EQ(result[0].size(), 2); + + // int_bool_map + EXPECT_EQ(result[0][0], "1#true|2#false|3#true"); + EXPECT_EQ(result[1][0], "10#false|20#true"); + EXPECT_EQ(result[2][0], "100#true"); + EXPECT_EQ(result[3][0], ""); + EXPECT_EQ(result[4][0], "42#false|99#true|123#false|456#true"); + + // string_int_map + EXPECT_EQ(result[0][1], "key1#10|key2#20"); + EXPECT_EQ(result[1][1], "foo#100"); + EXPECT_EQ(result[2][1], "bar#1|baz#2|qux#3"); + EXPECT_EQ(result[3][1], ""); + EXPECT_EQ(result[4][1], "test#999"); +} + +TEST_F(TextWriterTest, verifyMapTypesWithTextReader) { + // Test specifically for MAP types with various key-value combinations + const vector_size_t length = 5; + + // Create key and value vectors for the maps + const auto keyVector = + makeFlatVector({1, 2, 3, 10, 20, 100, 42, 99, 123, 456}); + const auto valueVector = makeFlatVector( + {true, false, true, false, true, true, false, true, false, true}); + + // Set up offsets and sizes for each row + BufferPtr sizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr offsets = facebook::velox::allocateOffsets(length, pool()); + auto rawSizes = sizes->asMutable(); + auto rawOffsets = offsets->asMutable(); + + rawSizes[0] = 3; // Row 0: 3 key-value pairs + rawSizes[1] = 2; // Row 1: 2 key-value pairs + rawSizes[2] = 1; // Row 2: 1 key-value pair + rawSizes[3] = 0; // Row 3: 0 key-value pairs (empty map) + rawSizes[4] = 4; // Row 4: 4 key-value pairs + + rawOffsets[0] = 0; + for (int i = 1; i < length; i++) { + rawOffsets[i] = rawOffsets[i - 1] + rawSizes[i - 1]; + } + + // Create string keys for a second map column + const auto stringKeyVector = makeFlatVector( + {"key1", "key2", "foo", "bar", "baz", "qux", "test"}); + const auto intValueVector = + makeFlatVector({10, 20, 100, 1, 2, 3, 999}); + + // Set up offsets and sizes for string key map + BufferPtr stringSizes = facebook::velox::allocateOffsets(length, pool()); + BufferPtr stringOffsets = facebook::velox::allocateOffsets(length, pool()); + auto rawStringSizes = stringSizes->asMutable(); + auto rawStringOffsets = stringOffsets->asMutable(); + + rawStringSizes[0] = 2; // Row 0: 2 key-value pairs + rawStringSizes[1] = 1; // Row 1: 1 key-value pair + rawStringSizes[2] = 3; // Row 2: 3 key-value pairs + rawStringSizes[3] = 0; // Row 3: 0 key-value pairs (empty map) + rawStringSizes[4] = 1; // Row 4: 1 key-value pair + + rawStringOffsets[0] = 0; + for (int i = 1; i < length; i++) { + rawStringOffsets[i] = rawStringOffsets[i - 1] + rawStringSizes[i - 1]; + } + + const auto data = makeRowVector( + {std::make_shared( + pool(), + MAP(BIGINT(), BOOLEAN()), + nullptr, + length, + offsets, + sizes, + keyVector, + valueVector), + std::make_shared( + pool(), + MAP(VARCHAR(), INTEGER()), + nullptr, + length, + stringOffsets, + stringSizes, + stringKeyVector, + intValueVector)}); + + const auto schema = ROW( + {{"int_bool_map", MAP(BIGINT(), BOOLEAN())}, + {"string_int_map", MAP(VARCHAR(), INTEGER())}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_map_writer.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + const auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + // Set up reader + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(5, result), 5); + for (int i = 0; i < 5; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextWriterTest, nestedRowTypes) { + // Test specifically for nested ROW types + auto nestedRowChildren = std::vector{ + makeFlatVector({42, 100, -5, 0, 999}), + makeFlatVector({true, false, true, false, true}), + makeArrayVector( + {{3.14159, 2.71828}, + {2.71828, 1.41421, 0.0}, + {1.41421, -123.456}, + {0.0, 999.999}, + {-123.456, 42.0, 3.14159}})}; + auto nestedRowVector = makeRowVector( + {"nested_int", "nested_bool", "nested_arr_double"}, nestedRowChildren); + + const auto data = makeRowVector( + {makeFlatVector( + {"hello", "world", "test", "sample", "data"}), + nestedRowVector, + makeFlatVector({false, true, false, true, false})}); + + const auto schema = ROW( + {{"col_varchar", VARCHAR()}, + {"col_nested_row", + ROW( + {{"nested_int", INTEGER()}, + {"nested_bool", BOOLEAN()}, + {"nested_arr_double", ARRAY(DOUBLE())}})}, + {"col_bool", BOOLEAN()}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_nested_row_writer.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + const auto serDeOptions = dwio::common::SerDeOptions(',', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf, serDeOptions); + + EXPECT_EQ(result.size(), 5); + EXPECT_EQ(result[0].size(), 3); + + EXPECT_EQ(result[0][0], "hello"); + EXPECT_EQ(result[1][0], "world"); + EXPECT_EQ(result[2][0], "test"); + EXPECT_EQ(result[3][0], "sample"); + EXPECT_EQ(result[4][0], "data"); + + // nested row + EXPECT_EQ(result[0][1], "42|true|3.141590#2.718280"); + EXPECT_EQ(result[1][1], "100|false|2.718280#1.414210#0.000000"); + EXPECT_EQ(result[2][1], "-5|true|1.414210#-123.456000"); + EXPECT_EQ(result[3][1], "0|false|0.000000#999.999000"); + EXPECT_EQ(result[4][1], "999|true|-123.456000#42.000000#3.141590"); + + EXPECT_EQ(result[0][2], "false"); + EXPECT_EQ(result[1][2], "true"); + EXPECT_EQ(result[2][2], "false"); + EXPECT_EQ(result[3][2], "true"); + EXPECT_EQ(result[4][2], "false"); +} + +TEST_F(TextWriterTest, verifyNestedRowTypesWithTextReader) { + // Test specifically for nested ROW types + auto nestedRowChildren = std::vector{ + makeFlatVector({42, 100, -5, 0, 999}), + makeFlatVector({true, false, true, false, true}), + makeArrayVector( + {{3.14159, 2.71828}, + {2.71828, 1.41421, 0.0}, + {1.41421, -123.456}, + {0.0, 999.999}, + {-123.456, 42.0, 3.14159}})}; + auto nestedRowVector = makeRowVector( + {"nested_int", "nested_bool", "nested_arr_double"}, nestedRowChildren); + + const auto data = makeRowVector( + {makeFlatVector( + {"hello", "world", "test", "sample", "data"}), + nestedRowVector, + makeFlatVector({false, true, false, true, false})}); + + const auto schema = ROW( + {{"col_varchar", VARCHAR()}, + {"col_nested_row", + ROW( + {{"nested_int", INTEGER()}, + {"nested_bool", BOOLEAN()}, + {"nested_arr_double", ARRAY(DOUBLE())}})}, + {"col_bool", BOOLEAN()}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_nested_row_writer.txt"; + + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Use custom delimiters: field separator '\x01', nested row field separator + // '\x02' + const auto serDeOptions = dwio::common::SerDeOptions('\x01', '\x02', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + // Set up reader + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(5, result), 5); + for (int i = 0; i < 5; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } +} + +TEST_F(TextWriterTest, DISABLED_deeplyNestedComplexTypes) { + // Inner maps for the arrays + const auto innerMapKeys1 = makeFlatVector({1, 11, 22}); + const auto innerMapValues1 = makeFlatVector({true, false, true}); + const auto innerMapKeys2 = makeFlatVector({33, 44}); + const auto innerMapValues2 = makeFlatVector({false, true}); + const auto innerMapKeys3 = makeFlatVector({55, 66, 77, 88}); + const auto innerMapValues3 = makeFlatVector({true, true, false, true}); + const auto innerMapKeys4 = makeFlatVector(std::vector{99}); + const auto innerMapValues4 = makeFlatVector(std::vector{false}); + const auto innerMapKeys5 = makeFlatVector({100, 200}); + const auto innerMapValues5 = makeFlatVector({true, false}); + const auto innerMapKeys6 = makeFlatVector({300, 400, 500}); + const auto innerMapValues6 = makeFlatVector({false, false, true}); + + // Combine all inner map keys and values + const auto allInnerMapKeys = makeFlatVector( + {1, 11, 22, 33, 44, 55, 66, 77, 88, 99, 100, 200, 300, 400, 500}); + const auto allInnerMapValues = makeFlatVector( + {true, + false, + true, + false, + true, + true, + true, + false, + true, + false, + true, + false, + false, + false, + true}); + + // Create inner maps with proper offsets and sizes + BufferPtr innerMapSizes = allocateOffsets(6, pool()); + BufferPtr innerMapOffsets = allocateOffsets(6, pool()); + auto rawInnerMapSizes = innerMapSizes->asMutable(); + auto rawInnerMapOffsets = innerMapOffsets->asMutable(); + + rawInnerMapSizes[0] = 3; // {1:true, 11:false, 22:true} + rawInnerMapSizes[1] = 2; // {33:false, 44:true} + rawInnerMapSizes[2] = 4; // {55:true, 66:true, 77:false, 88:true} + rawInnerMapSizes[3] = 1; // {99:false} + rawInnerMapSizes[4] = 2; // {100:true, 200:false} + rawInnerMapSizes[5] = 3; // {300:false, 400:false, 500:true} + + rawInnerMapOffsets[0] = 0; + for (int i = 1; i < 6; i++) { + rawInnerMapOffsets[i] = rawInnerMapOffsets[i - 1] + rawInnerMapSizes[i - 1]; + } + + auto innerMapsVector = std::make_shared( + pool(), + MAP(BIGINT(), BOOLEAN()), + nullptr, + 6, + innerMapOffsets, + innerMapSizes, + allInnerMapKeys, + allInnerMapValues); + + // Create arrays containing the inner maps + // Array 1: [innerMap0, innerMap1] (maps at indices 0, 1) + // Array 2: [innerMap2, innerMap3, innerMap4] (maps at indices 2, 3, 4) + // Array 3: [innerMap5] (map at index 5) + auto arrayVector = makeArrayVector({0, 2, 5, 6}, innerMapsVector); + + // Create the outer map keys + const auto outerMapKeys = makeFlatVector({10, 20, 30}); + + // Create the final data structure + auto outerMapOffsets = allocateOffsets(3, pool()); + auto outerMapSizes = allocateOffsets(3, pool()); + auto rawOuterMapOffsets = outerMapOffsets->asMutable(); + auto rawOuterMapSizes = outerMapSizes->asMutable(); + + // Set offsets: [0, 1, 2] + rawOuterMapOffsets[0] = 0; + rawOuterMapOffsets[1] = 1; + rawOuterMapOffsets[2] = 2; + + // Set sizes: [1, 1, 1] (each outer map entry contains 1 array) + rawOuterMapSizes[0] = 1; + rawOuterMapSizes[1] = 1; + rawOuterMapSizes[2] = 1; + + const auto data = makeRowVector({std::make_shared( + pool(), + MAP(BIGINT(), ARRAY(MAP(BIGINT(), BOOLEAN()))), + nullptr, + 3, + outerMapOffsets, + outerMapSizes, + outerMapKeys, + arrayVector)}); + + const auto schema = + ROW({{"col_nested_map", MAP(BIGINT(), ARRAY(MAP(BIGINT(), BOOLEAN())))}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_text_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath.string(), + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Define hierarchical delimiters for nested data structures: + // '\x01' - field separator (between columns) + // '|' - nesting level 1 + // '#' - nesting level 2 + // '!', '@', '$', '%', '^' - separators for deeper nesting levels + // This creates a delimiter hierarchy to properly serialize complex nested + // types + auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', '#'); + serDeOptions.separators[3] = '!'; + serDeOptions.separators[4] = '@'; + serDeOptions.separators[5] = '$'; + serDeOptions.separators[6] = '%'; + serDeOptions.separators[7] = '^'; + + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + + writer->write(data); + writer->close(); + + BufferPtr charBuf = AlignedBuffer::allocate(1024, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf, serDeOptions); + + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0].size(), 1); + EXPECT_EQ(result[1].size(), 1); + EXPECT_EQ(result[2].size(), 1); + + // Row 0: {10: [{1:true, 11:false, 22:true}, {33:false, 44:true}]} + EXPECT_EQ(result[0][0], "10#1@true!11@false!22@true#33@false!44@true"); + + // Row 1: {20: [{55:true, 66:true, 77:false, 88:true}, {99:false}, {100:true, + // 200:false}]} + EXPECT_EQ( + result[1][0], + "20#55@true!66@true!77@false!88@true#99@false#100@true!200@false"); + + // Row 2: {30: [{300:false, 400:false, 500:true}]} + EXPECT_EQ(result[2][0], "30#300@false!400@false!500@true"); +} + TEST_F(TextWriterTest, abort) { auto schema = ROW({"c0", "c1"}, {BIGINT(), BOOLEAN()}); auto data = makeRowVector( @@ -155,10 +1268,14 @@ TEST_F(TextWriterTest, abort) { WriterOptions writerOptions; writerOptions.memoryPool = rootPool_.get(); writerOptions.defaultFlushCount = 10; - auto filePath = fs::path( - fmt::format("{}/test_text_writer_abort.txt", tempPath_->getPath())); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_text_writer_abort.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( - filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + filePath.string(), + dwio::common::FileSink::Options{.pool = leafPool_.get()}); auto writer = std::make_unique( schema, std::move(sink), @@ -166,13 +1283,592 @@ TEST_F(TextWriterTest, abort) { writer->write(data); writer->abort(); - std::string result = readFile(filePath); + uint64_t result = readFile(tempPath, filePath.filename().string()); + // With defaultFlushCount as 10, it will trigger two times of flushes before // abort, and abort will discard the remaining 5 characters in buffer. The // written file would have: - // 1^Atrue - // 2^Atrue + // 1^Atrue\n + // 2^Atrue\n // 3^A - EXPECT_EQ(result.size(), 14); + EXPECT_EQ(result, 16); } + +TEST_F(TextWriterTest, tripleNestedArraysWithCustomDelimiters) { + // Create expected triple nested array structure + // Row 1: [[[1,2], [3,4]], [[5,6,7], [8,9]], [[10,11,12,13], [14,15]]] + // Row 2: [[[20,21], [22,23,24]], [[25,26]], [[27,28,29], [30,31,32]]] + // Row 3: [[[100,101,102]], [[200,201], [300,301,302,303]], [[400,401,402], + // [500,501]]] + + // Create the innermost arrays (level 3) + auto innermostArrays = makeArrayVector({ + {1, 2}, + {3, 4}, // Row 1, outer array 0 + {5, 6, 7}, + {8, 9}, // Row 1, outer array 1 + {10, 11, 12, 13}, + {14, 15}, // Row 1, outer array 2 + {20, 21}, + {22, 23, 24}, // Row 2, outer array 0 + {25, 26}, // Row 2, outer array 1 + {27, 28, 29}, + {30, 31, 32}, // Row 2, outer array 2 + {100, 101, 102}, // Row 3, outer array 0 + {200, 201}, + {300, 301, 302, 303}, // Row 3, outer array 1 + {400, 401, 402}, + {500, 501} // Row 3, outer array 2 + }); + + // Create middle level arrays (level 2) - each contains innermost arrays + auto middleArrays = makeArrayVector( + { + 0, // Row 1, outer array 0: contains innermost arrays [0,1] + 2, // Row 1, outer array 1: contains innermost arrays [2,3] + 4, // Row 1, outer array 2: contains innermost arrays [4,5] + 6, // Row 2, outer array 0: contains innermost arrays [6,7] + 8, // Row 2, outer array 1: contains innermost arrays [8] + 9, // Row 2, outer array 2: contains innermost arrays [9,10] + 11, // Row 3, outer array 0: contains innermost arrays [11] + 12, // Row 3, outer array 1: contains innermost arrays [12,13] + 14, // Row 3, outer array 2: contains innermost arrays [14,15] + }, + innermostArrays); + + // Create outermost arrays (level 1) - each row contains middle arrays + auto outerArray = makeArrayVector({0, 3, 6}, middleArrays); + + const auto data = makeRowVector({outerArray}); + + const auto schema = + ROW({{"col_triple_nested_array", ARRAY(ARRAY(ARRAY(BIGINT())))}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_text_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + auto sink = std::make_unique( + filePath.string(), + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', ','); + serDeOptions.separators[3] = '#'; + + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf, serDeOptions); + + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0].size(), 1); + + EXPECT_EQ(result[0][0], "1#2,3#4|5#6#7,8#9|10#11#12#13,14#15"); + EXPECT_EQ(result[1][0], "20#21,22#23#24|25#26|27#28#29,30#31#32"); + EXPECT_EQ( + result[2][0], "100#101#102|200#201,300#301#302#303|400#401#402,500#501"); +} + +TEST_F( + TextWriterTest, + verifyTripleNestedArraysWithCustomDelimitersWithTextReader) { + // Create expected triple nested array structure + // Row 1: [[[1,2], [3,4]], [[5,6,7], [8,9]], [[10,11,12,13], [14,15]]] + // Row 2: [[[20,21], [22,23,24]], [[25,26]], [[27,28,29], [30,31,32]]] + // Row 3: [[[100,101,102]], [[200,201], [300,301,302,303]], [[400,401,402], + // [500,501]]] + + // Create the innermost arrays (level 3) + auto innermostArrays = makeArrayVector({ + {1, 2}, + {3, 4}, // Row 1, outer array 0 + {5, 6, 7}, + {8, 9}, // Row 1, outer array 1 + {10, 11, 12, 13}, + {14, 15}, // Row 1, outer array 2 + {20, 21}, + {22, 23, 24}, // Row 2, outer array 0 + {25, 26}, // Row 2, outer array 1 + {27, 28, 29}, + {30, 31, 32}, // Row 2, outer array 2 + {100, 101, 102}, // Row 3, outer array 0 + {200, 201}, + {300, 301, 302, 303}, // Row 3, outer array 1 + {400, 401, 402}, + {500, 501} // Row 3, outer array 2 + }); + + // Create middle level arrays (level 2) - each contains innermost arrays + auto middleArrays = makeArrayVector( + { + 0, // Row 1, outer array 0: contains innermost arrays [0,1] + 2, // Row 1, outer array 1: contains innermost arrays [2,3] + 4, // Row 1, outer array 2: contains innermost arrays [4,5] + 6, // Row 2, outer array 0: contains innermost arrays [6,7] + 8, // Row 2, outer array 1: contains innermost arrays [8] + 9, // Row 2, outer array 2: contains innermost arrays [9,10] + 11, // Row 3, outer array 0: contains innermost arrays [11] + 12, // Row 3, outer array 1: contains innermost arrays [12,13] + 14, // Row 3, outer array 2: contains innermost arrays [14,15] + // 16 // End marker + }, + innermostArrays); + + // Create outermost arrays (level 1) - each row contains middle arrays + auto outerArray = makeArrayVector( + { + 0, 3, 6 + // , 9 // Row boundaries: Row 1 [0-2], Row 2 [3-5], Row 3 [6-8] + }, + middleArrays); + + const auto data = makeRowVector({outerArray}); + + const auto schema = + ROW({{"col_triple_nested_array", ARRAY(ARRAY(ARRAY(BIGINT())))}}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + auto filePath = + fs::path(fmt::format("{}/test_text_writer.txt", tempPath_->getPath())); + auto sink = std::make_unique( + filePath.string(), + dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + auto serDeOptions = dwio::common::SerDeOptions('\x01', '|', ','); + serDeOptions.separators[3] = '#'; + + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + writer->write(data); + writer->close(); + + // Set up reader. + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } +} + +TEST_F(TextWriterTest, simpleEscapeCharTest) { + auto schema = ROW({"col_string", "col_array"}, {VARCHAR(), ARRAY(VARCHAR())}); + + // Create test data with strings containing comma delimiter characters + // Field delimiter: ',', Collection delimiter: '|', Escape char: '\' + auto data = makeRowVector( + {"col_string", "col_array"}, + {makeFlatVector( + {"engineer,manager", "developer,senior", "analyst,junior"}, + VARCHAR()), + // Array column with strings containing comma delimiters + makeArrayVector( + {{"role,title", "position,level"}, // Array with comma delimiters + {"job,description", "work,type"}, + {"career,path", "skill,set"}})}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_escape_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Configure SerDeOptions with comma as field delimiter and escaping enabled + auto serDeOptions = dwio::common::SerDeOptions( + ',', // field delimiter (comma) + '|', // collection delimiter + '#', // map key delimiter + '\\', // escape character + true // isEscaped = true + ); + + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf); + EXPECT_EQ(result.size(), 3); + + // Verify that commas in strings are properly escaped with backslashes + // Temporary parse because parseTextFile is not handling escape characters + EXPECT_EQ(result[0][0], "engineer\\,manager,role\\,title|position\\,level"); + EXPECT_EQ(result[1][0], "developer\\,senior,job\\,description|work\\,type"); + EXPECT_EQ(result[2][0], "analyst\\,junior,career\\,path|skill\\,set"); +} + +TEST_F(TextWriterTest, verifySimpleEscapeCharTestWithTextReader) { + auto schema = ROW({"col_string", "col_array"}, {VARCHAR(), ARRAY(VARCHAR())}); + + // Create test data with strings containing comma delimiter characters + // Field delimiter: ',', Collection delimiter: '|', Escape char: '\' + auto data = makeRowVector( + {"col_string", "col_array"}, + {makeFlatVector( + {"engineer,manager", "developer,senior", "analyst,junior"}, + VARCHAR()), + // Array column with strings containing comma delimiters + makeArrayVector( + {{"role,title", "position,level"}, // Array with comma delimiters + {"job,description", "work,type"}, + {"career,path", "skill,set"}})}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_escape_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Configure SerDeOptions with comma as field delimiter and escaping enabled + auto serDeOptions = dwio::common::SerDeOptions( + ',', // field delimiter (comma) + '|', // collection delimiter + '#', // map key delimiter + '\\', // escape character + true // isEscaped = true + ); + + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + + writer->write(data); + writer->close(); + + // Set up reader + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(5, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextWriterTest, customEscapeCharTest) { + auto schema = ROW({"col_string", "col_array"}, {VARCHAR(), ARRAY(VARCHAR())}); + + // Create test data with strings containing comma delimiter characters + // Field delimiter: ',', Collection delimiter: '|', Escape char: '\' + auto data = makeRowVector( + {"col_string", "col_array"}, + {makeFlatVector( + {"engineer,manager", "developer,senior", "analyst,junior"}, + VARCHAR()), + // Array column with strings containing comma delimiters + makeArrayVector( + {{"role,title", "position,level"}, // Array with comma delimiters + {"job,description", "work,type"}, + {"career,path", "skill,set"}})}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_escape_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Configure SerDeOptions with comma as field delimiter and escaping enabled + const auto serDeOptions = dwio::common::SerDeOptions( + ',', // field delimiter (comma) + '|', // collection delimiter + '#', // map key delimiter + '@', // escape character + true // isEscaped = true + ); + + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf); + + EXPECT_EQ(result.size(), 3); + + // Verify that commas in strings are properly escaped with @ character + EXPECT_EQ(result[0][0], "engineer@,manager,role@,title|position@,level"); + EXPECT_EQ(result[1][0], "developer@,senior,job@,description|work@,type"); + EXPECT_EQ(result[2][0], "analyst@,junior,career@,path|skill@,set"); +} + +TEST_F(TextWriterTest, verifyCustomEscapeCharTestWithTextReader) { + auto schema = ROW({"col_string", "col_array"}, {VARCHAR(), ARRAY(VARCHAR())}); + + // Create test data with strings containing comma delimiter characters + // Field delimiter: ',', Collection delimiter: '|', Escape char: '\' + auto data = makeRowVector( + {"col_string", "col_array"}, + {makeFlatVector( + {"engineer,manager", "developer,senior", "analyst,junior"}, + VARCHAR()), + // Array column with strings containing comma delimiters + makeArrayVector( + {{"role,title", "position,level"}, // Array with comma delimiters + {"job,description", "work,type"}, + {"career,path", "skill,set"}})}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_escape_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Configure SerDeOptions with comma as field delimiter and escaping enabled + const auto serDeOptions = dwio::common::SerDeOptions( + ',', // field delimiter (comma) + '|', // collection delimiter + '#', // map key delimiter + '@', // escape character + true // isEscaped = true + ); + + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + + writer->write(data); + writer->close(); + + // Set up reader + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + setScanSpec(*schema, rowReaderOptions); + + auto rowReader = reader->createRowReader(rowReaderOptions); + + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + +TEST_F(TextWriterTest, headerTest) { + auto schema = ROW({"name", "age", "city"}, {VARCHAR(), INTEGER(), VARCHAR()}); + + // Create simple test data with 3 rows + auto data = makeRowVector( + {"name", "age", "city"}, + {makeFlatVector({"Alice", "Bob", "Charlie"}, VARCHAR()), + makeFlatVector({25, 30, 35}), + makeFlatVector({"NYC", "LA", "Chicago"}, VARCHAR())}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + writerOptions.headerLineCount = 1; + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_header_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Use comma as field delimiter to match header format + const auto serDeOptions = dwio::common::SerDeOptions(',', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + + writer->write(data); + writer->close(); + + const auto fs = filesystems::getFileSystem(tempPath, nullptr); + const auto& file = fs->openFileForRead(filePath.string()); + auto fileSize = file->size(); + + BufferPtr charBuf = AlignedBuffer::allocate(fileSize, pool()); + auto rawCharBuf = charBuf->asMutable(); + std::vector> result = + parseTextFile(tempPath, filename, rawCharBuf, serDeOptions); + + EXPECT_EQ(result.size(), 4); // Header + 3 data rows + EXPECT_EQ(result[0].size(), 3); + + // Verify header is first row + EXPECT_EQ(result[0][0], "name"); + EXPECT_EQ(result[0][1], "age"); + EXPECT_EQ(result[0][2], "city"); + + // Verify data rows + EXPECT_EQ(result[1][0], "Alice"); + EXPECT_EQ(result[1][1], "25"); + EXPECT_EQ(result[1][2], "NYC"); + + EXPECT_EQ(result[2][0], "Bob"); + EXPECT_EQ(result[2][1], "30"); + EXPECT_EQ(result[2][2], "LA"); + + EXPECT_EQ(result[3][0], "Charlie"); + EXPECT_EQ(result[3][1], "35"); + EXPECT_EQ(result[3][2], "Chicago"); +} + +TEST_F(TextWriterTest, verifyHeaderTestWithTextReader) { + auto schema = ROW({"name", "age", "city"}, {VARCHAR(), INTEGER(), VARCHAR()}); + + // Create simple test data with 3 rows + auto data = makeRowVector( + {"name", "age", "city"}, + {makeFlatVector({"Alice", "Bob", "Charlie"}, VARCHAR()), + makeFlatVector({25, 30, 35}), + makeFlatVector({"NYC", "LA", "Chicago"}, VARCHAR())}); + + WriterOptions writerOptions; + writerOptions.memoryPool = rootPool_.get(); + writerOptions.headerLineCount = 1; + + const auto tempPath = tempPath_->getPath(); + const auto filename = "test_header_writer.txt"; + auto filePath = fs::path(fmt::format("{}/{}", tempPath, filename)); + + auto sink = std::make_unique( + filePath, dwio::common::FileSink::Options{.pool = leafPool_.get()}); + + // Use comma as field delimiter to match header format + const auto serDeOptions = dwio::common::SerDeOptions(',', '|', '#'); + auto writer = std::make_unique( + schema, + std::move(sink), + std::make_shared(writerOptions), + serDeOptions); + + writer->write(data); + writer->close(); + + // Set up reader + auto readerFactory = + dwio::common::getReaderFactory(dwio::common::FileFormat::TEXT); + auto readFile = std::make_shared(filePath.string()); + auto readerOptions = dwio::common::ReaderOptions(pool()); + readerOptions.setFileSchema(schema); + readerOptions.setSerDeOptions(serDeOptions); + + auto input = + std::make_unique(readFile, poolRef()); + auto reader = readerFactory->createReader(std::move(input), readerOptions); + dwio::common::RowReaderOptions rowReaderOptions; + rowReaderOptions.setSkipRows(1); + setScanSpec(*schema, rowReaderOptions); + + auto rowReader = reader->createRowReader(rowReaderOptions); + EXPECT_EQ(*reader->rowType(), *schema); + + VectorPtr result; + + ASSERT_EQ(rowReader->next(3, result), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(result->equalValueAt(data.get(), i, i)); + } + ASSERT_EQ(rowReader->next(10, result), 0); +} + } // namespace facebook::velox::text diff --git a/velox/dwio/text/writer/TextWriter.cpp b/velox/dwio/text/writer/TextWriter.cpp index 2b3c79ab9a90..0293fa4ee832 100644 --- a/velox/dwio/text/writer/TextWriter.cpp +++ b/velox/dwio/text/writer/TextWriter.cpp @@ -15,13 +15,14 @@ */ #include "velox/dwio/text/writer/TextWriter.h" +#include "velox/common/encode/Base64.h" #include -#include "velox/common/base/Pointers.h" -#include "velox/common/encode/Base64.h" -#include "velox/exec/MemoryReclaimer.h" namespace facebook::velox::text { + +using dwio::common::SerDeOptions; + template std::optional toTextStr(T val) { return std::optional(std::to_string(val)); @@ -57,6 +58,7 @@ std::optional toTextStr(double val) { template <> std::optional toTextStr(Timestamp val) { TimestampToStringOptions options; + val.toTimezone(Timestamp::defaultTimezone()); options.dateTimeSeparator = ' '; options.precision = TimestampPrecision::kMilliseconds; return {val.toString(options)}; @@ -65,24 +67,83 @@ std::optional toTextStr(Timestamp val) { TextWriter::TextWriter( RowTypePtr schema, std::unique_ptr sink, - const std::shared_ptr& options) + const std::shared_ptr& options, + const SerDeOptions& serDeOptions) : schema_(std::move(schema)), - bufferedWriterSink_(std::make_unique( - std::move(sink), - options->memoryPool->addLeafChild(fmt::format( - "{}.text_writer_node.{}", - options->memoryPool->name(), - folly::to(folly::Random::rand64()))), - options->defaultFlushCount)) {} + bufferedWriterSink_( + std::make_unique( + std::move(sink), + options->memoryPool->addLeafChild( + fmt::format( + "{}.text_writer_node.{}", + options->memoryPool->name(), + folly::to(folly::Random::rand64()))), + options->defaultFlushCount)), + headerLineCount_(options->headerLineCount), + serDeOptions_(serDeOptions) { + VELOX_CHECK_LE(headerLineCount_, 1, "Header line count must be <= 1"); +} + +uint8_t TextWriter::getDelimiterForDepth(uint8_t depth) const { + VELOX_CHECK_LT( + depth, + serDeOptions_.separators.size(), + "Depth {} exceeds maximum supported depth", + depth); + return serDeOptions_.separators[depth]; +} + +// Adds escape characters before separator/delimiting characters in the input +// string. +std::string TextWriter::addEscapeChar(std::string&& data, uint8_t depth) { + if (!serDeOptions_.isEscaped) { + return std::move(data); + } + + std::string escapedData; + escapedData.reserve(data.length() * 2); + + for (size_t i = 0; i < data.length(); ++i) { + // Break out of the loop earlier if we count down. + for (int j = depth - 1; j >= 0; --j) { + if (data[i] == serDeOptions_.separators[j]) { + escapedData += static_cast(serDeOptions_.escapeChar); + break; + } + } + escapedData += data[i]; + } + + return escapedData; +} void TextWriter::write(const VectorPtr& data) { VELOX_CHECK_EQ( data->encoding(), VectorEncoding::Simple::ROW, "Text writer expects row vector input"); + VELOX_CHECK( data->type()->equivalent(*schema_), "The file schema type should be equal with the input row vector type."); + + // write 1 row of header + if (headerLineCount_ == 1) { + const auto numCols = schema_->size(); + for (column_index_t col = 0; col < numCols; ++col) { + if (col != 0) { + bufferedWriterSink_->write((char)serDeOptions_.separators[0]); + } + + std::string escapedcolName = + addEscapeChar(std::string(schema_->nameOf(col)), 0); + bufferedWriterSink_->write( + escapedcolName.data(), escapedcolName.length()); + } + + bufferedWriterSink_->write((char)serDeOptions_.newLine); + } + const RowVector* dataRowVector = data->as(); std::vector> decodedColumnVectors; @@ -94,15 +155,19 @@ void TextWriter::write(const VectorPtr& data) { decodedColumnVectors.push_back(std::move(decodedColumnVector)); } + std::optional delimiter; for (vector_size_t row = 0; row < data->size(); ++row) { for (size_t column = 0; column < numColumns; ++column) { - if (column != 0) { - bufferedWriterSink_->write(TextFileTraits::kSOH); - } + delimiter = (column == 0) ? std::nullopt + : std::optional(serDeOptions_.separators[0]); writeCellValue( - decodedColumnVectors.at(column), schema_->childAt(column), row); + decodedColumnVectors.at(column), + schema_->childAt(column)->kind(), + row, + 0, + delimiter); } - bufferedWriterSink_->write(TextFileTraits::kNewLine); + bufferedWriterSink_->write((char)serDeOptions_.newLine); } } @@ -120,73 +185,186 @@ void TextWriter::abort() { void TextWriter::writeCellValue( const std::shared_ptr& decodedColumnVector, - const TypePtr& type, - vector_size_t row) { - std::optional dataStr; - std::optional dataSV; + const TypeKind type, + vector_size_t row, + uint8_t depth, + std::optional delimiter) { + if (delimiter.has_value()) { + bufferedWriterSink_->write((char)delimiter.value()); + } if (decodedColumnVector->isNullAt(row)) { + std::string escapedNullString = + addEscapeChar(std::string(serDeOptions_.nullString), depth); bufferedWriterSink_->write( - TextFileTraits::kNullData.data(), TextFileTraits::kNullData.length()); + escapedNullString.data(), escapedNullString.length()); return; } - switch (type->kind()) { - case TypeKind::BOOLEAN: + + ++depth; + + /// TODO: Increase supported depth in future + VELOX_CHECK_LE(depth, 4, "Depth {} exceeds maximum supported depth", 4); + + std::optional dataStr = std::nullopt; + std::optional dataSV = std::nullopt; + + switch (type) { + case TypeKind::BOOLEAN: { dataStr = toTextStr(folly::to(decodedColumnVector->valueAt(row))); break; - case TypeKind::TINYINT: + } + case TypeKind::TINYINT: { dataStr = toTextStr(decodedColumnVector->valueAt(row)); break; - case TypeKind::SMALLINT: + } + case TypeKind::SMALLINT: { dataStr = toTextStr(decodedColumnVector->valueAt(row)); break; - case TypeKind::INTEGER: + } + case TypeKind::INTEGER: { dataStr = toTextStr(decodedColumnVector->valueAt(row)); break; - case TypeKind::BIGINT: + } + case TypeKind::BIGINT: { dataStr = toTextStr(decodedColumnVector->valueAt(row)); break; - case TypeKind::REAL: + } + case TypeKind::REAL: { dataStr = toTextStr(decodedColumnVector->valueAt(row)); break; - case TypeKind::DOUBLE: + } + case TypeKind::DOUBLE: { dataStr = toTextStr(decodedColumnVector->valueAt(row)); break; - case TypeKind::TIMESTAMP: + } + case TypeKind::TIMESTAMP: { dataStr = toTextStr(decodedColumnVector->valueAt(row)); break; - case TypeKind::VARCHAR: + } + case TypeKind::VARCHAR: { dataSV = std::optional(decodedColumnVector->valueAt(row)); break; + } case TypeKind::VARBINARY: { auto data = decodedColumnVector->valueAt(row); dataStr = std::optional(encoding::Base64::encode(data.data(), data.size())); break; } - // TODO Add support for complex types - case TypeKind::ARRAY: + case TypeKind::ARRAY: { + // ARRAY vector members + const auto& arrVecPtr = decodedColumnVector->base()->as(); + const auto& indices = decodedColumnVector->indices(); + const auto& size = arrVecPtr->sizeAt(indices[row]); + const auto& offset = arrVecPtr->offsetAt(indices[row]); + + auto slice = arrVecPtr->elements()->slice(offset, size); + auto decodedElement = + std::make_shared(DecodedVector(*slice)); + for (int i = 0; i < size; ++i) { + delimiter = (i == 0) ? std::nullopt + : std::optional(getDelimiterForDepth(depth)); + writeCellValue( + decodedElement, + arrVecPtr->elements().get()->typeKind(), + i, + depth, + delimiter); + } + return; + } + case TypeKind::MAP: { + // MAP vector members + const auto& mapVecPtr = decodedColumnVector->base()->as(); + const auto& indices = decodedColumnVector->indices(); + const auto& size = mapVecPtr->sizeAt(indices[row]); + const auto& offset = mapVecPtr->offsetAt(indices[row]); + + auto keySlice = mapVecPtr->mapKeys()->slice(offset, size); + auto decodedKeys = + std::make_shared(DecodedVector(*keySlice)); + + auto valSlice = mapVecPtr->mapValues()->slice(offset, size); + auto decodedValues = + std::make_shared(DecodedVector(*valSlice)); + + for (int i = 0; i < size; ++i) { + delimiter = (i == 0) ? std::nullopt + : std::optional(getDelimiterForDepth(depth)); + writeCellValue( + decodedKeys, + mapVecPtr->mapKeys().get()->typeKind(), + i, + depth, + delimiter); + + delimiter = std::optional(getDelimiterForDepth(depth + 1)); + writeCellValue( + decodedValues, + mapVecPtr->mapValues().get()->typeKind(), + i, + depth, + delimiter); + } + + return; + } + case TypeKind::ROW: { + const RowVector* rowVecPtr = decodedColumnVector->base()->as(); + const auto& indices = decodedColumnVector->indices(); + const auto actualRowIndex = indices[row]; + + std::vector> decodedColumnVectors; + const auto numColumns = rowVecPtr->childrenSize(); + for (size_t column = 0; column < numColumns; ++column) { + auto decodedColumnVector = + std::make_shared(DecodedVector( + *rowVecPtr->childAt(column), + SelectivityVector(rowVecPtr->size()))); + decodedColumnVectors.push_back(std::move(decodedColumnVector)); + } + + std::optional nestedRowDelimiter; + for (size_t column = 0; column < numColumns; ++column) { + nestedRowDelimiter = (column == 0) + ? std::nullopt + : std::optional(getDelimiterForDepth(depth)); + writeCellValue( + decodedColumnVectors.at(column), + rowVecPtr->childAt(column)->typeKind(), + actualRowIndex, + depth, + nestedRowDelimiter); + } + return; + } + case TypeKind::UNKNOWN: [[fallthrough]]; - case TypeKind::MAP: + case TypeKind::FUNCTION: [[fallthrough]]; - case TypeKind::ROW: + case TypeKind::OPAQUE: [[fallthrough]]; - case TypeKind::UNKNOWN: + case TypeKind::INVALID: [[fallthrough]]; default: - VELOX_NYI("{} is not supported yet in TextWriter", type->kind()); + VELOX_NYI( + "Text writer does not support type {}", TypeKindName::toName(type)); } - if (dataStr.has_value()) { - VELOX_CHECK(!dataSV.has_value()); - bufferedWriterSink_->write( - dataStr.value().data(), dataStr.value().length()); - return; - } + VELOX_CHECK( + dataStr.has_value() ^ dataSV.has_value(), + "Exactly one of dataStr or dataSV must be set. Currently dataStr is {} and dataSV is {}", + dataStr.has_value(), + dataSV.has_value()); + + std::string data = dataStr.has_value() + ? dataStr.value() + : std::string(dataSV.value().data(), dataSV.value().size()); - VELOX_CHECK(dataSV.has_value()); - bufferedWriterSink_->write(dataSV.value().data(), dataSV.value().size()); + std::string escapedData = addEscapeChar(std::move(data), depth); + bufferedWriterSink_->write(escapedData.data(), escapedData.length()); } std::unique_ptr TextWriterFactory::createWriter( diff --git a/velox/dwio/text/writer/TextWriter.h b/velox/dwio/text/writer/TextWriter.h index 43389a6fbf06..2ec11b82685b 100644 --- a/velox/dwio/text/writer/TextWriter.h +++ b/velox/dwio/text/writer/TextWriter.h @@ -16,38 +16,20 @@ #pragma once -#include "velox/common/compression/Compression.h" -#include "velox/common/config/Config.h" -#include "velox/dwio/common/DataBuffer.h" #include "velox/dwio/common/FileSink.h" -#include "velox/dwio/common/FlushPolicy.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/Writer.h" #include "velox/dwio/common/WriterFactory.h" #include "velox/dwio/text/writer/BufferedWriterSink.h" -#include "velox/vector/ComplexVector.h" namespace facebook::velox::text { +using dwio::common::SerDeOptions; + struct WriterOptions : public dwio::common::WriterOptions { int64_t defaultFlushCount = 10 << 10; -}; - -// TODO: move to a separate file to be shared with text reader once it is in oss -class TextFileTraits { - public: - //// The following constants define the delimiters used by TextFile format. - /// Each row is separated by 'kNewLine'. - /// Each column is separated by 'kSOH' within each row. - - /// String for null data. - static inline const std::string kNullData = "\\N"; - - /// Delimiter between columns. - static const char kSOH = '\x01'; - - /// Delimiter between rows. - static const char kNewLine = '\n'; + uint8_t headerLineCount = + 0; // number of lines in the header, currently only support 0 or 1 }; /// Encodes Velox vectors in TextFormat and writes into a FileSink. @@ -58,10 +40,12 @@ class TextWriter : public dwio::common::Writer { /// non-null. /// @param sink output sink /// @param options writer options + /// @param serDeOptions specifies the serialization options TextWriter( RowTypePtr schema, std::unique_ptr sink, - const std::shared_ptr& options); + const std::shared_ptr& options, + const SerDeOptions& serDeOptions = SerDeOptions()); ~TextWriter() override = default; @@ -79,13 +63,22 @@ class TextWriter : public dwio::common::Writer { void abort() override; private: + uint8_t getDelimiterForDepth(uint8_t depth) const; + void writeCellValue( const std::shared_ptr& decodedColumnVector, - const TypePtr& type, - vector_size_t row); + TypeKind type, + vector_size_t row, + uint8_t depth, + std::optional delimiter); + + std::string addEscapeChar(std::string&& dataToWrite, uint8_t depth); const RowTypePtr schema_; const std::unique_ptr bufferedWriterSink_; + + uint8_t headerLineCount_; + SerDeOptions serDeOptions_; }; class TextWriterFactory : public dwio::common::WriterFactory { diff --git a/velox/examples/CMakeLists.txt b/velox/examples/CMakeLists.txt index bdad72ce9d4d..a4b29fc0efef 100644 --- a/velox/examples/CMakeLists.txt +++ b/velox/examples/CMakeLists.txt @@ -20,7 +20,8 @@ target_link_libraries( velox_core velox_expression re2::re2 - simdjson::simdjson) + simdjson::simdjson +) add_executable(velox_example_expression_eval ExpressionEval.cpp) @@ -30,7 +31,8 @@ target_link_libraries( velox_vector velox_caching velox_memory - velox_expression) + velox_expression +) add_executable(velox_example_opaque_type OpaqueType.cpp) target_link_libraries( @@ -39,7 +41,8 @@ target_link_libraries( velox_vector velox_caching velox_expression - velox_memory) + velox_memory +) # This is disabled temporarily until we figure out why g++ is crashing linking # it on linux builds. @@ -59,12 +62,12 @@ target_link_libraries( velox_exec velox_exec_test_lib velox_hive_connector - velox_memory) + velox_memory +) add_executable(velox_example_vector_reader_writer VectorReaderWriter.cpp) -target_link_libraries( - velox_example_vector_reader_writer velox_expression velox_type velox_vector) +target_link_libraries(velox_example_vector_reader_writer velox_expression velox_type velox_vector) add_executable(velox_example_operator_extensibility OperatorExtensibility.cpp) target_link_libraries( @@ -74,4 +77,5 @@ target_link_libraries( velox_exec velox_exec_test_lib velox_memory - velox_vector_test_lib) + velox_vector_test_lib +) diff --git a/velox/examples/ExpressionEval.cpp b/velox/examples/ExpressionEval.cpp index 4fcf101ee599..ef12f250b7e3 100644 --- a/velox/examples/ExpressionEval.cpp +++ b/velox/examples/ExpressionEval.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/memory/Memory.h" +#include "velox/core/Expressions.h" #include "velox/functions/Udf.h" #include "velox/type/Type.h" #include "velox/vector/BaseVector.h" @@ -102,9 +103,7 @@ int main(int argc, char** argv) { // would be automatically and recursively generated based on some input IDL // (or by a SQL string parser). auto exprTree = std::make_shared( - BIGINT(), - std::vector{fieldAccessExprNode}, - "times_two"); + BIGINT(), "times_two", fieldAccessExprNode); // Lastly, ExprSet contains the main expression evaluation logic. It takes a // vector of expression trees (if there are multiple expressions to be diff --git a/velox/examples/OpaqueType.cpp b/velox/examples/OpaqueType.cpp index 2ed05105e6e9..96757c08e1e5 100644 --- a/velox/examples/OpaqueType.cpp +++ b/velox/examples/OpaqueType.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/memory/Memory.h" +#include "velox/core/Expressions.h" #include "velox/expression/VectorFunction.h" #include "velox/functions/Udf.h" #include "velox/type/Type.h" @@ -323,9 +324,9 @@ VectorPtr evaluate( auto exprPlan = std::make_shared( OPAQUE(), - std::vector{ - fieldAccessExprNode1, fieldAccessExprNode2}, - functionName); + functionName, + fieldAccessExprNode1, + fieldAccessExprNode2); exec::ExprSet exprSet({exprPlan}, &execCtx); exec::EvalCtx evalCtx(&execCtx, &exprSet, rowVector.get()); diff --git a/velox/examples/OperatorExtensibility.cpp b/velox/examples/OperatorExtensibility.cpp index 6f48b4e9bbd6..a6e4aaec682b 100644 --- a/velox/examples/OperatorExtensibility.cpp +++ b/velox/examples/OperatorExtensibility.cpp @@ -124,8 +124,9 @@ class DuplicateRowOperator : public exec::Operator { outputChildren.reserve(input->childrenSize()); for (const auto& child : input->children()) { - outputChildren.push_back(BaseVector::wrapInDictionary( - BufferPtr(), indices, outputSize, child)); + outputChildren.push_back( + BaseVector::wrapInDictionary( + BufferPtr(), indices, outputSize, child)); } return std::make_shared( pool(), diff --git a/velox/examples/ScanAndSort.cpp b/velox/examples/ScanAndSort.cpp index 67aa80c67a32..6e27f5d01db6 100644 --- a/velox/examples/ScanAndSort.cpp +++ b/velox/examples/ScanAndSort.cpp @@ -23,13 +23,13 @@ #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" #include "velox/exec/Task.h" -#include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" #include "velox/vector/BaseVector.h" #include +#include #include using namespace facebook::velox; @@ -85,18 +85,12 @@ int main(int argc, char** argv) { // We need a connector id string to identify the connector. const std::string kHiveConnectorId = "test-hive"; - // Register the Hive Connector Factory. - connector::registerConnectorFactory( - std::make_shared()); - // Create a new connector instance from the connector factory and register - // it: - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared( - std::unordered_map())); + // Create a new connector instance and register it. + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + kHiveConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(hiveConnector); // To be able to read local files, we need to register the local file @@ -131,7 +125,7 @@ int main(int argc, char** argv) { std::shared_ptr executor( std::make_shared( - std::thread::hardware_concurrency())); + folly::hardware_concurrency())); // Task is the top-level execution concept. A task needs a taskId (as a // string), the plan fragment to execute, a destination (only used for @@ -141,7 +135,8 @@ int main(int argc, char** argv) { writerPlanFragment, /*destination=*/0, core::QueryCtx::create(executor.get()), - exec::Task::ExecutionMode::kSerial); + exec::Task::ExecutionMode::kSerial, + exec::Consumer{}); // next() starts execution using the client thread. The loop pumps output // vectors out of the task (there are none in this query fragment). @@ -171,7 +166,8 @@ int main(int argc, char** argv) { readPlanFragment, /*destination=*/0, core::QueryCtx::create(executor.get()), - exec::Task::ExecutionMode::kSerial); + exec::Task::ExecutionMode::kSerial, + exec::Consumer{}); // Now that we have the query fragment and Task structure set up, we will // add data to it via `splits`. diff --git a/velox/examples/SimpleFunctions.cpp b/velox/examples/SimpleFunctions.cpp index 4c12cb4bf510..e34b0598a227 100644 --- a/velox/examples/SimpleFunctions.cpp +++ b/velox/examples/SimpleFunctions.cpp @@ -341,7 +341,7 @@ struct MyRegexpMatchFunction { // quite expensive to compile it on a per-row basis. In this example we // support both modes (const and non-const). if (pattern != nullptr) { - re_.emplace(*pattern); + re_.emplace(std::string_view(*pattern)); } // Optionally, one could also inspect the session configs in `QueryConfig`. @@ -359,7 +359,8 @@ struct MyRegexpMatchFunction { // > `my_regexp_match(col1, col2)` result = re_.has_value() ? RE2::PartialMatch(toStringPiece(input), *re_) - : RE2::PartialMatch(toStringPiece(input), ::re2::RE2(pattern)); + : RE2::PartialMatch( + toStringPiece(input), ::re2::RE2(std::string_view(pattern))); return true; } diff --git a/velox/exec/Aggregate.cpp b/velox/exec/Aggregate.cpp index b56158145d88..003f669e66ae 100644 --- a/velox/exec/Aggregate.cpp +++ b/velox/exec/Aggregate.cpp @@ -20,7 +20,6 @@ #include "velox/exec/AggregateCompanionAdapter.h" #include "velox/exec/AggregateCompanionSignatures.h" #include "velox/exec/AggregateWindow.h" -#include "velox/expression/SignatureBinder.h" namespace facebook::velox::exec { @@ -244,32 +243,6 @@ std::vector getCompanionSignaturesWithSuffix( return entries; } -// Selects the signature by name and argument types. Returns the resolved result -// type for the type signature retrieved by the callback. Throws a user -// exception if the corresponding signature isn't found. -TypePtr resolveResultType( - const std::string& name, - const std::vector& argTypes, - const std::function& getResultType) { - auto signatures = exec::getAggregateFunctionSignatures(name); - if (!signatures.has_value()) { - VELOX_USER_FAIL("Aggregate function not registered: {}", name); - } - for (auto& signature : signatures.value()) { - SignatureBinder binder(*signature, argTypes); - if (binder.tryBind()) { - return binder.tryResolveType(getResultType(*signature)); - } - } - - std::stringstream error; - error << "Aggregate function signature is not supported: " - << toString(name, argTypes) - << ". Supported signatures: " << toString(signatures.value()) << "."; - VELOX_USER_FAIL(error.str()); -} - } // namespace std::optional getCompanionFunctionSignatures( @@ -321,24 +294,6 @@ std::unique_ptr Aggregate::create( const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& config) { - // TODO(timaou, kletkavrubashku): Reneable the validation once "regr_slope" - // signature is fixed - // - // Validate the result type. if (isPartialOutput(step)) { - // auto intermediateType = Aggregate::intermediateType(name, argTypes); - // VELOX_CHECK( - // resultType->equivalent(*intermediateType), - // "Intermediate type mismatch. Expected: {}, actual: {}", - // intermediateType->toString(), - // resultType->toString()); - // } else { - // auto finalType = Aggregate::finalType(name, argTypes); - // VELOX_CHECK( - // resultType->equivalent(*finalType), - // "Final type mismatch. Expected: {}, actual: {}", - // finalType->toString(), - // resultType->toString()); - // } // Lookup the function in the new registry first. if (auto func = getAggregateFunctionEntry(name)) { return func->factory(step, argTypes, resultType, config); @@ -347,40 +302,6 @@ std::unique_ptr Aggregate::create( VELOX_USER_FAIL("Aggregate function not registered: {}", name); } -// static -TypePtr Aggregate::intermediateType( - const std::string& name, - const std::vector& argTypes) { - auto type = resolveResultType( - name, - argTypes, - [](const AggregateFunctionSignature& signature) -> const TypeSignature& { - return signature.intermediateType(); - }); - VELOX_USER_CHECK( - type, - "Cannot resolve intermediate type for aggregate function {}", - toString(name, argTypes)); - return type; -} - -// static -TypePtr Aggregate::finalType( - const std::string& name, - const std::vector& argTypes) { - auto type = resolveResultType( - name, - argTypes, - [](const AggregateFunctionSignature& signature) -> const TypeSignature& { - return signature.returnType(); - }); - VELOX_USER_CHECK( - type, - "Cannot resolve final type for aggregate function {}", - toString(name, argTypes)); - return type; -} - void Aggregate::setLambdaExpressions( std::vector lambdaExpressions, std::shared_ptr expressionEvaluator) { diff --git a/velox/exec/Aggregate.h b/velox/exec/Aggregate.h index 8ed6c88a02e4..a3052ea9daca 100644 --- a/velox/exec/Aggregate.h +++ b/velox/exec/Aggregate.h @@ -341,18 +341,6 @@ class Aggregate { const TypePtr& resultType, const core::QueryConfig& config); - // Returns the intermediate type for 'name' with signature - // 'argTypes'. Throws if cannot resolve. - static TypePtr intermediateType( - const std::string& name, - const std::vector& argTypes); - - // Returns the final type for 'name' with signature - // 'argTypes'. Throws if cannot resolve. - static TypePtr finalType( - const std::string& name, - const std::vector& argTypes); - protected: virtual void setAllocatorInternal(HashStringAllocator* allocator); @@ -446,7 +434,7 @@ class Aggregate { if (isInitialized(group)) { auto accumulator = value(group); std::destroy_at(accumulator); - ::memset(accumulator, 0, sizeof(T)); + ::memset(static_cast(accumulator), 0, sizeof(T)); } } @@ -517,6 +505,10 @@ using AggregateFunctionFactory = std::function( const core::QueryConfig& config)>; struct AggregateFunctionMetadata { + /// True if results of the aggregation ignore duplicate values. + /// For example, min and max ignore duplicates while sum does not. + bool ignoreDuplicates{false}; + /// True if results of the aggregation depend on the order of inputs. For /// example, array_agg is order sensitive while count is not. bool orderSensitive{true}; diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index ea3c64c5c9ea..9aab4fbac9b0 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -340,25 +340,35 @@ bool registerMergeExtractFunctionInternal( mergeExtractFunctionName, std::move(mergeExtractSignatures), [name, mergeExtractFunctionName]( - core::AggregationNode::Step /*step*/, + core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, const core::QueryConfig& config) -> std::unique_ptr { - const auto& [originalResultType, _] = - resolveAggregateFunction(mergeExtractFunctionName, argTypes); - if (!originalResultType) { - // TODO: limitation -- result type must be resolvable given - // intermediate type of the original UDAF. - VELOX_UNREACHABLE( - "Signatures whose result types are not resolvable given intermediate types should have been excluded."); + TypePtr functionResultType; + if (step == core::AggregationNode::Step::kFinal || + step == core::AggregationNode::Step::kSingle) { + functionResultType = resultType; + } else { + // When step is kPartial or kIntermediate, 'resultType' is + // the intermediate type and the original result type needs to + // be resolved for the aggregate function creation. + const auto& originalResultType = + resolveResultType(mergeExtractFunctionName, argTypes); + if (!originalResultType) { + // Result type must be resolvable given intermediate type of + // the original UDAF. + VELOX_FAIL( + "Signatures' result types must be resolvable given intermediate types."); + } + functionResultType = originalResultType; } if (auto func = getAggregateFunctionEntry(name)) { auto fn = func->factory( core::AggregationNode::Step::kFinal, argTypes, - originalResultType, + functionResultType, config); VELOX_CHECK_NOT_NULL(fn); return std::make_unique< diff --git a/velox/exec/AggregateCompanionAdapter.h b/velox/exec/AggregateCompanionAdapter.h index 8a6af66ff514..fc0a5909f8a6 100644 --- a/velox/exec/AggregateCompanionAdapter.h +++ b/velox/exec/AggregateCompanionAdapter.h @@ -169,21 +169,35 @@ struct AggregateCompanionAdapter { }; }; +/// In Velox, "Step" is a property of the aggregate operator, whereas in +/// Spark, it is tied to individual aggregate functions. Spark executes +/// aggregates using a mix of partial, intermediate, and final aggregate +/// functions. To bridge the two systems, the planner translates Spark's +/// aggregate modes into corresponding Velox companion functions and assigns +/// the "single" step to Velox’s AggregationNode. These companion functions +/// are intended for internal use within the aggregate operator and are not +/// designed to be used as standalone functions, and their result types +/// may not always be inferable from intermediate types. More details can be +/// found in +/// https://github.com/facebookincubator/velox/pull/11999#issuecomment-3274577979 +/// and https://github.com/facebookincubator/velox/issues/12830. class CompanionFunctionsRegistrar { public: - // Register the partial companion function for an aggregation function of - // `name` and `signatures`. When there is already a function of the same name, - // if `overwrite` is true, the registration is replaced. Otherwise, return - // false without overwriting the registry. + /// Register the partial companion function for an aggregate function of + /// `name` and `signatures`. When there is already a function of the same + /// name, if `overwrite` is true, the registration is replaced. Otherwise, + /// return false without overwriting the registry. This function supports + /// generating Spark compatible companion functions. static bool registerPartialFunction( const std::string& name, const std::vector& signatures, const AggregateFunctionMetadata& metadata, bool overwrite = false); - // When there is already a function of the same name as the merge companion - // function, if `overwrite` is true, the registration is replaced. Otherwise, - // return false without overwriting the registry. + /// When there is already a function of the same name as the merge companion + /// function, if `overwrite` is true, the registration is replaced. Otherwise, + /// return false without overwriting the registry. This function supports + /// generating Spark compatible companion functions. static bool registerMergeFunction( const std::string& name, const std::vector& signatures, @@ -204,14 +218,16 @@ class CompanionFunctionsRegistrar { const std::vector& signatures, bool overwrite = false); - // Similar to registerExtractFunction(), the result type of the original - // aggregation function is required to be resolvable given its intermediate - // type. If there are multiple signatures of the original aggregation function - // with the same intermediate type, register merge-extract functions with + /// If there are multiple signatures of the original aggregate function + /// with the same intermediate type, register merge-extract functions with // suffix of their result types in the function names for each of them. When - // there is already a function of the same name as the merge-extract companion - // function, if `overwrite` is true, the registration is replaced. Otherwise, - // return false without overwriting the registry. + /// there is already a function of the same name as the merge-extract + /// companion function, if `overwrite` is true, the registration is replaced. + /// Otherwise, return false without overwriting the registry. This function + /// supports generating Spark compatible companion functions only when the + /// return types are explicitly specified (typically in "single" or "final" + /// steps). It will throw an exception if return types are not provided and + /// cannot be resolved from the intermediate types. static bool registerMergeExtractFunction( const std::string& name, const std::vector& signatures, diff --git a/velox/exec/AggregateCompanionSignatures.cpp b/velox/exec/AggregateCompanionSignatures.cpp index 2fa7e8d11dc6..c6da4488dc62 100644 --- a/velox/exec/AggregateCompanionSignatures.cpp +++ b/velox/exec/AggregateCompanionSignatures.cpp @@ -102,20 +102,18 @@ CompanionSignatures::partialFunctionSignatures( const std::vector& signatures) { std::vector partialSignatures; for (const auto& signature : signatures) { - if (!isResultTypeResolvableGivenIntermediateType(signature)) { - continue; - } std::vector usedTypes = signature->argumentTypes(); usedTypes.push_back(signature->intermediateType()); auto variables = usedTypeVariables(usedTypes, signature->variables()); - partialSignatures.push_back(std::make_shared( - /*variables*/ variables, - /*returnType*/ signature->intermediateType(), - /*intermediateType*/ signature->intermediateType(), - /*argumentTypes*/ signature->argumentTypes(), - /*constantArguments*/ signature->constantArguments(), - /*variableArity*/ signature->variableArity())); + partialSignatures.push_back( + std::make_shared( + /*variables*/ variables, + /*returnType*/ signature->intermediateType(), + /*intermediateType*/ signature->intermediateType(), + /*argumentTypes*/ signature->argumentTypes(), + /*constantArguments*/ signature->constantArguments(), + /*variableArity*/ signature->variableArity())); } return partialSignatures; } @@ -126,10 +124,6 @@ std::string CompanionSignatures::partialFunctionName(const std::string& name) { AggregateFunctionSignaturePtr CompanionSignatures::mergeFunctionSignature( const AggregateFunctionSignaturePtr& signature) { - if (!isResultTypeResolvableGivenIntermediateType(signature)) { - return nullptr; - } - std::vector usedTypes = {signature->intermediateType()}; auto variables = usedTypeVariables(usedTypes, signature->variables()); return std::make_shared( @@ -172,10 +166,6 @@ bool CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( AggregateFunctionSignaturePtr CompanionSignatures::mergeExtractFunctionSignature( const AggregateFunctionSignaturePtr& signature) { - if (!isResultTypeResolvableGivenIntermediateType(signature)) { - return nullptr; - } - std::vector usedTypes = { signature->intermediateType(), signature->returnType()}; auto variables = usedTypeVariables(usedTypes, signature->variables()); diff --git a/velox/exec/AggregateCompanionSignatures.h b/velox/exec/AggregateCompanionSignatures.h index ea7d4b2d6ce7..c0557adc0917 100644 --- a/velox/exec/AggregateCompanionSignatures.h +++ b/velox/exec/AggregateCompanionSignatures.h @@ -106,8 +106,9 @@ class CompanionSignatures { normalizeType(signature->intermediateType(), signature->variables()); auto normalizedReturnType = normalizeType(signature->returnType(), signature->variables()); - if (distinctIntermediateAndResultTypes.count(std::make_pair( - normalizedIntermediateType, normalizedReturnType))) { + if (distinctIntermediateAndResultTypes.count( + std::make_pair( + normalizedIntermediateType, normalizedReturnType))) { continue; } diff --git a/velox/exec/AggregateFunctionRegistry.cpp b/velox/exec/AggregateFunctionRegistry.cpp index 4214c8490f7b..d1ec1dd987d9 100644 --- a/velox/exec/AggregateFunctionRegistry.cpp +++ b/velox/exec/AggregateFunctionRegistry.cpp @@ -16,28 +16,125 @@ #include "velox/exec/AggregateFunctionRegistry.h" #include "velox/exec/Aggregate.h" -#include "velox/expression/FunctionSignature.h" #include "velox/expression/SignatureBinder.h" #include "velox/type/Type.h" namespace facebook::velox::exec { -std::pair resolveAggregateFunction( - const std::string& functionName, +namespace { +std::string makeSignatureNotSupportedError( + const std::string& name, + const std::vector& argTypes, + const std::vector>& + signatures) { + std::stringstream error; + error << "Aggregate function signature is not supported: " + << toString(name, argTypes) + << ". Supported signatures: " << toString(signatures) << "."; + return error.str(); +} +} // namespace + +TypePtr resolveResultType( + const std::string& name, + const std::vector& argTypes) { + if (auto signatures = getAggregateFunctionSignatures(name)) { + for (const auto& signature : signatures.value()) { + SignatureBinder binder(*signature, argTypes); + if (binder.tryBind()) { + return binder.tryResolveReturnType(); + } + } + + VELOX_USER_FAIL( + makeSignatureNotSupportedError(name, argTypes, signatures.value())); + } else { + VELOX_USER_FAIL("Aggregate function not registered: {}", name); + } +} + +namespace { +bool hasCoercion(const std::vector& coercions) { + for (const auto& coercion : coercions) { + if (coercion.type != nullptr) { + return true; + } + } + + return false; +} +} // namespace + +TypePtr resolveResultTypeWithCoercions( + const std::string& name, + const std::vector& argTypes, + std::vector& coercions) { + coercions.clear(); + + std::vector, TypePtr>> candidates; + if (auto signatures = getAggregateFunctionSignatures(name)) { + for (const auto& signature : signatures.value()) { + exec::SignatureBinder binder(*signature, argTypes); + std::vector requiredCoercions; + if (binder.tryBindWithCoercions(requiredCoercions)) { + auto type = binder.tryResolveReturnType(); + if (!hasCoercion(requiredCoercions)) { + coercions.resize(argTypes.size(), nullptr); + return type; + } + + candidates.emplace_back(requiredCoercions, type); + } + } + + if (auto index = Coercion::pickLowestCost(candidates)) { + const auto& requiredCoercions = candidates[index.value()].first; + coercions.reserve(requiredCoercions.size()); + for (const auto& coercion : requiredCoercions) { + coercions.push_back(coercion.type); + } + + return candidates[index.value()].second; + } + + VELOX_USER_FAIL( + makeSignatureNotSupportedError(name, argTypes, signatures.value())); + } else { + VELOX_USER_FAIL("Aggregate function not registered: {}", name); + } +} + +TypePtr resolveIntermediateType( + const std::string& name, const std::vector& argTypes) { - if (auto aggregateFunctionSignatures = - getAggregateFunctionSignatures(functionName)) { - for (const auto& signature : aggregateFunctionSignatures.value()) { + if (auto signatures = getAggregateFunctionSignatures(name)) { + for (const auto& signature : signatures.value()) { SignatureBinder binder(*signature, argTypes); if (binder.tryBind()) { - return std::make_pair( - binder.tryResolveReturnType(), - binder.tryResolveType(signature->intermediateType())); + return binder.tryResolveType(signature->intermediateType()); } } + + std::stringstream error; + error << "Aggregate function signature is not supported: " + << toString(name, argTypes) + << ". Supported signatures: " << toString(signatures.value()) << "."; + VELOX_USER_FAIL(error.str()); + } else { + VELOX_USER_FAIL("Aggregate function not registered: {}", name); } +} + +std::vector getAggregateFunctionNames() { + std::vector names; + exec::aggregateFunctions().withRLock([&](const auto& map) { + names.reserve(map.size()); + for (const auto& function : map) { + names.push_back(function.first); + } + }); - return std::make_pair(nullptr, nullptr); + return names; } } // namespace facebook::velox::exec diff --git a/velox/exec/AggregateFunctionRegistry.h b/velox/exec/AggregateFunctionRegistry.h index b8d116989480..4adea534887b 100644 --- a/velox/exec/AggregateFunctionRegistry.h +++ b/velox/exec/AggregateFunctionRegistry.h @@ -22,11 +22,41 @@ namespace facebook::velox::exec { -/// Given a name of aggregate function and argument types, returns a pair of the -/// return type and intermediate type if the function exists. Returns a pair of -/// nullptr otherwise. -std::pair resolveAggregateFunction( - const std::string& functionName, +/// Given a name of aggregate function and argument types, returns the result +/// type if the function exists. Throws if function doesn't exist or doesn't +/// support specified argument types. Since aggregate functions can be +/// integrated into internal steps of an aggregate operator — rather than +/// always being used as standalone functions at the SQL level — their result +/// types may not always be inferable from the intermediate types. As a +/// result, an exception might be thrown during the type resolution process. In +/// such cases, the caller should explicitly specify the result type. More +/// details can be found in +/// https://github.com/facebookincubator/velox/pull/11999#issuecomment-3274577979 +/// and https://github.com/facebookincubator/velox/issues/12830. +TypePtr resolveResultType( + const std::string& name, const std::vector& argTypes); +/// Like 'resolveResultType', but with support for applying type conversions if +/// a function signature doesn't match 'argTypes' exactly. +/// +/// @param coercions A list of optional type coercions that were applied to +/// resolve a function successfully. Contains one entry per argument. The entry +/// is null if no coercion is required for that argument. The entry is not null +/// if coercion is necessary. +TypePtr resolveResultTypeWithCoercions( + const std::string& name, + const std::vector& argTypes, + std::vector& coercions); + +/// Given a name of aggregate function and argument types, returns the +/// intermediate type if the function exists. Throws if function doesn't exist +/// or doesn't support specified argument types. +TypePtr resolveIntermediateType( + const std::string& name, + const std::vector& argTypes); + +/// Returns all the registered aggregation function names. +std::vector getAggregateFunctionNames(); + } // namespace facebook::velox::exec diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index 9b19ab687965..27a7ae551961 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -16,6 +16,7 @@ #include "velox/exec/AggregateInfo.h" #include "velox/exec/Aggregate.h" +#include "velox/exec/AggregateFunctionRegistry.h" #include "velox/exec/Operator.h" #include "velox/expression/Expr.h" @@ -80,10 +81,10 @@ std::vector toAggregateInfo( arg->toString()); } } + const auto& name = aggregate.call->name(); - info.distinct = aggregate.distinct; - info.intermediateType = Aggregate::intermediateType( - aggregate.call->name(), aggregate.rawInputTypes); + info.intermediateType = + resolveIntermediateType(name, aggregate.rawInputTypes); // Setup aggregation mask: convert the Variable Reference name to the // channel (projection) index, if there is a mask. @@ -96,7 +97,7 @@ std::vector toAggregateInfo( auto index = numKeys + i; const auto& aggResultType = outputType->childAt(index); info.function = Aggregate::create( - aggregate.call->name(), + name, isPartialOutput(step) ? core::AggregationNode::Step::kPartial : core::AggregationNode::Step::kSingle, aggregate.rawInputTypes, @@ -112,10 +113,13 @@ std::vector toAggregateInfo( info.function->setLambdaExpressions(lambdas, expressionEvaluator); } - // Ignore sorting properties if aggregate function is not sensitive to the - // order of inputs. - auto* entry = getAggregateFunctionEntry(aggregate.call->name()); + // 1. Ignore duplicates property + // if aggregate function is not sensitive to duplicates. + // 2. Ignore sorting properties + // if aggregate function is not sensitive to the order of inputs. + auto* entry = getAggregateFunctionEntry(name); const auto& metadata = entry->metadata; + info.distinct = !metadata.ignoreDuplicates && aggregate.distinct; if (metadata.orderSensitive) { // Sorting keys and orders. const auto numSortingKeys = aggregate.sortingKeys.size(); diff --git a/velox/exec/AssignUniqueId.h b/velox/exec/AssignUniqueId.h index 7ce3daffe8d9..a26d3da68687 100644 --- a/velox/exec/AssignUniqueId.h +++ b/velox/exec/AssignUniqueId.h @@ -37,7 +37,7 @@ class AssignUniqueId : public Operator { } bool needsInput() const override { - return true; + return input_ == nullptr; } void addInput(RowVectorPtr input) override; @@ -48,6 +48,11 @@ class AssignUniqueId : public Operator { return BlockingReason::kNotBlocked; } + bool startDrain() override { + // No need to drain for assignUniqueId operator. + return false; + } + bool isFinished() override; private: diff --git a/velox/exec/BlockingReason.cpp b/velox/exec/BlockingReason.cpp new file mode 100644 index 000000000000..ffe07aa41d84 --- /dev/null +++ b/velox/exec/BlockingReason.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/BlockingReason.h" + +namespace facebook::velox::exec { + +namespace { +const auto& blockingReasonNames() { + static const folly::F14FastMap kNames = { + {BlockingReason::kNotBlocked, "kNotBlocked"}, + {BlockingReason::kWaitForConsumer, "kWaitForConsumer"}, + {BlockingReason::kWaitForSplit, "kWaitForSplit"}, + {BlockingReason::kWaitForProducer, "kWaitForProducer"}, + {BlockingReason::kWaitForJoinBuild, "kWaitForJoinBuild"}, + {BlockingReason::kWaitForJoinProbe, "kWaitForJoinProbe"}, + {BlockingReason::kWaitForMergeJoinRightSide, + "kWaitForMergeJoinRightSide"}, + {BlockingReason::kWaitForMemory, "kWaitForMemory"}, + {BlockingReason::kWaitForConnector, "kWaitForConnector"}, + {BlockingReason::kYield, "kYield"}, + {BlockingReason::kWaitForArbitration, "kWaitForArbitration"}, + {BlockingReason::kWaitForScanScaleUp, "kWaitForScanScaleUp"}, + {BlockingReason::kWaitForIndexLookup, "kWaitForIndexLookup"}, + }; + return kNames; +} + +} // namespace + +VELOX_DEFINE_ENUM_NAME(BlockingReason, blockingReasonNames) + +} // namespace facebook::velox::exec diff --git a/velox/exec/BlockingReason.h b/velox/exec/BlockingReason.h new file mode 100644 index 000000000000..f89e808b0999 --- /dev/null +++ b/velox/exec/BlockingReason.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/Enums.h" + +namespace facebook::velox::exec { + +enum class BlockingReason { + kNotBlocked, + kWaitForConsumer, + kWaitForSplit, + /// Some operators can get blocked due to the producer(s) (they are + /// currently waiting data from) not having anything produced. Used by + /// LocalExchange, LocalMergeExchange, Exchange and MergeExchange operators. + kWaitForProducer, + kWaitForJoinBuild, + /// For a build operator, it is blocked waiting for the probe operators to + /// finish probing before build the next hash table from one of the + /// previously spilled partition data. For a probe operator, it is blocked + /// waiting for all its peer probe operators to finish probing before + /// notifying the build operators to build the next hash table from the + /// previously spilled data. + kWaitForJoinProbe, + /// Used by MergeJoin operator, indicating that it was blocked by the right + /// side input being unavailable. + kWaitForMergeJoinRightSide, + kWaitForMemory, + kWaitForConnector, + /// Some operators (like Table Scan) may run long loops and can 'voluntarily' + /// exit them because Task requested to yield or stop or after a certain time. + /// This is the blocking reason used in such cases. + kYield, + /// Operator is blocked waiting for its associated query memory arbitration to + /// finish. + kWaitForArbitration, + /// For a table scan operator, it is blocked waiting for the scan controller + /// to increase the number of table scan processing threads to start + /// processing. + kWaitForScanScaleUp, + /// Used by IndexLookupJoin operator, indicating that it was blocked by the + /// async index lookup. + kWaitForIndexLookup, +}; + +VELOX_DECLARE_ENUM_NAME(BlockingReason); +} // namespace facebook::velox::exec + +template <> +struct fmt::formatter + : formatter { + auto format(facebook::velox::exec::BlockingReason b, format_context& ctx) + const { + return formatter::format( + facebook::velox::exec::BlockingReasonName::toName(b), ctx); + } +}; diff --git a/velox/exec/CMakeLists.txt b/velox/exec/CMakeLists.txt index 7ad8f77e1166..9dd6d4ad065c 100644 --- a/velox/exec/CMakeLists.txt +++ b/velox/exec/CMakeLists.txt @@ -19,11 +19,13 @@ velox_add_library( AggregateCompanionSignatures.cpp AggregateFunctionRegistry.cpp AggregateInfo.cpp - AggregationMasks.cpp AggregateWindow.cpp + AggregationMasks.cpp ArrowStream.cpp AssignUniqueId.cpp + BlockingReason.cpp CallbackSink.cpp + ColumnStatsCollector.cpp ContainerRowSerde.cpp DistinctAggregations.cpp Driver.cpp @@ -32,6 +34,7 @@ velox_add_library( ExchangeClient.cpp ExchangeQueue.cpp ExchangeSource.cpp + SerializedPage.cpp Expand.cpp FilterProject.cpp GroupId.cpp @@ -42,6 +45,7 @@ velox_add_library( HashPartitionFunction.cpp HashProbe.cpp HashTable.cpp + HashTableCache.cpp IndexLookupJoin.cpp JoinBridge.cpp Limit.cpp @@ -54,32 +58,35 @@ velox_add_library( MergeSource.cpp NestedLoopJoinBuild.cpp NestedLoopJoinProbe.cpp + SpatialIndex.cpp + SpatialJoinBuild.cpp + SpatialJoinProbe.cpp Operator.cpp + OperatorTraceReader.cpp + OperatorTraceScan.cpp + OperatorTraceWriter.cpp OperatorUtils.cpp OrderBy.cpp OutputBuffer.cpp OutputBufferManager.cpp - OperatorTraceReader.cpp - OperatorTraceScan.cpp - OperatorTraceWriter.cpp - TaskTraceReader.cpp - TaskTraceWriter.cpp - Trace.cpp - TraceUtil.cpp - PartitionedOutput.cpp + ParallelProject.cpp PartitionFunction.cpp PartitionStreamingWindowBuild.cpp + PartitionedOutput.cpp PlanNodeStats.cpp PrefixSort.cpp ProbeOperatorState.cpp - RowsStreamingWindowBuild.cpp + SubPartitionedSortWindowBuild.cpp RowContainer.cpp RowNumber.cpp - ScaledScanController.cpp + RowsStreamingWindowBuild.cpp ScaleWriterLocalPartition.cpp + ScaledScanController.cpp SortBuffer.cpp - SortedAggregations.cpp SortWindowBuild.cpp + SortedAggregations.cpp + SpatialJoinBuild.cpp + SpatialJoinProbe.cpp Spill.cpp SpillFile.cpp Spiller.cpp @@ -89,44 +96,40 @@ velox_add_library( TableWriteMerge.cpp TableWriter.cpp Task.cpp + TaskStructs.cpp + TaskTraceReader.cpp + TaskTraceWriter.cpp TopN.cpp TopNRowNumber.cpp + Trace.cpp + TraceUtil.cpp Unnest.cpp Values.cpp VectorHasher.cpp Window.cpp WindowBuild.cpp WindowFunction.cpp - WindowPartition.cpp) + WindowPartition.cpp +) velox_link_libraries( velox_exec - velox_file + velox_arrow_bridge + velox_common_base + velox_common_compression velox_core - velox_vector velox_connector velox_expression + velox_file + velox_presto_serializer velox_time - velox_common_base velox_test_util - velox_arrow_bridge - velox_common_compression) + velox_vector +) velox_add_library(velox_cursor Cursor.cpp) -velox_link_libraries( - velox_cursor - velox_core - velox_exception - velox_expression - velox_dwio_common - velox_dwio_dwrf_reader - velox_dwio_dwrf_writer - velox_type_fbhive - velox_hive_connector - velox_tpch_connector - velox_presto_serializer - velox_functions_prestosql - velox_aggregates) + +velox_link_libraries(velox_cursor velox_core velox_exception velox_expression) if(${VELOX_BUILD_TESTING}) add_subdirectory(fuzzer) diff --git a/velox/exec/CallbackSink.h b/velox/exec/CallbackSink.h index dc8afd64763a..775592e2242c 100644 --- a/velox/exec/CallbackSink.h +++ b/velox/exec/CallbackSink.h @@ -26,8 +26,9 @@ class CallbackSink : public Operator { int32_t operatorId, DriverCtx* driverCtx, Consumer consumeCb, - std::function startedCb = nullptr) - : Operator(driverCtx, nullptr, operatorId, "N/A", "CallbackSink"), + std::function startedCb = nullptr, + const std::string& planNodeId = "N/A") + : Operator(driverCtx, nullptr, operatorId, planNodeId, "CallbackSink"), startedCb_{std::move(startedCb)}, consumeCb_{std::move(consumeCb)} {} diff --git a/velox/exec/ColumnStatsCollector.cpp b/velox/exec/ColumnStatsCollector.cpp new file mode 100644 index 000000000000..d1bc68b1fa9f --- /dev/null +++ b/velox/exec/ColumnStatsCollector.cpp @@ -0,0 +1,213 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/ColumnStatsCollector.h" +#include "velox/exec/Aggregate.h" +#include "velox/exec/AggregateFunctionRegistry.h" +#include "velox/exec/AggregateInfo.h" +#include "velox/exec/VectorHasher.h" + +namespace facebook::velox::exec { + +ColumnStatsCollector::ColumnStatsCollector( + const core::ColumnStatsSpec& statsSpec, + const RowTypePtr& inputType, + const core::QueryConfig* queryConfig, + memory::MemoryPool* pool, + tsan_atomic* nonReclaimableSection) + : statsSpec_(statsSpec), + inputType_(inputType), + queryConfig_(queryConfig), + pool_(pool), + nonReclaimableSection_(nonReclaimableSection), + maxOutputBatchRows_( + statsSpec_.groupingKeys.empty() ? 1 : kMaxOutputBatchRows) { + VELOX_CHECK_NOT_NULL(inputType_); + VELOX_CHECK_NOT_NULL(queryConfig_); + VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK_NOT_NULL(nonReclaimableSection_); +} + +void ColumnStatsCollector::initialize() { + if (initialized_) { + return; + } + SCOPE_EXIT { + initialized_ = true; + }; + setOutputType(); + VELOX_CHECK_NOT_NULL(outputType_); + createGroupingSet(); + VELOX_CHECK_NOT_NULL(groupingSet_); +} + +void ColumnStatsCollector::setOutputType() { + VELOX_CHECK_NULL(outputType_); + outputType_ = statsSpec_.outputType(); +} + +std::pair, std::vector> +ColumnStatsCollector::setupGroupingKeyChannelProjections() const { + const auto& groupingKeys = statsSpec_.groupingKeys; + std::vector groupingKeyInputChannels; + groupingKeyInputChannels.reserve(groupingKeys.size()); + for (auto i = 0; i < groupingKeys.size(); ++i) { + groupingKeyInputChannels.push_back( + exprToChannel(groupingKeys[i].get(), inputType_)); + } + + std::vector groupingKeyOutputChannels(groupingKeys.size()); + std::iota( + groupingKeyOutputChannels.begin(), groupingKeyOutputChannels.end(), 0); + return std::make_pair(groupingKeyInputChannels, groupingKeyOutputChannels); +} + +std::vector ColumnStatsCollector::createAggregates( + size_t numGroupingKeys) { + VELOX_CHECK_NOT_NULL(outputType_); + const auto step = statsSpec_.aggregationStep; + const auto numAggregates = statsSpec_.aggregates.size(); + std::vector aggregateInfos; + aggregateInfos.reserve(numAggregates); + for (auto i = 0; i < numAggregates; ++i) { + const auto& aggregate = statsSpec_.aggregates[i]; + AggregateInfo info; + auto& channels = info.inputs; + auto& constants = info.constantInputs; + for (const auto& arg : aggregate.call->inputs()) { + if (auto field = + dynamic_cast(arg.get())) { + channels.push_back(inputType_->getChildIdx(field->name())); + constants.push_back(nullptr); + } else { + VELOX_FAIL( + "Aggregation expression must be field access for column stats collection: {}", + arg->toString()); + } + } + VELOX_CHECK(!aggregate.distinct); + info.intermediateType = resolveIntermediateType( + aggregate.call->name(), aggregate.rawInputTypes); + // Column stats collection doesn't support aggregation mask. + VELOX_CHECK_NULL(aggregate.mask); + info.mask = std::nullopt; + const auto outputChannel = numGroupingKeys + i; + const auto& aggResultType = outputType_->childAt(outputChannel); + info.function = Aggregate::create( + aggregate.call->name(), + isPartialOutput(step) ? core::AggregationNode::Step::kPartial + : core::AggregationNode::Step::kSingle, + aggregate.rawInputTypes, + aggResultType, + *queryConfig_); + VELOX_CHECK(aggregate.sortingKeys.empty()); + VELOX_CHECK(aggregate.sortingOrders.empty()); + info.output = outputChannel; + aggregateInfos.emplace_back(std::move(info)); + } + return aggregateInfos; +} + +void ColumnStatsCollector::createGroupingSet() { + VELOX_CHECK_NULL(groupingSet_); + + auto [groupingKeyInputChannels, groupingKeyOutputChannels] = + setupGroupingKeyChannelProjections(); + + auto hashers = createVectorHashers(inputType_, groupingKeyInputChannels); + const auto numHashers = hashers.size(); + + // Setup aggregates based on the column stats specifications + auto aggregateInfos = createAggregates(numHashers); + // Create the grouping set for aggregation execution. + groupingSet_ = std::make_unique( + inputType_, + std::move(hashers), + /*preGroupedKeys=*/std::vector{}, + std::move(groupingKeyOutputChannels), + std::move(aggregateInfos), + /*ignoreNullKey=*/false, + /*isPartial=*/isPartialOutput(statsSpec_.aggregationStep), + /*isRawInput=*/isRawInput(statsSpec_.aggregationStep), + /*globalGroupingSets=*/std::vector{}, + /*groupIdChannel=*/std::nullopt, + /*spillConfig=*/nullptr, + nonReclaimableSection_, + queryConfig_, + pool_, + /*spillStats=*/nullptr); +} + +void ColumnStatsCollector::addInput(RowVectorPtr input) { + VELOX_CHECK_NOT_NULL(input); + VELOX_CHECK(initialized_); + + if (input->size() == 0) { + return; + } + + // Add input to the grouping set + groupingSet_->addInput(input, /*mayPushdown=*/false); +} + +void ColumnStatsCollector::noMoreInput() { + if (!noMoreInput_) { + noMoreInput_ = true; + if (groupingSet_) { + groupingSet_->noMoreInput(); + } + } +} + +void ColumnStatsCollector::prepareOutput() { + if (output_) { + VectorPtr output = std::move(output_); + BaseVector::prepareForReuse(output, maxOutputBatchRows_); + output_ = std::static_pointer_cast(output); + } else { + output_ = std::static_pointer_cast( + BaseVector::create(outputType_, maxOutputBatchRows_, pool_)); + } +} + +RowVectorPtr ColumnStatsCollector::getOutput() { + VELOX_CHECK(initialized_); + + if (!groupingSet_ || !noMoreInput_ || finished_) { + return nullptr; + } + + prepareOutput(); + const bool hasMoreOutput = groupingSet_->getOutput( + maxOutputBatchRows_, kMaxOutputBatchBytes, outputIterator_, output_); + if (!hasMoreOutput) { + finished_ = true; + return nullptr; + } + return output_; +} + +bool ColumnStatsCollector::finished() const { + return finished_; +} + +void ColumnStatsCollector::close() { + if (groupingSet_) { + groupingSet_.reset(); + } +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/ColumnStatsCollector.h b/velox/exec/ColumnStatsCollector.h new file mode 100644 index 000000000000..5c76f9af3915 --- /dev/null +++ b/velox/exec/ColumnStatsCollector.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/exec/GroupingSet.h" + +namespace facebook::velox::exec { + +/// ColumnStatsCollector collects column statistics during table write +/// operations by leveraging the GroupingSet aggregation framework. It supports +/// both partitioned and unpartitioned tables, computing aggregated statistics +/// such as min, max, count, etc. for specified columns. +/// +/// Usage pattern: +/// 1. Create instance with ColumnStatsSpec defining desired statistics +/// 2. Call initialize() once before processing any input +/// 3. Call addInput() repeatedly with data batches +/// 4. Call noMoreInput() when all input has been provided +/// 5. Call getOutput() until finished() returns true to retrieve all results +/// 6. Call close() to cleanup resources +class ColumnStatsCollector { + public: + /// Constructs a ColumnStatsCollector for collecting statistics during table + /// writes. + /// @param statsSpec Specification defining grouping keys, aggregation step, + /// and aggregation functions to compute + /// @param inputType Schema of the input data that will be processed + /// @param queryConfig Query configuration settings + /// @param pool Memory pool for allocations + /// @param nonReclaimableSection Atomic flag indicating non-reclaimable memory + /// section + ColumnStatsCollector( + const core::ColumnStatsSpec& statsSpec, + const RowTypePtr& inputType, + const core::QueryConfig* queryConfig, + memory::MemoryPool* pool, + tsan_atomic* nonReclaimableSection); + + /// Initializes the stats collector. Must be called exactly once before + /// adding any input data. Sets up internal aggregation structures based + /// on the provided ColumnStatsSpec. + void initialize(); + + /// Adds a batch of input data for statistics collection. Can be called + /// multiple times with different batches until all input data has been + /// processed. + /// @param input Batch of input data to process for statistics collection + void addInput(RowVectorPtr input); + + /// Signals that no more input data will be provided. Must be called after + /// all addInput() calls to finalize the aggregation computation. + void noMoreInput(); + + /// Retrieves the computed column statistics. For partitioned tables, results + /// are returned one partition at a time, so this method may need to be called + /// multiple times until finished() returns true. + RowVectorPtr getOutput(); + + /// Checks whether all computed statistics have been returned. For partitioned + /// tables, there is one output row per partition, so multiple getOutput() + /// calls may be required. Returns true when all statistics have been + /// retrieved. + bool finished() const; + + /// Cleans up and releases all resources used by the stats collector. + /// Should be called after all statistics have been retrieved and the + /// collector is no longer needed. + void close(); + + private: + void setOutputType(); + + // Creates the grouping key channel projections for the column stats + // collection for partitioned table write with one group per each table + // partition. + std::pair, std::vector> + setupGroupingKeyChannelProjections() const; + + void createGroupingSet(); + + std::vector createAggregates(size_t numGroupingKeys); + + void prepareOutput(); + + static const int kMaxOutputBatchRows = 512; + static const int kMaxOutputBatchBytes = 128 << 20; + + const core::ColumnStatsSpec statsSpec_; + const RowTypePtr inputType_; + const core::QueryConfig* const queryConfig_; + memory::MemoryPool* const pool_; + tsan_atomic* const nonReclaimableSection_; + const vector_size_t maxOutputBatchRows_; + + bool initialized_{false}; + bool noMoreInput_{false}; + bool finished_{false}; + + std::unique_ptr groupingSet_; + RowTypePtr outputType_; + RowVectorPtr output_; + RowContainerIterator outputIterator_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/ContainerRowSerde.cpp b/velox/exec/ContainerRowSerde.cpp index fc85ad61327e..b571a0a7b95c 100644 --- a/velox/exec/ContainerRowSerde.cpp +++ b/velox/exec/ContainerRowSerde.cpp @@ -388,11 +388,8 @@ std::optional compareSwitch( template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t< - Kind != TypeKind::VARCHAR && Kind != TypeKind::VARBINARY && - Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && - Kind != TypeKind::ROW, - int32_t> = 0> + std::enable_if_t = + 0> std::optional compare( ByteInputStream& left, const BaseVector& right, @@ -693,11 +690,8 @@ std::optional compareSwitch( template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t< - Kind != TypeKind::VARCHAR && Kind != TypeKind::VARBINARY && - Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && - Kind != TypeKind::ROW, - int32_t> = 0> + std::enable_if_t = + 0> std::optional compare( ByteInputStream& left, ByteInputStream& right, @@ -898,11 +892,8 @@ uint64_t hashSwitch(ByteInputStream& stream, const Type* type); template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t< - Kind != TypeKind::VARBINARY && Kind != TypeKind::VARCHAR && - Kind != TypeKind::ARRAY && Kind != TypeKind::MAP && - Kind != TypeKind::ROW, - int32_t> = 0> + std::enable_if_t = + 0> uint64_t hashOne(ByteInputStream& stream, const Type* type) { using T = typename TypeTraits::NativeType; @@ -936,21 +927,22 @@ uint64_t hashOne(ByteInputStream& stream, const Type* /*type*/) { return folly::hasher()(readStringView(stream, storage)); } -template -uint64_t -hashArray(ByteInputStream& in, uint64_t hash, const Type* elementType) { - auto size = in.read(); +template +void hashArray( + ByteInputStream& in, + int32_t size, + const Type* elementType, + const TFunc& consumer) { NullsReader nulls(in, size); for (auto i = 0; i < size; ++i) { - uint64_t value; + uint64_t hash; if (bits::isBitSet(nulls.data(), i)) { - value = BaseVector::kNullHash; + hash = BaseVector::kNullHash; } else { - value = hashSwitch(in, elementType); + hash = hashSwitch(in, elementType); } - hash = bits::commutativeHashMix(hash, value); + consumer(hash); } - return hash; } template < @@ -985,11 +977,20 @@ template < uint64_t hashOne(ByteInputStream& in, const Type* type) { const auto& elementType = type->childAt(0); + auto size = in.read(); + + uint64_t hash = BaseVector::kNullHash; if (elementType->providesCustomComparison()) { - return hashArray(in, BaseVector::kNullHash, elementType.get()); + hashArray(in, size, elementType.get(), [&](uint64_t value) { + hash = bits::hashMix(hash, value); + }); } else { - return hashArray(in, BaseVector::kNullHash, elementType.get()); + hashArray(in, size, elementType.get(), [&](uint64_t value) { + hash = bits::hashMix(hash, value); + }); } + + return hash; } template < @@ -1000,18 +1001,39 @@ uint64_t hashOne(ByteInputStream& in, const Type* type) { const auto& keyType = type->childAt(0); const auto& valueType = type->childAt(1); - uint64_t hash; + auto size = in.read(); + + std::vector keyHashes; + keyHashes.reserve(size); if (keyType->providesCustomComparison()) { - hash = hashArray(in, BaseVector::kNullHash, keyType.get()); + hashArray(in, size, keyType.get(), [&](uint64_t value) { + keyHashes.push_back(value); + }); } else { - hash = hashArray(in, BaseVector::kNullHash, keyType.get()); + hashArray(in, size, keyType.get(), [&](uint64_t value) { + keyHashes.push_back(value); + }); } + uint64_t hash = BaseVector::kNullHash; + size_t i = 0; + + auto updateHash = [&](uint64_t valueHash) { + hash = bits::commutativeHashMix( + hash, bits::hashMix(keyHashes.at(i), valueHash)); + ++i; + }; + + auto valuesSize = in.read(); + VELOX_CHECK_EQ(size, valuesSize); + if (valueType->providesCustomComparison()) { - return hashArray(in, hash, valueType.get()); + hashArray(in, size, valueType.get(), updateHash); } else { - return hashArray(in, hash, valueType.get()); + hashArray(in, size, valueType.get(), updateHash); } + + return hash; } template diff --git a/velox/exec/Cursor.cpp b/velox/exec/Cursor.cpp index 1e3de91686c1..a81d74f2ac5c 100644 --- a/velox/exec/Cursor.cpp +++ b/velox/exec/Cursor.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ #include "velox/exec/Cursor.h" +#include #include "velox/common/file/FileSystems.h" -#include "velox/exec/Operator.h" #include @@ -68,7 +68,8 @@ exec::BlockingReason TaskQueue::enqueue( consumerPromise_.setValue(); } if (totalBytes_ > maxBytes_) { - auto [unblockPromise, unblockFuture] = makeVeloxContinuePromiseContract(); + auto [unblockPromise, unblockFuture] = + makeVeloxContinuePromiseContract("TaskQueue::enqueue"); producerUnblockPromises_.emplace_back(std::move(unblockPromise)); *future = std::move(unblockFuture); return exec::BlockingReason::kWaitForConsumer; @@ -100,7 +101,7 @@ RowVectorPtr TaskQueue::dequeue() { } if (!vector) { consumerBlocked_ = true; - consumerPromise_ = ContinuePromise(); + consumerPromise_ = ContinuePromise("TaskQueue::dequeue"); consumerFuture_ = consumerPromise_.getFuture(); } } @@ -175,22 +176,25 @@ class TaskCursorBase : public TaskCursor { if (!params.spillDirectory.empty()) { taskSpillDirectory_ = params.spillDirectory + "/" + taskId_; - auto fileSystem = - velox::filesystems::getFileSystem(taskSpillDirectory_, nullptr); - VELOX_CHECK_NOT_NULL(fileSystem, "File System is null!"); - try { - fileSystem->mkdir(taskSpillDirectory_); - } catch (...) { - LOG(ERROR) << "Faield to create task spill directory " - << taskSpillDirectory_ << " base director " - << params.spillDirectory << " exists[" - << std::filesystem::exists(taskSpillDirectory_) << "]"; - - std::rethrow_exception(std::current_exception()); - } + taskSpillDirectoryCb_ = params.spillDirectoryCallback; + if (taskSpillDirectoryCb_ == nullptr) { + auto fileSystem = + velox::filesystems::getFileSystem(taskSpillDirectory_, nullptr); + VELOX_CHECK_NOT_NULL(fileSystem, "File System is null!"); + try { + fileSystem->mkdir(taskSpillDirectory_); + } catch (...) { + LOG(ERROR) << "Faield to create task spill directory " + << taskSpillDirectory_ << " base director " + << params.spillDirectory << " exists[" + << std::filesystem::exists(taskSpillDirectory_) << "]"; + + std::rethrow_exception(std::current_exception()); + } - LOG(INFO) << "Task spill directory[" << taskSpillDirectory_ - << "] created"; + LOG(INFO) << "Task spill directory[" << taskSpillDirectory_ + << "] created"; + } } } @@ -199,6 +203,7 @@ class TaskCursorBase : public TaskCursor { std::shared_ptr queryCtx_; core::PlanFragment planFragment_; std::string taskSpillDirectory_; + std::function taskSpillDirectoryCb_; private: std::shared_ptr executor_; @@ -210,7 +215,7 @@ class MultiThreadedTaskCursor : public TaskCursorBase { : TaskCursorBase( params, std::make_shared( - std::thread::hardware_concurrency())), + folly::hardware_concurrency())), maxDrivers_{params.maxDrivers}, numConcurrentSplitGroups_{params.numConcurrentSplitGroups}, numSplitGroups_{params.numSplitGroups} { @@ -223,7 +228,14 @@ class MultiThreadedTaskCursor : public TaskCursorBase { std::make_shared(params.bufferedBytes, params.outputPool); // Captured as a shared_ptr by the consumer callback of task_. - auto queue = queue_; + auto queueHolder = std::weak_ptr(queue_); + std::optional spillDiskOpts; + if (!taskSpillDirectory_.empty()) { + spillDiskOpts = common::SpillDiskOptions{ + .spillDirPath = taskSpillDirectory_, + .spillDirCreated = taskSpillDirectoryCb_ == nullptr, + .spillDirCreateCb = taskSpillDirectoryCb_}; + } task_ = Task::create( taskId_, std::move(planFragment_), @@ -231,10 +243,15 @@ class MultiThreadedTaskCursor : public TaskCursorBase { std::move(queryCtx_), Task::ExecutionMode::kParallel, // consumer - [queue, copyResult = params.copyResult]( + [queueHolder, copyResult = params.copyResult, taskId = taskId_]( const RowVectorPtr& vector, bool drained, velox::ContinueFuture* future) { + auto queue = queueHolder.lock(); + if (queue == nullptr) { + LOG(ERROR) << "TaskQueue has been destroyed, taskId: " << taskId; + return exec::BlockingReason::kNotBlocked; + } VELOX_CHECK( !drained, "Unexpected drain in multithreaded task cursor"); if (!vector || !copyResult) { @@ -250,16 +267,18 @@ class MultiThreadedTaskCursor : public TaskCursorBase { return queue->enqueue(std::move(copy), future); }, 0, - [queue](std::exception_ptr) { + std::move(spillDiskOpts), + [queueHolder, taskId = taskId_](std::exception_ptr) { // onError close the queue to unblock producers and consumers. // moveNext will handle rethrowing the error once it's // unblocked. + auto queue = queueHolder.lock(); + if (queue == nullptr) { + LOG(ERROR) << "TaskQueue has been destroyed, taskId: " << taskId; + return; + } queue->close(); }); - - if (!taskSpillDirectory_.empty()) { - task_->setSpillDirectory(taskSpillDirectory_); - } } ~MultiThreadedTaskCursor() override { @@ -372,17 +391,22 @@ class SingleThreadedTaskCursor : public TaskCursorBase { VELOX_CHECK( !queryCtx_->isExecutorSupplied(), "Executor should not be set in serial task cursor"); - + std::optional spillDiskOpts; + if (!taskSpillDirectory_.empty()) { + spillDiskOpts = common::SpillDiskOptions{ + .spillDirPath = taskSpillDirectory_, + .spillDirCreated = true, + .spillDirCreateCb = taskSpillDirectoryCb_}; + } task_ = Task::create( taskId_, std::move(planFragment_), params.destination, std::move(queryCtx_), - Task::ExecutionMode::kSerial); - - if (!taskSpillDirectory_.empty()) { - task_->setSpillDirectory(taskSpillDirectory_); - } + Task::ExecutionMode::kSerial, + std::function{}, + 0, + std::move(spillDiskOpts)); VELOX_CHECK( task_->supportSerialExecutionMode(), diff --git a/velox/exec/Cursor.h b/velox/exec/Cursor.h index ac067703f85d..7e5f48802c02 100644 --- a/velox/exec/Cursor.h +++ b/velox/exec/Cursor.h @@ -14,8 +14,9 @@ * limitations under the License. */ #pragma once -#include + #include "velox/core/PlanNode.h" +#include "velox/exec/Driver.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { @@ -70,6 +71,12 @@ struct CursorParameters { /// would be built from it. std::string spillDirectory; + /// Callback function to dynamically create or determine the spill directory + /// path at runtime. If provided, this callback is invoked when spilling is + /// needed and must return a valid directory path. This allows for dynamic + /// spill directory creation or path resolution based on runtime conditions. + std::function spillDirectoryCallback; + bool copyResult = true; /// If true, use serial execution mode. Use parallel execution mode @@ -134,7 +141,7 @@ class TaskQueue { std::mutex mutex_; std::vector producerUnblockPromises_; bool consumerBlocked_ = false; - ContinuePromise consumerPromise_; + ContinuePromise consumerPromise_{ContinuePromise::makeEmpty()}; ContinueFuture consumerFuture_; bool closed_ = false; }; diff --git a/velox/exec/DistinctAggregations.cpp b/velox/exec/DistinctAggregations.cpp index f6f9db5a2454..f29cb91d8f3e 100644 --- a/velox/exec/DistinctAggregations.cpp +++ b/velox/exec/DistinctAggregations.cpp @@ -32,9 +32,10 @@ class TypedDistinctAggregations : public DistinctAggregations { : pool_{pool}, aggregates_{std::move(aggregates)}, inputs_{aggregates_[0]->inputs}, - inputType_(TypedDistinctAggregations::makeInputTypeForAccumulator( - inputType, - inputs_)) {} + inputType_( + TypedDistinctAggregations::makeInputTypeForAccumulator( + inputType, + inputs_)) {} /// Returns metadata about the accumulator used to store unique inputs. Accumulator accumulator() const override { @@ -49,6 +50,9 @@ class TypedDistinctAggregations : public DistinctAggregations { }, [this](folly::Range groups) { for (auto* group : groups) { + if (!isInitialized(group)) { + continue; + } auto* accumulator = reinterpret_cast(group + offset_); accumulator->free(*allocator_); @@ -126,11 +130,11 @@ class TypedDistinctAggregations : public DistinctAggregations { // Overwrite empty groups over the destructed groups to keep the container // in a well formed state. - raw_vector temp; + raw_vector indices(pool_); aggregate.function->initializeNewGroups( groups.data(), folly::Range( - iota(groups.size(), temp), groups.size())); + iota(groups.size(), indices), groups.size())); } } diff --git a/velox/exec/DistinctAggregations.h b/velox/exec/DistinctAggregations.h index ffb6d1b1a4c3..fdd13a89fd02 100644 --- a/velox/exec/DistinctAggregations.h +++ b/velox/exec/DistinctAggregations.h @@ -101,6 +101,10 @@ class DistinctAggregations { char** groups, folly::Range indices) = 0; + bool isInitialized(char* group) const { + return group[initializedByte_] & initializedMask_; + } + HashStringAllocator* allocator_; int32_t offset_; int32_t nullByte_; diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 7a9d3654ae9a..ae96bc6c1bdc 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -18,6 +18,7 @@ #include "velox/common/process/TraceContext.h" #include "velox/exec/Task.h" +#include "velox/vector/LazyVector.h" using facebook::velox::common::testutil::TestValue; @@ -141,10 +142,12 @@ std::optional DriverCtx::makeSpillConfig( queryConfig.maxSpillRunRows(), queryConfig.writerFlushThresholdBytes(), queryConfig.spillCompressionKind(), + queryConfig.spillNumMaxMergeFiles(), queryConfig.spillPrefixSortEnabled() ? std::optional(prefixSortConfig()) : std::nullopt, - queryConfig.spillFileCreateConfig()); + queryConfig.spillFileCreateConfig(), + queryConfig.windowSpillMinReadBatchRows()); } std::atomic_uint64_t BlockingState::numBlockedDrivers_{0}; @@ -158,9 +161,10 @@ BlockingState::BlockingState( future_(std::move(future)), operator_(op), reason_(reason), - sinceUs_(std::chrono::duration_cast( - std::chrono::high_resolution_clock::now().time_since_epoch()) - .count()) { + sinceUs_( + std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()) + .count()) { // Set before leaving the thread. driver_->state().hasBlockingFuture = true; numBlockedDrivers_++; @@ -247,6 +251,8 @@ void Driver::init( std::vector> operators) { VELOX_CHECK_NULL(ctx_); ctx_ = std::move(ctx); + enableOperatorBatchSizeStats_ = + ctx_->queryConfig().enableOperatorBatchSizeStats(); cpuSliceMs_ = task()->driverCpuTimeSliceLimitMs(); VELOX_CHECK(operators_.empty()); operators_ = std::move(operators); @@ -264,55 +270,6 @@ void Driver::initializeOperators() { } } -void Driver::pushdownFilters(int operatorIndex) { - auto* op = operators_[operatorIndex].get(); - const auto& filters = op->getDynamicFilters(); - if (filters.empty()) { - return; - } - const auto& planNodeId = op->planNodeId(); - - op->addRuntimeStat("dynamicFiltersProduced", RuntimeCounter(filters.size())); - - // Walk operator list upstream and find a place to install the filters. - for (const auto& entry : filters) { - auto channel = entry.first; - for (auto i = operatorIndex - 1; i >= 0; --i) { - auto prevOp = operators_[i].get(); - - if (i == 0) { - // Source operator. - VELOX_CHECK( - prevOp->canAddDynamicFilter(), - "Cannot push down dynamic filters produced by {}", - op->toString()); - prevOp->addDynamicFilter(planNodeId, channel, entry.second); - prevOp->addRuntimeStat("dynamicFiltersAccepted", RuntimeCounter(1)); - break; - } - - const auto& identityProjections = prevOp->identityProjections(); - const auto inputChannel = - getIdentityProjection(identityProjections, channel); - if (!inputChannel.has_value()) { - // Filter channel is not an identity projection. - VELOX_CHECK( - prevOp->canAddDynamicFilter(), - "Cannot push down dynamic filters produced by {}", - op->toString()); - prevOp->addDynamicFilter(planNodeId, channel, entry.second); - prevOp->addRuntimeStat("dynamicFiltersAccepted", RuntimeCounter(1)); - break; - } - - // Continue walking upstream. - channel = inputChannel.value(); - } - } - - op->clearDynamicFilters(); -} - RowVectorPtr Driver::next( ContinueFuture* future, Operator*& blockingOp, @@ -456,11 +413,12 @@ CpuWallTiming Driver::processLazyIoStats( cpuDelta = std::min(cpuDelta, timing.cpuNanos); wallDelta = std::min(wallDelta, timing.wallNanos); lockStats = operators_[0]->stats().wlock(); - lockStats->getOutputTiming.add(CpuWallTiming{ - 1, - static_cast(wallDelta), - static_cast(cpuDelta), - }); + lockStats->getOutputTiming.add( + CpuWallTiming{ + 1, + static_cast(wallDelta), + static_cast(cpuDelta), + }); lockStats->inputBytes += inputBytesDelta; lockStats->outputBytes += inputBytesDelta; return CpuWallTiming{ @@ -482,6 +440,12 @@ bool Driver::checkUnderArbitration(ContinueFuture* future) { } namespace { +inline void addInput(Operator* op, const RowVectorPtr& input) { + if (FOLLY_LIKELY(!op->dryRun())) { + op->addInput(input); + } +} + inline void getOutput(Operator* op, RowVectorPtr& result) { result = op->getOutput(); if (FOLLY_UNLIKELY(op->shouldDropOutput())) { @@ -618,15 +582,14 @@ StopReason Driver::runInternal( kOpMethodGetOutput); if (intermediateResult) { validateOperatorOutputResult(intermediateResult, *op); - resultBytes = intermediateResult->estimateFlatSize(); - { - auto lockedStats = op->stats().wlock(); - lockedStats->addOutputVector( - resultBytes, intermediateResult->size()); + if (enableOperatorBatchSizeStats()) { + resultBytes = intermediateResult->estimateFlatSize(); } + auto lockedStats = op->stats().wlock(); + lockedStats->addOutputVector( + resultBytes, intermediateResult->size()); } }); - pushdownFilters(i); if (intermediateResult) { withDeltaCpuWallTimer( nextOp, &OperatorStats::addInputTiming, [&]() { @@ -641,7 +604,7 @@ StopReason Driver::runInternal( nextOp); CALL_OPERATOR( - nextOp->addInput(intermediateResult), + addInput(nextOp, intermediateResult), nextOp, curOperatorId_ + 1, kOpMethodAddInput); @@ -708,12 +671,12 @@ StopReason Driver::runInternal( getOutput(op, result), op, curOperatorId_, kOpMethodGetOutput); if (result) { validateOperatorOutputResult(result, *op); - - { - auto lockedStats = op->stats().wlock(); - lockedStats->addOutputVector( - result->estimateFlatSize(), result->size()); + vector_size_t resultByteSize{0}; + if (enableOperatorBatchSizeStats()) { + resultByteSize = result->estimateFlatSize(); } + auto lockedStats = op->stats().wlock(); + lockedStats->addOutputVector(resultByteSize, result->size()); } }); @@ -737,14 +700,13 @@ StopReason Driver::runInternal( close(); return StopReason::kAtEnd; } - pushdownFilters(i); continue; } } } } catch (velox::VeloxException&) { task()->setError(std::current_exception()); - // The CancelPoolGuard will close 'self' and remove from Task. + // The CancelGuard will close 'self' and remove from Task. return StopReason::kAlreadyTerminated; } catch (std::exception&) { task()->setError(std::current_exception()); @@ -840,6 +802,19 @@ void Driver::closeOperators() { for (auto& op : operators_) { auto stats = op->stats(true); stats.numDrivers = 1; + + // Calculate this driver's CPU time for this specific operator and add it as + // a runtime stat. This will be aggregated across all drivers, with the max + // field containing the CPU time from the longest running driver. + uint64_t operatorCpuNanos = stats.addInputTiming.cpuNanos + + stats.getOutputTiming.cpuNanos + stats.finishTiming.cpuNanos + + stats.isBlockedTiming.cpuNanos; + + if (operatorCpuNanos > 0) { + stats.runtimeStats[OperatorStats::kDriverCpuTime] = + RuntimeMetric(operatorCpuNanos, RuntimeCounter::Unit::kNanos); + } + task()->addOperatorStats(stats); } } @@ -968,22 +943,23 @@ bool Driver::mayPushdownAggregation(Operator* aggregation) const { aggregation->toString()); } -std::unordered_set Driver::canPushdownFilters( - const Operator* filterSource, - const std::vector& channels) const { - int filterSourceIndex = -1; +int Driver::operatorIndex(const Operator* op) const { + int index = -1; for (auto i = 0; i < operators_.size(); ++i) { - auto op = operators_[i].get(); - if (filterSource == op) { - filterSourceIndex = i; + if (op == operators_[i].get()) { + index = i; break; } } VELOX_CHECK_GE( - filterSourceIndex, - 0, - "Operator not found in its Driver: {}", - filterSource->toString()); + index, 0, "Operator not found in its Driver: {}", op->toString()); + return index; +} + +std::unordered_set Driver::canPushdownFilters( + const Operator* filterSource, + const std::vector& channels) const { + const int filterSourceIndex = operatorIndex(filterSource); std::unordered_set supportedChannels; for (auto i = 0; i < channels.size(); ++i) { @@ -1018,6 +994,73 @@ std::unordered_set Driver::canPushdownFilters( return supportedChannels; } +int Driver::pushdownFilters( + Operator* filterSource, + const std::vector& channels, + const std::function& makeFilter) { + const int filterSourceIndex = operatorIndex(filterSource); + int numFiltersProduced = 0; + std::vector numFiltersAccepted(filterSourceIndex); + for (auto i = 0; i < channels.size(); ++i) { + auto channel = channels[i]; + int j = -1; + for (j = filterSourceIndex - 1; j >= 0; --j) { + auto* prevOp = operators_[j].get(); + if (j == 0) { + // Source operator. + break; + } + const auto& identityProjections = prevOp->identityProjections(); + const auto inputChannel = + getIdentityProjection(identityProjections, channel); + if (!inputChannel.has_value()) { + // Filter channel is not an identity projection. + break; + } + // Continue walking upstream. + channel = inputChannel.value(); + } + if (!(j >= 0 && operators_[j]->canAddDynamicFilter())) { + continue; + } + common::FilterPtr filter; + auto lkSource = pushdownFilters_->at(filterSourceIndex).wlock(); + if (makeFilter(i, filter)) { + if (filter) { + // A new filter is generated. + auto lkTarget = pushdownFilters_->at(j).wlock(); + common::Filter::merge(filter, lkTarget->filters[channel]); + lkTarget->dynamicFilteredColumns.insert(channel); + } else { + // Same filter is already generated by another operator on the same + // node. Just do some sanity check here. + auto lkTarget = pushdownFilters_->at(j).rlock(); + VELOX_CHECK( + lkTarget->filters.at(channel) && + lkTarget->dynamicFilteredColumns.contains(channel)); + } + ++numFiltersProduced; + ++numFiltersAccepted[j]; + } + } + for (int j = 0; j < filterSourceIndex; ++j) { + if (numFiltersAccepted[j] == 0) { + continue; + } + { + auto lk = pushdownFilters_->at(j).rlock(); + operators_[j]->addDynamicFilterLocked(filterSource->planNodeId(), *lk); + } + operators_[j]->addRuntimeStat( + "dynamicFiltersAccepted", RuntimeCounter(numFiltersAccepted[j])); + } + if (numFiltersProduced > 0) { + filterSource->addRuntimeStat( + "dynamicFiltersProduced", RuntimeCounter(numFiltersProduced)); + } + return numFiltersProduced; +} + Operator* Driver::findOperator(std::string_view planNodeId) const { for (auto& op : operators_) { if (op->planNodeId() == planNodeId) { @@ -1065,7 +1108,7 @@ std::string Driver::toString() const { std::string blockedOp = (blockedOperatorId_ < operators_.size()) ? operators_[blockedOperatorId_]->toString() : ""; - out << "blocked (" << blockingReasonToString(blockingReason_) << " " + out << "blocked (" << BlockingReasonName::toName(blockingReason_) << " " << blockedOp << "), "; } else if (state_.isEnqueued) { out << "enqueued "; @@ -1076,9 +1119,13 @@ std::string Driver::toString() const { } out << "{Operators: "; - for (auto& op : operators_) { - out << op->toString() << ", "; - } + std::vector opStrs; + opStrs.reserve(operators_.size()); + std::ranges::transform( + operators_, std::back_inserter(opStrs), [](const auto& op) { + return op->toString(); + }); + out << folly::join(", ", opStrs); out << "}"; const auto ocs = opCallStatus(); if (!ocs.empty()) { @@ -1103,7 +1150,7 @@ Driver::CancelGuard::~CancelGuard() { folly::dynamic Driver::toJson() const { folly::dynamic obj = folly::dynamic::object; - obj["blockingReason"] = blockingReasonToString(blockingReason_); + obj["blockingReason"] = BlockingReasonName::toName(blockingReason_); obj["state"] = state_.toJson(); obj["closed"] = closed_.load(); obj["queueTimeStartMicros"] = queueTimeStartUs_; @@ -1198,40 +1245,6 @@ std::string Driver::label() const { return fmt::format("", task()->taskId(), ctx_->driverId); } -std::string blockingReasonToString(BlockingReason reason) { - switch (reason) { - case BlockingReason::kNotBlocked: - return "kNotBlocked"; - case BlockingReason::kWaitForConsumer: - return "kWaitForConsumer"; - case BlockingReason::kWaitForSplit: - return "kWaitForSplit"; - case BlockingReason::kWaitForProducer: - return "kWaitForProducer"; - case BlockingReason::kWaitForJoinBuild: - return "kWaitForJoinBuild"; - case BlockingReason::kWaitForJoinProbe: - return "kWaitForJoinProbe"; - case BlockingReason::kWaitForMergeJoinRightSide: - return "kWaitForMergeJoinRightSide"; - case BlockingReason::kWaitForMemory: - return "kWaitForMemory"; - case BlockingReason::kWaitForConnector: - return "kWaitForConnector"; - case BlockingReason::kYield: - return "kYield"; - case BlockingReason::kWaitForArbitration: - return "kWaitForArbitration"; - case BlockingReason::kWaitForScanScaleUp: - return "kWaitForScanScaleUp"; - case BlockingReason::kWaitForIndexLookup: - return "kWaitForIndexLookup"; - default: - VELOX_UNREACHABLE( - fmt::format("Unknown blocking reason {}", static_cast(reason))); - } -} - DriverThreadContext* driverThreadContext() { return driverThreadCtx; } diff --git a/velox/exec/Driver.h b/velox/exec/Driver.h index b59457fc77e5..517bf3df9a89 100644 --- a/velox/exec/Driver.h +++ b/velox/exec/Driver.h @@ -27,7 +27,7 @@ #include "velox/common/base/TraceConfig.h" #include "velox/common/time/CpuWallTimer.h" #include "velox/core/PlanFragment.h" -#include "velox/core/QueryCtx.h" +#include "velox/exec/BlockingReason.h" namespace facebook::velox::exec { @@ -61,14 +61,6 @@ std::string stopReasonString(StopReason reason); std::ostream& operator<<(std::ostream& out, const StopReason& reason); -struct DriverStats { - static constexpr const char* kTotalPauseTime = "totalDriverPauseWallNanos"; - static constexpr const char* kTotalOffThreadTime = - "totalDriverOffThreadWallNanos"; - - std::unordered_map runtimeStats; -}; - /// Represents a Driver's state. This is used for cancellation, forcing /// release of and for waiting for memory. The fields are serialized on /// the mutex of the Driver's Task. @@ -94,7 +86,7 @@ struct DriverStats { /// Terminated - 'isTerminated' is set. The Driver cannot run after this and /// the state is final. /// -/// CancelPool allows terminating or pausing a set of Drivers. The Task API +/// Task allows terminating or pausing a set of Drivers. The Task API /// allows starting or resuming Drivers. When terminate is requested the request /// is successful when all Drivers are off thread, blocked or suspended. When /// pause is requested, we have success when all Drivers are either enqueued, @@ -182,45 +174,6 @@ struct ThreadState { } }; -enum class BlockingReason { - kNotBlocked, - kWaitForConsumer, - kWaitForSplit, - /// Some operators can get blocked due to the producer(s) (they are - /// currently waiting data from) not having anything produced. Used by - /// LocalExchange, LocalMergeExchange, Exchange and MergeExchange operators. - kWaitForProducer, - kWaitForJoinBuild, - /// For a build operator, it is blocked waiting for the probe operators to - /// finish probing before build the next hash table from one of the - /// previously spilled partition data. For a probe operator, it is blocked - /// waiting for all its peer probe operators to finish probing before - /// notifying the build operators to build the next hash table from the - /// previously spilled data. - kWaitForJoinProbe, - /// Used by MergeJoin operator, indicating that it was blocked by the right - /// side input being unavailable. - kWaitForMergeJoinRightSide, - kWaitForMemory, - kWaitForConnector, - /// Some operators (like Table Scan) may run long loops and can 'voluntarily' - /// exit them because Task requested to yield or stop or after a certain time. - /// This is the blocking reason used in such cases. - kYield, - /// Operator is blocked waiting for its associated query memory arbitration to - /// finish. - kWaitForArbitration, - /// For a table scan operator, it is blocked waiting for the scan controller - /// to increase the number of table scan processing threads to start - /// processing. - kWaitForScanScaleUp, - /// Used by IndexLookupJoin operator, indicating that it was blocked by the - /// async index lookup. - kWaitForIndexLookup, -}; - -std::string blockingReasonToString(BlockingReason reason); - class BlockingState { public: BlockingState( @@ -363,6 +316,26 @@ struct OpCallStatus { std::atomic method{kOpMethodNone}; }; +struct PushdownFilters { + /// Keep a single instance across drivers so that we do not need to repeatedly + /// merge them in different drivers. + folly::F14FastMap filters; + + /// Indices added here will never be removed. + folly::F14FastSet dynamicFilteredColumns; + + /// Whether static filters has been added to filters. This only needs to be + /// done once per node by the first driver. + bool staticFiltersInitialized = false; +}; + +/// Pushdown filters on nodes in the pipeline. Locks must be acquired in the +/// order from downstream to upstream (i.e. it's forbidden that we acquire the +/// upstream node lock first, and then acquire the downstream node lock while we +/// hold the upstream lock). +using PipelinePushdownFilters = + std::vector>; + class Driver : public std::enable_shared_from_this { public: static void enqueue(std::shared_ptr instance); @@ -411,6 +384,11 @@ class Driver : public std::enable_shared_from_this { /// time slice limit if set. bool shouldYield() const; + /// Inline function to check if operator batch size stats are enabled. + inline bool enableOperatorBatchSizeStats() const { + return enableOperatorBatchSizeStats_; + } + /// Checks if the associated query is under memory arbitration or not. The /// function returns true if it is and set future which is fulfilled when the /// memory arbitration finishes. @@ -431,6 +409,30 @@ class Driver : public std::enable_shared_from_this { const Operator* filterSource, const std::vector& channels) const; + /// Try to add new dynamic filters from `filterSource' to its upstream + /// operator which accept dynamic filters. `channels' are the inputs for + /// `filterSource'. + /// + /// `makeFilter' is called with a lock held on the node of `filterSource' in + /// `pushdownFilters_'. It should return whether a filter should be added, + /// and set the FilterPtr output parameter with a new filter if one is + /// generated. If `makeFilter' returns true but FilterPtr is not set, it + /// means a filter is already generated by another operator on the same node, + /// and we just need to set the new merged filter on the accepting operator. + /// + /// Return the number of filters produced. + int pushdownFilters( + Operator* filterSource, + const std::vector& channels, + const std::function& + makeFilter); + + int operatorIndex(const Operator* op) const; + + const std::shared_ptr& pushdownFilters() const { + return pushdownFilters_; + } + /// Returns the Operator with 'planNodeId' or nullptr if not found. For /// example, hash join probe accesses the corresponding build by id. Operator* findOperator(std::string_view planNodeId) const; @@ -600,10 +602,6 @@ class Driver : public std::enable_shared_from_this { void close(); - // Push down dynamic filters produced by the operator at the specified - // position in the pipeline. - void pushdownFilters(int operatorIndex); - using TimingMemberPtr = CpuWallTiming OperatorStats::*; template void withDeltaCpuWallTimer( @@ -631,6 +629,10 @@ class Driver : public std::enable_shared_from_this { std::unique_ptr ctx_; + // If set, the operator output batch size stats will be collected during + // driver execution. + bool enableOperatorBatchSizeStats_{false}; + // If not zero, specifies the driver cpu time slice. size_t cpuSliceMs_{0}; @@ -664,6 +666,10 @@ class Driver : public std::enable_shared_from_this { // of DriverFactory::createDriver(). bool isAdaptable_{true}; + // Pushdown filters on the pipeline. This is generated per split group per + // pipeline. + std::shared_ptr pushdownFilters_; + friend struct DriverFactory; }; @@ -728,6 +734,7 @@ struct DriverFactory { std::shared_ptr createDriver( std::unique_ptr ctx, std::shared_ptr exchangeClient, + std::shared_ptr filters, std::function numDrivers); /// Replaces operators at indices 'begin' to 'end - 1' with @@ -811,6 +818,10 @@ struct DriverFactory { /// based on this pipeline. std::vector needsNestedLoopJoinBridges() const; + /// Returns plan node IDs for which Spatial Join Bridges must be created + /// based on this pipeline. + std::vector needsSpatialJoinBridges() const; + static std::vector adapters; }; @@ -850,16 +861,6 @@ DriverThreadContext* driverThreadContext(); } // namespace facebook::velox::exec -template <> -struct fmt::formatter - : formatter { - auto format(facebook::velox::exec::BlockingReason b, format_context& ctx) - const { - return formatter::format( - facebook::velox::exec::blockingReasonToString(b), ctx); - } -}; - template <> struct fmt::formatter : formatter { diff --git a/velox/exec/DriverStats.h b/velox/exec/DriverStats.h new file mode 100644 index 000000000000..3b047f4c9e07 --- /dev/null +++ b/velox/exec/DriverStats.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/common/base/RuntimeMetrics.h" + +namespace facebook::velox::exec { + +struct DriverStats { + static constexpr const char* kTotalPauseTime = "totalDriverPauseWallNanos"; + static constexpr const char* kTotalOffThreadTime = + "totalDriverOffThreadWallNanos"; + + std::unordered_map runtimeStats; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/Exchange.cpp b/velox/exec/Exchange.cpp index 8582c804124f..6585961cb915 100644 --- a/velox/exec/Exchange.cpp +++ b/velox/exec/Exchange.cpp @@ -14,23 +14,47 @@ * limitations under the License. */ #include "velox/exec/Exchange.h" - +#include "velox/common/Casts.h" +#include "velox/common/serialization/Serializable.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/serializers/CompactRowSerializer.h" namespace facebook::velox::exec { +folly::dynamic RemoteConnectorSplit::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "RemoteConnectorSplit"; + obj["taskId"] = taskId; + return obj; +} + +// static +std::shared_ptr RemoteConnectorSplit::create( + const folly::dynamic& obj) { + const auto taskId = obj["taskId"].asString(); + return std::make_shared(taskId); +} + +// static +void RemoteConnectorSplit::registerSerDe() { + auto& registry = DeserializationRegistryForSharedPtr(); + registry.Register("RemoteConnectorSplit", RemoteConnectorSplit::create); +} + namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - return options; +std::unique_ptr mergePages( + std::vector>& pages) { + VELOX_CHECK(!pages.empty()); + std::unique_ptr mergedBufs; + for (const auto& page : pages) { + if (mergedBufs == nullptr) { + mergedBufs = page->getIOBuf(); + } else { + mergedBufs->appendToChain(page->getIOBuf()); + } + } + return mergedBufs; } } // namespace @@ -50,7 +74,9 @@ Exchange::Exchange( driverCtx->queryConfig().preferredOutputBatchBytes()}, serdeKind_{exchangeNode->serdeKind()}, serdeOptions_{getVectorSerdeOptions( - operatorCtx_->driverCtx()->queryConfig(), + common::stringToCompressionKind(operatorCtx_->driverCtx() + ->queryConfig() + .shuffleCompressionKind()), serdeKind_)}, processSplits_{operatorCtx_->driverCtx()->driverId == 0}, driverId_{driverCtx->driverId}, @@ -64,39 +90,42 @@ void Exchange::addRemoteTaskIds(std::vector& remoteTaskIds) { stats_.wlock()->numSplits += remoteTaskIds.size(); } -bool Exchange::getSplits(ContinueFuture* future) { +void Exchange::getSplits(ContinueFuture* future) { if (!processSplits_) { - return false; + return; } if (noMoreSplits_) { - return false; + return; } std::vector remoteTaskIds; for (;;) { exec::Split split; - auto reason = operatorCtx_->task()->getSplitOrFuture( + const auto reason = operatorCtx_->task()->getSplitOrFuture( operatorCtx_->driverCtx()->splitGroupId, planNodeId(), split, *future); - if (reason == BlockingReason::kNotBlocked) { - if (split.hasConnectorSplit()) { - auto remoteSplit = std::dynamic_pointer_cast( - split.connectorSplit); - VELOX_CHECK_NOT_NULL(remoteSplit, "Wrong type of split"); - remoteTaskIds.push_back(remoteSplit->taskId); - } else { - addRemoteTaskIds(remoteTaskIds); - exchangeClient_->noMoreRemoteTasks(); - noMoreSplits_ = true; - if (atEnd_) { - operatorCtx_->task()->multipleSplitsFinished( - false, stats_.rlock()->numSplits, 0); - recordExchangeClientStats(); - } - return false; - } - } else { + if (reason != BlockingReason::kNotBlocked) { addRemoteTaskIds(remoteTaskIds); - return true; + return; + } + + if (split.hasConnectorSplit()) { + auto remoteSplit = + checkedPointerCast(split.connectorSplit); + if (FOLLY_UNLIKELY(splitTracer_ != nullptr)) { + splitTracer_->write(split); + } + remoteTaskIds.push_back(remoteSplit->taskId); + continue; + } + + addRemoteTaskIds(remoteTaskIds); + exchangeClient_->noMoreRemoteTasks(); + noMoreSplits_ = true; + if (atEnd_) { + operatorCtx_->task()->multipleSplitsFinished( + false, stats_.rlock()->numSplits, 0); + recordExchangeClientStats(); } + return; } } @@ -106,7 +135,6 @@ BlockingReason Exchange::isBlocked(ContinueFuture* future) { } // Start fetching data right away. Do not wait for all splits to be available. - if (!splitFuture_.valid()) { getSplits(&splitFuture_); } @@ -135,6 +163,7 @@ BlockingReason Exchange::isBlocked(ContinueFuture* future) { } // Block until data becomes available. + VELOX_CHECK(dataFuture.valid()); *future = std::move(dataFuture); return BlockingReason::kWaitForProducer; } @@ -143,102 +172,142 @@ bool Exchange::isFinished() { return atEnd_ && currentPages_.empty(); } -namespace { -std::unique_ptr mergePages( - std::vector>& pages) { - VELOX_CHECK(!pages.empty()); - std::unique_ptr mergedBufs; - for (const auto& page : pages) { - if (mergedBufs == nullptr) { - mergedBufs = page->getIOBuf(); - } else { - mergedBufs->appendToChain(page->getIOBuf()); - } - } - return mergedBufs; -} -} // namespace - RowVectorPtr Exchange::getOutput() { auto* serde = getSerde(); if (serde->supportsAppendInDeserialize()) { - uint64_t rawInputBytes{0}; - if (currentPages_.empty()) { - return nullptr; - } - vector_size_t resultOffset = 0; - for (const auto& page : currentPages_) { + return getOutputFromColumnarPages(serde); + } + return getOutputFromRowPages(serde); +} + +RowVectorPtr Exchange::getOutputFromColumnarPages(VectorSerde* serde) { + if (currentPages_.empty()) { + return nullptr; + } + + // Calculate target row count based on estimated row size, similar to + // getOutputFromRowPages. + // Start conservatively, then use estimates. + const auto numRows = estimatedRowSize_.has_value() + ? std::max( + (preferredOutputBatchBytes_ / estimatedRowSize_.value()), + kInitialOutputRows) + : kInitialOutputRows; + + // Process pages one-by-one from currentPages_ pointed by columnarPageIdx_. + // Within each page, deserialize vectors incrementally until we hit the target + // batch size. + uint64_t rawInputBytes = 0; + vector_size_t resultOffset{0}; + + // Should be either starting fresh or continuing from a previous partial page + VELOX_CHECK( + inputStream_ == nullptr || columnarPageIdx_ < currentPages_.size()); + + // Iterate through pages + while (columnarPageIdx_ < currentPages_.size()) { + auto& page = currentPages_[columnarPageIdx_]; + + if (!inputStream_) { + // NOTE: 'rawInputBytes' only counts bytes from pages processed from the + // beginning in this call. If processing resumes from the middle of a + // page, that page's bytes are not counted. This ensures each page is + // counted only once in 'rawInputBytes' across multiple calls. rawInputBytes += page->size(); + inputStream_ = page->prepareStreamForDeserialize(); + } - auto inputStream = page->prepareStreamForDeserialize(); - while (!inputStream->atEnd()) { - serde->deserialize( - inputStream.get(), - pool(), - outputType_, - &result_, - resultOffset, - serdeOptions_.get()); - resultOffset = result_->size(); - } + // Inner loop: deserialize vectors from current page until batch is full + // or page is exhausted. + while (!inputStream_->atEnd() && resultOffset < numRows) { + serde->deserialize( + inputStream_.get(), + pool(), + outputType_, + &result_, + resultOffset, + serdeOptions_.get()); + + resultOffset = result_->size(); + } + + if (inputStream_->atEnd()) { + // Page is fully consumed, free memory immediately, and move to the next. + inputStream_ = nullptr; + page.reset(); + ++columnarPageIdx_; + } + + // Stop if accumulated enough rows for this batch. + if (resultOffset >= numRows) { + break; } - currentPages_.clear(); - recordInputStats(rawInputBytes); - return result_; - } - if (serde->kind() == VectorSerde::Kind::kCompactRow) { - return getOutputFromCompactRows(serde); } - if (serde->kind() == VectorSerde::Kind::kUnsafeRow) { - return getOutputFromUnsafeRows(serde); + + const auto numOutputRows = result_->size(); + VELOX_CHECK_GT(numOutputRows, 0); + + estimatedRowSize_ = std::max( + result_->estimateFlatSize() / numOutputRows, + estimatedRowSize_.value_or(1L)); + + // If processed all pages, clear the vector and reset state. + if (columnarPageIdx_ >= currentPages_.size()) { + VELOX_CHECK_NULL(inputStream_); + currentPages_.clear(); + columnarPageIdx_ = 0; } - VELOX_UNREACHABLE( - "Unsupported serde kind: {}", VectorSerde::kindName(serde->kind())); + + recordInputStats(rawInputBytes); + return result_; } -RowVectorPtr Exchange::getOutputFromCompactRows(VectorSerde* serde) { +RowVectorPtr Exchange::getOutputFromRowPages(VectorSerde* serde) { uint64_t rawInputBytes{0}; if (currentPages_.empty()) { - VELOX_CHECK_NULL(compactRowInputStream_); - VELOX_CHECK_NULL(compactRowIterator_); + VELOX_CHECK_NULL(inputStream_); + VELOX_CHECK_NULL(rowIterator_); return nullptr; } - if (compactRowInputStream_ == nullptr) { + if (inputStream_ == nullptr) { std::unique_ptr mergedBufs = mergePages(currentPages_); rawInputBytes += mergedBufs->computeChainDataLength(); - compactRowPages_ = std::make_unique(std::move(mergedBufs)); - compactRowInputStream_ = compactRowPages_->prepareStreamForDeserialize(); + mergedRowPage_ = + std::make_unique(std::move(mergedBufs)); + inputStream_ = mergedRowPage_->prepareStreamForDeserialize(); } - auto numRows = kInitialOutputCompactRows; - if (estimatedCompactRowSize_.has_value()) { + auto numRows = kInitialOutputRows; + if (estimatedRowSize_.has_value()) { numRows = std::max( - (preferredOutputBatchBytes_ / estimatedCompactRowSize_.value()), - kInitialOutputCompactRows); + (preferredOutputBatchBytes_ / estimatedRowSize_.value()), + kInitialOutputRows); } + // Check if the serde supports batched deserialization serde->deserialize( - compactRowInputStream_.get(), - compactRowIterator_, + inputStream_.get(), + rowIterator_, numRows, outputType_, &result_, pool(), serdeOptions_.get()); + const auto numOutputRows = result_->size(); VELOX_CHECK_GT(numOutputRows, 0); - estimatedCompactRowSize_ = std::max( + estimatedRowSize_ = std::max( result_->estimateFlatSize() / numOutputRows, - estimatedCompactRowSize_.value_or(1L)); + estimatedRowSize_.value_or(1L)); - if (compactRowInputStream_->atEnd() && compactRowIterator_ == nullptr) { + if (inputStream_->atEnd() && rowIterator_ == nullptr) { // only clear the input stream if we have reached the end of the row // iterator because row iterator may depend on input stream if serialized // rows are not compressed. - compactRowInputStream_ = nullptr; - compactRowPages_ = nullptr; + inputStream_ = nullptr; + mergedRowPage_ = nullptr; currentPages_.clear(); } @@ -246,22 +315,6 @@ RowVectorPtr Exchange::getOutputFromCompactRows(VectorSerde* serde) { return result_; } -RowVectorPtr Exchange::getOutputFromUnsafeRows(VectorSerde* serde) { - uint64_t rawInputBytes{0}; - if (currentPages_.empty()) { - return nullptr; - } - std::unique_ptr mergedBufs = mergePages(currentPages_); - rawInputBytes += mergedBufs->computeChainDataLength(); - auto mergedPages = std::make_unique(std::move(mergedBufs)); - auto source = mergedPages->prepareStreamForDeserialize(); - serde->deserialize( - source.get(), pool(), outputType_, &result_, serdeOptions_.get()); - currentPages_.clear(); - recordInputStats(rawInputBytes); - return result_; -} - void Exchange::recordInputStats(uint64_t rawInputBytes) { auto lockedStats = stats_.wlock(); lockedStats->rawInputBytes += rawInputBytes; @@ -273,6 +326,13 @@ void Exchange::close() { SourceOperator::close(); currentPages_.clear(); result_ = nullptr; + + // Clean up stateful deserialization state + inputStream_ = nullptr; + mergedRowPage_ = nullptr; + rowIterator_ = nullptr; + columnarPageIdx_ = 0; + if (exchangeClient_) { recordExchangeClientStats(); exchangeClient_->close(); @@ -301,14 +361,12 @@ void Exchange::recordExchangeClientStats() { lockedStats->runtimeStats.insert({name, value}); } - auto backgroundCpuTimeMs = - exchangeClientStats.find(ExchangeClient::kBackgroundCpuTimeMs); - if (backgroundCpuTimeMs != exchangeClientStats.end()) { + const auto iter = exchangeClientStats.find(Operator::kBackgroundCpuTimeNanos); + if (iter != exchangeClientStats.end()) { const CpuWallTiming backgroundTiming{ - static_cast(backgroundCpuTimeMs->second.count), + static_cast(iter->second.count), 0, - static_cast(backgroundCpuTimeMs->second.sum) * - Timestamp::kNanosecondsInMillisecond}; + static_cast(iter->second.sum)}; lockedStats->backgroundTiming.clear(); lockedStats->backgroundTiming.add(backgroundTiming); } diff --git a/velox/exec/Exchange.h b/velox/exec/Exchange.h index a6f7c261d50f..ea4bc95f19c6 100644 --- a/velox/exec/Exchange.h +++ b/velox/exec/Exchange.h @@ -31,6 +31,13 @@ struct RemoteConnectorSplit : public connector::ConnectorSplit { explicit RemoteConnectorSplit(const std::string& remoteTaskId) : ConnectorSplit(""), taskId(remoteTaskId) {} + static std::shared_ptr create( + const folly::dynamic& obj); + + static void registerSerDe(); + + folly::dynamic serialize() const override; + std::string toString() const override { return fmt::format("Remote: {}", taskId); } @@ -60,12 +67,11 @@ class Exchange : public SourceOperator { protected: virtual VectorSerde* getSerde(); - private: - // When 'estimatedCompactRowSize_' is unset, meaning we haven't materialized + // When 'estimatedRowSize_' is unset, meaning we haven't materialized // and returned any output from this exchange operator, we return this // conservative number of output rows, to make sure memory does not grow too // much. - static constexpr uint64_t kInitialOutputCompactRows = 64; + static constexpr uint64_t kInitialOutputRows = 64; // Invoked to create exchange client for remote tasks. The function shuffles // the source task ids first to randomize the source tasks we fetch data from. @@ -75,11 +81,8 @@ class Exchange : public SourceOperator { // Fetches splits from the task until there are no more splits or task returns // a future that will be complete when more splits arrive. Adds splits to - // exchangeClient_. Returns true if received a future from the task and sets - // the 'future' parameter. Returns false if fetched all splits or if this - // operator is not the first operator in the pipeline and therefore is not - // responsible for fetching splits and adding them to the exchangeClient_. - bool getSplits(ContinueFuture* future); + // exchangeClient_. + void getSplits(ContinueFuture* future); // Fetches runtime stats from ExchangeClient and replaces these in this // operator's stats. @@ -87,9 +90,9 @@ class Exchange : public SourceOperator { void recordInputStats(uint64_t rawInputBytes); - RowVectorPtr getOutputFromCompactRows(VectorSerde* serde); + RowVectorPtr getOutputFromColumnarPages(VectorSerde* serde); - RowVectorPtr getOutputFromUnsafeRows(VectorSerde* serde); + RowVectorPtr getOutputFromRowPages(VectorSerde* serde); const uint64_t preferredOutputBatchBytes_; @@ -114,19 +117,25 @@ class Exchange : public SourceOperator { // Reusable result vector. RowVectorPtr result_; - std::vector> currentPages_; + std::vector> currentPages_; bool atEnd_{false}; std::default_random_engine rng_{std::random_device{}()}; - // Memory holders needed by compact row serde to perform cursor like reads - // across 'getOutputFromCompactRows' calls. - std::unique_ptr compactRowPages_; - std::unique_ptr compactRowInputStream_; - std::unique_ptr compactRowIterator_; + // Memory holders for deserialization across 'getOutput' calls. + // The merged pages for row serialization. + std::unique_ptr mergedRowPage_; + std::unique_ptr rowIterator_; + + // State for columnar page deserialization. + // Index of the current page in 'currentPages_' being processed. + size_t columnarPageIdx_{0}; + + // Stream for deserialization used by both row and columnar. + std::unique_ptr inputStream_; // The estimated bytes per row of the output of this exchange operator // computed from the last processed output. - std::optional estimatedCompactRowSize_; + std::optional estimatedRowSize_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/ExchangeClient.cpp b/velox/exec/ExchangeClient.cpp index 0a8b92b2f018..a138f0552045 100644 --- a/velox/exec/ExchangeClient.cpp +++ b/velox/exec/ExchangeClient.cpp @@ -53,7 +53,16 @@ void ExchangeClient::addRemoteTaskId(const std::string& remoteTaskId) { sources_.push_back(source); queue_->addSourceLocked(); emptySources_.push(source); - requestSpecs = pickSourcesToRequestLocked(); + // When lazyFetching_ is true, I/O will be triggered lazily when next() is + // called from Exchange::isBlocked(). This allows waiter tasks using + // cached hash tables to skip I/O entirely when the table is already + // cached - the HashBuild operator will finish before + // Exchange::isBlocked() is ever called, so no unnecessary data fetching + // occurs. + if (!lazyFetching_) { + // Start fetching data immediately. + requestSpecs = pickSourcesToRequestLocked(); + } } } @@ -78,6 +87,11 @@ void ExchangeClient::close() { if (closed_) { return; } + + // Capture stats BEFORE clearing sources_. + // This allows stats() to return meaningful data even after close(). + stats_ = collectStatsLocked(); + closed_ = true; sources = std::move(sources_); producingSources = std::move(producingSources_); @@ -91,9 +105,17 @@ void ExchangeClient::close() { queue_->close(); } -folly::F14FastMap ExchangeClient::stats() const { - folly::F14FastMap stats; +folly::F14FastMap ExchangeClient::stats() { std::lock_guard l(queue_->mutex()); + if (stats_.empty()) { + stats_ = collectStatsLocked(); + } + return stats_; +} + +folly::F14FastMap +ExchangeClient::collectStatsLocked() const { + folly::F14FastMap stats; for (const auto& source : sources_) { if (source->supportsMetrics()) { @@ -119,13 +141,13 @@ folly::F14FastMap ExchangeClient::stats() const { return stats; } -std::vector> ExchangeClient::next( +std::vector> ExchangeClient::next( int consumerId, uint32_t maxBytes, bool* atEnd, ContinueFuture* future) { std::vector requestSpecs; - std::vector> pages; + std::vector> pages; ContinuePromise stalePromise = ContinuePromise::makeEmpty(); { std::lock_guard l(queue_->mutex()); @@ -185,7 +207,7 @@ void ExchangeClient::request(std::vector&& requestSpecs) { RECORD_METRIC_VALUE(kMetricExchangeDataCount); } - bool pauseCurrentSource = false; + bool pauseCurrentSource{false}; std::vector requestSpecs; std::shared_ptr currentSource = spec.source; { @@ -232,6 +254,9 @@ ExchangeClient::pickSourcesToRequestLocked() { if (closed_) { return {}; } + if (skipRequestDataSizeWithSingleSource()) { + return pickupSingleSourceToRequestLocked(); + } std::vector requestSpecs; while (!emptySources_.empty()) { auto& source = emptySources_.front(); @@ -283,6 +308,43 @@ ExchangeClient::pickSourcesToRequestLocked() { return requestSpecs; } +std::vector +ExchangeClient::pickupSingleSourceToRequestLocked() { + VELOX_CHECK_EQ(sources_.size(), 1); + VELOX_CHECK(!closed_); + if (emptySources_.empty() && producingSources_.empty()) { + return {}; + } + + VELOX_CHECK_EQ(totalPendingBytes_, 0); + VELOX_CHECK_LE(!!emptySources_.empty() + !!producingSources_.empty(), 1); + const auto requestBytes = maxQueuedBytes_ - queue_->totalBytes(); + + if (requestBytes <= 0) { + return {}; + } + std::vector requestSpecs; + SCOPE_EXIT { + totalPendingBytes_ += requestBytes; + }; + if (!emptySources_.empty()) { + VELOX_CHECK_EQ(emptySources_.size(), 1); + auto& source = emptySources_.front(); + VELOX_CHECK(source->shouldRequestLocked()); + requestSpecs.push_back({std::move(source), requestBytes}); + emptySources_.pop(); + return requestSpecs; + } + + VELOX_CHECK_EQ(producingSources_.size(), 1); + auto& source = producingSources_.front().source; + VELOX_CHECK(source->shouldRequestLocked()); + VELOX_CHECK(!producingSources_.front().remainingBytes.empty()); + requestSpecs.push_back({std::move(source), requestBytes}); + producingSources_.pop(); + return requestSpecs; +} + ExchangeClient::~ExchangeClient() { close(); } diff --git a/velox/exec/ExchangeClient.h b/velox/exec/ExchangeClient.h index b99fc49e8851..3b8e3df3bdb7 100644 --- a/velox/exec/ExchangeClient.h +++ b/velox/exec/ExchangeClient.h @@ -26,7 +26,6 @@ class ExchangeClient : public std::enable_shared_from_this { public: static constexpr int32_t kDefaultMaxQueuedBytes = 32 << 20; // 32 MB. static constexpr std::chrono::milliseconds kRequestDataMaxWait{100}; - static inline const std::string kBackgroundCpuTimeMs = "backgroundCpuTimeMs"; ExchangeClient( std::string taskId, @@ -36,23 +35,29 @@ class ExchangeClient : public std::enable_shared_from_this { uint64_t minOutputBatchBytes, memory::MemoryPool* pool, folly::Executor* executor, - int32_t requestDataSizesMaxWaitSec = 10) + int32_t requestDataSizesMaxWaitSec = 10, + bool skipRequestDataSizeWithSingleSource = false, + bool lazyFetching = false) : taskId_{std::move(taskId)}, destination_(destination), maxQueuedBytes_{maxQueuedBytes}, kRequestDataSizesMaxWaitSec_{requestDataSizesMaxWaitSec}, pool_(pool), executor_(executor), - queue_(std::make_shared( - numberOfConsumers, - minOutputBatchBytes)), + queue_( + std::make_shared( + numberOfConsumers, + minOutputBatchBytes)), // See comment in 'pickSourcesToRequestLocked' for why this is needed // for 'minOutputBatchBytes_'. Note: ExchangeQueue does not need max(1, // minOutputBatchBytes) because for 'MergeExchangeSource', we want // ExchangeQueue 'minOutputBatchBytes' to be be 0 so that it always // unblocks. In short, 0 has a special meaning for ExchangeQueue minOutputBatchBytes_( - std::max(static_cast(1), minOutputBatchBytes)) { + std::max(static_cast(1), minOutputBatchBytes)), + skipRequestDataSizeWithSingleSource_( + skipRequestDataSizeWithSingleSource), + lazyFetching_(lazyFetching) { VELOX_CHECK_NOT_NULL(pool_); VELOX_CHECK_NOT_NULL(executor_); // NOTE: the executor is used to run async response callback from the @@ -87,8 +92,8 @@ class ExchangeClient : public std::enable_shared_from_this { // Returns runtime statistics aggregated across all of the exchange sources. // ExchangeClient is expected to report background CPU time by including a - // runtime metric named ExchangeClient::kBackgroundCpuTimeMs. - folly::F14FastMap stats() const; + // runtime metric named Operator::kBackgroundCpuTimeNanos. + folly::F14FastMap stats(); const std::shared_ptr& queue() const { return queue_; @@ -102,7 +107,7 @@ class ExchangeClient : public std::enable_shared_from_this { /// /// The data may be compressed, in which case 'maxBytes' applies to compressed /// size. - std::vector> + std::vector> next(int consumerId, uint32_t maxBytes, bool* atEnd, ContinueFuture* future); std::string toString() const; @@ -131,10 +136,37 @@ class ExchangeClient : public std::enable_shared_from_this { std::vector remainingBytes; }; + // Selects exchange sources to request data from based on available queue + // capacity. Handles multiple sources by first requesting data sizes from all + // empty sources, then requesting actual data from producing sources based on + // their remaining bytes and available capacity. May initiate out-of-band + // transfers for large pages that exceed capacity to avoid deadlock + // situations. For single source case, delegates to + // pickupSingleSourceToRequestLocked which sets max request bytes based on + // available queue space instead of reported remaining bytes from exchange + // sources. std::vector pickSourcesToRequestLocked(); + // Specialized single-source request picker for single-source exchange + // clients. Sets the max request bytes based on available space in the queue + // rather than the reported remaining bytes from exchange sources. The reason + // is that single source has no other alternative so just fetch as much as + // possible from that source. Returns a request spec for the single source + // when there is available capacity in the queue and no pending requests. If + // capacity is unavailable or requests are already pending, returns empty + // vector. + std::vector pickupSingleSourceToRequestLocked(); void request(std::vector&& requestSpecs); + /// Returns true if skip request data size optimization is enabled for single + /// source exchanges. + bool skipRequestDataSizeWithSingleSource() const { + return skipRequestDataSizeWithSingleSource_ && queue_->hasNoMoreSources() && + sources_.size() == 1; + } + + folly::F14FastMap collectStatsLocked() const; + // Handy for ad-hoc logging. const std::string taskId_; const int destination_; @@ -149,10 +181,21 @@ class ExchangeClient : public std::enable_shared_from_this { std::vector> sources_; bool closed_{false}; + folly::F14FastMap stats_; + // The minimum byte size the consumer is expected to consume from // the exchange queue. const uint64_t minOutputBatchBytes_; + // Enable single source exchange optimization query config flag + // when there is only one exchange source. + const bool skipRequestDataSizeWithSingleSource_; + + // If true, defer fetching until next() is called. + // If false (default), start fetching data immediately when remote tasks are + // added. + const bool lazyFetching_; + // Total number of bytes in flight. int64_t totalPendingBytes_{0}; diff --git a/velox/exec/ExchangeQueue.cpp b/velox/exec/ExchangeQueue.cpp index b7d1100dbbbb..9a9114ae2c43 100644 --- a/velox/exec/ExchangeQueue.cpp +++ b/velox/exec/ExchangeQueue.cpp @@ -16,42 +16,18 @@ #include "velox/exec/ExchangeQueue.h" #include -namespace facebook::velox::exec { - -SerializedPage::SerializedPage( - std::unique_ptr iobuf, - std::function onDestructionCb, - std::optional numRows) - : iobuf_(std::move(iobuf)), - iobufBytes_(chainBytes(*iobuf_.get())), - numRows_(numRows), - onDestructionCb_(onDestructionCb) { - VELOX_CHECK_NOT_NULL(iobuf_); - for (auto& buf : *iobuf_) { - int32_t bufSize = buf.size(); - ranges_.push_back(ByteRange{ - const_cast(reinterpret_cast(buf.data())), - bufSize, - 0}); - } -} +#include "velox/common/testutil/TestValue.h" -SerializedPage::~SerializedPage() { - if (onDestructionCb_) { - onDestructionCb_(*iobuf_.get()); - } -} +using facebook::velox::common::testutil::TestValue; -std::unique_ptr SerializedPage::prepareStreamForDeserialize() { - return std::make_unique(std::move(ranges_)); -} +namespace facebook::velox::exec { void ExchangeQueue::noMoreSources() { std::vector promises; { std::lock_guard l(mutex_); noMoreSources_ = true; - promises = checkCompleteLocked(); + promises = checkNoMoreInput(); } clearPromises(promises); } @@ -66,20 +42,21 @@ void ExchangeQueue::close() { } int64_t ExchangeQueue::minOutputBatchBytesLocked() const { - // always allow to unblock when at end - if (atEnd_) { + // Allow to unblock if no more input. + if (noMoreInput_) { return 0; } - // At most 1% of received bytes so far to minimize latency for small exchanges + // At most 1% of received bytes so far to minimize latency for small + // exchanges. return std::min(minOutputBatchBytes_, receivedBytes_ / 100); } void ExchangeQueue::enqueueLocked( - std::unique_ptr&& page, + std::unique_ptr&& page, std::vector& promises) { if (page == nullptr) { ++numCompleted_; - auto completedPromises = checkCompleteLocked(); + auto completedPromises = checkNoMoreInput(); promises.reserve(promises.size() + completedPromises.size()); for (auto& promise : completedPromises) { promises.push_back(std::move(promise)); @@ -124,18 +101,20 @@ void ExchangeQueue::addPromiseLocked( *stalePromise = std::move(it->second); it->second = std::move(promise); } else { - promises_[consumerId] = std::move(promise); + promises_.emplace(consumerId, std::move(promise)); } VELOX_CHECK_LE(promises_.size(), numberOfConsumers_); } -std::vector> ExchangeQueue::dequeueLocked( +std::vector> ExchangeQueue::dequeueLocked( int consumerId, uint32_t maxBytes, bool* atEnd, ContinueFuture* future, ContinuePromise* stalePromise) { VELOX_CHECK_NOT_NULL(future); + TestValue::adjust( + "facebook::velox::exec::ExchangeQueue::dequeueLocked", this); if (!error_.empty()) { *atEnd = true; VELOX_FAIL(error_); @@ -150,11 +129,11 @@ std::vector> ExchangeQueue::dequeueLocked( return {}; } - std::vector> pages; + std::vector> pages; uint32_t pageBytes = 0; for (;;) { if (queue_.empty()) { - if (atEnd_) { + if (noMoreInput_) { *atEnd = true; } else if (pages.empty()) { addPromiseLocked(consumerId, future, stalePromise); @@ -183,9 +162,9 @@ void ExchangeQueue::setError(const std::string& error) { return; } error_ = error; - atEnd_ = true; - // NOTE: clear the serialized page queue as we won't consume from an - // errored queue. + noMoreInput_ = true; + // NOTE: clear the serialized page queue as we won't consume from an errored + // queue. queue_.clear(); promises = clearAllPromisesLocked(); } diff --git a/velox/exec/ExchangeQueue.h b/velox/exec/ExchangeQueue.h index 4f77360fdbc7..91a633d5366b 100644 --- a/velox/exec/ExchangeQueue.h +++ b/velox/exec/ExchangeQueue.h @@ -15,66 +15,11 @@ */ #pragma once -#include "velox/common/memory/ByteStream.h" +#include "velox/exec/SerializedPage.h" -namespace facebook::velox::exec { - -/// Corresponds to Presto SerializedPage, i.e. a container for serialize vectors -/// in Presto wire format. -class SerializedPage { - public: - /// Construct from IOBuf chain. - explicit SerializedPage( - std::unique_ptr iobuf, - std::function onDestructionCb = nullptr, - std::optional numRows = std::nullopt); - - ~SerializedPage(); +#include - /// Returns the size of the serialized data in bytes. - uint64_t size() const { - return iobufBytes_; - } - - std::optional numRows() const { - return numRows_; - } - - /// Makes 'input' ready for deserializing 'this' with - /// VectorStreamGroup::read(). - std::unique_ptr prepareStreamForDeserialize(); - - std::unique_ptr getIOBuf() const { - return iobuf_->clone(); - } - - private: - static int64_t chainBytes(folly::IOBuf& iobuf) { - int64_t size = 0; - for (auto& range : iobuf) { - size += range.size(); - } - return size; - } - - // Buffers containing the serialized data. The memory is owned by 'iobuf_'. - std::vector ranges_; - - // IOBuf holding the data in 'ranges_. - std::unique_ptr iobuf_; - - // Number of payload bytes in 'iobuf_'. - const int64_t iobufBytes_; - - // Number of payload rows, if provided. - const std::optional numRows_; - - // Callback that will be called on destruction of the SerializedPage, - // primarily used to free externally allocated memory backing folly::IOBuf - // from caller. Caller is responsible to pass in proper cleanup logic to - // prevent any memory leak. - std::function onDestructionCb_; -}; +namespace facebook::velox::exec { /// Queue of results retrieved from source. Owned by shared_ptr by /// Exchange and client threads and registered callbacks waiting @@ -108,7 +53,7 @@ class ExchangeQueue { /// returned in 'promises'. When 'page' is nullptr and the queue is not /// completed serving data, no 'promises' will be added and returned. void enqueueLocked( - std::unique_ptr&& page, + std::unique_ptr&& page, std::vector& promises); /// If data is permanently not available, e.g. the source cannot be @@ -127,7 +72,7 @@ class ExchangeQueue { /// /// The data may be compressed, in which case 'maxBytes' applies to compressed /// size. - std::vector> dequeueLocked( + std::vector> dequeueLocked( int consumerId, uint32_t maxBytes, bool* atEnd, @@ -162,6 +107,10 @@ class ExchangeQueue { void noMoreSources(); + bool hasNoMoreSources() const { + return noMoreSources_; + } + void close(); private: @@ -170,9 +119,9 @@ class ExchangeQueue { return clearAllPromisesLocked(); } - std::vector checkCompleteLocked() { + std::vector checkNoMoreInput() { if (noMoreSources_ && numCompleted_ == numSources_) { - atEnd_ = true; + noMoreInput_ = true; return clearAllPromisesLocked(); } return {}; @@ -193,7 +142,9 @@ class ExchangeQueue { } std::vector clearAllPromisesLocked() { - std::vector promises(promises_.size()); + std::vector promises; + promises.reserve(promises_.size()); + auto it = promises_.begin(); while (it != promises_.end()) { promises.push_back(std::move(it->second)); @@ -216,11 +167,14 @@ class ExchangeQueue { int numCompleted_{0}; int numSources_{0}; - bool noMoreSources_{false}; - bool atEnd_{false}; + tsan_atomic noMoreSources_{false}; + // True if no more pages will be enqueued. This can be due to all sources + // completing normally or an error. Note that the queue itself may still + // contain data to be consumed. + bool noMoreInput_{false}; std::mutex mutex_; - std::deque> queue_; + std::deque> queue_; // The map from consumer id to the waiting promise folly::F14FastMap promises_; diff --git a/velox/exec/ExchangeSource.h b/velox/exec/ExchangeSource.h index 2b0f74fa2057..79ec65781b89 100644 --- a/velox/exec/ExchangeSource.h +++ b/velox/exec/ExchangeSource.h @@ -106,7 +106,7 @@ class ExchangeSource : public std::enable_shared_from_this { // Returns runtime statistics. ExchangeSource is expected to report // background CPU time by including a runtime metric named - // ExchangeClient::kBackgroundCpuTimeMs. + // Operator::kBackgroundCpuTimeNanos. virtual folly::F14FastMap stats() const { VELOX_UNREACHABLE(); } diff --git a/velox/exec/Expand.cpp b/velox/exec/Expand.cpp index 5d866b888c18..44b6215819e0 100644 --- a/velox/exec/Expand.cpp +++ b/velox/exec/Expand.cpp @@ -31,6 +31,7 @@ Expand::Expand( const auto numRows = expandNode->projections().size(); fieldProjections_.reserve(numRows); constantProjections_.reserve(numRows); + constantOutputs_.reserve(numRows); const auto numColumns = expandNode->names().size(); for (const auto& rowProjections : expandNode->projections()) { std::vector rowProjection; @@ -58,6 +59,25 @@ Expand::Expand( } } +void Expand::initialize() { + if (constantProjections_.empty()) { + return; + } + const auto numColumns = constantProjections_[0].size(); + for (const auto& projections : constantProjections_) { + std::vector constantOutput; + constantOutput.reserve(numColumns); + for (const auto& constant : projections) { + if (constant) { + constantOutput.push_back(constant->toConstantVector(pool())); + } else { + constantOutput.push_back(nullptr); + } + } + constantOutputs_.emplace_back(std::move(constantOutput)); + } +} + bool Expand::needsInput() const { return !noMoreInput_ && input_ == nullptr; } @@ -81,21 +101,13 @@ RowVectorPtr Expand::getOutput() { std::vector outputColumns(outputType_->size()); const auto& rowProjection = fieldProjections_[rowIndex_]; - const auto& constantProjection = constantProjections_[rowIndex_]; + const auto& constantProjection = constantOutputs_[rowIndex_]; const auto numColumns = rowProjection.size(); for (auto i = 0; i < numColumns; ++i) { if (rowProjection[i] == kConstantChannel) { - const auto& constantExpr = constantProjection[i]; - if (constantExpr->value().isNull()) { - // Add null column. - outputColumns[i] = BaseVector::createNullConstant( - outputType_->childAt(i), numInput, pool()); - } else { - // Add constant column. - outputColumns[i] = BaseVector::createConstant( - constantExpr->type(), constantExpr->value(), numInput, pool()); - } + outputColumns[i] = + BaseVector::wrapInConstant(numInput, 0, constantProjection[i]); } else { outputColumns[i] = input_->childAt(rowProjection[i]); } diff --git a/velox/exec/Expand.h b/velox/exec/Expand.h index 97c737c1d1f3..adf87a715268 100644 --- a/velox/exec/Expand.h +++ b/velox/exec/Expand.h @@ -42,11 +42,15 @@ class Expand : public Operator { } private: + void initialize() override; + std::vector> fieldProjections_; std::vector>> constantProjections_; + std::vector> constantOutputs_; + // Used to indicate the index of fieldProjections_. int32_t rowIndex_{0}; }; diff --git a/velox/exec/FilterProject.cpp b/velox/exec/FilterProject.cpp index 4b095f3fb55c..b8189539d8f4 100644 --- a/velox/exec/FilterProject.cpp +++ b/velox/exec/FilterProject.cpp @@ -86,10 +86,15 @@ FilterProject::FilterProject( project ? project->id() : filter->id(), "FilterProject"), hasFilter_(filter != nullptr), + lazyDereference_( + dynamic_cast(project.get()) != + nullptr), project_(project), filter_(filter) { + VELOX_CHECK(!(lazyDereference_ && filter)); if (filter_ != nullptr && project_ != nullptr) { - stats().withWLock([&](auto& stats) { + folly::Synchronized& opStats = Operator::stats(); + opStats.withWLock([&](auto& stats) { stats.setStatSplitter( [filterId = filter_->id()](const auto& combinedStats) { return splitStats(combinedStats, filterId); @@ -123,7 +128,8 @@ void FilterProject::initialize() { isIdentityProjection_ = true; } numExprs_ = allExprs.size(); - exprs_ = makeExprSetFromFlag(std::move(allExprs), operatorCtx_->execCtx()); + exprs_ = makeExprSetFromFlag( + std::move(allExprs), operatorCtx_->execCtx(), lazyDereference_); if (numExprs_ > 0 && !identityProjections_.empty()) { const auto inputType = project_ ? project_->sources()[0]->outputType() @@ -146,44 +152,37 @@ void FilterProject::initialize() { void FilterProject::addInput(RowVectorPtr input) { input_ = std::move(input); - numProcessedInputRows_ = 0; -} - -bool FilterProject::allInputProcessed() { - if (!input_) { - return true; - } - if (numProcessedInputRows_ == input_->size()) { - input_ = nullptr; - return true; - } - return false; } bool FilterProject::isFinished() { - return noMoreInput_ && allInputProcessed(); + return noMoreInput_ && !input_; } RowVectorPtr FilterProject::getOutput() { - if (allInputProcessed()) { + if (!input_) { return nullptr; } + SCOPE_EXIT { + input_.reset(); + }; vector_size_t size = input_->size(); LocalSelectivityVector localRows(*operatorCtx_->execCtx(), size); auto* rows = localRows.get(); VELOX_DCHECK_NOT_NULL(rows); rows->setAll(); - EvalCtx evalCtx(operatorCtx_->execCtx(), exprs_.get(), input_.get()); - - // Pre-load lazy vectors which are referenced by both expressions and identity - // projections. - for (auto fieldIdx : multiplyReferencedFieldIndices_) { - evalCtx.ensureFieldLoaded(fieldIdx, *rows); + EvalCtx evalCtx( + operatorCtx_->execCtx(), exprs_.get(), input_.get(), lazyDereference_); + + if (!lazyDereference_) { + // Pre-load lazy vectors which are referenced by both expressions and + // identity projections. + for (auto fieldIdx : multiplyReferencedFieldIndices_) { + evalCtx.ensureFieldLoaded(fieldIdx, *rows); + } } if (!hasFilter_) { - numProcessedInputRows_ = size; VELOX_CHECK(!isIdentityProjection_); auto results = project(*rows, evalCtx); return fillOutput(size, nullptr, results); @@ -191,7 +190,6 @@ RowVectorPtr FilterProject::getOutput() { // evaluate filter auto numOut = filter(evalCtx, *rows); - numProcessedInputRows_ = size; if (numOut == 0) { // no rows passed the filer input_ = nullptr; return nullptr; @@ -229,4 +227,16 @@ vector_size_t FilterProject::filter( exprs_->eval(0, 1, true, allRows, evalCtx, results); return processFilterResults(results[0], allRows, filterEvalCtx_, pool()); } + +OperatorStats FilterProject::stats(bool clear) { + auto stats = Operator::stats(clear); + if (operatorCtx() + ->driverCtx() + ->queryConfig() + .operatorTrackExpressionStats() && + exprs_ != nullptr) { + stats.expressionStats = exprs_->stats(true /*excludeSpecialForm*/); + } + return stats; +} } // namespace facebook::velox::exec diff --git a/velox/exec/FilterProject.h b/velox/exec/FilterProject.h index d82f2cb2a400..b5ae557e0fd6 100644 --- a/velox/exec/FilterProject.h +++ b/velox/exec/FilterProject.h @@ -78,12 +78,17 @@ class FilterProject : public Operator { void initialize() override; - private: - // Tests if 'numProcessedRows_' equals to the length of input_ and clears - // outstanding references to input_ if done. Returns true if getOutput - // should return nullptr. - bool allInputProcessed(); + /// Ensures that expression stats are added to the operator stats if their + /// tracking is enabled via query config. + OperatorStats stats(bool clear) override; + + /// Returns the filterNode, call this function before initialize the operator, + /// this field is reset in function initialize. + const std::shared_ptr& filterNode() const { + return filter_; + } + private: // Evaluate filter on all rows. Return number of rows that passed the filter. // Populate filterEvalCtx_.selectedBits and selectedIndices with the indices // of the passing rows if only some rows pass the filter. If all or no rows @@ -100,6 +105,8 @@ class FilterProject : public Operator { // If true exprs_[0] is a filter and the other expressions are projections const bool hasFilter_{false}; + const bool lazyDereference_; + // Cached filter and project node for lazy initialization. After // initialization, they will be reset, and initialized_ will be set to true. std::shared_ptr project_; @@ -111,8 +118,6 @@ class FilterProject : public Operator { FilterEvalCtx filterEvalCtx_; - vector_size_t numProcessedInputRows_{0}; - // Indices for fields/input columns that are both an identity projection and // are referenced by either a filter or project expression. This is used to // identify fields that need to be preloaded before evaluating filters or diff --git a/velox/exec/GroupingSet.cpp b/velox/exec/GroupingSet.cpp index c9da8832986c..152290f9f283 100644 --- a/velox/exec/GroupingSet.cpp +++ b/velox/exec/GroupingSet.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/GroupingSet.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" using facebook::velox::common::testutil::TestValue; @@ -53,7 +54,8 @@ GroupingSet::GroupingSet( const std::optional& groupIdChannel, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, - OperatorCtx* operatorCtx, + const core::QueryConfig* queryConfig, + memory::MemoryPool* pool, folly::Synchronized* spillStats) : preGroupedKeyChannels_(std::move(preGroupedKeys)), groupingKeyOutputProjections_(std::move(groupingKeyOutputProjections)), @@ -61,7 +63,8 @@ GroupingSet::GroupingSet( isGlobal_(hashers_.empty()), isPartial_(isPartial), isRawInput_(isRawInput), - queryConfig_(operatorCtx->task()->queryCtx()->queryConfig()), + queryConfig_(queryConfig), + pool_(pool), aggregates_(std::move(aggregates)), masks_(extractMaskChannels(aggregates_)), ignoreNullKeys_(ignoreNullKeys), @@ -69,13 +72,12 @@ GroupingSet::GroupingSet( groupIdChannel_(groupIdChannel), spillConfig_(spillConfig), nonReclaimableSection_(nonReclaimableSection), - stringAllocator_(operatorCtx->pool()), - rows_(operatorCtx->pool()), - isAdaptive_(queryConfig_.hashAdaptivityEnabled()), - pool_(*operatorCtx->pool()), + stringAllocator_(pool_), + rows_(pool_), + isAdaptive_(queryConfig_->hashAdaptivityEnabled()), spillStats_(spillStats) { VELOX_CHECK_NOT_NULL(nonReclaimableSection_); - VELOX_CHECK(pool_.trackUsage()); + VELOX_CHECK(pool_->trackUsage()); for (auto& hasher : hashers_) { keyChannels_.push_back(hasher->channel()); @@ -104,7 +106,7 @@ GroupingSet::GroupingSet( } sortedAggregations_ = - SortedAggregations::create(aggregates_, inputType, &pool_); + SortedAggregations::create(aggregates_, inputType, pool_); if (isPartial_) { VELOX_USER_CHECK_NULL( sortedAggregations_, @@ -117,7 +119,7 @@ GroupingSet::GroupingSet( !isPartial_, "Partial aggregations over distinct inputs are not supported"); distinctAggregations_.emplace_back( - DistinctAggregations::create({&aggregate}, inputType, &pool_)); + DistinctAggregations::create({&aggregate}, inputType, pool_)); } else { distinctAggregations_.push_back(nullptr); } @@ -148,7 +150,8 @@ std::unique_ptr GroupingSet::createForMarkDistinct( /*groupIdColumn=*/std::nullopt, /*spillConfig=*/nullptr, nonReclaimableSection, - operatorCtx, + &operatorCtx->driverCtx()->queryConfig(), + operatorCtx->pool(), /*spillStats=*/nullptr); }; @@ -374,10 +377,10 @@ std::vector GroupingSet::accumulators(bool excludeToIntermediate) { void GroupingSet::createHashTable() { if (ignoreNullKeys_) { table_ = HashTable::createForAggregation( - std::move(hashers_), accumulators(false), &pool_); + std::move(hashers_), accumulators(false), pool_); } else { table_ = HashTable::createForAggregation( - std::move(hashers_), accumulators(false), &pool_); + std::move(hashers_), accumulators(false), pool_); } RowContainer& rows = *table_->rows(); @@ -416,7 +419,7 @@ void GroupingSet::createHashTable() { } } - lookup_ = std::make_unique(table_->hashers(), &pool_); + lookup_ = std::make_unique(table_->hashers(), pool_); if (!isAdaptive_ && table_->hashMode() != BaseHashTable::HashMode::kHash) { table_->forceGenericHashMode(BaseHashTable::kNoSpillInputStartPartitionBit); } @@ -427,7 +430,7 @@ void GroupingSet::initializeGlobalAggregation() { return; } - lookup_ = std::make_unique(hashers_, &pool_); + lookup_ = std::make_unique(hashers_, pool_); lookup_->reset(1); // Row layout is: @@ -642,7 +645,7 @@ bool GroupingSet::getDefaultGlobalGroupingSetOutput( } auto globalAggregatesRow = - BaseVector::create(result->type(), 1, &pool_); + BaseVector::create(result->type(), 1, pool_); VELOX_CHECK(getGlobalAggregationOutput(iterator, globalAggregatesRow)); @@ -878,16 +881,16 @@ void GroupingSet::ensureInputFits(const RowVectorPtr& input) { const int64_t flatBytes = input->estimateFlatSize(); // Test-only spill path. - if (testingTriggerSpill(pool_.name())) { + if (testingTriggerSpill(pool_->name())) { memory::ReclaimableSectionGuard guard(nonReclaimableSection_); - memory::testingRunArbitration(&pool_); + memory::testingRunArbitration(pool_); return; } - const auto currentUsage = pool_.usedBytes(); + const auto currentUsage = pool_->usedBytes(); const auto minReservationBytes = currentUsage * spillConfig_->minSpillableReservationPct / 100; - const auto availableReservationBytes = pool_.availableReservation(); + const auto availableReservationBytes = pool_->availableReservation(); const auto tableIncrementBytes = table_->hashTableSizeIncrease(input->size()); const auto incrementBytes = rows->sizeIncrement(input->size(), outOfLineBytes ? flatBytes * 2 : 0) + @@ -924,14 +927,14 @@ void GroupingSet::ensureInputFits(const RowVectorPtr& input) { currentUsage * spillConfig_->spillableReservationGrowthPct / 100); { memory::ReclaimableSectionGuard guard(nonReclaimableSection_); - if (pool_.maybeReserve(targetIncrementBytes)) { + if (pool_->maybeReserve(targetIncrementBytes)) { return; } } LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) - << " for memory pool " << pool_.name() - << ", usage: " << succinctBytes(pool_.usedBytes()) - << ", reservation: " << succinctBytes(pool_.reservedBytes()); + << " for memory pool " << pool_->name() + << ", usage: " << succinctBytes(pool_->usedBytes()) + << ", reservation: " << succinctBytes(pool_->reservedBytes()); } void GroupingSet::ensureOutputFits() { @@ -945,33 +948,33 @@ void GroupingSet::ensureOutputFits() { } // Test-only spill path. - if (testingTriggerSpill(pool_.name())) { + if (testingTriggerSpill(pool_->name())) { memory::ReclaimableSectionGuard guard(nonReclaimableSection_); - memory::testingRunArbitration(&pool_); + memory::testingRunArbitration(pool_); return; } const uint64_t outputBufferSizeToReserve = - queryConfig_.preferredOutputBatchBytes() * 1.2; + queryConfig_->preferredOutputBatchBytes() * 1.2; { memory::ReclaimableSectionGuard guard(nonReclaimableSection_); - if (pool_.maybeReserve(outputBufferSizeToReserve)) { + if (pool_->maybeReserve(outputBufferSizeToReserve)) { if (hasSpilled()) { // If reservation triggers spilling on the 'GroupingSet' itself, we will // no longer need the reserved memory for output processing as the // output processing will be conducted from unspilled data through // 'getOutputWithSpill()', and it does not require this amount of memory // to process. - pool_.release(); + pool_->release(); } return; } } LOG(WARNING) << "Failed to reserve " << succinctBytes(outputBufferSizeToReserve) - << " for memory pool " << pool_.name() - << ", usage: " << succinctBytes(pool_.usedBytes()) - << ", reservation: " << succinctBytes(pool_.reservedBytes()); + << " for memory pool " << pool_->name() + << ", usage: " << succinctBytes(pool_->usedBytes()) + << ", reservation: " << succinctBytes(pool_->reservedBytes()); } RowTypePtr GroupingSet::makeSpillType() const { @@ -1013,7 +1016,7 @@ void GroupingSet::spill() { auto* rows = table_->rows(); VELOX_CHECK_NULL(outputSpiller_); if (inputSpiller_ == nullptr) { - VELOX_DCHECK(pool_.trackUsage()); + VELOX_DCHECK(pool_->trackUsage()); VELOX_CHECK(numDistinctSpillFilesPerPartition_.empty()); const auto sortingKeys = SpillState::makeSortingKeys( std::vector(rows->keyTypes().size())); @@ -1061,7 +1064,7 @@ void GroupingSet::spill(const RowContainerIterator& rowIterator) { } auto* rows = table_->rows(); - VELOX_CHECK(pool_.trackUsage()); + VELOX_CHECK(pool_->trackUsage()); outputSpiller_ = std::make_unique( rows, makeSpillType(), spillConfig_, spillStats_); @@ -1099,7 +1102,8 @@ bool GroupingSet::getOutputWithSpill( false, false, false, - &pool_); + false, + pool_); initializeAggregates(aggregates_, *mergeRows_, false); } @@ -1134,8 +1138,7 @@ bool GroupingSet::prepareNextSpillPartitionOutput() { auto it = spillPartitionSet_.begin(); VELOX_CHECK_NE(outputSpillPartition_, it->first.partitionNumber()); outputSpillPartition_ = it->first.partitionNumber(); - merge_ = it->second->createOrderedReader( - spillConfig_->readBufferSize, &pool_, spillStats_); + merge_ = it->second->createOrderedReader(*spillConfig_, pool_, spillStats_); spillPartitionSet_.erase(it); return true; } @@ -1230,7 +1233,7 @@ void GroupingSet::prepareSpillResultWithoutAggregates( spillResultWithoutAggregates_ = BaseVector::create( std::make_shared(std::move(names), std::move(types)), maxOutputRows, - &pool_); + pool_); } else { VectorPtr spillResultWithoutAggregates = std::move(spillResultWithoutAggregates_); @@ -1244,6 +1247,9 @@ void GroupingSet::prepareSpillResultWithoutAggregates( spillResultWithoutAggregates_->childAt(groupingKeyOutputProjections_[i]) = std::move(result->childAt(i)); } + + spillSources_.resize(maxOutputRows); + spillSourceRows_.resize(maxOutputRows); } void GroupingSet::projectResult(const RowVectorPtr& result) { @@ -1264,6 +1270,7 @@ bool GroupingSet::mergeNextWithoutAggregates( VELOX_CHECK_EQ( numDistinctSpillFilesPerPartition_.size(), 1 << spillConfig_->numPartitionBits); + VELOX_CHECK(pool_ == result->pool()); // We are looping over sorted rows produced by tree-of-losers. We logically // split the stream into runs of duplicate rows. As we process each run we @@ -1278,12 +1285,15 @@ bool GroupingSet::mergeNextWithoutAggregates( // less than 'numDistinctSpillFilesPerPartition_'. bool newDistinct{true}; int32_t numOutputRows{0}; + int32_t outputSize{0}; + bool endOfBatch = false; prepareSpillResultWithoutAggregates(maxOutputRows, result); - while (numOutputRows < maxOutputRows) { + while (numOutputRows + outputSize < maxOutputRows) { const auto next = merge_->nextWithEquals(); auto* stream = next.first; if (stream == nullptr) { + VELOX_CHECK_EQ(outputSize, 0); if (numOutputRows > 0) { break; } @@ -1298,17 +1308,40 @@ bool GroupingSet::mergeNextWithoutAggregates( numDistinctSpillFilesPerPartition_[outputSpillPartition_]) { newDistinct = false; } - if (next.second) { - stream->pop(); - continue; - } - if (newDistinct) { + auto index = stream->currentIndex(&endOfBatch); + if (!next.second && newDistinct) { // Yield result for new distinct. - spillResultWithoutAggregates_->copy( - &stream->current(), numOutputRows++, stream->currentIndex(), 1); + spillSources_[outputSize] = &stream->current(); + spillSourceRows_[outputSize] = index; + ++outputSize; + } + + if (FOLLY_UNLIKELY(endOfBatch)) { + // The stream is at end of input batch. Need to copy out the rows before + // fetching next batch in 'pop'. + gatherCopy( + spillResultWithoutAggregates_.get(), + numOutputRows, + outputSize, + spillSources_, + spillSourceRows_); + numOutputRows += outputSize; + outputSize = 0; } stream->pop(); - newDistinct = true; + // Reset newDistinct flag for new row. + if (!next.second) { + newDistinct = true; + } + } + if (FOLLY_LIKELY(outputSize != 0)) { + gatherCopy( + spillResultWithoutAggregates_.get(), + numOutputRows, + outputSize, + spillSources_, + spillSourceRows_); + numOutputRows += outputSize; } spillResultWithoutAggregates_->resize(numOutputRows); projectResult(result); @@ -1405,7 +1438,8 @@ void GroupingSet::abandonPartialAggregation() { false, false, false, - &pool_); + false, + pool_); initializeAggregates(aggregates_, *intermediateRows_, true); table_.reset(); } @@ -1500,8 +1534,9 @@ void GroupingSet::toIntermediate( &aggregateVector); } if (intermediateRows_) { - intermediateRows_->eraseRows(folly::Range( - intermediateGroups_.data(), intermediateGroups_.size())); + intermediateRows_->eraseRows( + folly::Range( + intermediateGroups_.data(), intermediateGroups_.size())); } // It's unnecessary to call function->clear() to reset the internal states of diff --git a/velox/exec/GroupingSet.h b/velox/exec/GroupingSet.h index a599fb9b628d..f84a327f5f71 100644 --- a/velox/exec/GroupingSet.h +++ b/velox/exec/GroupingSet.h @@ -15,13 +15,13 @@ */ #pragma once +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/AggregateInfo.h" #include "velox/exec/AggregationMasks.h" #include "velox/exec/DistinctAggregations.h" #include "velox/exec/HashTable.h" #include "velox/exec/SortedAggregations.h" #include "velox/exec/Spiller.h" -#include "velox/exec/TreeOfLosers.h" #include "velox/exec/VectorHasher.h" namespace facebook::velox::exec { @@ -43,7 +43,8 @@ class GroupingSet { const std::optional& groupIdChannel, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, - OperatorCtx* operatorCtx, + const core::QueryConfig* queryConfig, + memory::MemoryPool* pool, folly::Synchronized* spillStats); ~GroupingSet(); @@ -148,7 +149,7 @@ class GroupingSet { RowContainerIterator& iterator, RowVectorPtr& result); - memory::MemoryPool& testingPool() const { + memory::MemoryPool* testingPool() const { return pool_; } @@ -302,7 +303,8 @@ class GroupingSet { const bool isGlobal_; const bool isPartial_; const bool isRawInput_; - const core::QueryConfig& queryConfig_; + const core::QueryConfig* const queryConfig_; + memory::MemoryPool* const pool_; std::vector aggregates_; AggregationMasks masks_; @@ -358,6 +360,11 @@ class GroupingSet { // result. RowVectorPtr spillResultWithoutAggregates_{nullptr}; + // Records the source rows to copy to 'output_' in order. + std::vector spillSources_; + + std::vector spillSourceRows_; + // The value of mayPushdown flag specified in addInput() for the // 'remainingInput_'. bool remainingMayPushdown_; @@ -391,9 +398,6 @@ class GroupingSet { // to merge. SelectivityVector mergeSelection_; - // Pool of the OperatorCtx. Used for spilling. - memory::MemoryPool& pool_; - // True if partial aggregation has been given up as non-productive. bool abandonedPartialAggregation_{false}; diff --git a/velox/exec/HashAggregation.cpp b/velox/exec/HashAggregation.cpp index 14e87f336ec2..bcc3b4b97cbd 100644 --- a/velox/exec/HashAggregation.cpp +++ b/velox/exec/HashAggregation.cpp @@ -84,7 +84,7 @@ void HashAggregation::initialize() { "Unexpected result type for an aggregation: {}, expected {}, step {}", aggResultType->toString(), expectedType->toString(), - core::AggregationNode::stepName(aggregationNode_->step())); + core::AggregationNode::toName(aggregationNode_->step())); } for (auto i = 0; i < hashers.size(); ++i) { @@ -112,8 +112,9 @@ void HashAggregation::initialize() { groupIdChannel, spillConfig_.has_value() ? &spillConfig_.value() : nullptr, &nonReclaimableSection_, - operatorCtx_.get(), - &spillStats_); + &operatorCtx_->driverCtx()->queryConfig(), + operatorCtx_->pool(), + spillStats_.get()); aggregationNode_.reset(); } diff --git a/velox/exec/HashBitRange.h b/velox/exec/HashBitRange.h index b4ebcd30a597..3d2af0f57dcb 100644 --- a/velox/exec/HashBitRange.h +++ b/velox/exec/HashBitRange.h @@ -55,14 +55,10 @@ class HashBitRange { return 1 << numBits(); } - inline bool operator==(const HashBitRange& other) const { + bool operator==(const HashBitRange& other) const { return std::tie(begin_, end_) == std::tie(other.begin_, other.end_); } - inline bool operator!=(const HashBitRange& other) const { - return !(*this == other); - } - private: // Low bit number of hash number bit range. uint8_t begin_; diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 3cd021bbb7ec..389065a448c2 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -18,8 +18,10 @@ #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/HashTableCache.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" +#include "velox/exec/VectorHasher.h" #include "velox/expression/FieldReference.h" using facebook::velox::common::testutil::TestValue; @@ -62,6 +64,13 @@ HashBuild::HashBuild( joinType_{joinNode_->joinType()}, nullAware_{joinNode_->isNullAware()}, needProbedFlagSpill_{needRightSideJoin(joinType_)}, + dropDuplicates_(joinNode_->canDropDuplicates()), + vectorHasherMaxNumDistinct_( + driverCtx->queryConfig().joinBuildVectorHasherMaxNumDistinct()), + abandonHashBuildDedupMinRows_( + driverCtx->queryConfig().abandonHashBuildDedupMinRows()), + abandonHashBuildDedupMinPct_( + driverCtx->queryConfig().abandonHashBuildDedupMinPct()), joinBridge_(operatorCtx_->task()->getHashJoinBridgeLocked( operatorCtx_->driverCtx()->splitGroupId, planNodeId())), @@ -85,36 +94,127 @@ HashBuild::HashBuild( // Identify the non-key build side columns and make a decoder for each. const int32_t numDependents = inputType->size() - numKeys; - if (numDependents > 0) { - // Number of join keys (numKeys) may be less then number of input columns - // (inputType->size()). In this case numDependents is negative and cannot be - // used to call 'reserve'. This happens when we join different probe side - // keys with the same build side key: SELECT * FROM t LEFT JOIN u ON t.k1 = - // u.k AND t.k2 = u.k. - dependentChannels_.reserve(numDependents); - decoders_.reserve(numDependents); - } - for (auto i = 0; i < inputType->size(); ++i) { - if (keyChannelMap_.find(i) == keyChannelMap_.end()) { - dependentChannels_.emplace_back(i); - decoders_.emplace_back(std::make_unique()); + if (!dropDuplicates_) { + if (numDependents > 0) { + // Number of join keys (numKeys) may be less then number of input columns + // (inputType->size()). In this case numDependents is negative and cannot + // be used to call 'reserve'. This happens when we join different probe + // side keys with the same build side key: SELECT * FROM t LEFT JOIN u ON + // t.k1 = u.k AND t.k2 = u.k. + dependentChannels_.reserve(numDependents); + decoders_.reserve(numDependents); + } + + for (auto i = 0; i < inputType->size(); ++i) { + if (keyChannelMap_.find(i) == keyChannelMap_.end()) { + dependentChannels_.emplace_back(i); + decoders_.emplace_back(std::make_unique()); + } } } tableType_ = hashJoinTableType(joinNode_); - setupTable(); - setupSpiller(); + stateCleared_ = false; } void HashBuild::initialize() { Operator::initialize(); + if (setupCachedHashTable()) { + return; + } + + // Set up table and spiller now that cache state is initialized. + // This ensures tableMemoryPool() returns the cache's tablePool when enabled. + setupTable(); + setupSpiller(); + if (isAntiJoin(joinType_) && joinNode_->filter()) { setupFilterForAntiJoins(keyChannelMap_); } } +bool HashBuild::setupCachedHashTable() { + if (!joinNode_->useHashTableCache()) { + return false; + } + + const auto& queryId = operatorCtx_->task()->queryCtx()->queryId(); + cacheKey_ = fmt::format("{}:{}", queryId, planNodeId()); + + // Get or create the cache entry (which includes the pool). + // If another task is already building, future_ will be set. + auto* cache = HashTableCache::instance(); + auto* queryCtx = operatorCtx_->task()->queryCtx().get(); + cacheEntry_ = cache->get(cacheKey_, taskId(), queryCtx, &future_); + VELOX_CHECK_NOT_NULL(cacheEntry_); + VELOX_CHECK_NOT_NULL(cacheEntry_->tablePool); + + // Check if table is already built. + if (cacheEntry_->buildComplete) { + noMoreInput(); + return true; + } + + // Check if we're a waiter task (future was set by get). + if (future_.valid()) { + setState(State::kWaitForBuild); + return true; + } + + // This is the builder task - proceed with building. + return false; +} + +bool HashBuild::getHashTableFromCache() { + if (!useHashTableCache()) { + return false; + } + + if (!cacheEntry_->buildComplete) { + // Cache miss - we need to build the table. + stats_.wlock()->addRuntimeStat( + BaseHashTable::kHashTableCacheMiss, RuntimeCounter(1)); + return false; + } + + // Table already built by a previous task! Use it directly. + // Notify the bridge with the cached table. + // We pass a shared_ptr copy (not std::move) since the cache retains + // ownership. + joinBridge_->setHashTable( + cacheEntry_->table, {}, cacheEntry_->hasNullKeys, nullptr); + // Record cache hit metric. + stats_.wlock()->addRuntimeStat( + BaseHashTable::kHashTableCacheHit, RuntimeCounter(1)); + return true; +} + +void HashBuild::maybeSetHashTableInCache( + const std::shared_ptr& table) { + if (!useHashTableCache()) { + return; + } + auto* cache = HashTableCache::instance(); + cache->put(cacheKey(), table, joinHasNullKeys_); +} + +bool HashBuild::receivedCachedHashTable() { + if (!useHashTableCache() || future_.valid()) { + return false; + } + // We were waiting on cached table from another task. + // Ensure that table is ready. + VELOX_CHECK( + cacheEntry_->buildComplete, + "Signalled that cache table is ready but it is not built yet."); + // Proceed through normal noMoreInput flow which will use the cache. + setRunning(); + noMoreInput(); + return true; +} + void HashBuild::setupTable() { VELOX_CHECK_NULL(table_); @@ -132,6 +232,7 @@ void HashBuild::setupTable() { for (int i = numKeys; i < tableType_->size(); ++i) { dependentTypes.emplace_back(tableType_->childAt(i)); } + auto& queryConfig = operatorCtx_->driverCtx()->queryConfig(); if (joinNode_->isRightJoin() || joinNode_->isFullJoin() || joinNode_->isRightSemiProjectJoin()) { // Do not ignore null keys. @@ -140,16 +241,9 @@ void HashBuild::setupTable() { dependentTypes, true, // allowDuplicates true, // hasProbedFlag - operatorCtx_->driverCtx() - ->queryConfig() - .minTableRowsForParallelJoinBuild(), - pool()); + queryConfig.minTableRowsForParallelJoinBuild(), + tableMemoryPool()); } else { - // (Left) semi and anti join with no extra filter only needs to know whether - // there is a match. Hence, no need to store entries with duplicate keys. - const bool dropDuplicates = !joinNode_->filter() && - (joinNode_->isLeftSemiFilterJoin() || - joinNode_->isLeftSemiProjectJoin() || isAntiJoin(joinType_)); // Right semi join needs to tag build rows that were probed. const bool needProbedFlag = joinNode_->isRightSemiFilterJoin(); if (isLeftNullAwareJoinWithFilter(joinNode_)) { @@ -158,26 +252,32 @@ void HashBuild::setupTable() { table_ = HashTable::createForJoin( std::move(keyHashers), dependentTypes, - !dropDuplicates, // allowDuplicates + !dropDuplicates_, // allowDuplicates needProbedFlag, // hasProbedFlag - operatorCtx_->driverCtx() - ->queryConfig() - .minTableRowsForParallelJoinBuild(), - pool()); + queryConfig.minTableRowsForParallelJoinBuild(), + tableMemoryPool()); } else { // Ignore null keys table_ = HashTable::createForJoin( std::move(keyHashers), dependentTypes, - !dropDuplicates, // allowDuplicates + !dropDuplicates_, // allowDuplicates needProbedFlag, // hasProbedFlag - operatorCtx_->driverCtx() - ->queryConfig() - .minTableRowsForParallelJoinBuild(), - pool()); + queryConfig.minTableRowsForParallelJoinBuild(), + tableMemoryPool(), + queryConfig.hashProbeBloomFilterPushdownMaxSize()); } } analyzeKeys_ = table_->hashMode() != BaseHashTable::HashMode::kHash; + if (abandonHashBuildDedupMinPct_ == 0) { + // Building a HashTable without duplicates is disabled if + // abandonBuildNoDupHashMinPct_ is 0. + abandonHashBuildDedup_ = true; + table_->setAllowDuplicates(true); + return; + } + // Only create HashLookup when dedup is enabled. + lookup_ = std::make_unique(table_->hashers(), pool()); } void HashBuild::setupSpiller(SpillPartition* spillPartition) { @@ -203,7 +303,7 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) { uint8_t startPartitionBit = config->startPartitionBit; if (spillPartition != nullptr) { spillInputReader_ = spillPartition->createUnorderedReader( - config->readBufferSize, pool(), &spillStats_); + config->readBufferSize, pool(), spillStats_.get()); VELOX_CHECK(!restoringPartitionId_.has_value()); restoringPartitionId_ = spillPartition->id(); const auto numPartitionBits = config->numPartitionBits; @@ -218,7 +318,7 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) { LOG(WARNING) << "Exceeded spill level limit: " << config->maxSpillLevel << ", and disable spilling for memory pool: " << pool()->name(); - ++spillStats_.wlock()->spillMaxLevelExceededCount; + ++spillStats_->wlock()->spillMaxLevelExceededCount; exceededMaxSpillLevelLimit_ = true; return; } @@ -233,7 +333,7 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) { HashBitRange( startPartitionBit, startPartitionBit + config->numPartitionBits), config, - &spillStats_); + spillStats_.get()); const int32_t numPartitions = spiller_->hashBits().numPartitions(); spillInputIndicesBuffers_.resize(numPartitions); @@ -308,6 +408,11 @@ void HashBuild::removeInputRowsForAntiJoinFilter() { void HashBuild::addInput(RowVectorPtr input) { checkRunning(); + + VELOX_CHECK( + !useHashTableCache() || + (cacheEntry_->builderTaskId == taskId() && !cacheEntry_->buildComplete)); + ensureInputFits(input); TestValue::adjust("facebook::velox::exec::HashBuild::addInput", this); @@ -376,6 +481,25 @@ void HashBuild::addInput(RowVectorPtr input) { return; } + if (dropDuplicates_ && !abandonHashBuildDedup_) { + const bool abandonEarly = abandonHashBuildDedupEarly(table_->numDistinct()); + if (!abandonEarly) { + numHashInputRows_ += activeRows_.countSelected(); + table_->prepareForGroupProbe( + *lookup_, + input, + activeRows_, + BaseHashTable::kNoSpillInputStartPartitionBit); + if (lookup_->rows.empty()) { + return; + } + table_->groupProbe( + *lookup_, BaseHashTable::kNoSpillInputStartPartitionBit); + return; + } + abandonHashBuildDedup(); + } + if (analyzeKeys_ && hashes_.size() < activeRows_.end()) { hashes_.resize(activeRows_.end()); } @@ -626,6 +750,7 @@ void HashBuild::noMoreInput() { if (noMoreInput_) { return; } + Operator::noMoreInput(); noMoreInputInternal(); @@ -670,6 +795,10 @@ bool HashBuild::finishHashBuild() { } }; + if (getHashTableFromCache()) { + return true; + } + if (joinHasNullKeys_ && isAntiJoin(joinType_) && nullAware_ && !joinNode_->filter()) { joinBridge_->setAntiJoinHasNullKeys(); @@ -754,6 +883,8 @@ bool HashBuild::finishHashBuild() { std::move(otherTables), isInputFromSpill() ? spillConfig()->startPartitionBit : BaseHashTable::kNoSpillInputStartPartitionBit, + vectorHasherMaxNumDistinct_, + dropDuplicates_, allowParallelJoinBuild ? operatorCtx_->task()->queryCtx()->executor() : nullptr); } @@ -768,26 +899,31 @@ bool HashBuild::finishHashBuild() { HashJoinTableSpillFunc tableSpillFunc; if (canReclaim()) { VELOX_CHECK_NOT_NULL(spiller_); - tableSpillFunc = [hashBitRange = spiller_->hashBits(), - restoringPartitionId = restoringPartitionId_, - joinNode = joinNode_, - spillConfig = spillConfig(), - spillStats = - &spillStats_](std::shared_ptr table) { - return spillHashJoinTable( - table, - restoringPartitionId, - hashBitRange, - joinNode, - spillConfig, - spillStats); - }; - } + tableSpillFunc = + [hashBitRange = spiller_->hashBits(), + restoringPartitionId = restoringPartitionId_, + joinNode = joinNode_, + spillConfig = spillConfig(), + spillStats = spillStats_.get()](std::shared_ptr table) { + return spillHashJoinTable( + table, + restoringPartitionId, + hashBitRange, + joinNode, + spillConfig, + spillStats); + }; + } + + // For hash table caching: the last driver caches the merged table. + std::shared_ptr table = std::move(table_); + maybeSetHashTableInCache(table); joinBridge_->setHashTable( - std::move(table_), + table, std::move(spillPartitions), joinHasNullKeys_, std::move(tableSpillFunc)); + if (canSpill()) { stateCleared_ = true; } @@ -839,7 +975,6 @@ void HashBuild::ensureTableFits(uint64_t numRows) { void HashBuild::postHashBuildProcess() { checkRunning(); - if (!canSpill()) { setState(State::kFinish); return; @@ -878,6 +1013,7 @@ void HashBuild::setupSpillInput(HashJoinBridge::SpillInput spillInput) { setupTable(); setupSpiller(spillInput.spillPartition.get()); stateCleared_ = false; + numHashInputRows_ = 0; // Start to process spill input. processSpillInput(); @@ -891,7 +1027,7 @@ void HashBuild::processSpillInput() { if (!isRunning()) { return; } - if (operatorCtx_->driver()->shouldYield()) { + if (shouldYield()) { state_ = State::kYield; future_ = ContinueFuture{folly::Unit{}}; return; @@ -928,6 +1064,36 @@ void HashBuild::addRuntimeStats() { RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos)); } + for (const auto& timing : + table_->parallelJoinBuildStats().bloomFilterPartitionTimings) { + lockedStats->getOutputTiming.add(timing); + if (timing.wallNanos > 0) { + lockedStats->addRuntimeStat( + BaseHashTable::kParallelJoinBloomFilterPartitionWallNanos, + RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos)); + } + if (timing.cpuNanos > 0) { + lockedStats->addRuntimeStat( + BaseHashTable::kParallelJoinBloomFilterPartitionCpuNanos, + RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos)); + } + } + + for (const auto& timing : + table_->parallelJoinBuildStats().bloomFilterBuildTimings) { + lockedStats->getOutputTiming.add(timing); + if (timing.wallNanos > 0) { + lockedStats->addRuntimeStat( + BaseHashTable::kParallelJoinBloomFilterBuildWallNanos, + RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos)); + } + if (timing.cpuNanos > 0) { + lockedStats->addRuntimeStat( + BaseHashTable::kParallelJoinBloomFilterBuildCpuNanos, + RuntimeCounter(timing.cpuNanos, RuntimeCounter::Unit::kNanos)); + } + } + for (auto i = 0; i < hashers.size(); i++) { hashers[i]->cardinality(0, asRange, asDistinct); if (asRange != VectorHasher::kRangeTooLarge) { @@ -958,6 +1124,12 @@ void HashBuild::addRuntimeStats() { RuntimeCounter( spillConfig()->spillLevel(spiller_->hashBits().begin()))); } + + lockedStats->addRuntimeStat( + BaseHashTable::kVectorHasherMergeCpuNanos, + RuntimeCounter( + table_->vectorHasherMergeTiming().cpuNanos, + RuntimeCounter::Unit::kNanos)); } BlockingReason HashBuild::isBlocked(ContinueFuture* future) { @@ -975,6 +1147,11 @@ BlockingReason HashBuild::isBlocked(ContinueFuture* future) { case State::kFinish: break; case State::kWaitForBuild: + if (receivedCachedHashTable()) { + break; + } + // We were waiting for peer drivers to finish - fall through to + // kWaitForProbe which has the same logic. [[fallthrough]]; case State::kWaitForProbe: if (!future_.valid()) { @@ -1058,6 +1235,11 @@ bool HashBuild::canSpill() const { if (!Operator::canSpill()) { return false; } + // For Cached hash table, we don't support spill either by the + // task thats building or by the task that is re-using it + if (useHashTableCache()) { + return false; + } if (operatorCtx_->task()->hasMixedExecutionGroupJoin(joinNode_.get())) { return operatorCtx_->driverCtx() ->queryConfig() @@ -1149,6 +1331,19 @@ void HashBuild::reclaim( } } +memory::MemoryPool* HashBuild::tableMemoryPool() const { + if (useHashTableCache()) { + // Cached hash tables use a leaf pool under the query pool (from cache + // entry). This allows the table to outlive the task while still supporting + // allocations. + VELOX_CHECK_NOT_NULL(cacheEntry_); + VELOX_CHECK_NOT_NULL(cacheEntry_->tablePool); + return cacheEntry_->tablePool.get(); + } + // Regular joins use operator pool + return pool(); +} + bool HashBuild::nonReclaimableState() const { // Apart from being in the nonReclaimable section, it's also not reclaimable // if: @@ -1239,4 +1434,21 @@ void HashBuildSpiller::extractSpill( rows.data(), rows.size(), false, false, result->childAt(types.size())); } } + +bool HashBuild::abandonHashBuildDedupEarly(int64_t numDistinct) const { + VELOX_CHECK(dropDuplicates_); + return numHashInputRows_ > abandonHashBuildDedupMinRows_ && + 100 * numDistinct / numHashInputRows_ >= abandonHashBuildDedupMinPct_; +} + +void HashBuild::abandonHashBuildDedup() { + // The hash table is no longer directly constructed in addInput. The data + // that was previously inserted into the hash table is already in the + // RowContainer. + addRuntimeStat("abandonBuildNoDupHash", RuntimeCounter(1)); + abandonHashBuildDedup_ = true; + table_->setAllowDuplicates(true); + lookup_.reset(); +} + } // namespace facebook::velox::exec diff --git a/velox/exec/HashBuild.h b/velox/exec/HashBuild.h index b2df3dc4a283..61315d15f9d2 100644 --- a/velox/exec/HashBuild.h +++ b/velox/exec/HashBuild.h @@ -17,12 +17,11 @@ #include "velox/exec/HashJoinBridge.h" #include "velox/exec/HashTable.h" +#include "velox/exec/HashTableCache.h" #include "velox/exec/Operator.h" #include "velox/exec/Spill.h" #include "velox/exec/Spiller.h" #include "velox/exec/UnorderedStreamReader.h" -#include "velox/exec/VectorHasher.h" -#include "velox/expression/Expr.h" namespace facebook::velox::exec { class HashBuildSpiller; @@ -109,6 +108,25 @@ class HashBuild final : public Operator { // Invoked to set up hash table to build. void setupTable(); + // Sets up hash table caching if enabled. Returns true if the cached table + // is already available or if this operator should wait for another task + // to build it, in which case further initialization should be skipped. + // Returns false if this operator should proceed with building the table. + bool setupCachedHashTable(); + + // Checks if a cached hash table is available and uses it if so. + // Returns true if the cached table was used (build can be skipped). + // Returns false if we need to build the table (cache miss). + bool getHashTableFromCache(); + + // Called when waiting for a cached hash table from another task. + // Returns true if the cached table was received and noMoreInput was called. + bool receivedCachedHashTable(); + + // Stores the built hash table in the cache for reuse by other tasks. + // No-op if hash table caching is not enabled. + void maybeSetHashTableInCache(const std::shared_ptr& table); + // Invoked when operator has finished processing the build input and wait for // all the other drivers to finish the processing. The last driver that // reaches to the hash build barrier, is responsible to build the hash table @@ -206,6 +224,35 @@ class HashBuild final : public Operator { // not. bool nonReclaimableState() const; + // True if we have enough rows and not enough duplicate join keys, i.e. more + // than 'abandonHashBuildDedupMinRows_' rows and more than + // 'abandonHashBuildDedupMinPct_' % of rows are unique. + bool abandonHashBuildDedupEarly(int64_t numDistinct) const; + + // Invoked to abandon build deduped hash table. + void abandonHashBuildDedup(); + + // Returns true if this operator is using a cached hash table. + // When enabled, the hash table is built once and cached for reuse + // by other tasks within the same query and stage. + bool useHashTableCache() const { + return !cacheKey_.empty(); + } + + // Returns the hash table cache key for this operator. + // Only valid if useHashTableCache() returns true. + const std::string& cacheKey() const { + VELOX_CHECK( + useHashTableCache(), + "cacheKey() called when table caching is not enabled"); + return cacheKey_; + } + + // Determines the memory pool to use for the hash table. + // For cached hash tables, uses query-level pool so the table can + // outlive the task. For regular joins, uses operator pool. + memory::MemoryPool* tableMemoryPool() const; + const std::shared_ptr joinNode_; const core::JoinType joinType_; @@ -219,12 +266,37 @@ class HashBuild final : public Operator { // not. const bool needProbedFlagSpill_; + // Indicates whether drop duplicate rows. Rows containing duplicate keys + // can be removed for left semi and anti join. + const bool dropDuplicates_; + + // Maximum number of distinct values to keep when merging vector hashers + const size_t vectorHasherMaxNumDistinct_; + + // Minimum number of rows to see before deciding to give up build no + // duplicates hash table. + const int32_t abandonHashBuildDedupMinRows_; + + // Min unique rows pct for give up build deduped hash table. If more + // than this many rows are unique, build hash table in addInput phase is not + // worthwhile. + const int32_t abandonHashBuildDedupMinPct_; + std::shared_ptr joinBridge_; tsan_atomic exceededMaxSpillLevelLimit_{false}; State state_{State::kRunning}; + // For hash table caching: the cache key passed in at construction. + // If set, this operator coordinates via HashTableCache. + // Key format: "queryId:planNodeId" + std::string cacheKey_; + + // For hash table caching: cached entry containing the shared table and pool. + // Retrieved from HashTableCache. + std::shared_ptr cacheEntry_; + // The row type used for hash table build and disk spilling. RowTypePtr tableType_; @@ -244,6 +316,9 @@ class HashBuild final : public Operator { // Container for the rows being accumulated. std::unique_ptr table_; + // Used for building hash table while adding input rows. + std::unique_ptr lookup_; + // Key channels in 'input_' std::vector keyChannels_; @@ -271,6 +346,10 @@ class HashBuild final : public Operator { // at least one entry with null join keys. bool joinHasNullKeys_{false}; + // Whether to abandon building a HashTable without duplicates in HashBuild + // addInput phase for left semi/anti join. + bool abandonHashBuildDedup_{false}; + // The type used to spill hash table which might attach a boolean column to // record the probed flag if 'needProbedFlagSpill_' is true. RowTypePtr spillType_; @@ -312,6 +391,10 @@ class HashBuild final : public Operator { // Maps key channel in 'input_' to channel in key. folly::F14FastMap keyChannelMap_; + + // Count the number of hash table input rows for building deduped + // hash table. It will not be updated after abandonBuildNoDupHash_ is true. + int64_t numHashInputRows_ = 0; }; inline std::ostream& operator<<(std::ostream& os, HashBuild::State state) { diff --git a/velox/exec/HashJoinBridge.cpp b/velox/exec/HashJoinBridge.cpp index 7c359c924c3e..8fe29d0e61b2 100644 --- a/velox/exec/HashJoinBridge.cpp +++ b/velox/exec/HashJoinBridge.cpp @@ -15,7 +15,6 @@ */ #include "velox/exec/HashJoinBridge.h" -#include "velox/common/memory/MemoryArbitrator.h" #include "velox/exec/HashBuild.h" namespace facebook::velox::exec { @@ -43,13 +42,18 @@ RowTypePtr hashJoinTableType( types.emplace_back(inputType->childAt(channel)); } + if (joinNode->canDropDuplicates()) { + // For left semi and anti join with no extra filter, hash table does not + // store dependent columns. + return ROW(std::move(names), std::move(types)); + } + for (auto i = 0; i < inputType->size(); ++i) { if (keyChannelSet.find(i) == keyChannelSet.end()) { names.emplace_back(inputType->nameOf(i)); types.emplace_back(inputType->childAt(i)); } } - return ROW(std::move(names), std::move(types)); } @@ -212,7 +216,7 @@ SpillPartitionSet spillHashJoinTable( } void HashJoinBridge::setHashTable( - std::unique_ptr table, + std::shared_ptr table, SpillPartitionSet spillPartitionSet, bool hasNullKeys, HashJoinTableSpillFunc&& tableSpillFunc) { @@ -453,11 +457,11 @@ uint64_t HashJoinMemoryReclaimer::reclaim( } bool isHashBuildMemoryPool(const memory::MemoryPool& pool) { - return folly::StringPiece(pool.name()).endsWith("HashBuild"); + return pool.name().ends_with("HashBuild"); } bool isHashProbeMemoryPool(const memory::MemoryPool& pool) { - return folly::StringPiece(pool.name()).endsWith("HashProbe"); + return pool.name().ends_with("HashProbe"); } bool needRightSideJoin(core::JoinType joinType) { diff --git a/velox/exec/HashJoinBridge.h b/velox/exec/HashJoinBridge.h index 879eab6801f8..a79e6b5c754f 100644 --- a/velox/exec/HashJoinBridge.h +++ b/velox/exec/HashJoinBridge.h @@ -53,8 +53,9 @@ class HashJoinBridge : public JoinBridge { /// Invoked by the build operator to set the built hash table. /// 'spillPartitionSet' contains the spilled partitions while building /// 'table' which only applies if the disk spilling is enabled. + /// Accepts both unique_ptr (regular joins) and shared_ptr (broadcast joins). void setHashTable( - std::unique_ptr table, + std::shared_ptr table, SpillPartitionSet spillPartitionSet, bool hasNullKeys, HashJoinTableSpillFunc&& tableSpillFunc); @@ -143,6 +144,17 @@ class HashJoinBridge : public JoinBridge { bool testingHasMoreSpilledPartitions(); + /// Return the next unclaimed row container id in the current hash table for + /// HashProbe drivers to output build-side rows. + int getAndIncrementUnclaimedRowContainerId() { + return unclaimedRowContainerId_.fetch_add(1); + } + + /// Reset the next unclaimed row container id to 0. + void resetUnclaimedRowContainerId() { + unclaimedRowContainerId_.store(0); + } + private: void appendSpilledHashTablePartitionsLocked( SpillPartitionSet&& spillPartitionSet); @@ -183,6 +195,13 @@ class HashJoinBridge : public JoinBridge { // processing. bool probeStarted_; + // Keep track of the next row container id in a hash table that has not been + // processed by any hash probe driver. This is used when hash probe drivers + // output build-side rows. When drivers are allowed to output build-side rows + // in parallel, drivers call getAndIncrementClaimedRowContainerId() to ensure + // the row containers they process do not overlap with each other. + std::atomic unclaimedRowContainerId_{0}; + friend test::HashJoinBridgeTestHelper; }; diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index d6dd38926564..155b2946b0db 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -21,6 +21,7 @@ #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" +#include "velox/vector/LazyVector.h" using facebook::velox::common::testutil::TestValue; @@ -130,7 +131,9 @@ HashProbe::HashProbe( operatorCtx_->driverCtx()->splitGroupId, planNodeId())), filterResult_(1), - outputTableRowsCapacity_(outputBatchSize_) { + outputTableRowsCapacity_(outputBatchSize_), + parallelJoinBuildRowsEnabled_( + driverCtx->queryConfig().parallelOutputJoinBuildRowsEnabled()) { VELOX_CHECK_NOT_NULL(joinBridge_); } @@ -249,7 +252,7 @@ void HashProbe::maybeSetupInputSpiller( restoringPartitionId_, HashBitRange(bitOffset, bitOffset + spillConfig()->numPartitionBits), spillConfig(), - &spillStats_); + spillStats_.get()); // Set the spill partitions to the corresponding ones at the build side. We // only spill the seen partitions from the build side. For the ones not seen @@ -288,7 +291,7 @@ void HashProbe::maybeSetupSpillInputReader( VELOX_CHECK_EQ(partition->id(), restoredPartitionId.value()); restoringPartitionId_ = restoredPartitionId; spillInputReader_ = partition->createUnorderedReader( - spillConfig_->readBufferSize, pool(), &spillStats_); + spillConfig_->readBufferSize, pool(), spillStats_.get()); inputSpillPartitionSet_.erase(iter); } @@ -373,6 +376,46 @@ void HashProbe::initializeResultIter() { rowSizeEstimation); } +void HashProbe::pushdownDynamicFilters() { + auto* driver = operatorCtx_->driverCtx()->driver; + auto numFilters = driver->pushdownFilters( + this, + keyChannels_, + [&](column_index_t sourceChannel, + std::shared_ptr& filter) { + if (dynamicFiltersProducedOnChannels_.contains(sourceChannel)) { + return true; + } + auto& hasher = *table_->hashers()[sourceChannel]; + filter = hasher.getFilter(false); + if (!filter) { + filter = hasher.getBloomFilter(); + if (!filter) { + return false; + } + auto* bloomFilter = + checkedPointerCast( + filter.get()); + addRuntimeStat( + "bloomFilterSize", RuntimeCounter(bloomFilter->blocksByteSize())); + } + dynamicFiltersProducedOnChannels_.insert(sourceChannel); + for (auto* peer : findPeerOperators()) { + peer->dynamicFiltersProducedOnChannels_.insert(sourceChannel); + } + return true; + }); + // The join can be completely replaced with a pushed down filter when the + // following conditions are met: + // * hash table has a single key with unique values, + // * build side has no dependent columns. + if (keyChannels_.size() == 1 && !table_->hasDuplicateKeys() && + tableOutputProjections_.empty() && !filter_ && numFilters > 0 && + !table_->hashers()[0]->getBloomFilter() && !isRightJoin(joinType_)) { + canReplaceWithDynamicFilter_ = true; + } +} + void HashProbe::asyncWaitForHashTable() { checkRunning(); VELOX_CHECK_NULL(table_); @@ -429,6 +472,9 @@ void HashProbe::asyncWaitForHashTable() { (isRightSemiProjectJoin(joinType_) && !nullAware_) || isRightJoin(joinType_)) && table_->hashMode() != BaseHashTable::HashMode::kHash && !isSpillInput() && + operatorCtx_->driverCtx() + ->queryConfig() + .hashProbeDynamicFilterPushdownEnabled() && !hasMoreSpillData()) { // Find out whether there are any upstream operators that can accept dynamic // filters on all or a subset of the join keys. Create dynamic filters to @@ -438,18 +484,7 @@ void HashProbe::asyncWaitForHashTable() { // probe input is read from spilled data and there is no upstream operators // involved; (2) if there is spill data to restore, then we can't filter // probe inputs solely based on the current table's join keys. - const auto& buildHashers = table_->hashers(); - const auto channels = operatorCtx_->driverCtx()->driver->canPushdownFilters( - this, keyChannels_); - - for (auto i = 0; i < keyChannels_.size(); ++i) { - if (channels.find(keyChannels_[i]) != channels.end()) { - if (auto filter = buildHashers[i]->getFilter(/*nullAllowed=*/false)) { - dynamicFilters_.emplace(keyChannels_[i], std::move(filter)); - } - } - } - hasGeneratedDynamicFilters_ = !dynamicFilters_.empty(); + pushdownDynamicFilters(); } } @@ -473,7 +508,6 @@ void HashProbe::prepareForSpillRestore() { restoringPartitionId_.reset(); spillInputPartitionIds_.clear(); spillOutputReader_.reset(); - lastProbeIterator_.reset(); VELOX_CHECK(promises_.empty() || lastProber_); if (!lastProber_) { @@ -614,7 +648,6 @@ BlockingReason HashProbe::isBlocked(ContinueFuture* future) { } break; case ProbeOperatorState::kWaitForPeers: - VELOX_CHECK(canSpill()); if (!future_.valid()) { setRunning(); } @@ -633,20 +666,6 @@ BlockingReason HashProbe::isBlocked(ContinueFuture* future) { return fromStateToBlockingReason(state_); } -void HashProbe::clearDynamicFilters() { - // The join can be completely replaced with a pushed down filter when the - // following conditions are met: - // * hash table has a single key with unique values, - // * build side has no dependent columns. - if (keyChannels_.size() == 1 && !table_->hasDuplicateKeys() && - tableOutputProjections_.empty() && !filter_ && !dynamicFilters_.empty() && - !isRightJoin(joinType_)) { - canReplaceWithDynamicFilter_ = true; - } - - Operator::clearDynamicFilters(); -} - void HashProbe::decodeAndDetectNonNullKeys() { nonNullInputRows_.resize(input_->size()); nonNullInputRows_.setAll(); @@ -857,28 +876,49 @@ void HashProbe::fillOutput(vector_size_t size) { } RowVectorPtr HashProbe::getBuildSideOutput() { - auto* outputTableRows = + if (buildSideOutputRowContainerId_ == -1) { + buildSideOutputRowContainerId_ = + joinBridge_->getAndIncrementUnclaimedRowContainerId(); + lastProbeIterator_.reset(); + } + if (buildSideOutputRowContainerId_ >= table_->numRowContainers()) { + return nullptr; + } + + char** outputTableRows = initBuffer(outputTableRows_, outputTableRowsCapacity_, pool()); - int32_t numOut; - if (isRightSemiFilterJoin(joinType_)) { - numOut = table_->listProbedRows( - &lastProbeIterator_, - outputTableRowsCapacity_, - RowContainer::kUnlimited, - outputTableRows); - } else if (isRightSemiProjectJoin(joinType_)) { - numOut = table_->listAllRows( - &lastProbeIterator_, - outputTableRowsCapacity_, - RowContainer::kUnlimited, - outputTableRows); - } else { - // Must be a right join or full join. - numOut = table_->listNotProbedRows( - &lastProbeIterator_, - outputTableRowsCapacity_, - RowContainer::kUnlimited, - outputTableRows); + int32_t numOut{0}; + while (numOut == 0 && + buildSideOutputRowContainerId_ < table_->numRowContainers()) { + if (isRightSemiFilterJoin(joinType_)) { + numOut = table_->listProbedRows( + lastProbeIterator_, + buildSideOutputRowContainerId_, + outputTableRowsCapacity_, + RowContainer::kUnlimited, + outputTableRows); + } else if (isRightSemiProjectJoin(joinType_)) { + numOut = table_->listAllRows( + lastProbeIterator_, + buildSideOutputRowContainerId_, + outputTableRowsCapacity_, + RowContainer::kUnlimited, + outputTableRows); + + } else { + // Must be a right join or full join. + numOut = table_->listNotProbedRows( + lastProbeIterator_, + buildSideOutputRowContainerId_, + outputTableRowsCapacity_, + RowContainer::kUnlimited, + outputTableRows); + } + if (numOut == 0) { + buildSideOutputRowContainerId_ = + joinBridge_->getAndIncrementUnclaimedRowContainerId(); + lastProbeIterator_.reset(); + } } if (numOut == 0) { return nullptr; @@ -902,9 +942,9 @@ RowVectorPtr HashProbe::getBuildSideOutput() { if (isRightSemiProjectJoin(joinType_)) { // Populate 'match' column. - if (noInput_) { + if (noInput_ && nullAware_) { // Probe side is empty. All rows should return 'match = false', even ones - // with a null join key. + // with a null join key. (This applies to null-aware joins only.) matchColumn() = createConstantFalse(numOut, pool()); } else { table_->rows()->extractProbedFlags( @@ -942,6 +982,12 @@ bool HashProbe::canSpill() const { if (!Operator::canSpill()) { return false; } + // Hash table caching is incompatible with spilling. When the table is + // cached and shared across tasks, clearing it after probe would corrupt + // the cache for subsequent tasks. + if (joinNode_->useHashTableCache()) { + return false; + } if (operatorCtx_->task()->hasMixedExecutionGroupJoin(joinNode_.get())) { return operatorCtx_->driverCtx() ->queryConfig() @@ -972,16 +1018,11 @@ void HashProbe::checkStateTransition(ProbeOperatorState state) { VELOX_CHECK_NE(state_, state); switch (state) { case ProbeOperatorState::kRunning: - if (!canSpill()) { - VELOX_CHECK_EQ(state_, ProbeOperatorState::kWaitForBuild); - } else { - VELOX_CHECK( - state_ == ProbeOperatorState::kWaitForBuild || - state_ == ProbeOperatorState::kWaitForPeers); - } + VELOX_CHECK( + state_ == ProbeOperatorState::kWaitForBuild || + state_ == ProbeOperatorState::kWaitForPeers); break; case ProbeOperatorState::kWaitForPeers: - VELOX_CHECK(canSpill()); [[fallthrough]]; case ProbeOperatorState::kWaitForBuild: [[fallthrough]]; @@ -1027,7 +1068,7 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { return nullptr; } - if (needLastProbe() && lastProber_) { + if (needLastProbe() && (outputBuildRowsInParallel_ || lastProber_)) { auto output = getBuildSideOutput(); if (output != nullptr) { return output; @@ -1170,6 +1211,13 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { numOut = evalFilter(numOut); if (numOut == 0) { + // The hash probe might get stuck in the output loop if the filter is + // highly selective. This does not apply if the call is made during + // spilling, because we cannot break out and resume when the operator is + // undergoing spilling. + if (!toSpillOutput && shouldYield()) { + return nullptr; + } continue; } @@ -1694,21 +1742,34 @@ void HashProbe::noMoreInputInternal() { spillInputPartitionIds_.size(), inputSpiller_->state().spilledPartitionIdSet().size()); inputSpiller_->finishSpill(inputSpillPartitionSet_); - VELOX_CHECK_EQ(spillStats_.rlock()->spillSortTimeNanos, 0); + VELOX_CHECK_EQ(spillStats_->rlock()->spillSortTimeNanos, 0); } const bool hasSpillEnabled = canSpill(); std::vector promises; std::vector> peers; - // The last operator to finish processing inputs is responsible for - // producing build-side rows based on the join. + + // Reset flags about outputing build-side rows in parallel. + buildSideOutputRowContainerId_ = -1; + outputBuildRowsInParallel_ = + parallelJoinBuildRowsEnabled_ && needLastProbe() && !hasSpillEnabled; + + // NOTE: if 'hasSpillEnabled' is false and 'outputBuildRowsInParallel_' is + // false too, then a hash probe operator doesn't need to wait for all the + // other peers to finish probe processing. If 'hasSpillEnabled' is true, it + // needs to wait and might expect spill gets triggered by the other probe + // operators, or there is previously spilled table partition(s) that needs to + // restore. If 'outputBuildRowsInParallel_', it needs to wait to all drivers + // start outputing build-side rows in parallel only after all drivers finish + // probe processing. + const bool shouldBlock = hasSpillEnabled || outputBuildRowsInParallel_; if (!operatorCtx_->task()->allPeersFinished( planNodeId(), operatorCtx_->driver(), - hasSpillEnabled ? &future_ : nullptr, - hasSpillEnabled ? promises_ : promises, + shouldBlock ? &future_ : nullptr, + shouldBlock ? promises_ : promises, peers)) { - if (hasSpillEnabled) { + if (shouldBlock) { VELOX_CHECK(future_.valid()); setState(ProbeOperatorState::kWaitForPeers); VELOX_DCHECK(promises_.empty()); @@ -1719,13 +1780,14 @@ void HashProbe::noMoreInputInternal() { } VELOX_CHECK(promises.empty()); - // NOTE: if 'hasSpillEnabled' is false, then a hash probe operator doesn't - // need to wait for all the other peers to finish probe processing. - // Otherwise, it needs to wait and might expect spill gets triggered by the - // other probe operators, or there is previously spilled table partition(s) - // that needs to restore. - VELOX_CHECK(hasSpillEnabled || peers.empty()); lastProber_ = true; + joinBridge_->resetUnclaimedRowContainerId(); + // If 'outputBuildRowsInParallel_' is true, wake up all peers to start + // outputing build-side rows in parallel. Otherwise, only let the last prober + // proceed. + if (outputBuildRowsInParallel_) { + wakeupPeerOperators(); + } } bool HashProbe::isFinished() { @@ -1769,8 +1831,9 @@ void HashProbe::ensureOutputFits() { } // We only need to reserve memory for output if need. - if (input_ == nullptr && - (hasMoreInput() || !(needLastProbe() && lastProber_))) { + bool outputBuildSideRows = + outputBuildRowsInParallel_ || (needLastProbe() && lastProber_); + if (input_ == nullptr && (hasMoreInput() || !outputBuildSideRows)) { return; } @@ -1888,7 +1951,7 @@ void HashProbe::reclaim( tableSpillHashBits_, joinNode_, spillConfig(), - &spillStats_); + spillStats_.get()); VELOX_CHECK(!spillPartitionSet.empty()); } const auto spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet); @@ -1970,7 +2033,11 @@ void HashProbe::spillOutput() { } // We spill all the outputs produced from 'input_' into a single partition. auto outputSpiller = std::make_unique( - outputType_, std::nullopt, HashBitRange{}, spillConfig(), &spillStats_); + outputType_, + std::nullopt, + HashBitRange{}, + spillConfig(), + spillStats_.get()); outputSpiller->setPartitionsSpilled({SpillPartitionId(0)}); RowVectorPtr output{nullptr}; @@ -2012,7 +2079,7 @@ void HashProbe::maybeSetupSpillOutputReader() { spillOutputReader_ = spillOutputPartitionSet_.begin()->second->createUnorderedReader( - spillConfig_->readBufferSize, pool(), &spillStats_); + spillConfig_->readBufferSize, pool(), spillStats_.get()); spillOutputPartitionSet_.clear(); } @@ -2038,7 +2105,7 @@ void HashProbe::checkMaxSpillLevel( << "Exceeded spill level limit: " << config->maxSpillLevel << ", and disable spilling for memory pool: " << pool()->name(); exceededMaxSpillLevelLimit_ = true; - ++spillStats_.wlock()->spillMaxLevelExceededCount; + ++spillStats_->wlock()->spillMaxLevelExceededCount; return; } } @@ -2060,7 +2127,7 @@ void HashProbe::close() { spillOutputReader_.reset(); clearBuffers(); - // Fullfill any pending promises + // Fulfill any pending promises if (lastProber_) { wakeupPeerOperators(); } diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index 4eeaf0f3a25a..a45ab7d8b555 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -16,7 +16,6 @@ #pragma once #include "velox/exec/HashBuild.h" -#include "velox/exec/HashPartitionFunction.h" #include "velox/exec/HashTable.h" #include "velox/exec/Operator.h" #include "velox/exec/ProbeOperatorState.h" @@ -64,8 +63,6 @@ class HashProbe : public Operator { void close() override; - void clearDynamicFilters() override; - bool canReclaim() const override; const std::vector& tableOutputProjections() const { @@ -130,6 +127,8 @@ class HashProbe : public Operator { // Applies only for mixed grouped execution mode. bool allProbeGroupFinished() const; + void pushdownDynamicFilters(); + // Invoked to wait for the hash table to be built by the hash build operators // asynchronously. The function also sets up the internal state for // potentially spilling input or reading spilled input or recursively spill @@ -413,11 +412,7 @@ class HashProbe : public Operator { // Channel of probe keys in 'input_'. std::vector keyChannels_; - // True if we have generated dynamic filters from the hash build join keys. - // - // NOTE: 'dynamicFilters_' might have been cleared once they have been pushed - // down to the upstream operators. - tsan_atomic hasGeneratedDynamicFilters_{false}; + folly::F14FastSet dynamicFiltersProducedOnChannels_; // True if the join can become a no-op starting with the next batch of input. bool canReplaceWithDynamicFilter_{false}; @@ -635,7 +630,7 @@ class HashProbe : public Operator { std::optional currentRowPassed; }; - BaseHashTable::RowsIterator lastProbeIterator_; + RowContainerIterator lastProbeIterator_; // For left and anti join with filter, tracks the probe side rows which had // matches on the build side but didn't pass the filter. @@ -733,6 +728,23 @@ class HashProbe : public Operator { // Input vector used for listing rows with null keys. VectorPtr nullKeyProbeInput_; + + // Flag to indicate whether the query config allows hash probe drivers to + // output build-side rows in parallel. + const bool parallelJoinBuildRowsEnabled_; + + // Flag to indicate whether this hash probe operator is outputing build-side + // rows in parallel with the peer operators for the current hash table. + // Outputing build-side rows in parallel is not enabled in either of the + // following cases: + // 1. parallelJoinBuildRowsEnabled_ is false. + // 2. This join type does not need build-side rows. + // 3. There are more spilled data to restore. + bool outputBuildRowsInParallel_{false}; + + // The index of the row container in the current hash table that this hash + // probe oprator is processing to output build-side rows. + int buildSideOutputRowContainerId_{-1}; }; inline std::ostream& operator<<(std::ostream& os, ProbeOperatorState state) { diff --git a/velox/exec/HashTable.cpp b/velox/exec/HashTable.cpp index 3717d6631c08..6dd4f967346f 100644 --- a/velox/exec/HashTable.cpp +++ b/velox/exec/HashTable.cpp @@ -28,6 +28,7 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { + // static std::string BaseHashTable::modeString(HashMode mode) { switch (mode) { @@ -52,16 +53,20 @@ HashTable::HashTable( bool isJoinBuild, bool hasProbedFlag, uint32_t minTableSizeForParallelJoinBuild, - memory::MemoryPool* pool) + memory::MemoryPool* pool, + uint64_t bloomFilterMaxSize) : BaseHashTable(std::move(hashers)), pool_(pool), minTableSizeForParallelJoinBuild_(minTableSizeForParallelJoinBuild), + bloomFilterMaxSize_(bloomFilterMaxSize), isJoinBuild_(isJoinBuild), + allowDuplicates_(allowDuplicates), buildPartitionBounds_(raw_vector(pool)) { + VELOX_CHECK(bloomFilterMaxSize_ == 0 || (isJoinBuild && ignoreNullKeys)); std::vector keys; for (auto& hasher : hashers_) { keys.push_back(hasher->type()); - if (!VectorHasher::typeKindSupportsValueIds(hasher->typeKind())) { + if (!hasher->typeSupportsValueIds()) { hashMode_ = HashMode::kHash; } } @@ -75,6 +80,7 @@ HashTable::HashTable( isJoinBuild, hasProbedFlag, hashMode_ != HashMode::kHash, + /*useListRowIndex=*/false, pool); nextOffset_ = rows_->nextOffset(); } @@ -332,6 +338,8 @@ char* HashTable::insertEntry( HashLookup& lookup, uint64_t index, vector_size_t row) { + TestValue::adjust( + "facebook::velox::exec::HashTable::insertEntry", rows_->pool()); char* group = rows_->newRow(); lookup.hits[row] = group; // NOLINT storeKeys(lookup, row); @@ -358,8 +366,13 @@ bool HashTable::compareKeys( int32_t i = 0; do { auto& hasher = lookup.hashers[i]; - if (!rows_->equals( - group, rows_->columnAt(i), hasher->decodedVector(), row)) { + if (rows_->compare( + group, + rows_->columnAt(i), + hasher->decodedVector(), + row, + CompareFlags::equality( + CompareFlags::NullHandlingMode::kNullAsValue)) != 0) { return false; } } while (++i < numKeys); @@ -417,9 +430,12 @@ FOLLY_ALWAYS_INLINE void HashTable::fullProbe( } namespace { + // Group prefetch size for join build & probe. constexpr int32_t kPrefetchSize = 64; +constexpr int32_t kHashBatchSize = 1024; + // Normalized keys have non0-random bits. Bits need to be propagated // up to make a tag byte and down so that non-lowest bits of // normalized key affect the hash table index. @@ -834,12 +850,87 @@ bool HashTable::hashRows( } namespace { + +template +void partitionBloomFilterRowsImpl( + int32_t offset, + const common::BigintValuesUsingBloomFilter& filter, + const RowContainer& rowContainer, + uint8_t partitionMask, + RowPartitions& rowPartitions) { + char* rows[kHashBatchSize]; + uint8_t partitions[kHashBatchSize]; + RowContainerIterator iter; + while (auto numRows = rowContainer.listRows( + &iter, kHashBatchSize, RowContainer::kUnlimited, rows)) { + for (int i = 0; i < numRows; ++i) { + auto value = folly::loadUnaligned(rows[i] + offset); + partitions[i] = filter.blockIndex(value) & partitionMask; + } + rowPartitions.appendPartitions( + folly::Range(partitions, numRows)); + } +} + +void partitionBloomFilterRows( + const VectorHasher& hasher, + int32_t offset, + const common::BigintValuesUsingBloomFilter& filter, + const RowContainer& rowContainer, + uint8_t numPartitions, + RowPartitions& rowPartitions) { + VELOX_DCHECK(hasher.supportsBloomFilter()); + VELOX_DCHECK(bits::isPowerOfTwo(numPartitions)); + switch (hasher.typeKind()) { + case TypeKind::INTEGER: + partitionBloomFilterRowsImpl( + offset, filter, rowContainer, numPartitions - 1, rowPartitions); + break; + case TypeKind::BIGINT: + partitionBloomFilterRowsImpl( + offset, filter, rowContainer, numPartitions - 1, rowPartitions); + break; + default: + VELOX_UNREACHABLE(); + } +} + +template +void buildBloomFilterImpl( + int32_t offset, + char** rows, + int numRows, + common::BigintValuesUsingBloomFilter& filter) { + for (int i = 0; i < numRows; ++i) { + filter.insert(folly::loadUnaligned(rows[i] + offset)); + } +} + +void buildBloomFilter( + const VectorHasher& hasher, + int32_t offset, + char** rows, + int numRows, + common::BigintValuesUsingBloomFilter& filter) { + VELOX_DCHECK(hasher.supportsBloomFilter()); + switch (hasher.typeKind()) { + case TypeKind::INTEGER: + buildBloomFilterImpl(offset, rows, numRows, filter); + break; + case TypeKind::BIGINT: + buildBloomFilterImpl(offset, rows, numRows, filter); + break; + default: + VELOX_UNREACHABLE(); + } +} + template void syncWorkItems( std::vector>& items, - std::exception_ptr& error, std::vector& timings, - bool log = false) { + bool throwError) { + std::exception_ptr error; // All items must be synced also in case of error because the items // hold references to the table and rows which could be destructed // if unwinding the stack did not pause to sync. @@ -850,15 +941,42 @@ void syncWorkItems( timings.push_back(item->prepareTiming()); } } catch (const std::exception& e) { - if (log) { + if (!throwError) { LOG(ERROR) << "Error in async hash build: " << e.what(); + } else { + error = std::current_exception(); } - error = std::current_exception(); } } + if (error) { + std::rethrow_exception(error); + } } + } // namespace +template <> +bool HashTable::bloomFilterSupported() const { + if (!(bloomFilterMaxSize_ > 0 && + common::BigintValuesUsingBloomFilter::numBlocks(numDistinct_) * + sizeof(SplitBlockBloomFilter::Block) <= + bloomFilterMaxSize_)) { + return false; + } + for (auto& hasher : hashers_) { + if (hasher->supportsBloomFilter()) { + return true; + } + } + return false; +} + +template <> +bool HashTable::bloomFilterSupported() const { + VELOX_CHECK_EQ(bloomFilterMaxSize_, 0); + return false; +} + template bool HashTable::canApplyParallelJoinBuild() const { if (!isJoinBuild_ || buildExecutor_ == nullptr) { @@ -908,23 +1026,56 @@ void HashTable::parallelJoinBuild() { buildPartitionBounds_.back() = sizeMask_ + 1; std::vector>> partitionSteps; std::vector>> buildSteps; + std::vector>> bloomFilterPartitionSteps; + std::vector>> bloomFilterBuildSteps; // rowPartitions are used in the async threads, so declare them before the // sync guard. std::vector> rowPartitions; auto sync = folly::makeGuard([&]() { // This is executed on returning path, possibly in unwinding, so must not // throw. - std::exception_ptr error; syncWorkItems( - partitionSteps, error, parallelJoinBuildStats_.partitionTimings, true); + partitionSteps, parallelJoinBuildStats_.partitionTimings, false); + syncWorkItems(buildSteps, parallelJoinBuildStats_.buildTimings, false); syncWorkItems( - buildSteps, error, parallelJoinBuildStats_.buildTimings, true); + bloomFilterPartitionSteps, + parallelJoinBuildStats_.bloomFilterPartitionTimings, + false); + syncWorkItems( + bloomFilterBuildSteps, + parallelJoinBuildStats_.bloomFilterBuildTimings, + false); + // Release the partition bounds to reduce memory usage. + buildPartitionBounds_ = raw_vector(pool_); }); + // Passing driver context directly to avoid cross thread access to thread + // local driver thread context. + const DriverCtx* driverCtx{nullptr}; + if (const auto* driverThreadCtx = driverThreadContext()) { + driverCtx = driverThreadCtx->driverCtx(); + } + const auto getTable = [this](size_t i) INLINE_LAMBDA { return i == 0 ? this : otherTables_[i - 1].get(); }; + const auto runStep = [&](auto& steps, auto&& work, bool runInCurrentThread) { + auto step = std::make_shared>([work = std::move(work)] { + work(); + return std::make_unique(true); + }); + steps.push_back(step); + if (runInCurrentThread) { + step->prepare(); + } else { + buildExecutor_->add([driverCtx, step]() { + ScopedDriverThreadContext scopedDriverThreadContext(driverCtx); + step->prepare(); + }); + } + }; + // This step can involve large memory allocations, so there is a chance of // OOMs here. Do it before any async work is started to reduce the chances of // concurrency issues. @@ -934,63 +1085,102 @@ void HashTable::parallelJoinBuild() { rowPartitions.push_back(table->rows()->createRowPartitions(*rows_->pool())); } - // Passing driver context directly to avoid cross thread access to thread - // local driver thread context. - const DriverCtx* driverCtx{nullptr}; - if (const auto* driverThreadCtx = driverThreadContext()) { - driverCtx = driverThreadCtx->driverCtx(); - } - // The parallel table partitioning step. for (auto i = 0; i < numPartitions; ++i) { auto* table = getTable(i); - partitionSteps.push_back(std::make_shared>( - [this, table, rawRowPartitions = rowPartitions[i].get()]() { + bool last = i == numPartitions - 1; + runStep( + partitionSteps, + [this, table, rawRowPartitions = rowPartitions[i].get()] { partitionRows(*table, *rawRowPartitions); - return std::make_unique(true); - })); - VELOX_CHECK(!partitionSteps.empty()); - buildExecutor_->add([driverCtx, step = partitionSteps.back()]() { - ScopedDriverThreadContext scopedDriverThreadContext(driverCtx); - step->prepare(); - }); - } - - std::exception_ptr error; - syncWorkItems( - partitionSteps, error, parallelJoinBuildStats_.partitionTimings); - if (error != nullptr) { - std::rethrow_exception(error); + }, + // run last partition on current thread to avoid wasting current thread + // on just waiting + last); } + syncWorkItems(partitionSteps, parallelJoinBuildStats_.partitionTimings, true); // The parallel table building step. std::vector> overflowPerPartition(numPartitions); for (auto i = 0; i < numPartitions; ++i) { - buildSteps.push_back(std::make_shared>( - [this, i, &overflowPerPartition, &rowPartitions]() { + bool last = i == numPartitions - 1; + runStep( + buildSteps, + [this, i, &overflowPerPartition, &rowPartitions] { buildJoinPartition(i, rowPartitions, overflowPerPartition[i]); - return std::make_unique(true); - })); - VELOX_CHECK(!buildSteps.empty()); - buildExecutor_->add([driverCtx, step = buildSteps.back()]() { - ScopedDriverThreadContext scopedDriverThreadContext(driverCtx); - step->prepare(); - }); - } - syncWorkItems(buildSteps, error, parallelJoinBuildStats_.buildTimings); - - if (error != nullptr) { - std::rethrow_exception(error); + }, + // run last partition on current thread to avoid wasting current thread + // on just waiting + last); + } + syncWorkItems(buildSteps, parallelJoinBuildStats_.buildTimings, true); + + if (bloomFilterSupported()) { + const auto numBloomFilterPartitions = bits::isPowerOfTwo(numPartitions) + ? numPartitions + : bits::nextPowerOfTwo(numPartitions) / 2; + VELOX_CHECK_GT(numBloomFilterPartitions, 0); + for (int i = 0; i < hashers_.size(); ++i) { + if (!hashers_[i]->supportsBloomFilter()) { + continue; + } + auto filter = std::make_shared( + numDistinct_, false); + hashers_[i]->setBloomFilter(filter); + for (auto j = 0; j < numPartitions; ++j) { + bool last = j == numPartitions - 1; + auto* rows = getTable(j)->rows(); + rowPartitions[j]->reset(); + runStep( + bloomFilterPartitionSteps, + [hasher = hashers_[i].get(), + offset = rows->columnAt(i).offset(), + filter, + rows, + numBloomFilterPartitions, + rowPartitions = rowPartitions[j].get()] { + partitionBloomFilterRows( + *hasher, + offset, + *filter, + *rows, + numBloomFilterPartitions, + *rowPartitions); + }, + // run last partition on current thread to avoid wasting current + // thread on just waiting + last); + } + syncWorkItems( + bloomFilterPartitionSteps, + parallelJoinBuildStats_.bloomFilterPartitionTimings, + true); + for (auto j = 0; j < numBloomFilterPartitions; ++j) { + bool last = j == numBloomFilterPartitions - 1; + runStep( + bloomFilterBuildSteps, + [this, i, j, &rowPartitions] { + buildBloomFilterPartition(i, j, rowPartitions); + }, + // run last partition on current thread to avoid wasting current + // thread on just waiting + last); + } + syncWorkItems( + bloomFilterBuildSteps, + parallelJoinBuildStats_.bloomFilterBuildTimings, + true); + } } - raw_vector hashes; + raw_vector hashes(pool_); for (auto i = 0; i < numPartitions; ++i) { auto& overflows = overflowPerPartition[i]; hashes.resize(overflows.size()); - hashRows( + VELOX_CHECK(hashRows( folly::Range(overflows.data(), overflows.size()), false, - hashes); + hashes)); auto table = i == 0 ? this : otherTables_[i - 1].get(); insertForJoin(overflows.data(), hashes.data(), overflows.size(), nullptr); VELOX_CHECK_EQ(table->rows()->numRows(), table->numParallelBuildRows_); @@ -1023,14 +1213,14 @@ template void HashTable::partitionRows( HashTable& subtable, RowPartitions& rowPartitions) { - constexpr int32_t kBatch = 1024; - raw_vector rows(kBatch); - raw_vector hashes(kBatch); - raw_vector partitions(kBatch); + raw_vector rows(kHashBatchSize, pool_); + raw_vector hashes(kHashBatchSize, pool_); + raw_vector partitions(kHashBatchSize, pool_); RowContainerIterator iter; while (auto numRows = subtable.rows_->listRows( - &iter, kBatch, RowContainer::kUnlimited, rows.data())) { - hashRows(folly::Range(rows.data(), numRows), true, hashes); + &iter, kHashBatchSize, RowContainer::kUnlimited, rows.data())) { + VELOX_CHECK( + hashRows(folly::Range(rows.data(), numRows), true, hashes)); VELOX_DCHECK_EQ( 0, buildPartitionBounds_.capacity() % @@ -1051,9 +1241,8 @@ void HashTable::buildJoinPartition( uint8_t partition, const std::vector>& rowPartitions, std::vector& overflow) { - constexpr int32_t kBatch = 1024; - raw_vector rows(kBatch); - raw_vector hashes(kBatch); + raw_vector rows(kHashBatchSize, pool_); + raw_vector hashes(kHashBatchSize, pool_); const int32_t numPartitions = 1 + otherTables_.size(); TableInsertPartitionInfo partitionInfo{ buildPartitionBounds_[partition], @@ -1062,15 +1251,44 @@ void HashTable::buildJoinPartition( for (auto i = 0; i < numPartitions; ++i) { auto* table = i == 0 ? this : otherTables_[i - 1].get(); RowContainerIterator iter; - while (const auto numRows = table->rows_->listPartitionRows( - iter, partition, kBatch, *rowPartitions[i], rows.data())) { - hashRows(folly::Range(rows.data(), numRows), false, hashes); + while ( + const auto numRows = table->rows_->listPartitionRows( + iter, partition, kHashBatchSize, *rowPartitions[i], rows.data())) { + VELOX_CHECK(hashRows(folly::Range(rows.data(), numRows), false, hashes)); insertForJoin(rows.data(), hashes.data(), numRows, &partitionInfo); table->numParallelBuildRows_ += numRows; } } } +template <> +void HashTable::buildBloomFilterPartition( + column_index_t columnIndex, + uint8_t partition, + const std::vector>& rowPartitions) { + char* rows[kHashBatchSize]; + for (auto i = 0; i < 1 + otherTables_.size(); ++i) { + auto* table = i == 0 ? this : otherTables_[i - 1].get(); + auto rowColumn = table->rows_->columnAt(columnIndex); + auto* filter = checkedPointerCast( + hashers_[columnIndex]->getBloomFilter().get()); + RowContainerIterator iter; + while (auto numRows = table->rows_->listPartitionRows( + iter, partition, kHashBatchSize, *rowPartitions[i], rows)) { + buildBloomFilter( + *hashers_[columnIndex], rowColumn.offset(), rows, numRows, *filter); + } + } +} + +template <> +void HashTable::buildBloomFilterPartition( + column_index_t /*columnIndex*/, + uint8_t /*partition*/, + const std::vector>& /*rowPartitions*/) { + VELOX_UNREACHABLE(); +} + template bool HashTable::insertBatch( char** groups, @@ -1091,7 +1309,7 @@ bool HashTable::insertBatch( template void HashTable::insertForGroupBy( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups) { if (hashMode_ == HashMode::kArray) { for (auto i = 0; i < numGroups; ++i) { @@ -1225,7 +1443,7 @@ template template FOLLY_ALWAYS_INLINE void HashTable::insertForJoinWithPrefetch( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo) { auto i = 0; @@ -1261,7 +1479,7 @@ FOLLY_ALWAYS_INLINE void HashTable::insertForJoinWithPrefetch( template void HashTable::insertForJoin( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo) { // The insertable rows are in the table, all get put in the hash table or @@ -1286,30 +1504,55 @@ void HashTable::rehash( bool initNormalizedKeys, int8_t spillInputStartPartitionBit) { ++numRehashes_; - constexpr int32_t kHashBatchSize = 1024; if (canApplyParallelJoinBuild()) { parallelJoinBuild(); return; } - raw_vector hashes; + raw_vector hashes(pool_); hashes.resize(kHashBatchSize); char* groups[kHashBatchSize]; + const bool shouldBuildBloomFilter = bloomFilterSupported(); + std::vector bloomFilters; + if (shouldBuildBloomFilter) { + bloomFilters.resize(hashers_.size()); + for (int i = 0; i < hashers_.size(); ++i) { + if (!hashers_[i]->supportsBloomFilter()) { + continue; + } + auto filter = std::make_shared( + numDistinct_, false); + bloomFilters[i] = filter.get(); + hashers_[i]->setBloomFilter(filter); + } + } // A join build can have multiple payload tables. Loop over 'this' // and the possible other tables and put all the data in the table // of 'this'. for (int32_t i = 0; i <= otherTables_.size(); ++i) { RowContainerIterator iterator; int32_t numGroups; + auto* table = i == 0 ? this : otherTables_[i - 1].get(); do { - numGroups = (i == 0 ? this : otherTables_[i - 1].get()) - ->rows() - ->listRows(&iterator, kHashBatchSize, groups); + numGroups = table->rows()->listRows(&iterator, kHashBatchSize, groups); if (!insertBatch( groups, numGroups, hashes, initNormalizedKeys || i != 0)) { VELOX_CHECK_NE(hashMode_, HashMode::kHash); setHashMode(HashMode::kHash, 0, spillInputStartPartitionBit); return; } + if (shouldBuildBloomFilter) { + for (int j = 0; j < hashers_.size(); ++j) { + if (!hashers_[j]->supportsBloomFilter()) { + continue; + } + buildBloomFilter( + *hashers_[j], + table->rows()->columnAt(j).offset(), + groups, + numGroups, + *bloomFilters[j]); + } + } } while (numGroups > 0); } } @@ -1348,7 +1591,6 @@ void HashTable::setHashMode( template bool HashTable::analyze() { - constexpr int32_t kHashBatchSize = 1024; // @lint-ignore CLANGTIDY char* groups[kHashBatchSize]; RowContainerIterator iterator; @@ -1485,7 +1727,9 @@ void HashTable::decideHashMode( return; } disableRangeArrayHash_ |= disableRangeArrayHash; - if (numDistinct_ && !isJoinBuild_) { + if (numDistinct_ && (!isJoinBuild_ || joinBuildNoDuplicates())) { + // If the join type is left semi and anti, allowDuplicates_ will be false, + // and join build is building hash table while adding input rows. if (!analyze()) { setHashMode(HashMode::kHash, numNew, spillInputStartPartitionBit); return; @@ -1572,7 +1816,10 @@ void HashTable::checkHashBitsOverlap( int8_t spillInputStartPartitionBit) { if (spillInputStartPartitionBit != kNoSpillInputStartPartitionBit && hashMode() != HashMode::kArray) { - VELOX_CHECK_LT(sizeBits_ - 1, spillInputStartPartitionBit); + VELOX_CHECK_LT( + sizeBits_ - 1, + spillInputStartPartitionBit, + "The size bits of the hash table must be lower than the spilling partition bits to avoid overlap"); } } @@ -1704,12 +1951,26 @@ template void HashTable::prepareJoinTable( std::vector> tables, int8_t spillInputStartPartitionBit, + size_t vectorHasherMaxNumDistinct, + bool dropDuplicates, folly::Executor* executor) { buildExecutor_ = executor; + if (dropDuplicates) { + if (table_ != nullptr) { + // Reset table_ and capacity_ to trigger rehash. + rows_->pool()->freeContiguous(tableAllocation_); + table_ = nullptr; + capacity_ = 0; + } + // Call analyze to insert all unique values in row container to the + // table hashers' uniqueValues_; + analyze(); + } otherTables_.reserve(tables.size()); for (auto& table : tables) { - otherTables_.emplace_back(std::unique_ptr>( - dynamic_cast*>(table.release()))); + otherTables_.emplace_back( + std::unique_ptr>( + dynamic_cast*>(table.release()))); } // If there are multiple tables, we need to merge the 'columnHasNulls' flags @@ -1727,6 +1988,7 @@ void HashTable::prepareJoinTable( bool useValueIds = mayUseValueIds(*this); if (useValueIds) { + CpuWallTimer timer(vectorHasherMergeTiming_); for (auto& other : otherTables_) { if (!mayUseValueIds(*other)) { useValueIds = false; @@ -1735,8 +1997,13 @@ void HashTable::prepareJoinTable( } if (useValueIds) { for (auto& other : otherTables_) { + if (dropDuplicates) { + // Before merging with the current hashers, all values in the row + // containers of other table need to be inserted into uniqueValues_. + other->analyze(); + } for (auto i = 0; i < hashers_.size(); ++i) { - hashers_[i]->merge(*other->hashers_[i]); + hashers_[i]->merge(*other->hashers_[i], vectorHasherMaxNumDistinct); if (!hashers_[i]->mayUseValueIds()) { useValueIds = false; break; @@ -1797,7 +2064,7 @@ int32_t HashTable::listJoinResults( uint64_t totalBytes{0}; while (iter.lastRowIndex < iter.rows->size()) { if (!iter.nextHit) { - auto row = (*iter.rows)[iter.lastRowIndex]; + const auto row = (*iter.rows)[iter.lastRowIndex]; iter.nextHit = (*iter.hits)[row]; // NOLINT if (!iter.nextHit) { ++iter.lastRowIndex; @@ -1962,6 +2229,58 @@ int32_t HashTable::listAllRows( return listRows(iter, maxRows, maxBytes, rows); } +template +template +int32_t HashTable::listRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + const auto& rowContainer = rowContainerId == 0 + ? rows_.get() + : otherTables_[rowContainerId - 1]->rows(); + const auto numRows = rowContainer->template listRows( + &rowContainerIterator, maxRows, maxBytes, rows); + if (numRows > 0) { + return numRows; + } + return 0; +} + +template +int32_t HashTable::listNotProbedRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + return listRows( + rowContainerIterator, rowContainerId, maxRows, maxBytes, rows); +} + +template +int32_t HashTable::listProbedRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + return listRows( + rowContainerIterator, rowContainerId, maxRows, maxBytes, rows); +} + +template +int32_t HashTable::listAllRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) { + return listRows( + rowContainerIterator, rowContainerId, maxRows, maxBytes, rows); +} + template <> int32_t HashTable::listNullKeyRows( NullKeyRowsIterator* iter, @@ -2007,7 +2326,7 @@ int32_t HashTable::listNullKeyRows( template void HashTable::erase(folly::Range rows) { auto numRows = rows.size(); - raw_vector hashes; + raw_vector hashes(pool_); hashes.resize(numRows); for (int32_t i = 0; i < hashers_.size(); ++i) { @@ -2063,7 +2382,7 @@ void HashTable::eraseWithHashes( } numDistinct_ -= numRows; if (!otherTables_.empty()) { - raw_vector containerRows; + raw_vector containerRows(pool_); containerRows.resize(rows.size()); for (auto& other : otherTables_) { const auto numContainerRows = diff --git a/velox/exec/HashTable.h b/velox/exec/HashTable.h index 870ab7f01db0..8bcf3df5bd42 100644 --- a/velox/exec/HashTable.h +++ b/velox/exec/HashTable.h @@ -115,6 +115,8 @@ struct HashTableStats { struct ParallelJoinBuildStats { std::vector partitionTimings; std::vector buildTimings; + std::vector bloomFilterPartitionTimings; + std::vector bloomFilterBuildTimings; }; class BaseHashTable { @@ -127,6 +129,9 @@ class BaseHashTable { using MaskType = uint16_t; + /// The load factor of the hash table. + static constexpr double kHashTableLoadFactor = 0.7; + /// 2M entries, i.e. 16MB is the largest array based hash table. static constexpr uint64_t kArrayHashMaxSize = 2L << 20; @@ -152,6 +157,18 @@ class BaseHashTable { "hashtable.parallelJoinBuildWallNanos"}; static inline const std::string kParallelJoinBuildCpuNanos{ "hashtable.parallelJoinBuildCpuNanos"}; + static inline const std::string kParallelJoinBloomFilterPartitionWallNanos{ + "hashtable.parallelJoinBloomFilterPartitionWallNanos"}; + static inline const std::string kParallelJoinBloomFilterPartitionCpuNanos{ + "hashtable.parallelJoinBloomFilterPartitionCpuNanos"}; + static inline const std::string kParallelJoinBloomFilterBuildWallNanos{ + "hashtable.parallelJoinBloomFilterBuildWallNanos"}; + static inline const std::string kParallelJoinBloomFilterBuildCpuNanos{ + "hashtable.parallelJoinBloomFilterBuildCpuNanos"}; + static inline const std::string kVectorHasherMergeCpuNanos{ + "hashtable.vectorHasherMergeCpuNanos"}; + static inline const std::string kHashTableCacheHit{"hashtable.cacheHit"}; + static inline const std::string kHashTableCacheMiss{"hashtable.cacheMiss"}; /// Returns the string of the given 'mode'. static std::string modeString(HashMode mode); @@ -282,6 +299,15 @@ class BaseHashTable { uint64_t maxBytes, char** rows) = 0; + /// Same as above, but only return rows from the row container of + /// 'rowContainerId'. + virtual int32_t listNotProbedRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) = 0; + /// Returns rows with 'probed' flag set. Used by the right semi join. virtual int32_t listProbedRows( RowsIterator* iter, @@ -289,6 +315,15 @@ class BaseHashTable { uint64_t maxBytes, char** rows) = 0; + /// Same as above, but only return rows from the row container of + /// 'rowContainerId'. + virtual int32_t listProbedRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) = 0; + /// Returns all rows. Used by the right semi join project. virtual int32_t listAllRows( RowsIterator* iter, @@ -296,6 +331,15 @@ class BaseHashTable { uint64_t maxBytes, char** rows) = 0; + /// Same as above, but only return rows from the row container of + /// 'rowContainerId'. + virtual int32_t listAllRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) = 0; + /// Returns all rows with null keys. Used by null-aware joins (e.g. anti or /// left semi project). virtual int32_t listNullKeyRows( @@ -307,8 +351,22 @@ class BaseHashTable { virtual void prepareJoinTable( std::vector> tables, int8_t spillInputStartPartitionBit, + size_t vectorHasherMaxNumDistinct, + bool dropDuplicates = false, folly::Executor* executor = nullptr) = 0; + /// The hash table used for join build in left semi and anti join may not + /// retain duplicate join keys when allowDuplicates_ is false. This is + /// achieved by constructing the hash table in the addInput phase to eliminate + /// duplicate join keys. When the percentage of duplicate data is small, it + /// will adaptively adjust to not build the hash table in the addInput phase. + /// Instead, it operates like other join types by reading all the data before + /// building the hash table. This function is used to change the behavior of + /// building hash table, if allowDuplicates is true, the join hash table will + /// not be built during the addInput phase, and the input data will also not + /// be deduplicated, but it will not impact the containing row container. + virtual void setAllowDuplicates(bool allowDuplicates) = 0; + /// Returns the memory footprint in bytes for any data structures /// owned by 'this'. virtual int64_t allocatedBytes() const = 0; @@ -325,6 +383,9 @@ class BaseHashTable { /// side. This is used for sizing the internal hash table. virtual uint64_t numDistinct() const = 0; + /// Returns the number of row containers in this hash table. + virtual int32_t numRowContainers() const = 0; + /// Return a number of current stats that can help with debugging and /// profiling. virtual HashTableStats stats() const = 0; @@ -411,8 +472,7 @@ class BaseHashTable { __attribute__((__no_sanitize__("thread"))) #endif #endif - static TagVector - loadTags(uint8_t* tags, int64_t tagIndex) { + static TagVector loadTags(uint8_t* tags, int64_t tagIndex) { // Cannot use xsimd::batch::unaligned here because we need to skip TSAN. auto src = tags + tagIndex; #if XSIMD_WITH_SSE2 @@ -426,6 +486,10 @@ class BaseHashTable { return parallelJoinBuildStats_; } + const CpuWallTiming& vectorHasherMergeTiming() const { + return vectorHasherMergeTiming_; + } + /// Copies the values at 'columnIndex' into 'result' for the 'rows.size' rows /// pointed to by 'rows'. If an entry in 'rows' is null, sets corresponding /// row in 'result' to null. @@ -449,6 +513,7 @@ class BaseHashTable { std::unique_ptr rows_; ParallelJoinBuildStats parallelJoinBuildStats_; + CpuWallTiming vectorHasherMergeTiming_; }; FOLLY_ALWAYS_INLINE std::ostream& operator<<( @@ -482,7 +547,8 @@ class HashTable : public BaseHashTable { bool isJoinBuild, bool hasProbedFlag, uint32_t minTableSizeForParallelJoinBuild, - memory::MemoryPool* pool); + memory::MemoryPool* pool, + uint64_t bloomFilterMaxSize = 0); ~HashTable() override = default; @@ -507,7 +573,8 @@ class HashTable : public BaseHashTable { bool allowDuplicates, bool hasProbedFlag, uint32_t minTableSizeForParallelJoinBuild, - memory::MemoryPool* pool) { + memory::MemoryPool* pool, + uint64_t bloomFilterMaxSize = 0) { return std::make_unique( std::move(hashers), std::vector{}, @@ -516,7 +583,8 @@ class HashTable : public BaseHashTable { true, // isJoinBuild hasProbedFlag, minTableSizeForParallelJoinBuild, - pool); + pool, + bloomFilterMaxSize); } void groupProbe(HashLookup& lookup, int8_t spillInputStartPartitionBit) @@ -549,6 +617,27 @@ class HashTable : public BaseHashTable { uint64_t maxBytes, char** rows) override; + int32_t listNotProbedRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) override; + + int32_t listProbedRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) override; + + int32_t listAllRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows) override; + int32_t listNullKeyRows( NullKeyRowsIterator* iter, int32_t maxRows, @@ -575,6 +664,10 @@ class HashTable : public BaseHashTable { return numDistinct_; } + int32_t numRowContainers() const override { + return otherTables_.size() + 1; + } + HashTableStats stats() const override { return HashTableStats{ capacity_, numRehashes_, numDistinct_, numTombstones_}; @@ -584,6 +677,10 @@ class HashTable : public BaseHashTable { return hasDuplicates_.check(); } + void setAllowDuplicates(const bool allowDuplicates) override { + allowDuplicates_ = allowDuplicates; + } + HashMode hashMode() const override { return hashMode_; } @@ -608,6 +705,8 @@ class HashTable : public BaseHashTable { void prepareJoinTable( std::vector> tables, int8_t spillInputStartPartitionBit, + size_t vectorHasherMaxNumDistinct, + bool dropDuplicates = false, folly::Executor* executor = nullptr) override; void prepareForJoinProbe( @@ -742,8 +841,7 @@ class HashTable : public BaseHashTable { // Returns the number of entries after which the table gets rehashed. static uint64_t rehashSize(int64_t size) { - // This implements the F14 load factor: Resize if less than 1/8 unoccupied. - return size - (size / 8); + return size * kHashTableLoadFactor; } // Returns the number of entries with 'numNew' and existing 'numDistincts' @@ -764,6 +862,14 @@ class HashTable : public BaseHashTable { int32_t listRows(RowsIterator* iter, int32_t maxRows, uint64_t maxBytes, char** rows); + template + int32_t listRows( + RowContainerIterator& rowContainerIterator, + int rowContainerId, + int32_t maxRows, + uint64_t maxBytes, + char** rows); + char*& nextRow(char* row) { return *reinterpret_cast(row + nextOffset_); } @@ -844,7 +950,7 @@ class HashTable : public BaseHashTable { // to the end of 'overflows' in 'partitionInfo'. void insertForJoin( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo); @@ -852,7 +958,8 @@ class HashTable : public BaseHashTable { // contents in a RowContainer owned by 'this'. 'hashes' are the hash // numbers or array indices (if kArray mode) for each // group. 'groups' is expected to have no duplicate keys. - void insertForGroupBy(char** groups, uint64_t* hashes, int32_t numGroups); + void + insertForGroupBy(char** groups, const uint64_t* hashes, int32_t numGroups); // Checks if we can apply parallel table build optimization for hash join. // The function returns true if all of the following conditions: @@ -888,6 +995,20 @@ class HashTable : public BaseHashTable { HashTable& subtable, RowPartitions& rowPartitions); + // Whether we should build Bloom filters. If Bloom filter pushdown is + // enabled, and the size fits, this returns true if any of the key columns + // support it. The actual build should build a Bloom filter for each key + // column that supports it, and skip the ones that do not. + bool bloomFilterSupported() const; + + // Populate the Bloom filter for the key column with `columnIndex` and rows + // with certain `partition`. The partitions information is stored in + // `rowPartitions`. + void buildBloomFilterPartition( + column_index_t columnIndex, + uint8_t partition, + const std::vector>& rowPartitions); + // Calculates hashes for 'rows' and returns them in 'hashes'. If // 'initNormalizedKeys' is true, the normalized keys are stored below each row // in the container. If 'initNormalizedKeys' is false and the table is in @@ -948,7 +1069,7 @@ class HashTable : public BaseHashTable { template void insertForJoinWithPrefetch( char** groups, - uint64_t* hashes, + const uint64_t* hashes, int32_t numGroups, TableInsertPartitionInfo* partitionInfo); @@ -965,7 +1086,13 @@ class HashTable : public BaseHashTable { // or distinct mode VectorHashers in a group by hash table. 0 for // join build sides. int32_t reservePct() const { - return isJoinBuild_ ? 0 : 50; + return (isJoinBuild_ && allowDuplicates_) ? 0 : 50; + } + + // Used to indicate whether it is a HashTable that does not contain duplicate + // join keys. + bool joinBuildNoDuplicates() const { + return isJoinBuild_ && !allowDuplicates_; } // Returns the byte offset of the bucket for 'hash' starting from 'table_'. @@ -1033,8 +1160,11 @@ class HashTable : public BaseHashTable { // The min table size in row to trigger parallel join table build. const uint32_t minTableSizeForParallelJoinBuild_; + const uint64_t bloomFilterMaxSize_; + int8_t sizeBits_; bool isJoinBuild_ = false; + bool allowDuplicates_ = true; // Set at join build time if the table has duplicates, meaning that // the join can be cardinality increasing. Atomic for tsan because diff --git a/velox/exec/HashTableCache.cpp b/velox/exec/HashTableCache.cpp new file mode 100644 index 000000000000..1bb0048bf38e --- /dev/null +++ b/velox/exec/HashTableCache.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/HashTableCache.h" + +#include + +#include "velox/core/QueryCtx.h" + +namespace facebook::velox::exec { + +HashTableCache* HashTableCache::instance() { + static HashTableCache instance; + return &instance; +} + +std::shared_ptr HashTableCache::get( + const std::string& key, + const std::string& taskId, + core::QueryCtx* queryCtx, + ContinueFuture* future) { + VELOX_CHECK_NOT_NULL(future, "future parameter must not be null"); + VELOX_CHECK_NOT_NULL(queryCtx, "queryCtx parameter must not be null"); + + std::lock_guard guard(lock_); + + auto it = tables_.find(key); + if (it == tables_.end()) { + // No entry exists - create a placeholder for this task to build the table. + auto* queryPool = queryCtx->pool(); + auto entry = std::make_shared( + key, + taskId, + queryPool->addLeafChild(fmt::format("cached_table_{}", key))); + tables_.insert({key, entry}); + + // Register callback to clean up this cache entry when QueryCtx is + // destroyed. This ensures tablePool memory is freed before the query + // pool is destroyed. + queryCtx->addReleaseCallback( + [cacheKey = key]() { HashTableCache::instance()->drop(cacheKey); }); + + // Return entry with pool, table will be filled later. + return entry; + } + + auto& entry = it->second; + + // Check if build is complete + if (entry->buildComplete) { + return entry; + } + + // If this is the builder task, don't wait - all drivers of the builder task + // should proceed to build (they coordinate via JoinBridge, not here). + if (entry->builderTaskId == taskId) { + return entry; + } + + auto [promise, _future] = + makeVeloxContinuePromiseContract(fmt::format("HashTableCache::{}", key)); + entry->buildPromises.push_back(std::move(promise)); + *future = std::move(_future); + + return entry; +} + +void HashTableCache::put( + const std::string& key, + std::shared_ptr table, + bool hasNullKeys) { + std::vector promises; + + { + std::lock_guard guard(lock_); + + auto it = tables_.find(key); + VELOX_CHECK( + it != tables_.end(), + "Cache entry for key '{}' must be created by get() before put()", + key); + + auto& entry = it->second; + VELOX_CHECK(!entry->buildComplete); + VELOX_CHECK_NULL(entry->table); + // Update the entry with the built table + entry->table = std::move(table); + entry->hasNullKeys = hasNullKeys; + entry->buildComplete = true; + + // Collect promises to notify waiters + promises = std::move(entry->buildPromises); + } + + // Notify all waiting tasks outside the lock + for (auto& promise : promises) { + promise.setValue(); + } +} + +void HashTableCache::drop(const std::string& key) { + std::shared_ptr entry; + { + std::lock_guard guard(lock_); + auto it = tables_.find(key); + if (it != tables_.end()) { + entry = std::move(it->second); + tables_.erase(it); + } + } + + // Clear the table outside the lock to free memory before the entry + // is destroyed. This ensures the tablePool's memory is released + // before any parent pools are destroyed. + if (entry) { + entry->table.reset(); + } +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/HashTableCache.h b/velox/exec/HashTableCache.h new file mode 100644 index 000000000000..e1d9717e89e6 --- /dev/null +++ b/velox/exec/HashTableCache.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "velox/exec/HashTable.h" + +namespace facebook::velox::core { +class QueryCtx; +} + +namespace facebook::velox::exec { + +/// Cached hash table entry with build coordination metadata. +struct HashTableCacheEntry { + HashTableCacheEntry( + std::string _cacheKey, + std::string _builderTaskId, + std::shared_ptr _tablePool) + : cacheKey(std::move(_cacheKey)), + builderTaskId(std::move(_builderTaskId)), + tablePool(std::move(_tablePool)) {} + + const std::string cacheKey; + const std::string builderTaskId; + const std::shared_ptr tablePool; + std::shared_ptr table; + bool hasNullKeys{false}; + tsan_atomic buildComplete{false}; + std::vector buildPromises; +}; + +/// Global cache for hash tables shared across tasks within the same query. +/// First task builds the table, subsequent tasks wait and reuse it. +class HashTableCache { + public: + static HashTableCache* instance(); + + /// Gets or creates a cache entry. First caller becomes the builder. + /// Subsequent callers from different tasks get a future to wait on. + /// When a new entry is created, a release callback is registered on queryCtx + /// to clean up the entry when the query completes. + /// @param future Must be non-null; set if caller needs to wait. + std::shared_ptr get( + const std::string& key, + const std::string& taskId, + core::QueryCtx* queryCtx, + ContinueFuture* future); + + /// Stores a built hash table and notifies waiting tasks. + void put( + const std::string& key, + std::shared_ptr table, + bool hasNullKeys); + + /// Removes a cache entry. + void drop(const std::string& key); + + private: + HashTableCache() = default; + + std::mutex lock_; + std::unordered_map> tables_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/HilbertIndex.h b/velox/exec/HilbertIndex.h new file mode 100644 index 000000000000..9d8fd096f953 --- /dev/null +++ b/velox/exec/HilbertIndex.h @@ -0,0 +1,156 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Based off of https://threadlocalmutex.com/?p=126 + +#pragma once + +#include +#include + +#include "velox/common/base/Exceptions.h" + +namespace facebook::velox::exec { + +class HilbertIndex { + public: + /// Construct a Hilber index. If a min value is greater than the max value, + /// this will panic. + HilbertIndex(float minX, float minY, float maxX, float maxY) + : minX_(minX), minY_(minY), maxX_(maxX), maxY_(maxY) { + VELOX_CHECK(minX_ <= maxX_); + VELOX_CHECK(minY_ <= maxY_); + + float deltaX = maxX_ - minX_; + // Subnormals cause numerical instability. + // NOLINTNEXTLINE(facebook-hte-FloatingPointMin) + if (deltaX < std::numeric_limits::min()) { + xScale_ = 0; + } else { + xScale_ = kHilbertMax / deltaX; + } + + float deltaY = maxY_ - minY_; + // Subnormals cause numerical instability. + // NOLINTNEXTLINE(facebook-hte-FloatingPointMin) + if (deltaY < std::numeric_limits::min()) { + yScale_ = 0; + } else { + yScale_ = kHilbertMax / deltaY; + } + } + + uint32_t inline indexOf(float x, float y) const { + if (!(x >= minX_ && x <= maxX_ && y >= minY_ && y <= maxY_)) { + // Put things outside the bounds at the end of the Hilbert curve. + // Negation handles NaNs + return std::numeric_limits::max(); + } + + float maxFloat = static_cast(std::numeric_limits::max()); + + uint32_t xInt = static_cast( + std::clamp(xScale_ * (x - minX_), 0.0f, maxFloat)); + uint32_t yInt = static_cast( + std::clamp(yScale_ * (y - minY_), 0.0f, maxFloat)); + return discreteIndexOf(xInt, yInt); + } + + private: + static inline uint32_t interleave(uint32_t x) { + x = (x | (x << 8)) & 0x00FF00FF; + x = (x | (x << 4)) & 0x0F0F0F0F; + x = (x | (x << 2)) & 0x33333333; + x = (x | (x << 1)) & 0x55555555; + return x; + } + + static inline uint32_t discreteIndexOf(uint32_t x, uint32_t y) { + uint32_t A, B, C, D; + + // Initial prefix scan round, prime with x and y + { + uint32_t a = x ^ y; + uint32_t b = 0xFFFF ^ a; + uint32_t c = 0xFFFF ^ (x | y); + uint32_t d = x & (y ^ 0xFFFF); + + A = a | (b >> 1); + B = (a >> 1) ^ a; + + C = ((c >> 1) ^ (b & (d >> 1))) ^ c; + D = ((a & (c >> 1)) ^ (d >> 1)) ^ d; + } + + { + uint32_t a = A; + uint32_t b = B; + uint32_t c = C; + uint32_t d = D; + + A = ((a & (a >> 2)) ^ (b & (b >> 2))); + B = ((a & (b >> 2)) ^ (b & ((a ^ b) >> 2))); + + C ^= ((a & (c >> 2)) ^ (b & (d >> 2))); + D ^= ((b & (c >> 2)) ^ ((a ^ b) & (d >> 2))); + } + + { + uint32_t a = A; + uint32_t b = B; + uint32_t c = C; + uint32_t d = D; + + A = ((a & (a >> 4)) ^ (b & (b >> 4))); + B = ((a & (b >> 4)) ^ (b & ((a ^ b) >> 4))); + + C ^= ((a & (c >> 4)) ^ (b & (d >> 4))); + D ^= ((b & (c >> 4)) ^ ((a ^ b) & (d >> 4))); + } + + // Final round and projection + { + uint32_t a = A; + uint32_t b = B; + uint32_t c = C; + uint32_t d = D; + + C ^= ((a & (c >> 8)) ^ (b & (d >> 8))); + D ^= ((b & (c >> 8)) ^ ((a ^ b) & (d >> 8))); + } + + // Undo transformation prefix scan + uint32_t a = C ^ (C >> 1); + uint32_t b = D ^ (D >> 1); + + // Recover index bits + uint32_t i0 = x ^ y; + uint32_t i1 = b | (0xFFFF ^ (i0 | a)); + + return (interleave(i1) << 1) | interleave(i0); + } + + static const int8_t kHilbertBits = 16; + static constexpr float kHilbertMax = (1 << kHilbertBits) - 1; + + const float minX_; + const float minY_; + const float maxX_; + const float maxY_; + float xScale_; + float yScale_; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/IndexLookupJoin.cpp b/velox/exec/IndexLookupJoin.cpp index 1fe92ce84416..bfcf0868b80b 100644 --- a/velox/exec/IndexLookupJoin.cpp +++ b/velox/exec/IndexLookupJoin.cpp @@ -17,6 +17,7 @@ #include "velox/buffer/Buffer.h" #include "velox/connectors/Connector.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/Expr.h" #include "velox/expression/FieldReference.h" @@ -85,8 +86,9 @@ bool addBetweenConditionBound( lookupInputChannels, lookupInputNameSet); } else { - VELOX_USER_CHECK(core::TypedExprs::asConstant(typeExpr)->type()->equivalent( - *indexKeyType)); + VELOX_USER_CHECK( + core::TypedExprs::asConstant(typeExpr)->type()->equivalent( + *indexKeyType)); } return isConstant; } @@ -138,15 +140,20 @@ IndexLookupJoin::IndexLookupJoin( operatorId, joinNode->id(), "IndexLookupJoin"), + splitOutput_{driverCtx->queryConfig().indexLookupJoinSplitOutput()}, // TODO: support to update output batch size with output size stats during // the lookup processing. - outputBatchSize_{outputBatchRows()}, + outputBatchSize_{ + driverCtx->queryConfig().indexLookupJoinSplitOutput() + ? outputBatchRows() + : std::numeric_limits::max()}, joinType_{joinNode->joinType()}, + hasMarker_(joinNode->hasMarker()), numKeys_{joinNode->leftKeys().size()}, probeType_{joinNode->sources()[0]->outputType()}, lookupType_{joinNode->lookupSource()->outputType()}, lookupTableHandle_{joinNode->lookupSource()->tableHandle()}, - lookupConditions_{joinNode->joinConditions()}, + joinConditions_{joinNode->joinConditions()}, lookupColumnHandles_(joinNode->lookupSource()->assignments()), connectorQueryCtx_{operatorCtx_->createConnectorQueryCtx( lookupTableHandle_->connectorId(), @@ -179,11 +186,12 @@ void IndexLookupJoin::initialize() { initLookupInput(); initLookupOutput(); initOutputProjections(); + initFilter(); indexSource_ = connector_->createIndexSource( lookupInputType_, numKeys_, - lookupConditions_, + joinConditions_, lookupOutputType_, lookupTableHandle_, lookupColumnHandles_, @@ -194,10 +202,7 @@ void IndexLookupJoin::ensureInputLoaded(const InputBatchState& batch) { VELOX_CHECK_GT(numInputBatches(), 0); // Ensure each input vector are lazy loaded before process next batch. This is // to ensure the ordered lazy materialization in the source readers. - auto& input = batch.input; - for (auto i = 0; i < input->childrenSize(); ++i) { - input->childAt(i)->loadedVector(); - } + loadColumns(batch.input, *operatorCtx_->execCtx()); } void IndexLookupJoin::initInputBatches() { @@ -213,14 +218,13 @@ void IndexLookupJoin::initLookupInput() { VELOX_CHECK(lookupInputChannels_.empty()); std::vector lookupInputNames; - lookupInputNames.reserve(numKeys_ + lookupConditions_.size()); + lookupInputNames.reserve(numKeys_ + joinConditions_.size()); std::vector lookupInputTypes; - lookupInputTypes.reserve(numKeys_ + lookupConditions_.size()); - lookupInputChannels_.reserve(numKeys_ + lookupConditions_.size()); + lookupInputTypes.reserve(numKeys_ + joinConditions_.size()); + lookupInputChannels_.reserve(numKeys_ + joinConditions_.size()); SCOPE_EXIT { - VELOX_CHECK_GE( - lookupInputNames.size(), numKeys_ + lookupConditions_.size()); + VELOX_CHECK_GE(lookupInputNames.size(), numKeys_ + joinConditions_.size()); VELOX_CHECK_EQ(lookupInputNames.size(), lookupInputChannels_.size()); lookupInputType_ = ROW(std::move(lookupInputNames), std::move(lookupInputTypes)); @@ -249,19 +253,25 @@ void IndexLookupJoin::initLookupInput() { lookupInputColumnSet); } - if (lookupConditions_.empty()) { + SCOPE_EXIT { + VELOX_CHECK(lookupKeyOrConditionHashers_.empty()); + VELOX_CHECK(!lookupInputChannels_.empty()); + lookupKeyOrConditionHashers_ = + createVectorHashers(probeType_, lookupInputChannels_); + }; + if (joinConditions_.empty()) { return; } - for (const auto& lookupCondition : lookupConditions_) { - const auto indexKeyName = getColumnName(lookupCondition->key); + for (const auto& joinCondition : joinConditions_) { + const auto indexKeyName = getColumnName(joinCondition->key); VELOX_USER_CHECK_EQ(lookupIndexColumnSet.count(indexKeyName), 0); lookupIndexColumnSet.insert(indexKeyName); const auto indexKeyType = lookupType_->findChild(indexKeyName); if (const auto inCondition = std::dynamic_pointer_cast( - lookupCondition)) { + joinCondition)) { const auto conditionInputName = getColumnName(inCondition->list); const auto conditionInputChannel = probeType_->getChildIdx(conditionInputName); @@ -282,7 +292,7 @@ void IndexLookupJoin::initLookupInput() { if (const auto betweenCondition = std::dynamic_pointer_cast( - lookupCondition)) { + joinCondition)) { addBetweenCondition( betweenCondition, probeType_, @@ -292,6 +302,21 @@ void IndexLookupJoin::initLookupInput() { lookupInputChannels_, lookupInputColumnSet); } + + if (const auto equalCondition = + std::dynamic_pointer_cast( + joinCondition)) { + // Process an equal join condition by validating that the value is + // constant. Equal conditions only support constant values for filtering. + VELOX_USER_CHECK( + core::TypedExprs::isConstant(equalCondition->value), + "Equal condition value must be constant: {}", + equalCondition->toString()); + VELOX_USER_CHECK( + core::TypedExprs::asConstant(equalCondition->value) + ->type() + ->equivalent(*indexKeyType)); + } } } @@ -349,11 +374,63 @@ void IndexLookupJoin::initOutputProjections() { } lookupOutputProjections_.emplace_back(i, outputChannelOpt.value()); } + if (hasMarker_) { + matchOutputChannel_ = outputType_->size() - 1; + } + VELOX_USER_CHECK_EQ( - probeOutputProjections_.size() + lookupOutputProjections_.size(), + probeOutputProjections_.size() + lookupOutputProjections_.size() + + !!matchOutputChannel_.has_value(), outputType_->size()); } +void IndexLookupJoin::initFilter() { + VELOX_CHECK_NULL(filter_); + + if (joinNode_->filter() == nullptr) { + return; + } + + std::vector filters = {joinNode_->filter()}; + filter_ = + std::make_unique(std::move(filters), operatorCtx_->execCtx()); + + std::vector names; + std::vector types; + const auto numFields = filter_->expr(0)->distinctFields().size(); + names.reserve(numFields); + types.reserve(numFields); + + column_index_t filterChannel{0}; + const auto addChannel = [&](column_index_t channel, + const RowTypePtr& inputType, + std::vector& projections) { + names.emplace_back(inputType->nameOf(channel)); + types.emplace_back(inputType->childAt(channel)); + projections.emplace_back(channel, filterChannel++); + }; + + for (const auto& field : filter_->expr(0)->distinctFields()) { + const auto& name = field->field(); + auto channel = probeType_->getChildIdxIfExists(name); + if (channel.has_value()) { + addChannel(channel.value(), probeType_, filterProbeInputProjections_); + continue; + } + channel = lookupOutputType_->getChildIdxIfExists(name); + if (channel.has_value()) { + addChannel( + channel.value(), lookupOutputType_, filterLookupInputProjections_); + continue; + } + VELOX_FAIL( + "Index lookup join filter field not found in either left or right input: {}", + field->toString()); + } + + filterInputType_ = ROW(std::move(names), std::move(types)); +} + bool IndexLookupJoin::startDrain() { return numInputBatches() != 0; } @@ -415,11 +492,10 @@ void IndexLookupJoin::addInput(RowVectorPtr input) { auto& batch = nextInputBatch(); VELOX_CHECK_LE(numInputBatches(), maxNumInputBatches_); batch.input = std::move(input); - if (numInputBatches() > 0) { - ensureInputLoaded(batch); - prepareLookup(batch); - startLookup(batch); - } + ensureInputLoaded(batch); + decodeAndDetectNonNullKeys(batch); + prepareLookup(batch); + startLookup(batch); } RowVectorPtr IndexLookupJoin::getOutput() { @@ -450,19 +526,175 @@ RowVectorPtr IndexLookupJoin::getOutput() { void IndexLookupJoin::prepareLookup(InputBatchState& batch) { VELOX_CHECK_GT(numInputBatches(), 0); VELOX_CHECK_NOT_NULL(batch.input); + const size_t numLookupRows = batch.lookupInputHasNullKeys + ? batch.nonNullInputRows.countSelected() + : batch.input->size(); if (batch.lookupInput == nullptr) { - batch.lookupInput = BaseVector::create( - lookupInputType_, batch.input->size(), pool()); + batch.lookupInput = + BaseVector::create(lookupInputType_, numLookupRows, pool()); } else { VectorPtr lookupInputVector = std::move(batch.lookupInput); - BaseVector::prepareForReuse(lookupInputVector, batch.input->size()); + BaseVector::prepareForReuse(lookupInputVector, numLookupRows); batch.lookupInput = std::static_pointer_cast(lookupInputVector); } + if (!batch.lookupInputHasNullKeys) { + for (auto i = 0; i < lookupInputType_->size(); ++i) { + batch.input->childAt(lookupInputChannels_[i])->loadedVector(); + batch.lookupInput->childAt(i) = + batch.input->childAt(lookupInputChannels_[i]); + } + return; + } + + if (batch.lookupInput->size() == 0) { + return; + } + + const auto mappingByteSize = numLookupRows * sizeof(vector_size_t); + if ((batch.nonNullInputMappings == nullptr) || + !batch.nonNullInputMappings->unique() || + (batch.nonNullInputMappings->capacity() < mappingByteSize)) { + batch.nonNullInputMappings = allocateIndices(numLookupRows, pool()); + batch.rawNonNullInputMappings = + batch.nonNullInputMappings->asMutable(); + } + batch.nonNullInputMappings->setSize(mappingByteSize); + + size_t lookupRow = 0; + batch.nonNullInputRows.applyToSelected( + [&](auto row) { batch.rawNonNullInputMappings[lookupRow++] = row; }); + VELOX_CHECK_EQ(lookupRow, numLookupRows); + for (auto i = 0; i < lookupInputType_->size(); ++i) { - batch.lookupInput->childAt(i) = - batch.input->childAt(lookupInputChannels_[i]); - batch.lookupInput->childAt(i)->loadedVector(); + batch.input->childAt(lookupInputChannels_[i])->loadedVector(); + batch.lookupInput->childAt(i) = BaseVector::wrapInDictionary( + nullptr, + batch.nonNullInputMappings, + numLookupRows, + batch.input->childAt(lookupInputChannels_[i])); + } +} + +void IndexLookupJoin::mergeLookupResults(InputBatchState& batch) { + VELOX_CHECK(!batch.partialOutputs.empty()); + VELOX_CHECK_NULL(batch.lookupResult); + + if (batch.partialOutputs.size() == 1) { + batch.lookupResult = std::move(batch.partialOutputs[0]); + SCOPE_EXIT { + VELOX_CHECK_NOT_NULL(batch.lookupResult); + batch.partialOutputs.clear(); + }; + return; + } + + // Calculate total size. + vector_size_t totalSize = 0; + for (const auto& result : batch.partialOutputs) { + totalSize += static_cast(result->size()); + } + + // Merge inputHits buffers. + auto mergedInputHits = allocateIndices(totalSize, pool()); + auto* rawMergedInputHits = mergedInputHits->asMutable(); + vector_size_t offset = 0; + for (const auto& result : batch.partialOutputs) { + std::memcpy( + rawMergedInputHits + offset, + result->inputHits->as(), + result->size() * sizeof(vector_size_t)); + offset += static_cast(result->size()); + } + + // Merge output RowVectors. + // NOTE: Uncommon path for connectors that do not respect output batch size + // properly + auto mergedOutput = BaseVector::create( + batch.partialOutputs[0]->output->type(), totalSize, pool()); + vector_size_t outputOffset = 0; + for (const auto& result : batch.partialOutputs) { + mergedOutput->copy(result->output.get(), outputOffset, 0, result->size()); + outputOffset += static_cast(result->size()); + } + + batch.lookupResult = std::make_unique( + std::move(mergedInputHits), std::move(mergedOutput)); + batch.partialOutputs.clear(); +} + +bool IndexLookupJoin::getLookupResults(InputBatchState& batch) { + VELOX_CHECK_NOT_NULL(batch.lookupInput); + VELOX_CHECK_NOT_NULL(batch.lookupResultIter); + VELOX_CHECK(!batch.lookupFuture.valid()); + + // Result is ready. + if (batch.lookupResult != nullptr) { + return true; + } + + // Fetch the first result if not already fetched. + if (batch.lookupResult == nullptr && batch.partialOutputs.empty()) { + auto lookupResultOr = + batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); + if (!lookupResultOr.has_value()) { + VELOX_CHECK(batch.lookupFuture.valid()); + return false; + } + VELOX_CHECK(!batch.lookupFuture.valid()); + + // Either splitOutput_ is true, or no more results, or first result is null. + if (splitOutput_ || !lookupResultOr.has_value() || + !batch.lookupResultIter->hasNext()) { + batch.lookupResult = std::move(lookupResultOr).value(); + return true; + } + + // Otherwise start accumulating results. + batch.partialOutputs.push_back(std::move(lookupResultOr).value()); + } + + // Continue accumulating remaining results when splitOutput_ is false. + // This handles both initial accumulation and resuming after async + // interruption. + VELOX_CHECK(!splitOutput_); + VELOX_CHECK(!batch.partialOutputs.empty()); + VELOX_CHECK_NULL(batch.lookupResult); + + while (batch.lookupResultIter->hasNext()) { + auto nextResultOr = + batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); + if (!nextResultOr.has_value()) { + // Need to wait for async operation. + VELOX_CHECK(batch.lookupFuture.valid()); + return false; + } + VELOX_CHECK(!batch.lookupFuture.valid()); + auto nextResult = std::move(nextResultOr).value(); + if (nextResult != nullptr) { + batch.partialOutputs.push_back(std::move(nextResult)); + } + } + + // All results accumulated, merge them. + mergeLookupResults(batch); + return true; +} + +void IndexLookupJoin::decodeAndDetectNonNullKeys(InputBatchState& batch) { + const auto numRows = batch.input->size(); + batch.nonNullInputRows.resize(numRows); + batch.nonNullInputRows.setAll(); + + for (auto i = 0; i < lookupKeyOrConditionHashers_.size(); ++i) { + const auto* key = + batch.input->childAt(lookupKeyOrConditionHashers_[i]->channel()) + ->loadedVector(); + lookupKeyOrConditionHashers_[i]->decode(*key, batch.nonNullInputRows); + } + deselectRowsWithNulls(lookupKeyOrConditionHashers_, batch.nonNullInputRows); + if (batch.nonNullInputRows.countSelected() < numRows) { + batch.lookupInputHasNullKeys = true; } } @@ -470,21 +702,25 @@ void IndexLookupJoin::startLookup(InputBatchState& batch) { VELOX_CHECK_GT(numInputBatches(), 0); VELOX_CHECK_NOT_NULL(batch.input); VELOX_CHECK_NOT_NULL(batch.lookupInput); - VELOX_CHECK_EQ(batch.lookupInput->size(), batch.input->size()); + if (batch.lookupInputHasNullKeys) { + VELOX_CHECK_LT(batch.lookupInput->size(), batch.input->size()); + } else { + VELOX_CHECK_EQ(batch.lookupInput->size(), batch.input->size()); + } VELOX_CHECK_NULL(batch.lookupResultIter); VELOX_CHECK_NULL(batch.lookupResult); VELOX_CHECK(!batch.lookupFuture.valid()); - batch.lookupResultIter = indexSource_->lookup( - connector::IndexSource::LookupRequest{batch.lookupInput}); - auto lookupResultOr = - batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); - if (!lookupResultOr.has_value()) { - VELOX_CHECK(batch.lookupFuture.valid()); + if (batch.lookupInput->size() == 0) { + // No need to start lookup for empty lookup input. return; } - VELOX_CHECK(!batch.lookupFuture.valid()); - batch.lookupResult = std::move(lookupResultOr).value(); + + // Create the lookup result iterator. + batch.lookupResultIter = indexSource_->lookup( + connector::IndexSource::LookupRequest{batch.lookupInput}); + + getLookupResults(batch); } RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( @@ -492,33 +728,38 @@ RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( VELOX_CHECK(!batch.empty()); VELOX_CHECK(!batch.lookupFuture.valid() || batch.lookupFuture.isReady()); batch.lookupFuture = ContinueFuture::makeEmpty(); - VELOX_CHECK_NOT_NULL(batch.lookupResultIter); - if (batch.lookupResult == nullptr) { - auto resultOptional = - batch.lookupResultIter->next(outputBatchSize_, batch.lookupFuture); - if (!resultOptional.has_value()) { - VELOX_CHECK(batch.lookupFuture.valid()); - return nullptr; - } - VELOX_CHECK(!batch.lookupFuture.valid()); + if (batch.lookupInput->size() == 0) { + return produceRemainingOutput(batch); + } - batch.lookupResult = std::move(resultOptional).value(); - if (batch.lookupResult == nullptr) { - if (hasRemainingOutputForLeftJoin(batch)) { - return produceRemainingOutputForLeftJoin(batch); - } - finishInput(batch); - return nullptr; + if (!getLookupResults(batch)) { + // Async operation pending, need to wait. + VELOX_CHECK(batch.lookupFuture.valid()); + return nullptr; + } + + VELOX_CHECK(!batch.lookupFuture.valid()); + VELOX_CHECK(batch.partialOutputs.empty()); + + if (batch.lookupResult == nullptr) { + if (hasRemainingOutputForLeftJoin(batch)) { + return produceRemainingOutputForLeftJoin(batch); } - rawLookupInputHitIndices_ = - batch.lookupResult->inputHits->as(); - } else if (rawLookupInputHitIndices_ == nullptr) { - rawLookupInputHitIndices_ = - batch.lookupResult->inputHits->as(); + finishInput(batch); + return nullptr; } + + prepareLookupResult(batch); VELOX_CHECK_NOT_NULL(batch.lookupResult); + if (!applyFilterOnLookupResult(batch)) { + VELOX_CHECK_NULL(batch.lookupResult); + // All rows in lookup result are filtered out, and fetch next lookup result + // batch. + return nullptr; + } + SCOPE_EXIT { maybeFinishLookupResult(batch); }; @@ -528,15 +769,79 @@ RowVectorPtr IndexLookupJoin::getOutputFromLookupResult( return produceOutputForLeftJoin(batch); } +RowVectorPtr IndexLookupJoin::produceRemainingOutput(InputBatchState& batch) { + if (hasRemainingOutputForLeftJoin(batch)) { + return produceRemainingOutputForLeftJoin(batch); + } + finishInput(batch); + return nullptr; +} + +void IndexLookupJoin::prepareLookupResult(InputBatchState& batch) { + VELOX_CHECK_NOT_NULL(batch.lookupResult); + if (rawLookupInputHitIndices_ != nullptr) { + return; + } + + if (!batch.lookupInputHasNullKeys) { + rawLookupInputHitIndices_ = + batch.lookupResult->inputHits->as(); + return; + } + VELOX_CHECK_NOT_NULL(batch.nonNullInputMappings); + vector_size_t* rawLookupInputHitIndices = + batch.ensureInputHitsWritable(pool()); + for (auto i = 0; i < batch.lookupResult->size(); ++i) { + rawLookupInputHitIndices[i] = + batch.rawNonNullInputMappings[rawLookupInputHitIndices[i]]; +#ifdef NDEBUG + if (i > 0) { + VELOX_DCHECK_LE( + rawLookupInputHitIndices[i - 1], rawLookupInputHitIndices[i]); + } +#endif + } + rawLookupInputHitIndices_ = rawLookupInputHitIndices; +} + +vector_size_t* IndexLookupJoin::InputBatchState::ensureInputHitsWritable( + memory::MemoryPool* pool) { + VELOX_CHECK_NOT_NULL(lookupResult); + if (lookupResult->inputHits->isMutable()) { + return lookupResult->inputHits->asMutable(); + } + + const auto indicesByteSize = lookupResult->size() * sizeof(vector_size_t); + if ((resultInputHitIndices == nullptr) || + !resultInputHitIndices->isMutable() || + (resultInputHitIndices->capacity() < indicesByteSize)) { + resultInputHitIndices = allocateIndices(indicesByteSize, pool); + } else { + resultInputHitIndices->setSize(indicesByteSize); + } + auto* rawLookupInputHitIndices = + resultInputHitIndices->asMutable(); + std::memcpy( + rawLookupInputHitIndices, + lookupResult->inputHits->as(), + indicesByteSize); + lookupResult->inputHits = resultInputHitIndices; + return rawLookupInputHitIndices; +} + void IndexLookupJoin::maybeFinishLookupResult(InputBatchState& batch) { VELOX_CHECK_NOT_NULL(batch.lookupResult); if (nextOutputResultRow_ == batch.lookupResult->size()) { - batch.lookupResult = nullptr; - nextOutputResultRow_ = 0; - rawLookupInputHitIndices_ = nullptr; + finishLookupResult(batch); } } +void IndexLookupJoin::finishLookupResult(InputBatchState& batch) { + batch.lookupResult = nullptr; + nextOutputResultRow_ = 0; + rawLookupInputHitIndices_ = nullptr; +} + bool IndexLookupJoin::hasRemainingOutputForLeftJoin( const InputBatchState& batch) const { if (joinType_ != core::JoinType::kLeft) { @@ -550,12 +855,14 @@ bool IndexLookupJoin::hasRemainingOutputForLeftJoin( void IndexLookupJoin::finishInput(InputBatchState& batch) { VELOX_CHECK_NOT_NULL(batch.input); - VELOX_CHECK_NOT_NULL(batch.lookupResultIter); + VELOX_CHECK_EQ( + batch.lookupInput->size() == 0, batch.lookupResultIter == nullptr); VELOX_CHECK(!batch.lookupFuture.valid()); batch.input = nullptr; batch.lookupResultIter = nullptr; batch.lookupResult = nullptr; + batch.lookupInputHasNullKeys = false; lastProcessedInputRow_ = std::nullopt; nextOutputResultRow_ = 0; ++startBatchIndex_; @@ -565,10 +872,9 @@ void IndexLookupJoin::finishInput(InputBatchState& batch) { VELOX_CHECK(!nextBatch.empty()); if (nextBatch.lookupResult != nullptr) { VELOX_CHECK(!nextBatch.lookupFuture.valid()); - rawLookupInputHitIndices_ = - nextBatch.lookupResult->inputHits->as(); } else { - VELOX_CHECK(nextBatch.lookupFuture.valid()); + VELOX_CHECK_EQ( + nextBatch.lookupInput->size() != 0, nextBatch.lookupFuture.valid()); } } } @@ -622,12 +928,28 @@ RowVectorPtr IndexLookupJoin::produceOutputForInnerJoin( ->slice(nextOutputResultRow_, numOutputRows); } } - nextOutputResultRow_ += numOutputRows; VELOX_CHECK_LE(nextOutputResultRow_, batch.lookupResult->size()); return output_; } +void IndexLookupJoin::fillOutputMatchRows( + vector_size_t offset, + vector_size_t size, + bool match) { + VELOX_CHECK_EQ(joinType_, core::JoinType::kLeft); + bits::fillBits( + rawLookupOutputNulls_, + offset, + offset + size, + match ? bits::kNotNull : bits::kNull); + if (!hasMarker_) { + return; + } + VELOX_CHECK_NOT_NULL(rawMatchValues_); + bits::fillBits(rawMatchValues_, offset, offset + size, match); +} + RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( const InputBatchState& batch) { VELOX_CHECK_EQ(joinType_, core::JoinType::kLeft); @@ -636,32 +958,39 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( VELOX_CHECK_NOT_NULL(rawLookupInputHitIndices_); VELOX_CHECK_NOT_NULL(batch.input); - prepareOutputRowMappings(outputBatchSize_); + const auto startOutputRow = nextOutputResultRow_; + int32_t lastProcessedInputRow = lastProcessedInputRow_.value_or(-1); + const auto startProcessInputRow = lastProcessedInputRow + 1; + // Set 'maxOutputRows' to max number of output rows that can be produced + // considering the possible missed input rows. + const size_t maxOutputRows = std::min( + outputBatchSize_, + batch.lookupResult->size() - nextOutputResultRow_ + batch.input->size() - + startProcessInputRow); + prepareOutputRowMappings(maxOutputRows); VELOX_CHECK_NOT_NULL(rawLookupOutputNulls_); - size_t numOutputRows{0}; size_t totalMissedInputRows{0}; - int32_t lastProcessedInputRow = lastProcessedInputRow_.value_or(-1); - for (; numOutputRows < outputBatchSize_ && + for (; numOutputRows < maxOutputRows && nextOutputResultRow_ < batch.lookupResult->size();) { VELOX_CHECK_GE( - rawLookupInputHitIndices_[nextOutputResultRow_], lastProcessedInputRow); + rawLookupInputHitIndices_[nextOutputResultRow_], + lastProcessedInputRow, + "nextOutputResultRow_ {}, batch.lookupResult->size() {}", + nextOutputResultRow_, + batch.lookupResult->size()); const vector_size_t numMissedInputRows = rawLookupInputHitIndices_[nextOutputResultRow_] - lastProcessedInputRow - 1; VELOX_CHECK_GE(numMissedInputRows, -1); if (numMissedInputRows > 0) { if (totalMissedInputRows == 0) { - bits::fillBits( - rawLookupOutputNulls_, 0, outputBatchSize_, bits::kNotNull); + ensureMatchColumn(maxOutputRows); + fillOutputMatchRows(0, maxOutputRows, true); } const auto numOutputMissedInputRows = std::min( - numMissedInputRows, outputBatchSize_ - numOutputRows); - bits::fillBits( - rawLookupOutputNulls_, - numOutputRows, - numOutputRows + numOutputMissedInputRows, - bits::kNull); + numMissedInputRows, maxOutputRows - numOutputRows); + fillOutputMatchRows(numOutputRows, numOutputMissedInputRows, false); for (auto i = 0; i < numOutputMissedInputRows; ++i) { rawProbeOutputRowIndices_[numOutputRows++] = ++lastProcessedInputRow; } @@ -677,13 +1006,14 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( ++numOutputRows; } VELOX_CHECK( - numOutputRows == outputBatchSize_ || + numOutputRows == maxOutputRows || nextOutputResultRow_ == batch.lookupResult->size()); VELOX_CHECK_LE(nextOutputResultRow_, batch.lookupResult->size()); lastProcessedInputRow_ = lastProcessedInputRow; if (totalMissedInputRows > 0) { lookupOutputNulls_->setSize(bits::nbytes(numOutputRows)); + setMatchColumnSize(numOutputRows); } probeOutputRowMapping_->setSize(numOutputRows * sizeof(vector_size_t)); lookupOutputRowMapping_->setSize(numOutputRows * sizeof(vector_size_t)); @@ -693,63 +1023,124 @@ RowVectorPtr IndexLookupJoin::produceOutputForLeftJoin( } prepareOutput(numOutputRows); - for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( - nullptr, - probeOutputRowMapping_, - numOutputRows, - batch.input->childAt(projection.inputChannel)); + const auto numInputRows = lastProcessedInputRow - startProcessInputRow + 1; + if (numInputRows == numOutputRows) { + if (startProcessInputRow == 0 && numInputRows == batch.input->size()) { + for (const auto& projection : probeOutputProjections_) { + output_->childAt(projection.outputChannel) = + batch.input->childAt(projection.inputChannel); + } + } else { + for (const auto& projection : probeOutputProjections_) { + output_->childAt(projection.outputChannel) = + batch.input->childAt(projection.inputChannel) + ->slice(startProcessInputRow, numInputRows); + } + } + } else { + for (const auto& projection : probeOutputProjections_) { + output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( + nullptr, + probeOutputRowMapping_, + numOutputRows, + batch.input->childAt(projection.inputChannel)); + } } - for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( - totalMissedInputRows > 0 ? lookupOutputNulls_ : nullptr, - lookupOutputRowMapping_, - numOutputRows, - batch.lookupResult->output->childAt(projection.inputChannel)); + + if (totalMissedInputRows > 0) { + for (const auto& projection : lookupOutputProjections_) { + output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( + lookupOutputNulls_, + lookupOutputRowMapping_, + numOutputRows, + batch.lookupResult->output->childAt(projection.inputChannel)); + } + if (hasMarker_) { + output_->childAt(matchOutputChannel_.value()) = matchColumn_; + } + } else { + if (startOutputRow == 0 && + numOutputRows == batch.lookupResult->output->size()) { + for (const auto& projection : lookupOutputProjections_) { + output_->childAt(projection.outputChannel) = + batch.lookupResult->output->childAt(projection.inputChannel); + } + } else { + for (const auto& projection : lookupOutputProjections_) { + output_->childAt(projection.outputChannel) = + batch.lookupResult->output->childAt(projection.inputChannel) + ->slice(startOutputRow, numOutputRows); + } + } + if (hasMarker_) { + output_->childAt(matchOutputChannel_.value()) = + BaseVector::createConstant(BOOLEAN(), true, numOutputRows, pool()); + } } return output_; } +void IndexLookupJoin::ensureMatchColumn(vector_size_t maxOutputRows) { + if (!hasMarker_) { + return; + } + if (matchColumn_) { + VectorPtr matchColumn = std::move(matchColumn_); + BaseVector::prepareForReuse(matchColumn, maxOutputRows); + matchColumn_ = std::dynamic_pointer_cast>(matchColumn); + } else { + matchColumn_ = + BaseVector::create>(BOOLEAN(), maxOutputRows, pool()); + } + VELOX_CHECK_NOT_NULL(matchColumn_); + rawMatchValues_ = matchColumn_->mutableRawValues(); +} + +void IndexLookupJoin::setMatchColumnSize(vector_size_t numOutputRows) { + if (!hasMarker_) { + return; + } + VELOX_CHECK_NOT_NULL(matchColumn_); + matchColumn_->resize(numOutputRows); +} + RowVectorPtr IndexLookupJoin::produceRemainingOutputForLeftJoin( const InputBatchState& batch) { VELOX_CHECK_EQ(joinType_, core::JoinType::kLeft); VELOX_CHECK(!batch.empty()); VELOX_CHECK(hasRemainingOutputForLeftJoin(batch)); VELOX_CHECK_NULL(rawLookupInputHitIndices_); - prepareOutputRowMappings(outputBatchSize_); - VELOX_CHECK_NOT_NULL(rawLookupOutputNulls_); size_t lastProcessedInputRow = lastProcessedInputRow_.value_or(-1); + const auto startProcessInputRow = lastProcessedInputRow + 1; const size_t numOutputRows = std::min( - outputBatchSize_, batch.input->size() - lastProcessedInputRow - 1); + outputBatchSize_, batch.input->size() - startProcessInputRow); VELOX_CHECK_GT(numOutputRows, 0); - bits::fillBits(rawLookupOutputNulls_, 0, numOutputRows, bits::kNull); - for (auto outputRow = 0; outputRow < numOutputRows; ++outputRow) { - rawProbeOutputRowIndices_[outputRow] = ++lastProcessedInputRow; - } - lookupOutputNulls_->setSize(bits::nbytes(numOutputRows)); - probeOutputRowMapping_->setSize(numOutputRows * sizeof(vector_size_t)); - lookupOutputRowMapping_->setSize(numOutputRows * sizeof(vector_size_t)); - + VELOX_CHECK_LE(numOutputRows, batch.input->size()); prepareOutput(numOutputRows); - for (const auto& projection : probeOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( - nullptr, - probeOutputRowMapping_, - numOutputRows, - batch.input->childAt(projection.inputChannel)); + if (numOutputRows != batch.input->size()) { + for (const auto& projection : probeOutputProjections_) { + output_->childAt(projection.outputChannel) = + batch.input->childAt(projection.inputChannel) + ->slice(startProcessInputRow, numOutputRows); + } + } else { + for (const auto& projection : probeOutputProjections_) { + output_->childAt(projection.outputChannel) = + batch.input->childAt(projection.inputChannel); + } } for (const auto& projection : lookupOutputProjections_) { - output_->childAt(projection.outputChannel) = BaseVector::wrapInDictionary( - lookupOutputNulls_, - lookupOutputRowMapping_, + output_->childAt(projection.outputChannel) = BaseVector::createNullConstant( + output_->type()->childAt(projection.outputChannel), numOutputRows, - BaseVector::createNullConstant( - output_->type()->childAt(projection.outputChannel), - numOutputRows, - pool())); + pool()); } - lastProcessedInputRow_ = lastProcessedInputRow; + if (hasMarker_) { + output_->childAt(matchOutputChannel_.value()) = + BaseVector::createConstant(BOOLEAN(), false, numOutputRows, pool()); + } + lastProcessedInputRow_ = lastProcessedInputRow + numOutputRows; return output_; } @@ -798,6 +1189,111 @@ void IndexLookupJoin::close() { Operator::close(); } +bool IndexLookupJoin::applyFilterOnLookupResult(InputBatchState& batch) { + VELOX_CHECK_NOT_NULL(batch.lookupResult); + if (!filter_) { + return true; + } + if (batch.lookupResult->size() == 0) { + return true; + } + + const auto numResultRows = batch.lookupResult->size(); + + // Prepare filter input vector + filterRows_.resize(numResultRows); + filterRows_.setAll(); + + if (!filterInput_) { + filterInput_ = + BaseVector::create(filterInputType_, numResultRows, pool()); + } else { + VectorPtr filterInputVector = std::move(filterInput_); + BaseVector::prepareForReuse(filterInputVector, numResultRows); + filterInput_ = std::static_pointer_cast(filterInputVector); + } + + // Populate filter input from probe input. + for (const auto& projection : filterProbeInputProjections_) { + // Get the probe input column and dictionary-wrap it with hit indices + filterInput_->childAt(projection.outputChannel) = + BaseVector::wrapInDictionary( + nullptr, + batch.lookupResult->inputHits, + numResultRows, + batch.input->childAt(projection.inputChannel)); + } + + // Populate filter input from lookup result. + for (const auto& projection : filterLookupInputProjections_) { + filterInput_->childAt(projection.outputChannel) = + batch.lookupResult->output->childAt(projection.inputChannel); + } + + // Evaluate filter + filterResult_.resize(1); + EvalCtx evalCtx(operatorCtx_->execCtx(), filter_.get(), filterInput_.get()); + filter_->eval(filterRows_, evalCtx, filterResult_); + decodedFilterResult_.decode(*filterResult_[0], filterRows_); + + const auto indicesByteSize = numResultRows * sizeof(vector_size_t); + if (!filteredIndices_ || !filteredIndices_->isMutable() || + filteredIndices_->capacity() < indicesByteSize) { + filteredIndices_ = allocateIndices(numResultRows, pool()); + } else { + filteredIndices_->setSize(indicesByteSize); + } + auto* rawFilteredIndices = filteredIndices_->asMutable(); + + vector_size_t numPassed{0}; + for (auto i = 0; i < numResultRows; ++i) { + if (!decodedFilterResult_.isNullAt(i) && + decodedFilterResult_.valueAt(i)) { + rawFilteredIndices[numPassed++] = i; + } + } + + if (numPassed == 0) { + finishLookupResult(batch); + return false; + } + + if (numPassed == numResultRows) { + return true; + } + + // Some rows passed - create filtered lookup result. + filteredIndices_->setSize(numPassed * sizeof(vector_size_t)); + + // Update the inputHits buffer. + auto* rawLookupInputHitIndices = batch.ensureInputHitsWritable(pool()); + for (auto i = 0; i < numPassed; ++i) { + rawLookupInputHitIndices[i] = + rawLookupInputHitIndices_[rawFilteredIndices[i]]; +#ifdef NDEBUG + if (i > 0) { + VELOX_DCHECK_LE( + rawLookupInputHitIndices[i - 1], rawLookupInputHitIndices[i]); + } +#endif + } + batch.lookupResult->inputHits->setSize(numPassed * sizeof(vector_size_t)); + rawLookupInputHitIndices_ = rawLookupInputHitIndices; + + // Create the filtered result vector. + auto filteredOutput = BaseVector::create( + batch.lookupResult->output->type(), numPassed, pool()); + for (auto i = 0; i < batch.lookupResult->output->childrenSize(); ++i) { + filteredOutput->childAt(i) = BaseVector::wrapInDictionary( + nullptr, + filteredIndices_, + numPassed, + batch.lookupResult->output->childAt(i)); + } + batch.lookupResult->output = std::move(filteredOutput); + return true; +} + void IndexLookupJoin::recordConnectorStats() { if (indexSource_ == nullptr) { // NOTE: index join might fail to create index source so skip record stats @@ -814,8 +1310,8 @@ void IndexLookupJoin::recordConnectorStats() { const CpuWallTiming backgroundTiming{ static_cast(connectorStats[kConnectorLookupWallTime].count), static_cast(connectorStats[kConnectorLookupWallTime].sum), - // NOTE: this might not be accurate as it doesn't include the time spent - // inside the index storage client. + // NOTE: this might not be accurate as it doesn't include the time + // spent inside the index storage client. static_cast(connectorStats[kConnectorResultPrepareTime].sum) + connectorStats[kClientRequestProcessTime].sum + connectorStats[kClientResultProcessTime].sum}; diff --git a/velox/exec/IndexLookupJoin.h b/velox/exec/IndexLookupJoin.h index 30860f032e89..73700b830bb5 100644 --- a/velox/exec/IndexLookupJoin.h +++ b/velox/exec/IndexLookupJoin.h @@ -15,6 +15,7 @@ */ #pragma once #include "velox/exec/Operator.h" +#include "velox/exec/VectorHasher.h" namespace facebook::velox::exec { @@ -77,6 +78,9 @@ class IndexLookupJoin : public Operator { /// the raw data received from the remote storage lookup. static inline const std::string kClientLookupResultSize{ "clientLookupResultSize"}; + /// The number of lookup results received from remote storage with error. + static inline const std::string kClientNumErrorResults{ + "clientNumErrorResults"}; private: using LookupResultIter = connector::IndexSource::LookupResultIterator; @@ -86,6 +90,16 @@ class IndexLookupJoin : public Operator { struct InputBatchState { // The input batch to process. RowVectorPtr input; + + // If true, it indicates that the probe input has null join keys. + bool lookupInputHasNullKeys{false}; + // Select input rows with non-null join keys. + SelectivityVector nonNullInputRows; + // The map from lookup input row to the corresponding probe input row. It is + // used to handle the case that probe input has null keys. + BufferPtr nonNullInputMappings; + vector_size_t* rawNonNullInputMappings{nullptr}; + // The reusable vector projected from 'input' as index lookup input. RowVectorPtr lookupInput; // Used to fetch lookup results for an input batch. @@ -97,6 +111,19 @@ class IndexLookupJoin : public Operator { // output processing. We might split the output result into multiple output // batches based on the operator's output batch size limit. std::unique_ptr lookupResult; + // Specifies the indices of input row in 'input' that have matches in + // 'output' from 'lookupResult'. This is only used in case + // 'lookupInputHasNullKeys' is true in which 'inputHits' in 'lookupResult' + // points to rows in 'lookupInput' which might be different from 'input'. + // To ease the rest of index lookup join result processing, we need to + // redirect the lookup input hit row from 'lookupInput' to the corresponding + // row in 'input' through the mapping specified by 'nonNullInputMappings'. + // The redirect input hit indices are stored in 'resultInputHitIndices'. + BufferPtr resultInputHitIndices; + // When splitOutput_ is false, this tracks partially accumulated results + // that are waiting for async operations to complete before continuing + // accumulation. + std::vector> partialOutputs; InputBatchState() : lookupFuture(ContinueFuture::makeEmpty()) {} @@ -105,12 +132,20 @@ class IndexLookupJoin : public Operator { lookupResultIter = nullptr; lookupFuture = ContinueFuture::makeEmpty(); lookupResult = nullptr; + partialOutputs.clear(); } // Indicates if this input batch is empty. bool empty() const { return input == nullptr; } + + // Ensures that the lookup result's inputHits buffer is writable and returns + // a mutable pointer. If the buffer is already mutable, returns it directly. + // Otherwise, creates a new writable buffer by copying the existing data and + // returns a pointer to the new buffer. This is needed when filters or null + // key handling requires modifying the input hit indices. + vector_size_t* ensureInputHitsWritable(memory::MemoryPool* pool); }; void initInputBatches(); @@ -118,17 +153,40 @@ class IndexLookupJoin : public Operator { void initLookupInput(); void initLookupOutput(); void initOutputProjections(); + void initFilter(); + + // Applies the join filter directly on the lookup result, updating the + // lookup result to only include rows that pass the filter. Returns true if + // some rows passed the filter, otherwise false. + bool applyFilterOnLookupResult(InputBatchState& batch); + void ensureInputLoaded(const InputBatchState& batch); // Prepare index source lookup for a given 'input_'. void prepareLookup(InputBatchState& batch); void startLookup(InputBatchState& batch); + // Helper function to merge batch.partialOutputs into a single + // batch.lookupResult. This is used when splitOutput_ is false to ensure all + // results from an iterator are combined into one output batch. + void mergeLookupResults(InputBatchState& batch); + // Helper function to get all lookup results. Fetches the first result if not + // already fetched, and when splitOutput_ is false, accumulates all remaining + // results into a single batch. Handles both initial lookup and resuming + // accumulation after async interruption. Returns true if results are ready, + // false if an async operation is pending. + bool getLookupResults(InputBatchState& batch); + void startLookupBlockWait(); void endLookupBlockWait(); RowVectorPtr getOutputFromLookupResult(InputBatchState& batch); RowVectorPtr produceOutputForInnerJoin(const InputBatchState& batch); RowVectorPtr produceOutputForLeftJoin(const InputBatchState& batch); + // Handles production of remaining output after lookup result processing is + // complete. For left joins, this ensures unmatched rows from the probe side + // are included in the output with null values for lookup columns. For inner + // joins, this simply finishes the input batch. + RowVectorPtr produceRemainingOutput(InputBatchState& batch); // Produces output for the remaining input rows that has no matches from the // lookup at the end of current input batch processing. RowVectorPtr produceRemainingOutputForLeftJoin(const InputBatchState& batch); @@ -140,8 +198,10 @@ class IndexLookupJoin : public Operator { bool hasRemainingOutputForLeftJoin(const InputBatchState& batch) const; // Checks if we have finished processing the current 'lookupResult_'. If so, - // we reset 'lookupResult_' and corresponding processing state. + // call 'finishLookupResult' to reset 'lookupResult_' and corresponding + // processing state. void maybeFinishLookupResult(InputBatchState& batch); + void finishLookupResult(InputBatchState& batch); // Invoked after finished processing the current 'input_' batch. The function // resets the input batch and the lookup result states. @@ -151,9 +211,31 @@ class IndexLookupJoin : public Operator { // 'outputBatchSize'. This is only used by left join which needs to fill nulls // for output rows without lookup matches. void prepareOutputRowMappings(size_t outputBatchSize); + // Prepare 'output_' for the next output batch with size of 'numOutputRows'. void prepareOutput(vector_size_t numOutputRows); + // Invoked to ensure the match column is created to store the output match + // result for the left join. + void ensureMatchColumn(vector_size_t maxOutputRows); + + // Invoked to fill the match column and output nulls with the match result for + // the left join. + void + fillOutputMatchRows(vector_size_t offset, vector_size_t size, bool match); + + // Invoked to set the match column with the actual output size. + void setMatchColumnSize(vector_size_t numOutputRows); + + // Invoked to decode the probe input keys to detect if there are any null + // keys. + void decodeAndDetectNonNullKeys(InputBatchState& batch); + + // Invoked to prepare the lookup result for processing. If the probe input has + // null keys, it maps the hit rows in lookup result to the corresponding probe + // input rows. + void prepareLookupResult(InputBatchState& batch); + // Invoked at operator close to record the lookup stats. void recordConnectorStats(); @@ -186,17 +268,20 @@ class IndexLookupJoin : public Operator { return inputBatches_[startBatchIndex_ % maxNumInputBatches_]; } + // If true, allows one input row to produce multiple output rows. + // If false, enforces one-to-one mapping. + const bool splitOutput_; // Maximum number of rows in the output batch. const vector_size_t outputBatchSize_; // Type of join. const core::JoinType joinType_; + const bool hasMarker_; const size_t numKeys_; const RowTypePtr probeType_; const RowTypePtr lookupType_; - const std::shared_ptr lookupTableHandle_; - const std::vector lookupConditions_; - std::unordered_map> - lookupColumnHandles_; + const connector::ConnectorTableHandlePtr lookupTableHandle_; + const std::vector joinConditions_; + const connector::ColumnHandleMap lookupColumnHandles_; const std::shared_ptr connectorQueryCtx_; const std::shared_ptr connector_; const size_t maxNumInputBatches_; @@ -211,6 +296,10 @@ class IndexLookupJoin : public Operator { // The column channels in probe 'input_' referenced by 'lookupInputType_'. std::vector lookupInputChannels_; + // Used to decode and check if any probe-side input key or condition columns + // have nulls. + std::vector> lookupKeyOrConditionHashers_; + // The input batches to process with ranges pointed by 'startBatchIndex_' and // 'endBatchIndex_'. std::vector inputBatches_; @@ -227,6 +316,7 @@ class IndexLookupJoin : public Operator { // Used to project output columns from the probe input and lookup output. std::vector probeOutputProjections_; std::vector lookupOutputProjections_; + std::optional matchOutputChannel_; std::shared_ptr indexSource_; @@ -253,8 +343,28 @@ class IndexLookupJoin : public Operator { BufferPtr lookupOutputNulls_; uint64_t* rawLookupOutputNulls_{nullptr}; + // Join filter. + std::unique_ptr filter_; + + // Join filter input type. + RowTypePtr filterInputType_; + + // Maps probe-side input channels to channels in 'filterInputType_'. + std::vector filterProbeInputProjections_; + // Maps lookup-side input channels to channels in 'filterInputType_', + std::vector filterLookupInputProjections_; + + // Reusable memory for filter evaluations. + RowVectorPtr filterInput_; + SelectivityVector filterRows_; + std::vector filterResult_; + DecodedVector decodedFilterResult_; + BufferPtr filteredIndices_; + // The reusable output vector for the join output. RowVectorPtr output_; + FlatVectorPtr matchColumn_{nullptr}; + uint64_t* rawMatchValues_{nullptr}; // The start time of the current lookup driver block wait, and reset after the // driver wait completes. diff --git a/velox/exec/Limit.cpp b/velox/exec/Limit.cpp index adab0530b30d..0ed5bbf73861 100644 --- a/velox/exec/Limit.cpp +++ b/velox/exec/Limit.cpp @@ -37,6 +37,10 @@ Limit::Limit( } } +bool Limit::startDrain() { + return false; +} + bool Limit::needsInput() const { return !finished_ && input_ == nullptr; } @@ -47,15 +51,19 @@ void Limit::addInput(RowVectorPtr input) { } RowVectorPtr Limit::getOutput() { - if (input_ == nullptr || (remainingOffset_ == 0 && remainingLimit_ == 0)) { + VELOX_DCHECK(!isDraining()); + + if ((input_ == nullptr) || (remainingOffset_ == 0 && remainingLimit_ == 0)) { return nullptr; } + SCOPE_EXIT { + input_ = nullptr; + }; const auto inputSize = input_->size(); if (remainingOffset_ >= inputSize) { remainingOffset_ -= inputSize; - input_ = nullptr; return nullptr; } @@ -71,7 +79,6 @@ RowVectorPtr Limit::getOutput() { auto output = fillOutput(outputSize, indices); remainingOffset_ = 0; remainingLimit_ -= outputSize; - input_ = nullptr; if (remainingLimit_ == 0) { finished_ = true; } @@ -85,7 +92,6 @@ RowVectorPtr Limit::getOutput() { if (remainingLimit_ >= inputSize) { remainingLimit_ -= inputSize; auto output = input_; - input_.reset(); return output; } @@ -95,7 +101,6 @@ RowVectorPtr Limit::getOutput() { input_->nulls(), remainingLimit_, input_->children()); - input_.reset(); remainingLimit_ = 0; return output; } diff --git a/velox/exec/Limit.h b/velox/exec/Limit.h index 7e38eda64090..24a5c00fc508 100644 --- a/velox/exec/Limit.h +++ b/velox/exec/Limit.h @@ -28,6 +28,8 @@ class Limit : public Operator { void addInput(RowVectorPtr input) override; + bool startDrain() override; + RowVectorPtr getOutput() override; BlockingReason isBlocked(ContinueFuture* /*future*/) override { diff --git a/velox/exec/LocalPartition.cpp b/velox/exec/LocalPartition.cpp index 1fbce8b8960c..ec682dabf59f 100644 --- a/velox/exec/LocalPartition.cpp +++ b/velox/exec/LocalPartition.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/LocalPartition.h" +#include "velox/common/Casts.h" #include "velox/exec/Task.h" #include "velox/vector/EncodedVectorCopy.h" @@ -420,6 +421,7 @@ RowVectorPtr LocalPartition::wrapChildren( void LocalPartition::copy( const RowVectorPtr& input, const folly::Range& ranges, + const size_t partition, VectorPtr& target) { if (ranges.empty()) { return; @@ -432,57 +434,153 @@ void LocalPartition::copy( } if (!target) { - target = BaseVector::create(outputType_, 0, pool()); + target = getOrCreateVector(partition); } target->resize(target->size() + ranges.size()); target->copyRanges(input.get(), ranges); } -RowVectorPtr LocalPartition::processPartition( +VectorPtr LocalPartition::getOrCreateVector(const size_t partition) { + auto reusable = queues_[partition]->getVector(); + if (reusable) { + VELOX_CHECK_EQ(reusable->type(), outputType_); + reusable->unsafeResize(0); + for (auto i = 0; i < reusable->childrenSize(); ++i) { + reusable->childAt(i) = nullptr; + } + return reusable; + } else { + return BaseVector::create(outputType_, 0, pool()); + } +} + +void LocalPartition::populatePartitionBuffer( const RowVectorPtr& input, - vector_size_t size, - int partition, - const BufferPtr& indices, - const vector_size_t* rawIndices) { + const vector_size_t numPartitionRows, + const size_t partition, + const vector_size_t* rawIndices, + uint64_t& totalPartitionBufferSizeExcludingString, + uint64_t& totalPartitionStringBufferSize) { + VELOX_CHECK_GT(singlePartitionBufferSize_, 0); + copyRanges_.resize(numPartitionRows); + + auto& partitionBuffer = partitionBuffers_[partition]; + auto targetIndex = 0; + if (partitionBuffer) { + targetIndex = partitionBuffer->size(); + } + for (int i = 0; i < numPartitionRows; i++) { + copyRanges_[i] = {rawIndices[i], targetIndex, 1}; + targetIndex++; + } + + copy(input, copyRanges_, partition, partitionBuffer); + + if (partitionBuffer) { + uint64_t stringBufferSize{0}; + auto totalSize = partitionBuffer->retainedSize(stringBufferSize); + totalPartitionBufferSizeExcludingString += totalSize - stringBufferSize; + totalPartitionStringBufferSize += stringBufferSize; + } +} + +RowVectorPtr LocalPartition::createPartition( + const RowVectorPtr& input, + const vector_size_t numPartitionRows, + const size_t partition, + const BufferPtr& indices) { RowVectorPtr partitionData{nullptr}; if (singlePartitionBufferSize_ > 0) { - if (partitionBuffers_.empty()) { - partitionBuffers_.resize(numPartitions_); - } - if (copyRanges_.size() < size) { - copyRanges_.resize(size); - } - auto& partitionBuffer = partitionBuffers_[partition]; - auto targetIndex = 0; if (partitionBuffer) { - targetIndex = partitionBuffer->size(); - } - for (int i = 0; i < size; i++) { - copyRanges_[i] = {rawIndices[i], targetIndex, 1}; - targetIndex++; + partitionData = + checkedPointerCast(partitionBuffer); + partitionBuffers_[partition] = nullptr; } + } else if (numPartitionRows > 0) { + partitionData = wrapChildren( + input, numPartitionRows, indices, queues_[partition]->getVector()); + } + return partitionData; +} - copy( - input, - folly::Range{copyRanges_.data(), static_cast(size)}, - partitionBuffer); - - if (partitionBuffer && - partitionBuffer->retainedSize() >= singlePartitionBufferSize_) { - partitionData = std::dynamic_pointer_cast(partitionBuffer); - VELOX_CHECK(partitionData); - partitionBuffers_[partition] = nullptr; +void LocalPartition::populateAndEnqueuePartitions( + RowVectorPtr input, + const std::vector& numRowsPerPartition, + const std::vector& indexBuffers, + const std::vector& rawIndicesBuffers) { + uint64_t totalPartitionBufferSizeExcludingString = 0; + uint64_t totalPartitionStringBufferSize = 0; + uint16_t nonEmptyPartitionCount = 0; + + // Populate partition buffers if in buffer mode. + if (singlePartitionBufferSize_ > 0) { + if (partitionBuffers_.empty()) { + partitionBuffers_.resize(numPartitions_); + } + for (auto partition = 0; partition < numPartitions_; partition++) { + populatePartitionBuffer( + input, + numRowsPerPartition[partition], + partition, + rawIndicesBuffers[partition], + totalPartitionBufferSizeExcludingString, + totalPartitionStringBufferSize); + if (partitionBuffers_[partition]) { + nonEmptyPartitionCount++; + } } } else { - partitionData = - wrapChildren(input, size, indices, queues_[partition]->getVector()); + nonEmptyPartitionCount = numPartitions_ - + std::count(numRowsPerPartition.begin(), numRowsPerPartition.end(), 0); + } + VELOX_CHECK_GT( + nonEmptyPartitionCount, + 0, + "Input rows should be assigned to at least one partition"); + + // Calculate the partition buffer size across all partitions with amortized + // string buffer sizes. + auto balancedTotalPartitionBufferSize = + totalPartitionBufferSizeExcludingString + + (totalPartitionStringBufferSize / nonEmptyPartitionCount); + auto inputRetainedSize = input->retainedSize(); + + // Enqueue all partitions if one of the following conditions is met: + // 1. This operator is not in buffer mode. + // 2. This operator is in buffer mode and the total buffer size across all + // partitions exceeds 'singlePartitionBufferSize_ * numPartitions_'. + if (singlePartitionBufferSize_ == 0 || + balancedTotalPartitionBufferSize >= + singlePartitionBufferSize_ * numPartitions_) { + auto perPartitionAmortizedSize = + (singlePartitionBufferSize_ > 0 ? balancedTotalPartitionBufferSize + : inputRetainedSize) / + nonEmptyPartitionCount; + for (auto partition = 0; partition < numPartitions_; partition++) { + auto partitionSize = numRowsPerPartition[partition]; + auto partitionData = createPartition( + input, partitionSize, partition, indexBuffers[partition]); + if (!partitionData) { + continue; + } + + ContinueFuture future; + auto reason = queues_[partition]->enqueue( + std::move(partitionData), perPartitionAmortizedSize, &future); + if (reason != BlockingReason::kNotBlocked) { + blockingReasons_.push_back(reason); + futures_.push_back(std::move(future)); + } + } } - return partitionData; } void LocalPartition::addInput(RowVectorPtr input) { prepareForInput(input); + if (input->size() == 0) { + return; + } const auto singlePartition = numPartitions_ == 1 ? 0 @@ -512,31 +610,7 @@ void LocalPartition::addInput(RowVectorPtr input) { ++maxIndex[partition]; } - const int64_t totalSize = input->retainedSize(); - for (auto partition = 0; partition < numPartitions_; partition++) { - auto partitionSize = maxIndex[partition]; - if (partitionSize == 0) { - // Do not enqueue empty partitions. - continue; - } - - auto partitionData = processPartition( - input, - partitionSize, - partition, - indexBuffers_[partition], - rawIndices_[partition]); - - if (partitionData) { - ContinueFuture future; - auto reason = queues_[partition]->enqueue( - partitionData, totalSize * partitionSize / numInput, &future); - if (reason != BlockingReason::kNotBlocked) { - blockingReasons_.push_back(reason); - futures_.push_back(std::move(future)); - } - } - } + populateAndEnqueuePartitions(input, maxIndex, indexBuffers_, rawIndices_); } void LocalPartition::prepareForInput(RowVectorPtr& input) { @@ -566,19 +640,36 @@ BlockingReason LocalPartition::isBlocked(ContinueFuture* future) { void LocalPartition::noMoreInput() { Operator::noMoreInput(); if (!partitionBuffers_.empty()) { + uint64_t totalPartitionBufferSizeExcludingString = 0; + uint64_t totalPartitionStringBufferSize = 0; + uint16_t nonEmptyPartitionCount = 0; for (auto partition = 0; partition < numPartitions_; partition++) { - if (partitionBuffers_[partition] && - partitionBuffers_[partition]->size() > 0) { - auto partitionData = - std::dynamic_pointer_cast(partitionBuffers_[partition]); - VELOX_CHECK(partitionData); - ContinueFuture future; - queues_[partition]->enqueue( - partitionData, - partitionBuffers_[partition]->retainedSize(), - &future); + if (partitionBuffers_[partition]) { + uint64_t stringBufferSize{0}; + auto totalSize = + partitionBuffers_[partition]->retainedSize(stringBufferSize); + totalPartitionBufferSizeExcludingString += totalSize - stringBufferSize; + totalPartitionStringBufferSize += stringBufferSize; + nonEmptyPartitionCount++; + } + } + if (nonEmptyPartitionCount > 0) { + auto balancedPartitionBufferSize = + totalPartitionBufferSizeExcludingString + + (totalPartitionStringBufferSize / nonEmptyPartitionCount); + for (auto partition = 0; partition < numPartitions_; partition++) { + if (partitionBuffers_[partition]) { + auto partitionData = checkedPointerCast( + partitionBuffers_[partition]); + ContinueFuture future; + + queues_[partition]->enqueue( + partitionData, + balancedPartitionBufferSize / nonEmptyPartitionCount, + &future); + } + partitionBuffers_[partition] = nullptr; } - partitionBuffers_[partition] = nullptr; } partitionBuffers_.resize(0); copyRanges_.resize(0); diff --git a/velox/exec/LocalPartition.h b/velox/exec/LocalPartition.h index bdf7ab42df5e..f2b1da8de504 100644 --- a/velox/exec/LocalPartition.h +++ b/velox/exec/LocalPartition.h @@ -240,12 +240,21 @@ class LocalPartition : public Operator { void allocateIndexBuffers(const std::vector& sizes); - RowVectorPtr processPartition( - const RowVectorPtr& input, - vector_size_t size, - int partition, - const BufferPtr& indices, - const vector_size_t* rawIndices); + /// Create partitions from 'input' according to 'numRowsPerPartition' and + /// 'indexBuffers', and enqueue the partitions to LocalExchangeQueues. The + /// behavior of partition vector creation varies depending on + /// 'singlePartitionBufferSize_'. If 'singlePartitionBufferSize_' is non-zero, + /// append rows from 'input' to 'partitionBuffers_' for every partition. When + /// the total size of all partition buffer vectors exceeds + /// 'singlePartitionBufferSize_ * numPartitions_', flush all partitionBuffers_ + /// vectors to LocalExchangeQueues. If 'singlePartitionBufferSize_' is zero, + /// create partition vectors by wrapping 'input' with indexBuffers and flush + /// them to LocalExchangeQueues immediately. + void populateAndEnqueuePartitions( + RowVectorPtr input, + const std::vector& numRowsPerPartition, + const std::vector& indexBuffers, + const std::vector& rawIndicesBuffers); const std::vector> queues_; const size_t numPartitions_; @@ -261,6 +270,11 @@ class LocalPartition : public Operator { std::vector rawIndices_; private: + // Try getting a reusable vector for 'partition' from the corresponding + // local-exchange vector pool of this partition. If none is available, create + // a new vector. + VectorPtr getOrCreateVector(const size_t partition); + RowVectorPtr wrapChildren( const RowVectorPtr& input, vector_size_t size, @@ -270,8 +284,30 @@ class LocalPartition : public Operator { void copy( const RowVectorPtr& input, const folly::Range& ranges, + const size_t partition, VectorPtr& target); + /// Add rows from 'input' to 'partitionBuffers_' that every row belongs to. + /// Also set 'totalPartitionBufferSizeExcludingString' to be the total size of + /// all partition buffer vectors excluding string buffers inside them, and set + /// 'totalPartitionStringBufferSize' to be the total size of all string + /// buffers in all partition buffer vectors. + void populatePartitionBuffer( + const RowVectorPtr& input, + const vector_size_t numPartitionRows, + const size_t partition, + const vector_size_t* rawIndices, + uint64_t& totalPartitionBufferSizeExcludingString, + uint64_t& totalPartitionStringBufferSize); + + /// Return the partition vector to be added to LocalExchangeQueue. This method + /// returns nullptr if no row belongs to 'partition' + RowVectorPtr createPartition( + const RowVectorPtr& input, + const vector_size_t numPartitionRows, + const size_t partition, + const BufferPtr& indices); + const uint64_t singlePartitionBufferSize_; std::vector copyRanges_; std::vector partitionBuffers_; diff --git a/velox/exec/LocalPlanner.cpp b/velox/exec/LocalPlanner.cpp index ae0361d3a8e7..10525a92917e 100644 --- a/velox/exec/LocalPlanner.cpp +++ b/velox/exec/LocalPlanner.cpp @@ -35,10 +35,13 @@ #include "velox/exec/NestedLoopJoinProbe.h" #include "velox/exec/OperatorTraceScan.h" #include "velox/exec/OrderBy.h" +#include "velox/exec/ParallelProject.h" #include "velox/exec/PartitionedOutput.h" #include "velox/exec/RoundRobinPartitionFunction.h" #include "velox/exec/RowNumber.h" #include "velox/exec/ScaleWriterLocalPartition.h" +#include "velox/exec/SpatialJoinBuild.h" +#include "velox/exec/SpatialJoinProbe.h" #include "velox/exec/StreamingAggregation.h" #include "velox/exec/TableScan.h" #include "velox/exec/TableWriteMerge.h" @@ -121,7 +124,10 @@ OperatorSupplier makeOperatorSupplier( std::dynamic_pointer_cast(planNode)) { return [localMerge](int32_t operatorId, DriverCtx* ctx) { auto mergeSource = ctx->task->addLocalMergeSource( - ctx->splitGroupId, localMerge->id(), localMerge->outputType()); + ctx->splitGroupId, + localMerge->id(), + localMerge->outputType(), + ctx->queryConfig().localMergeSourceQueueSize()); auto consumerCb = [mergeSource]( RowVectorPtr input, bool drained, ContinueFuture* future) { @@ -160,7 +166,7 @@ OperatorSupplier makeOperatorSupplier( VELOX_UNSUPPORTED( "Hash join currently does not support mixed grouped execution for join " "type {}", - core::joinTypeName(join->joinType())); + core::JoinTypeName::toName(join->joinType())); } return std::make_unique(operatorId, ctx, join); }; @@ -173,6 +179,13 @@ OperatorSupplier makeOperatorSupplier( }; } + if (auto join = + std::dynamic_pointer_cast(planNode)) { + return [join](int32_t operatorId, DriverCtx* ctx) { + return std::make_unique(operatorId, ctx, join); + }; + } + if (auto join = std::dynamic_pointer_cast(planNode)) { auto planNodeId = planNode->id(); @@ -190,7 +203,15 @@ OperatorSupplier makeOperatorSupplier( return source->enqueue(std::move(input), future); } }; - return std::make_unique(operatorId, ctx, consumer); + // NOTE: Pass planNodeId to associate CallbackSink with the MergeJoin + // node for proper operator identification and input collection. + // Operator::maybeSetTracer() uses this to enable tracing. + return std::make_unique( + operatorId, + ctx, + consumer, + nullptr, + ctx->queryConfig().queryTraceEnabled() ? planNodeId : "N/A"); }; } @@ -204,10 +225,11 @@ void plan( OperatorSupplier operatorSupplier, std::vector>* driverFactories) { if (!currentPlanNodes) { - driverFactories->push_back(std::make_unique()); - currentPlanNodes = &driverFactories->back()->planNodes; - driverFactories->back()->operatorSupplier = std::move(operatorSupplier); - driverFactories->back()->consumerNode = consumerNode; + auto driverFactory = std::make_unique(); + currentPlanNodes = &driverFactory->planNodes; + driverFactory->operatorSupplier = std::move(operatorSupplier); + driverFactory->consumerNode = consumerNode; + driverFactories->push_back(std::move(driverFactory)); } const auto& sources = planNode->sources(); @@ -298,8 +320,9 @@ uint32_t maxDrivers( return 1; } else if ( auto join = std::dynamic_pointer_cast(node)) { - // Right semi project doesn't support multi-threaded execution. - if (join->isRightSemiProjectJoin()) { + // Null-aware right semi project doesn't support multi-threaded + // execution. + if (join->isRightSemiProjectJoin() && join->isNullAware()) { return 1; } } else if ( @@ -351,7 +374,7 @@ void LocalPlanner::plan( planFragment.planNode, nullptr, nullptr, - detail::makeOperatorSupplier(consumerSupplier), + detail::makeOperatorSupplier(std::move(consumerSupplier)), driverFactories); (*driverFactories)[0]->outputDriver = true; @@ -401,14 +424,14 @@ void LocalPlanner::determineGroupedExecutionPipelines( size_t numGroupedExecutionSources{0}; for (const auto& sourceNode : localPartitionNode->sources()) { for (auto& anotherFactory : driverFactories) { - if (sourceNode == anotherFactory->planNodes.back() and + if (sourceNode == anotherFactory->planNodes.back() && anotherFactory->groupedExecution) { ++numGroupedExecutionSources; break; } } } - if (numGroupedExecutionSources > 0 and + if (numGroupedExecutionSources > 0 && numGroupedExecutionSources == localPartitionNode->sources().size()) { factory->groupedExecution = true; } @@ -456,6 +479,11 @@ void LocalPlanner::markMixedJoinBridges( break; } } + } else if ( + auto spatialJoinNode = + std::dynamic_pointer_cast( + planNode)) { + VELOX_FAIL("Spatial joins do not support grouped execution."); } } } @@ -464,6 +492,7 @@ void LocalPlanner::markMixedJoinBridges( std::shared_ptr DriverFactory::createDriver( std::unique_ptr ctx, std::shared_ptr exchangeClient, + std::shared_ptr filters, std::function numDrivers) { auto driver = std::shared_ptr(new Driver()); ctx->driver = driver.get(); @@ -481,8 +510,9 @@ std::shared_ptr DriverFactory::createDriver( auto next = planNodes[i + 1]; if (auto projectNode = std::dynamic_pointer_cast(next)) { - operators.push_back(std::make_unique( - id, ctx.get(), filterNode, projectNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), filterNode, projectNode)); i++; continue; } @@ -494,6 +524,12 @@ std::shared_ptr DriverFactory::createDriver( std::dynamic_pointer_cast(planNode)) { operators.push_back( std::make_unique(id, ctx.get(), nullptr, projectNode)); + } else if ( + auto projectNode = + std::dynamic_pointer_cast( + planNode)) { + operators.push_back( + std::make_unique(id, ctx.get(), projectNode)); } else if ( auto valuesNode = std::dynamic_pointer_cast(planNode)) { @@ -517,8 +553,9 @@ std::shared_ptr DriverFactory::createDriver( auto tableWriteMergeNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, ctx.get(), tableWriteMergeNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), tableWriteMergeNode)); } else if ( auto mergeExchangeNode = std::dynamic_pointer_cast( @@ -530,14 +567,16 @@ std::shared_ptr DriverFactory::createDriver( std::dynamic_pointer_cast(planNode)) { // NOTE: the exchange client can only be used by one operator in a driver. VELOX_CHECK_NOT_NULL(exchangeClient); - operators.push_back(std::make_unique( - id, ctx.get(), exchangeNode, std::move(exchangeClient))); + operators.push_back( + std::make_unique( + id, ctx.get(), exchangeNode, std::move(exchangeClient))); } else if ( auto partitionedOutputNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, ctx.get(), partitionedOutputNode, eagerFlush(*planNode))); + operators.push_back( + std::make_unique( + id, ctx.get(), partitionedOutputNode, eagerFlush(*planNode))); } else if ( auto joinNode = std::dynamic_pointer_cast(planNode)) { @@ -548,6 +587,11 @@ std::shared_ptr DriverFactory::createDriver( planNode)) { operators.push_back( std::make_unique(id, ctx.get(), joinNode)); + } else if ( + auto spatialJoinNode = + std::dynamic_pointer_cast(planNode)) { + operators.push_back( + std::make_unique(id, ctx.get(), spatialJoinNode)); } else if ( auto joinNode = std::dynamic_pointer_cast( @@ -558,8 +602,9 @@ std::shared_ptr DriverFactory::createDriver( auto aggregationNode = std::dynamic_pointer_cast(planNode)) { if (aggregationNode->isPreGrouped()) { - operators.push_back(std::make_unique( - id, ctx.get(), aggregationNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), aggregationNode)); } else { operators.push_back( std::make_unique(id, ctx.get(), aggregationNode)); @@ -622,12 +667,13 @@ std::shared_ptr DriverFactory::createDriver( auto localPartitionNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, - ctx.get(), - localPartitionNode->outputType(), - localPartitionNode->id(), - ctx->partitionId)); + operators.push_back( + std::make_unique( + id, + ctx.get(), + localPartitionNode->outputType(), + localPartitionNode->id(), + ctx->partitionId)); } else if ( auto unnest = std::dynamic_pointer_cast(planNode)) { @@ -642,17 +688,19 @@ std::shared_ptr DriverFactory::createDriver( auto assignUniqueIdNode = std::dynamic_pointer_cast( planNode)) { - operators.push_back(std::make_unique( - id, - ctx.get(), - assignUniqueIdNode, - assignUniqueIdNode->taskUniqueId(), - assignUniqueIdNode->uniqueIdCounter())); + operators.push_back( + std::make_unique( + id, + ctx.get(), + assignUniqueIdNode, + assignUniqueIdNode->taskUniqueId(), + assignUniqueIdNode->uniqueIdCounter())); } else if ( const auto traceScanNode = std::dynamic_pointer_cast(planNode)) { - operators.push_back(std::make_unique( - id, ctx.get(), traceScanNode)); + operators.push_back( + std::make_unique( + id, ctx.get(), traceScanNode)); } else { std::unique_ptr extended; if (planNode->requiresExchangeClient()) { @@ -672,6 +720,11 @@ std::shared_ptr DriverFactory::createDriver( operators.push_back(operatorSupplier(operators.size(), ctx.get())); } + if (filters->empty()) { + filters->resize(operators.size()); + } else { + VELOX_CHECK_EQ(filters->size(), operators.size()); + } driver->init(std::move(ctx), std::move(operators)); for (auto& adapter : adapters) { if (adapter.adapt(*this, *driver)) { @@ -679,6 +732,7 @@ std::shared_ptr DriverFactory::createDriver( } } driver->isAdaptable_ = false; + driver->pushdownFilters_ = std::move(filters); return driver; } @@ -762,6 +816,19 @@ std::vector DriverFactory::needsNestedLoopJoinBridges() return planNodeIds; } +std::vector DriverFactory::needsSpatialJoinBridges() const { + std::vector planNodeIds; + for (const auto& planNode : planNodes) { + if (auto joinNode = + std::dynamic_pointer_cast(planNode)) { + // Grouped execution pipelines should not create cross-mode bridges. + planNodeIds.emplace_back(joinNode->id()); + } + } + + return planNodeIds; +} + // static void DriverFactory::registerAdapter(DriverAdapter adapter) { adapters.push_back(std::move(adapter)); diff --git a/velox/exec/MemoryReclaimer.cpp b/velox/exec/MemoryReclaimer.cpp index 4aee429883db..1837da7881b9 100644 --- a/velox/exec/MemoryReclaimer.cpp +++ b/velox/exec/MemoryReclaimer.cpp @@ -108,9 +108,10 @@ uint64_t ParallelMemoryReclaimer::reclaim( if (!reclaimableBytesOpt.has_value()) { continue; } - candidates.push_back(Candidate{ - std::move(child), - static_cast(reclaimableBytesOpt.value())}); + candidates.push_back( + Candidate{ + std::move(child), + static_cast(reclaimableBytesOpt.value())}); } } } @@ -134,21 +135,23 @@ uint64_t ParallelMemoryReclaimer::reclaim( if (candidate.reclaimableBytes == 0) { continue; } - reclaimTasks.push_back(memory::createAsyncMemoryReclaimTask( - [&, reclaimPool = candidate.pool]() { - try { - Stats reclaimStats; - const auto bytes = - reclaimPool->reclaim(targetBytes, maxWaitMs, reclaimStats); - return std::make_unique( - bytes, std::move(reclaimStats)); - } catch (const std::exception& e) { - VELOX_MEM_LOG(ERROR) << "Reclaim from memory pool " << pool->name() - << " failed: " << e.what(); - // The exception is captured and thrown by the caller. - return std::make_unique(std::current_exception()); - } - })); + reclaimTasks.push_back( + memory::createAsyncMemoryReclaimTask( + [&, reclaimPool = candidate.pool]() { + try { + Stats reclaimStats; + const auto bytes = + reclaimPool->reclaim(targetBytes, maxWaitMs, reclaimStats); + return std::make_unique( + bytes, std::move(reclaimStats)); + } catch (const std::exception& e) { + VELOX_MEM_LOG(ERROR) << "Reclaim from memory pool " + << pool->name() << " failed: " << e.what(); + // The exception is captured and thrown by the caller. + return std::make_unique( + std::current_exception()); + } + })); if (reclaimTasks.size() > 1) { executor_->add([source = reclaimTasks.back()]() { source->prepare(); }); } diff --git a/velox/exec/Merge.cpp b/velox/exec/Merge.cpp index 51b46fdcaca7..fda9deb44cc0 100644 --- a/velox/exec/Merge.cpp +++ b/velox/exec/Merge.cpp @@ -15,25 +15,15 @@ */ #include "velox/exec/Merge.h" +#include +#include #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { -namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - return options; -} -} // namespace Merge::Merge( int32_t operatorId, @@ -43,47 +33,42 @@ Merge::Merge( sortingKeys, const std::vector& sortingOrders, const std::string& planNodeId, - const std::string& operatorType) + const std::string& operatorType, + const std::optional& spillConfig) : SourceOperator( driverCtx, std::move(outputType), operatorId, planNodeId, - operatorType), - outputBatchSize_{outputBatchRows()} { - auto numKeys = sortingKeys.size(); - sortingKeys_.reserve(numKeys); - for (int i = 0; i < numKeys; ++i) { - auto channel = exprToChannel(sortingKeys[i].get(), outputType_); - VELOX_CHECK_NE( - channel, - kConstantChannel, - "Merge doesn't allow constant grouping keys"); - sortingKeys_.emplace_back( - channel, - CompareFlags{ - sortingOrders[i].isNullsFirst(), - sortingOrders[i].isAscending(), - false}); - } -} - -void Merge::initializeTreeOfLosers() { - std::vector> sourceCursors; - sourceCursors.reserve(sources_.size()); - for (auto& source : sources_) { - sourceCursors.push_back(std::make_unique( - source.get(), sortingKeys_, outputBatchSize_)); - } - - // Save the pointers to cursors before moving these into the TreeOfLosers. - streams_.reserve(sources_.size()); - for (auto& cursor : sourceCursors) { - streams_.push_back(cursor.get()); - } - - treeOfLosers_ = - std::make_unique>(std::move(sourceCursors)); + operatorType, + spillConfig), + maxOutputBatchRows_{outputBatchRows()}, + maxOutputBatchBytes_{ + driverCtx->queryConfig().preferredOutputBatchBytes()}, + sortingKeys_([&]() { + auto numKeys = sortingKeys.size(); + std::vector keys; + keys.reserve(numKeys); + for (int i = 0; i < numKeys; ++i) { + auto channel = exprToChannel(sortingKeys[i].get(), outputType_); + VELOX_CHECK_NE( + channel, + kConstantChannel, + "Merge doesn't allow constant grouping keys"); + keys.emplace_back( + channel, + CompareFlags{ + sortingOrders[i].isNullsFirst(), + sortingOrders[i].isAscending(), + false}); + } + return keys; + }()) {} + +void Merge::initialize() { + Operator::initialize(); + VELOX_CHECK_EQ(mergeStats_.streamingSourceReadStartTimeUs, 0); + mergeStats_.streamingSourceReadStartTimeUs = getCurrentTimeMicro(); } BlockingReason Merge::isBlocked(ContinueFuture* future) { @@ -101,17 +86,10 @@ BlockingReason Merge::isBlocked(ContinueFuture* future) { return BlockingReason::kNotBlocked; } - startSources(); + maybeStartNextMergeSourceGroup(); - // No merging is needed if there is only one source. - if (streams_.empty() && sources_.size() > 1) { - initializeTreeOfLosers(); - } - - if (sourceBlockingFutures_.empty()) { - for (auto& cursor : streams_) { - cursor->isBlocked(sourceBlockingFutures_); - } + if (sourceMerger_ != nullptr) { + sourceMerger_->isBlocked(sourceBlockingFutures_); } if (sourceBlockingFutures_.empty()) { @@ -123,25 +101,171 @@ BlockingReason Merge::isBlocked(ContinueFuture* future) { return BlockingReason::kWaitForProducer; } -void Merge::startSources() { - VELOX_CHECK_LE(numStartedSources_, sources_.size()); - // Start the merge source once. - if (numStartedSources_ >= sources_.size()) { +bool Merge::isFinished() { + return finished_; +} + +void Merge::maybeSetupOutputSpiller() { + VELOX_CHECK(canSpill()); + VELOX_CHECK(spillConfig_.has_value()); + if (mergeOutputSpiller_ != nullptr) { return; } - VELOX_CHECK_EQ(numStartedSources_, 0); - VELOX_CHECK(streams_.empty()); - VELOX_CHECK(sourceBlockingFutures_.empty()); - // TODO: support lazy start for local merge with a large number of sources - // to cap the memory usage. - for (auto& source : sources_) { + + mergeOutputSpiller_ = std::make_unique( + outputType_, + std::nullopt, + HashBitRange{}, + sortingKeys_, + &spillConfig_.value(), + spillStats_.get()); +} + +void Merge::spill() { + if (output_ == nullptr) { + return; + } + maybeSetupOutputSpiller(); + numSpilledRows_ += output_->size(); + mergeOutputSpiller_->spill(SpillPartitionId{0}, output_); + output_ = nullptr; +} + +void Merge::finishMergeSourceGroup() { + sourceMerger_ = nullptr; + if (mergeOutputSpiller_ == nullptr) { + return; + } + VELOX_CHECK(needSpill()); + VELOX_CHECK_GT(numSpilledRows_, 0); + // Finishes spill if it has happened and setup spill merger if no more source + // to merge. + SpillPartitionSet spillPartitionSet; + mergeOutputSpiller_->finishSpill(spillPartitionSet); + mergeOutputSpiller_ = nullptr; + VELOX_CHECK_EQ(spillPartitionSet.size(), 1); + auto spillFiles = spillPartitionSet.begin()->second->files(); + VELOX_CHECK(!spillFiles.empty()); + spillFileGroups_.push_back(std::move(spillFiles)); +} + +void Merge::setupSpillMerger() { + VELOX_CHECK(!spillFileGroups_.empty()); + VELOX_CHECK_NULL(spillMerger_); + VELOX_CHECK(spillConfig_.has_value()); + std::vector>> spillReadFilesGroups; + spillReadFilesGroups.reserve(spillFileGroups_.size()); + for (const auto& spillFiles : spillFileGroups_) { + std::vector> spillReadFiles; + spillReadFiles.reserve(spillFiles.size()); + for (const auto& spillFile : spillFiles) { + spillReadFiles.emplace_back( + SpillReadFile::create( + spillFile, + spillConfig_->readBufferSize, + pool(), + spillStats_.get())); + } + spillReadFilesGroups.push_back(std::move(spillReadFiles)); + } + spillFileGroups_.clear(); + spillMerger_ = std::make_shared( + sortingKeys_, + outputType_, + std::move(spillReadFilesGroups), + maxOutputBatchRows_, + maxOutputBatchBytes_, + operatorCtx_->driverCtx()->queryConfig().localMergeSourceQueueSize(), + &spillConfig_.value(), + spillStats_, + pool()); + spillMerger_->start(); +} + +void Merge::maybeStartNextMergeSourceGroup() { + if (sourceMerger_ != nullptr || numStartedSources_ >= sources_.size()) { + return; + } + + // Gets the merge sources for the next partial merge run. + std::vector sources; + for (auto i = numStartedSources_; i < + (std::min(sources_.size(), numStartedSources_ + maxNumMergeSources_)); + ++i) { + sources.push_back(sources_[i].get()); + } + + // Initializes the source merger. + std::vector> cursors; + cursors.reserve(sources.size()); + for (auto* source : sources) { + cursors.push_back( + std::make_unique( + source, sortingKeys_, maxOutputBatchRows_)); + } + + // TODO: consider to provide a config other than the regular operator batch + // size to tune the batch size of the streaming source merge output as the + // merge operator is single threaded. + sourceMerger_ = std::make_unique( + outputType_, + std::move(cursors), + maxOutputBatchRows_, + maxOutputBatchBytes_, + pool()); + // Start sources. + for (const auto& source : sources) { source->start(); } - numStartedSources_ = sources_.size(); + numStartedSources_ += sources.size(); } -bool Merge::isFinished() { - return finished_; +RowVectorPtr Merge::getOutputFromSpill() { + VELOX_CHECK_NOT_NULL(spillMerger_); + VELOX_CHECK_NULL(sourceMerger_); + bool atEnd = false; + output_ = spillMerger_->getOutput(sourceBlockingFutures_, atEnd); + SCOPE_EXIT { + if (!atEnd) { + return; + } + finished_ = true; + VELOX_CHECK_EQ(mergeStats_.spilledSourceReadEndTimeUs, 0); + mergeStats_.spilledSourceReadEndTimeUs = getCurrentTimeMicro(); + }; + return std::move(output_); +} + +RowVectorPtr Merge::getOutputFromSource() { + VELOX_CHECK_NULL(spillMerger_); + bool atEnd = false; + output_ = sourceMerger_->getOutput(sourceBlockingFutures_, atEnd); + if (needSpill()) { + spill(); + VELOX_CHECK_NULL(output_); + } + + if (!atEnd) { + return std::move(output_); + } + + finishMergeSourceGroup(); + if (numStartedSources_ < sources_.size()) { + VELOX_CHECK_NULL(output_); + return nullptr; + } + + VELOX_CHECK_EQ(mergeStats_.streamingSourceReadEndTimeUs, 0); + mergeStats_.streamingSourceReadEndTimeUs = getCurrentTimeMicro(); + + if (numSpilledRows_ > 0) { + setupSpillMerger(); + VELOX_CHECK_NULL(output_); + return nullptr; + } + + finished_ = true; + return std::move(output_); } RowVectorPtr Merge::getOutput() { @@ -149,78 +273,167 @@ RowVectorPtr Merge::getOutput() { return nullptr; } - VELOX_CHECK_EQ(numStartedSources_, sources_.size()); + // Read from spill. + if (spillMerger_ != nullptr) { + return getOutputFromSpill(); + } - // No merging is needed if there is only one source. - if (sources_.size() == 1) { - ContinueFuture future; - RowVectorPtr data; - auto reason = sources_[0]->next(data, &future); - if (reason != BlockingReason::kNotBlocked) { - sourceBlockingFutures_.emplace_back(std::move(future)); - return nullptr; + return getOutputFromSource(); +} + +void Merge::close() { + recordMergeStats(); + for (auto& source : sources_) { + source->close(); + } + Operator::close(); +} + +void Merge::recordMergeStats() { + auto lockedStats = stats_.wlock(); + if (mergeStats_.streamingSourceReadEndTimeUs > 0) { + VELOX_CHECK_GT(mergeStats_.streamingSourceReadStartTimeUs, 0); + VELOX_CHECK_GE( + mergeStats_.streamingSourceReadEndTimeUs, + mergeStats_.streamingSourceReadStartTimeUs); + lockedStats->addRuntimeStat( + kStreamingSourceReadWallNanos, + RuntimeCounter( + (mergeStats_.streamingSourceReadEndTimeUs - + mergeStats_.streamingSourceReadStartTimeUs) * + 1'000, + RuntimeCounter::Unit::kNanos)); + } + if (mergeStats_.spilledSourceReadEndTimeUs > 0) { + VELOX_CHECK_GT(mergeStats_.streamingSourceReadEndTimeUs, 0); + VELOX_CHECK_GE( + mergeStats_.spilledSourceReadEndTimeUs, + mergeStats_.streamingSourceReadEndTimeUs); + VELOX_CHECK_GT(numSpilledRows_, 0); + lockedStats->addRuntimeStat( + kSpilledSourceReadWallNanos, + RuntimeCounter( + (mergeStats_.spilledSourceReadEndTimeUs - + mergeStats_.streamingSourceReadEndTimeUs) * + 1'000, + RuntimeCounter::Unit::kNanos)); + } +} + +SourceMerger::SourceMerger( + const RowTypePtr& type, + std::vector> sourceStreams, + vector_size_t maxOutputBatchRows, + uint64_t maxOutputBatchBytes, + velox::memory::MemoryPool* pool) + : type_(type), + maxOutputBatchRows_(maxOutputBatchRows), + maxOutputBatchBytes_(maxOutputBatchBytes), + streams_([&sourceStreams]() { + std::vector streams; + for (auto& cursor : sourceStreams) { + streams.push_back(cursor.get()); + } + return streams; + }()), + merger_( + std::make_unique>( + std::move(sourceStreams))), + pool_(pool) {} + +void SourceMerger::isBlocked( + std::vector& sourceBlockingFutures) const { + if (sourceBlockingFutures.empty()) { + for (auto* stream : streams_) { + stream->isBlocked(sourceBlockingFutures); } + } +} - finished_ = data == nullptr; - return data; +void SourceMerger::setOutputBatchSize() { + if (outputBatchRows_ != 0) { + return; + } + size_t numEstimations{0}; + int64_t estimateRowSizeSum{0}; + for (auto* stream : streams_) { + const auto estimateRowSize = stream->estimateRowSize(); + if (estimateRowSize.has_value()) { + ++numEstimations; + estimateRowSizeSum += estimateRowSize.value(); + } } + if (numEstimations == 0) { + outputBatchRows_ = maxOutputBatchRows_; + return; + } + + const auto estimateRowSize = + std::max(1, estimateRowSizeSum / numEstimations); + outputBatchRows_ = std::min( + std::max(1, maxOutputBatchBytes_ / estimateRowSize), + maxOutputBatchRows_); +} + +RowVectorPtr SourceMerger::getOutput( + std::vector& sourceBlockingFutures, + bool& atEnd) { + VELOX_CHECK_NOT_NULL(merger_); + atEnd = false; + setOutputBatchSize(); + VELOX_CHECK_GT(outputBatchRows_, 0); + if (!output_) { - output_ = BaseVector::create( - outputType_, outputBatchSize_, operatorCtx_->pool()); + output_ = BaseVector::create(type_, outputBatchRows_, pool_); for (auto& child : output_->children()) { - child->resize(outputBatchSize_); + child->resize(outputBatchRows_); } } for (;;) { - auto stream = treeOfLosers_->next(); + auto stream = merger_->next(); if (!stream) { - finished_ = true; + atEnd = true; // Return nullptr if there is no data. - if (outputSize_ == 0) { + if (outputRows_ == 0) { return nullptr; } - - output_->resize(outputSize_); + output_->resize(outputRows_); return std::move(output_); } - if (stream->setOutputRow(outputSize_)) { + if (stream->setOutputRow(outputRows_)) { // The stream is at end of input batch. Need to copy out the rows before // fetching next batch in 'pop'. stream->copyToOutput(output_); + TestValue::adjust( + "facebook::velox::exec::SourceMerger::getOutput", + &sourceBlockingFutures); } - ++outputSize_; + ++outputRows_; // Advance the stream. - stream->pop(sourceBlockingFutures_); + stream->pop(sourceBlockingFutures); - if (outputSize_ == outputBatchSize_) { + if (outputRows_ == outputBatchRows_) { // Copy out data from all sources. for (auto& s : streams_) { s->copyToOutput(output_); } - - outputSize_ = 0; + outputRows_ = 0; return std::move(output_); } - if (!sourceBlockingFutures_.empty()) { + if (!sourceBlockingFutures.empty()) { return nullptr; } } } -void Merge::close() { - for (auto& source : sources_) { - source->close(); - } -} - bool SourceStream::operator<(const MergeStream& other) const { const auto& otherCursor = static_cast(other); for (auto i = 0; i < sortingKeys_.size(); ++i) { @@ -301,6 +514,205 @@ bool SourceStream::fetchMoreData(std::vector& futures) { return false; } +SpillMerger::SpillMerger( + const std::vector& sortingKeys, + const RowTypePtr& type, + std::vector>> + spillReadFilesGroup, + vector_size_t maxOutputBatchRows, + uint64_t maxOutputBatchBytes, + int mergeSourceQueueSize, + const common::SpillConfig* spillConfig, + const std::shared_ptr>& spillStats, + velox::memory::MemoryPool* pool) + : executor_(spillConfig->executor), + spillStats_(spillStats), + pool_(pool->shared_from_this()), + sources_( + createMergeSources(spillReadFilesGroup.size(), mergeSourceQueueSize)), + batchStreams_(createBatchStreams(std::move(spillReadFilesGroup))), + // TODO: consider to provide a config other than the regular operator + // batch size to tune the batch size of the spilled source merge output as + // the merge operator is single threaded. + sourceMerger_(createSourceMerger( + sortingKeys, + type, + sources_, + maxOutputBatchRows, + maxOutputBatchBytes, + pool)) {} + +SpillMerger::~SpillMerger() { + sourceMerger_.reset(); + batchStreams_.clear(); + sources_.clear(); +} + +void SpillMerger::start() { + VELOX_CHECK_NOT_NULL( + executor_, + "SpillMerge require configure executor to run async spill file stream producer"); + scheduleAsyncSpillFileStreamReads(); +} + +RowVectorPtr SpillMerger::getOutput( + std::vector& sourceBlockingFutures, + bool& atEnd) { + TestValue::adjust( + "facebook::velox::exec::SpillMerger::getOutput", &sourceBlockingFutures); + sourceMerger_->isBlocked(sourceBlockingFutures); + if (!sourceBlockingFutures.empty()) { + return nullptr; + } + // SpillMerger::getOutput waits for all readers to finish, reaches EOF, + // and rethrows any captured error. Centralizing error propagation here + // helps prevent potential resource leaks. + auto output = sourceMerger_->getOutput(sourceBlockingFutures, atEnd); + if (atEnd) { + checkError(); + } + return output; +} + +std::vector> SpillMerger::createMergeSources( + size_t numSpillSources, + int queueSize) { + std::vector> sources; + sources.reserve(numSpillSources); + for (auto i = 0; i < numSpillSources; ++i) { + sources.push_back(MergeSource::createLocalMergeSource(queueSize)); + } + for (const auto& source : sources) { + source->start(); + } + return sources; +} + +std::vector> SpillMerger::createBatchStreams( + std::vector>> + spillReadFilesGroup) { + const auto numStreams = spillReadFilesGroup.size(); + std::vector> batchStreams; + batchStreams.reserve(numStreams); + for (auto i = 0; i < numStreams; ++i) { + batchStreams.emplace_back( + ConcatFilesSpillBatchStream::create(std::move(spillReadFilesGroup[i]))); + } + return batchStreams; +} + +std::unique_ptr SpillMerger::createSourceMerger( + const std::vector& sortingKeys, + const RowTypePtr& type, + const std::vector>& sources, + vector_size_t maxOutputBatchRows, + uint64_t maxOutputBatchBytes, + velox::memory::MemoryPool* pool) { + std::vector> streams; + streams.reserve(sources.size()); + for (const auto& source : sources) { + streams.push_back( + std::make_unique( + source.get(), sortingKeys, maxOutputBatchRows)); + } + return std::make_unique( + type, std::move(streams), maxOutputBatchRows, maxOutputBatchBytes, pool); +} + +void SpillMerger::finishSource(size_t streamIdx) const { + ContinueFuture future{ContinueFuture::makeEmpty()}; + sources_[streamIdx]->enqueue(nullptr, &future); + VELOX_CHECK(!future.valid()); +} + +void SpillMerger::readFromSpillFileStream( + const std::weak_ptr& mergeHolder, + size_t streamIdx) { + TestValue::adjust( + "facebook::velox::exec::SpillMerger::readFromSpillFileStream", nullptr); + const auto merger = mergeHolder.lock(); + if (merger == nullptr) { + LOG(ERROR) << "SpillMerger is destroyed, abandon reading from batch stream"; + return; + } + + try { + if (hasError()) { + finishSource(streamIdx); + return; + } + + RowVectorPtr vector; + if (!batchStreams_[streamIdx]->nextBatch(vector)) { + VELOX_CHECK_NULL(vector); + finishSource(streamIdx); + return; + } + + ContinueFuture future{ContinueFuture::makeEmpty()}; + const auto blockingReason = + sources_[streamIdx]->enqueue(std::move(vector), &future); + if (blockingReason == BlockingReason::kNotBlocked) { + VELOX_CHECK(!future.valid()); + readFromSpillFileStream(mergeHolder, streamIdx); + } else { + VELOX_CHECK(future.valid()); + std::move(future) + .via(executor_) + .thenValue([this, mergeHolder, streamIdx](auto&&) { + readFromSpillFileStream(mergeHolder, streamIdx); + }) + .thenError( + folly::tag_t{}, + [this, mergeHolder, streamIdx](const std::exception& e) { + const auto merger = mergeHolder.lock(); + if (merger != nullptr) { + LOG(ERROR) << "Stop the " << streamIdx + << " th source on error: " << e.what(); + setError(std::make_exception_ptr(e)); + finishSource(streamIdx); + } + }); + } + } catch (const std::exception& e) { + LOG(ERROR) << "The " << streamIdx + << " spill stream failed with error: " << e.what(); + setError(std::current_exception()); + finishSource(streamIdx); + } +} + +void SpillMerger::scheduleAsyncSpillFileStreamReads() { + VELOX_CHECK_EQ(batchStreams_.size(), sources_.size()); + for (auto i = 0; i < batchStreams_.size(); ++i) { + executor_->add([&, streamIdx = i]() { + readFromSpillFileStream(std::weak_ptr(shared_from_this()), streamIdx); + }); + } +} + +void SpillMerger::setError(const std::exception_ptr& exception) { + std::lock_guard l(mutex_); + if (exception_ != nullptr) { + return; + } + exception_ = exception; +} + +bool SpillMerger::hasError() const { + std::lock_guard l(mutex_); + return exception_ != nullptr; +} + +void SpillMerger::checkError() { + if (hasError()) { + sourceMerger_.reset(); + batchStreams_.clear(); + sources_.clear(); + std::rethrow_exception(exception_); + } +} + LocalMerge::LocalMerge( int32_t operatorId, DriverCtx* driverCtx, @@ -312,11 +724,22 @@ LocalMerge::LocalMerge( localMergeNode->sortingKeys(), localMergeNode->sortingOrders(), localMergeNode->id(), - "LocalMerge") { + "LocalMerge", + localMergeNode->canSpill(driverCtx->queryConfig()) + ? driverCtx->makeSpillConfig(operatorId) + : std::nullopt) { VELOX_CHECK_EQ( operatorCtx_->driverCtx()->driverId, 0, "LocalMerge needs to run single-threaded"); + // Enable local merge spill iff spill is enabled and the spill executor is + // provided. + if (spillConfig_.has_value() && spillConfig_->executor != nullptr) { + maxNumMergeSources_ = operatorCtx_->task() + ->queryCtx() + ->queryConfig() + .localMergeMaxNumMergeSources(); + } } BlockingReason LocalMerge::addMergeSources(ContinueFuture* /* future */) { @@ -341,7 +764,8 @@ MergeExchange::MergeExchange( "MergeExchange"), serde_(getNamedVectorSerde(mergeExchangeNode->serdeKind())), serdeOptions_(getVectorSerdeOptions( - driverCtx->queryConfig(), + common::stringToCompressionKind( + driverCtx->queryConfig().shuffleCompressionKind()), mergeExchangeNode->serdeKind())) {} BlockingReason MergeExchange::addMergeSources(ContinueFuture* future) { @@ -386,13 +810,14 @@ BlockingReason MergeExchange::addMergeSources(ContinueFuture* future) { operatorCtx_->planNodeId(), operatorCtx_->driverCtx()->pipelineId, remoteSourceIndex); - sources_.emplace_back(MergeSource::createMergeExchangeSource( - this, - remoteSourceTaskIds_[remoteSourceIndex], - operatorCtx_->task()->destination(), - maxQueuedBytesPerSource, - pool, - operatorCtx_->task()->queryCtx()->executor())); + sources_.emplace_back( + MergeSource::createMergeExchangeSource( + this, + remoteSourceTaskIds_[remoteSourceIndex], + operatorCtx_->task()->destination(), + maxQueuedBytesPerSource, + pool, + operatorCtx_->task()->queryCtx()->executor())); } } // TODO Delay this call until all input data has been processed. diff --git a/velox/exec/Merge.h b/velox/exec/Merge.h index 841988d5f799..b688b42cf268 100644 --- a/velox/exec/Merge.h +++ b/velox/exec/Merge.h @@ -15,14 +15,17 @@ */ #pragma once +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/Exchange.h" #include "velox/exec/MergeSource.h" #include "velox/exec/Spill.h" -#include "velox/exec/TreeOfLosers.h" +#include "velox/exec/Spiller.h" namespace facebook::velox::exec { class SourceStream; +class SourceMerger; +class SpillMerger; // Merge operator Implementation: This implementation uses priority queue // to perform a k-way merge of its inputs. It stops merging if any one of @@ -37,7 +40,10 @@ class Merge : public SourceOperator { sortingKeys, const std::vector& sortingOrders, const std::string& planNodeId, - const std::string& operatorType); + const std::string& operatorType, + const std::optional& spillConfig = std::nullopt); + + void initialize() override; BlockingReason isBlocked(ContinueFuture* future) override; @@ -51,39 +57,132 @@ class Merge : public SourceOperator { return outputType_; } + /// The name of runtime stats specific to merge. + /// The running wall time of the merge operator reading from the streaming + /// source. If spilling is enabled for local merge, this also includes the + /// time that writes to the spilled source. + static inline const std::string kStreamingSourceReadWallNanos{ + "streamingSourceReadWallNanos"}; + /// The running wall time of the merge operator reading from the spilled + /// source to produce the final output. This only applies when spilling is + /// enabled for local merge. + static inline const std::string kSpilledSourceReadWallNanos{ + "spilledSourceReadWallNanos"}; + protected: virtual BlockingReason addMergeSources(ContinueFuture* future) = 0; std::vector> sources_; size_t numStartedSources_{0}; + /// Maximum number of merge sources per run. + uint32_t maxNumMergeSources_{std::numeric_limits::max()}; private: - void startSources(); + // Tracks the internal execution stats for a merge operator. + struct Stats { + // The time point that a merge operator starts reading from the streaming + // source. + uint64_t streamingSourceReadStartTimeUs{0}; + // The time point that a merge operator finishes read from the streaming + // source. This includes the time that writes to the spilled source for + // recursive merge when spilling is enabled for local merge. + uint64_t streamingSourceReadEndTimeUs{0}; + // The time point that a merge operator finishes read from the spilled + // source. This only applies when spilling is enabled for local merge. + uint64_t spilledSourceReadEndTimeUs{0}; + }; + void recordMergeStats(); + + // Start sources for this merge run, it may start either all the sources at + // once or a portion of the sources at a time to cap the memory usage. + void maybeStartNextMergeSourceGroup(); + + // Returns true if needs to spill the merged source output if all sources can + // not be merged at once. + bool needSpill() const { + return maxNumMergeSources_ < sources_.size(); + } - void initializeTreeOfLosers(); + void maybeSetupOutputSpiller(); - /// Maximum number of rows in the output batch. - const vector_size_t outputBatchSize_; + // Spill the output of a partial merge sources. + void spill(); - std::vector sortingKeys_; + // Invoked at the end for each partial merge run to ensure the order within + // each spill file. + void finishMergeSourceGroup(); - /// A list of cursors over batches of ordered source data. One per source. - /// Aligned with 'sources'. - std::vector streams_; + // Create spillMerger_ exactly once if spill has happened. + void setupSpillMerger(); - /// Used to merge data from two or more sources. - std::unique_ptr> treeOfLosers_; + RowVectorPtr getOutputFromSpill(); - RowVectorPtr output_; + RowVectorPtr getOutputFromSource(); + + // Maximum number of rows in the output batch. + const vector_size_t maxOutputBatchRows_; + // Maximum number of bytes in the output batch. + const uint64_t maxOutputBatchBytes_; + const std::vector sortingKeys_; + + Stats mergeStats_; + RowVectorPtr output_; /// Number of rows accumulated in 'output_' so far. vector_size_t outputSize_{0}; - bool finished_{false}; /// A list of blocking futures for sources. These are populates when a given /// source is blocked waiting for the next batch of data. std::vector sourceBlockingFutures_; + + std::unique_ptr sourceMerger_; + std::shared_ptr spillMerger_; + std::unique_ptr mergeOutputSpiller_; + // Number of total spilled rows, it must be equal to the input rows. + uint64_t numSpilledRows_{0}; + // SpillFiles group for all the partial merge runs. + std::vector spillFileGroups_; +}; + +/// A utility class for sort-merging data from upstream sources of the +/// `LocalMerge` operator. The `LocalMerge` operator may start only a portion of +/// the sources at a time to cap the memory usage, hence it might perform +/// multiple sort-merge operations with a subset of merge sources. +class SourceMerger { + public: + SourceMerger( + const RowTypePtr& type, + std::vector> sourceStreams, + vector_size_t maxOutputBatchRows, + uint64_t maxOutputBatchBytes, + velox::memory::MemoryPool* pool); + + void isBlocked(std::vector& sourceBlockingFutures) const; + + RowVectorPtr getOutput( + std::vector& sourceBlockingFutures, + bool& atEnd); + + private: + void setOutputBatchSize(); + + const RowTypePtr type_; + const vector_size_t maxOutputBatchRows_; + const uint64_t maxOutputBatchBytes_; + const std::vector streams_; + const std::unique_ptr> merger_; + velox::memory::MemoryPool* const pool_; + + // The max number of rows in an output vector which is determined by + // 'setOutputBatchSize'. The calculation is based on the actual estimated row + // size and capped by 'maxOutputBatchRows_' and 'maxOutputBatchBytes_'. + vector_size_t outputBatchRows_{0}; + + // Reusable output vector. + RowVectorPtr output_; + // The number of rows in 'output_' vector. + uint64_t outputRows_{0}; }; class SourceStream final : public MergeStream { @@ -112,6 +211,15 @@ class SourceStream final : public MergeStream { return !atEnd_; } + // Returns the estimated row size based on the vector received from the + // merge source. + std::optional estimateRowSize() const { + if (data_ == nullptr || data_->size() == 0) { + return std::nullopt; + } + return data_->estimateFlatSize() / data_->size(); + } + /// Returns true if current source row is less then current source row in /// 'other'. bool operator<(const MergeStream& other) const override; @@ -171,6 +279,76 @@ class SourceStream final : public MergeStream { std::vector sourceRows_; }; +/// A utility class for sort-merging data from data spilled by the `LocalMerge` +/// operator. +class SpillMerger : public std::enable_shared_from_this { + public: + SpillMerger( + const std::vector& sortingKeys, + const RowTypePtr& type, + std::vector>> + spillReadFilesGroup, + vector_size_t maxOutputBatchRows, + uint64_t maxOutputBatchBytes, + int mergeSourceQueueSize, + const common::SpillConfig* spillConfig, + const std::shared_ptr>& + spillStats, + velox::memory::MemoryPool* pool); + + ~SpillMerger(); + + void start(); + + RowVectorPtr getOutput( + std::vector& sourceBlockingFutures, + bool& atEnd); + + private: + static std::vector> createMergeSources( + size_t numSpillSources, + int queueSize); + + static std::vector> createBatchStreams( + std::vector>> + spillReadFilesGroup); + + static std::unique_ptr createSourceMerger( + const std::vector& sortingKeys, + const RowTypePtr& type, + const std::vector>& sources, + vector_size_t maxOutputBatchRows, + uint64_t maxOutputBatchBytes, + velox::memory::MemoryPool* pool); + + void finishSource(size_t streamIdx) const; + + void readFromSpillFileStream( + const std::weak_ptr& mergeHolder, + size_t streamIdx); + + void scheduleAsyncSpillFileStreamReads(); + + // Sets 'exception_' when an async reader throws. + void setError(const std::exception_ptr& exception); + + // Returns true if any async reader has thrown an exception. + bool hasError() const; + + // If any async reader has thrown an exception, rethrows it. + void checkError(); + + folly::Executor* const executor_; + const std::shared_ptr> spillStats_; + const std::shared_ptr pool_; + + std::vector> sources_; + std::vector> batchStreams_; + std::unique_ptr sourceMerger_; + mutable std::timed_mutex mutex_; + std::exception_ptr exception_ = nullptr; +}; + // LocalMerge merges its source's output into a single stream of // sorted rows. It runs single threaded. The sources may run multi-threaded and // in the same task. diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 550c059cc81f..382448465121 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -20,6 +20,25 @@ namespace facebook::velox::exec { +namespace { +void copyRow( + const RowVectorPtr& source, + vector_size_t sourceIndex, + const RowVectorPtr& target, + vector_size_t targetIndex, + const std::vector& projections) { + for (const auto& projection : projections) { + const auto& sourceChild = source->childAt(projection.inputChannel); + const auto& targetChild = target->childAt(projection.outputChannel); + targetChild->copy(sourceChild.get(), targetIndex, sourceIndex, 1); + } +} + +bool isSemiFilterJoin(core::JoinType joinType) { + return isLeftSemiFilterJoin(joinType) || isRightSemiFilterJoin(joinType); +} +} // namespace + MergeJoin::MergeJoin( int32_t operatorId, DriverCtx* driverCtx, @@ -36,9 +55,9 @@ MergeJoin::MergeJoin( rightNodeId_{joinNode->sources()[1]->id()}, joinNode_(joinNode) { VELOX_USER_CHECK( - core::MergeJoinNode::isSupported(joinNode_->joinType()), - "The join type is not supported by merge join: ", - joinTypeName(joinNode_->joinType())); + core::MergeJoinNode::isSupported(joinType_), + "The join type is not supported by merge join: {}", + core::JoinTypeName::toName(joinType_)); } void MergeJoin::initialize() { @@ -90,7 +109,8 @@ void MergeJoin::initialize() { initializeFilter(joinNode_->filter(), leftType, rightType); if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || - joinNode_->isRightJoin() || joinNode_->isFullJoin()) { + joinNode_->isRightJoin() || joinNode_->isFullJoin() || + isSemiFilterJoin(joinType_)) { joinTracker_ = JoinTracker(outputBatchSize_, pool()); } } else if (joinNode_->isAntiJoin()) { @@ -181,8 +201,8 @@ void MergeJoin::initializeFilter( } BlockingReason MergeJoin::isBlocked(ContinueFuture* future) { - if (futureRightSideInput_.valid()) { - *future = std::move(futureRightSideInput_); + if (rightSideInputFuture_.valid()) { + *future = std::move(rightSideInputFuture_); return BlockingReason::kWaitForMergeJoinRightSide; } @@ -190,16 +210,15 @@ BlockingReason MergeJoin::isBlocked(ContinueFuture* future) { } bool MergeJoin::needsInput() const { - if (isRightJoin(joinType_)) { - return (input_ == nullptr || rightInput_ == nullptr); - } return input_ == nullptr; } void MergeJoin::addInput(RowVectorPtr input) { + VELOX_CHECK_NULL(input_); input_ = std::move(input); + // TODO: support selective lazy loading both sides. + loadColumns(input_, *operatorCtx_->execCtx()); leftRowIndex_ = 0; - if (joinTracker_) { joinTracker_->resetLastVector(); } @@ -207,28 +226,37 @@ void MergeJoin::addInput(RowVectorPtr input) { // static int32_t MergeJoin::compare( - const std::vector& keys, - const RowVectorPtr& batch, - vector_size_t index, - const std::vector& otherKeys, - const RowVectorPtr& otherBatch, - vector_size_t otherIndex) { - for (auto i = 0; i < keys.size(); ++i) { + const std::vector& leftKeys, + const RowVectorPtr& leftBatch, + vector_size_t leftIndex, + const std::vector& rightKeys, + const RowVectorPtr& rightBatch, + vector_size_t rightIndex) { + for (auto i = 0; i < leftKeys.size(); ++i) { static const CompareFlags kCompareFlags = { .equalsOnly = true, .nullHandlingMode = CompareFlags::NullHandlingMode::kNullAsIndeterminate}; - const auto compare = batch->childAt(keys[i])->compare( - otherBatch->childAt(otherKeys[i]).get(), - index, - otherIndex, - kCompareFlags); + const auto compare = leftBatch->childAt(leftKeys[i]) + ->compare( + rightBatch->childAt(rightKeys[i]).get(), + leftIndex, + rightIndex, + kCompareFlags); - // Comparing null with anything will return std::nullopt. if (!compare.has_value()) { - // The SQL semantics of Presto and Spark will always return false if - // comparing a NULL value with any other value. - return -1; + // Under CompareFlags::NullHandlingMode::kNullAsIndeterminate, + // std::nullopt is returned in three cases: + // 1) Both the left key and the right key are null. + // 2) The left key is null, and the right key is not null. + // 3) The left key is not null, and the right key is null. + // + // However, the comparison result semantics differ: + // - Cases (1) and (2): return -1, meaning input_ should catch up with + // rightInput_. + // - Case (3): return 1, indicating the left key is considered greater, + // so rightInput_ should catch up with input_ in the subsequent steps. + return leftBatch->childAt(leftKeys[i])->isNullAt(leftIndex) ? -1 : 1; } else if (compare.value() != 0) { return compare.value(); } @@ -245,9 +273,8 @@ bool MergeJoin::findEndOfMatch( return true; } - auto prevInput = match.inputs.back(); - auto prevIndex = prevInput->size() - 1; - + const auto prevInput = match.inputs.back(); + const auto prevIndex = prevInput->size() - 1; const auto numInputRows = input->size(); vector_size_t endRow = 0; @@ -257,9 +284,6 @@ bool MergeJoin::findEndOfMatch( } if (endRow == numInputRows) { - // Inputs are kept past getting a new batch of inputs. LazyVectors - // must be loaded before advancing to the next batch. - loadColumns(input, *operatorCtx_->execCtx()); match.inputs.push_back(input); match.endRowIndex = endRow; return false; @@ -274,21 +298,6 @@ bool MergeJoin::findEndOfMatch( return true; } -namespace { -void copyRow( - const RowVectorPtr& source, - vector_size_t sourceIndex, - const RowVectorPtr& target, - vector_size_t targetIndex, - const std::vector& projections) { - for (const auto& projection : projections) { - const auto& sourceChild = source->childAt(projection.inputChannel); - const auto& targetChild = target->childAt(projection.outputChannel); - targetChild->copy(sourceChild.get(), targetIndex, sourceIndex, 1); - } -} -} // namespace - inline void addNull( VectorPtr& target, vector_size_t index, @@ -338,17 +347,23 @@ bool MergeJoin::tryAddOutputRowForLeftJoin() { } ++outputSize_; - return true; } bool MergeJoin::tryAddOutputRowForRightJoin() { - VELOX_USER_CHECK(isRightJoin(joinType_) || isFullJoin(joinType_)); + VELOX_CHECK(isRightJoin(joinType_) || isFullJoin(joinType_)); if (outputSize_ == outputBatchSize_) { return false; } - rawRightOutputIndices_[outputSize_] = rightRowIndex_++; + if (!isRightFlattened_) { + // All right side projections share the same dictionary indices + // (rightIndices_). + rawRightOutputIndices_[outputSize_] = rightRowIndex_++; + } else { + copyRow( + rightInput_, rightRowIndex_++, output_, outputSize_, rightProjections_); + } for (const auto& projection : leftProjections_) { auto& target = output_->childAt(projection.outputChannel); @@ -366,7 +381,6 @@ bool MergeJoin::tryAddOutputRowForRightJoin() { } ++outputSize_; - return true; } @@ -421,7 +435,7 @@ bool MergeJoin::tryAddOutputRow( filterRightInputProjections_); if (joinTracker_) { - if (isRightJoin(joinType_)) { + if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { // Record right-side row with a match on the left-side. joinTracker_->addMatch(rightBatch, rightRow, outputSize_); } else { @@ -440,7 +454,6 @@ bool MergeJoin::tryAddOutputRow( } ++outputSize_; - return true; } @@ -548,17 +561,16 @@ bool MergeJoin::prepareOutput( if (filter_ != nullptr && filterInput_ == nullptr) { std::vector inputs(filterInputType_->size()); - for (const auto [filterInputChannel, outputChannel] : - filterInputToOutputChannel_) { - inputs[filterInputChannel] = output_->childAt(outputChannel); - } for (auto i = 0; i < filterInputType_->size(); ++i) { if (filterInputToOutputChannel_.find(i) != filterInputToOutputChannel_.end()) { - continue; + inputs[i] = output_->childAt(filterInputToOutputChannel_[i]); + } else { + inputs[i] = BaseVector::create( + filterInputType_->childAt(i), + outputBatchSize_, + operatorCtx_->pool()); } - inputs[i] = BaseVector::create( - filterInputType_->childAt(i), outputBatchSize_, operatorCtx_->pool()); } filterInput_ = std::make_shared( @@ -579,69 +591,98 @@ bool MergeJoin::addToOutput() { } } -bool MergeJoin::addToOutputForLeftJoin() { - size_t firstLeftBatch; - vector_size_t leftStartRowIndex; - if (leftMatch_->cursor) { - firstLeftBatch = leftMatch_->cursor->batchIndex; - leftStartRowIndex = leftMatch_->cursor->rowIndex; +template +bool MergeJoin::addToOutputImpl() { + // For left join: outerMatch=left, innerMatch=right + // For right join: outerMatch=right, innerMatch=left + auto& outerMatch = IsLeftJoin ? leftMatch_ : rightMatch_; + auto& innerMatch = IsLeftJoin ? rightMatch_ : leftMatch_; + + size_t outerFirstBatch; + vector_size_t outerStartRowIndex; + if (outerMatch->cursor) { + outerFirstBatch = outerMatch->cursor->batchIndex; + outerStartRowIndex = outerMatch->cursor->rowIndex; } else { - firstLeftBatch = 0; - leftStartRowIndex = leftMatch_->startRowIndex; + outerFirstBatch = 0; + outerStartRowIndex = outerMatch->startRowIndex; } - const size_t numLeftBatches = leftMatch_->inputs.size(); - for (size_t l = firstLeftBatch; l < numLeftBatches; ++l) { - const auto leftBatch = leftMatch_->inputs[l]; - const auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0; - const auto leftEndRow = - l == numLeftBatches - 1 ? leftMatch_->endRowIndex : leftBatch->size(); - - for (auto i = leftStartRow; i < leftEndRow; ++i) { - const auto firstRightBatch = - (l == firstLeftBatch && i == leftStartRow && rightMatch_->cursor) - ? rightMatch_->cursor->batchIndex + const size_t numOuterBatches = outerMatch->inputs.size(); + for (size_t outerBatchIndex = outerFirstBatch; + outerBatchIndex < numOuterBatches; + ++outerBatchIndex) { + const auto& outerBatch = outerMatch->inputs[outerBatchIndex]; + const auto outerStartRow = + outerBatchIndex == outerFirstBatch ? outerStartRowIndex : 0; + const auto outerEndRow = outerBatchIndex == numOuterBatches - 1 + ? outerMatch->endRowIndex + : outerBatch->size(); + + for (auto outerRow = outerStartRow; outerRow < outerEndRow; ++outerRow) { + const bool outerFirstRow = + outerBatchIndex == outerFirstBatch && outerRow == outerStartRow; + const auto innerFirstBatch = (outerFirstRow && innerMatch->cursor) + ? innerMatch->cursor->batchIndex : 0; - const auto rightStartRowIndex = - (l == firstLeftBatch && i == leftStartRow && rightMatch_->cursor) - ? rightMatch_->cursor->rowIndex - : rightMatch_->startRowIndex; + const auto innerStartRowIndex = (outerFirstRow && innerMatch->cursor) + ? innerMatch->cursor->rowIndex + : innerMatch->startRowIndex; - const auto numRightBatches = rightMatch_->inputs.size(); + const auto numInnerBatches = innerMatch->inputs.size(); // TODO: Since semi joins only require determining if there is at least // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - for (size_t r = isLeftSemiFilterJoin(joinType_) ? numRightBatches - 1 - : firstRightBatch; - r < numRightBatches; - ++r) { - const auto rightBatch = rightMatch_->inputs[r]; - auto rightStartRow = r == firstRightBatch ? rightStartRowIndex : 0; - const auto rightEndRow = r == numRightBatches - 1 - ? rightMatch_->endRowIndex - : rightBatch->size(); - if (isLeftSemiFilterJoin(joinType_)) { - rightStartRow = rightEndRow - 1; + // Handle semi-filter join optimization + const bool isSemiFilter = IsLeftJoin + ? (isLeftSemiFilterJoin(joinType_) && !filter_) + : (isRightSemiFilterJoin(joinType_) && !filter_); + + for (size_t innerBatchIndex = isSemiFilter ? numInnerBatches - 1 + : innerFirstBatch; + innerBatchIndex < numInnerBatches; + ++innerBatchIndex) { + const auto& innerBatch = innerMatch->inputs[innerBatchIndex]; + auto innerStartRow = + innerBatchIndex == innerFirstBatch ? innerStartRowIndex : 0; + const auto innerEndRow = innerBatchIndex == numInnerBatches - 1 + ? innerMatch->endRowIndex + : innerBatch->size(); + + if (isSemiFilter) { + innerStartRow = innerEndRow - 1; } + + // Determine the correct order for prepareOutput and tryAddOutputRow + const auto& leftBatch = IsLeftJoin ? outerBatch : innerBatch; + const auto& rightBatch = IsLeftJoin ? innerBatch : outerBatch; + if (prepareOutput(leftBatch, rightBatch)) { output_->resize(outputSize_); - leftMatch_->setCursor(l, i); - rightMatch_->setCursor(r, rightStartRow); + outerMatch->setCursor(outerBatchIndex, outerRow); + innerMatch->setCursor(innerBatchIndex, innerStartRow); return true; } - for (auto j = rightStartRow; j < rightEndRow; ++j) { - if (!tryAddOutputRow(leftBatch, i, rightBatch, j)) { - // If we run out of space in the current output_, we will need to - // produce a buffer and continue processing left later. In this - // case, we cannot leave left as a lazy vector, since we cannot have - // two dictionaries wrapping the same lazy vector. - loadColumns(currentLeft_, *operatorCtx_->execCtx()); - leftMatch_->setCursor(l, i); - rightMatch_->setCursor(r, j); - return true; + if (IsLeftJoin) { + for (auto innerRow = innerStartRow; innerRow < innerEndRow; + ++innerRow) { + if (!tryAddOutputRow(leftBatch, outerRow, rightBatch, innerRow)) { + outerMatch->setCursor(outerBatchIndex, outerRow); + innerMatch->setCursor(innerBatchIndex, innerRow); + return true; + } + } + } else { + for (auto innerRow = innerStartRow; innerRow < innerEndRow; + ++innerRow) { + if (!tryAddOutputRow(leftBatch, innerRow, rightBatch, outerRow)) { + outerMatch->setCursor(outerBatchIndex, outerRow); + innerMatch->setCursor(innerBatchIndex, innerRow); + return true; + } } } } @@ -650,100 +691,15 @@ bool MergeJoin::addToOutputForLeftJoin() { leftMatch_.reset(); rightMatch_.reset(); - - // If the current key match finished, but there are still records to be - // processed in the left, we need to load lazy vectors (see comment above). - if (input_ && leftRowIndex_ != input_->size()) { - loadColumns(currentLeft_, *operatorCtx_->execCtx()); - } return outputSize_ == outputBatchSize_; } -bool MergeJoin::addToOutputForRightJoin() { - size_t firstRightBatch; - vector_size_t rightStartRowIndex; - if (rightMatch_->cursor) { - firstRightBatch = rightMatch_->cursor->batchIndex; - rightStartRowIndex = rightMatch_->cursor->rowIndex; - } else { - firstRightBatch = 0; - rightStartRowIndex = rightMatch_->startRowIndex; - } - - const size_t numRightBatches = rightMatch_->inputs.size(); - for (size_t r = firstRightBatch; r < numRightBatches; ++r) { - const auto rightBatch = rightMatch_->inputs[r]; - const auto rightStartRow = r == firstRightBatch ? rightStartRowIndex : 0; - const auto rightEndRow = r == numRightBatches - 1 ? rightMatch_->endRowIndex - : rightBatch->size(); - - for (auto i = rightStartRow; i < rightEndRow; ++i) { - const auto firstLeftBatch = - (r == firstRightBatch && i == rightStartRow && leftMatch_->cursor) - ? leftMatch_->cursor->batchIndex - : 0; - - const auto leftStartRowIndex = - (r == firstRightBatch && i == rightStartRow && leftMatch_->cursor) - ? leftMatch_->cursor->rowIndex - : leftMatch_->startRowIndex; - - const auto numLeftBatches = leftMatch_->inputs.size(); - // TODO: Since semi joins only require determining if there is at least - // one match on the other side, we could explore specialized algorithms - // or data structures that short-circuit the join process once a match - // is found. - for (size_t l = isRightSemiFilterJoin(joinType_) ? numLeftBatches - 1 - : firstLeftBatch; - l < numLeftBatches; - ++l) { - const auto leftBatch = leftMatch_->inputs[l]; - auto leftStartRow = l == firstLeftBatch ? leftStartRowIndex : 0; - const auto leftEndRow = l == numLeftBatches - 1 - ? leftMatch_->endRowIndex - : leftBatch->size(); - if (isRightSemiFilterJoin(joinType_)) { - // RightSemiFilter produce each row from the right at most once. - leftStartRow = leftEndRow - 1; - } - - if (prepareOutput(leftBatch, rightBatch)) { - // Differently from left joins, for right joins we need to load lazies - // (from the left) whenever we detect we have to move to the next - // right batch, since this means that we will produce this buffer, but - // we may have subsequent matches. - loadColumns(leftBatch, *operatorCtx_->execCtx()); - output_->resize(outputSize_); - leftMatch_->setCursor(l, leftStartRow); - rightMatch_->setCursor(r, i); - return true; - } - - for (auto j = leftStartRow; j < leftEndRow; ++j) { - if (!tryAddOutputRow(leftBatch, j, rightBatch, i)) { - // If we run out of space in the current output_, we will need to - // produce a buffer and continue processing left later. In this - // case, we cannot leave left as a lazy vector, since we cannot have - // two dictionaries wrapping the same lazy vector. - loadColumns(currentLeft_, *operatorCtx_->execCtx()); - rightMatch_->setCursor(r, i); - leftMatch_->setCursor(l, j); - return true; - } - } - } - } - } - - leftMatch_.reset(); - rightMatch_.reset(); +bool MergeJoin::addToOutputForLeftJoin() { + return addToOutputImpl(); +} - // If the current key match finished, but there are still records to be - // processed in the left, we need to load lazy vectors (see comment above). - if (rightInput_ && rightRowIndex_ != rightInput_->size()) { - loadColumns(currentLeft_, *operatorCtx_->execCtx()); - } - return outputSize_ == outputBatchSize_; +bool MergeJoin::addToOutputForRightJoin() { + return addToOutputImpl(); } namespace { @@ -751,19 +707,18 @@ vector_size_t firstNonNull( const RowVectorPtr& rowVector, const std::vector& keys, vector_size_t start = 0) { - for (auto i = start; i < rowVector->size(); ++i) { + for (auto row = start; row < rowVector->size(); ++row) { bool hasNull = false; for (auto key : keys) { - if (rowVector->childAt(key)->isNullAt(i)) { + if (rowVector->childAt(key)->isNullAt(row)) { hasNull = true; break; } } if (!hasNull) { - return i; + return row; } } - return rowVector->size(); } } // namespace @@ -813,104 +768,125 @@ RowVectorPtr MergeJoin::getOutput() { } return output; } - // No rows survived the filter. Get more rows. continue; - } else if (isAntiJoin(joinType_)) { - output = filterOutputForAntiJoin(output); - if (output) { - return output; - } + } - // No rows survived the filter for anti join. Get more rows. - continue; - } else { + if (!isAntiJoin(joinType_)) { return output; } + + output = filterOutputForAntiJoin(output); + if (output != nullptr && output->size() > 0) { + return output; + } + // No rows survived the filter for anti join. Get more rows. + continue; } - if (rightHasDrained_ && leftHasDrained_) { - finishDrain(); + if (processDrain()) { return nullptr; } - if (leftHasDrained_ && !input_) { - if (isInnerJoin(joinType_) && (!rightMatch_ || rightMatch_->complete)) { - operatorCtx_->task()->dropInput(rightNodeId_); - } - if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { - operatorCtx_->task()->dropInput(rightNodeId_); + + // Check if we need to get more data from the right side. + if (needsInputFromRightSide()) { + if (!getNextFromRightSide()) { + return nullptr; } + continue; } - if (rightHasDrained_ && !rightInput_) { - if (isInnerJoin(joinType_) && (!leftMatch_ || leftMatch_->complete)) { - operatorCtx_->task()->dropInput(this); - } - if (isRightJoin(joinType_)) { - operatorCtx_->task()->dropInput(this); - } + + return nullptr; + } + VELOX_UNREACHABLE(); +} + +bool MergeJoin::processDrain() { + if (rightHasDrained_ && leftHasDrained_) { + finishDrain(); + return true; + } + + if (leftHasDrained_ && !input_) { + if (isInnerJoin(joinType_) && (!rightMatch_ || rightMatch_->complete)) { + operatorCtx_->task()->dropInput(rightNodeId_); } - // Check if we need to get more data from the right side. - if (!rightHasNoInput() && !futureRightSideInput_.valid() && !rightInput_) { - if (!rightSource_) { - rightSource_ = operatorCtx_->task()->getMergeJoinSource( - operatorCtx_->driverCtx()->splitGroupId, planNodeId()); - } + if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { + operatorCtx_->task()->dropInput(rightNodeId_); + } + } - while (!rightHasNoInput() && !rightInput_) { - const auto blockingReason = rightSource_->next( - &futureRightSideInput_, &rightInput_, rightHasDrained_); - if (blockingReason != BlockingReason::kNotBlocked) { - VELOX_CHECK(!rightHasDrained_); - return nullptr; - } + if (rightHasDrained_ && !rightInput_) { + if (isInnerJoin(joinType_) && (!leftMatch_ || leftMatch_->complete)) { + operatorCtx_->task()->dropInput(this); + } + if (isRightJoin(joinType_)) { + operatorCtx_->task()->dropInput(this); + } + } + return false; +} - if (rightInput_) { - if (isFullJoin(joinType_) || isRightJoin(joinType_)) { - rightRowIndex_ = 0; - } else { - rightRowIndex_ = firstNonNull(rightInput_, rightKeyChannels_); - if (finishedRightBatch()) { - // Ran out of rows on the right side. - rightInput_ = nullptr; - } - } - } else if (!rightHasDrained_) { - noMoreRightInput_ = true; +bool MergeJoin::needsInputFromRightSide() const { + return !rightHasNoInput() && !rightSideInputFuture_.valid() && !rightInput_; +} + +bool MergeJoin::getNextFromRightSide() { + VELOX_CHECK(needsInputFromRightSide()); + if (rightSource_ == nullptr) { + rightSource_ = operatorCtx_->task()->getMergeJoinSource( + operatorCtx_->driverCtx()->splitGroupId, planNodeId()); + } + + while (!rightHasNoInput() && !rightInput_) { + const auto blockingReason = rightSource_->next( + &rightSideInputFuture_, &rightInput_, rightHasDrained_); + if (blockingReason != BlockingReason::kNotBlocked) { + VELOX_CHECK(!rightHasDrained_); + return false; + } + + if (rightInput_) { + if (isFullJoin(joinType_) || isRightJoin(joinType_)) { + rightRowIndex_ = 0; + } else { + rightRowIndex_ = firstNonNull(rightInput_, rightKeyChannels_); + if (rightBatchFinished()) { + // Ran out of rows on the right side. + clearRightInput(); } } - continue; + } else if (!rightHasDrained_) { + noMoreRightInput_ = true; } - - return nullptr; } - VELOX_UNREACHABLE(); + return true; } RowVectorPtr MergeJoin::handleRightSideNullRows() { + if (!isRightJoin(joinType_) && !isFullJoin(joinType_)) { + return nullptr; + } const auto rightFirstNonNullIndex = - firstNonNull(rightInput_, rightKeyChannels_); - if ((isRightJoin(joinType_) || isFullJoin(joinType_)) && - rightFirstNonNullIndex > rightRowIndex_) { - if (prepareOutput(nullptr, rightInput_)) { + firstNonNull(rightInput_, rightKeyChannels_, rightRowIndex_); + if (rightFirstNonNullIndex <= rightRowIndex_) { + return nullptr; + } + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + for (int row = rightRowIndex_; row < rightFirstNonNullIndex; ++row) { + if (!tryAddOutputRowForRightJoin()) { output_->resize(outputSize_); return std::move(output_); } - for (int i = rightRowIndex_; i < rightFirstNonNullIndex; ++i) { - if (!tryAddOutputRowForRightJoin()) { - rightRowIndex_ = i; - return std::move(output_); - } - - if (finishedRightBatch()) { - // Ran out of rows on the right side. - rightInput_ = nullptr; - return nullptr; - } - } - - rightRowIndex_ = rightFirstNonNullIndex; } - + VELOX_CHECK_EQ(rightRowIndex_, rightFirstNonNullIndex); + if (rightBatchFinished()) { + // Ran out of rows on the right side. + clearRightInput(); + } return nullptr; } @@ -933,53 +909,11 @@ RowVectorPtr MergeJoin::doGetOutput() { // match. if (leftMatch_) { VELOX_CHECK(rightMatch_); - - if (input_) { - // Look for continuation of a match on the left and/or right sides. - if (!findEndOfMatch(input_, leftKeyChannels_, leftMatch_.value())) { - // Continue looking for the end of the match. - input_ = nullptr; - return nullptr; - } - VELOX_CHECK(leftMatch_->complete); - - if (leftMatch_->inputs.back() == input_) { - leftRowIndex_ = leftMatch_->endRowIndex; - } - } else if (leftHasNoInput()) { - leftMatch_->complete = true; - } else { - // Need more input. - return nullptr; - } - - if (rightInput_) { - if (!findEndOfMatch( - rightInput_, rightKeyChannels_, rightMatch_.value())) { - VELOX_CHECK(!rightMatch_->complete); - // Continue looking for the end of the match. - rightInput_ = nullptr; - return nullptr; - } - VELOX_CHECK(rightMatch_->complete); - - if (rightMatch_->inputs.back() == rightInput_) { - if (isFullJoin(joinType_) || isRightJoin(joinType_)) { - rightRowIndex_ = rightMatch_->endRowIndex; - } else { - rightRowIndex_ = firstNonNull( - rightInput_, rightKeyChannels_, rightMatch_->endRowIndex); - if (rightRowIndex_ == rightInput_->size()) { - rightInput_ = nullptr; - } - } - } - } else if (rightHasNoInput()) { - rightMatch_->complete = true; - } else { - // Need more input. + if (!advanceMatch()) { return nullptr; } + VELOX_CHECK(leftMatch_->complete); + VELOX_CHECK(rightMatch_->complete); } // There is no output-in-progress match, but there can be a complete match @@ -994,131 +928,10 @@ RowVectorPtr MergeJoin::doGetOutput() { } if (!input_ || !rightInput_) { - if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { - if (input_ && rightHasNoInput()) { - // If output_ is currently wrapping a different buffer, return it - // first. - if (prepareOutput(input_, nullptr)) { - output_->resize(outputSize_); - return std::move(output_); - } - while (true) { - if (!tryAddOutputRowForLeftJoin()) { - return std::move(output_); - } - - if (finishedLeftBatch()) { - input_ = nullptr; - return produceOutput(); - } - } - VELOX_UNREACHABLE(); - } - - if (leftHasNoInput()) { - if (output_ != nullptr) { - output_->resize(outputSize_); - return std::move(output_); - } - if (input_ == nullptr) { - rightInput_ = nullptr; - } - } - } else if (isRightJoin(joinType_)) { - if (rightInput_ && leftHasNoInput()) { - // If output_ is currently wrapping a different buffer, return it - // first. - if (prepareOutput(nullptr, rightInput_)) { - output_->resize(outputSize_); - return std::move(output_); - } - - while (true) { - if (!tryAddOutputRowForRightJoin()) { - return std::move(output_); - } - - if (finishedRightBatch()) { - // Ran out of rows on the right side. - rightInput_ = nullptr; - return nullptr; - } - } - VELOX_UNREACHABLE(); - } - - if (rightHasNoInput() && output_) { - output_->resize(outputSize_); - return std::move(output_); - } - } else if (isFullJoin(joinType_)) { - if (input_ && rightHasNoInput()) { - // If output_ is currently wrapping a different buffer, return it - // first. - if (prepareOutput(input_, nullptr)) { - output_->resize(outputSize_); - return std::move(output_); - } - - while (true) { - if (!tryAddOutputRowForLeftJoin()) { - return std::move(output_); - } - - if (finishedLeftBatch()) { - input_ = nullptr; - return produceOutput(); - } - } - VELOX_UNREACHABLE(); - } - - if (leftHasNoInput() && output_) { - output_->resize(outputSize_); - return std::move(output_); - } - - if (rightInput_ && leftHasNoInput()) { - // If output_ is currently wrapping a different buffer, return it - // first. - if (prepareOutput(nullptr, rightInput_)) { - output_->resize(outputSize_); - return std::move(output_); - } - - while (true) { - if (!tryAddOutputRowForRightJoin()) { - return std::move(output_); - } - - if (finishedRightBatch()) { - // Ran out of rows on the right side. - rightInput_ = nullptr; - return nullptr; - } - } - VELOX_UNREACHABLE(); - } - - if (rightHasNoInput() && output_) { - output_->resize(outputSize_); - return std::move(output_); - } - } else { - if (leftHasNoInput() || rightHasNoInput()) { - if (output_) { - output_->resize(outputSize_); - return std::move(output_); - } - input_ = nullptr; - rightInput_ = nullptr; - } - } - - return nullptr; + return handleSingleSideOutput(); } - const auto output = handleRightSideNullRows(); + auto output = handleRightSideNullRows(); if (output != nullptr || rightInput_ == nullptr) { return output; } @@ -1147,8 +960,8 @@ RowVectorPtr MergeJoin::doGetOutput() { firstNonNull(input_, leftKeyChannels_, leftRowIndex_ + 1); } - if (finishedLeftBatch()) { - input_ = nullptr; + if (leftBatchFinished()) { + clearLeftInput(); return produceOutput(); } compareResult = compare(); @@ -1163,7 +976,6 @@ RowVectorPtr MergeJoin::doGetOutput() { output_->resize(outputSize_); return std::move(output_); } - if (!tryAddOutputRowForRightJoin()) { return std::move(output_); } @@ -1172,9 +984,9 @@ RowVectorPtr MergeJoin::doGetOutput() { firstNonNull(rightInput_, rightKeyChannels_, rightRowIndex_ + 1); } - if (finishedRightBatch()) { - rightInput_ = nullptr; - return produceOutput(); + if (rightBatchFinished()) { + clearRightInput(); + return nullptr; } compareResult = compare(); } @@ -1186,11 +998,6 @@ RowVectorPtr MergeJoin::doGetOutput() { while (leftEndRow < input_->size() && compareLeft(leftEndRow) == 0) { ++leftEndRow; } - - if (leftEndRow == input_->size()) { - // Matches continue in subsequent input. Load all lazies. - loadColumns(input_, *operatorCtx_->execCtx()); - } leftMatch_ = Match{ {input_}, leftRowIndex_, @@ -1198,43 +1005,47 @@ RowVectorPtr MergeJoin::doGetOutput() { leftEndRow < input_->size(), std::nullopt}; - vector_size_t endRightRow = rightRowIndex_ + 1; - while (endRightRow < rightInput_->size() && - compareRight(endRightRow) == 0) { - ++endRightRow; + vector_size_t rightEndRow = rightRowIndex_ + 1; + while (rightEndRow < rightInput_->size() && + compareRight(rightEndRow) == 0) { + ++rightEndRow; } rightMatch_ = Match{ {rightInput_}, rightRowIndex_, - endRightRow, - endRightRow < rightInput_->size(), + rightEndRow, + rightEndRow < rightInput_->size(), std::nullopt}; + // Track matched rows for this key match. + matchedLeftRows_ += leftEndRow - leftMatch_->startRowIndex; + matchedRightRows_ += rightEndRow - rightMatch_->startRowIndex; + if (!leftMatch_->complete || !rightMatch_->complete) { if (!leftMatch_->complete) { // Need to continue looking for the end of match. - input_ = nullptr; + clearLeftInput(); } if (!rightMatch_->complete) { // Need to continue looking for the end of match. - rightInput_ = nullptr; + clearRightInput(); } return nullptr; } leftRowIndex_ = leftEndRow; if (isFullJoin(joinType_) || isRightJoin(joinType_)) { - rightRowIndex_ = endRightRow; + rightRowIndex_ = rightEndRow; } else { rightRowIndex_ = - firstNonNull(rightInput_, rightKeyChannels_, endRightRow); + firstNonNull(rightInput_, rightKeyChannels_, rightEndRow); } - if (finishedRightBatch()) { + if (rightBatchFinished()) { // Ran out of rows on the right side. - rightInput_ = nullptr; + clearRightInput(); } if (addToOutput()) { @@ -1252,6 +1063,223 @@ RowVectorPtr MergeJoin::doGetOutput() { VELOX_UNREACHABLE(); } +RowVectorPtr MergeJoin::handleSingleSideOutput() { + VELOX_CHECK(!input_ || !rightInput_); + + if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { + if (input_ && rightHasNoInput()) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(input_, nullptr)) { + output_->resize(outputSize_); + return std::move(output_); + } + while (true) { + if (!tryAddOutputRowForLeftJoin()) { + return std::move(output_); + } + + if (leftBatchFinished()) { + clearLeftInput(); + return produceOutput(); + } + } + VELOX_UNREACHABLE(); + } + + if (leftHasNoInput()) { + if (output_ != nullptr) { + output_->resize(outputSize_); + return std::move(output_); + } + if (input_ == nullptr) { + clearRightInput(); + } + } + } else if (isRightJoin(joinType_)) { + if (rightInput_ && leftHasNoInput()) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + + while (true) { + if (!tryAddOutputRowForRightJoin()) { + return std::move(output_); + } + + if (rightBatchFinished()) { + // Ran out of rows on the right side. + clearRightInput(); + return nullptr; + } + } + VELOX_UNREACHABLE(); + } + + if (rightHasNoInput()) { + if (output_ != nullptr) { + output_->resize(outputSize_); + return std::move(output_); + } + if (rightInput_ == nullptr) { + clearLeftInput(); + } + } + } else if (isFullJoin(joinType_)) { + if (input_ && rightHasNoInput()) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(input_, nullptr)) { + output_->resize(outputSize_); + return std::move(output_); + } + + while (true) { + if (!tryAddOutputRowForLeftJoin()) { + return std::move(output_); + } + + if (leftBatchFinished()) { + clearLeftInput(); + return produceOutput(); + } + } + VELOX_UNREACHABLE(); + } + + if (leftHasNoInput() && output_) { + output_->resize(outputSize_); + return std::move(output_); + } + + if (rightInput_ && leftHasNoInput()) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + + while (true) { + if (!tryAddOutputRowForRightJoin()) { + return std::move(output_); + } + + if (rightBatchFinished()) { + // Ran out of rows on the right side. + clearRightInput(); + return nullptr; + } + } + VELOX_UNREACHABLE(); + } + + if (rightHasNoInput() && output_) { + output_->resize(outputSize_); + return std::move(output_); + } + } else { + if (leftHasNoInput() || rightHasNoInput()) { + if (output_) { + output_->resize(outputSize_); + return std::move(output_); + } + clearLeftInput(); + clearRightInput(); + } + } + return nullptr; +} + +bool MergeJoin::advanceMatch() { + VELOX_CHECK(leftMatch_); + VELOX_CHECK(rightMatch_); + + if (!advanceLeftMatch()) { + return false; + } + return advanceRightMatch(); +} + +// Template implementation for advancing left or right match completion. +// Consolidates logic for finding the end of incomplete matches. +template +bool MergeJoin::advanceMatchImpl() { + auto& match = IsLeft ? leftMatch_ : rightMatch_; + auto& input = IsLeft ? input_ : rightInput_; + auto& keyChannels = IsLeft ? leftKeyChannels_ : rightKeyChannels_; + auto& rowIndex = IsLeft ? leftRowIndex_ : rightRowIndex_; + + VELOX_CHECK(match); + + if (input) { + // Look for continuation of a match. + if (!findEndOfMatch(input, keyChannels, match.value())) { + VELOX_CHECK(!match->complete); + // Continue looking for the end of the match. + if constexpr (IsLeft) { + clearLeftInput(); + } else { + clearRightInput(); + } + return false; + } + VELOX_CHECK(match->complete); + + if (match->inputs.back() == input) { + if (IsLeft || isFullJoin(joinType_) || isRightJoin(joinType_)) { + rowIndex = match->endRowIndex; + } else { + rowIndex = firstNonNull(input, keyChannels, match->endRowIndex); + if (rowIndex == input->size()) { + if constexpr (IsLeft) { + clearLeftInput(); + } else { + clearRightInput(); + } + } + } + } + return true; + } + + if constexpr (IsLeft) { + if (leftHasNoInput()) { + match->complete = true; + return true; + } + } else { + if (rightHasNoInput()) { + match->complete = true; + return true; + } + } + + // Need more input. + return false; +} + +// Delegates to the template implementation for left match advancement. +bool MergeJoin::advanceLeftMatch() { + return advanceMatchImpl(); +} + +// Delegates to the template implementation for right match advancement. +bool MergeJoin::advanceRightMatch() { + return advanceMatchImpl(); +} + +void MergeJoin::clearLeftInput() { + input_ = nullptr; +} + +void MergeJoin::clearRightInput() { + rightInput_ = nullptr; +} + RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const auto numRows = output->size(); @@ -1274,7 +1302,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. const auto onMiss = [&](auto row) { - if (isAntiJoin(joinType_)) { + if (isSemiFilterJoin(joinType_)) { return; } rawIndices[numPassed++] = row; @@ -1346,22 +1374,21 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { } }; + auto onMatch = [&](auto row, bool firstMatch) { + const bool isNonSemiAntiJoin = + !isSemiFilterJoin(joinType_) && !isAntiJoin(joinType_); + + if ((isSemiFilterJoin(joinType_) && firstMatch) || isNonSemiAntiJoin) { + rawIndices[numPassed++] = row; + } + }; + for (auto i = 0; i < numRows; ++i) { if (filterRows.isValid(i)) { const bool passed = !decodedFilterResult_.isNullAt(i) && decodedFilterResult_.valueAt(i); - joinTracker_->processFilterResult(i, passed, onMiss); - - if (isAntiJoin(joinType_)) { - if (!passed) { - rawIndices[numPassed++] = i; - } - } else { - if (passed) { - rawIndices[numPassed++] = i; - } - } + joinTracker_->processFilterResult(i, passed, onMiss, onMatch); } else { // This row doesn't have a match on the right side. Keep it // unconditionally. @@ -1371,19 +1398,19 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // Every time we start a new left key match, `processFilterResult()` will // check if at least one row from the previous match passed the filter. If - // none did, it calls onMiss to add a record with null right projections to - // the output. + // none did, it calls onMiss to add a record with null right projections + // to the output. // // Before we leave the current buffer, since we may not have seen the next - // left key match yet, the last key match may still be pending to produce a - // row (because `processFilterResult()` was not called yet). + // left key match yet, the last key match may still be pending to produce + // a row (because `processFilterResult()` was not called yet). // // To handle this, we need to call `noMoreFilterResults()` unless the - // same current left key match may continue in the next buffer. So there are - // two cases to check: + // same current left key match may continue in the next buffer. So there + // are two cases to check: // - // 1. If leftMatch_ is nullopt, there for sure the next buffer will contain - // a different key match. + // 1. If leftMatch_ is nullopt, there for sure the next buffer will + // contain a different key match. // // 2. leftMatch_ may not be nullopt, but may be related to a different // (subsequent) left key. So we check if the last row in the batch has the @@ -1450,6 +1477,14 @@ bool MergeJoin::isFinished() { } void MergeJoin::close() { + // Report match ratio statistics. + { + auto lockedStats = stats_.wlock(); + lockedStats->addRuntimeStat( + "matchedLeftRows", RuntimeCounter(matchedLeftRows_)); + lockedStats->addRuntimeStat( + "matchedRightRows", RuntimeCounter(matchedRightRows_)); + } if (rightSource_) { rightSource_->close(); } @@ -1468,8 +1503,8 @@ void MergeJoin::finishDrain() { leftHasDrained_ = false; rightHasDrained_ = false; - input_ = nullptr; - rightInput_ = nullptr; + clearLeftInput(); + clearRightInput(); leftMatch_.reset(); rightMatch_.reset(); if (joinTracker_.has_value()) { diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index 47dea482554d..67dfadf917fe 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -235,6 +235,12 @@ class MergeJoin : public Operator { // right. bool addToOutputForRightJoin(); + private: + // Template function to consolidate addToOutputForLeftJoin and + // addToOutputForRightJoin + template + bool addToOutputImpl(); + // Tries to add one row of output by writing to the indices of the output // dictionaries. By default, this operator returns dictionaries wrapped around // the input columns from the left and right. If `isRightFlattened_`, the @@ -274,16 +280,62 @@ class MergeJoin : public Operator { // rightRowIndex_ are unchanged. bool tryAddOutputRowForRightJoin(); + // Checks if we need to fetch more input from the right side of the join. + // Returns true if the right side has not been exhausted and we don't have + // a pending future for right side input and there's no current right input + // batch available for processing. + bool needsInputFromRightSide() const; + + // Attempts to get the next batch of input from the right side. This method + // will keep trying to get input until either: + // 1. A right input batch is received + // 2. The right side is exhausted + // 3. A blocking operation is encountered + // Returns true if successful (got input or exhausted), false if blocked. + bool getNextFromRightSide(); + + // Processes the draining state when one or both sides of the join have been + // exhausted. Handles cleanup operations like dropping input sources for + // different join types (inner, left, right, anti) when appropriate. + // Returns true if draining is complete and the operator should return + // nullptr, false if draining is still in progress or there is no draining. + bool processDrain(); + // If all rows from the current left batch have been processed. - bool finishedLeftBatch() const { + bool leftBatchFinished() const { return leftRowIndex_ == input_->size(); } // If all rows from the current right batch have been processed. - bool finishedRightBatch() const { + bool rightBatchFinished() const { return rightRowIndex_ == rightInput_->size(); } + // Tries to complete incomplete matches on both left and right sides. + // Returns true if both matches are complete, false if more input is needed. + bool advanceMatch(); + + // Tries to complete an incomplete match on the left side by finding the end + // of matching key sequence. Returns true if complete, false if more input + // needed. + bool advanceLeftMatch(); + + // Tries to complete an incomplete match on the right side by finding the end + // of matching key sequence. Returns true if complete, false if more input + // needed. + bool advanceRightMatch(); + + // Template function to consolidate advanceLeftMatch and advanceRightMatch + // logic. Uses compile-time template parameter to handle left vs right + // differences. + template + bool advanceMatchImpl(); + + // Handles output generation when only one side of the join has data + // available. Processes unmatched rows for outer joins when the other side is + // exhausted. + RowVectorPtr handleSingleSideOutput(); + // Properly resizes and produces the current output vector if one is // available. RowVectorPtr produceOutput() { @@ -310,6 +362,16 @@ class MergeJoin : public Operator { // rows from the left side that have a match on the right. RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output); + // Clears the current left input batch (input_) by setting it to nullptr. + // Called when the left batch has been fully processed or when resetting + // state during match processing. + void clearLeftInput(); + + // Clears the current right input batch (rightInput_) by setting it to + // nullptr. Called when the right batch has been fully processed or when + // resetting state during match processing. + void clearRightInput(); + // As we populate the results of the join, we track whether a given // output row is a result of a match between left and right sides or a miss. // We use JoinTracker::addMatch and addMiss methods for that. @@ -394,11 +456,12 @@ class MergeJoin : public Operator { // rows that correspond to a single left-side row. Use // 'noMoreFilterResults' to make sure 'onMiss' is called for the last // left-side row. - template + template void processFilterResult( vector_size_t outputIndex, bool passed, - TOnMiss onMiss) { + const TOnMiss& onMiss, + const TOnMatch& onMatch) { const auto rowNumber = rawLeftRowNumbers_[outputIndex]; if (currentLeftRowNumber_ != rowNumber) { if (currentRow_ != -1 && !currentRowPassed_) { @@ -412,6 +475,7 @@ class MergeJoin : public Operator { } if (passed) { + onMatch(outputIndex, /*firstMatch=*/!currentRowPassed_); currentRowPassed_ = true; } } @@ -561,12 +625,18 @@ class MergeJoin : public Operator { vector_size_t outputSize_; // A future that will be completed when right side input becomes available. - ContinueFuture futureRightSideInput_{ContinueFuture::makeEmpty()}; + ContinueFuture rightSideInputFuture_{ContinueFuture::makeEmpty()}; // True if all the right side data has been received. bool noMoreRightInput_{false}; bool leftHasDrained_{false}; bool rightHasDrained_{false}; + + // Stats for tracking matched rows from the left side + uint64_t matchedLeftRows_{0}; + + // Stats for tracking matched rows from the right side + uint64_t matchedRightRows_{0}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/MergeSource.cpp b/velox/exec/MergeSource.cpp index 411decd83e6e..b2e4005ebf62 100644 --- a/velox/exec/MergeSource.cpp +++ b/velox/exec/MergeSource.cpp @@ -198,15 +198,16 @@ class MergeExchangeSource : public MergeSource { memory::MemoryPool* pool, folly::Executor* executor) : mergeExchange_(mergeExchange), - client_(std::make_shared( - mergeExchange->taskId(), - destination, - maxQueuedBytes, - 1, - // Deliver right away to avoid blocking other sources - 0, - pool, - executor)) { + client_( + std::make_shared( + mergeExchange->taskId(), + destination, + maxQueuedBytes, + 1, + // Deliver right away to avoid blocking other sources + 0, + pool, + executor)) { client_->addRemoteTaskId(taskId); client_->noMoreRemoteTasks(); } @@ -286,16 +287,14 @@ class MergeExchangeSource : public MergeSource { std::shared_ptr client_; std::unique_ptr inputStream_; - std::unique_ptr currentPage_; + std::unique_ptr currentPage_; bool atEnd_ = false; }; } // namespace -std::shared_ptr MergeSource::createLocalMergeSource() { - // Buffer up to 2 vectors from each source before blocking to wait - // for consumers. - static const int kDefaultQueueSize = 2; - return std::make_shared(kDefaultQueueSize); +std::shared_ptr MergeSource::createLocalMergeSource( + int queueSize) { + return std::make_shared(queueSize); } std::shared_ptr MergeSource::createMergeExchangeSource( diff --git a/velox/exec/MergeSource.h b/velox/exec/MergeSource.h index e5892add6ed9..ed9f3d1f37fa 100644 --- a/velox/exec/MergeSource.h +++ b/velox/exec/MergeSource.h @@ -46,7 +46,7 @@ class MergeSource { virtual void close() = 0; // Factory methods to create MergeSources. - static std::shared_ptr createLocalMergeSource(); + static std::shared_ptr createLocalMergeSource(int queueSize); static std::shared_ptr createMergeExchangeSource( MergeExchange* mergeExchange, diff --git a/velox/exec/NestedLoopJoinProbe.cpp b/velox/exec/NestedLoopJoinProbe.cpp index 370c57ac103a..3c72407796cc 100644 --- a/velox/exec/NestedLoopJoinProbe.cpp +++ b/velox/exec/NestedLoopJoinProbe.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/NestedLoopJoinProbe.h" +#include "velox/exec/DriverStats.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" @@ -250,6 +251,17 @@ RowVectorPtr NestedLoopJoinProbe::getOutput() { // Generate actual join output by processing probe and build matches, and // probe mismaches (for left joins). output = generateOutput(); + + // generateOutput() can be computationally expensive when + // processing large buildVectors_, as it iterates through all build rows for + // each probe row. The probe moves slowly through buildVectors_, making the + // numOutputRows_ >= outputBatchSize_ check harder to reach, which can cause + // driver threads to get stuck in long-running processing loops that exceed + // their allocated CPU time slice limits, preventing other tasks from being + // scheduled and degrading overall system responsiveness. + if (FOLLY_UNLIKELY(shouldYield())) { + break; + } } return output; } diff --git a/velox/exec/NestedLoopJoinProbe.h b/velox/exec/NestedLoopJoinProbe.h index aa2a0f6b4080..cbec6cc4dae0 100644 --- a/velox/exec/NestedLoopJoinProbe.h +++ b/velox/exec/NestedLoopJoinProbe.h @@ -124,7 +124,7 @@ class NestedLoopJoinProbe : public Operator { // true), or if the output is full (returns false). If it returns false, a // valid vector with more than zero records will be available at `output_`; if // it returns true, either nullptr or zero records may be placed at `output_`. - // Also if it returns true, it's the caller's responsiblity to deicide when to + // Also if it returns true, it's the caller's responsiblity to decide when to // set `output_` size. // // Also updates `buildMatched_` if the build records that received a match, so diff --git a/velox/exec/OneWayStatusFlag.h b/velox/exec/OneWayStatusFlag.h index 8610585d155f..4eea37c24713 100644 --- a/velox/exec/OneWayStatusFlag.h +++ b/velox/exec/OneWayStatusFlag.h @@ -16,53 +16,28 @@ #pragma once -#include #include namespace facebook::velox::exec { -/// A simple one way status flag that uses a non atomic flag to avoid -/// unnecessary atomic operations. class OneWayStatusFlag { public: - bool check() const { -#if defined(__x86_64__) - /// This flag is can only go from false to true, and must is only checked at - /// the end of a loop. Given that once a flag is true it can never go back - /// to false, we are ok to use this in a non synchronized manner to avoid - /// the overhead. As such we consciously exempt ourselves here from TSAN - /// detection. - folly::annotate_ignore_thread_sanitizer_guard g(__FILE__, __LINE__); - return fastStatus_ || atomicStatus_.load(); -#else - return atomicStatus_.load(std::memory_order_relaxed) || - atomicStatus_.load(); -#endif + bool check() const noexcept { + return status_.load(std::memory_order_acquire); } - void set() { -#if defined(__x86_64__) - if (!fastStatus_) { - atomicStatus_.store(true); - fastStatus_ = true; + void set() noexcept { + if (!status_.load(std::memory_order_relaxed)) { + status_.store(true, std::memory_order_release); } -#else - if (!atomicStatus_.load(std::memory_order_relaxed)) { - atomicStatus_.store(true); - } -#endif } - /// Operator overload to convert OneWayStatusFlag to bool - operator bool() const { + explicit operator bool() const noexcept { return check(); } private: -#if defined(__x86_64__) - bool fastStatus_{false}; -#endif - std::atomic_bool atomicStatus_{false}; + std::atomic_bool status_{false}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/Operator.cpp b/velox/exec/Operator.cpp index 29bfbef7cb5e..1729247854f7 100644 --- a/velox/exec/Operator.cpp +++ b/velox/exec/Operator.cpp @@ -68,9 +68,12 @@ OperatorCtx::createConnectorQueryCtx( driverCtx_->driverId, driverCtx_->queryConfig().sessionTimezone(), driverCtx_->queryConfig().adjustTimestampToTimezone(), - task->getCancellationToken()); + task->getCancellationToken(), + task->queryCtx()->fsTokenProvider()); connectorQueryCtx->setSelectiveNimbleReaderEnabled( driverCtx_->queryConfig().selectiveNimbleReaderEnabled()); + connectorQueryCtx->setRowSizeTrackingMode( + driverCtx_->queryConfig().rowSizeTrackingMode()); return connectorQueryCtx; } @@ -81,18 +84,23 @@ Operator::Operator( std::string planNodeId, std::string operatorType, std::optional spillConfig) - : operatorCtx_(std::make_unique( - driverCtx, - planNodeId, - operatorId, - operatorType)), + : operatorCtx_( + std::make_unique( + driverCtx, + planNodeId, + operatorId, + operatorType)), outputType_(std::move(outputType)), spillConfig_(std::move(spillConfig)), - stats_(OperatorStats{ - operatorId, - driverCtx->pipelineId, - std::move(planNodeId), - std::move(operatorType)}) {} + dryRun_( + operatorCtx_->driverCtx()->traceConfig().has_value() && + operatorCtx_->driverCtx()->traceConfig()->dryRun), + stats_( + OperatorStats{ + operatorId, + driverCtx->pipelineId, + std::move(planNodeId), + std::move(operatorType)}) {} void Operator::maybeSetReclaimer() { VELOX_CHECK_NULL(pool()->reclaimer()); @@ -111,7 +119,7 @@ void Operator::maybeSetTracer() { } const auto nodeId = planNodeId(); - if (traceConfig->queryNodes.count(nodeId) == 0) { + if (traceConfig->queryNodeId.empty() || traceConfig->queryNodeId != nodeId) { return; } @@ -141,7 +149,7 @@ void Operator::maybeSetTracer() { opTraceDirPath, operatorCtx_->driverCtx()->queryConfig().opTraceDirectoryCreateConfig()); - if (operatorType() == "TableScan") { + if (dynamic_cast(this) != nullptr) { setupSplitTracer(opTraceDirPath); } else { setupInputTracer(opTraceDirPath); @@ -375,7 +383,7 @@ void Operator::recordBlockingTime(uint64_t start, BlockingReason reason) { std::chrono::high_resolution_clock::now().time_since_epoch()) .count(); const auto wallNanos = (now - start) * 1000; - const auto blockReason = blockingReasonToString(reason).substr(1); + const auto blockReason = BlockingReasonName::toName(reason).substr(1); auto lockedStats = stats_.wlock(); lockedStats->blockedWallNanos += wallNanos; @@ -387,7 +395,7 @@ void Operator::recordBlockingTime(uint64_t start, BlockingReason reason) { } void Operator::recordSpillStats() { - const auto lockedSpillStats = spillStats_.wlock(); + const auto lockedSpillStats = spillStats_->wlock(); auto lockedStats = stats_.wlock(); lockedStats->spilledInputBytes += lockedSpillStats->spilledInputBytes; lockedStats->spilledBytes += lockedSpillStats->spilledBytes; @@ -589,6 +597,14 @@ void OperatorStats::add(const OperatorStats& other) { } } + for (const auto& [name, exprStats] : other.expressionStats) { + if (UNLIKELY(expressionStats.count(name) == 0)) { + expressionStats.insert(std::make_pair(name, exprStats)); + } else { + expressionStats.at(name).add(exprStats); + } + } + numDrivers += other.numDrivers; spilledInputBytes += other.spilledInputBytes; spilledBytes += other.spilledBytes; @@ -625,6 +641,7 @@ void OperatorStats::clear() { memoryStats.clear(); runtimeStats.clear(); + expressionStats.clear(); numDrivers = 0; spilledInputBytes = 0; diff --git a/velox/exec/Operator.h b/velox/exec/Operator.h index fd3bb3234ac5..7af6e4cd34d5 100644 --- a/velox/exec/Operator.h +++ b/velox/exec/Operator.h @@ -17,11 +17,11 @@ #include #include "velox/core/PlanNode.h" +#include "velox/core/QueryCtx.h" #include "velox/exec/Driver.h" #include "velox/exec/JoinBridge.h" #include "velox/exec/OperatorStats.h" #include "velox/exec/OperatorTraceWriter.h" -#include "velox/type/Filter.h" namespace facebook::velox::exec { @@ -152,6 +152,11 @@ class Operator : public BaseRuntimeStatWriter { } }; + /// The name for background cpu time metric if operator has background cpu + /// usages outside its driver thread. + static inline const std::string kBackgroundCpuTimeNanos = + "backgroundCpuTimeNanos"; + /// The name of the runtime spill stats collected and reported by operators /// that support spilling. @@ -283,40 +288,31 @@ class Operator : public BaseRuntimeStatWriter { /// build side is empty. virtual bool isFinished() = 0; + /// True if the operator is in dry run mode which is only used by input + /// trace collection for crash debugging. + bool dryRun() const { + return dryRun_; + } + /// Traces input batch of the operator. virtual void traceInput(const RowVectorPtr&); /// Finishes tracing of the operator. virtual void finishTrace(); - /// Returns single-column dynamically generated filters to be pushed down to - /// upstream operators. Used to push down filters on join keys from broadcast - /// hash join into probe-side table scan. Can also be used to push down TopN - /// cutoff. - virtual const std:: - unordered_map>& - getDynamicFilters() const { - return dynamicFilters_; - } - - /// Clears dynamically generated filters. Called after filters were pushed - /// down. - virtual void clearDynamicFilters() { - dynamicFilters_.clear(); - } - /// Returns true if this operator would accept a filter dynamically generated /// by a downstream operator. virtual bool canAddDynamicFilter() const { return false; } - /// Adds a filter dynamically generated by a downstream operator. Called only - /// if canAddFilter() returns true. - virtual void addDynamicFilter( + /// Adds pending filters dynamically generated by a downstream + /// operator. Called only if canAddDynamicFilter() returns true. Shared lock + /// on the PushdownFilters is already held by current thread when this method + /// is called by driver. + virtual void addDynamicFilterLocked( const core::PlanNodeId& /*producer*/, - column_index_t /*outputChannel*/, - const std::shared_ptr& /*filter*/) { + const PushdownFilters& /*filters*/) { VELOX_UNSUPPORTED( "This operator doesn't support dynamic filter pushdown: {}", toString()); @@ -632,10 +628,13 @@ class Operator : public BaseRuntimeStatWriter { /// the fs dir path to store spill files), otherwise null. const std::optional spillConfig_; + const bool dryRun_; + bool initialized_{false}; folly::Synchronized stats_; - folly::Synchronized spillStats_; + std::shared_ptr> spillStats_ = + std::make_shared>(); /// NOTE: only one of the two could be set for an operator for tracing . /// 'splitTracer_' is only set for table scan to record the processed split @@ -664,8 +663,14 @@ class Operator : public BaseRuntimeStatWriter { /// could copy directly from input to output if no cardinality change. bool isIdentityProjection_ = false; - std::unordered_map> - dynamicFilters_; + /// Returns true if the driver should yield execution to prevent getting + /// stuck in long processing loops that exceed their allocated CPU time + /// slice limits. Operators should call this method as yield points during + /// time-consuming operations to allow other tasks to be scheduled and + /// maintain system responsiveness. + bool shouldYield() const { + return operatorCtx_->driverCtx()->driver->shouldYield(); + } private: // Setup 'inputTracer_' to record the processed input vectors. @@ -698,13 +703,15 @@ class SourceOperator : public Operator { RowTypePtr outputType, int32_t operatorId, const std::string& planNodeId, - const std::string& operatorType) + const std::string& operatorType, + const std::optional& spillConfig = std::nullopt) : Operator( driverCtx, std::move(outputType), operatorId, planNodeId, - operatorType) {} + operatorType, + spillConfig) {} bool needsInput() const override { return false; diff --git a/velox/exec/OperatorStats.h b/velox/exec/OperatorStats.h index f833b1b10403..9c95bf25eb41 100644 --- a/velox/exec/OperatorStats.h +++ b/velox/exec/OperatorStats.h @@ -15,8 +15,11 @@ */ #pragma once +#include "velox/common/base/RuntimeMetrics.h" #include "velox/common/memory/MemoryPool.h" #include "velox/common/time/CpuWallTimer.h" +#include "velox/core/PlanNode.h" +#include "velox/expression/ExprStats.h" namespace facebook::velox::exec { @@ -86,6 +89,11 @@ struct DynamicFilterStats { }; struct OperatorStats { + /// Runtime stat name for per-driver CPU time (actual work time, not including + /// blocked time) for this operator. The max field will contain the CPU time + /// from the longest running single driver. + static constexpr const char* kDriverCpuTime = "driverCpuTimeNanos"; + /// Initial ordinal position in the operator's pipeline. int32_t operatorId = 0; int32_t pipelineId = 0; @@ -181,6 +189,11 @@ struct OperatorStats { std::unordered_map runtimeStats; + // A map of expression name to its respective stats. + // These are only populated when a copy of the stats is returned via + // Operator::stats(bool) API. + std::unordered_map expressionStats; + int numDrivers = 0; OperatorStats() = default; diff --git a/velox/exec/OperatorTraceWriter.h b/velox/exec/OperatorTraceWriter.h index 437f43daba0b..189577dee1f0 100644 --- a/velox/exec/OperatorTraceWriter.h +++ b/velox/exec/OperatorTraceWriter.h @@ -59,7 +59,7 @@ class OperatorTraceInputWriter { true, common::CompressionKind::CompressionKind_ZSTD, 0.8, - /*nullsFirst=*/true}; + /*_nullsFirst=*/true}; const std::shared_ptr fs_; memory::MemoryPool* const pool_; VectorSerde* const serde_; diff --git a/velox/exec/OperatorUtils.cpp b/velox/exec/OperatorUtils.cpp index 89f2eb1d0322..ba386ade1192 100644 --- a/velox/exec/OperatorUtils.cpp +++ b/velox/exec/OperatorUtils.cpp @@ -14,10 +14,13 @@ * limitations under the License. */ #include "velox/exec/OperatorUtils.h" +#include "velox/exec/PartitionedOutput.h" #include "velox/exec/VectorHasher.h" #include "velox/expression/EvalCtx.h" +#include "velox/serializers/PrestoSerializer.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/FlatVector.h" +#include "velox/vector/LazyVector.h" namespace facebook::velox::exec { @@ -100,7 +103,11 @@ void gatherCopy( const std::vector& sources, const std::vector& sourceIndices, column_index_t sourceChannel) { - if (target->isScalar()) { + const bool flattenSources = + std::all_of(sources.begin(), sources.end(), [](const auto& source) { + return source->isFlatEncoding(); + }); + if (target->isScalar() && flattenSources) { VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( scalarGatherCopy, target->type()->kind(), @@ -121,10 +128,11 @@ void gatherCopy( bool shouldAggregateRuntimeMetric(const std::string& name) { static const folly::F14FastSet metricNames{ "dataSourceAddSplitWallNanos", - "dataSourceReadWallNanos", "dataSourceLazyWallNanos", "queuedWallNanos", - "flushTimes"}; + "flushTimes", + "driverCpuTimeNanos", + "ioWaitWallNanos"}; if (metricNames.contains(name)) { return true; } @@ -230,7 +238,7 @@ vector_size_t processEncodedFilterResults( const SelectivityVector& rows, FilterEvalCtx& filterEvalCtx, memory::MemoryPool* pool) { - auto size = rows.size(); + const auto size = rows.size(); DecodedVector& decoded = filterEvalCtx.decodedResult; decoded.decode(*filterResult.get(), rows); @@ -456,6 +464,14 @@ std::string makeOperatorSpillPath( return fmt::format("{}/{}_{}_{}", spillDir, pipelineId, driverId, operatorId); } +void setOperatorRuntimeStats( + const std::string& name, + const RuntimeCounter& value, + std::unordered_map& stats) { + stats[name] = RuntimeMetric(value.unit); + stats[name].addValue(value.value); +} + void addOperatorRuntimeStats( const std::string& name, const RuntimeCounter& value, @@ -505,12 +521,24 @@ void projectChildren( const std::vector& projections, int32_t size, const BufferPtr& mapping) { + int maxInputChannel = -1; + int maxOutputChannel = -1; for (auto [inputChannel, outputChannel] : projections) { - if (outputChannel >= projectedChildren.size()) { - projectedChildren.resize(outputChannel + 1); + maxInputChannel = std::max(maxInputChannel, inputChannel); + maxOutputChannel = std::max(maxOutputChannel, outputChannel); + } + // Cache for already wrapped children to avoid wrapping the same child + // multiple times. + std::vector wrappedChildren(1 + maxInputChannel); + if (1 + maxOutputChannel > projectedChildren.size()) { + projectedChildren.resize(1 + maxOutputChannel); + } + for (auto [inputChannel, outputChannel] : projections) { + auto& wrapped = wrappedChildren[inputChannel]; + if (!wrapped) { + wrapped = wrapChild(size, mapping, src[inputChannel]); } - projectedChildren[outputChannel] = - wrapChild(size, mapping, src[inputChannel]); + projectedChildren[outputChannel] = wrapped; } } @@ -549,4 +577,20 @@ std::unique_ptr BlockedOperatorFactory::toOperator( } return nullptr; } + +std::unique_ptr getVectorSerdeOptions( + common::CompressionKind compressionKind, + VectorSerde::Kind kind, + std::optional minCompressionRatio) { + std::unique_ptr options = + kind == VectorSerde::Kind::kPresto + ? std::make_unique() + : std::make_unique(); + options->compressionKind = compressionKind; + if (minCompressionRatio.has_value()) { + options->minCompressionRatio = minCompressionRatio.value(); + } + return options; +} + } // namespace facebook::velox::exec diff --git a/velox/exec/OperatorUtils.h b/velox/exec/OperatorUtils.h index 1f7c25efda32..605edba4a288 100644 --- a/velox/exec/OperatorUtils.h +++ b/velox/exec/OperatorUtils.h @@ -15,8 +15,11 @@ */ #pragma once +#include +#include "velox/core/QueryConfig.h" #include "velox/exec/Operator.h" #include "velox/exec/Spiller.h" +#include "velox/vector/VectorStream.h" namespace facebook::velox::exec { @@ -145,6 +148,12 @@ std::string makeOperatorSpillPath( int driverId, int32_t operatorId); +/// Set a named runtime metric in operator 'stats'. +void setOperatorRuntimeStats( + const std::string& name, + const RuntimeCounter& value, + std::unordered_map& stats); + /// Add a named runtime metric to operator 'stats'. void addOperatorRuntimeStats( const std::string& name, @@ -168,7 +177,8 @@ folly::Range initializeRowNumberMapping( /// Projects children of 'src' row vector according to 'projections'. Optionally /// takes a 'mapping' and 'size' that represent the indices and size, /// respectively, of a dictionary wrapping that should be applied to the -/// projections. The output param 'projectedChildren' will contain all the final +/// projections. Dictionary wrapping of the same child vector is cached for +/// reuse. The output param 'projectedChildren' will contain all the final /// projections at the expected channel index. Indices not specified in /// 'projections' will be left untouched in 'projectedChildren'. void projectChildren( @@ -300,4 +310,12 @@ class BlockedOperatorFactory : public Operator::PlanNodeTranslator { private: BlockedOperatorCb blockedCb_{nullptr}; }; + +/// Creates VectorSerde::Options for the given VectorSerde kind with compression +/// settings. Optionally configures minimum compression ratio. +std::unique_ptr getVectorSerdeOptions( + common::CompressionKind compressionKind, + VectorSerde::Kind kind, + std::optional minCompressionRatio = std::nullopt); + } // namespace facebook::velox::exec diff --git a/velox/exec/OrderBy.cpp b/velox/exec/OrderBy.cpp index 93fd410c3307..dc6b5d54809b 100644 --- a/velox/exec/OrderBy.cpp +++ b/velox/exec/OrderBy.cpp @@ -67,7 +67,7 @@ OrderBy::OrderBy( &nonReclaimableSection_, driverCtx->prefixSortConfig(), spillConfig_.has_value() ? &(spillConfig_.value()) : nullptr, - &spillStats_); + spillStats_.get()); } void OrderBy::addInput(RowVectorPtr input) { diff --git a/velox/exec/OutputBuffer.cpp b/velox/exec/OutputBuffer.cpp index eb00d51852f4..f43fab2b8cc3 100644 --- a/velox/exec/OutputBuffer.cpp +++ b/velox/exec/OutputBuffer.cpp @@ -30,10 +30,10 @@ void ArbitraryBuffer::noMoreData() { pages_.push_back(nullptr); } -void ArbitraryBuffer::enqueue(std::unique_ptr page) { +void ArbitraryBuffer::enqueue(std::unique_ptr page) { VELOX_CHECK_NOT_NULL(page, "Unexpected null page"); VELOX_CHECK(!hasNoMoreData(), "Arbitrary buffer has set no more data marker"); - pages_.push_back(std::shared_ptr(page.release())); + pages_.push_back(std::shared_ptr(page.release())); } void ArbitraryBuffer::getAvailablePageSizes(std::vector& out) const { @@ -45,7 +45,7 @@ void ArbitraryBuffer::getAvailablePageSizes(std::vector& out) const { } } -std::vector> ArbitraryBuffer::getPages( +std::vector> ArbitraryBuffer::getPages( uint64_t maxBytes) { if (maxBytes == 0 && !pages_.empty() && pages_.front() == nullptr) { // Always give out an end marker when this buffer is finished and fully @@ -57,7 +57,7 @@ std::vector> ArbitraryBuffer::getPages( VELOX_CHECK_EQ(pages_.size(), 1); return {nullptr}; } - std::vector> pages; + std::vector> pages; uint64_t bytesRemoved{0}; while (bytesRemoved < maxBytes && !pages_.empty()) { if (pages_.front() == nullptr) { @@ -81,7 +81,7 @@ std::string ArbitraryBuffer::toString() const { hasNoMoreData()); } -void DestinationBuffer::Stats::recordEnqueue(const SerializedPage& data) { +void DestinationBuffer::Stats::recordEnqueue(const SerializedPageBase& data) { const auto numRows = data.numRows(); VELOX_CHECK(numRows.has_value(), "SerializedPage's numRows must be valid"); bytesBuffered += data.size(); @@ -89,7 +89,8 @@ void DestinationBuffer::Stats::recordEnqueue(const SerializedPage& data) { ++pagesBuffered; } -void DestinationBuffer::Stats::recordAcknowledge(const SerializedPage& data) { +void DestinationBuffer::Stats::recordAcknowledge( + const SerializedPageBase& data) { const auto numRows = data.numRows(); VELOX_CHECK(numRows.has_value(), "SerializedPage's numRows must be valid"); const int64_t size = data.size(); @@ -104,7 +105,7 @@ void DestinationBuffer::Stats::recordAcknowledge(const SerializedPage& data) { ++pagesSent; } -void DestinationBuffer::Stats::recordDelete(const SerializedPage& data) { +void DestinationBuffer::Stats::recordDelete(const SerializedPageBase& data) { recordAcknowledge(data); } @@ -185,7 +186,7 @@ DestinationBuffer::Data DestinationBuffer::getData( return {std::move(data), std::move(remainingBytes), true}; } -void DestinationBuffer::enqueue(std::shared_ptr data) { +void DestinationBuffer::enqueue(std::shared_ptr data) { // Drop duplicate end markers. if (data == nullptr && !data_.empty() && data_.back() == nullptr) { return; @@ -245,7 +246,7 @@ void DestinationBuffer::loadData(ArbitraryBuffer* buffer, uint64_t maxBytes) { } } -std::vector> DestinationBuffer::acknowledge( +std::vector> DestinationBuffer::acknowledge( int64_t sequence, bool fromGetData) { const int64_t numDeleted = sequence - sequence_; @@ -268,7 +269,7 @@ std::vector> DestinationBuffer::acknowledge( VELOX_CHECK_LE( numDeleted, data_.size(), "Ack received for a not yet produced item"); - std::vector> freed; + std::vector> freed; for (auto i = 0; i < numDeleted; ++i) { if (data_[i] == nullptr) { VELOX_CHECK_EQ(i, data_.size() - 1, "null marker found in the middle"); @@ -282,9 +283,9 @@ std::vector> DestinationBuffer::acknowledge( return freed; } -std::vector> +std::vector> DestinationBuffer::deleteResults() { - std::vector> freed; + std::vector> freed; for (auto i = 0; i < data_.size(); ++i) { if (data_[i] == nullptr) { VELOX_CHECK_EQ(i, data_.size() - 1, "null marker found in the middle"); @@ -314,7 +315,7 @@ namespace { // that we do the expensive free outside and only then continue the // producers which will allocate more memory. void releaseAfterAcknowledge( - std::vector>& freed, + std::vector>& freed, std::vector& promises) { freed.clear(); for (auto& promise : promises) { @@ -445,7 +446,7 @@ void OutputBuffer::updateTotalBufferedBytesMsLocked() { bool OutputBuffer::enqueue( int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future) { VELOX_CHECK_NOT_NULL(data); VELOX_CHECK( @@ -471,8 +472,6 @@ bool OutputBuffer::enqueue( enqueuePartitionedOutputLocked( destination, std::move(data), dataAvailableCallbacks); break; - default: - VELOX_UNREACHABLE(PartitionedOutputNode::kindString(kind_)); } if (bufferedBytes_ >= maxSize_ && future) { @@ -494,13 +493,13 @@ bool OutputBuffer::enqueue( } void OutputBuffer::enqueueBroadcastOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs) { VELOX_DCHECK(isBroadcast()); VELOX_CHECK_NULL(arbitraryBuffer_); VELOX_DCHECK(dataAvailableCbs.empty()); - std::shared_ptr sharedData(data.release()); + std::shared_ptr sharedData(data.release()); for (auto& buffer : buffers_) { if (buffer != nullptr) { buffer->enqueue(sharedData); @@ -516,7 +515,7 @@ void OutputBuffer::enqueueBroadcastOutputLocked( } void OutputBuffer::enqueueArbitraryOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs) { VELOX_DCHECK(isArbitrary()); VELOX_DCHECK_NOT_NULL(arbitraryBuffer_); @@ -543,7 +542,7 @@ void OutputBuffer::enqueueArbitraryOutputLocked( void OutputBuffer::enqueuePartitionedOutputLocked( int destination, - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs) { VELOX_DCHECK(isPartitioned()); VELOX_CHECK_NULL(arbitraryBuffer_); @@ -633,7 +632,7 @@ bool OutputBuffer::isFinishedLocked() { } void OutputBuffer::acknowledge(int destination, int64_t sequence) { - std::vector> freed; + std::vector> freed; std::vector promises; { std::lock_guard l(mutex_); @@ -651,7 +650,7 @@ void OutputBuffer::acknowledge(int destination, int64_t sequence) { } void OutputBuffer::updateAfterAcknowledgeLocked( - const std::vector>& freed, + const std::vector>& freed, std::vector& promises) { uint64_t freedBytes{0}; int freedPages{0}; @@ -675,7 +674,7 @@ void OutputBuffer::updateAfterAcknowledgeLocked( } bool OutputBuffer::deleteResults(int destination) { - std::vector> freed; + std::vector> freed; std::vector promises; bool isFinished; DataAvailable dataAvailable; @@ -719,7 +718,7 @@ void OutputBuffer::getData( DataAvailableCallback notify, DataConsumerActiveCheckCallback activeCheck) { DestinationBuffer::Data data; - std::vector> freed; + std::vector> freed; std::vector promises; { std::lock_guard l(mutex_); diff --git a/velox/exec/OutputBuffer.h b/velox/exec/OutputBuffer.h index 640c1bfcd8e7..6b9667fece62 100644 --- a/velox/exec/OutputBuffer.h +++ b/velox/exec/OutputBuffer.h @@ -75,11 +75,11 @@ class ArbitraryBuffer { /// appends a null page at the end of 'pages_' as end marker. void noMoreData(); - void enqueue(std::unique_ptr page); + void enqueue(std::unique_ptr page); /// Returns a number of pages with total bytes no less than 'maxBytes' if /// there are sufficient buffered pages. - std::vector> getPages(uint64_t maxBytes); + std::vector> getPages(uint64_t maxBytes); /// Append the available page sizes to `out'. void getAvailablePageSizes(std::vector& out) const; @@ -87,7 +87,7 @@ class ArbitraryBuffer { std::string toString() const; private: - std::deque> pages_; + std::deque> pages_; }; class DestinationBuffer { @@ -98,11 +98,11 @@ class DestinationBuffer { /// 2. Sent: the data is removed from the buffer after it is acked or /// deleted. struct Stats { - void recordEnqueue(const SerializedPage& data); + void recordEnqueue(const SerializedPageBase& data); - void recordAcknowledge(const SerializedPage& data); + void recordAcknowledge(const SerializedPageBase& data); - void recordDelete(const SerializedPage& data); + void recordDelete(const SerializedPageBase& data); bool finished{false}; @@ -117,7 +117,7 @@ class DestinationBuffer { int64_t pagesSent{0}; }; - void enqueue(std::shared_ptr data); + void enqueue(std::shared_ptr data); /// Invoked to load data with up to 'notifyMaxBytes_' bytes from arbitrary /// 'buffer' if there is pending fetch from this destination in which case @@ -165,12 +165,12 @@ class DestinationBuffer { /// do not give a warning for the case where no data is removed, otherwise we /// expect that data does get freed. We cannot assert that data gets deleted /// because acknowledge messages can arrive out of order. - std::vector> acknowledge( + std::vector> acknowledge( int64_t sequence, bool fromGetData); /// Removes all remaining data from the queue and returns the removed data. - std::vector> deleteResults(); + std::vector> deleteResults(); /// Returns and clears the notify callback, if any, along with arguments for /// the callback. @@ -187,7 +187,7 @@ class DestinationBuffer { private: void clearNotify(); - std::vector> data_; + std::vector> data_; // The sequence number of the first in 'data_'. int64_t sequence_ = 0; DataAvailableCallback notify_{nullptr}; @@ -280,7 +280,7 @@ class OutputBuffer { bool enqueue( int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future); void noMoreData(); @@ -345,7 +345,7 @@ class OutputBuffer { // Updates buffered size and returns possibly continuable producer promises // in 'promises'. void updateAfterAcknowledgeLocked( - const std::vector>& freed, + const std::vector>& freed, std::vector& promises); /// Given an updated total number of broadcast buffers, add any missing ones @@ -353,16 +353,16 @@ class OutputBuffer { void addOutputBuffersLocked(int numBuffers); void enqueueBroadcastOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs); void enqueueArbitraryOutputLocked( - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs); void enqueuePartitionedOutputLocked( int destination, - std::unique_ptr data, + std::unique_ptr data, std::vector& dataAvailableCbs); std::string toStringLocked() const; @@ -401,7 +401,7 @@ class OutputBuffer { // While noMoreBuffers_ is false, stores the enqueued data to // broadcast to destinations that have not yet been initialized. Cleared // after receiving no-more-broadcast-buffers signal. - std::vector> dataToBroadcast_; + std::vector> dataToBroadcast_; std::mutex mutex_; // Actual data size in 'buffers_'. diff --git a/velox/exec/OutputBufferManager.cpp b/velox/exec/OutputBufferManager.cpp index 4773911a530d..f8183218fb4f 100644 --- a/velox/exec/OutputBufferManager.cpp +++ b/velox/exec/OutputBufferManager.cpp @@ -57,7 +57,7 @@ uint64_t OutputBufferManager::numBuffers() const { bool OutputBufferManager::enqueue( const std::string& taskId, int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future) { return getBuffer(taskId)->enqueue(destination, std::move(data), future); } diff --git a/velox/exec/OutputBufferManager.h b/velox/exec/OutputBufferManager.h index 8affa6f90512..ef9487ee87b5 100644 --- a/velox/exec/OutputBufferManager.h +++ b/velox/exec/OutputBufferManager.h @@ -53,7 +53,7 @@ class OutputBufferManager { bool enqueue( const std::string& taskId, int destination, - std::unique_ptr data, + std::unique_ptr data, ContinueFuture* future); void noMoreData(const std::string& taskId); diff --git a/velox/exec/ParallelProject.cpp b/velox/exec/ParallelProject.cpp new file mode 100644 index 000000000000..d632553968a9 --- /dev/null +++ b/velox/exec/ParallelProject.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/ParallelProject.h" +#include "velox/common/base/AsyncSource.h" +#include "velox/exec/Operator.h" +#include "velox/exec/Task.h" + +namespace facebook::velox::exec { + +ParallelProject::ParallelProject( + int32_t operatorId, + DriverCtx* driverCtx, + const core::ParallelProjectNodePtr& node) + : Operator( + driverCtx, + node->outputType(), + operatorId, + node->id(), + "ParallelProject"), + node_(node) {} + +namespace { +bool checkAddIdentityProjection( + const core::TypedExprPtr& projection, + const RowTypePtr& inputType, + column_index_t outputChannel, + std::vector& identityProjections) { + if (auto field = core::TypedExprs::asFieldAccess(projection)) { + const auto& inputs = field->inputs(); + if (inputs.empty() || + (inputs.size() == 1 && + dynamic_cast(inputs[0].get()))) { + const auto inputChannel = inputType->getChildIdx(field->name()); + identityProjections.emplace_back(inputChannel, outputChannel); + return true; + } + } + + return false; +} +} // namespace + +void ParallelProject::initialize() { + Operator::initialize(); + std::vector allExprs; + + const auto& inputType = node_->sources()[0]->outputType(); + const auto& exprGroups = node_->exprGroups(); + + int32_t unitIdx = 0; + int32_t exprIdx = 0; + int32_t unitSize = exprGroups[unitIdx].size(); + work_.emplace_back(); + work_.back().execCtx = std::make_unique( + operatorCtx_->pool(), operatorCtx_->driverCtx()->task->queryCtx().get()); + + std::vector unitExprs; + for (column_index_t i = 0; i < node_->exprNames().size(); i++) { + const auto& projection = exprGroups[unitIdx][exprIdx]; + bool identityProjection = checkAddIdentityProjection( + projection, inputType, i, identityProjections_); + if (!identityProjection) { + unitExprs.push_back(projection); + work_.back().resultProjections.emplace_back(unitExprs.size() - 1, i); + } else { + work_.back().loadOnly.push_back(identityProjections_.back().inputChannel); + } + ++exprIdx; + if (exprIdx == unitSize) { + // It may be that the only work is loading lazies. + auto tempExprs = + makeExprSetFromFlag(std::move(unitExprs), operatorCtx_->execCtx()); + std::shared_ptr shared(tempExprs.release()); + work_.back().exprSet = std::move(shared); + ++unitIdx; + exprIdx = 0; + if (unitIdx == exprGroups.size()) { + break; + } + unitSize = exprGroups[unitIdx].size(); + work_.emplace_back(); + work_.back().execCtx = std::make_unique( + operatorCtx_->pool(), + operatorCtx_->driverCtx()->task->queryCtx().get()); + } + } + + int32_t outputIdx = node_->exprNames().size(); + auto sourceType = node_->sources()[0]->outputType(); + for (auto& name : node_->noLoadIdentities()) { + auto idx = sourceType->getChildIdx(name); + identityProjections_.emplace_back(idx, outputIdx++); + } +} + +void ParallelProject::addInput(RowVectorPtr input) { + input_ = std::move(input); + numProcessedInputRows_ = 0; +} + +bool ParallelProject::allInputProcessed() { + if (!input_) { + return true; + } + if (numProcessedInputRows_ == input_->size()) { + input_ = nullptr; + return true; + } + return false; +} + +bool ParallelProject::isFinished() { + return noMoreInput_ && allInputProcessed(); +} + +RowVectorPtr ParallelProject::getOutput() { + if (allInputProcessed()) { + return nullptr; + } + + vector_size_t size = input_->size(); + allRows_.resize(size); + allRows_.setAll(); + std::vector>> pending; + std::vector results(outputType_->size()); + + for (auto i = 0; i < work_.size(); ++i) { + pending.push_back( + std::make_shared>( + [i, &results, this]() { return doWork(i, results); })); + auto item = pending.back(); + operatorCtx_->task()->queryCtx()->executor()->add( + [item]() { item->prepare(); }); + } + std::exception_ptr error; + for (auto i = 0; i < pending.size(); ++i) { + auto result = pending[i]->move(); + stats_.wlock()->getOutputTiming.add(pending[i]->prepareTiming()); + if (!error && result->error) { + error = result->error; + } + } + if (error) { + std::rethrow_exception(error); + } + + for (auto& projection : identityProjections_) { + results[projection.outputChannel] = + input_->childAt(projection.inputChannel); + } + numProcessedInputRows_ = size; + input_.reset(); + return std::make_shared( + operatorCtx_->pool(), outputType_, nullptr, size, std::move(results)); +} + +std::unique_ptr ParallelProject::doWork( + int32_t workIdx, + std::vector& results) { + auto& work = work_[workIdx]; + EvalCtx evalCtx(work.execCtx.get(), work.exprSet.get(), input_.get()); + try { + for (auto channel : work.loadOnly) { + evalCtx.ensureFieldLoaded(channel, allRows_); + } + + std::vector localResults; + work.exprSet->eval( + 0, work.exprSet->exprs().size(), true, allRows_, evalCtx, localResults); + for (auto& projection : work.resultProjections) { + results[projection.outputChannel] = + std::move(localResults[projection.inputChannel]); + } + } catch (const std::exception&) { + return std::make_unique(std::current_exception()); + } + return std::make_unique(nullptr); +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/ParallelProject.h b/velox/exec/ParallelProject.h new file mode 100644 index 000000000000..55be1c4002e6 --- /dev/null +++ b/velox/exec/ParallelProject.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/Driver.h" +#include "velox/exec/Operator.h" +#include "velox/expression/Expr.h" + +namespace facebook::velox::exec { + +class ParallelProject : public Operator { + public: + ParallelProject( + int32_t operatorId, + DriverCtx* driverCtx, + const core::ParallelProjectNodePtr& node); + + bool isFilter() const override { + return false; + } + + bool preservesOrder() const override { + return true; + } + + bool needsInput() const override { + return !input_; + } + + void addInput(RowVectorPtr input) override; + + RowVectorPtr getOutput() override; + + BlockingReason isBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + bool isFinished() override; + + void close() override { + Operator::close(); + for (auto& work : work_) { + if (work.exprSet) { + work.exprSet->clear(); + } + } + } + + void initialize() override; + + private: + struct WorkUnit { + // Maps from result channel of exprSet to channel in the node's output type. + std::vector resultProjections; + // Positions in input which are to be loaded by this group. + std::vector loadOnly; + std::unique_ptr execCtx; + std::shared_ptr exprSet; + }; + + struct WorkResult { + WorkResult(std::exception_ptr e) : error(std::move(e)) {} + std::exception_ptr error; + }; + + // Tests if 'numProcessedRows_' equals to the length of input_ and clears + // outstanding references to input_ if done. Returns true if getOutput + // should return nullptr. + bool allInputProcessed(); + + std::unique_ptr doWork( + int32_t workIdx, + std::vector& result); + + // Cached ParallelProject node for lazy initialization. After + // initialization, they will be reset, and initialized_ will be set to true. + const core::ParallelProjectNodePtr node_; + + bool initialized_{false}; + + std::vector work_; + SelectivityVector allRows_; + int32_t numProcessedInputRows_{0}; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/PartitionedOutput.cpp b/velox/exec/PartitionedOutput.cpp index 28878ef21d15..86fd98db7961 100644 --- a/velox/exec/PartitionedOutput.cpp +++ b/velox/exec/PartitionedOutput.cpp @@ -15,24 +15,11 @@ */ #include "velox/exec/PartitionedOutput.h" +#include "velox/exec/OperatorUtils.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { -namespace { -std::unique_ptr getVectorSerdeOptions( - const core::QueryConfig& queryConfig, - VectorSerde::Kind kind) { - std::unique_ptr options = - kind == VectorSerde::Kind::kPresto - ? std::make_unique() - : std::make_unique(); - options->compressionKind = - common::stringToCompressionKind(queryConfig.shuffleCompressionKind()); - options->minCompressionRatio = PartitionedOutput::minCompressionRatio(); - return options; -} -} // namespace namespace detail { Destination::Destination( @@ -145,7 +132,7 @@ BlockingReason Destination::flush( bool blocked = bufferManager.enqueue( taskId_, destination_, - std::make_unique( + std::make_unique( stream.getIOBuf(bufferReleaseFn), nullptr, flushedRows), future); @@ -203,8 +190,11 @@ PartitionedOutput::PartitionedOutput( eagerFlush_(eagerFlush), serde_(getNamedVectorSerde(planNode->serdeKind())), serdeOptions_(getVectorSerdeOptions( - operatorCtx_->driverCtx()->queryConfig(), - planNode->serdeKind())) { + common::stringToCompressionKind(operatorCtx_->driverCtx() + ->queryConfig() + .shuffleCompressionKind()), + planNode->serdeKind(), + PartitionedOutput::minCompressionRatio())) { if (!planNode->isPartitioned()) { VELOX_USER_CHECK_EQ(numDestinations_, 1); } @@ -256,17 +246,18 @@ void PartitionedOutput::initializeDestinations() { if (destinations_.empty()) { auto taskId = operatorCtx_->taskId(); for (int i = 0; i < numDestinations_; ++i) { - destinations_.push_back(std::make_unique( - taskId, - i, - serde_, - serdeOptions_.get(), - pool(), - eagerFlush_, - [&](uint64_t bytes, uint64_t rows) { - auto lockedStats = stats_.wlock(); - lockedStats->addOutputVector(bytes, rows); - })); + destinations_.push_back( + std::make_unique( + taskId, + i, + serde_, + serdeOptions_.get(), + pool(), + eagerFlush_, + [&](uint64_t bytes, uint64_t rows) { + auto lockedStats = stats_.wlock(); + lockedStats->addOutputVector(bytes, rows); + })); } } } @@ -286,7 +277,7 @@ void PartitionedOutput::initializeSizeBuffers() { void PartitionedOutput::estimateRowSizes() { const auto numInput = input_->size(); std::fill(rowSize_.begin(), rowSize_.end(), 0); - raw_vector storage; + raw_vector storage(pool()); const auto numbers = iota(numInput, storage); const auto rows = folly::Range(numbers, numInput); if (serde_->kind() == VectorSerde::Kind::kCompactRow) { diff --git a/velox/exec/PlanNodeStats.cpp b/velox/exec/PlanNodeStats.cpp index 5306a642b5a0..76c47b3db399 100644 --- a/velox/exec/PlanNodeStats.cpp +++ b/velox/exec/PlanNodeStats.cpp @@ -59,6 +59,14 @@ PlanNodeStats& PlanNodeStats::operator+=(const PlanNodeStats& another) { } } + for (const auto& [name, exprStats] : another.expressionStats) { + auto const [it, inserted] = + this->expressionStats.try_emplace(name, exprStats); + if (!inserted) { + it->second.add(exprStats); + } + } + // Populating number of drivers for plan nodes with multiple operators is not // useful. Each operator could have been executed in different pipelines with // different number of drivers. @@ -131,6 +139,14 @@ void PlanNodeStats::addTotals(const OperatorStats& stats) { } } + for (const auto& [name, exprStats] : stats.expressionStats) { + if (UNLIKELY(this->expressionStats.count(name) == 0)) { + this->expressionStats.insert(std::make_pair(name, exprStats)); + } else { + this->expressionStats.at(name).add(exprStats); + } + } + // Populating number of drivers for plan nodes with multiple operators is not // useful. Each operator could have been executed in different pipelines with // different number of drivers. @@ -290,7 +306,7 @@ namespace { void printCustomStats( const std::unordered_map& stats, const std::string& indentation, - std::stringstream& stream) { + std::ostream& stream) { int width = 0; for (const auto& entry : stats) { if (width < entry.first.size()) { @@ -306,9 +322,9 @@ void printCustomStats( } for (const auto& [name, metric] : orderedStats) { - stream << std::endl; stream << indentation << std::left << std::setw(width) << name; metric.printMetric(stream); + stream << std::endl; } } } // namespace @@ -330,16 +346,15 @@ std::string printPlanWithStats( // Print input rows and sizes only for leaf plan nodes. Including this // information for other plan nodes is redundant as it is the same as // output of the source nodes. - const bool includeInputStats = leafPlanNodes.count(planNodeId) > 0; - stream << stats.toString(includeInputStats); + const bool includeInputStats = leafPlanNodes.contains(planNodeId); + stream << indentation << stats.toString(includeInputStats) << std::endl; // Include break down by operator type for plan nodes with multiple // operators. Print input rows and sizes for all such nodes. if (stats.isMultiOperatorTypeNode()) { for (const auto& entry : stats.operatorStats) { - stream << std::endl; stream << indentation << entry.first << ": " - << entry.second->toString(true); + << entry.second->toString(true) << std::endl; if (includeCustomStats) { printCustomStats( diff --git a/velox/exec/PlanNodeStats.h b/velox/exec/PlanNodeStats.h index 8c41d7a0a46d..c1c6dbc00ee9 100644 --- a/velox/exec/PlanNodeStats.h +++ b/velox/exec/PlanNodeStats.h @@ -18,6 +18,7 @@ #include #include "velox/common/time/CpuWallTimer.h" #include "velox/exec/Operator.h" +#include "velox/expression/ExprStats.h" namespace facebook::velox::exec { struct TaskStats; @@ -142,6 +143,9 @@ struct PlanNodeStats { /// Total spilled files. uint32_t spilledFiles{0}; + /// A map of expression name to its respective stats. + std::unordered_map expressionStats; + /// Add stats for a single operator instance. void add(const OperatorStats& stats); diff --git a/velox/exec/PrefixSort.cpp b/velox/exec/PrefixSort.cpp index 45bafec927eb..6578429a02e4 100644 --- a/velox/exec/PrefixSort.cpp +++ b/velox/exec/PrefixSort.cpp @@ -101,7 +101,7 @@ FOLLY_ALWAYS_INLINE void extractRowColumnToPrefix( default: VELOX_UNSUPPORTED( "prefix-sort does not support type kind: {}", - mapTypeKindToName(typeKind)); + TypeKindName::toName(typeKind)); } } @@ -310,7 +310,7 @@ void PrefixSort::extractRowAndEncodePrefixKeys(char* row, char* prefixBuffer) { } // static. -uint32_t PrefixSort::maxRequiredBytes( +uint64_t PrefixSort::maxRequiredBytes( const RowContainer* rowContainer, const std::vector& compareFlags, const velox::common::PrefixSortConfig& config, @@ -345,14 +345,15 @@ void PrefixSort::stdSort( }); } -uint32_t PrefixSort::maxRequiredBytes() const { +uint64_t PrefixSort::maxRequiredBytes() const { const auto numRows = rowContainer_->numRows(); const auto numPages = memory::AllocationTraits::numPages(numRows * sortLayout_.entrySize); // Prefix data size + swap buffer size. return memory::AllocationTraits::pageBytes(numPages) + - pool_->preferredSize(checkedPlus( - sortLayout_.entrySize, AlignedBuffer::kPaddedSize)) + + pool_->preferredSize( + checkedPlus( + sortLayout_.entrySize, AlignedBuffer::kPaddedSize)) + 2 * pool_->alignment(); } diff --git a/velox/exec/PrefixSort.h b/velox/exec/PrefixSort.h index 3d1fa46c2a5d..0bd888412c19 100644 --- a/velox/exec/PrefixSort.h +++ b/velox/exec/PrefixSort.h @@ -153,7 +153,7 @@ class PrefixSort { /// The std::sort won't require bytes while prefix sort may require buffers /// such as prefix data. The logic is similar to the above function /// PrefixSort::sort but returns the maximum buffer the sort may need. - static uint32_t maxRequiredBytes( + static uint64_t maxRequiredBytes( const RowContainer* rowContainer, const std::vector& compareFlags, const velox::common::PrefixSortConfig& config, @@ -205,7 +205,7 @@ class PrefixSort { // Estimates the memory required for prefix sort such as prefix buffer and // swap buffer. - uint32_t maxRequiredBytes() const; + uint64_t maxRequiredBytes() const; void sortInternal(std::vector>& rows); diff --git a/velox/exec/RowContainer.cpp b/velox/exec/RowContainer.cpp index d2c76d7c11d3..9671b6bf57a1 100644 --- a/velox/exec/RowContainer.cpp +++ b/velox/exec/RowContainer.cpp @@ -42,8 +42,7 @@ static int32_t typeKindSize(TypeKind kind) { __attribute__((__no_sanitize__("thread"))) #endif #endif -inline void -setBit(char* bits, uint32_t idx) { +inline void setBit(char* bits, uint32_t idx) { auto bitsAs8Bit = reinterpret_cast(bits); bitsAs8Bit[idx / 8] |= (1 << (idx % 8)); } @@ -137,14 +136,17 @@ RowContainer::RowContainer( bool isJoinBuild, bool hasProbedFlag, bool hasNormalizedKeys, + bool useListRowIndex, memory::MemoryPool* pool) : keyTypes_(keyTypes), nullableKeys_(nullableKeys), isJoinBuild_(isJoinBuild), hasNormalizedKeys_(hasNormalizedKeys), + useListRowIndex_(useListRowIndex), stringAllocator_(std::make_unique(pool)), accumulators_(accumulators), - rows_(pool) { + rows_(pool), + rowPointers_(StlAllocator(stringAllocator_.get())) { // Compute the layout of the payload row. The row has keys, null flags, // accumulators, dependent fields. All fields are fixed width. If variable // width data is referenced, this is done with StringView(for VARCHAR) and @@ -174,17 +176,17 @@ RowContainer::RowContainer( // bits. 'numRowsWithNormalizedKey_' gives the number of rows with // the extra field. int32_t offset = 0; - int32_t nullOffset = 0; + int32_t flagOffset = 0; bool isVariableWidth = false; for (auto& type : keyTypes_) { typeKinds_.push_back(type->kind()); types_.push_back(type); offsets_.push_back(offset); offset += typeKindSize(type->kind()); - nullOffsets_.push_back(nullOffset); + nullOffsets_.push_back(flagOffset); isVariableWidth |= !type->isFixedWidth(); if (nullableKeys_) { - ++nullOffset; + ++flagOffset; } } // Make offset at least sizeof pointer so that there is space for a @@ -192,18 +194,16 @@ RowContainer::RowContainer( offset = std::max(offset, sizeof(void*)); const int32_t firstAggregateOffset = offset; if (!accumulators.empty()) { - // This moves nullOffset to the start of the next byte. + // This moves flagOffset to the start of the next byte. // This is to guarantee the null and initialized bits for an aggregate // always appear in the same byte. - nullOffset = (nullOffset + 7) & -8; + flagOffset = (flagOffset + 7) & -8; } for (const auto& accumulator : accumulators) { - // Initialized bit. Set when the accumulator is initialized. - nullOffsets_.push_back(nullOffset); - ++nullOffset; // Null bit. - nullOffsets_.push_back(nullOffset); - ++nullOffset; + nullOffsets_.push_back(flagOffset); + // Increment for two bits: null bit and following initialized bit. + flagOffset += kNumAccumulatorFlags; isVariableWidth |= !accumulator.isFixedSize(); usesExternalMemory_ |= accumulator.usesExternalMemory(); alignment_ = combineAlignments(accumulator.alignment(), alignment_); @@ -211,21 +211,19 @@ RowContainer::RowContainer( for (auto& type : dependentTypes) { types_.push_back(type); typeKinds_.push_back(type->kind()); - nullOffsets_.push_back(nullOffset); - ++nullOffset; + nullOffsets_.push_back(flagOffset); + ++flagOffset; isVariableWidth |= !type->isFixedWidth(); } if (hasProbedFlag) { - nullOffsets_.push_back(nullOffset); - probedFlagOffset_ = nullOffset + firstAggregateOffset * 8; - ++nullOffset; + probedFlagOffset_ = flagOffset + firstAggregateOffset * 8; + ++flagOffset; } // Free flag. - nullOffsets_.push_back(nullOffset); - freeFlagOffset_ = nullOffset + firstAggregateOffset * 8; - ++nullOffset; + freeFlagOffset_ = flagOffset + firstAggregateOffset * 8; + ++flagOffset; // Add 1 to the last null offset to get the number of bits. - flagBytes_ = bits::nbytes(nullOffsets_.back() + 1); + flagBytes_ = bits::nbytes(flagOffset); // Fixup 'nullOffsets_' to be the bit number from the start of the row. for (int32_t i = 0; i < nullOffsets_.size(); ++i) { nullOffsets_[i] += firstAggregateOffset * 8; @@ -250,14 +248,6 @@ RowContainer::RowContainer( offset += sizeof(void*); } fixedRowSize_ = bits::roundUp(offset, alignment_); - // A distinct hash table has no aggregates and if the hash table has - // no nulls, it may be that there are no null flags. - if (!nullOffsets_.empty()) { - // All flags like free and probed flags and null flags for keys and non-keys - // start as 0. This is also used to mark aggregates as uninitialized on row - // creation. - initialNulls_.resize(flagBytes_, 0x0); - } originalNormalizedKeySize_ = hasNormalizedKeys_ ? bits::roundUp(sizeof(normalized_key_t), alignment_) : 0; @@ -268,16 +258,7 @@ RowContainer::RowContainer( offsets_[i], (nullableKeys_ || i >= keyTypes_.size()) ? nullOffsets_[nullOffsetsPos] : RowColumn::kNotNullOffset); - - // offsets_ contains the offsets for keys, then accumulators, then dependent - // columns. This captures the case where i is the index of an accumulator. - if (!accumulators.empty() && i >= keyTypes_.size() && - i < keyTypes_.size() + accumulators.size()) { - // Aggregates have null flags and initialized flags. - nullOffsetsPos += kNumAccumulatorFlags; - } else { - ++nullOffsetsPos; - } + ++nullOffsetsPos; } rowColumnsStats_.resize(types_.size()); } @@ -301,10 +282,24 @@ char* RowContainer::newRow() { if (normalizedKeySize_) { ++numRowsWithNormalizedKey_; } + + if (useListRowIndex_) { + rowPointers_.push_back(row); + } } return initializeRow(row, false /* reuse */); } +void RowContainer::setAllNull(char* row) { + VELOX_CHECK(!bits::isBitSet(row, freeFlagOffset_)); + removeOrUpdateRowColumnStats(row, /*setToNull=*/true); + if (!nullOffsets_.empty()) { + for (auto i : nullOffsets_) { + row[nullByte(i)] |= nullMask(i); + } + } +} + char* RowContainer::initializeRow(char* row, bool reuse) { if (reuse) { auto rows = folly::Range(&row, 1); @@ -317,10 +312,9 @@ char* RowContainer::initializeRow(char* row, bool reuse) { ::memset(row, 0, fixedRowSize_); } if (!nullOffsets_.empty()) { - ::memcpy( - row + nullByte(nullOffsets_[0]), - initialNulls_.data(), - initialNulls_.size()); + // Sets all null and initialized bits to 0 (for each accumulator, + // initialized bit follows the null bit). + ::memset(row + nullByte(nullOffsets_[0]), 0x0, flagBytes_); } if (rowSizeOffset_) { variableRowSize(row) = 0; @@ -361,7 +355,7 @@ void RowContainer::eraseRows(folly::Range rows) { } int32_t RowContainer::findRows(folly::Range rows, char** result) const { - raw_vector> ranges; + raw_vector> ranges(pool()); ranges.resize(rows_.numRanges()); for (auto i = 0; i < rows_.numRanges(); ++i) { ranges[i] = rows_.rangeAt(i); @@ -370,8 +364,8 @@ int32_t RowContainer::findRows(folly::Range rows, char** result) const { ranges.begin(), ranges.end(), [](const auto& left, const auto& right) { return left.data() < right.data(); }); - raw_vector starts; - raw_vector sizes; + raw_vector starts(pool()); + raw_vector sizes(pool()); starts.reserve(ranges.size()); sizes.reserve(ranges.size()); for (const auto& range : ranges) { @@ -599,7 +593,7 @@ int32_t RowContainer::variableSizeAt(const char* row, column_index_t column) } const auto typeKind = typeKinds_[column]; - if (typeKind == TypeKind::VARCHAR || typeKind == TypeKind::VARBINARY) { + if (is_string_kind(typeKind)) { return reinterpret_cast(row + rowColumn.offset()) ->size(); } else { @@ -625,7 +619,7 @@ int32_t RowContainer::extractVariableSizeAt( } const auto typeKind = typeKinds_[column]; - if (typeKind == TypeKind::VARCHAR || typeKind == TypeKind::VARBINARY) { + if (is_string_kind(typeKind)) { const auto value = valueAt(row, rowColumn.offset()); const auto size = value.size(); ::memcpy(output, &size, 4); @@ -663,7 +657,7 @@ int32_t RowContainer::storeVariableSizeAt( // First 4 bytes is the size of the data. const auto size = *reinterpret_cast(data); - if (typeKind == TypeKind::VARCHAR || typeKind == TypeKind::VARBINARY) { + if (is_string_kind(typeKind)) { if (size > 0) { stringAllocator_->copyMultipart( StringView(data + 4, size), row, rowColumn.offset()); @@ -901,13 +895,11 @@ void RowContainer::hashTyped( : BaseVector::kNullHash; } else { uint64_t hash; - if constexpr (Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + if constexpr (is_string_kind(Kind)) { hash = folly::hasher()(HashStringAllocator::contiguousString( valueAt(row, offset), storage)); - } else if constexpr ( - Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || - Kind == TypeKind::MAP) { + } else if constexpr (is_nested_kind(Kind)) { auto in = prepareRead(row, offset); hash = ContainerRowSerde::hash(in, type); } else if constexpr (typeProvidesCustomComparison) { @@ -977,6 +969,8 @@ void RowContainer::clear() { hasDuplicateRows_ = false; rows_.clear(); + rowPointers_.clear(); + rowPointers_.shrink_to_fit(); stringAllocator_->clear(); numRows_ = 0; numRowsWithNormalizedKey_ = 0; @@ -1038,7 +1032,8 @@ std::optional RowContainer::estimateRowSize() const { } int64_t freeBytes = rows_.freeBytes() + fixedRowSize_ * numFreeRows_; int64_t usedSize = rows_.allocatedBytes() - freeBytes + - stringAllocator_->retainedSize() - stringAllocator_->freeSpace(); + stringAllocator_->retainedSize() - stringAllocator_->freeSpace() - + rowPointers_.capacity() * sizeof(char*); int64_t rowSize = usedSize / numRows_; VELOX_CHECK_GT( rowSize, 0, "Estimated row size of the RowContainer must be positive."); @@ -1275,9 +1270,9 @@ RowComparator::RowComparator( } } -bool RowComparator::operator()(const char* lhs, const char* rhs) { +int32_t RowComparator::compare(const char* lhs, const char* rhs) { if (lhs == rhs) { - return false; + return 0; } for (auto& key : keyInfo_) { if (auto result = rowContainer_->compare( @@ -1285,26 +1280,37 @@ bool RowComparator::operator()(const char* lhs, const char* rhs) { rhs, key.first, {key.second.isNullsFirst(), key.second.isAscending(), false})) { - return result < 0; + return result; } } - return false; + return 0; } -bool RowComparator::operator()( +bool RowComparator::operator()(const char* lhs, const char* rhs) { + return compare(lhs, rhs) < 0; +} + +int32_t RowComparator::compare( const std::vector& decodedVectors, vector_size_t index, - const char* rhs) { + const char* other) { for (auto& key : keyInfo_) { if (auto result = rowContainer_->compare( - rhs, + other, rowContainer_->columnAt(key.first), decodedVectors[key.first], index, {key.second.isNullsFirst(), key.second.isAscending(), false})) { - return result > 0; + return -result; } } - return false; + return 0; +} + +bool RowComparator::operator()( + const std::vector& decodedVectors, + vector_size_t index, + const char* other) { + return compare(decodedVectors, index, other) < 0; } } // namespace facebook::velox::exec diff --git a/velox/exec/RowContainer.h b/velox/exec/RowContainer.h index 84783a19ee97..db336a2d9aef 100644 --- a/velox/exec/RowContainer.h +++ b/velox/exec/RowContainer.h @@ -83,6 +83,8 @@ struct RowContainerIterator { char* rowBegin{nullptr}; /// First byte after the end of the range containing 'currentRow'. char* endOfRun{nullptr}; + /// Cursor of the list row operation. + int32_t listRowCursor{0}; /// Returns the current row, skipping a possible normalized key below the /// first byte of row. @@ -118,6 +120,10 @@ class RowPartitions { return size_; } + void reset() { + size_ = 0; + } + private: const int32_t capacity_; @@ -273,6 +279,21 @@ class RowContainer { const std::vector& keyTypes, const std::vector& dependentTypes, memory::MemoryPool* pool) + : RowContainer( + keyTypes, + dependentTypes, + /*useListRowIndex=*/false, + pool) {} + + /// If 'useListRowIndex' is true, the container maintains an internal array of + /// row pointers so that listRowsFast() can return rows without scanning + /// underlying allocations or checking free/probe flags. It is intended to be + /// used in SortBuffer and SortInputSpiller to improve performance. + RowContainer( + const std::vector& keyTypes, + const std::vector& dependentTypes, + bool useListRowIndex, + memory::MemoryPool* pool) : RowContainer( keyTypes, true, // nullableKeys @@ -282,6 +303,7 @@ class RowContainer { false, // isJoinBuild false, // hasProbedFlag false, // hasNormalizedKey + useListRowIndex, pool) {} ~RowContainer(); @@ -313,6 +335,7 @@ class RowContainer { bool isJoinBuild, bool hasProbedFlag, bool hasNormalizedKey, + bool useListRowIndex, memory::MemoryPool* pool); /// Allocates a new row and initializes possible aggregates to null. @@ -328,13 +351,7 @@ class RowContainer { /// Sets all fields, aggregates, keys and dependents to null. Used when making /// a row with uninitialized keys for aggregates with no-op partial /// aggregation. - void setAllNull(char* row) { - removeOrUpdateRowColumnStats(row, /*setToNull=*/true); - if (!nullOffsets_.empty()) { - memset(row + nullByte(nullOffsets_[0]), 0xff, initialNulls_.size()); - bits::clearBit(row, freeFlagOffset_); - } - } + void setAllNull(char* row); /// The row size excluding any out-of-line stored variable length values. int32_t fixedRowSize() const { @@ -579,8 +596,7 @@ class RowContainer { __attribute__((__no_sanitize__("thread"))) #endif #endif - int32_t - listRows( + int32_t listRows( RowContainerIterator* iter, int32_t maxRows, uint64_t maxBytes, @@ -644,6 +660,20 @@ class RowContainer { return count; } + /// Fast path for `listRows` that returns `rowPointers_` directly. Used by + /// `SortBuffer` and `SortInputSpiller`, so it skips checking the free and + /// probe flags. + int32_t listRowsFast(RowContainerIterator* iter, int32_t maxRows, char** rows) + const { + int32_t count = 0; + while (count < maxRows && iter->listRowCursor < rowPointers_.size()) { + char* row = rowPointers_[iter->listRowCursor]; + rows[count++] = row; + ++iter->listRowCursor; + } + return count; + } + /// Extracts up to 'maxRows' rows starting at the position of 'iter'. A /// default constructed or reset iter starts at the beginning. Returns the /// number of rows written to 'rows'. Returns 0 when at end. Stops after the @@ -658,6 +688,9 @@ class RowContainer { int32_t listRows(RowContainerIterator* iter, int32_t maxRows, char** rows) const { + if (useListRowIndex_) { + return listRowsFast(iter, maxRows, rows); + } return listRows(iter, maxRows, kUnlimited, rows); } @@ -674,21 +707,13 @@ class RowContainer { __attribute__((__no_sanitize__("thread"))) #endif #endif - void - setProbedFlag(char** rows, int32_t numRows); - - /// Returns true if 'row' at 'column' equals the value at 'index' in - /// 'decoded'. 'mayHaveNulls' specifies if nulls need to be checked. This is a - /// fast path for compare(). - template - bool equals( - const char* row, - RowColumn column, - const DecodedVector& decoded, - vector_size_t index) const; + void setProbedFlag(char** rows, int32_t numRows); /// Compares the value at 'column' in 'row' with the value at 'index' in /// 'decoded'. Returns 0 for equal, < 0 for 'row' < 'decoded', > 0 otherwise. + /// 'mayHaveNulls' specifies if nulls need to be checked. This is a fast path + /// for compare(). + template int32_t compare( const char* row, RowColumn column, @@ -806,6 +831,10 @@ class RowContainer { return 0; } + const std::vector>& testingRowPointers() const { + return rowPointers_; + } + memory::MemoryPool* pool() const { return stringAllocator_->pool(); } @@ -1102,7 +1131,7 @@ class RowContainer { BufferPtr& nullBuffer = result->mutableNulls(maxRows, true); auto nulls = nullBuffer->asMutable(); - BufferPtr valuesBuffer = result->mutableValues(maxRows); + BufferPtr valuesBuffer = result->mutableValues(); [[maybe_unused]] auto values = valuesBuffer->asMutableRange(); for (int32_t i = 0; i < numRows; ++i) { const char* row; @@ -1134,9 +1163,9 @@ class RowContainer { int32_t offset, int32_t resultOffset, FlatVector* result) { - auto maxRows = numRows + resultOffset; + [[maybe_unused]] auto maxRows = numRows + resultOffset; VELOX_DCHECK_LE(maxRows, result->size()); - BufferPtr valuesBuffer = result->mutableValues(maxRows); + BufferPtr valuesBuffer = result->mutableValues(); [[maybe_unused]] auto values = valuesBuffer->asMutableRange(); for (int32_t i = 0; i < numRows; ++i) { const char* row; @@ -1173,56 +1202,29 @@ class RowContainer { bool mix, uint64_t* result) const; - template - inline bool equalsWithNulls( - const char* row, - int32_t offset, - int32_t nullByte, - uint8_t nullMask, - const DecodedVector& decoded, - vector_size_t index) const { - bool rowIsNull = isNullAt(row, nullByte, nullMask); - bool indexIsNull = decoded.isNullAt(index); - if (rowIsNull || indexIsNull) { - return rowIsNull == indexIsNull; - } - - return equalsNoNulls( - row, offset, decoded, index); - } - - template - inline bool equalsNoNulls( + template + inline int compare( const char* row, - int32_t offset, + RowColumn column, const DecodedVector& decoded, - vector_size_t index) const { - using T = typename KindToFlatVector::HashRowType; - - if constexpr ( - Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || - Kind == TypeKind::MAP) { - return compareComplexType(row, offset, decoded, index) == 0; - } else if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { - return compareStringAsc( - valueAt(row, offset), decoded, index) == 0; - } else if constexpr (typeProvidesCustomComparison) { - return SimpleVector::template comparePrimitiveAscWithCustomComparison< - Kind>( - decoded.base()->type().get(), - decoded.valueAt(index), - valueAt(row, offset)) == 0; + vector_size_t index, + CompareFlags flags) const { + if (decoded.base()->typeUsesCustomComparison()) { + return compare( + row, column, decoded, index, flags); } else { - return SimpleVector::comparePrimitiveAsc( - decoded.valueAt(index), valueAt(row, offset)) == 0; + return compare( + row, column, decoded, index, flags); } } template < bool typeProvidesCustomComparison, TypeKind Kind, - std::enable_if_t = 0> + bool mayHaveNulls, + std::enable_if_t< + Kind != TypeKind::OPAQUE && Kind != TypeKind::UNKNOWN, + int32_t> = 0> inline int compare( const char* row, RowColumn column, @@ -1230,20 +1232,23 @@ class RowContainer { vector_size_t index, CompareFlags flags) const { using T = typename KindToFlatVector::HashRowType; - bool rowIsNull = isNullAt(row, column.nullByte(), column.nullMask()); - bool indexIsNull = decoded.isNullAt(index); - if (rowIsNull) { - return indexIsNull ? 0 : flags.nullsFirst ? -1 : 1; - } - if (indexIsNull) { - return flags.nullsFirst ? 1 : -1; + + if constexpr (mayHaveNulls) { + bool rowIsNull = isNullAt(row, column.nullByte(), column.nullMask()); + bool indexIsNull = decoded.isNullAt(index); + if (rowIsNull) { + return indexIsNull ? 0 : flags.nullsFirst ? -1 : 1; + } + if (indexIsNull) { + return flags.nullsFirst ? 1 : -1; + } } + if constexpr ( Kind == TypeKind::ROW || Kind == TypeKind::ARRAY || Kind == TypeKind::MAP) { return compareComplexType(row, column.offset(), decoded, index, flags); - } else if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + } else if constexpr (is_string_kind(Kind)) { auto result = compareStringAsc( valueAt(row, column.offset()), decoded, index); return flags.ascending ? result : result * -1; @@ -1267,6 +1272,22 @@ class RowContainer { template < bool typeProvidesCustomComparison, TypeKind Kind, + bool mayHaveNulls, + std::enable_if_t = 0> + inline int compare( + const char* row, + RowColumn column, + const DecodedVector& /*decoded*/, + vector_size_t /*index*/, + CompareFlags flags) const { + const bool rowIsNull = isNullAt(row, column.nullByte(), column.nullMask()); + return rowIsNull ? 0 : flags.nullsFirst ? 1 : -1; + } + + template < + bool typeProvidesCustomComparison, + TypeKind Kind, + bool mayHaveNulls, std::enable_if_t = 0> inline int compare( const char* /*row*/, @@ -1307,8 +1328,7 @@ class RowContainer { Kind == TypeKind::MAP) { return compareComplexType( left, right, type, leftOffset, rightOffset, flags); - } else if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + } else if constexpr (is_string_kind(Kind)) { auto leftValue = valueAt(left, leftOffset); auto rightValue = valueAt(right, rightOffset); auto result = compareStringAsc(leftValue, rightValue); @@ -1493,7 +1513,8 @@ class RowContainer { const bool isJoinBuild_; // True if normalized keys are enabled in initial state. const bool hasNormalizedKeys_; - + // True if use 'listRowsFast'. + const bool useListRowIndex_; const std::unique_ptr stringAllocator_; // Indicates if we can add new row to this row container. It is set to false @@ -1512,7 +1533,7 @@ class RowContainer { // Indicates if this row container has rows with duplicate keys. This only // applies if 'nextOffset_' is set. tsan_atomic hasDuplicateRows_{false}; - // Bit position of null bit in the row. 0 if no null flag. Order is keys, + // Bit position of null bit in the row. 0 if no null flag. Order is keys, // accumulators, dependent. std::vector nullOffsets_; // Position of field or accumulator. Corresponds 1:1 to 'nullOffset_'. @@ -1543,15 +1564,13 @@ class RowContainer { // Extra bytes to reserve before each added row for a normalized key. Set to // 0 after deciding not to use normalized keys. int normalizedKeySize_; - // Copied over the null bits of each row on initialization. Keys are - // not null, aggregates are null. - std::vector initialNulls_; uint64_t numRows_ = 0; // Head of linked list of free rows. char* firstFreeRow_ = nullptr; uint64_t numFreeRows_ = 0; memory::AllocationPool rows_; + std::vector> rowPointers_; int alignment_ = 1; @@ -1739,60 +1758,21 @@ inline void RowContainer::extractNulls( } template -inline bool RowContainer::equals( - const char* row, - RowColumn column, - const DecodedVector& decoded, - vector_size_t index) const { - auto typeKind = decoded.base()->typeKind(); - if (typeKind == TypeKind::UNKNOWN) { - return isNullAt(row, column.nullByte(), column.nullMask()); - } - - if constexpr (!mayHaveNulls) { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( - equalsNoNulls, false, typeKind, row, column.offset(), decoded, index); - } else { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( - equalsWithNulls, - false, - typeKind, - row, - column.offset(), - column.nullByte(), - column.nullMask(), - decoded, - index); - } -} - inline int RowContainer::compare( const char* row, RowColumn column, const DecodedVector& decoded, vector_size_t index, CompareFlags flags) const { - if (decoded.base()->typeUsesCustomComparison()) { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( - compare, - true, - decoded.base()->typeKind(), - row, - column, - decoded, - index, - flags); - } else { - return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( - compare, - false, - decoded.base()->typeKind(), - row, - column, - decoded, - index, - flags); - } + return VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH_ALL( + compare, + mayHaveNulls, + decoded.base()->typeKind(), + row, + column, + decoded, + index, + flags); } inline int RowContainer::compare( @@ -1873,11 +1853,21 @@ class RowComparator { /// Returns true if lhs < rhs, false otherwise. bool operator()(const char* lhs, const char* rhs); - /// Returns true if decodeVectors[index] < rhs, false otherwise. + /// Returns 0 for equal, < 0 for lhs < rhs, > 0 otherwise. + int compare(const char* lhs, const char* rhs); + + /// Returns true if decodedVectors[index] < other, false otherwise. bool operator()( const std::vector& decodedVectors, vector_size_t index, - const char* rhs); + const char* other); + + /// Returns 0 for equal, < 0 for decodedVectors[index] < other, + /// > 0 otherwise. + int32_t compare( + const std::vector& decodedVectors, + vector_size_t index, + const char* other); private: std::vector> keyInfo_; diff --git a/velox/exec/RowNumber.cpp b/velox/exec/RowNumber.cpp index 31421d0d29b5..e57424da59df 100644 --- a/velox/exec/RowNumber.cpp +++ b/velox/exec/RowNumber.cpp @@ -120,13 +120,13 @@ void RowNumber::restoreNextSpillPartition() { auto it = spillInputPartitionSet_.begin(); restoringPartitionId_ = it->first; spillInputReader_ = it->second->createUnorderedReader( - spillConfig_->readBufferSize, pool(), &spillStats_); + spillConfig_->readBufferSize, pool(), spillStats_.get()); // Find matching partition for the hash table. auto hashTableIt = spillHashTablePartitionSet_.find(it->first); if (hashTableIt != spillHashTablePartitionSet_.end()) { spillHashTableReader_ = hashTableIt->second->createUnorderedReader( - spillConfig_->readBufferSize, pool(), &spillStats_); + spillConfig_->readBufferSize, pool(), spillStats_.get()); setSpillPartitionBits(&(it->first)); @@ -388,7 +388,7 @@ void RowNumber::reclaim( << spillConfig_->maxSpillLevel << ", and abandon spilling for memory pool: " << pool()->name(); - ++spillStats_.wlock()->spillMaxLevelExceededCount; + ++spillStats_->wlock()->spillMaxLevelExceededCount; return; } @@ -408,7 +408,7 @@ SpillPartitionIdSet RowNumber::spillHashTable() { tableType, spillPartitionBits_, &spillConfig, - &spillStats_); + spillStats_.get()); hashTableSpiller->spill(); hashTableSpiller->finishSpill(spillHashTablePartitionSet_); @@ -429,7 +429,7 @@ void RowNumber::setupInputSpiller( restoringPartitionId_, spillPartitionBits_, &spillConfig, - &spillStats_); + spillStats_.get()); const auto& hashers = table_->hashers(); @@ -509,7 +509,7 @@ void RowNumber::recursiveSpillInput() { while (spillInputReader_->nextBatch(unspilledInput)) { spillInput(unspilledInput, pool()); - if (operatorCtx_->driver()->shouldYield()) { + if (shouldYield()) { yield_ = true; return; } diff --git a/velox/exec/RowsStreamingWindowBuild.cpp b/velox/exec/RowsStreamingWindowBuild.cpp index a3dedd8a60cd..7f03a465b6c3 100644 --- a/velox/exec/RowsStreamingWindowBuild.cpp +++ b/velox/exec/RowsStreamingWindowBuild.cpp @@ -50,8 +50,9 @@ bool RowsStreamingWindowBuild::needsInput() { void RowsStreamingWindowBuild::ensureInputPartition() { if (windowPartitions_.empty() || windowPartitions_.back()->complete()) { - windowPartitions_.emplace_back(std::make_shared( - data_.get(), inversedInputChannels_, sortKeyInfo_)); + windowPartitions_.emplace_back( + std::make_shared( + data_.get(), inversedInputChannels_, sortKeyInfo_)); } } diff --git a/velox/exec/ScaleWriterLocalPartition.cpp b/velox/exec/ScaleWriterLocalPartition.cpp index 08e3b3bad311..7530ff403a0c 100644 --- a/velox/exec/ScaleWriterLocalPartition.cpp +++ b/velox/exec/ScaleWriterLocalPartition.cpp @@ -174,28 +174,11 @@ void ScaleWriterPartitioningLocalPartition::addInput(RowVectorPtr input) { row; } - for (auto i = 0; i < numPartitions_; ++i) { - const auto writerRowCount = writerAssignmentCounts_[i]; - if (writerRowCount == 0) { - continue; - } - - auto writerInput = processPartition( - input, - writerRowCount, - i, - std::move(writerAssignmmentIndicesBuffers_[i]), - rawWriterAssignmmentIndicesBuffers_[i]); - if (writerInput != nullptr) { - ContinueFuture future; - auto reason = queues_[i]->enqueue( - writerInput, totalInputBytes * writerRowCount / numInput, &future); - if (reason != BlockingReason::kNotBlocked) { - blockingReasons_.push_back(reason); - futures_.push_back(std::move(future)); - } - } - } + populateAndEnqueuePartitions( + input, + writerAssignmentCounts_, + writerAssignmmentIndicesBuffers_, + rawWriterAssignmmentIndicesBuffers_); } // Only update the scaling state if the memory used is below the @@ -223,10 +206,9 @@ uint32_t ScaleWriterPartitioningLocalPartition::getNextWriterId( void ScaleWriterPartitioningLocalPartition::close() { LocalPartition::close(); - // The last driver operator reports the shared table partition rebalancer - // stats. We expect one reference hold by this operator and one referenced by - // the task. - if (tablePartitionRebalancer_.use_count() != 2) { + // The first driver operator instance reports the shared table partition + // rebalancer stats. + if (operatorCtx_->driverCtx()->driverId != 0) { return; } diff --git a/velox/exec/SerializedPage.cpp b/velox/exec/SerializedPage.cpp new file mode 100644 index 000000000000..7ba6a3af0f7f --- /dev/null +++ b/velox/exec/SerializedPage.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/SerializedPage.h" + +namespace facebook::velox::exec { + +PrestoSerializedPage::PrestoSerializedPage( + std::unique_ptr iobuf, + std::function onDestructionCb, + std::optional numRows) + : iobuf_(std::move(iobuf)), + iobufBytes_(chainBytes(*iobuf_.get())), + numRows_(numRows), + onDestructionCb_(onDestructionCb) { + VELOX_CHECK_NOT_NULL(iobuf_); + for (auto& buf : *iobuf_) { + int32_t bufSize = buf.size(); + ranges_.push_back( + ByteRange{ + const_cast(reinterpret_cast(buf.data())), + bufSize, + 0}); + } +} + +PrestoSerializedPage::~PrestoSerializedPage() { + if (onDestructionCb_) { + onDestructionCb_(*iobuf_.get()); + } +} + +std::unique_ptr +PrestoSerializedPage::prepareStreamForDeserialize() { + return std::make_unique(std::move(ranges_)); +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/SerializedPage.h b/velox/exec/SerializedPage.h new file mode 100644 index 000000000000..3f93de8012ad --- /dev/null +++ b/velox/exec/SerializedPage.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/memory/ByteStream.h" + +namespace facebook::velox::exec { + +/// Interface for serialized pages. +class SerializedPageBase { + public: + virtual ~SerializedPageBase() = default; + + /// Returns the size of the serialized data in bytes. + virtual uint64_t size() const = 0; + + /// Returns the number of rows if available. + virtual std::optional numRows() const = 0; + + /// Makes 'input' ready for deserializing 'this' with + /// VectorStreamGroup::read(). + virtual std::unique_ptr prepareStreamForDeserialize() = 0; + + /// Returns a clone of the IOBuf. + virtual std::unique_ptr getIOBuf() const = 0; +}; + +/// Corresponds to Presto SerializedPage, i.e. a container for serialized +/// vectors in Presto wire format. +class PrestoSerializedPage : public SerializedPageBase { + public: + /// Construct from IOBuf chain. + explicit PrestoSerializedPage( + std::unique_ptr iobuf, + std::function onDestructionCb = nullptr, + std::optional numRows = std::nullopt); + + ~PrestoSerializedPage() override; + + uint64_t size() const override { + return iobufBytes_; + } + + std::optional numRows() const override { + return numRows_; + } + + std::unique_ptr prepareStreamForDeserialize() override; + + std::unique_ptr getIOBuf() const override { + return iobuf_->clone(); + } + + private: + static int64_t chainBytes(folly::IOBuf& iobuf) { + int64_t size = 0; + for (auto& range : iobuf) { + size += range.size(); + } + return size; + } + + // Buffers containing the serialized data. The memory is owned by 'iobuf_'. + std::vector ranges_; + + // IOBuf holding the data in 'ranges_. + std::unique_ptr iobuf_; + + // Number of payload bytes in 'iobuf_'. + const int64_t iobufBytes_; + + // Number of payload rows, if provided. + const std::optional numRows_; + + // Callback that will be called on destruction of the PrestoSerializedPage, + // primarily used to free externally allocated memory backing folly::IOBuf + // from caller. Caller is responsible to pass in proper cleanup logic to + // prevent any memory leak. + std::function onDestructionCb_; +}; + +// TODO: Remove after fully migration to new SerializedPageBase and +// PrestoSerializedPage API. +using SerializedPage = PrestoSerializedPage; +} // namespace facebook::velox::exec diff --git a/velox/exec/SerializedPageSpiller.cpp b/velox/exec/SerializedPageSpiller.cpp deleted file mode 100644 index b565968c1c97..000000000000 --- a/velox/exec/SerializedPageSpiller.cpp +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/exec/SerializedPageSpiller.h" -#include "velox/exec/SpillFile.h" - -namespace facebook::velox::exec { - -void SerializedPageSpiller::spill( - const std::vector>& pages) { - writeWithBufferControl([&]() { - if (pages.empty()) { - return 0UL; - } - - if (bufferStream_ == nullptr) { - bufferStream_ = std::make_unique(*pool_); - } - - // Spill file layout: - // --- Page 0 --- - // - // (1 Byte) is null page - // (8 Bytes) payload size - // (1 Bytes) has num rows - // (8 Bytes) num rows - // (x Bytes) payload - // - // --- Page 1 --- - // ... - // --- Page n --- - // ... - const auto totalBytesBeforeSpill = totalBytes_; - uint64_t totalRows{0}; - for (const auto& page : pages) { - if (page != nullptr) { - const auto pageSize = page->size(); - totalBytes_ += pageSize; - totalRows += page->numRows().value_or(0); - } - - // Spill payload headers. - const uint8_t isNull = (page == nullptr) ? 1 : 0; - bufferStream_->write( - reinterpret_cast(&isNull), sizeof(uint8_t)); - if (page == nullptr) { - continue; - } - - auto pageBytes = static_cast(page->size()); - bufferStream_->write( - reinterpret_cast(&pageBytes), sizeof(int64_t)); - - const auto numRowsOpt = page->numRows(); - uint8_t hasNumRows = numRowsOpt.has_value() ? 1 : 0; - bufferStream_->write(reinterpret_cast(&hasNumRows), 1); - if (numRowsOpt.has_value()) { - const int64_t numRows = numRowsOpt.value(); - bufferStream_->write( - reinterpret_cast(&numRows), sizeof(int64_t)); - } - - // Spill payload. - auto iobuf = page->getIOBuf(); - for (auto range = iobuf->begin(); range != iobuf->end(); ++range) { - bufferStream_->write( - reinterpret_cast(range->data()), - static_cast(range->size())); - } - } - - VELOX_CHECK_GE(totalBytes_, totalBytesBeforeSpill); - totalPages_ += pages.size(); - return totalRows; - }); -} - -SerializedPageSpiller::Result SerializedPageSpiller::finishSpill() { - auto spillFiles = finish(); - return {std::move(spillFiles), totalPages_}; -} - -void SerializedPageSpiller::flushBuffer( - SpillWriteFile* file, - uint64_t& writtenBytes, - uint64_t& flushTimeNs, - uint64_t& writeTimeNs) { - flushTimeNs = 0; - { - NanosecondTimer timer(&writeTimeNs); - writtenBytes = file->write(bufferStream_->getIOBuf()); - } -} - -bool SerializedPageSpiller::bufferEmpty() const { - return bufferStream_ == nullptr; -} - -uint64_t SerializedPageSpiller::bufferSize() const { - return bufferStream_ == nullptr - ? 0 - : static_cast(bufferStream_->tellp()); -} - -void SerializedPageSpiller::addFinishedFile(SpillWriteFile* file) { - SpillFileInfo spillFileInfo; - spillFileInfo.path = file->path(); - finishedFiles_.push_back(std::move(spillFileInfo)); -} - -bool SerializedPageSpillReader::empty() const { - return numPages_ == 0; -} - -uint64_t SerializedPageSpillReader::numPages() const { - return numPages_; -} - -std::shared_ptr SerializedPageSpillReader::at(uint64_t index) { - ensurePages(index); - return bufferedPages_[index]; -} - -void SerializedPageSpillReader::deleteAll() { - spillFilePaths_.clear(); - curFileStream_.reset(); - bufferedPages_.clear(); - numPages_ = 0; -} - -void SerializedPageSpillReader::deleteFront(uint64_t numPages) { - if (numPages == 0) { - return; - } - ensurePages(numPages - 1); - bufferedPages_.erase( - bufferedPages_.begin(), bufferedPages_.begin() + numPages); - numPages_ -= numPages; -} - -void SerializedPageSpillReader::ensureSpillFile() { - if (curFileStream_ != nullptr) { - return; - } - VELOX_CHECK(!spillFilePaths_.empty()); - auto filePath = spillFilePaths_.front(); - auto fs = filesystems::getFileSystem(filePath, nullptr); - auto file = fs->openFileForRead(filePath); - curFileStream_ = std::make_unique( - std::move(file), readBufferSize_, pool_); - spillFilePaths_.pop_front(); -} - -void SerializedPageSpillReader::ensurePages(uint64_t index) { - VELOX_CHECK_LT(index, numPages_); - if (index < bufferedPages_.size()) { - return; - } - - while (index >= bufferedPages_.size()) { - bufferedPages_.push_back(unspillNextPage()); - } -} - -namespace { -struct FreeData { - const std::shared_ptr pool; - int64_t bytesToFree{0}; -}; - -void freeFunc(void* data, void* userData) { - auto* freeData = reinterpret_cast(userData); - freeData->pool->free(data, freeData->bytesToFree); - delete freeData; -} -} // namespace - -std::shared_ptr SerializedPageSpillReader::unspillNextPage() { - VELOX_CHECK(!empty()); - ensureSpillFile(); - - SCOPE_EXIT { - if (curFileStream_->atEnd()) { - curFileStream_.reset(); - } - }; - - // Read payload headers - const auto isNull = !!(curFileStream_->read()); - if (isNull) { - return nullptr; - } - const auto iobufBytes = curFileStream_->read(); - const auto hasNumRows = curFileStream_->read() == 0 ? false : true; - int64_t numRows{0}; - if (hasNumRows) { - numRows = curFileStream_->read(); - } - - // Read payload - VELOX_CHECK_GE(curFileStream_->remainingSize(), iobufBytes); - void* rawBuf = pool_->allocate(iobufBytes); - curFileStream_->readBytes(reinterpret_cast(rawBuf), iobufBytes); - - auto* userData = new FreeData{pool_->shared_from_this(), iobufBytes}; - auto iobuf = - folly::IOBuf::takeOwnership(rawBuf, iobufBytes, freeFunc, userData, true); - - return std::make_shared( - std::move(iobuf), - nullptr, - hasNumRows ? std::optional(numRows) : std::nullopt); -} -} // namespace facebook::velox::exec diff --git a/velox/exec/SerializedPageSpiller.h b/velox/exec/SerializedPageSpiller.h deleted file mode 100644 index b1974f5bc66b..000000000000 --- a/velox/exec/SerializedPageSpiller.h +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "velox/common/base/SpillStats.h" -#include "velox/common/file/FileInputStream.h" -#include "velox/exec/ExchangeQueue.h" -#include "velox/exec/SpillFile.h" - -namespace facebook::velox::exec { -namespace test { -class SerializedPageSpillerHelper; -class SerializedPageSpillReaderHelper; -} // namespace test - -/// Used for spilling a sequence of 'SerializedPage'. The spiller preserves the -/// order of the pages. -class SerializedPageSpiller : public SpillWriterBase { - public: - struct Result { - SpillFiles spillFiles; - uint64_t totalPages; - }; - - SerializedPageSpiller( - uint64_t writeBufferSize, - uint64_t targetFileSize, - const std::string& pathPrefix, - const std::string& fileCreateConfig, - common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, - memory::MemoryPool* pool, - folly::Synchronized* stats) - : SpillWriterBase( - writeBufferSize, - targetFileSize, - pathPrefix, - fileCreateConfig, - updateAndCheckSpillLimitCb, - pool, - stats) { - VELOX_CHECK_NOT_NULL(pool_); - } - - /// Spill 'pages' to disk. The method does not free the original in-memory - /// structure of 'pages'. It is caller's responsibility to free them. - void spill(const std::vector>& pages); - - /// Finishes the spilling and return the spilled result. - Result finishSpill(); - - private: - void flushBuffer( - SpillWriteFile* file, - uint64_t& writtenBytes, - uint64_t& flushTimeNs, - uint64_t& writeTimeNs) override; - - bool bufferEmpty() const override; - - uint64_t bufferSize() const override; - - void addFinishedFile(SpillWriteFile* file) override; - - uint64_t totalBytes_{0}; - - uint64_t totalPages_{0}; - - std::unique_ptr bufferStream_; - - friend class test::SerializedPageSpillerHelper; -}; - -/// Used for reading a sequence of 'SerializedPage' that were spilled by -/// 'SerializedPageSpiller'. The reader preserves the page order, and provides -/// random index access functionality in a load-on-read manner. It is used by -/// 'DestinationBuffer' and provides convenient APIs for reading and deleting -/// the pages, with unspilling handled transparently. -class SerializedPageSpillReader { - public: - SerializedPageSpillReader( - SerializedPageSpiller::Result&& spillResult, - uint64_t readBufferSize, - memory::MemoryPool* pool, - folly::Synchronized* spillStats) - : readBufferSize_(readBufferSize), - pool_(pool), - spillStats_(spillStats), - numPages_(spillResult.totalPages) { - for (const auto& spillFileInfo : spillResult.spillFiles) { - spillFilePaths_.push_back(spillFileInfo.path); - } - } - - /// Returns true if there are any remaining pages left to read. - bool empty() const; - - /// Returns the current number of pages in the reader. - uint64_t numPages() const; - - /// Returns the page at 'index' in the reader. - std::shared_ptr at(uint64_t index); - - /// Delete 'numPages' from the front. - void deleteFront(uint64_t numPages); - - /// Delete all pages from the reader. - void deleteAll(); - - private: - // Ensures the current file stream is open. If not, opens the next file. - void ensureSpillFile(); - - // Ensures all pages up to 'index' are loaded in memory. - void ensurePages(uint64_t index); - - // Unspills one serialized page and returns it. - std::shared_ptr unspillNextPage(); - - const uint64_t readBufferSize_; - - memory::MemoryPool* const pool_; - - folly::Synchronized* const spillStats_; - - // The current file stream. - std::unique_ptr curFileStream_; - - // A small number of front pages buffered in memory from spilled pages. - // These pages will be kept in memory and won't be spilled again. - std::vector> bufferedPages_; - - std::deque spillFilePaths_; - - uint64_t numPages_; - - friend class test::SerializedPageSpillReaderHelper; -}; -} // namespace facebook::velox::exec diff --git a/velox/exec/SortBuffer.cpp b/velox/exec/SortBuffer.cpp index d64f858d95f1..15bf8890b5bf 100644 --- a/velox/exec/SortBuffer.cpp +++ b/velox/exec/SortBuffer.cpp @@ -37,7 +37,7 @@ SortBuffer::SortBuffer( spillConfig_(spillConfig), spillStats_(spillStats), sortedRows_(0, memory::StlAllocator(*pool)) { - VELOX_CHECK_GE(input_->size(), sortCompareFlags_.size()); + VELOX_CHECK_GE(input_->children().size(), sortCompareFlags_.size()); VELOX_CHECK_GT(sortCompareFlags_.size(), 0); VELOX_CHECK_EQ(sortColumnIndices.size(), sortCompareFlags_.size()); VELOX_CHECK_NOT_NULL(nonReclaimableSection_); @@ -74,7 +74,7 @@ SortBuffer::SortBuffer( } data_ = std::make_unique( - sortedColumnTypes, nonSortedColumnTypes, pool_); + sortedColumnTypes, nonSortedColumnTypes, /*useListRowIndex=*/true, pool_); spillerStoreType_ = ROW(std::move(sortedSpillColumnNames), std::move(sortedSpillColumnTypes)); } @@ -90,12 +90,12 @@ void SortBuffer::addInput(const VectorPtr& input) { VELOX_CHECK(!noMoreInput_); ensureInputFits(input); - SelectivityVector allRows(input->size()); + const SelectivityVector allRows(input->size()); std::vector rows(input->size()); for (int row = 0; row < input->size(); ++row) { rows[row] = data_->newRow(); } - auto* inputRow = input->as(); + const auto* inputRow = input->as(); for (const auto& columnProjection : columnMap_) { DecodedVector decoded( *inputRow->childAt(columnProjection.outputChannel), allRows); @@ -128,6 +128,7 @@ void SortBuffer::noMoreInput() { updateEstimatedOutputRowSize(); // Sort the pointers to the rows in RowContainer (data_) instead of sorting // the rows. + // TODO: Reuse 'RowContainer::rowPointers_'. sortedRows_.resize(numInputRows_); RowContainerIterator iter; data_->listRows(&iter, numInputRows_, sortedRows_.data()); @@ -310,7 +311,7 @@ void SortBuffer::ensureSortFits() { } // The memory for std::vector sorted rows and prefix sort required buffer. - uint64_t sortBufferToReserve = + const auto sortBufferToReserve = numInputRows_ * sizeof(char*) + PrefixSort::maxRequiredBytes( data_.get(), sortCompareFlags_, prefixSortConfig_, pool_); @@ -389,10 +390,6 @@ void SortBuffer::prepareOutput(vector_size_t batchSize) { BaseVector::create(input_, batchSize, pool_)); } - for (auto& child : output_->children()) { - child->resize(batchSize); - } - if (hasSpilled()) { spillSources_.resize(batchSize); spillSourceRows_.resize(batchSize); @@ -488,7 +485,7 @@ void SortBuffer::prepareOutputWithSpill() { VELOX_CHECK_EQ(spillPartitionSet_.size(), 1); spillMerger_ = spillPartitionSet_.begin()->second->createOrderedReader( - spillConfig_->readBufferSize, pool(), spillStats_); + *spillConfig_, pool(), spillStats_); spillPartitionSet_.clear(); } } // namespace facebook::velox::exec diff --git a/velox/exec/SortWindowBuild.cpp b/velox/exec/SortWindowBuild.cpp index f25175cc2cfa..eeadc8b6d9a7 100644 --- a/velox/exec/SortWindowBuild.cpp +++ b/velox/exec/SortWindowBuild.cpp @@ -16,6 +16,7 @@ #include "velox/exec/SortWindowBuild.h" #include "velox/exec/MemoryReclaimer.h" +#include "velox/exec/Window.h" namespace facebook::velox::exec { @@ -45,16 +46,19 @@ SortWindowBuild::SortWindowBuild( common::PrefixSortConfig&& prefixSortConfig, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, folly::Synchronized* spillStats) : WindowBuild(node, pool, spillConfig, nonReclaimableSection), numPartitionKeys_{node->partitionKeys().size()}, compareFlags_{makeCompareFlags(numPartitionKeys_, node->sortingOrders())}, pool_(pool), prefixSortConfig_(prefixSortConfig), + opStats_(opStats), spillStats_(spillStats), sortedRows_(0, memory::StlAllocator(*pool)), partitionStartRows_(0, memory::StlAllocator(*pool)) { VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK_NOT_NULL(opStats_); allKeyInfo_.reserve(partitionKeyInfo_.size() + sortKeyInfo_.size()); allKeyInfo_.insert( allKeyInfo_.cend(), partitionKeyInfo_.begin(), partitionKeyInfo_.end()); @@ -72,13 +76,20 @@ void SortWindowBuild::addInput(RowVectorPtr input) { // Add all the rows into the RowContainer. for (auto row = 0; row < input->size(); ++row) { - char* newRow = data_->newRow(); + addDecodedInputRow(decodedInputVectors_, row); + } +} - for (auto col = 0; col < input->childrenSize(); ++col) { - data_->store(decodedInputVectors_[col], row, newRow, col); - } +void SortWindowBuild::addDecodedInputRow( + std::vector& decodedInputVectors, + vector_size_t row) { + char* newRow = data_->newRow(); + + for (auto col = 0; col < inputChannels_.size(); ++col) { + data_->store(decodedInputVectors[col], row, newRow, col); } - numRows_ += input->size(); + + numRows_++; } void SortWindowBuild::ensureInputFits(const RowVectorPtr& input) { @@ -280,7 +291,7 @@ void SortWindowBuild::noMoreInput() { spiller_->finishSpill(spillPartitionSet); VELOX_CHECK_EQ(spillPartitionSet.size(), 1); merge_ = spillPartitionSet.begin()->second->createOrderedReader( - spillConfig_->readBufferSize, pool_, spillStats_); + *spillConfig_, pool_, spillStats_); } else { // At this point we have seen all the input rows. The operator is // being prepared to output rows now. @@ -294,14 +305,31 @@ void SortWindowBuild::noMoreInput() { pool_->release(); } -void SortWindowBuild::loadNextPartitionFromSpill() { +void SortWindowBuild::loadNextPartitionBatchFromSpill() { + // Check if current partition batch still has available partitions. If so, + // return directly. + if (currentPartition_ < static_cast(partitionStartRows_.size() - 2)) { + return; + } + + const int minReadBatchRows = spillConfig_->windowMinReadBatchRows; sortedRows_.clear(); - sortedRows_.shrink_to_fit(); + sortedRows_.reserve(minReadBatchRows); data_->clear(); + partitionStartRows_.clear(); + partitionStartRows_.reserve(minReadBatchRows); + partitionStartRows_.push_back(0); + currentPartition_ = -1; + numSpillReadBatches_++; + // Load at least #minReadBatchRows rows and a complete partition. The rows + // might contain multiple partitions. Record the partition boundaries as + // inMemory case. In this way, the logic of getting window partitions would be + // identical between inMemory and spill. for (;;) { auto next = merge_->next(); if (next == nullptr) { + partitionStartRows_.push_back(sortedRows_.size()); break; } @@ -324,7 +352,10 @@ void SortWindowBuild::loadNextPartitionFromSpill() { } if (newPartition) { - break; + partitionStartRows_.push_back(sortedRows_.size()); + if (sortedRows_.size() >= minReadBatchRows) { + break; + } } auto* newRow = data_->newRow(); @@ -334,16 +365,19 @@ void SortWindowBuild::loadNextPartitionFromSpill() { sortedRows_.push_back(newRow); next->pop(); } -} -std::shared_ptr SortWindowBuild::nextPartition() { - if (merge_ != nullptr) { - VELOX_CHECK(!sortedRows_.empty(), "No window partitions available"); - auto partition = folly::Range(sortedRows_.data(), sortedRows_.size()); - return std::make_shared( - data_.get(), partition, inversedInputChannels_, sortKeyInfo_); + // No more partition batches. All data is consumed. + if (sortedRows_.empty()) { + partitionStartRows_.clear(); + numSpillReadBatches_--; + + auto lockedOpStats = opStats_->wlock(); + lockedOpStats->runtimeStats[Window::kWindowSpillReadNumBatches] = + RuntimeMetric(numSpillReadBatches_); } +} +std::shared_ptr SortWindowBuild::nextPartition() { VELOX_CHECK(!partitionStartRows_.empty(), "No window partitions available"); currentPartition_++; @@ -364,8 +398,7 @@ std::shared_ptr SortWindowBuild::nextPartition() { bool SortWindowBuild::hasNextPartition() { if (merge_ != nullptr) { - loadNextPartitionFromSpill(); - return !sortedRows_.empty(); + loadNextPartitionBatchFromSpill(); } return partitionStartRows_.size() > 0 && diff --git a/velox/exec/SortWindowBuild.h b/velox/exec/SortWindowBuild.h index 72875094007a..c6ddffebfbb6 100644 --- a/velox/exec/SortWindowBuild.h +++ b/velox/exec/SortWindowBuild.h @@ -32,6 +32,7 @@ class SortWindowBuild : public WindowBuild { common::PrefixSortConfig&& prefixSortConfig, const common::SpillConfig* spillConfig, tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, folly::Synchronized* spillStats); ~SortWindowBuild() override { @@ -45,6 +46,10 @@ class SortWindowBuild : public WindowBuild { void addInput(RowVectorPtr input) override; + void addDecodedInputRow( + std::vector& decodedInputVectors, + vector_size_t row); + void spill() override; std::optional spilledStats() const override; @@ -55,9 +60,9 @@ class SortWindowBuild : public WindowBuild { std::shared_ptr nextPartition() override; - private: void ensureInputFits(const RowVectorPtr& input); + private: void ensureSortFits(); void setupSpiller(); @@ -75,8 +80,10 @@ class SortWindowBuild : public WindowBuild { // Find the next partition start row from start. vector_size_t findNextPartitionStartRow(vector_size_t start); - // Reads next partition from spilled data into 'data_' and 'sortedRows_'. - void loadNextPartitionFromSpill(); + // Load the next partition batch if needed. If current partition batch is not + // entirely consumed, return directly. Otherwise, read next partition batch + // from spilled data into 'data_' and set pointers in 'sortedRows_'. + void loadNextPartitionBatchFromSpill(); const size_t numPartitionKeys_; @@ -92,6 +99,8 @@ class SortWindowBuild : public WindowBuild { // Config for Prefix-sort. const common::PrefixSortConfig prefixSortConfig_; + folly::Synchronized* const opStats_; + folly::Synchronized* const spillStats_; // allKeyInfo_ is a combination of (partitionKeyInfo_ and sortKeyInfo_). @@ -121,5 +130,8 @@ class SortWindowBuild : public WindowBuild { // Used to sort-merge spilled data. std::unique_ptr> merge_; + + // Number of batches of whole partitions read from spilled data. + uint64_t numSpillReadBatches_ = 0; }; } // namespace facebook::velox::exec diff --git a/velox/exec/SortedAggregations.cpp b/velox/exec/SortedAggregations.cpp index 79fbbedb885a..419c8102d435 100644 --- a/velox/exec/SortedAggregations.cpp +++ b/velox/exec/SortedAggregations.cpp @@ -63,7 +63,8 @@ struct RowPointers { SortedAggregations::SortedAggregations( const std::vector& aggregates, const RowTypePtr& inputType, - memory::MemoryPool* pool) { + memory::MemoryPool* pool) + : pool_(pool) { // Collect inputs and sorting keys from all aggregates. std::unordered_set allInputs; for (const auto* aggregate : aggregates) { @@ -366,7 +367,7 @@ vector_size_t SortedAggregations::extractSingleGroup( void SortedAggregations::extractValues( folly::Range groups, const RowVectorPtr& result) { - raw_vector temp; + raw_vector indices(pool_); SelectivityVector rows; std::vector groupRows; for (const auto& [sortingSpec, aggregates] : aggregates_) { @@ -432,7 +433,7 @@ void SortedAggregations::extractValues( aggregate->function->initializeNewGroups( groups.data(), folly::Range( - iota(groups.size(), temp), groups.size())); + iota(groups.size(), indices), groups.size())); } } } diff --git a/velox/exec/SortedAggregations.h b/velox/exec/SortedAggregations.h index 85acbdf4c9db..10ccba2ff233 100644 --- a/velox/exec/SortedAggregations.h +++ b/velox/exec/SortedAggregations.h @@ -142,6 +142,8 @@ class SortedAggregations { } }; + memory::MemoryPool* const pool_; + // Aggregates grouped by sorting keys and orders. folly:: F14FastMap, Hash, EqualTo> diff --git a/velox/exec/SpatialIndex.cpp b/velox/exec/SpatialIndex.cpp new file mode 100644 index 000000000000..52fee815af25 --- /dev/null +++ b/velox/exec/SpatialIndex.cpp @@ -0,0 +1,181 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/exec/HilbertIndex.h" +#include "velox/exec/SpatialIndex.h" + +namespace facebook::velox::exec { + +std::vector RTreeLevel::query( + const Envelope& queryEnv, + const std::vector& branchIndices) const { + std::vector result; + + for (size_t branchIdx : branchIndices) { + size_t startIdx = branchIdx * branchSize_; + size_t endIdx = std::min(startIdx + branchSize_, minXs_.size()); + for (size_t idx = startIdx; idx < endIdx; ++idx) { + bool intersects = (queryEnv.maxX >= minXs_[idx]) && + (queryEnv.maxY >= minYs_[idx]) && (queryEnv.minX <= maxXs_[idx]) && + (queryEnv.minY <= maxYs_[idx]); + if (intersects) { + result.push_back(idx); + } + } + } + + return result; +} + +namespace { +std::pair> buildLevel( + uint32_t branchSize, + const std::vector& envelopes) { + std::vector minXs; + minXs.reserve(envelopes.size()); + std::vector minYs; + minYs.reserve(envelopes.size()); + std::vector maxXs; + maxXs.reserve(envelopes.size()); + std::vector maxYs; + maxYs.reserve(envelopes.size()); + + std::vector parentEnvelopes; + parentEnvelopes.reserve((envelopes.size() + branchSize - 1) / branchSize); + Envelope currentBounds = Envelope::empty(); + + uint32_t idx = 0; + for (const auto& env : envelopes) { + ++idx; + currentBounds.maxX = std::max(currentBounds.maxX, env.maxX); + currentBounds.maxY = std::max(currentBounds.maxY, env.maxY); + currentBounds.minX = std::min(currentBounds.minX, env.minX); + currentBounds.minY = std::min(currentBounds.minY, env.minY); + if (idx % branchSize == 0) { + parentEnvelopes.push_back(currentBounds); + currentBounds = Envelope::empty(); + } + + minXs.push_back(env.minX); + minYs.push_back(env.minY); + maxXs.push_back(env.maxX); + maxYs.push_back(env.maxY); + } + + if (!currentBounds.isEmpty()) { + parentEnvelopes.push_back(currentBounds); + } + + return { + RTreeLevel( + branchSize, + std::move(minXs), + std::move(minYs), + std::move(maxXs), + std::move(maxYs)), + std::move(parentEnvelopes)}; +} +} // namespace + +SpatialIndex::SpatialIndex( + Envelope bounds, + std::vector envelopes, + uint32_t branchSize) + : branchSize_(branchSize), bounds_(std::move(bounds)) { + VELOX_CHECK_GT(branchSize_, 1); + + if (!bounds_.isEmpty()) { + HilbertIndex hilbert( + bounds_.minX, bounds_.minY, bounds_.maxX, bounds_.maxY); + + std::sort( + envelopes.begin(), envelopes.end(), [&](const auto& a, const auto& b) { + return hilbert.indexOf(a.minX, a.minY) < + hilbert.indexOf(b.minX, b.minY); + }); + } + + rowIndices_.reserve(envelopes.size()); + for (const auto& env : envelopes) { + VELOX_CHECK(env.minX >= bounds_.minX); + VELOX_CHECK(env.minY >= bounds_.minY); + VELOX_CHECK(env.maxX <= bounds_.maxX); + VELOX_CHECK(env.maxY <= bounds_.maxY); + rowIndices_.push_back(env.rowIndex); + } + + if (envelopes.size() > 0) { + size_t numLevels = + std::ceil(std::log(envelopes.size()) / std::log(branchSize_)); + levels_.reserve(numLevels); + } + + while (envelopes.size() > branchSize_) { + auto [level, parentEnvelopes] = buildLevel(branchSize_, envelopes); + levels_.push_back(std::move(level)); + envelopes = std::move(parentEnvelopes); + } + + if (!envelopes.empty() && (envelopes.size() > 1 || levels_.empty())) { + levels_.push_back(buildLevel(branchSize_, envelopes).first); + } + + if (!levels_.empty()) { + VELOX_CHECK_GT(branchSize_ + 1, levels_.back().size()); + } +} + +std::vector SpatialIndex::query(const Envelope& queryEnv) const { + std::vector result; + if (levels_.empty() || !Envelope::intersects(queryEnv, bounds_)) { + return result; + } + + size_t thisLevel = levels_.size() - 1; + VELOX_CHECK_GT(levels_[thisLevel].size(), 0); + VELOX_CHECK_GT(branchSize_ + 1, levels_[thisLevel].size()); + + // The top level should have only one branch. + std::vector childIndices = {0}; + for (; thisLevel > 0; --thisLevel) { + // Avoiding thisLevel = 0 due to int underflow + childIndices = levels_[thisLevel].query(queryEnv, childIndices); + // If we have no matches, return. + if (childIndices.empty()) { + return result; + } + } + + // We're at level 0 now. The indices index into rowIndices. + VELOX_DCHECK_EQ(thisLevel, 0); + childIndices = levels_[thisLevel].query(queryEnv, childIndices); + result.reserve(childIndices.size()); + for (auto idx : childIndices) { + result.push_back(rowIndices_[idx]); + } + + return result; +} + +Envelope SpatialIndex::bounds() const { + return bounds_; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialIndex.h b/velox/exec/SpatialIndex.h new file mode 100644 index 000000000000..d6b2ad6d83b4 --- /dev/null +++ b/velox/exec/SpatialIndex.h @@ -0,0 +1,231 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include "velox/vector/TypeAliases.h" + +namespace facebook::velox::exec { + +/// A minimal envelope for a geometry. +/// It also includes an index for the geometry for later reference. This can +/// be -1 if the geometry is not indexed. +/// +/// Envelopes use float32s instead of float64s so that SIMD loops can be +/// twice as fast. Our geometries use float64 coordinates, so we have to +/// downcast them for the envelope. The loss of precision is theoretically fine +/// because the envelope checks are already approximate: either they don't +/// intersect or they might intersect. Thus, expanding the envelopes slightly +/// does not affect correctness (but it might affect efficiency slightly). +/// +/// We want to show that if two envelopes expressed with float64 precision would +/// intersect, the envelopes with float32 precision would also intersect. +/// +/// Define +/// ``` +/// nextUp(f) = std::nextafter(f, std::numeric_limits::infinity()) +/// nextDown(f) = std::nextafter(f, -std::numeric_limits::infinity()) +/// ``` +/// which move a float up or down one ulp (unit in the last place). +/// +/// Since the conditions are all of the form `maxX >= minX` for float64s maxX +/// and minX, we need to show that this implies `nextUp((float) maxX) >= +/// nextDown((float) minX)`. +/// +/// Assume you have a double `d` and two adjacent floats `f0` and `f1`, such +/// that `d` is "between" `f0` and `f1`: +/// +/// 1. `(double) f0 <= d <= (double) f1` +/// 2. `nextup(f0) == f1 && f0 = nextdown(f0)` +/// +/// This implies `nextdown((float) d) <= f0 && nextup((float) d) >= f1`. +/// +/// Let double `minX` have two adjacent floats `f0`, `f1` as above, and `maxX` +/// have two adjacent floats `g0`, `g1`. Then +/// ``` +/// (double) nextDown((float) minX) +/// <= (double) f0 +/// <= minX +/// <= maxX +/// <= (double) g1 +/// <= (double) nextUp((float) maxX) +/// ``` +/// +/// And this implies `nextDown((float) minX) <= nextUp((float) maxX)` as +/// desired. The same argument applies to all for members, so if we construct +/// the float32 precision envelope by applying nextDown to the minX/Ys and +/// nextUp to maxX/Ys, the float32 envelope intersects in all cases that the +/// float64 envelope would (but not necessarily the converse). +struct Envelope { + float minX{std::numeric_limits::infinity()}; + float minY{std::numeric_limits::infinity()}; + float maxX{-std::numeric_limits::infinity()}; + float maxY{-std::numeric_limits::infinity()}; + vector_size_t rowIndex = -1; + + /// Returns true if the intersection of two envelopes is not empty. + static inline bool intersects(const Envelope& left, const Envelope& right) { + return (left.maxX >= right.minX) && (left.minX <= right.maxX) && + (left.maxY >= right.minY) && (left.minY <= right.maxY); + } + + /// Returns true if the envelope contains at least one point. + /// An envelope of a point is not empty. + inline bool isEmpty() const { + // This negation handles NaNs correctly. + return !((minX <= maxX) && (minY <= maxY)); + } + + /// Expands this Envelope to also contain the other. + inline void merge(const Envelope& other) { + minX = std::min(minX, other.minX); + minY = std::min(minY, other.minY); + maxX = std::max(maxX, other.maxX); + maxY = std::max(maxY, other.maxY); + } + + /// Construct an empty envelope. + static constexpr inline Envelope empty() { + return Envelope{ + .minX = std::numeric_limits::infinity(), + .minY = std::numeric_limits::infinity(), + .maxX = -std::numeric_limits::infinity(), + .maxY = -std::numeric_limits::infinity()}; + } + + static constexpr inline Envelope from( + double minX, + double minY, + double maxX, + double maxY, + vector_size_t rowIndex = -1) { + return Envelope{ + .minX = std::nextafterf( + static_cast(minX), -std::numeric_limits::infinity()), + .minY = std::nextafterf( + static_cast(minY), -std::numeric_limits::infinity()), + .maxX = std::nextafterf( + static_cast(maxX), std::numeric_limits::infinity()), + .maxY = std::nextafterf( + static_cast(maxY), std::numeric_limits::infinity()), + .rowIndex = rowIndex}; + } + + static inline Envelope of(const std::vector& envelopes) { + Envelope result = Envelope::empty(); + for (const auto& envelope : envelopes) { + result.merge(envelope); + } + return result; + } +}; + +/// A single level of an R-tree. It is a set of envelopes that can be linearly +/// scanned for envelope intersection. +class RTreeLevel { + public: + RTreeLevel(const RTreeLevel&) = delete; + RTreeLevel& operator=(const RTreeLevel&) = delete; + + RTreeLevel() = default; + RTreeLevel(RTreeLevel&&) = default; + RTreeLevel& operator=(RTreeLevel&&) = default; + ~RTreeLevel() = default; + + explicit RTreeLevel( + size_t branchSize, + std::vector minXs, + std::vector minYs, + std::vector maxXs, + std::vector maxYs) + : branchSize_{branchSize}, + minXs_(std::move(minXs)), + minYs_(std::move(minYs)), + maxXs_(std::move(maxXs)), + maxYs_(std::move(maxYs)) {} + + /// Returns the internal indices of all envelopes that probeEnv intersects. + /// Order of the returned indices is an implementation detail and cannot be + /// relied upon. + /// This does not do a short-circuit bounds check: the caller should do that + /// first. + std::vector query( + const Envelope& queryEnv, + const std::vector& branchIndices) const; + + size_t size() const { + return minXs_.size(); + } + + private: + size_t branchSize_{}; + Envelope bounds_; + std::vector minXs_{}; + std::vector minYs_{}; + std::vector maxXs_{}; + std::vector maxYs_{}; +}; + +/// A spatial index for a set of geometries. The index only cares about the +/// envelopes of the geometries, and an index into the geometries (not stored in +/// SpatialIndex). +/// +/// The contract is that SpatialIndex::probe returns the indices of all +/// envelopes that probeEnv intersects. The form of the index is an +/// implementation detail. The order of the returned indicies is an +/// implementation detail. +class SpatialIndex { + public: + SpatialIndex(const SpatialIndex&) = delete; + SpatialIndex& operator=(const SpatialIndex&) = delete; + + SpatialIndex() = default; + SpatialIndex(SpatialIndex&&) = default; + SpatialIndex& operator=(SpatialIndex&&) = default; + ~SpatialIndex() = default; + + static const uint32_t kDefaultRTreeBranchSize = 32; + + /// Constructs a spatial index from envelopes contained with `bounds`. + /// `bounds` must contain all envelopes in `envelopes`, otherwise the + /// an assertio will fail. Envelopes should not contain NaN coordinates. + explicit SpatialIndex( + Envelope bounds, + std::vector envelopes, + uint32_t branchSize = kDefaultRTreeBranchSize); + + /// Returns the row indices of all envelopes that probeEnv intersects. + /// Order of the returned indices is an implementation detail and cannot be + /// relied upon. + std::vector query(const Envelope& queryEnv) const; + + /// Returns the envelope of the all envelopes in the index. + /// The returned envelope will have index = -1. + Envelope bounds() const; + + private: + uint32_t branchSize_ = kDefaultRTreeBranchSize; + + Envelope bounds_ = Envelope::empty(); + std::vector levels_{}; + std::vector rowIndices_{}; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinBuild.cpp b/velox/exec/SpatialJoinBuild.cpp new file mode 100644 index 000000000000..18dea2179caa --- /dev/null +++ b/velox/exec/SpatialJoinBuild.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/SpatialJoinBuild.h" +#include +#include "velox/common/geospatial/GeometryConstants.h" +#ifdef VELOX_ENABLE_GEO +#include "velox/common/geospatial/GeometrySerde.h" +#endif +#include "velox/exec/Task.h" + +namespace facebook::velox::exec { + +using velox::common::geospatial::GeometrySerializationType; + +void SpatialJoinBridge::setData(SpatialBuildResult buildResult) { + std::vector promises; + { + std::lock_guard l(mutex_); + VELOX_CHECK(!buildResult_.has_value(), "setData must be called only once"); + buildResult_ = std::move(buildResult); + promises = std::move(promises_); + } + notify(std::move(promises)); +} + +std::optional SpatialJoinBridge::dataOrFuture( + ContinueFuture* future) { + std::lock_guard l(mutex_); + VELOX_CHECK(!cancelled_, "Getting data after the build side is aborted"); + if (buildResult_.has_value()) { + return buildResult_.value(); + } + promises_.emplace_back("SpatialJoinBridge::dataOrFuture"); + *future = promises_.back().getSemiFuture(); + return std::nullopt; +} + +SpatialJoinBuild::SpatialJoinBuild( + int32_t operatorId, + DriverCtx* driverCtx, + std::shared_ptr joinNode) + : Operator( + driverCtx, + nullptr, + operatorId, + joinNode->id(), + "SpatialJoinBuild") { + const auto& buildType = joinNode->rightNode()->outputType(); + buildGeometryChannel_ = + buildType->getChildIdx(joinNode->buildGeometry()->name()); + VELOX_CHECK_EQ( + buildType->childAt(buildGeometryChannel_), + joinNode->buildGeometry()->type()); + if (joinNode->radius().has_value()) { + auto radiusVar = joinNode->radius().value(); + uint32_t radiusChannel = buildType->getChildIdx(radiusVar->name()); + VELOX_CHECK_EQ(buildType->childAt(radiusChannel), radiusVar->type()); + radiusChannel_ = radiusChannel; + } +} + +void SpatialJoinBuild::addInput(RowVectorPtr input) { + if (input->size() > 0) { + // Load lazy vectors before storing. + for (auto& child : input->children()) { + child->loadedVector(); + } + dataVectors_.emplace_back(std::move(input)); + } +} + +BlockingReason SpatialJoinBuild::isBlocked(ContinueFuture* future) { + if (!future_.valid()) { + return BlockingReason::kNotBlocked; + } + *future = std::move(future_); + return BlockingReason::kWaitForJoinBuild; +} + +// Merge adjacent vectors to larger vectors as long as the result do not exceed +// the size limit. This is important for performance because each small vector +// here would be duplicated by the number of rows on probe side, result in huge +// number of small vectors in the output. +std::vector SpatialJoinBuild::mergeDataVectors() const { + const auto maxBatchRows = + operatorCtx_->task()->queryCtx()->queryConfig().maxOutputBatchRows(); + std::vector merged; + for (size_t i = 0; i < dataVectors_.size();) { + // convert int32_t to int64_t to avoid sum overflow + int64_t batchSize = static_cast(dataVectors_[i]->size()); + auto j = i + 1; + while (j < dataVectors_.size() && + batchSize + dataVectors_[j]->size() <= maxBatchRows) { + batchSize += dataVectors_[j++]->size(); + } + if (j == i + 1) { + merged.push_back(dataVectors_[i++]); + } else { + auto batch = BaseVector::create( + dataVectors_[i]->type(), + static_cast(batchSize), + pool()); + batchSize = 0; + while (i < j) { + auto* source = dataVectors_[i++].get(); + batch->copy( + source, static_cast(batchSize), 0, source->size()); + batchSize += source->size(); + } + merged.push_back(std::move(batch)); + } + } + return merged; +} + +Envelope SpatialJoinBuild::readEnvelope( + const StringView& geometryBytes, + double radius) { +#ifdef VELOX_ENABLE_GEO + radius = std::max(radius, 0.0); + auto geosEnvelope = + common::geospatial::GeometryDeserializer::deserializeEnvelope( + geometryBytes); + if (geosEnvelope->isNull()) { + return Envelope::empty(); + } else { + return Envelope::from( + geosEnvelope->getMinX() - radius, + geosEnvelope->getMinY() - radius, + geosEnvelope->getMaxX() + radius, + geosEnvelope->getMaxY() + radius); + } +#else + // When VELOX_ENABLE_GEO is not set, return an envelope of infinite area + // to ensure all geometries are considered for spatial join + return Envelope::from( + -std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + std::numeric_limits::infinity(), + std::numeric_limits::infinity()); +#endif +} + +SpatialIndex SpatialJoinBuild::buildSpatialIndex( + const std::vector& data, + column_index_t geometryIdx, + std::optional radiusIdx) { + size_t numRows = 0; + for (auto& vector : data) { + numRows += vector->size(); + } + std::vector envelopes; + // TODO: Chunk the data to avoid allocating a large vector. + envelopes.reserve(numRows); + + DecodedVector radiusCol; + DecodedVector geometryCol; + vector_size_t offset = 0; + Envelope bounds = Envelope::empty(); + + for (auto& vector : data) { + const auto& rawGeometryCol = + vector->childAt(geometryIdx)->asChecked>(); + geometryCol.decode(*rawGeometryCol); + + auto constantZero = velox::BaseVector::createConstant( + velox::DOUBLE(), 0.0, vector->size(), pool()); + if (radiusIdx.has_value()) { + const auto& rawRadiusCol = + vector->childAt(radiusIdx.value())->asChecked>(); + radiusCol.decode(*rawRadiusCol); + } else { + radiusCol.decode(*constantZero); + } + + // TODO: Make a selectivity vector based on nulls and use for DecodedVector. + for (vector_size_t i = 0; i < vector->size(); ++i) { + if (geometryCol.isNullAt(i) || radiusCol.isNullAt(i)) { + // If geometry or radius is null, it will not match the predicate and so + // we should skip the envelope. + continue; + } + double radius = radiusCol.valueAt(i); + const StringView geometryBytes = geometryCol.valueAt(i); + Envelope envelope = SpatialJoinBuild::readEnvelope(geometryBytes, radius); + if (FOLLY_UNLIKELY(envelope.isEmpty())) { + continue; + } + envelope.rowIndex = offset + geometryCol.index(i); + bounds.merge(envelope); + envelopes.push_back(std::move(envelope)); + } + offset += vector->size(); + } + return SpatialIndex(std::move(bounds), std::move(envelopes)); +} + +void SpatialJoinBuild::noMoreInput() { + Operator::noMoreInput(); + std::vector promises; + std::vector> peers; + // The last Driver to hit SpatialJoinBuild::finish gathers the data from + // all build Drivers and hands it over to the probe side. At this + // point all build Drivers are continued and will free their + // state. allPeersFinished is true only for the last Driver of the + // build pipeline. + if (!operatorCtx_->task()->allPeersFinished( + planNodeId(), operatorCtx_->driver(), &future_, promises, peers)) { + return; + } + + { + auto promisesGuard = folly::makeGuard([&]() { + // Realize the promises so that the other Drivers (which were not + // the last to finish) can continue from the barrier and finish. + peers.clear(); + for (auto& promise : promises) { + promise.setValue(); + } + }); + + for (auto& peer : peers) { + auto op = peer->findOperator(planNodeId()); + auto* build = dynamic_cast(op); + VELOX_CHECK_NOT_NULL(build); + dataVectors_.insert( + dataVectors_.end(), + std::make_move_iterator(build->dataVectors_.begin()), + std::make_move_iterator(build->dataVectors_.end())); + } + } + + dataVectors_ = mergeDataVectors(); + SpatialIndex spatialIndex = + buildSpatialIndex(dataVectors_, buildGeometryChannel_, radiusChannel_); + SpatialBuildResult buildResult; + buildResult.spatialIndex = + std::make_shared(std::move(spatialIndex)); + buildResult.buildVectors = std::move(dataVectors_); + + operatorCtx_->task() + ->getSpatialJoinBridge( + operatorCtx_->driverCtx()->splitGroupId, planNodeId()) + ->setData(std::move(buildResult)); +} + +bool SpatialJoinBuild::isFinished() { + return !future_.valid() && noMoreInput_; +} +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinBuild.h b/velox/exec/SpatialJoinBuild.h new file mode 100644 index 000000000000..daceada9f296 --- /dev/null +++ b/velox/exec/SpatialJoinBuild.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/exec/JoinBridge.h" +#include "velox/exec/Operator.h" +#include "velox/exec/SpatialIndex.h" + +namespace facebook::velox::exec { + +struct SpatialBuildResult { + std::vector buildVectors; + std::shared_ptr spatialIndex; +}; + +class SpatialJoinBridge : public JoinBridge { + public: + void setData(SpatialBuildResult buildResult); + + std::optional dataOrFuture(ContinueFuture* future); + + private: + std::optional buildResult_; +}; + +class SpatialJoinBuild : public Operator { + public: + SpatialJoinBuild( + int32_t operatorId, + DriverCtx* driverCtx, + std::shared_ptr joinNode); + + void addInput(RowVectorPtr input) override; + + RowVectorPtr getOutput() override { + return nullptr; + } + + bool needsInput() const override { + return !noMoreInput_; + } + + void noMoreInput() override; + + BlockingReason isBlocked(ContinueFuture* future) override; + + bool isFinished() override; + + void close() override { + dataVectors_.clear(); + Operator::close(); + } + + static Envelope readEnvelope( + const StringView& serializedGeometry, + double radius); + + private: + std::vector mergeDataVectors() const; + + SpatialIndex buildSpatialIndex( + const std::vector& data, + column_index_t geometryIdx, + std::optional radiusIdx); + + std::vector dataVectors_; + + // Channel of geometry variable used to build spatial index + column_index_t buildGeometryChannel_; + // Channel (if set) of radius variable used to build spatial index + std::optional radiusChannel_{}; + + // Future for synchronizing with other Drivers of the same pipeline. All build + // Drivers must be completed before making data available for the probe side. + ContinueFuture future_{ContinueFuture::makeEmpty()}; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinProbe.cpp b/velox/exec/SpatialJoinProbe.cpp new file mode 100644 index 000000000000..2d398dc35ebb --- /dev/null +++ b/velox/exec/SpatialJoinProbe.cpp @@ -0,0 +1,560 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/SpatialJoinProbe.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/exec/SpatialJoinBuild.h" +#include "velox/exec/Task.h" +#include "velox/expression/FieldReference.h" + +namespace facebook::velox::exec { +namespace { + +bool needsProbeMismatch(core::JoinType joinType) { + return isLeftJoin(joinType); +} + +std::vector extractProjections( + const RowTypePtr& srcType, + const RowTypePtr& destType) { + std::vector projections; + for (auto i = 0; i < srcType->size(); ++i) { + auto name = srcType->nameOf(i); + auto outIndex = destType->getChildIdxIfExists(name); + if (outIndex.has_value()) { + projections.emplace_back(i, outIndex.value()); + } + } + return projections; +} + +} // namespace + +////////////////// +// OUTPUT BUILDER + +void SpatialJoinOutputBuilder::initializeOutput( + const RowVectorPtr& input, + memory::MemoryPool* pool) { + if (output_ == nullptr) { + output_ = + BaseVector::create(outputType_, outputBatchSize_, pool); + } else { + VectorPtr outputVector = std::move(output_); + BaseVector::prepareForReuse(outputVector, outputBatchSize_); + output_ = std::static_pointer_cast(outputVector); + } + probeOutputIndices_ = allocateIndices(outputBatchSize_, pool); + rawProbeOutputIndices_ = probeOutputIndices_->asMutable(); + + // Add probe side projections as dictionary vectors + for (const auto& projection : probeProjections_) { + output_->childAt(projection.outputChannel) = wrapChild( + outputBatchSize_, + probeOutputIndices_, + input->childAt(projection.inputChannel)); + } + + // Add build side projections as uninitialized vectors + for (const auto& projection : buildProjections_) { + auto child = output_->childAt(projection.outputChannel); + if (child == nullptr) { + child = BaseVector::create( + outputType_->childAt(projection.outputChannel), + outputBatchSize_, + operatorCtx_.pool()); + } + } +} + +void SpatialJoinOutputBuilder::addOutputRow( + vector_size_t probeRow, + vector_size_t buildRow) { + VELOX_CHECK_NOT_NULL(probeOutputIndices_); + // Probe side is always a dictionary; just populate the index. + rawProbeOutputIndices_[outputRow_] = probeRow; + + // For the build side, we accumulate the ranges to copy, then copy all of + // them at once. Consecutive records are copied in one memcpy. + if (!buildCopyRanges_.empty() && + (buildCopyRanges_.back().sourceIndex + buildCopyRanges_.back().count) == + buildRow) { + ++buildCopyRanges_.back().count; + } else { + buildCopyRanges_.push_back({buildRow, outputRow_, 1}); + } + ++outputRow_; +} + +void SpatialJoinOutputBuilder::copyBuildValues( + const RowVectorPtr& buildVector) { + if (buildCopyRanges_.empty()) { + return; + } + + VELOX_CHECK_NOT_NULL(output_); + + for (const auto& projection : buildProjections_) { + const auto& buildChild = buildVector->childAt(projection.inputChannel); + const auto& outputChild = output_->childAt(projection.outputChannel); + outputChild->copyRanges(buildChild.get(), buildCopyRanges_); + } + buildCopyRanges_.clear(); +} + +void SpatialJoinOutputBuilder::addProbeMismatchRow(vector_size_t probeRow) { + VELOX_CHECK_NOT_NULL(output_); + + // Probe side is always a dictionary; just populate the index. + rawProbeOutputIndices_[outputRow_] = probeRow; + + // Null out build projections. + for (const auto& projection : buildProjections_) { + const auto& outputChild = output_->childAt(projection.outputChannel); + outputChild->setNull(outputRow_, true); + } + ++outputRow_; +} + +RowVectorPtr SpatialJoinOutputBuilder::takeOutput() { + VELOX_CHECK(buildCopyRanges_.empty()); + if (outputRow_ == 0 || !output_) { + return nullptr; + } + RowVectorPtr output = std::move(output_); + output->resize(outputRow_); + output_ = nullptr; + outputRow_ = 0; + return output; +} + +//////////////////// +// SpatialJoinProbe + +SpatialJoinProbe::SpatialJoinProbe( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& joinNode) + : Operator( + driverCtx, + joinNode->outputType(), + operatorId, + joinNode->id(), + "SpatialJoinProbe"), + joinType_(joinNode->joinType()), + outputBatchSize_{outputBatchRows()}, + joinNode_(joinNode), + buildProjections_(extractProjections( + joinNode_->rightNode()->outputType(), + outputType_)), + outputBuilder_{ + outputBatchSize_, + outputType_, + extractProjections( + joinNode_->leftNode()->outputType(), + outputType_), // these are the identity Projections + buildProjections_, + *operatorCtx_} { + auto probeType = joinNode_->leftNode()->outputType(); + identityProjections_ = extractProjections(probeType, outputType_); + probeGeometryChannel_ = + probeType->getChildIdx(joinNode_->probeGeometry()->name()); + VELOX_CHECK_EQ( + probeType->childAt(probeGeometryChannel_), + joinNode_->probeGeometry()->type()); +} + +///////// +// SETUP + +void SpatialJoinProbe::initialize() { + Operator::initialize(); + + VELOX_CHECK_NOT_NULL(joinNode_); + if (joinNode_->joinCondition() != nullptr) { + initializeFilter( + joinNode_->joinCondition(), + joinNode_->leftNode()->outputType(), + joinNode_->rightNode()->outputType()); + } + + joinNode_.reset(); +} + +void SpatialJoinProbe::initializeFilter( + const core::TypedExprPtr& filter, + const RowTypePtr& probeType, + const RowTypePtr& buildType) { + VELOX_CHECK_NULL(joinCondition_); + + std::vector filters = {filter}; + joinCondition_ = + std::make_unique(std::move(filters), operatorCtx_->execCtx()); + + column_index_t filterChannel = 0; + std::vector names; + std::vector types; + const auto numFields = joinCondition_->expr(0)->distinctFields().size(); + names.reserve(numFields); + types.reserve(numFields); + + for (const auto& field : joinCondition_->expr(0)->distinctFields()) { + const auto& name = field->field(); + auto channel = probeType->getChildIdxIfExists(name); + if (channel.has_value()) { + auto channelValue = channel.value(); + filterProbeProjections_.emplace_back(channelValue, filterChannel++); + names.emplace_back(probeType->nameOf(channelValue)); + types.emplace_back(probeType->childAt(channelValue)); + continue; + } + channel = buildType->getChildIdxIfExists(name); + if (channel.has_value()) { + auto channelValue = channel.value(); + filterBuildProjections_.emplace_back(channelValue, filterChannel++); + names.emplace_back(buildType->nameOf(channelValue)); + types.emplace_back(buildType->childAt(channelValue)); + continue; + } + VELOX_FAIL( + "Spatial join filter field {} not in probe or build input, filter: {}", + field->toString(), + filter->toString()); + } + + filterInputType_ = ROW(std::move(names), std::move(types)); +} + +BlockingReason SpatialJoinProbe::isBlocked(ContinueFuture* future) { + switch (state_) { + case ProbeOperatorState::kRunning: + [[fallthrough]]; + case ProbeOperatorState::kFinish: + return BlockingReason::kNotBlocked; + case ProbeOperatorState::kWaitForPeers: + if (future_.valid()) { + *future = std::move(future_); + return BlockingReason::kWaitForJoinProbe; + } + setState(ProbeOperatorState::kFinish); + return BlockingReason::kNotBlocked; + case ProbeOperatorState::kWaitForBuild: { + VELOX_CHECK(!buildVectors_.has_value()); + if (!getBuildData(future)) { + return BlockingReason::kWaitForJoinBuild; + } + VELOX_CHECK(buildVectors_.has_value()); + setState(ProbeOperatorState::kRunning); + return BlockingReason::kNotBlocked; + } + default: + VELOX_UNREACHABLE(probeOperatorStateName(state_)); + } +} + +void SpatialJoinProbe::close() { + if (joinCondition_ != nullptr) { + joinCondition_->clear(); + } + buildVectors_.reset(); + spatialIndex_.reset(); + Operator::close(); +} + +void SpatialJoinProbe::noMoreInput() { + Operator::noMoreInput(); + if (state_ == ProbeOperatorState::kRunning && input_ == nullptr) { + setState(ProbeOperatorState::kFinish); + } +} + +bool SpatialJoinProbe::getBuildData(ContinueFuture* future) { + VELOX_CHECK(!buildVectors_.has_value()); + + auto buildData = + operatorCtx_->task() + ->getSpatialJoinBridge( + operatorCtx_->driverCtx()->splitGroupId, planNodeId()) + ->dataOrFuture(future); + if (!buildData.has_value()) { + return false; + } + + buildVectors_ = buildData.value().buildVectors; + spatialIndex_ = buildData.value().spatialIndex; + return true; +} + +void SpatialJoinProbe::checkStateTransition(ProbeOperatorState state) { + VELOX_CHECK_NE(state_, state); + switch (state) { + case ProbeOperatorState::kRunning: + VELOX_CHECK_EQ(state_, ProbeOperatorState::kWaitForBuild); + break; + case ProbeOperatorState::kWaitForBuild: + [[fallthrough]]; + case ProbeOperatorState::kFinish: + VELOX_CHECK_EQ(state_, ProbeOperatorState::kRunning); + break; + default: + VELOX_UNREACHABLE(probeOperatorStateName(state_)); + break; + } +} + +//////////////// +// INPUT/OUTPUT + +void SpatialJoinProbe::addInput(RowVectorPtr input) { + VELOX_CHECK_NULL(input_); + VELOX_CHECK_EQ(probeRow_, 0); + VELOX_CHECK(!probeHasMatch_); + VELOX_CHECK_EQ(buildVectorIndex_, 0); + VELOX_CHECK_EQ(candidateIndex_, 0); + + // In getOutput(), we are going to wrap input in dictionaries a few rows at a + // time. Since lazy vectors cannot be wrapped in different dictionaries, we + // are going to load them here. + for (auto& child : input->children()) { + child->loadedVector(); + } + input_ = std::move(input); + decodedGeometryCol_.decode(*input_->childAt(probeGeometryChannel_) + ->asChecked>()); + ++probeCount_; +} + +RowVectorPtr SpatialJoinProbe::getOutput() { + if (state_ == ProbeOperatorState::kFinish || + state_ == ProbeOperatorState::kWaitForPeers) { + return nullptr; + } + + RowVectorPtr output{nullptr}; + while (output == nullptr) { + // Need more input. + if (input_ == nullptr) { + break; + } + + // If the task owning this operator isn't running, there is no point + // to continue executing this procedure, which may be long in degenerate + // cases. Exit the working loop and let the Driver handle exiting + // gracefully in its own loop. + if (!operatorCtx_->task()->isRunning()) { + break; + } + + if (shouldYield()) { + break; + } + + // Generate actual join output by processing probe and build matches, and + // probe mismaches (for left joins). + output = generateOutput(); + } + + if (output != nullptr) { + ++outputCount_; + } + return output; +} + +RowVectorPtr SpatialJoinProbe::generateOutput() { + VELOX_CHECK_NOT_NULL(input_); + VELOX_CHECK_GT(input_->size(), probeRow_); + outputBuilder_.initializeOutput(input_, pool()); + + while (!isOutputDone()) { + // Fill output_ with the results from one row. This may produce too + // much output and only partially complete. If so, the next time we + // call this we'll get the next chunk. + // + // addProbeRowOutput is responsible for advancing probeRow_. + addProbeRowOutput(); + } + + // If we've exhausted the input, release it. + if (probeRow_ >= input_->size()) { + finishProbeInput(); + } + + return outputBuilder_.takeOutput(); +} + +// Return true if adding output stops early because output is full. +void SpatialJoinProbe::addProbeRowOutput() { + VELOX_CHECK(buildVectors_.has_value()); + VELOX_CHECK(!outputBuilder_.isOutputFull()); + + // Find the candidates for each probe row from the spatial index. Only do + // this at the start for each row. + if (buildVectorIndex_ == 0 && candidateIndex_ == 0) { + candidateBuildRows_ = querySpatialIndex(); + } + + while (!isProbeRowDone()) { + addBuildVectorOutput(buildVectors_.value()[buildVectorIndex_]); + if (outputBuilder_.isOutputFull()) { + // If full, don't advance buildVectorIndex_ because we may not have + // exhausted the current vector. Return instead of breaking so that we + // can add a mismatch row later if necessary. + return; + } + advanceBuildVector(); + } + + // Now that we have finished the probe row, check if we need to add a probe + // mismatch record. + if (!probeHasMatch_ && needsProbeMismatch(joinType_)) { + outputBuilder_.addProbeMismatchRow(probeRow_); + } + // Advance here instead of the loop in generateOutput so that early return on + // full doesn't advance the probe. + advanceProbeRow(); +} + +void SpatialJoinProbe::addBuildVectorOutput(const RowVectorPtr& buildVector) { + if (FOLLY_UNLIKELY(needsFilterEvaluated_)) { + // Evaluate join filter for the whole vector just once. + evaluateJoinFilter(buildVector); + needsFilterEvaluated_ = false; + } + + // Start where we left off: after the last buildRow_ that was processed. + while (!isBuildVectorDone(buildVector)) { + vector_size_t buildRow = relativeBuildRow(candidateIndex_); + if (isJoinConditionMatch(candidateIndex_)) { + outputBuilder_.addOutputRow(probeRow_, buildRow); + probeHasMatch_ = true; + } + + // Advance candidateIndex_ even if full, since we're finished with this row. + ++candidateIndex_; + } + + // Since we are copying from the current buildVector, we must copy here. + outputBuilder_.copyBuildValues(buildVector); +} + +std::vector SpatialJoinProbe::querySpatialIndex() { + VELOX_CHECK(spatialIndex_.has_value()); + VELOX_CHECK_NOT_NULL(spatialIndex_.value()); + + if (decodedGeometryCol_.isNullAt(probeRow_)) { + return std::vector{}; + } + + // Always apply radius to build side, not probe side. + Envelope envelope = SpatialJoinBuild::readEnvelope( + decodedGeometryCol_.valueAt(probeRow_), 0 /* radius */); + std::vector candidates = spatialIndex_.value()->query(envelope); + std::sort(candidates.begin(), candidates.end()); + + return candidates; +} + +BufferPtr SpatialJoinProbe::makeBuildVectorIndices(vector_size_t vectorSize) { + // Find the slice of candidates that are in this build vector. + vector_size_t endIndex = candidateIndex_; + for (; endIndex < candidateBuildRows_.size(); ++endIndex) { + if (relativeBuildRow(endIndex) >= vectorSize) { + break; + } + } + + // Make an index vector to fit the candidates. Populate each entry with its + // relative build row. + vector_size_t indexCount = + static_cast(endIndex - candidateIndex_); + auto rowIndices = allocateIndices(indexCount, operatorCtx_->pool()); + auto rawIndices = rowIndices->asMutable(); + for (vector_size_t idx = 0; idx < indexCount; ++idx) { + rawIndices[idx] = relativeBuildRow(idx + candidateIndex_); + } + + return rowIndices; +} + +void SpatialJoinProbe::evaluateJoinFilter(const RowVectorPtr& buildVector) { + // Get the indices of the rows in the build vector that are candidates. + auto candidateRowsBuffer = makeBuildVectorIndices(buildVector->size()); + + // Now get the input for the spatial join filter, one row per candidate. + auto filterInput = getNextJoinBatch( + buildVector, + filterInputType_, + filterProbeProjections_, + filterBuildProjections_, + candidateRowsBuffer); + + if (filterInputRows_.size() != filterInput->size()) { + filterInputRows_.resizeFill(filterInput->size(), true); + } + VELOX_CHECK(filterInputRows_.isAllSelected()); + + std::vector filterResult; + EvalCtx evalCtx( + operatorCtx_->execCtx(), joinCondition_.get(), filterInput.get()); + joinCondition_->eval(0, 1, true, filterInputRows_, evalCtx, filterResult); + VELOX_CHECK_GT(filterResult.size(), 0); + filterOutput_ = filterResult[0]; + decodedFilterResult_.decode(*filterOutput_, filterInputRows_); +} + +RowVectorPtr SpatialJoinProbe::getNextJoinBatch( + const RowVectorPtr& buildVector, + const RowTypePtr& outputType, + const std::vector& probeProjections, + const std::vector& buildProjections, + BufferPtr candidateRows) const { + VELOX_CHECK_GT(buildVector->size(), 0); + // candidateRows is a buffer of vector_size_t indices into buildVector + const vector_size_t numOutputRows = + candidateRows->size() / sizeof(vector_size_t); + if (numOutputRows == 0) { + return RowVector::createEmpty(outputType, pool()); + } + + std::vector projectedChildren(outputType->size()); + // Project columns from the build side. + projectChildren( + projectedChildren, + buildVector, + buildProjections, + numOutputRows, + candidateRows); + + // Wrap projections from the probe side as constants. + for (const auto [inputChannel, outputChannel] : probeProjections) { + projectedChildren[outputChannel] = BaseVector::wrapInConstant( + numOutputRows, probeRow_, input_->childAt(inputChannel)); + } + + return std::make_shared( + pool(), outputType, nullptr, numOutputRows, std::move(projectedChildren)); +} + +void SpatialJoinProbe::finishProbeInput() { + VELOX_CHECK_NOT_NULL(input_); + input_.reset(); + probeRow_ = 0; + + if (noMoreInput_) { + setState(ProbeOperatorState::kFinish); + } +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/SpatialJoinProbe.h b/velox/exec/SpatialJoinProbe.h new file mode 100644 index 000000000000..e9fc2eea13ab --- /dev/null +++ b/velox/exec/SpatialJoinProbe.h @@ -0,0 +1,396 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/exec/Operator.h" +#include "velox/exec/ProbeOperatorState.h" +#include "velox/exec/SpatialIndex.h" + +namespace facebook::velox::exec { + +class SpatialJoinOutputBuilder { + public: + SpatialJoinOutputBuilder( + vector_size_t outputBatchSize, + RowTypePtr outputType, + std::vector probeProjections, + std::vector buildProjections, + const OperatorCtx& operatorCtx) + : outputBatchSize_{outputBatchSize}, + outputType_{std::move(outputType)}, + probeProjections_{std::move(probeProjections)}, + buildProjections_{std::move(buildProjections)}, + operatorCtx_{operatorCtx} { + VELOX_CHECK_GT(outputBatchSize_, 0); + } + + void initializeOutput(const RowVectorPtr& input, memory::MemoryPool* pool); + + bool isOutputFull() const { + return outputRow_ >= outputBatchSize_; + } + + void addOutputRow(vector_size_t probeRow, vector_size_t buildRow); + + /// Checks if it is required to add a probe mismatch row, and does it if + /// needed. The caller needs to ensure there is available space in `output_` + /// for the new record, which has nulled out build projections. + void addProbeMismatchRow(vector_size_t probeRow); + + void copyBuildValues(const RowVectorPtr& buildVector); + + RowVectorPtr takeOutput(); + + private: + // Initialization parameters + const vector_size_t outputBatchSize_; + const RowTypePtr outputType_; + const std::vector probeProjections_; + const std::vector buildProjections_; + const OperatorCtx& operatorCtx_; + + // Output state + RowVectorPtr output_; + vector_size_t outputRow_{0}; + // Dictionary indices for probe columns for output vector. + BufferPtr probeOutputIndices_; + // Mutable pointer to probeOutputIndices_ + vector_size_t* rawProbeOutputIndices_{}; + + // Stores the ranges of build values to be copied to the output vector (we + // batch them and copy once, instead of copying them row-by-row). + std::vector buildCopyRanges_{}; +}; + +/// Implements a Spatial Join between records from the probe (input_) +/// and build (SpatialJoinBridge) sides. It supports inner and left joins. +/// +/// This class is designed to evaluate spatial join conditions (e.g. +/// ST_INTERSECTS, ST_CONTAINS, ST_WITHIN) between geometric data types. It +/// can also implement spatial cross-join semantics if joinCondition is +/// nullptr. +/// +/// The output follows the order of the probe side rows (for inner and left +/// joins). All build vectors are materialized upfront (check buildVectors_), +/// but probe batches are processed one-by-one as a stream. +/// +/// To produce output, the operator processes each probe record from probe +/// input, using the following steps: +/// +/// 1. Materialize a cross-product batch across probe and build. +/// 2. Evaluate the spatial join condition. +/// 3. Add spatial matches to the output. +/// 4. Once all build vectors are processed for a particular probe row, check +/// if +/// a probe mismatch is needed (only for left and full outer joins). +/// 5. Once all probe and build inputs are processed, check if build +/// mismatches +/// are needed (only for right and full outer joins). +/// 6. If so, signal other peer operators; only a single operator instance +/// will +/// collect all build matches at the end, and emit any records that haven't +/// been matched by any of the peers. +/// +/// Spatial joins typically use spatial indexing for performance optimization, +/// but this implementation follows the nested loop pattern for compatibility +/// with the existing join framework. +class SpatialJoinProbe : public Operator { + public: + SpatialJoinProbe( + int32_t operatorId, + DriverCtx* driverCtx, + const std::shared_ptr& joinNode); + + void initialize() override; + + void addInput(RowVectorPtr input) override; + + RowVectorPtr getOutput() override; + + bool needsInput() const override { + return state_ == ProbeOperatorState::kRunning && input_ == nullptr && + !noMoreInput_; + } + + void noMoreInput() override; + + BlockingReason isBlocked(ContinueFuture* future) override; + + bool isFinished() override { + return state_ == ProbeOperatorState::kFinish; + } + + void close() override; + + private: + void checkStateTransition(ProbeOperatorState state); + + void setState(ProbeOperatorState state) { + checkStateTransition(state); + state_ = state; + } + + // Initialize spatial filter for evaluating spatial join conditions. + void initializeFilter( + const core::TypedExprPtr& filter, + const RowTypePtr& leftType, + const RowTypePtr& rightType); + + // Materializes build data from spatial join bridge into `buildVectors_`. + // Returns whether the data has been materialized and is ready for use. + // Spatial join requires all build data to be materialized and available in + // `buildVectors_` before it can produce output. + bool getBuildData(ContinueFuture* future); + + // Produce as much output as possible for the current input. + RowVectorPtr generateOutput(); + + // Returns true if the input is exhausted or the output is full. + bool isOutputDone() const { + return probeRow_ >= input_->size() || outputBuilder_.isOutputFull(); + } + + // Called when we are done processing the current probe batch, to signal we + // are ready for the next one. + // + // If this is the last probe batch (and this is a right or full outer join), + // change the operator state to signal peers. + void finishProbeInput(); + + // Add the output for a single probe row. This will return early if the + // output vector is full. + void addProbeRowOutput(); + + // Returns true if all output for the current probe row has been produced. + bool isProbeRowDone() const { + return candidateIndex_ >= candidateBuildRows_.size() || + buildVectorIndex_ >= buildVectors_.value().size(); + } + + // Increment probeRow_ and reset associated fields + void advanceProbeRow() { + ++probeRow_; + probeHasMatch_ = false; + buildVectorIndex_ = 0; + candidateIndex_ = 0; + candidateOffsetForCurrentBuildVector_ = 0; + buildRowOffset_ = 0; + needsFilterEvaluated_ = true; + } + + // Add the output for a single build vector for a single probe row. This will + // return early if the output vector is full. + void addBuildVectorOutput(const RowVectorPtr& buildVector); + + // Returns true if all the rows for the current build vector have been + // processed, or the output is full. + bool isBuildVectorDone(const RowVectorPtr& buildVector) const { + // Note that candidateBuildRows_ entries are row numbers across + // all build vectors. + return candidateIndex_ >= candidateBuildRows_.size() || + relativeBuildRow(candidateIndex_) >= buildVector->size() || + outputBuilder_.isOutputFull(); + } + + // Increment buildVectorIndex_ and reset associated fields + void advanceBuildVector() { + VELOX_CHECK(buildVectors_.has_value()); + + buildRowOffset_ += buildVectors_.value()[buildVectorIndex_]->size(); + ++buildVectorIndex_; + needsFilterEvaluated_ = true; + candidateOffsetForCurrentBuildVector_ = candidateIndex_; + } + + // Calculate candidate build rows from spatialIndex_ for the current probe + // row. This should be done each time the probe is advanced. + std::vector querySpatialIndex(); + + // Evaluates the spatial joinCondition for a given build vector. This method + // sets `filterOutput_` and `decodedFilterResult_`, which will be ready to + // be used by `isSpatialJoinConditionMatch()` below. + // This only evaluates rows that are in the candidateBuildRows_, restricted to + // those in the current build vector. Thus we must index into this with + // candidateIndex_. + void evaluateJoinFilter(const RowVectorPtr& buildVector); + + // Checks if the spatial join condition matched for a particular row. + bool isJoinConditionMatch(vector_size_t candidateIndex) const { + vector_size_t relativeIndex = + candidateIndex - candidateOffsetForCurrentBuildVector_; + VELOX_CHECK_GT(decodedFilterResult_.size(), relativeIndex); + return ( + !decodedFilterResult_.isNullAt(relativeIndex) && + decodedFilterResult_.valueAt(relativeIndex)); + } + + // Generates the next batch of a cross product between probe and build using + // the supplied projections. It uses the current probe row as constant, and + // flat copied data for build records. + RowVectorPtr getNextJoinBatch( + const RowVectorPtr& buildVector, + const RowTypePtr& outputType, + const std::vector& probeProjections, + const std::vector& buildProjections, + BufferPtr candidateRows) const; + + // Given a candidate index, return the row index into the current build + // vector. For example, if we have candidates [2, 50, 81] and have processed + // two build vectors with size 30 and 40, then `relativeBuildRow(2) == 11` + // (81 - (30 + 40)). + vector_size_t relativeBuildRow(vector_size_t candidateRow) const { + return candidateBuildRows_[candidateRow] - buildRowOffset_; + } + + // Make the indices of build vector candidates suitable for creating a + // DictionaryVector. + BufferPtr makeBuildVectorIndices(vector_size_t vectorSize); + + ///////// + // SETUP + // Variables set during operator setup that are used during execution. + // These should not be modified after the operator is initialized. + + const core::JoinType joinType_; + + // Maximum number of rows in the output batch. + const vector_size_t outputBatchSize_; + + // Join metadata and state. + std::shared_ptr joinNode_; + + // Spatial join condition expression. + // Must not be null + std::unique_ptr joinCondition_; + + // Input type for the spatial join condition expression. + RowTypePtr filterInputType_; + + // List of output projections from the build side. Note that the list of + // projections from the probe side is available at `identityProjections_`. + std::vector buildProjections_; + + // Projections needed as input to the filter to evaluation spatial join + // filter conditions. Note that if this is a cross-join, filter projections + // are the same as output projections. + std::vector filterProbeProjections_; + std::vector filterBuildProjections_; + + // Stores the build spatial index for the join + std::optional> spatialIndex_; + // Stores the data for build vectors (right side of the join). + std::optional> buildVectors_; + + // Channel of geometry variable used to probe spatial index + column_index_t probeGeometryChannel_; + + ////////////////// + // OPERATOR STATE + // Variables used to track the general operator state during exection. + // These will change throughout setup and execution. + + ProbeOperatorState state_{ProbeOperatorState::kWaitForBuild}; + ContinueFuture future_{ContinueFuture::makeEmpty()}; + + // The information needed to produce an output RowVectorPtr. It is stored + // for all execution, but is reset on each output batch. + SpatialJoinOutputBuilder outputBuilder_; + + // Count of output batches produced (1-indexed). Primarily for debugging. + size_t outputCount_{0}; + + // This is always set to all true, but we need it for eval/etc. Reuse between + // evaluations. + SelectivityVector filterInputRows_; + // The output result of the join condition evaluation on the **current** + // build vector. We must index into this with + // `candidateIndex_ - candidateOffsetForCurrentBuildVector_`. + VectorPtr filterOutput_; + // Decoded filterOutput: remove recursive dictionary/etc encodings. + // Like filterOutput_, this is only for the current build vector and we + // must index into this with + // `candidateIndex_ - candidateOffsetForCurrentBuildVector_`. + DecodedVector decodedFilterResult_; + + // Decoded geometry vector. Must be reset whenever input_ is changed (it + // maintains a pointer to input_). + DecodedVector decodedGeometryCol_{}; + + /////////////// + // PROBE STATE + // Variables used to track the probe-side state state during exection. + // These will change throughout setup and execution. + + // Count of probe batches added (1-indexed). Primarily for debugging. + size_t probeCount_{0}; + + // Probe row being currently processed (related to `input_`). + vector_size_t probeRow_{0}; + + // Whether the current probeRow_ has found a match. Needed for left join. + bool probeHasMatch_{false}; + + /////////////// + // BUILD STATE + // Variables used to track the build-side state state during exection. + // These will change throughout setup and execution. + // + // The build rows are stored in a vector of RowVectorPtrs. These are + // conceptually indexed by an absolute build row, which indexes into a + // flattened vector of rows. buildVectorIndex_ is the index to the current + // RowVectorPtr in buildVectors_, buildRowOffset_ is the sum of the sizes + // of the previous build vectors and should be subtracted from buildRow + // to index into the current build vector. + // + // We primarily use candidateBuildRows_, which is a vector of (absolute) + // build rows. candidateIndex_ indexes the entry in candidateBuildRows_, + // so candidateBuildrows_[candidiateIndex_] is the absolute build row + // of the current candidate. + + // Whether we need to evaluate the join filter on this build vector. It + // should be done once per build vector/probe row pair. + bool needsFilterEvaluated_{true}; + + // Index into `buildVectors_` for the build vector being currently + // processed. + size_t buildVectorIndex_{0}; + + // Keep track of how many build rows we've traversed in previous build + // RowVectors. Subtract this from the current element in candidateBuildRows_ + // to index into the current build RowVector. + vector_size_t buildRowOffset_{0}; + + // Build rows returned from the spatial index. + // The value is the row number over all build vectors, so if the have two + // build vectors of size 100 and 200, candidate row 50 is the 50th entry of + // the first vector, and 101 is the 2nd entry of the second vector. + std::vector candidateBuildRows_{}; + + // Index of candidate currently being processed from + // `buildVectors_[buildIndex_]`. + vector_size_t candidateIndex_{0}; + + // How many candidates were in previous build vectors. + // This is important because for each build vector, we calculate a + // decodedFilterResult_ with only the rows from from the candidates in + // that build vector. candidateIndex_ indexes over _all_ candidates, so + // we must substract candidateOffsetForCurrentBuildVector_ to index into the + // candidates for this build vector. + vector_size_t candidateOffsetForCurrentBuildVector_{0}; +}; + +} // namespace facebook::velox::exec diff --git a/velox/exec/Spill.cpp b/velox/exec/Spill.cpp index 9b9f14eb75d6..15c077f743ce 100644 --- a/velox/exec/Spill.cpp +++ b/velox/exec/Spill.cpp @@ -15,14 +15,82 @@ */ #include "velox/exec/Spill.h" +#include "velox/common/Casts.h" #include "velox/common/base/RuntimeMetrics.h" #include "velox/common/file/FileSystems.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" #include "velox/serializers/PrestoSerializer.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { + +namespace { +/// gatherMerge merges & sorts with the mergeTree and gatherCopy the +/// results into target. 'target' is the result RowVector, and the copying +/// starts from row 0 up to row target.size(). 'mergeTree' is the data source. +/// 'totalNumRows' is the actual num of rows that is copied to target. +/// 'bufferSources' and 'bufferSourceIndices' are buffering vectors that could +/// be reused across calls. +void gatherMerge( + RowVectorPtr& target, + TreeOfLosers& mergeTree, + int32_t& totalNumRows, + std::vector& bufferSources, + std::vector& bufferSourceIndices) { + VELOX_CHECK_GE(bufferSources.size(), target->size()); + VELOX_CHECK_GE(bufferSourceIndices.size(), target->size()); + totalNumRows = 0; + int32_t numBatchRows = 0; + bool endOfBatch = false; + for (auto currentStream = mergeTree.next(); + currentStream != nullptr && totalNumRows + numBatchRows < target->size(); + currentStream = mergeTree.next()) { + bufferSources[numBatchRows] = ¤tStream->current(); + bufferSourceIndices[numBatchRows] = + currentStream->currentIndex(&endOfBatch); + ++numBatchRows; + if (FOLLY_UNLIKELY(endOfBatch)) { + // The stream is at end of input batch. Need to copy out the rows before + // fetching next batch in 'pop'. + gatherCopy( + target.get(), + totalNumRows, + numBatchRows, + bufferSources, + bufferSourceIndices); + totalNumRows += numBatchRows; + numBatchRows = 0; + } + // Advance the stream. + currentStream->pop(); + } + VELOX_CHECK_LE(totalNumRows + numBatchRows, target->size()); + + if (FOLLY_LIKELY(numBatchRows != 0)) { + gatherCopy( + target.get(), + totalNumRows, + numBatchRows, + bufferSources, + bufferSourceIndices); + totalNumRows += numBatchRows; + numBatchRows = 0; + } +} +} // namespace + +void testingGatherMerge( + RowVectorPtr& target, + TreeOfLosers& mergeTree, + int32_t& totalNumRows, + std::vector& bufferSources, + std::vector& bufferSourceIndices) { + gatherMerge( + target, mergeTree, totalNumRows, bufferSources, bufferSourceIndices); +} + void SpillMergeStream::pop() { VELOX_CHECK(!closed_); if (++index_ >= size_) { @@ -119,10 +187,11 @@ void SpillState::validateSpillBytesSize(uint64_t bytes) { static constexpr uint64_t kMaxSpillBytesPerWrite = std::numeric_limits::max(); if (bytes >= kMaxSpillBytesPerWrite) { - VELOX_GENERIC_SPILL_FAILURE(fmt::format( - "Spill bytes will overflow. Bytes {}, kMaxSpillBytesPerWrite: {}", - bytes, - kMaxSpillBytesPerWrite)); + VELOX_GENERIC_SPILL_FAILURE( + fmt::format( + "Spill bytes will overflow. Bytes {}, kMaxSpillBytesPerWrite: {}", + bytes, + kMaxSpillBytesPerWrite)); } } @@ -283,8 +352,9 @@ SpillPartition::createUnorderedReader( std::vector> streams; streams.reserve(files_.size()); for (auto& fileInfo : files_) { - streams.push_back(FileSpillBatchStream::create( - SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); + streams.push_back( + FileSpillBatchStream::create( + SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); } files_.clear(); return std::make_unique>( @@ -292,15 +362,16 @@ SpillPartition::createUnorderedReader( } std::unique_ptr> -SpillPartition::createOrderedReader( +SpillPartition::createOrderedReaderInternal( uint64_t bufferSize, memory::MemoryPool* pool, folly::Synchronized* spillStats) { std::vector> streams; streams.reserve(files_.size()); for (auto& fileInfo : files_) { - streams.push_back(FileSpillMergeStream::create( - SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); + streams.push_back( + FileSpillMergeStream::create( + SpillReadFile::create(fileInfo, bufferSize, pool, spillStats))); } files_.clear(); // Check if the partition is empty or not. @@ -310,6 +381,174 @@ SpillPartition::createOrderedReader( return std::make_unique>(std::move(streams)); } +namespace { +size_t estimateOutputBatchRows( + const std::vector>& streams, + vector_size_t maxRows, + size_t maxBytes) { + size_t numEstimations{0}; + int64_t totalEstimatedBytes{0}; + for (const auto& stream : streams) { + const auto streamEstimateRowSize = stream->estimateRowSize(); + if (streamEstimateRowSize.has_value()) { + ++numEstimations; + totalEstimatedBytes += streamEstimateRowSize.value(); + } + } + + if (numEstimations == 0) { + return maxRows; + } + + const auto estimateRowSize = + std::max(1, totalEstimatedBytes / numEstimations); + return std::min( + std::max(1, maxBytes / estimateRowSize), maxRows); +} + +// This contains batching parameters and various kinds of batching buffers that +// are reused across multiple merging rounds. +struct SpillFileMergeParams { + static constexpr size_t kDefaultMaxBatchRows = 1'000; + static constexpr size_t kDefaultMaxBatchBytes = 64 * 1024; + + SpillFileMergeParams( + const TypePtr& type, + memory::MemoryPool* pool, + const vector_size_t _maxBatchRows = kDefaultMaxBatchRows, + const size_t _maxBatchBytes = kDefaultMaxBatchBytes) + : maxBatchRows(_maxBatchRows), maxBatchBytes(_maxBatchBytes) { + rowVector = std::static_pointer_cast( + BaseVector::create(type, maxBatchRows, pool)); + spillSources.resize(maxBatchRows); + spillSourceRows.resize(maxBatchRows); + } + + const vector_size_t maxBatchRows; + const size_t maxBatchBytes; + RowVectorPtr rowVector; + std::vector spillSources; + std::vector spillSourceRows; +}; + +SpillFileInfo mergeSpillFiles( + const std::vector& files, + const std::string& pathPrefix, + const common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, + const std::string& fileCreateConfig, + uint64_t readBufferSize, + uint64_t writeBufferSize, + SpillFileMergeParams& mergeParams, + memory::MemoryPool* pool, + folly::Synchronized* spillStats) { + VELOX_CHECK_GT(files.size(), 0); + std::vector> streams; + streams.reserve(files.size()); + for (const auto& fileInfo : files) { + streams.push_back( + FileSpillMergeStream::create( + SpillReadFile::create(fileInfo, readBufferSize, pool, spillStats))); + } + const auto batchRows = estimateOutputBatchRows( + streams, mergeParams.maxBatchRows, mergeParams.maxBatchBytes); + + auto mergeTree = + std::make_unique>(std::move(streams)); + const auto type = files[0].type; + + auto writer = std::make_unique( + type, + files[0].sortingKeys, + files[0].compressionKind, + pathPrefix, + std::numeric_limits::max(), + writeBufferSize, + fileCreateConfig, + updateAndCheckSpillLimitCb, + pool, + spillStats); + + while (mergeTree->next()) { + VectorPtr tmpRowVector = std::move(mergeParams.rowVector); + BaseVector::prepareForReuse(tmpRowVector, batchRows); + mergeParams.rowVector = checkedPointerCast(tmpRowVector); + mergeParams.rowVector->resize(batchRows); + int32_t outputRow = 0; + gatherMerge( + mergeParams.rowVector, + *mergeTree, + outputRow, + mergeParams.spillSources, + mergeParams.spillSourceRows); + + IndexRange range{0, outputRow}; + writer->write(mergeParams.rowVector, folly::Range(&range, 1)); + } + auto resultFiles = writer->finish(); + VELOX_CHECK_EQ(resultFiles.size(), 1); + return std::move(resultFiles[0]); +} + +struct SpillFileCompare { + bool operator()(const SpillFileInfo& lhs, const SpillFileInfo& rhs) const { + return lhs.size > rhs.size; + } +}; +using SpillFileHeap = std:: + priority_queue, SpillFileCompare>; +} // namespace + +std::unique_ptr> +SpillPartition::createOrderedReader( + const common::SpillConfig& spillConfig, + memory::MemoryPool* pool, + folly::Synchronized* spillStats) { + const auto numMaxMergeFiles = spillConfig.numMaxMergeFiles; + VELOX_CHECK_NE(numMaxMergeFiles, 1); + if (numMaxMergeFiles == 0 || files_.size() <= numMaxMergeFiles) { + return createOrderedReaderInternal( + spillConfig.readBufferSize, pool, spillStats); + } + + SpillFileHeap orderedFiles(files_.begin(), files_.end()); + SpillFiles files; + files.reserve(numMaxMergeFiles); + const auto mergeFilePathPrefix = files_[0].path; + SpillFileMergeParams mergeParams(files_[0].type, pool); + + // Recursively merge the files. + for (uint32_t round = 0; orderedFiles.size() > numMaxMergeFiles; ++round) { + const uint64_t numMergeFiles = std::min( + static_cast(numMaxMergeFiles), + static_cast(orderedFiles.size() + 1 - numMaxMergeFiles)); + // Choose the top 'numMergeFiles' smallest files for merging to minimize IO. + for (uint32_t i = 0; i < numMergeFiles; i++) { + files.push_back(orderedFiles.top()); + orderedFiles.pop(); + } + auto mergedFile = mergeSpillFiles( + files, + fmt::format("{}-merge-round-{}", mergeFilePathPrefix, round), + spillConfig.updateAndCheckSpillLimitCb, + spillConfig.fileCreateConfig, + spillConfig.readBufferSize, + spillConfig.writeBufferSize, + mergeParams, + pool, + spillStats); + orderedFiles.push(mergedFile); + files.clear(); + } + + files_.clear(); + while (!orderedFiles.empty()) { + files_.push_back(orderedFiles.top()); + orderedFiles.pop(); + } + return createOrderedReaderInternal( + spillConfig.readBufferSize, pool, spillStats); +} + IterableSpillPartitionSet::IterableSpillPartitionSet() { spillPartitionIter_ = spillPartitions_.begin(); } @@ -429,13 +668,38 @@ const std::vector& ConcatFilesSpillMergeStream::sortingKeys() return spillFiles_[fileIndex_]->sortingKeys(); } +std::unique_ptr ConcatFilesSpillBatchStream::create( + std::vector> spillFiles) { + auto* spillStream = new ConcatFilesSpillBatchStream(std::move(spillFiles)); + return std::unique_ptr(spillStream); +} + +bool ConcatFilesSpillBatchStream::nextBatch(RowVectorPtr& batch) { + TestValue::adjust( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", nullptr); + VELOX_CHECK_NULL(batch); + VELOX_CHECK(!atEnd_); + for (; fileIndex_ < spillFiles_.size(); ++fileIndex_) { + VELOX_CHECK_NOT_NULL(spillFiles_[fileIndex_]); + if (spillFiles_[fileIndex_]->nextBatch(batch)) { + VELOX_CHECK_NOT_NULL(batch); + return true; + } + spillFiles_[fileIndex_].reset(); + } + spillFiles_.clear(); + atEnd_ = true; + return false; +} + SpillPartitionId::SpillPartitionId(uint32_t partitionNumber) : encodedId_(partitionNumber) { if (FOLLY_UNLIKELY(partitionNumber >= (1 << kMaxPartitionBits))) { - VELOX_FAIL(fmt::format( - "Partition number {} exceeds max partition number {}", - partitionNumber, - 1 << kMaxPartitionBits)); + VELOX_FAIL( + fmt::format( + "Partition number {} exceeds max partition number {}", + partitionNumber, + 1 << kMaxPartitionBits)); } } @@ -444,10 +708,11 @@ SpillPartitionId::SpillPartitionId( uint32_t partitionNumber) { const auto childSpillLevel = parent.spillLevel() + 1; if (FOLLY_UNLIKELY(childSpillLevel > kMaxSpillLevel)) { - VELOX_FAIL(fmt::format( - "Spill level {} exceeds max spill level {}", - childSpillLevel, - kMaxSpillLevel)); + VELOX_FAIL( + fmt::format( + "Spill level {} exceeds max spill level {}", + childSpillLevel, + kMaxSpillLevel)); } encodedId_ = parent.encodedId_; encodedId_ = encodedId_ & ~kSpillLevelBitMask; @@ -459,15 +724,10 @@ SpillPartitionId::SpillPartitionId( encodedId_ |= partitionNumber << (kNumPartitionBits * childSpillLevel); } -bool SpillPartitionId::operator==(const SpillPartitionId& other) const { - return encodedId_ == other.encodedId_; -} - -bool SpillPartitionId::operator!=(const SpillPartitionId& other) const { - return !(*this == other); -} - bool SpillPartitionId::operator<(const SpillPartitionId& other) const { + if (*this == other) { + return false; + } for (auto i = 0; i <= std::min(spillLevel(), other.spillLevel()); ++i) { const auto selfPartitionNum = partitionNumber(i); const auto otherPartitionNum = other.partitionNumber(i); diff --git a/velox/exec/Spill.h b/velox/exec/Spill.h index ec7605af9593..128e1f938a45 100644 --- a/velox/exec/Spill.h +++ b/velox/exec/Spill.h @@ -21,11 +21,11 @@ #include #include "velox/common/base/SpillConfig.h" #include "velox/common/base/SpillStats.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/compression/Compression.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" #include "velox/exec/SpillFile.h" -#include "velox/exec/TreeOfLosers.h" #include "velox/exec/UnorderedStreamReader.h" #include "velox/exec/VectorHasher.h" #include "velox/vector/ComplexVector.h" @@ -33,6 +33,23 @@ #include "velox/vector/VectorStream.h" namespace facebook::velox::exec { + +class SpillMergeStream; + +/// Testing gatherMerge without exposing the interface in the header. Used in +/// test only. gatherMerge merges & sorts with the mergeTree and gatherCopy the +/// results into target. 'target' is the result RowVector, and the copying +/// starts from row 0 up to row target.size(). 'mergeTree' is the data source. +/// 'totalNumRows' is the actual num of rows that is copied to target. +/// 'bufferSources' and 'bufferSourceIndices' are buffering vectors that could +/// be reused across calls. +void testingGatherMerge( + RowVectorPtr& target, + TreeOfLosers& mergeTree, + int32_t& totalNumRows, + std::vector& bufferSources, + std::vector& bufferSourceIndices); + class VectorHasher; /// A source of sorted spilled RowVectors coming either from a file or memory. @@ -80,6 +97,15 @@ class SpillMergeStream : public MergeStream { return decoded_[index]; } + /// Returns the estimated row size based on the vector received from the + /// merge source. + std::optional estimateRowSize() const { + if (rowVector_ == nullptr || rowVector_->size() == 0) { + return std::nullopt; + } + return rowVector_->estimateFlatSize() / rowVector_->size(); + } + protected: virtual const std::vector& sortingKeys() const = 0; @@ -192,6 +218,28 @@ class FileSpillBatchStream : public BatchStream { std::unique_ptr spillFile_; }; +/// A source of spilled RowVectors coming from a sequence of files. +/// +/// NOTE: this object is not thread-safe. +class ConcatFilesSpillBatchStream final : public BatchStream { + public: + static std::unique_ptr create( + std::vector> spillFiles); + + bool nextBatch(RowVectorPtr& batch) override; + + private: + explicit ConcatFilesSpillBatchStream( + std::vector> spillFiles) + : spillFiles_(std::move(spillFiles)) { + VELOX_CHECK(!spillFiles_.empty()); + } + + std::vector> spillFiles_; + size_t fileIndex_{0}; + bool atEnd_{false}; +}; + /// A SpillMergeStream that contains a sequence of sorted spill files, the /// sorted keys are ordered both within each file and across files. class ConcatFilesSpillMergeStream final : public SpillMergeStream { @@ -244,9 +292,7 @@ class SpillPartitionId { /// Constructs a child spill level id, descending from provided 'parent'. SpillPartitionId(SpillPartitionId parent, uint32_t partitionNumber); - bool operator==(const SpillPartitionId& other) const; - - bool operator!=(const SpillPartitionId& other) const; + bool operator==(const SpillPartitionId& other) const = default; /// Customize the compare operator for recursive spilling control. It /// ensures the order such that: @@ -267,7 +313,7 @@ class SpillPartitionId { bool operator<(const SpillPartitionId& other) const; bool operator>(const SpillPartitionId& other) const { - return (*this != other) && !(*this < other); + return other < *this; } std::string toString() const; @@ -459,20 +505,31 @@ class SpillPartition { memory::MemoryPool* pool, folly::Synchronized* spillStats); + /// Create an ordered stream reader from this spill partition. If the + /// partition has more than spillConfig.numMaxMergeFiles files, the files will + /// be pre-merged recursively to make sure the final ordered reader reads no + /// more than numMaxMergeFiles files. This behavior is to avoid OOM problem + /// when opening and reading too many files at the same time. If + /// numMaxMergeFiles < 2, the merge way is unlimited. + std::unique_ptr> createOrderedReader( + const common::SpillConfig& spillConfig, + memory::MemoryPool* pool, + folly::Synchronized* spillStats); + + std::string toString() const; + + private: /// Invoked to create an ordered stream reader from this spill partition. /// The created reader will take the ownership of the spill files. /// 'bufferSize' specifies the read size from the storage. If the file /// system supports async read mode, then reader allocates two buffers with /// one buffer prefetch ahead. 'spillStats' is provided to collect the spill /// stats when reading data from spilled files. - std::unique_ptr> createOrderedReader( + std::unique_ptr> createOrderedReaderInternal( uint64_t bufferSize, memory::MemoryPool* pool, folly::Synchronized* spillStats); - std::string toString() const; - - private: SpillPartitionId id_; SpillFiles files_; // Counts the total file size in bytes from this spilled partition. diff --git a/velox/exec/SpillFile.cpp b/velox/exec/SpillFile.cpp index 0e612af59d40..c3cc3fea76a9 100644 --- a/velox/exec/SpillFile.cpp +++ b/velox/exec/SpillFile.cpp @@ -16,8 +16,7 @@ #include "velox/exec/SpillFile.h" #include "velox/common/base/RuntimeMetrics.h" -#include "velox/common/file/FileSystems.h" -#include "velox/vector/VectorStream.h" +#include "velox/serializers/SerializedPageFile.h" namespace facebook::velox::exec { namespace { @@ -29,245 +28,90 @@ namespace { static const bool kDefaultUseLosslessTimestamp = true; } // namespace -std::unique_ptr SpillWriteFile::create( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig) { - return std::unique_ptr( - new SpillWriteFile(id, pathPrefix, fileCreateConfig)); -} - -SpillWriteFile::SpillWriteFile( - uint32_t id, +SpillWriter::SpillWriter( + const RowTypePtr& type, + const std::vector& sortingKeys, + common::CompressionKind compressionKind, const std::string& pathPrefix, - const std::string& fileCreateConfig) - : id_(id), path_(fmt::format("{}-{}", pathPrefix, ordinalCounter_++)) { - auto fs = filesystems::getFileSystem(path_, nullptr); - file_ = fs->openFileForWrite( - path_, - filesystems::FileOptions{ - {{filesystems::FileOptions::kFileCreateConfig.toString(), - fileCreateConfig}}, - nullptr, - std::nullopt}); -} - -void SpillWriteFile::finish() { - VELOX_CHECK_NOT_NULL(file_); - size_ = file_->size(); - file_->close(); - file_ = nullptr; -} - -uint64_t SpillWriteFile::size() const { - if (file_ != nullptr) { - return file_->size(); - } - return size_; -} - -uint64_t SpillWriteFile::write(std::unique_ptr iobuf) { - auto writtenBytes = iobuf->computeChainDataLength(); - file_->append(std::move(iobuf)); - return writtenBytes; -} - -SpillWriterBase::SpillWriterBase( - uint64_t writeBufferSize, uint64_t targetFileSize, - const std::string& pathPrefix, + uint64_t writeBufferSize, const std::string& fileCreateConfig, - common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, + const common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, memory::MemoryPool* pool, folly::Synchronized* stats) - : pool_(pool), + : serializer::SerializedPageFileWriter( + pathPrefix, + targetFileSize, + writeBufferSize, + fileCreateConfig, + std::make_unique< + serializer::presto::PrestoVectorSerde::PrestoOptions>( + kDefaultUseLosslessTimestamp, + compressionKind, + 0.8, + /*_nullsFirst=*/true), + getNamedVectorSerde(VectorSerde::Kind::kPresto), + pool), + type_(type), + sortingKeys_(sortingKeys), stats_(stats), - updateAndCheckSpillLimitCb_(updateAndCheckSpillLimitCb), - fileCreateConfig_(fileCreateConfig), - pathPrefix_(pathPrefix), - writeBufferSize_(writeBufferSize), - targetFileSize_(targetFileSize) {} + updateAndCheckLimitCb_(updateAndCheckSpillLimitCb) {} -SpillFiles SpillWriterBase::finish() { - checkNotFinished(); - SCOPE_EXIT { - finished_ = true; - }; - finishFile(); - return std::move(finishedFiles_); -} - -void SpillWriterBase::finishFile() { - checkNotFinished(); - flush(); - closeFile(); - VELOX_CHECK_NULL(currentFile_); -} - -uint64_t SpillWriterBase::flush() { - if (bufferEmpty()) { - return 0; - } - - auto* file = ensureFile(); - VELOX_CHECK_NOT_NULL(file); - uint64_t writtenBytes{0}; - uint64_t flushTimeNs{0}; - uint64_t writeTimeNs{0}; - flushBuffer(file, writtenBytes, flushTimeNs, writeTimeNs); - updateWriteStats(writtenBytes, flushTimeNs, writeTimeNs); - updateAndCheckSpillLimitCb_(writtenBytes); - return writtenBytes; -} - -SpillWriteFile* SpillWriterBase::ensureFile() { - if ((currentFile_ != nullptr) && (currentFile_->size() > targetFileSize_)) { - closeFile(); - } - if (currentFile_ == nullptr) { - currentFile_ = SpillWriteFile::create( - nextFileId_++, - fmt::format("{}-{}", pathPrefix_, finishedFiles_.size()), - fileCreateConfig_); - } - return currentFile_.get(); -} - -void SpillWriterBase::closeFile() { - if (currentFile_ == nullptr) { - return; - } - currentFile_->finish(); - updateSpilledFileStats(currentFile_->size()); - addFinishedFile(currentFile_.get()); - currentFile_.reset(); -} - -uint64_t SpillWriterBase::writeWithBufferControl( - const std::function& writeCb) { - checkNotFinished(); - - uint64_t timeNs{0}; - uint64_t rowsWritten{0}; - { - NanosecondTimer timer(&timeNs); - rowsWritten = writeCb(); - } - updateAppendStats(rowsWritten, timeNs); - - if (bufferSize() < writeBufferSize_) { - return 0; - } - return flush(); +void SpillWriter::updateAppendStats( + uint64_t numRows, + uint64_t serializationTimeNs) { + auto statsLocked = stats_->wlock(); + statsLocked->spilledRows += numRows; + statsLocked->spillSerializationTimeNanos += serializationTimeNs; + common::updateGlobalSpillAppendStats(numRows, serializationTimeNs); } -void SpillWriterBase::updateWriteStats( +void SpillWriter::updateWriteStats( uint64_t spilledBytes, uint64_t flushTimeNs, - uint64_t writeTimeNs) { + uint64_t fileWriteTimeNs) { auto statsLocked = stats_->wlock(); statsLocked->spilledBytes += spilledBytes; statsLocked->spillFlushTimeNanos += flushTimeNs; - statsLocked->spillWriteTimeNanos += writeTimeNs; + statsLocked->spillWriteTimeNanos += fileWriteTimeNs; ++statsLocked->spillWrites; - common::updateGlobalSpillWriteStats(spilledBytes, flushTimeNs, writeTimeNs); + common::updateGlobalSpillWriteStats( + spilledBytes, flushTimeNs, fileWriteTimeNs); + updateAndCheckLimitCb_(spilledBytes); } -void SpillWriterBase::updateSpilledFileStats(uint64_t fileSize) { +void SpillWriter::updateFileStats( + const serializer::SerializedPageFile::FileInfo& file) { ++stats_->wlock()->spilledFiles; addThreadLocalRuntimeStat( - "spillFileSize", RuntimeCounter(fileSize, RuntimeCounter::Unit::kBytes)); + "spillFileSize", RuntimeCounter(file.size, RuntimeCounter::Unit::kBytes)); common::incrementGlobalSpilledFiles(); } -void SpillWriterBase::updateAppendStats( - uint64_t numRows, - uint64_t serializationTimeNs) { - auto statsLocked = stats_->wlock(); - statsLocked->spilledRows += numRows; - statsLocked->spillSerializationTimeNanos += serializationTimeNs; - common::updateGlobalSpillAppendStats(numRows, serializationTimeNs); -} - -SpillWriter::SpillWriter( - const RowTypePtr& type, - const std::vector& sortingKeys, - common::CompressionKind compressionKind, - const std::string& pathPrefix, - uint64_t targetFileSize, - uint64_t writeBufferSize, - const std::string& fileCreateConfig, - common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, - memory::MemoryPool* pool, - folly::Synchronized* stats) - : SpillWriterBase( - writeBufferSize, - targetFileSize, - pathPrefix, - fileCreateConfig, - updateAndCheckSpillLimitCb, - pool, - stats), - type_(type), - sortingKeys_(sortingKeys), - compressionKind_(compressionKind), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)) {} - -void SpillWriter::flushBuffer( - SpillWriteFile* file, - uint64_t& writtenBytes, - uint64_t& flushTimeNs, - uint64_t& writeTimeNs) { - IOBufOutputStream out( - *pool_, nullptr, std::max(64 * 1024, batch_->size())); - { - NanosecondTimer timer(&flushTimeNs); - batch_->flush(&out); - } - batch_.reset(); - - auto iobuf = out.getIOBuf(); - { - NanosecondTimer timer(&writeTimeNs); - writtenBytes = file->write(std::move(iobuf)); +SpillFiles SpillWriter::finish() { + const auto serializedPageFiles = + serializer::SerializedPageFileWriter::finish(); + SpillFiles spillFiles; + spillFiles.reserve(serializedPageFiles.size()); + for (const auto& fileInfo : serializedPageFiles) { + spillFiles.push_back( + SpillFileInfo{ + .id = fileInfo.id, + .type = type_, + .path = fileInfo.path, + .size = fileInfo.size, + .sortingKeys = sortingKeys_, + .compressionKind = serdeOptions_->compressionKind}); } -} - -uint64_t SpillWriter::write( - const RowVectorPtr& rows, - const folly::Range& indices) { - return writeWithBufferControl([&]() { - if (batch_ == nullptr) { - serializer::presto::PrestoVectorSerde::PrestoOptions options = { - kDefaultUseLosslessTimestamp, - compressionKind_, - 0.8, - /*nullsFirst=*/true}; - batch_ = std::make_unique(pool_, serde_); - batch_->createStreamTree( - std::static_pointer_cast(rows->type()), - 1'000, - &options); - } - batch_->append(rows, indices); - return rows->size(); - }); -} - -void SpillWriter::addFinishedFile(SpillWriteFile* file) { - finishedFiles_.push_back(SpillFileInfo{ - .id = file->id(), - .type = type_, - .path = file->path(), - .size = file->size(), - .sortingKeys = sortingKeys_, - .compressionKind = compressionKind_}); + return spillFiles; } std::vector SpillWriter::testingSpilledFilePaths() const { checkNotFinished(); std::vector spilledFilePaths; + spilledFilePaths.reserve( + finishedFiles_.size() + (currentFile_ != nullptr ? 1 : 0)); for (auto& file : finishedFiles_) { spilledFilePaths.push_back(file.path); } @@ -317,44 +161,25 @@ SpillReadFile::SpillReadFile( common::CompressionKind compressionKind, memory::MemoryPool* pool, folly::Synchronized* stats) - : id_(id), + : serializer::SerializedPageFileReader( + path, + bufferSize, + type, + getNamedVectorSerde(VectorSerde::Kind::kPresto), + std::make_unique< + serializer::presto::PrestoVectorSerde::PrestoOptions>( + kDefaultUseLosslessTimestamp, + compressionKind, + 0.8, + /*_nullsFirst=*/true), + pool), + id_(id), path_(path), size_(size), - type_(type), sortingKeys_(sortingKeys), - compressionKind_(compressionKind), - readOptions_{ - kDefaultUseLosslessTimestamp, - compressionKind_, - 0.8, - /*nullsFirst=*/true}, - pool_(pool), - serde_(getNamedVectorSerde(VectorSerde::Kind::kPresto)), - stats_(stats) { - auto fs = filesystems::getFileSystem(path_, nullptr); - auto file = fs->openFileForRead(path_); - input_ = std::make_unique( - std::move(file), bufferSize, pool_); -} + stats_(stats) {} -bool SpillReadFile::nextBatch(RowVectorPtr& rowVector) { - if (input_->atEnd()) { - recordSpillStats(); - return false; - } - - uint64_t timeNs{0}; - { - NanosecondTimer timer{&timeNs}; - VectorStreamGroup::read( - input_.get(), pool_, type_, serde_, &rowVector, &readOptions_); - } - stats_->wlock()->spillDeserializationTimeNanos += timeNs; - common::updateGlobalSpillDeserializationTimeNs(timeNs); - return true; -} - -void SpillReadFile::recordSpillStats() { +void SpillReadFile::updateFinalStats() { VELOX_CHECK(input_->atEnd()); const auto readStats = input_->stats(); common::updateGlobalSpillReadStats( @@ -363,5 +188,11 @@ void SpillReadFile::recordSpillStats() { lockedSpillStats->spillReads += readStats.numReads; lockedSpillStats->spillReadTimeNanos += readStats.readTimeNs; lockedSpillStats->spillReadBytes += readStats.readBytes; -} +}; + +void SpillReadFile::updateSerializationTimeStats(uint64_t timeNs) { + stats_->wlock()->spillDeserializationTimeNanos += timeNs; + common::updateGlobalSpillDeserializationTimeNs(timeNs); +}; + } // namespace facebook::velox::exec diff --git a/velox/exec/SpillFile.h b/velox/exec/SpillFile.h index 369a1c623399..6f194aa7c324 100644 --- a/velox/exec/SpillFile.h +++ b/velox/exec/SpillFile.h @@ -17,14 +17,16 @@ #pragma once #include +#include #include "velox/common/base/SpillConfig.h" #include "velox/common/base/SpillStats.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/compression/Compression.h" #include "velox/common/file/File.h" #include "velox/common/file/FileInputStream.h" -#include "velox/exec/TreeOfLosers.h" #include "velox/serializers/PrestoSerializer.h" +#include "velox/serializers/SerializedPageFile.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/DecodedVector.h" #include "velox/vector/VectorStream.h" @@ -32,55 +34,6 @@ namespace facebook::velox::exec { using SpillSortKey = std::pair; -/// Represents a spill file for writing the serialized spilled data into a disk -/// file. -class SpillWriteFile { - public: - static std::unique_ptr create( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig); - - uint32_t id() const { - return id_; - } - - /// Returns the file size in bytes. - uint64_t size() const; - - const std::string& path() const { - return path_; - } - - uint64_t write(std::unique_ptr iobuf); - - void write(const char* data, uint64_t bytes); - - WriteFile* file() { - return file_.get(); - } - - /// Finishes writing and flushes any unwritten data. - void finish(); - - private: - static inline std::atomic ordinalCounter_{0}; - - SpillWriteFile( - uint32_t id, - const std::string& pathPrefix, - const std::string& fileCreateConfig); - - // The spill file id which is monotonically increasing and unique for each - // associated spill partition. - const uint32_t id_; - const std::string path_; - - std::unique_ptr file_; - // Byte size of the backing file. Set when finishing writing. - uint64_t size_{0}; -}; - /// Records info of a finished spill file which is used for read. struct SpillFileInfo { uint32_t id; @@ -94,113 +47,10 @@ struct SpillFileInfo { using SpillFiles = std::vector; -/// Used to write the spilled data to a sequence of files for one partition. -/// This base class provides the functionality of managing buffer and write -/// files. The derived classes are responsible for: -/// 1. Creating write API to accommodate the type of data to be spilled. -/// 2. Implementing various buffer APIs and manage the buffer. -class SpillWriterBase { - public: - using WriteCb = std::function; - - SpillWriterBase( - uint64_t writeBufferSize, - uint64_t targetFileSize, - const std::string& pathPrefix, - const std::string& fileCreateConfig, - common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, - memory::MemoryPool* pool, - folly::Synchronized* stats); - - virtual ~SpillWriterBase() = default; - - /// Finishes the current file writing. - void finishFile(); - - /// Finishes this file writer. Further writes to this spill writer are not - /// allowed after this call. - SpillFiles finish(); - - uint64_t numFinishedFiles() const { - return finishedFiles_.size(); - } - - protected: - virtual void flushBuffer( - SpillWriteFile* file, - uint64_t& writtenBytes, - uint64_t& flushTimeNs, - uint64_t& writeTimeNs) = 0; - - virtual bool bufferEmpty() const = 0; - - virtual uint64_t bufferSize() const = 0; - - virtual void addFinishedFile(SpillWriteFile* file) = 0; - - // Write wrapper with buffer control. Derived class needs to implement - // 'writeCb' that needs to only write data to buffer, without taking care of - // the buffer limit. 'args' are the arguments to be passed to 'writeCb'. - uint64_t writeWithBufferControl(const std::function& writeCb); - - // Writes data from buffer to the current output file. Returns the actual - // written size. - uint64_t flush(); - - FOLLY_ALWAYS_INLINE void checkNotFinished() const { - VELOX_CHECK(!finished_, "SpillWriter has finished"); - } - - memory::MemoryPool* const pool_; - - folly::Synchronized* const stats_; - - std::unique_ptr currentFile_; - - SpillFiles finishedFiles_; - - private: - // Returns an open spill file for write. If there is no open spill file, then - // the function creates a new one. If the current open spill file exceeds the - // target file size limit, then it first closes the current one and then - // creates a new one. 'currentFile_' points to the current open spill file. - SpillWriteFile* ensureFile(); - - // Closes the current open spill file pointed by 'currentFile_'. - void closeFile(); - - // Invoked to update the disk write stats. - void updateWriteStats( - uint64_t spilledBytes, - uint64_t flushTimeNs, - uint64_t writeTimeNs); - - // Invoked to increment the number of spilled files and the file size. - void updateSpilledFileStats(uint64_t fileSize); - - // Invoked to update the number of spilled rows. - void updateAppendStats(uint64_t numRows, uint64_t serializationTimeUs); - - // Updates the aggregated spill bytes of this query, and throws if exceeds - // the max spill bytes limit. - const common::UpdateAndCheckSpillLimitCB updateAndCheckSpillLimitCb_; - - const std::string fileCreateConfig_; - - const std::string pathPrefix_; - - const uint64_t writeBufferSize_; - - const uint64_t targetFileSize_; - - uint64_t nextFileId_{0}; - - bool finished_{false}; -}; - -/// If data is sorted, each file is sorted. The globally sorted order is -/// produced by merging the constituent files. -class SpillWriter : public SpillWriterBase { +/// Used to write the spilled data to a sequence of files for one partition. If +/// data is sorted, each file is sorted. The globally sorted order is produced +/// by merging the constituent files. +class SpillWriter : public serializer::SerializedPageFileWriter { public: /// 'type' is a RowType describing the content. 'numSortKeys' is the number /// of leading columns on which the data is sorted. 'path' is a file path @@ -221,49 +71,43 @@ class SpillWriter : public SpillWriterBase { uint64_t targetFileSize, uint64_t writeBufferSize, const std::string& fileCreateConfig, - common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, + const common::UpdateAndCheckSpillLimitCB& updateAndCheckSpillLimitCb, memory::MemoryPool* pool, folly::Synchronized* stats); - /// Adds 'rows' for the positions in 'indices' into 'this'. The indices - /// must produce a view where the rows are sorted if sorting is desired. - /// Consecutive calls must have sorted data so that the first row of the - /// next call is not less than the last row of the previous call. - /// Returns the size to write. - uint64_t write( - const RowVectorPtr& rows, - const folly::Range& indices); + /// Finishes this file writer and returns the written spill files info. + /// + /// NOTE: we don't allow write to a spill writer after finish + SpillFiles finish(); std::vector testingSpilledFilePaths() const; std::vector testingSpilledFileIds() const; private: - bool bufferEmpty() const override { - return batch_ == nullptr; - } - - uint64_t bufferSize() const override { - return batch_->size(); - } + // Invoked to increment the number of spilled files and the file size. + void updateFileStats( + const serializer::SerializedPageFile::FileInfo& fileInfo) override; - void flushBuffer( - SpillWriteFile* file, - uint64_t& writtenBytes, - uint64_t& flushTimeNs, - uint64_t& writeTimeNs) override; + // Invoked to update the number of spilled rows. + void updateAppendStats(uint64_t numRows, uint64_t serializationTimeUs) + override; - void addFinishedFile(SpillWriteFile* file) override; + // Invoked to update the disk write stats. + void updateWriteStats( + uint64_t spilledBytes, + uint64_t flushTimeUs, + uint64_t writeTimeUs) override; const RowTypePtr type_; const std::vector sortingKeys_; - const common::CompressionKind compressionKind_; - - VectorSerde* const serde_; + folly::Synchronized* const stats_; - std::unique_ptr batch_; + // Updates the aggregated bytes of this query, and throws if exceeds + // the max bytes limit. + const common::UpdateAndCheckSpillLimitCB updateAndCheckLimitCb_; }; /// Represents a spill file for read which turns the serialized spilled data @@ -273,7 +117,7 @@ class SpillWriter : public SpillWriterBase { /// needs to remove the unused spill files at some point later. For example, a /// query Task deletes all the generated spill files in one operation using /// rmdir() call. -class SpillReadFile { +class SpillReadFile : public serializer::SerializedPageFileReader { public: static std::unique_ptr create( const SpillFileInfo& fileInfo, @@ -289,8 +133,6 @@ class SpillReadFile { return sortingKeys_; } - bool nextBatch(RowVectorPtr& rowVector); - /// Returns the file size in bytes. uint64_t size() const { return size_; @@ -312,24 +154,23 @@ class SpillReadFile { memory::MemoryPool* pool, folly::Synchronized* stats); - // Invoked to record spill read stats at the end of read input. - void recordSpillStats(); + // Records spill read stats at the end of read input. + void updateFinalStats() override; + + void updateSerializationTimeStats(uint64_t timeNs) override; // The spill file id which is monotonically increasing and unique for each // associated spill partition. const uint32_t id_; + const std::string path_; + // The file size in bytes. const uint64_t size_; - // The data type of spilled data. - const RowTypePtr type_; + const std::vector sortingKeys_; - const common::CompressionKind compressionKind_; - const serializer::presto::PrestoVectorSerde::PrestoOptions readOptions_; - memory::MemoryPool* const pool_; - VectorSerde* const serde_; - folly::Synchronized* const stats_; - std::unique_ptr input_; + folly::Synchronized* const stats_; }; + } // namespace facebook::velox::exec diff --git a/velox/exec/Spiller.cpp b/velox/exec/Spiller.cpp index d8dcec644e80..06675ede53aa 100644 --- a/velox/exec/Spiller.cpp +++ b/velox/exec/Spiller.cpp @@ -101,8 +101,9 @@ bool SpillerBase::fillSpillRuns(RowContainerIterator* iterator) { uint64_t totalRows{0}; for (;;) { - const auto numRows = container_->listRows( - iterator, rows.size(), RowContainer::kUnlimited, rows.data()); + // TODO: Reuse 'RowContainer::rowPointers_'. + const auto numRows = + container_->listRows(iterator, rows.size(), rows.data()); if (numRows == 0) { lastRun = true; break; @@ -154,8 +155,9 @@ void SpillerBase::runSpill(bool lastRun) { if (spillRun.rows.empty()) { continue; } - writes.push_back(memory::createAsyncMemoryReclaimTask( - [partitionId = id, this]() { return writeSpill(partitionId); })); + writes.push_back( + memory::createAsyncMemoryReclaimTask( + [partitionId = id, this]() { return writeSpill(partitionId); })); if ((writes.size() > 1) && executor_ != nullptr) { executor_->add([source = writes.back()]() { source->prepare(); }); } @@ -386,17 +388,13 @@ NoRowContainerSpiller::NoRowContainerSpiller( RowTypePtr rowType, std::optional parentId, HashBitRange bits, - const std::vector& sortingKeys, const common::SpillConfig* spillConfig, folly::Synchronized* spillStats) - : SpillerBase( - nullptr, + : NoRowContainerSpiller( std::move(rowType), - bits, - sortingKeys, - spillConfig->maxFileSize, - 0, parentId, + bits, + {}, spillConfig, spillStats) {} @@ -404,13 +402,17 @@ NoRowContainerSpiller::NoRowContainerSpiller( RowTypePtr rowType, std::optional parentId, HashBitRange bits, + const std::vector& sortingKeys, const common::SpillConfig* spillConfig, folly::Synchronized* spillStats) - : NoRowContainerSpiller( + : SpillerBase( + nullptr, std::move(rowType), - parentId, bits, - {}, + sortingKeys, + spillConfig->maxFileSize, + 0, + parentId, spillConfig, spillStats) {} diff --git a/velox/exec/Spiller.h b/velox/exec/Spiller.h index 5fa5a913fa7c..9151a7eaf53f 100644 --- a/velox/exec/Spiller.h +++ b/velox/exec/Spiller.h @@ -202,14 +202,6 @@ class NoRowContainerSpiller : public SpillerBase { public: static constexpr std::string_view kType = "NoRowContainerSpiller"; - NoRowContainerSpiller( - RowTypePtr rowType, - std::optional parentId, - HashBitRange bits, - const std::vector& sortingKeys, - const common::SpillConfig* spillConfig, - folly::Synchronized* spillStats); - NoRowContainerSpiller( RowTypePtr rowType, std::optional parentId, @@ -227,6 +219,15 @@ class NoRowContainerSpiller : public SpillerBase { } } + protected: + NoRowContainerSpiller( + RowTypePtr rowType, + std::optional parentId, + HashBitRange bits, + const std::vector& sortingKeys, + const common::SpillConfig* spillConfig, + folly::Synchronized* spillStats); + private: std::string type() const override { return std::string(kType); @@ -237,6 +238,24 @@ class NoRowContainerSpiller : public SpillerBase { } }; +class MergeSpiller final : public NoRowContainerSpiller { + public: + MergeSpiller( + RowTypePtr rowType, + std::optional parentId, + HashBitRange bits, + const std::vector& sortingKeys, + const common::SpillConfig* spillConfig, + folly::Synchronized* spillStats) + : NoRowContainerSpiller( + std::move(rowType), + parentId, + bits, + sortingKeys, + spillConfig, + spillStats) {} +}; + class SortInputSpiller : public SpillerBase { public: static constexpr std::string_view kType = "SortInputSpiller"; diff --git a/velox/exec/StreamingAggregation.cpp b/velox/exec/StreamingAggregation.cpp index 241e8c584778..9cb071ebaadf 100644 --- a/velox/exec/StreamingAggregation.cpp +++ b/velox/exec/StreamingAggregation.cpp @@ -41,8 +41,11 @@ StreamingAggregation::StreamingAggregation( ->queryConfig() .streamingAggregationMinOutputBatchRows()) : maxOutputBatchSize_}, + maxOutputBatchBytes_{ + operatorCtx_->driverCtx()->queryConfig().preferredOutputBatchBytes()}, aggregationNode_{aggregationNode}, - step_{aggregationNode->step()} { + step_{aggregationNode->step()}, + noGroupsSpanBatches_{aggregationNode_->noGroupsSpanBatches()} { if (aggregationNode_->ignoreNullKeys()) { VELOX_UNSUPPORTED( "Streaming aggregation doesn't support ignoring null keys yet"); @@ -200,12 +203,13 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) { return output; } -void StreamingAggregation::assignGroups() { +bool StreamingAggregation::assignGroups() { const auto numInput = input_->size(); VELOX_CHECK_GT(numInput, 0); inputGroups_.resize(numInput); + bool prevGroupAssigned{false}; // Look for the end of the last group. vector_size_t index = 0; if (prevInput_ != nullptr) { @@ -213,6 +217,7 @@ void StreamingAggregation::assignGroups() { auto* prevGroup = groups_[numGroups_ - 1]; for (; index < numInput; ++index) { if (equalKeys(groupingKeys_, prevInput_, prevIndex, input_, index)) { + prevGroupAssigned = true; inputGroups_[index] = prevGroup; } else { break; @@ -246,6 +251,7 @@ void StreamingAggregation::assignGroups() { } } groupBoundaries_.push_back(numInput); + return prevGroupAssigned; } const SelectivityVector& StreamingAggregation::getSelectivityVector( @@ -280,7 +286,8 @@ void StreamingAggregation::evaluateAggregates() { std::vector args; for (auto j = 0; j < inputs.size(); ++j) { if (inputs[j] == kConstantChannel) { - args.push_back(constantInputs[j]); + args.push_back( + BaseVector::wrapInConstant(input_->size(), 0, constantInputs[j])); } else { args.push_back(input_->childAt(inputs[j])); } @@ -326,13 +333,20 @@ bool StreamingAggregation::isFinished() { RowVectorPtr StreamingAggregation::getOutput() { if (!input_) { + SCOPE_EXIT { + outputFirstGroup_ = false; + }; if ((noMoreInput_ || isDraining()) && numGroups_ > 0) { - return createOutput( - std::min(numGroups_, static_cast(maxOutputBatchSize_))); + return createOutput(numGroups_); + } + if (outputFirstGroup_) { + VELOX_CHECK_GT(numGroups_, 1); + return createOutput(1); } maybeFinishDrain(); return nullptr; } + VELOX_CHECK(!outputFirstGroup_); const auto numInput = input_->size(); inputRows_.resize(numInput); @@ -341,17 +355,44 @@ RowVectorPtr StreamingAggregation::getOutput() { masks_->addInput(input_, inputRows_); const auto numPrevGroups = numGroups_; - - assignGroups(); + const bool prevGroupAssigned = assignGroups(); initializeNewGroups(numPrevGroups); evaluateAggregates(); + const auto estimatedRowBytes = rows_->estimateRowSize(); + const auto estimatedBatchBytes = + estimatedRowBytes.value_or(0) * rows_->numRows(); + RowVectorPtr output; - if (numGroups_ > minOutputBatchSize_) { - output = createOutput( - std::min(numGroups_ - 1, static_cast(maxOutputBatchSize_))); + + const bool outputDueToBatchSize = numGroups_ > minOutputBatchSize_; + const bool outputDueToBatchBytes = + numGroups_ > 1 && estimatedBatchBytes > maxOutputBatchBytes_; + if ((noGroupsSpanBatches_ || numPrevGroups > 0) && + (outputDueToBatchSize || outputDueToBatchBytes)) { + size_t numOutputGroups{0}; + if (noGroupsSpanBatches_) { + numOutputGroups = numGroups_; + } else { + // NOTE: we only want to apply the single group output optimization if + // 'minOutputBatchSize_' is set to one for eagerly streaming output + // producing. + if (!prevGroupAssigned || numPrevGroups == 1 || + minOutputBatchSize_ != 1) { + numOutputGroups = std::min(numGroups_ - 1, numPrevGroups); + } else { + numOutputGroups = std::min(numGroups_ - 1, numPrevGroups - 1); + outputFirstGroup_ = (numGroups_ - numOutputGroups) > 1; + } + } + VELOX_CHECK_GT(numOutputGroups, 0); + output = createOutput(numOutputGroups); } prevInput_ = input_; + if (numGroups_ == 0) { + VELOX_CHECK(noGroupsSpanBatches_); + prevInput_ = nullptr; + } input_ = nullptr; return output; } @@ -384,6 +425,7 @@ std::unique_ptr StreamingAggregation::makeRowContainer( false, false, false, + false, pool()); } diff --git a/velox/exec/StreamingAggregation.h b/velox/exec/StreamingAggregation.h index 951cab1a4105..5fcf5d4b7294 100644 --- a/velox/exec/StreamingAggregation.h +++ b/velox/exec/StreamingAggregation.h @@ -40,7 +40,9 @@ class StreamingAggregation : public Operator { RowVectorPtr getOutput() override; bool needsInput() const override { - return true; + // We don't need input if the first group is ready to output which has mixed + // input sources across streaming input batches. + return !outputFirstGroup_; } bool startDrain() override; @@ -71,8 +73,9 @@ class StreamingAggregation : public Operator { RowVectorPtr createOutput(size_t numGroups); // Assign input rows to groups based on values of the grouping keys. Store the - // assignments in inputGroups_. - void assignGroups(); + // assignments in inputGroups_. Returns true if there is input rows have been + // assigned to the previously last group. + bool assignGroups(); // Add input data to accumulators. void evaluateAggregates(); @@ -87,17 +90,26 @@ class StreamingAggregation : public Operator { // Initialize the aggregations setting allocator and offsets. void initializeAggregates(uint32_t numKeys); - /// Maximum number of rows in the output batch. + // Maximum number of rows in the output batch. const vector_size_t maxOutputBatchSize_; - /// Maximum number of rows in the output batch. + // Maximum number of rows in the output batch. const vector_size_t minOutputBatchSize_; + // If the size of the data in the RowContainer exceeds this value, we will + // output a batch regardless of the number of rows. + const uint64_t maxOutputBatchBytes_; + // Used at initialize() and gets reset() afterward. std::shared_ptr aggregationNode_; const core::AggregationNode::Step step_; + // When true, indicates that no sort group spans across input batches. Each + // input batch contains complete data for its groups. This allows the + // streaming aggregation operator to produce all group results for each input. + const bool noGroupsSpanBatches_; + std::vector groupingKeys_; std::vector aggregates_; std::unique_ptr sortedAggregations_; @@ -119,6 +131,17 @@ class StreamingAggregation : public Operator { // remaining entries are re-usable. size_t numGroups_{0}; + // If true, we want to output the first group which has inputs across + // different batches. Hence the next output could only contain the input from + // a single streaming input batch. This is used to help avoid data copy in + // streaming aggregation function processing which is only applicable if all + // the sources are from the same input batch. + // + // NOTE: the streaming aggregation operator must have at-least more than one + // groups in this case. Also we only enable this optimization if + // 'minOutputBatchSize_' is set to one for eagerly streaming output producing. + bool outputFirstGroup_{false}; + // Reusable memory. // Pointers to groups for all input rows. diff --git a/velox/exec/SubPartitionedSortWindowBuild.cpp b/velox/exec/SubPartitionedSortWindowBuild.cpp new file mode 100644 index 000000000000..bb380abeaad0 --- /dev/null +++ b/velox/exec/SubPartitionedSortWindowBuild.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/SubPartitionedSortWindowBuild.h" +#include "velox/exec/MemoryReclaimer.h" + +namespace facebook::velox::exec { + +SubPartitionedSortWindowBuild::SubPartitionedSortWindowBuild( + const std::shared_ptr& node, + int32_t numSubPartitions, + velox::memory::MemoryPool* pool, + common::PrefixSortConfig&& prefixSortConfig, + const common::SpillConfig* spillConfig, + tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, + folly::Synchronized* spillStats) + : WindowBuild(node, pool, spillConfig, nonReclaimableSection), + numSubPartitions_(numSubPartitions), + numPartitionKeys_{node->partitionKeys().size()}, + pool_(pool), + spillStats_(spillStats) { + VELOX_CHECK_NOT_NULL(pool_); + data_.reset(); + + std::vector keyChannels(numPartitionKeys_); + for (int i = 0; i < numPartitionKeys_; i++) { + keyChannels[i] = inputChannels_[i]; + } + subPartitioningFunction_ = std::make_unique( + false, numSubPartitions_, node->inputType(), keyChannels); + subWindowBuilds_.resize(numSubPartitions_); + for (int i = 0; i < numSubPartitions_; i++) { + subWindowBuilds_[i] = std::make_unique( + node, + pool, + common::PrefixSortConfig(prefixSortConfig), + spillConfig, + nonReclaimableSection, + opStats, + spillStats); + } +} + +void SubPartitionedSortWindowBuild::addInput(RowVectorPtr input) { + VELOX_CHECK_LT(currentSubPartition_, 0); + + subPartitionIdsBuffer_.resize(input->size()); + subPartitioningFunction_->partition(*input, subPartitionIdsBuffer_); + + for (auto i = 0; i < inputChannels_.size(); ++i) { + decodedInputVectors_[i].decode(*input->childAt(inputChannels_[i])); + } + + ensureInputFits(input); + + for (auto row = 0; row < input->size(); ++row) { + auto& windowBuild = subWindowBuilds_[subPartitionIdsBuffer_[row]]; + windowBuild->addDecodedInputRow(decodedInputVectors_, row); + } + + numRows_ += input->size(); +} + +bool SubPartitionedSortWindowBuild::switchToNextSubPartition() { + if (currentSubPartition_ >= numSubPartitions_) { + return false; + } + + if (currentSubPartition_ >= 0) { + subWindowBuilds_[currentSubPartition_].reset(); + } + currentSubPartition_++; + if (currentSubPartition_ >= numSubPartitions_) { + return false; + } + + VELOX_CHECK_NOT_NULL(subWindowBuilds_[currentSubPartition_]); + // WindowBuild starts processing the partitions when 'noMoreInput' is called, + // which allocates additional memory. We want to defer the memory allocation + // as late as possible to reduce memory usage, so we don't call 'noMoreInput' + // until the sub partition's data is to be consumed. + subWindowBuilds_[currentSubPartition_]->noMoreInput(); + return true; +} + +void SubPartitionedSortWindowBuild::ensureInputFits(const RowVectorPtr& input) { + if (spillConfig_ == nullptr) { + // Spilling is disabled. + return; + } + + if (numRows_ == 0) { + // Nothing to spill. + return; + } + + // Test-only spill path. + if (testingTriggerSpill(pool_->name())) { + spill(); + return; + } + + VELOX_CHECK_LT(currentSubPartition_, 0); + for (auto& windowBuild : subWindowBuilds_) { + windowBuild->ensureInputFits(input); + } +} + +void SubPartitionedSortWindowBuild::spill() { + VELOX_CHECK_LT(currentSubPartition_, 0); + for (auto& windowBuild : subWindowBuilds_) { + windowBuild->spill(); + } + spilled_ = true; +} + +std::optional SubPartitionedSortWindowBuild::spilledStats() + const { + if (!spilled_) { + return std::nullopt; + } + return {spillStats_->copy()}; +} + +void SubPartitionedSortWindowBuild::noMoreInput() { + if (numRows_ == 0) { + return; + } + + if (spilled_) { + // Spill remaining data to avoid running out of memory while sort-merging + // spilled data. + spill(); + } + + switchToNextSubPartition(); + + VELOX_CHECK_EQ(currentSubPartition_, 0); +} + +std::shared_ptr +SubPartitionedSortWindowBuild::nextPartition() { + VELOX_CHECK_GE(currentSubPartition_, 0); + VELOX_CHECK_LT(currentSubPartition_, numSubPartitions_); + VELOX_CHECK_NOT_NULL(subWindowBuilds_[currentSubPartition_]); + return subWindowBuilds_[currentSubPartition_]->nextPartition(); +} + +std::optional SubPartitionedSortWindowBuild::estimateRowSize() { + auto subPartition = std::max(currentSubPartition_, 0); + if (subPartition >= numSubPartitions_) { + return std::nullopt; + } + + if (subWindowBuilds_[subPartition]) { + return subWindowBuilds_[subPartition]->estimateRowSize(); + } + + return std::nullopt; +} + +bool SubPartitionedSortWindowBuild::hasNextPartition() { + // Check if the build hasn't begun or has finished. + if (currentSubPartition_ < 0 || currentSubPartition_ >= numSubPartitions_) { + return false; + } + + VELOX_CHECK_NOT_NULL(subWindowBuilds_[currentSubPartition_]); + if (subWindowBuilds_[currentSubPartition_]->hasNextPartition()) { + return true; + } + + if (switchToNextSubPartition()) { + return hasNextPartition(); + } + return false; +} +} // namespace facebook::velox::exec diff --git a/velox/exec/SubPartitionedSortWindowBuild.h b/velox/exec/SubPartitionedSortWindowBuild.h new file mode 100644 index 000000000000..667db5e2d70b --- /dev/null +++ b/velox/exec/SubPartitionedSortWindowBuild.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/HashPartitionFunction.h" +#include "velox/exec/PrefixSort.h" +#include "velox/exec/SortWindowBuild.h" +#include "velox/exec/Spiller.h" + +namespace facebook::velox::exec { +// Divides the input data into several sub partitions by partition keys, then +// sequentially sorts input data of each sub partition by {partition keys, sort +// keys} to identify window partitions with SortWindowBuild. As each sub +// partition has a smaller working set, the memory used by sorting is reduced. +// Besides, once a sub partition is completely consumed, its memory could be +// released immediately. +class SubPartitionedSortWindowBuild : public WindowBuild { + public: + SubPartitionedSortWindowBuild( + const std::shared_ptr& node, + int32_t numSubPartitions, + velox::memory::MemoryPool* pool, + common::PrefixSortConfig&& prefixSortConfig, + const common::SpillConfig* spillConfig, + tsan_atomic* nonReclaimableSection, + folly::Synchronized* opStats, + folly::Synchronized* spillStats); + + ~SubPartitionedSortWindowBuild() override { + pool_->release(); + } + + bool needsInput() override { + // No sub partitions are available yet, so can consume input rows. + return currentSubPartition_ < 0; + } + + void addInput(RowVectorPtr input) override; + + void spill() override; + + std::optional spilledStats() const override; + + void noMoreInput() override; + + bool hasNextPartition() override; + + std::shared_ptr nextPartition() override; + + std::optional estimateRowSize() override; + + private: + // The current sub partition's WindowBuild has finished producing all the + // data. Release all the memory of current sub partition's WindowBuild, and + // then switch to next sub partition's WindowBuild as the new current one. + bool switchToNextSubPartition(); + + void ensureInputFits(const RowVectorPtr& input); + + const int32_t numSubPartitions_; + + const size_t numPartitionKeys_; + + memory::MemoryPool* const pool_; + + folly::Synchronized* const spillStats_; + + // Divide input rows to the corresponding sub partitions. + std::unique_ptr subPartitioningFunction_; + + // WindowBuilds for each sub partition. + std::vector> subWindowBuilds_; + + bool spilled_{false}; + + // Buffers the subPartitionIds for each row. Reused across addInput calls. + std::vector subPartitionIdsBuffer_; + + int32_t currentSubPartition_ = -1; +}; +} // namespace facebook::velox::exec diff --git a/velox/exec/TableScan.cpp b/velox/exec/TableScan.cpp index e5f4c6f76e8a..6c19c1b89510 100644 --- a/velox/exec/TableScan.cpp +++ b/velox/exec/TableScan.cpp @@ -17,13 +17,56 @@ #include "velox/common/testutil/TestValue.h" #include "velox/common/time/Timer.h" #include "velox/exec/Task.h" -#include "velox/exec/TraceUtil.h" -#include "velox/expression/Expr.h" using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { +namespace { + +std::unique_ptr createDataSource( + folly::Synchronized& pushdownFilters, + connector::Connector& connector, + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, + connector::ConnectorQueryCtx* connectorQueryCtx) { + auto dataSource = connector.createDataSource( + outputType, tableHandle, columnHandles, connectorQueryCtx); + auto* staticFilters = dataSource->getFilters(); + if (!staticFilters) { + VELOX_CHECK(!connector.canAddDynamicFilter()); + return dataSource; + } + { + auto lk = pushdownFilters.rlock(); + if (lk->staticFiltersInitialized) { + for (auto outIndex : lk->dynamicFilteredColumns) { + dataSource->addDynamicFilter(outIndex, lk->filters.at(outIndex)); + } + return dataSource; + } + } + auto lk = pushdownFilters.wlock(); + if (!lk->staticFiltersInitialized) { + for (column_index_t i = 0, size = outputType->size(); i < size; ++i) { + auto handle = columnHandles.find(outputType->nameOf(i)); + VELOX_CHECK(handle != columnHandles.end()); + auto field = common::Subfield::create(handle->second->name()); + if (auto it = staticFilters->find(*field); it != staticFilters->end()) { + common::Filter::merge(it->second, lk->filters[i]); + } + } + lk->staticFiltersInitialized = true; + } + for (auto outIndex : lk->dynamicFilteredColumns) { + dataSource->addDynamicFilter(outIndex, lk->filters.at(outIndex)); + } + return dataSource; +} + +} // namespace + TableScan::TableScan( int32_t operatorId, DriverCtx* driverCtx, @@ -55,6 +98,11 @@ TableScan::TableScan( readBatchSize_ = driverCtx_->queryConfig().preferredOutputBatchRows(); } +void TableScan::initialize() { + SourceOperator::initialize(); + VELOX_CHECK_EQ(driverCtx_->driver->operatorIndex(this), 0); +} + bool TableScan::shouldYield(StopReason taskStopReason, size_t startTimeMs) const { // Checks task-level yield signal, driver-level yield signal and table scan @@ -132,15 +180,16 @@ RowVectorPtr TableScan::getOutput() { } continue; } - const auto estimatedRowSize = dataSource_->estimatedRowSize(); - readBatchSize_ = - estimatedRowSize == connector::DataSource::kUnknownRowSize - ? outputBatchRows() - : outputBatchRows(estimatedRowSize); } VELOX_CHECK(!needNewSplit_); VELOX_CHECK(!hasDrained()); + const auto estimatedRowSize = dataSource_->estimatedRowSize(); + // TODO: Expose this to operator stats. + VLOG(1) << "estimatedRowSize = " << estimatedRowSize; + readBatchSize_ = estimatedRowSize == connector::DataSource::kUnknownRowSize + ? outputBatchRows() + : outputBatchRows(estimatedRowSize); int32_t readBatchSize = readBatchSize_; if (maxFilteringRatio_ > 0) { readBatchSize = std::min( @@ -151,6 +200,7 @@ RowVectorPtr TableScan::getOutput() { std::optional dataOptional; { MicrosecondTimer timer(&ioTimeUs); + auto lk = driverCtx_->driver->pushdownFilters()->at(0).rlock(); dataOptional = dataSource_->next(readBatchSize, blockingFuture_); } @@ -174,20 +224,23 @@ RowVectorPtr TableScan::getOutput() { // at least read one batch from a split to trigger split fetch inside Meta // internal data source connector. if (data != nullptr && !shouldDropOutput()) { + constexpr int kMaxSelectiveBatchSizeMultiplier = 4; if (data->size() > 0) { lockedStats->addInputVector(data->estimateFlatSize(), data->size()); - constexpr int kMaxSelectiveBatchSizeMultiplier = 4; maxFilteringRatio_ = std::max( {maxFilteringRatio_, 1.0 * data->size() / readBatchSize, 1.0 / kMaxSelectiveBatchSizeMultiplier}); if (ioTimeUs > 0) { - RECORD_HISTOGRAM_METRIC_VALUE( + RECORD_METRIC_VALUE( velox::kMetricTableScanBatchProcessTimeMs, ioTimeUs / 1'000); } - RECORD_HISTOGRAM_METRIC_VALUE( + RECORD_METRIC_VALUE( velox::kMetricTableScanBatchBytes, data->estimateFlatSize()); return data; + } else { + maxFilteringRatio_ = std::max( + maxFilteringRatio_, 1.0 / kMaxSelectiveBatchSizeMultiplier); } continue; } @@ -250,17 +303,16 @@ bool TableScan::getSplit() { if (!split.hasConnectorSplit()) { noMoreSplits_ = true; - dynamicFilters_.clear(); if (dataSource_) { - const auto connectorStats = dataSource_->runtimeStats(); + const auto connectorStats = dataSource_->getRuntimeStats(); auto lockedStats = stats_.wlock(); - for (const auto& [name, counter] : connectorStats) { + for (const auto& [name, metric] : connectorStats) { if (FOLLY_UNLIKELY(lockedStats->runtimeStats.count(name) == 0)) { - lockedStats->runtimeStats.emplace(name, RuntimeMetric(counter.unit)); + lockedStats->runtimeStats.emplace(name, RuntimeMetric(metric.unit)); } else { - VELOX_CHECK_EQ(lockedStats->runtimeStats.at(name).unit, counter.unit); + VELOX_CHECK_EQ(lockedStats->runtimeStats.at(name).unit, metric.unit); } - lockedStats->runtimeStats.at(name).addValue(counter.value); + lockedStats->runtimeStats.at(name).merge(metric); } } return false; @@ -269,6 +321,9 @@ bool TableScan::getSplit() { if (FOLLY_UNLIKELY(splitTracer_ != nullptr)) { splitTracer_->write(split); } + + stats_.wlock()->addRuntimeStat( + "connectorSplitSize", RuntimeCounter(split.connectorSplit->size())); const auto& connectorSplit = split.connectorSplit; currentSplitWeight_ = connectorSplit->splitWeight; needNewSplit_ = false; @@ -285,11 +340,13 @@ bool TableScan::getSplit() { if (dataSource_ == nullptr) { connectorQueryCtx_ = operatorCtx_->createConnectorQueryCtx( connectorSplit->connectorId, planNodeId(), connectorPool_); - dataSource_ = connector_->createDataSource( - outputType_, tableHandle_, columnHandles_, connectorQueryCtx_.get()); - for (const auto& entry : dynamicFilters_) { - dataSource_->addDynamicFilter(entry.first, entry.second); - } + dataSource_ = createDataSource( + driverCtx_->driver->pushdownFilters()->at(0), + *connector_, + outputType_, + tableHandle_, + columnHandles_, + connectorQueryCtx_.get()); } debugString_ = fmt::format( @@ -308,9 +365,17 @@ bool TableScan::getSplit() { // The AsyncSource returns a unique_ptr to a shared_ptr. The unique_ptr // will be nullptr if there was a cancellation. numReadyPreloadedSplits_ += connectorSplit->dataSource->hasValue(); + auto startTimeNs = getCurrentTimeNano(); auto preparedDataSource = connectorSplit->dataSource->move(); - stats_.wlock()->getOutputTiming.add( - connectorSplit->dataSource->prepareTiming()); + auto endTimeNs = getCurrentTimeNano(); + stats_.wlock()->addRuntimeStat( + "waitForPreloadSplitNanos", + RuntimeCounter(endTimeNs - startTimeNs, RuntimeCounter::Unit::kNanos)); + stats_.wlock()->addRuntimeStat( + "preloadSplitPrepareTimeNanos", + RuntimeCounter( + connectorSplit->dataSource->prepareTiming().wallNanos, + RuntimeCounter::Unit::kNanos)); if (!preparedDataSource) { // There must be a cancellation. VELOX_CHECK(operatorCtx_->task()->isCancelled()); @@ -321,6 +386,7 @@ bool TableScan::getSplit() { uint64_t addSplitTimeUs{0}; { MicrosecondTimer timer(&addSplitTimeUs); + auto lk = driverCtx_->driver->pushdownFilters()->at(0).rlock(); dataSource_->addSplit(connectorSplit); } stats_.wlock()->addRuntimeStat( @@ -369,7 +435,7 @@ void TableScan::preload( ctx = operatorCtx_->createConnectorQueryCtx( split->connectorId, planNodeId(), connectorPool_), task = operatorCtx_->task(), - dynamicFilters = dynamicFilters_, + pushdownFilters = driverCtx_->driver->pushdownFilters(), split]() -> std::unique_ptr { if (task->isCancelled()) { return nullptr; @@ -382,22 +448,27 @@ void TableScan::preload( }, &debugString}); - auto dataSource = - connector->createDataSource(type, table, columns, ctx.get()); + auto dataSource = createDataSource( + pushdownFilters->at(0), + *connector, + type, + table, + columns, + ctx.get()); if (task->isCancelled()) { return nullptr; } - for (const auto& entry : dynamicFilters) { - dataSource->addDynamicFilter(entry.first, entry.second); + { + auto lk = pushdownFilters->at(0).rlock(); + dataSource->addSplit(split); } - dataSource->addSplit(split); return dataSource; }); } void TableScan::checkPreload() { - auto* executor = connector_->executor(); - if (maxSplitPreloadPerDriver_ == 0 || !executor || + auto* ioExecutor = connector_->ioExecutor(); + if (maxSplitPreloadPerDriver_ == 0 || !ioExecutor || !connector_->supportsSplitPreload()) { return; } @@ -406,11 +477,11 @@ void TableScan::checkPreload() { maxSplitPreloadPerDriver_; if (!splitPreloader_) { splitPreloader_ = - [executor, + [ioExecutor, this](const std::shared_ptr& split) { preload(split); - executor->add([connectorSplit = split]() mutable { + ioExecutor->add([connectorSplit = split]() mutable { connectorSplit->dataSource->prepare(); connectorSplit.reset(); }); @@ -423,18 +494,13 @@ bool TableScan::isFinished() { return noMoreSplits_; } -void TableScan::addDynamicFilter( +void TableScan::addDynamicFilterLocked( const core::PlanNodeId& producer, - column_index_t outputChannel, - const std::shared_ptr& filter) { + const PushdownFilters& filters) { if (dataSource_) { - dataSource_->addDynamicFilter(outputChannel, filter); - } - auto& currentFilter = dynamicFilters_[outputChannel]; - if (currentFilter) { - currentFilter = currentFilter->mergeWith(filter.get()); - } else { - currentFilter = filter; + for (auto channel : filters.dynamicFilteredColumns) { + dataSource_->addDynamicFilter(channel, filters.filters.at(channel)); + } } stats_.wlock()->dynamicFilterStats.producerNodeIds.emplace(producer); } diff --git a/velox/exec/TableScan.h b/velox/exec/TableScan.h index daebdd85cd1e..69bff9ddc457 100644 --- a/velox/exec/TableScan.h +++ b/velox/exec/TableScan.h @@ -28,6 +28,8 @@ class TableScan : public SourceOperator { DriverCtx* driverCtx, const std::shared_ptr& tableScanNode); + void initialize() override; + RowVectorPtr getOutput() override; BlockingReason isBlocked(ContinueFuture* future) override { @@ -50,10 +52,9 @@ class TableScan : public SourceOperator { return connector_->canAddDynamicFilter(); } - void addDynamicFilter( + void addDynamicFilterLocked( const core::PlanNodeId& producer, - column_index_t outputChannel, - const std::shared_ptr& filter) override; + const PushdownFilters& filters) override; /// The name of runtime stats specific to table scan. /// The number of running table scan drivers. @@ -98,10 +99,8 @@ class TableScan : public SourceOperator { // processing or not. void tryScaleUp(); - const std::shared_ptr tableHandle_; - const std:: - unordered_map> - columnHandles_; + const connector::ConnectorTableHandlePtr tableHandle_; + const connector::ColumnHandleMap columnHandles_; DriverCtx* const driverCtx_; const int32_t maxSplitPreloadPerDriver_{0}; const vector_size_t maxReadBatchSize_; @@ -124,9 +123,6 @@ class TableScan : public SourceOperator { std::shared_ptr connectorQueryCtx_; std::unique_ptr dataSource_; bool noMoreSplits_ = false; - // Dynamic filters to add to the data source when it gets created. - std::unordered_map> - dynamicFilters_; int32_t maxPreloadedSplits_{0}; diff --git a/velox/exec/TableWriteMerge.cpp b/velox/exec/TableWriteMerge.cpp index 69528e9338e3..cf64f84405ff 100644 --- a/velox/exec/TableWriteMerge.cpp +++ b/velox/exec/TableWriteMerge.cpp @@ -22,6 +22,7 @@ namespace facebook::velox::exec { namespace { + bool isSameCommitContext( const folly::dynamic& first, const folly::dynamic& second) { @@ -44,6 +45,7 @@ bool containsNonNullRows(const VectorPtr& vector) { } return false; } + } // namespace TableWriteMerge::TableWriteMerge( @@ -56,18 +58,26 @@ TableWriteMerge::TableWriteMerge( operatorId, tableWriteMergeNode->id(), "TableWriteMerge") { - VELOX_USER_CHECK(outputType_->equivalent( - *TableWriteTraits::outputType(tableWriteMergeNode->aggregationNode()))); - if (tableWriteMergeNode->aggregationNode() != nullptr) { - aggregation_ = std::make_unique( - operatorId, driverCtx, tableWriteMergeNode->aggregationNode()); + if (tableWriteMergeNode->outputType()->size() == 1) { + VELOX_USER_CHECK(!tableWriteMergeNode->hasColumnStatsSpec()); + } else { + VELOX_USER_CHECK(tableWriteMergeNode->outputType()->equivalent(*( + TableWriteTraits::outputType(tableWriteMergeNode->columnStatsSpec())))); + } + if (tableWriteMergeNode->hasColumnStatsSpec()) { + statsCollector_ = std::make_unique( + tableWriteMergeNode->columnStatsSpec().value(), + tableWriteMergeNode->sources()[0]->outputType(), + &operatorCtx_->driverCtx()->queryConfig(), + operatorCtx_->pool(), + &nonReclaimableSection_); } } void TableWriteMerge::initialize() { Operator::initialize(); - if (aggregation_ != nullptr) { - aggregation_->initialize(); + if (statsCollector_ != nullptr) { + statsCollector_->initialize(); } } @@ -76,8 +86,8 @@ void TableWriteMerge::addInput(RowVectorPtr input) { VELOX_CHECK_GT(input->size(), 0); if (isStatistics(input)) { - VELOX_CHECK_NOT_NULL(aggregation_); - aggregation_->addInput(input); + VELOX_CHECK_NOT_NULL(statsCollector_); + statsCollector_->addInput(input); return; } @@ -105,10 +115,16 @@ void TableWriteMerge::addInput(RowVectorPtr input) { void TableWriteMerge::noMoreInput() { Operator::noMoreInput(); - if (aggregation_ != nullptr) { - aggregation_->noMoreInput(); + if (statsCollector_ != nullptr) { + statsCollector_->noMoreInput(); + } +} + +void TableWriteMerge::close() { + if (statsCollector_ != nullptr) { + statsCollector_->close(); } - close(); + Operator::close(); } RowVectorPtr TableWriteMerge::getOutput() { @@ -121,11 +137,11 @@ RowVectorPtr TableWriteMerge::getOutput() { return nullptr; } - if (aggregation_ != nullptr && !aggregation_->isFinished()) { + if (statsCollector_ != nullptr && !statsCollector_->finished()) { const std::string commitContext = createTableCommitContext(false); return TableWriteTraits::createAggregationStatsOutput( outputType_, - aggregation_->getOutput(), + statsCollector_->getOutput(), StringView(commitContext), pool()); } diff --git a/velox/exec/TableWriteMerge.h b/velox/exec/TableWriteMerge.h index 9e915aa6574c..c5a0ade2ccdc 100644 --- a/velox/exec/TableWriteMerge.h +++ b/velox/exec/TableWriteMerge.h @@ -17,6 +17,7 @@ #pragma once #include "velox/core/PlanNode.h" +#include "velox/exec/ColumnStatsCollector.h" #include "velox/exec/Operator.h" namespace facebook::velox::exec { @@ -51,6 +52,8 @@ class TableWriteMerge : public Operator { return finished_; } + void close() override; + private: // Creates non-last output with fragments and last commit context only. RowVectorPtr createFragmentsOutput(); @@ -65,7 +68,7 @@ class TableWriteMerge : public Operator { // Check if the input is statistics input. bool isStatistics(RowVectorPtr input); - std::unique_ptr aggregation_; + std::unique_ptr statsCollector_; bool finished_{false}; // The sum of written rows. int64_t numRows_{0}; diff --git a/velox/exec/TableWriter.cpp b/velox/exec/TableWriter.cpp index 6ce173412350..7a202a8c4869 100644 --- a/velox/exec/TableWriter.cpp +++ b/velox/exec/TableWriter.cpp @@ -15,8 +15,6 @@ */ #include "velox/exec/TableWriter.h" - -#include "HashAggregation.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { @@ -24,7 +22,7 @@ namespace facebook::velox::exec { TableWriter::TableWriter( int32_t operatorId, DriverCtx* driverCtx, - const std::shared_ptr& tableWriteNode) + const core::TableWriteNodePtr& tableWriteNode) : Operator( driverCtx, tableWriteNode->outputType(), @@ -44,18 +42,22 @@ TableWriter::TableWriter( insertTableHandle_( tableWriteNode->insertTableHandle()->connectorInsertTableHandle()), commitStrategy_(tableWriteNode->commitStrategy()), - createTimeUs_(getCurrentTimeNano()) { + createTimeNs_(getCurrentTimeNano()) { setConnectorMemoryReclaimer(); if (tableWriteNode->outputType()->size() == 1) { - VELOX_USER_CHECK_NULL(tableWriteNode->aggregationNode()); + VELOX_USER_CHECK(!tableWriteNode->columnStatsSpec().has_value()); } else { VELOX_USER_CHECK(tableWriteNode->outputType()->equivalent( - *(TableWriteTraits::outputType(tableWriteNode->aggregationNode())))); + *(TableWriteTraits::outputType(tableWriteNode->columnStatsSpec())))); } - if (tableWriteNode->aggregationNode() != nullptr) { - aggregation_ = std::make_unique( - operatorId, driverCtx, tableWriteNode->aggregationNode()); + if (tableWriteNode->columnStatsSpec().has_value()) { + statsCollector_ = std::make_unique( + tableWriteNode->columnStatsSpec().value(), + tableWriteNode->sources()[0]->outputType(), + &operatorCtx_->driverCtx()->queryConfig(), + operatorCtx_->pool(), + &nonReclaimableSection_); } const auto& connectorId = tableWriteNode->insertTableHandle()->connectorId(); connector_ = connector::getConnector(connectorId); @@ -68,7 +70,7 @@ TableWriter::TableWriter( } void TableWriter::setTypeMappings( - const std::shared_ptr& tableWriteNode) { + const core::TableWriteNodePtr& tableWriteNode) { auto outputNames = tableWriteNode->columnNames(); auto outputTypes = tableWriteNode->columns()->children(); @@ -95,8 +97,8 @@ void TableWriter::initialize() { Operator::initialize(); VELOX_CHECK_NULL(dataSink_); createDataSink(); - if (aggregation_ != nullptr) { - aggregation_->initialize(); + if (statsCollector_ != nullptr) { + statsCollector_->initialize(); } } @@ -159,15 +161,15 @@ void TableWriter::addInput(RowVectorPtr input) { numWrittenRows_ += input->size(); updateStats(dataSink_->stats()); - if (aggregation_ != nullptr) { - aggregation_->addInput(input); + if (statsCollector_ != nullptr) { + statsCollector_->addInput(input); } } void TableWriter::noMoreInput() { Operator::noMoreInput(); - if (aggregation_ != nullptr) { - aggregation_->noMoreInput(); + if (statsCollector_ != nullptr) { + statsCollector_->noMoreInput(); } } @@ -185,11 +187,11 @@ RowVectorPtr TableWriter::getOutput() { return nullptr; } - if (aggregation_ != nullptr && !aggregation_->isFinished()) { + if (statsCollector_ != nullptr && !statsCollector_->finished()) { const std::string commitContext = createTableCommitContext(false); return TableWriteTraits::createAggregationStatsOutput( outputType_, - aggregation_->getOutput(), + statsCollector_->getOutput(), StringView(commitContext), pool()); } @@ -253,11 +255,12 @@ RowVectorPtr TableWriter::getOutput() { writtenRowsVector, fragmentsVector, commitContextVector}; // 4. Set null statistics columns. - if (aggregation_ != nullptr) { + if (statsCollector_ != nullptr) { for (int i = TableWriteTraits::kStatsChannel; i < outputType_->size(); ++i) { - columns.push_back(BaseVector::createNullConstant( - outputType_->childAt(i), writtenRowsVector->size(), pool())); + columns.push_back( + BaseVector::createNullConstant( + outputType_->childAt(i), writtenRowsVector->size(), pool())); } } @@ -271,14 +274,14 @@ std::string TableWriter::createTableCommitContext(bool lastOutput) { folly::dynamic::object (TableWriteTraits::kLifeSpanContextKey, "TaskWide") (TableWriteTraits::kTaskIdContextKey, connectorQueryCtx_->taskId()) - (TableWriteTraits::kCommitStrategyContextKey, commitStrategyToString(commitStrategy_)) + (TableWriteTraits::kCommitStrategyContextKey, connector::CommitStrategyName::toName(commitStrategy_)) (TableWriteTraits::klastPageContextKey, lastOutput)); // clang-format on } void TableWriter::updateStats(const connector::DataSink::Stats& stats) { const auto currentTimeNs = getCurrentTimeNano(); - VELOX_CHECK_GE(currentTimeNs, createTimeUs_); + VELOX_CHECK_GE(currentTimeNs, createTimeNs_); { auto lockedStats = stats_.wlock(); lockedStats->physicalWrittenBytes = stats.numWrittenBytes; @@ -312,10 +315,10 @@ void TableWriter::updateStats(const connector::DataSink::Stats& stats) { lockedStats->addRuntimeStat( kRunningWallNanos, RuntimeCounter( - currentTimeNs - createTimeUs_, RuntimeCounter::Unit::kNanos)); + currentTimeNs - createTimeNs_, RuntimeCounter::Unit::kNanos)); } if (!stats.spillStats.empty()) { - *spillStats_.wlock() += stats.spillStats; + *spillStats_->wlock() += stats.spillStats; } } @@ -325,8 +328,8 @@ void TableWriter::close() { // regular close. abortDataSink(); } - if (aggregation_ != nullptr) { - aggregation_->close(); + if (statsCollector_ != nullptr) { + statsCollector_->close(); } Operator::close(); } @@ -334,8 +337,9 @@ void TableWriter::close() { void TableWriter::setConnectorMemoryReclaimer() { VELOX_CHECK_NOT_NULL(connectorPool_); if (connectorPool_->parent()->reclaimer() != nullptr) { - connectorPool_->setReclaimer(TableWriter::ConnectorReclaimer::create( - spillConfig_, operatorCtx_->driverCtx(), this)); + connectorPool_->setReclaimer( + TableWriter::ConnectorReclaimer::create( + spillConfig_, operatorCtx_->driverCtx(), this)); } } @@ -403,105 +407,4 @@ uint64_t TableWriter::ConnectorReclaimer::reclaim( return ParallelMemoryReclaimer::reclaim(pool, targetBytes, maxWaitMs, stats); } -// static -RowVectorPtr TableWriteTraits::createAggregationStatsOutput( - RowTypePtr outputType, - RowVectorPtr aggregationOutput, - StringView tableCommitContext, - velox::memory::MemoryPool* pool) { - // TODO: record aggregation stats output time. - if (aggregationOutput == nullptr) { - return nullptr; - } - VELOX_CHECK_GT(aggregationOutput->childrenSize(), 0); - const vector_size_t numOutputRows = aggregationOutput->childAt(0)->size(); - std::vector columns; - for (int channel = 0; channel < outputType->size(); channel++) { - if (channel < TableWriteTraits::kContextChannel) { - // 1. Set null rows column. - // 2. Set null fragments column. - columns.push_back(BaseVector::createNullConstant( - outputType->childAt(channel), numOutputRows, pool)); - continue; - } - if (channel == TableWriteTraits::kContextChannel) { - // 3. Set commitcontext column. - columns.push_back(std::make_shared>( - pool, - numOutputRows, - false /*isNull*/, - VARBINARY(), - std::move(tableCommitContext))); - continue; - } - // 4. Set statistics columns. - columns.push_back( - aggregationOutput->childAt(channel - TableWriteTraits::kStatsChannel)); - } - return std::make_shared( - pool, outputType, nullptr, numOutputRows, columns); -} - -std::string TableWriteTraits::rowCountColumnName() { - static const std::string kRowCountName = "rows"; - return kRowCountName; -} - -std::string TableWriteTraits::fragmentColumnName() { - static const std::string kFragmentName = "fragments"; - return kFragmentName; -} - -std::string TableWriteTraits::contextColumnName() { - static const std::string kContextName = "commitcontext"; - return kContextName; -} - -const TypePtr& TableWriteTraits::rowCountColumnType() { - static const TypePtr kRowCountType = BIGINT(); - return kRowCountType; -} - -const TypePtr& TableWriteTraits::fragmentColumnType() { - static const TypePtr kFragmentType = VARBINARY(); - return kFragmentType; -} - -const TypePtr& TableWriteTraits::contextColumnType() { - static const TypePtr kContextType = VARBINARY(); - return kContextType; -} - -const RowTypePtr TableWriteTraits::outputType( - const std::shared_ptr& aggregationNode) { - static const auto kOutputTypeWithoutStats = - ROW({rowCountColumnName(), fragmentColumnName(), contextColumnName()}, - {rowCountColumnType(), fragmentColumnType(), contextColumnType()}); - if (aggregationNode == nullptr) { - return kOutputTypeWithoutStats; - } - return kOutputTypeWithoutStats->unionWith(aggregationNode->outputType()); -} - -folly::dynamic TableWriteTraits::getTableCommitContext( - const RowVectorPtr& input) { - VELOX_CHECK_GT(input->size(), 0); - auto* contextVector = - input->childAt(kContextChannel)->as>(); - return folly::parseJson(contextVector->valueAt(input->size() - 1)); -} - -int64_t TableWriteTraits::getRowCount(const RowVectorPtr& output) { - VELOX_CHECK_GT(output->size(), 0); - auto rowCountVector = - output->childAt(kRowCountChannel)->asFlatVector(); - VELOX_CHECK_NOT_NULL(rowCountVector); - int64_t rowCount{0}; - for (int i = 0; i < output->size(); ++i) { - if (!rowCountVector->isNullAt(i)) { - rowCount += rowCountVector->valueAt(i); - } - } - return rowCount; -} } // namespace facebook::velox::exec diff --git a/velox/exec/TableWriter.h b/velox/exec/TableWriter.h index 70505cb3fb5e..4692faec3129 100644 --- a/velox/exec/TableWriter.h +++ b/velox/exec/TableWriter.h @@ -16,93 +16,20 @@ #pragma once -#include "OperatorUtils.h" #include "velox/core/PlanNode.h" +#include "velox/core/TableWriteTraits.h" +#include "velox/exec/ColumnStatsCollector.h" #include "velox/exec/MemoryReclaimer.h" #include "velox/exec/Operator.h" namespace facebook::velox::exec { -/// Defines table writer output related config properties that are shared -/// between TableWriter and TableWriteMerger. -/// -/// TODO: the table write output processing is Prestissimo specific. Consider -/// move these part logic to Prestissimo and pass to Velox through a customized -/// output processing callback. -class TableWriteTraits { - public: - /// Defines the column names/types in table write output. - static std::string rowCountColumnName(); - static std::string fragmentColumnName(); - static std::string contextColumnName(); - - static const TypePtr& rowCountColumnType(); - static const TypePtr& fragmentColumnType(); - static const TypePtr& contextColumnType(); - - /// Defines the column channels in table write output. - /// Both the statistics and the row_count + fragments are transferred over the - /// same communication link between the TableWriter and TableFinish. Thus the - /// multiplexing is needed. - /// - /// The transferred page layout looks like: - /// [row_count_channel], [fragment_channel], [context_channel], - /// [statistic_channel_1] ... [statistic_channel_N]] - /// - /// [row_count_channel] - contains number of rows processed by a TableWriter - /// [fragment_channel] - contains data provided by the DataSink#finish - /// [statistic_channel_1] ...[statistic_channel_N] - - /// contain aggregated statistics computed by the statistics aggregation - /// within the TableWriter - /// - /// For convenience, we never set both: [row_count_channel] + - /// [fragment_channel] and the [statistic_channel_1] ... - /// [statistic_channel_N]. - /// - /// If this is a row that holds statistics - the [row_count_channel] + - /// [fragment_channel] will be NULL. - /// - /// If this is a row that holds the row count - /// or the fragment - all the statistics channels will be set to NULL. - static constexpr int32_t kRowCountChannel = 0; - static constexpr int32_t kFragmentChannel = 1; - static constexpr int32_t kContextChannel = 2; - static constexpr int32_t kStatsChannel = 3; - - /// Defines the names of metadata in commit context in table writer output. - static constexpr std::string_view kLifeSpanContextKey = "lifespan"; - static constexpr std::string_view kTaskIdContextKey = "taskId"; - static constexpr std::string_view kCommitStrategyContextKey = - "pageSinkCommitStrategy"; - static constexpr std::string_view klastPageContextKey = "lastPage"; - - static const RowTypePtr outputType( - const std::shared_ptr& aggregationNode = nullptr); - - /// Returns the parsed commit context from table writer 'output'. - static folly::dynamic getTableCommitContext(const RowVectorPtr& output); - - /// Returns the sum of row counts from table writer 'output'. - static int64_t getRowCount(const RowVectorPtr& output); - - /// Creates the statistics output. - /// Statistics page layout (aggregate by partition): - /// row fragments context [partition] stats1 stats2 ... - /// null null X [X] X X - /// null null X [X] X X - static RowVectorPtr createAggregationStatsOutput( - RowTypePtr outputType, - RowVectorPtr aggregationOutput, - StringView tableCommitContext, - velox::memory::MemoryPool* pool); -}; - class TableWriter : public Operator { public: TableWriter( int32_t operatorId, DriverCtx* driverCtx, - const std::shared_ptr& tableWriteNode); + const core::TableWriteNodePtr& tableWriteNode); BlockingReason isBlocked(ContinueFuture* future) override; @@ -137,6 +64,14 @@ class TableWriter : public Operator { // the table writer operator pool. So we report the memory usage from // 'connectorPool_'. stats.memoryStats = MemoryStats::memStatsFromPool(connectorPool_); + + if (FOLLY_LIKELY(dataSink_ != nullptr)) { + const auto connectorStats = dataSink_->runtimeStats(); + for (const auto& [name, counter] : connectorStats) { + stats.runtimeStats[name] = RuntimeMetric(counter.value, counter.unit); + } + } + return stats; } @@ -214,8 +149,7 @@ class TableWriter : public Operator { // Sets type mappings in `inputMapping_`, `mappedInputType_`, and // `mappedOutputType_`. - void setTypeMappings( - const std::shared_ptr& tableWriteNode); + void setTypeMappings(const core::TableWriteNodePtr& tableWriteNode); std::string createTableCommitContext(bool lastOutput); @@ -223,15 +157,14 @@ class TableWriter : public Operator { const DriverCtx* const driverCtx_; memory::MemoryPool* const connectorPool_; - const std::shared_ptr - insertTableHandle_; + const connector::ConnectorInsertTableHandlePtr insertTableHandle_; const connector::CommitStrategy commitStrategy_; // Records the writer operator creation time in ns. This is used to record - // the running wall time of a writer operator. This can helps to detect the + // the running wall time of a writer operator. This can help to detect the // slow scaled writer scheduling in Prestissimo. - const uint64_t createTimeUs_{0}; + const uint64_t createTimeNs_{0}; - std::unique_ptr aggregation_; + std::unique_ptr statsCollector_; std::shared_ptr connector_; std::shared_ptr connectorQueryCtx_; std::unique_ptr dataSink_; @@ -254,4 +187,10 @@ class TableWriter : public Operator { bool closed_{false}; vector_size_t numWrittenRows_{0}; }; + +// TODO: TableWriteTraits got moved to velox/core as it pertains to plan +// metadata, not execution. Maintaining the alias here in order not to break +// backward compatibility. +using core::TableWriteTraits; + } // namespace facebook::velox::exec diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index 3e6b49019271..62f8a81390d1 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -1,3 +1,4 @@ +#include "velox/common/Casts.h" /* * Copyright (c) Facebook, Inc. and its affiliates. * @@ -31,7 +32,10 @@ #include "velox/exec/OperatorUtils.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/PlanNodeStats.h" +#include "velox/exec/SpatialJoinBuild.h" +#include "velox/exec/TableScan.h" #include "velox/exec/Task.h" +#include "velox/exec/TaskTraceWriter.h" #include "velox/exec/TraceUtil.h" using facebook::velox::common::testutil::TestValue; @@ -39,6 +43,7 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec { namespace { + // RAII helper class to satisfy given promises and notify listeners of an event // connected to the promises outside of the mutex that guards the promises. // Inactive on creation. Must be activated explicitly by calling 'activate'. @@ -55,7 +60,7 @@ class EventCompletionNotifier { std::vector promises, std::function callback = nullptr) { active_ = true; - callback_ = callback; + callback_ = std::move(callback); promises_ = std::move(promises); } @@ -89,6 +94,13 @@ folly::Synchronized>>& listeners() { return kListeners; } +folly::Synchronized>>& +splitListenerFactories() { + static folly::Synchronized>> + kListenerFactories; + return kListenerFactories; +} + std::string errorMessageImpl(const std::exception_ptr& exception) { if (!exception) { return ""; @@ -164,20 +176,52 @@ bool isHashJoinOperator(const std::string& operatorType) { return (operatorType == "HashBuild") || (operatorType == "HashProbe"); } -// Moves split promises from one vector to another. -void movePromisesOut( - std::vector& from, - std::vector& to) { - if (to.empty()) { - to.swap(from); - return; +class QueueSplitsStore : public SplitsStore { + public: + using SplitsStore::SplitsStore; + + void requestBarrier(std::vector& promises) override { + addSplit(Split::createBarrier(), promises); + } + + bool nextSplit( + Split& split, + ContinueFuture& future, + int maxPreloadSplits, + const ConnectorSplitPreloadFunc& preload) override { + if (!splits_.empty()) { + split = getSplit(maxPreloadSplits, preload); + return true; + } + if (noMoreSplits_) { + return true; + } + future = makeFuture(); + return false; + } + + bool allSplitsConsumed() const override { + return noMoreSplits_ && splits_.empty(); } +}; - for (auto& promise : from) { - to.emplace_back(std::move(promise)); +void noMoreSplitsForStore( + SplitsStore* splitsStore, + std::vector& promises) { + if (!splitsStore) { + return; + } + auto newPromises = splitsStore->noMoreSplits(); + if (promises.empty()) { + promises.swap(newPromises); + return; + } + promises.reserve(promises.size() + newPromises.size()); + for (auto& promise : newPromises) { + promises.push_back(std::move(promise)); } - from.clear(); } + } // namespace std::string executionModeString(Task::ExecutionMode mode) { @@ -239,6 +283,34 @@ bool unregisterTaskListener(const std::shared_ptr& listener) { }); } +bool registerSplitListenerFactory( + const std::shared_ptr& factory) { + return splitListenerFactories().withWLock([&](auto& factories) { + for (const auto& existingFactory : factories) { + if (existingFactory == factory) { + // Listener already registered. Do not register again. + return false; + } + } + factories.emplace_back(factory); + return true; + }); +} + +bool unregisterSplitListenerFactory( + const std::shared_ptr& factory) { + return splitListenerFactories().withWLock([&](auto& factories) { + for (auto it = factories.begin(); it != factories.end(); ++it) { + if ((*it) == factory) { + factories.erase(it); + return true; + } + } + // Listener not found. + return false; + }); +} + // static std::shared_ptr Task::create( const std::string& taskId, @@ -248,8 +320,8 @@ std::shared_ptr Task::create( ExecutionMode mode, Consumer consumer, int32_t memoryArbitrationPriority, + std::optional spillDiskOpts, std::function onError) { - VELOX_CHECK_NOT_NULL(planFragment.planNode); return Task::create( taskId, std::move(planFragment), @@ -259,6 +331,7 @@ std::shared_ptr Task::create( (consumer ? [c = std::move(consumer)]() { return c; } : ConsumerSupplier{}), memoryArbitrationPriority, + std::move(spillDiskOpts), std::move(onError)); } @@ -271,6 +344,7 @@ std::shared_ptr Task::create( ExecutionMode mode, ConsumerSupplier consumerSupplier, int32_t memoryArbitrationPriority, + std::optional spillDiskOpts, std::function onError) { VELOX_CHECK_NOT_NULL(planFragment.planNode); auto task = std::shared_ptr(new Task( @@ -282,7 +356,7 @@ std::shared_ptr Task::create( std::move(consumerSupplier), memoryArbitrationPriority, std::move(onError))); - task->initTaskPool(); + task->init(std::move(spillDiskOpts)); task->addToTaskList(); return task; } @@ -303,9 +377,8 @@ Task::Task( memoryArbitrationPriority_(memoryArbitrationPriority), queryCtx_(std::move(queryCtx)), planFragment_(std::move(planFragment)), - supportBarrier_( - (mode_ == Task::ExecutionMode::kSerial) && - planFragment_.supportsBarrier()), + firstNodeNotSupportingBarrier_( + planFragment_.firstNodeNotSupportingBarrier()), traceConfig_(maybeMakeTraceConfig()), consumerSupplier_(std::move(consumerSupplier)), onError_(std::move(onError)), @@ -319,6 +392,19 @@ Task::Task( dynamic_cast(queryCtx_->executor())); } maybeInitTrace(); + + initSplitListeners(); +} + +void Task::initSplitListeners() { + splitListenerFactories().withRLock([&](const auto& factories) { + for (const auto& factory : factories) { + auto listener = factory->create(taskId_, uuid_, queryCtx_->queryConfig()); + if (listener != nullptr) { + splitListeners_.emplace_back(std::move(listener)); + } + } + }); } Task::~Task() { @@ -372,6 +458,85 @@ Task::~Task() { } } +void Task::ensureBarrierSupport() const { + VELOX_CHECK_EQ( + mode_, + Task::ExecutionMode::kSerial, + "Task doesn't support barriered execution."); + + VELOX_CHECK_NULL( + firstNodeNotSupportingBarrier_, + "Task doesn't support barriered execution. Name of the first node that " + "doesn't support barriered execution: {}", + firstNodeNotSupportingBarrier_->name()); +} + +void Task::init(std::optional&& spillDiskOpts) { + VELOX_CHECK(driverFactories_.empty()); + initTaskPool(); + + setSpillDiskConfig(std::move(spillDiskOpts)); + + if (mode_ != Task::ExecutionMode::kSerial) { + return; + } + + // Create drivers. + VELOX_CHECK_NULL( + consumerSupplier_, + "Serial execution mode doesn't support delivering results to a " + "callback"); + + taskStats_.executionStartTimeMs = getCurrentTimeMs(); + LocalPlanner::plan( + planFragment_, nullptr, &driverFactories_, queryCtx_->queryConfig(), 1); + exchangeClients_.resize(driverFactories_.size()); + + // In Task::next() we always assume ungrouped execution. + for (const auto& factory : driverFactories_) { + VELOX_CHECK(factory->supportsSerialExecution()); + numDriversUngrouped_ += factory->numDrivers; + numTotalDrivers_ += factory->numTotalDrivers; + taskStats_.pipelineStats.emplace_back( + factory->inputDriver, factory->outputDriver); + } + + // Create drivers. + createSplitGroupStateLocked(kUngroupedGroupId); + std::vector> drivers = + createDriversLocked(kUngroupedGroupId); + if (pool_->reservedBytes() != 0) { + VELOX_FAIL( + "Unexpected memory pool allocations during task[{}] driver initialization: {}", + taskId_, + pool_->treeMemoryUsage()); + } + + drivers_ = std::move(drivers); + driverBlockingStates_.reserve(drivers_.size()); + for (auto i = 0; i < drivers_.size(); ++i) { + driverBlockingStates_.emplace_back( + std::make_unique(drivers_[i].get())); + } +} + +void Task::setSpillDiskConfig( + std::optional&& spillDiskOpts) { + if (!spillDiskOpts.has_value()) { + return; + } + VELOX_CHECK( + !spillDiskOpts->spillDirPath.empty(), "Spill directory can't be empty"); + VELOX_CHECK( + spillDiskOpts->spillDirCreated || spillDiskOpts->spillDirCreateCb); + VELOX_CHECK_NULL(spillDirectoryCallback_); + VELOX_CHECK(!spillDirectoryCreated_); + VELOX_CHECK(spillDirectory_.empty()); + spillDirectory_ = std::move(spillDiskOpts->spillDirPath); + spillDirectoryCreated_ = spillDiskOpts->spillDirCreated; + spillDirectoryCallback_ = std::move(spillDiskOpts->spillDirCreateCb); +} + Task::TaskList& Task::taskList() { static TaskList taskList; return taskList; @@ -450,7 +615,7 @@ bool Task::allNodesReceivedNoMoreSplitsMessageLocked() const { const std::string& Task::getOrCreateSpillDirectory() { VELOX_CHECK( !spillDirectory_.empty() || spillDirectoryCallback_, - "Spill directory or spill directory callback must be set "); + "Spill directory or spill directory callback must be set"); if (spillDirectoryCreated_) { return spillDirectory_; } @@ -584,8 +749,9 @@ velox::memory::MemoryPool* Task::addOperatorPool( } else { nodePool = getOrAddNodePool(planNodeId); } - childPools_.push_back(nodePool->addLeafChild(fmt::format( - "op.{}.{}.{}.{}", planNodeId, pipelineId, driverId, operatorType))); + childPools_.push_back(nodePool->addLeafChild( + fmt::format( + "op.{}.{}.{}.{}", planNodeId, pipelineId, driverId, operatorType))); return childPools_.back().get(); } @@ -596,13 +762,14 @@ velox::memory::MemoryPool* Task::addConnectorPoolLocked( const std::string& operatorType, const std::string& connectorId) { auto* nodePool = getOrAddNodePool(planNodeId); - childPools_.push_back(nodePool->addAggregateChild(fmt::format( - "op.{}.{}.{}.{}.{}", - planNodeId, - pipelineId, - driverId, - operatorType, - connectorId))); + childPools_.push_back(nodePool->addAggregateChild( + fmt::format( + "op.{}.{}.{}.{}.{}", + planNodeId, + pipelineId, + driverId, + operatorType, + connectorId))); return childPools_.back().get(); } @@ -672,48 +839,8 @@ RowVectorPtr Task::next(ContinueFuture* future) { } } - // On first call, create the drivers. - if (driverFactories_.empty()) { - VELOX_CHECK_NULL( - consumerSupplier_, - "Serial execution mode doesn't support delivering results to a " - "callback"); - - taskStats_.executionStartTimeMs = getCurrentTimeMs(); - LocalPlanner::plan( - planFragment_, nullptr, &driverFactories_, queryCtx_->queryConfig(), 1); - exchangeClients_.resize(driverFactories_.size()); - - // In Task::next() we always assume ungrouped execution. - for (const auto& factory : driverFactories_) { - VELOX_CHECK(factory->supportsSerialExecution()); - numDriversUngrouped_ += factory->numDrivers; - numTotalDrivers_ += factory->numTotalDrivers; - taskStats_.pipelineStats.emplace_back( - factory->inputDriver, factory->outputDriver); - } - - // Create drivers. - createSplitGroupStateLocked(kUngroupedGroupId); - std::vector> drivers = - createDriversLocked(kUngroupedGroupId); - if (pool_->reservedBytes() != 0) { - VELOX_FAIL( - "Unexpected memory pool allocations during task[{}] driver initialization: {}", - taskId_, - pool_->treeMemoryUsage()); - } - - drivers_ = std::move(drivers); - driverBlockingStates_.reserve(drivers_.size()); - for (auto i = 0; i < drivers_.size(); ++i) { - driverBlockingStates_.emplace_back( - std::make_unique(drivers_[i].get())); - } - if (underBarrier()) { - startDriverBarriersLocked(); - } - } + VELOX_CHECK_EQ( + state_, TaskState::kRunning, "Task has already finished processing."); // Run drivers one at a time. If a driver blocks, continue running the other // drivers. Running other drivers is expected to unblock some or all blocked @@ -1164,6 +1291,8 @@ void Task::createSplitGroupStateLocked(uint32_t splitGroupId) { addHashJoinBridgesLocked(splitGroupId, factory->needsHashJoinBridges()); addNestedLoopJoinBridgesLocked( splitGroupId, factory->needsNestedLoopJoinBridges()); + addSpatialJoinBridgesLocked( + splitGroupId, factory->needsSpatialJoinBridges()); addCustomJoinBridgesLocked(splitGroupId, factory->planNodes); core::PlanNodeId tableScanNodeId; @@ -1196,6 +1325,7 @@ std::vector> Task::createDriversLocked( // execution, from the split group id. const uint32_t driverIdOffset = factory->numDrivers * (groupedExecutionDrivers ? splitGroupId : 0); + auto filters = std::make_shared(); for (uint32_t partitionId = 0; partitionId < factory->numDrivers; ++partitionId) { drivers.emplace_back(factory->createDriver( @@ -1206,6 +1336,7 @@ std::vector> Task::createDriversLocked( splitGroupId, partitionId), getExchangeClientLocked(pipeline), + filters, [self](size_t i) { return i < self->driverFactories_.size() ? self->driverFactories_[i]->numTotalDrivers @@ -1358,14 +1489,23 @@ void Task::setMaxSplitSequenceId( } } +void Task::onAddSplit( + const core::PlanNodeId& planNodeId, + const exec::Split& split) { + for (auto& listener : splitListeners_) { + listener->onAddSplit(planNodeId, split); + } +} + bool Task::addSplitWithSequence( const core::PlanNodeId& planNodeId, exec::Split&& split, long sequenceId) { RECORD_METRIC_VALUE(kMetricTaskSplitsCount, 1); - std::unique_ptr promise; + std::vector promises; bool added = false; bool isTaskRunning; + bool shouldLogSplit = false; { std::lock_guard l(mutex_); isTaskRunning = isRunningLocked(); @@ -1375,14 +1515,15 @@ bool Task::addSplitWithSequence( // duplicate splits would be ignored. auto& splitsState = getPlanNodeSplitsStateLocked(planNodeId); if (sequenceId > splitsState.maxSequenceId) { - promise = addSplitLocked(splitsState, std::move(split)); + shouldLogSplit = true; + addSplitLocked(splitsState, split, promises); added = true; } } } - if (promise) { - promise->setValue(); + for (auto& promise : promises) { + promise.setValue(); } if (!isTaskRunning) { @@ -1391,24 +1532,29 @@ bool Task::addSplitWithSequence( addRemoteSplit(planNodeId, split); } + if (shouldLogSplit) { + onAddSplit(planNodeId, split); + } + return added; } void Task::addSplit(const core::PlanNodeId& planNodeId, exec::Split&& split) { RECORD_METRIC_VALUE(kMetricTaskSplitsCount, 1); bool isTaskRunning; - std::unique_ptr promise; + bool shouldLogSplit = false; + std::vector promises; { std::lock_guard l(mutex_); isTaskRunning = isRunningLocked(); if (isTaskRunning) { - promise = addSplitLocked( - getPlanNodeSplitsStateLocked(planNodeId), std::move(split)); + shouldLogSplit = true; + addSplitLocked(getPlanNodeSplitsStateLocked(planNodeId), split, promises); } } - if (promise) { - promise->setValue(); + for (auto& promise : promises) { + promise.setValue(); } if (!isTaskRunning) { @@ -1416,6 +1562,10 @@ void Task::addSplit(const core::PlanNodeId& planNodeId, exec::Split&& split) { // @lint-ignore CLANGTIDY bugprone-use-after-move addRemoteSplit(planNodeId, split); } + + if (shouldLogSplit) { + onAddSplit(planNodeId, split); + } } void Task::addRemoteSplit( @@ -1432,15 +1582,16 @@ void Task::addRemoteSplit( } } -std::unique_ptr Task::addSplitLocked( +void Task::addSplitLocked( SplitsState& splitsState, - exec::Split&& split) { + const exec::Split& split, + std::vector& promises) { if (split.isBarrier()) { - VELOX_CHECK(supportBarrier_); + ensureBarrierSupport(); VELOX_CHECK(splitsState.sourceIsTableScan); VELOX_CHECK(!splitsState.noMoreSplits); - return addSplitToStoreLocked( - splitsState.groupSplitsStores[kUngroupedGroupId], std::move(split)); + addSplitToStoreLocked(splitsState, kUngroupedGroupId, split, promises); + return; } VELOX_CHECK( !barrierRequested_, "Can't add new split under barrier processing"); @@ -1458,8 +1609,8 @@ std::unique_ptr Task::addSplitLocked( } if (!split.hasGroup()) { - return addSplitToStoreLocked( - splitsState.groupSplitsStores[kUngroupedGroupId], std::move(split)); + addSplitToStoreLocked(splitsState, kUngroupedGroupId, split, promises); + return; } const auto splitGroupId = split.groupId; @@ -1471,21 +1622,27 @@ std::unique_ptr Task::addSplitLocked( // We might have some free driver slots to process this split group. ensureSplitGroupsAreBeingProcessedLocked(); } - return addSplitToStoreLocked( - splitsState.groupSplitsStores[splitGroupId], std::move(split)); + addSplitToStoreLocked(splitsState, splitGroupId, split, promises); } -std::unique_ptr Task::addSplitToStoreLocked( - SplitsStore& splitsStore, - exec::Split&& split) { - splitsStore.splits.push_back(split); - if (splitsStore.splitPromises.empty()) { - return nullptr; +void Task::addSplitToStoreLocked( + SplitsState& splitsState, + uint32_t groupId, + const exec::Split& split, + std::vector& promises) { + auto& splitsStore = splitsState.groupSplitsStores[groupId]; + if (!splitsStore) { + setSplitsStore( + splitsStore, + std::make_unique(!splitsState.sourceIsTableScan)); + } + if (split.isBarrier()) { + splitsStore->requestBarrier(promises); + return; } - auto promise = std::make_unique( - std::move(splitsStore.splitPromises.back())); - splitsStore.splitPromises.pop_back(); - return promise; + auto* queueSplitsStore = + checkedPointerCast(splitsStore.get()); + queueSplitsStore->addSplit(split, promises); } void Task::noMoreSplitsForGroup( @@ -1497,9 +1654,8 @@ void Task::noMoreSplitsForGroup( std::lock_guard l(mutex_); auto& splitsState = getPlanNodeSplitsStateLocked(planNodeId); - auto& splitsStore = splitsState.groupSplitsStores[splitGroupId]; - splitsStore.noMoreSplits = true; - promises = std::move(splitsStore.splitPromises); + noMoreSplitsForStore( + splitsState.groupSplitsStores[splitGroupId].get(), promises); // There were no splits in this group, hence, no active drivers. Mark the // group complete. @@ -1538,21 +1694,23 @@ void Task::noMoreSplits(const core::PlanNodeId& planNodeId) { "Expect 1 split store in a plan node in ungrouped execution mode, has {}", splitsState.groupSplitsStores.size()); auto it = splitsState.groupSplitsStores.begin(); - it->second.noMoreSplits = true; - splitPromises.swap(it->second.splitPromises); + noMoreSplitsForStore(it->second.get(), splitPromises); } else { // For an ungrouped execution plan node, in the unlikely case when there // are no split stores created (this means there were no splits at all), // we create one. - splitsState.groupSplitsStores.emplace( - kUngroupedGroupId, SplitsStore{{}, true, {}}); + auto queueSplitsStore = + std::make_unique(!splitsState.sourceIsTableScan); + queueSplitsStore->noMoreSplits(); + setSplitsStore( + splitsState.groupSplitsStores[kUngroupedGroupId], + std::move(queueSplitsStore)); } } else { // Grouped execution branch. // Mark all split stores as 'no more splits'. for (auto& it : splitsState.groupSplitsStores) { - it.second.noMoreSplits = true; - movePromisesOut(it.second.splitPromises, splitPromises); + noMoreSplitsForStore(it.second.get(), splitPromises); } } @@ -1576,17 +1734,35 @@ void Task::noMoreSplits(const core::PlanNodeId& planNodeId) { } } +void Task::setSplitsStore( + const core::PlanNodeId& planNodeId, + std::unique_ptr newSplitsStore) { + std::lock_guard lk(mutex_); + auto& splitsState = getPlanNodeSplitsStateLocked(planNodeId); + auto& splitsStore = splitsState.groupSplitsStores[kUngroupedGroupId]; + VELOX_CHECK_NULL(splitsStore); + setSplitsStore(splitsStore, std::move(newSplitsStore)); +} + +void Task::setSplitsStore( + std::unique_ptr& splitsStore, + std::unique_ptr newSplitsStore) { + splitsStore = std::move(newSplitsStore); + splitsStore->setTaskStats(taskStats_); + splitsStore->setPreloadingSplits(preloadingSplits_); +} + ContinueFuture Task::requestBarrier() { - VELOX_CHECK(supportBarrier_, "Task doesn't support barrier"); + ensureBarrierSupport(); return startBarrier("Task::requestBarrier"); } ContinueFuture Task::startBarrier(std::string_view comment) { - VELOX_CHECK(supportBarrier_); - std::vector> promises; + ensureBarrierSupport(); + std::vector promises; SCOPE_EXIT { for (auto& promise : promises) { - promise->setValue(); + promise.setValue(); } }; @@ -1595,7 +1771,7 @@ ContinueFuture Task::startBarrier(std::string_view comment) { auto [promise, future] = makeVeloxContinuePromiseContract(std::string{comment}); if (!isRunningLocked()) { - promises.push_back(std::make_unique(std::move(promise))); + promises.push_back(std::move(promise)); return std::move(future); } @@ -1620,10 +1796,7 @@ ContinueFuture Task::startBarrier(std::string_view comment) { for (const auto& leafPlanNode : leafPlanNodeIds) { auto barrierSplit = Split::createBarrier(); auto& splitState = getPlanNodeSplitsStateLocked(leafPlanNode); - auto promise = addSplitLocked(splitState, std::move(barrierSplit)); - if (promise != nullptr) { - promises.push_back(std::move(promise)); - } + addSplitLocked(splitState, barrierSplit, promises); } startDriverBarriersLocked(); return std::move(future); @@ -1771,6 +1944,11 @@ bool Task::checkNoMoreSplitGroupsLocked() { return false; } +bool Task::testingAllSplitsFinished() { + std::lock_guard l(mutex_); + return isAllSplitsFinishedLocked(); +} + bool Task::isAllSplitsFinishedLocked() { return (taskStats_.numFinishedSplits == taskStats_.numTotalSplits) && allNodesReceivedNoMoreSplitsMessageLocked(); @@ -1785,42 +1963,22 @@ BlockingReason Task::getSplitOrFuture( const ConnectorSplitPreloadFunc& preload) { std::lock_guard l(mutex_); auto& splitsState = getPlanNodeSplitsStateLocked(planNodeId); - return getSplitOrFutureLocked( - splitsState.sourceIsTableScan, - splitsState.groupSplitsStores[splitGroupId], - split, - future, - maxPreloadSplits, - preload); -} - -BlockingReason Task::getSplitOrFutureLocked( - bool forTableScan, - SplitsStore& splitsStore, - exec::Split& split, - ContinueFuture& future, - int32_t maxPreloadSplits, - const ConnectorSplitPreloadFunc& preload) { - if (splitsStore.splits.empty()) { - if (splitsStore.noMoreSplits) { - return BlockingReason::kNotBlocked; - } - auto [splitPromise, splitFuture] = makeVeloxContinuePromiseContract( - fmt::format("Task::getSplitOrFuture {}", taskId_)); - future = std::move(splitFuture); - splitsStore.splitPromises.push_back(std::move(splitPromise)); - return BlockingReason::kWaitForSplit; + auto& splitsStore = splitsState.groupSplitsStores[splitGroupId]; + if (!splitsStore) { + setSplitsStore( + splitsStore, + std::make_unique(!splitsState.sourceIsTableScan)); } - - split = getSplitLocked(forTableScan, splitsStore, maxPreloadSplits, preload); - return BlockingReason::kNotBlocked; + return splitsStore->nextSplit(split, future, maxPreloadSplits, preload) + ? BlockingReason::kNotBlocked + : BlockingReason::kWaitForSplit; } bool Task::testingHasDriverWaitForSplit() const { std::lock_guard l(mutex_); for (const auto& splitState : splitsStates_) { - for (const auto& splitStore : splitState.second.groupSplitsStores) { - if (!splitStore.second.splitPromises.empty()) { + for (const auto& [_, splitStore] : splitState.second.groupSplitsStores) { + if (splitStore && splitStore->numWaiters() > 0) { return true; } } @@ -1828,54 +1986,6 @@ bool Task::testingHasDriverWaitForSplit() const { return false; } -exec::Split Task::getSplitLocked( - bool forTableScan, - SplitsStore& splitsStore, - int32_t maxPreloadSplits, - const ConnectorSplitPreloadFunc& preload) { - int32_t readySplitIndex = -1; - if (maxPreloadSplits > 0) { - for (auto i = 0; i < splitsStore.splits.size() && i < maxPreloadSplits; - ++i) { - if (splitsStore.splits[i].isBarrier()) { - continue; - } - auto& connectorSplit = splitsStore.splits[i].connectorSplit; - if (!connectorSplit->dataSource) { - // Initializes split->dataSource. - preload(connectorSplit); - preloadingSplits_.emplace(connectorSplit); - } else if ( - (readySplitIndex == -1) && (connectorSplit->dataSource->hasValue())) { - readySplitIndex = i; - preloadingSplits_.erase(connectorSplit); - } - } - } - if (readySplitIndex == -1) { - readySplitIndex = 0; - } - VELOX_CHECK(!splitsStore.splits.empty()); - auto split = std::move(splitsStore.splits[readySplitIndex]); - splitsStore.splits.erase(splitsStore.splits.begin() + readySplitIndex); - - --taskStats_.numQueuedSplits; - ++taskStats_.numRunningSplits; - if (forTableScan && split.connectorSplit) { - --taskStats_.numQueuedTableScanSplits; - ++taskStats_.numRunningTableScanSplits; - taskStats_.queuedTableScanSplitWeights -= split.connectorSplit->splitWeight; - taskStats_.runningTableScanSplitWeights += - split.connectorSplit->splitWeight; - } - taskStats_.lastSplitStartTimeMs = getCurrentTimeMs(); - if (taskStats_.firstSplitStartTimeMs == 0) { - taskStats_.firstSplitStartTimeMs = taskStats_.lastSplitStartTimeMs; - } - - return split; -} - std::shared_ptr Task::getScaledScanControllerLocked( uint32_t splitGroupId, const core::PlanNodeId& planNodeId) { @@ -1915,9 +2025,6 @@ void Task::splitFinished(bool fromTableScan, int64_t splitWeight) { --taskStats_.numRunningTableScanSplits; taskStats_.runningTableScanSplitWeights -= splitWeight; } - if (isAllSplitsFinishedLocked()) { - taskStats_.executionEndTimeMs = getCurrentTimeMs(); - } } void Task::multipleSplitsFinished( @@ -1931,9 +2038,6 @@ void Task::multipleSplitsFinished( taskStats_.numRunningTableScanSplits -= numSplits; taskStats_.runningTableScanSplitWeights -= splitsWeight; } - if (isAllSplitsFinishedLocked()) { - taskStats_.executionEndTimeMs = getCurrentTimeMs(); - } } bool Task::isGroupedExecution() const { @@ -1983,7 +2087,7 @@ bool Task::allSplitsConsumedHelper(const core::PlanNode* planNode) const { VELOX_CHECK_NE(splitsStates_.count(planNodeId), 0); for (const auto& [_, splitsStore] : splitsStates_.at(planNodeId).groupSplitsStores) { - if (!splitsStore.splits.empty()) { + if (splitsStore && !splitsStore->allSplitsConsumed()) { return false; } } @@ -2081,13 +2185,6 @@ bool Task::checkIfFinishedLocked() { if (splitGroupStates_[kUngroupedGroupId].numFinishedOutputDrivers == numDrivers(outputPipelineId)) { allFinished = true; - - if (taskStats_.executionEndTimeMs == 0) { - // In case we haven't set executionEndTimeMs due to all splits - // depleted, we set it here. This can happen due to task error or task - // being cancelled. - taskStats_.executionEndTimeMs = getCurrentTimeMs(); - } } } @@ -2220,6 +2317,20 @@ void Task::addNestedLoopJoinBridgesLocked( } } +void Task::addSpatialJoinBridgesLocked( + uint32_t splitGroupId, + const std::vector& planNodeIds) { + auto& splitGroupState = splitGroupStates_[splitGroupId]; + for (const auto& planNodeId : planNodeIds) { + auto const inserted = + splitGroupState.bridges + .emplace(planNodeId, std::make_shared()) + .second; + VELOX_CHECK( + inserted, "Join bridge for node {} is already present", planNodeId); + } +} + std::shared_ptr Task::getHashJoinBridge( uint32_t splitGroupId, const core::PlanNodeId& planNodeId) { @@ -2239,6 +2350,12 @@ std::shared_ptr Task::getNestedLoopJoinBridge( return getJoinBridgeInternal(splitGroupId, planNodeId); } +std::shared_ptr Task::getSpatialJoinBridge( + uint32_t splitGroupId, + const core::PlanNodeId& planNodeId) { + return getJoinBridgeInternal(splitGroupId, planNodeId); +} + template std::shared_ptr Task::getJoinBridgeInternal( uint32_t splitGroupId, @@ -2252,7 +2369,7 @@ template std::shared_ptr Task::getJoinBridgeInternalLocked( uint32_t splitGroupId, const core::PlanNodeId& planNodeId, - MemberType SplitGroupState::*bridges_member) { + MemberType SplitGroupState::* bridges_member) { const auto& splitGroupState = splitGroupStates_[splitGroupId]; auto it = (splitGroupState.*bridges_member).find(planNodeId); @@ -2335,9 +2452,9 @@ ContinueFuture Task::terminate(TaskState terminalState) { cancellationSource_.requestCancellation(); } - LOG(INFO) << "Terminating task " << taskId() << " with state " - << taskStateString(state_) << " after running for " - << succinctMillis(timeSinceStartMsLocked()); + VLOG(1) << "Terminating task " << taskId() << " with state " + << taskStateString(state_) << " after running for " + << succinctMillis(timeSinceStartMsLocked()); taskCompletionNotifier.activate( std::move(taskCompletionPromises_), [&]() { onTaskCompletion(); }); @@ -2413,25 +2530,27 @@ ContinueFuture Task::terminate(TaskState terminalState) { } // Collect all outstanding split promises from all splits state structures. - for (auto& pair : splitsStates_) { - auto& splitState = pair.second; - for (auto& it : pair.second.groupSplitsStores) { - movePromisesOut(it.second.splitPromises, splitPromises); + for (auto& [nodeId, state] : splitsStates_) { + for (auto& [_, store] : state.groupSplitsStores) { + noMoreSplitsForStore(store.get(), splitPromises); } // Process remaining remote splits. - if (getExchangeClientLocked(pair.first) != nullptr) { + if (getExchangeClientLocked(nodeId) != nullptr) { std::vector splits; - for (auto& [groupId, store] : splitState.groupSplitsStores) { - while (!store.splits.empty()) { - splits.emplace_back(getSplitLocked( - splitState.sourceIsTableScan, store, 0, nullptr)); + for (auto& [groupId, store] : state.groupSplitsStores) { + if (!store) { + continue; + } + while (!store->allSplitsConsumed()) { + auto future = ContinueFuture::makeEmpty(); + VELOX_CHECK( + store->nextSplit(splits.emplace_back(), future, 0, nullptr)); } } if (!splits.empty()) { remainingRemoteSplits.emplace( - pair.first, - std::make_pair(std::move(splits), splitState.noMoreSplits)); + nodeId, std::make_pair(std::move(splits), state.noMoreSplits)); } } } @@ -2691,6 +2810,10 @@ void Task::onTaskCompletion() { exchangeClientByPlanNode_); } }); + + for (auto& listener : splitListeners_) { + listener->onTaskCompletion(); + } } ContinueFuture Task::stateChangeFuture(uint64_t maxWaitMicros) { @@ -2850,8 +2973,9 @@ folly::dynamic Task::toJson() const { std::shared_ptr Task::addLocalMergeSource( uint32_t splitGroupId, const core::PlanNodeId& planNodeId, - const RowTypePtr& rowType) { - auto source = MergeSource::createLocalMergeSource(); + const RowTypePtr& rowType, + int queueSize) { + auto source = MergeSource::createLocalMergeSource(queueSize); splitGroupStates_[splitGroupId].localMergeSources[planNodeId].push_back( source); return source; @@ -2912,8 +3036,9 @@ void Task::createLocalExchangeQueuesLocked( queryCtx_->queryConfig().maxLocalExchangeBufferSize()); exchange.queues.reserve(numPartitions); for (auto i = 0; i < numPartitions; ++i) { - exchange.queues.emplace_back(std::make_shared( - exchange.memoryManager, exchange.vectorPool, i)); + exchange.queues.emplace_back( + std::make_shared( + exchange.memoryManager, exchange.vectorPool, i)); } const auto partitionNode = @@ -3302,7 +3427,9 @@ void Task::createExchangeClientLocked( queryCtx()->queryConfig().minExchangeOutputBatchBytes(), addExchangeClientPool(planNodeId, pipelineId), queryCtx()->executor(), - queryCtx()->queryConfig().requestDataSizesMaxWaitSec()); + queryCtx()->queryConfig().requestDataSizesMaxWaitSec(), + queryCtx()->queryConfig().singleSourceExchangeOptimizationEnabled(), + queryCtx()->queryConfig().exchangeLazyFetchingEnabled()); exchangeClientByPlanNode_.emplace(planNodeId, exchangeClients_[pipelineId]); } @@ -3339,40 +3466,23 @@ std::optional Task::maybeMakeTraceConfig() const { return std::nullopt; } - const auto traceNodes = queryConfig.queryTraceNodeIds(); - VELOX_USER_CHECK(!traceNodes.empty(), "Query trace nodes are not set"); + const auto traceNodeId = queryConfig.queryTraceNodeId(); + VELOX_USER_CHECK(!traceNodeId.empty(), "Query trace node ID are not set"); const auto traceDir = trace::getTaskTraceDirectory( queryConfig.queryTraceDir(), queryCtx_->queryId(), taskId_); - std::vector traceNodeIds; - folly::split(',', traceNodes, traceNodeIds); - std::unordered_set traceNodeIdSet( - traceNodeIds.begin(), traceNodeIds.end()); - VELOX_USER_CHECK_EQ( - traceNodeIdSet.size(), - traceNodeIds.size(), - "Duplicate trace nodes found: {}", - folly::join(", ", traceNodeIds)); - - bool foundTraceNode{false}; - for (const auto& traceNodeId : traceNodeIds) { - if (core::PlanNode::findFirstNode( - planFragment_.planNode.get(), - [traceNodeId](const core::PlanNode* node) -> bool { - return node->id() == traceNodeId; - })) { - foundTraceNode = true; - break; - } - } - VELOX_USER_CHECK( - foundTraceNode, - "Trace plan nodes not found from task {}: {}", - taskId_, - folly::join(",", traceNodeIdSet)); - - LOG(INFO) << "Trace input for plan nodes " << traceNodes << " from task " + VELOX_USER_CHECK_NOT_NULL( + core::PlanNode::findFirstNode( + planFragment_.planNode.get(), + [traceNodeId](const core::PlanNode* node) -> bool { + return node->id() == traceNodeId; + }), + "Trace plan node ID = {} not found from task {}", + traceNodeId, + taskId_); + + LOG(INFO) << "Trace input for plan nodes " << traceNodeId << " from task " << taskId_; UpdateAndCheckTraceLimitCB updateAndCheckTraceLimitCB = @@ -3380,10 +3490,11 @@ std::optional Task::maybeMakeTraceConfig() const { queryCtx_->updateTracedBytesAndCheckLimit(bytes); }; return TraceConfig( - std::move(traceNodeIdSet), + traceNodeId, traceDir, std::move(updateAndCheckTraceLimitCB), - queryConfig.queryTraceTaskRegExp()); + queryConfig.queryTraceTaskRegExp(), + queryConfig.queryTraceDryRun()); } void Task::maybeInitTrace() { @@ -3394,7 +3505,9 @@ void Task::maybeInitTrace() { trace::createTraceDirectory(traceConfig_->queryTraceDir); const auto metadataWriter = std::make_unique( traceConfig_->queryTraceDir, memory::traceMemoryPool()); - metadataWriter->write(queryCtx_, planFragment_.planNode); + auto traceNode = + trace::getTraceNode(planFragment_.planNode, traceConfig_->queryNodeId); + metadataWriter->write(queryCtx_, traceNode); } void Task::testingVisitDrivers(const std::function& callback) { @@ -3605,8 +3718,8 @@ bool Task::DriverBlockingState::blocked(ContinueFuture* future) { VELOX_CHECK(promises_.empty()); return false; } - auto [blockPromise, blockFuture] = - makeVeloxContinuePromiseContract(fmt::format( + auto [blockPromise, blockFuture] = makeVeloxContinuePromiseContract( + fmt::format( "DriverBlockingState {} from task {}", driver_->driverCtx()->driverId, driver_->task()->taskId())); diff --git a/velox/exec/Task.h b/velox/exec/Task.h index 9cbb73d98eff..72bc266d1028 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/common/base/SkewedPartitionBalancer.h" #include "velox/common/base/TraceConfig.h" #include "velox/core/PlanFragment.h" @@ -23,11 +25,9 @@ #include "velox/exec/LocalPartition.h" #include "velox/exec/MemoryReclaimer.h" #include "velox/exec/MergeSource.h" -#include "velox/exec/Split.h" -#include "velox/exec/TableScan.h" +#include "velox/exec/ScaledScanController.h" #include "velox/exec/TaskStats.h" #include "velox/exec/TaskStructs.h" -#include "velox/exec/TaskTraceWriter.h" #include "velox/vector/ComplexVector.h" namespace facebook::velox::exec { @@ -36,9 +36,8 @@ class OutputBufferManager; class HashJoinBridge; class NestedLoopJoinBridge; - -using ConnectorSplitPreloadFunc = - std::function&)>; +class SpatialJoinBridge; +class SplitListener; class Task : public std::enable_shared_from_this { public: @@ -69,11 +68,16 @@ class Task : public std::enable_shared_from_this { /// @param consumer Optional factory function to get callbacks to pass the /// results of the execution. In a parallel execution mode, results from each /// thread are passed on to a separate consumer. + /// @param memoryArbitrationPriority Priority used by the memory arbitrator + /// to determine which task should have its memory reclaimed first when the + /// system is under memory pressure. Higher values indicate higher priority + /// (lower likelihood of being reclaimed). Default is 0. + /// @param spillDiskOpts Optional configuration for spill disk storage. When + /// provided, allows operators to spill intermediate data to disk during + /// execution when memory pressure is high. Includes spill directory path + /// and callback options. Default is std::nullopt (no spilling). /// @param onError Optional callback to receive an exception if task /// execution fails. - /// @param memoryArbitrationPriority Optional priority on task that, in a - /// multi task system, is used for memory arbitration to decide the order of - /// reclaiming. static std::shared_ptr create( const std::string& taskId, core::PlanFragment planFragment, @@ -82,6 +86,7 @@ class Task : public std::enable_shared_from_this { ExecutionMode mode, Consumer consumer = nullptr, int32_t memoryArbitrationPriority = 0, + std::optional spillDiskOpts = std::nullopt, std::function onError = nullptr); static std::shared_ptr create( @@ -92,6 +97,7 @@ class Task : public std::enable_shared_from_this { ExecutionMode mode, ConsumerSupplier consumerSupplier, int32_t memoryArbitrationPriority = 0, + std::optional spillDiskOpts = std::nullopt, std::function onError = nullptr); /// Convenience function for shortening a Presto taskId. To be used @@ -100,22 +106,6 @@ class Task : public std::enable_shared_from_this { ~Task(); - /// Specify directory to which data will be spilled if spilling is enabled and - /// required. Set 'alreadyCreated' to true if the directory has already been - /// created by the caller. - void setSpillDirectory( - const std::string& spillDirectory, - bool alreadyCreated = true) { - spillDirectory_ = spillDirectory; - spillDirectoryCreated_ = alreadyCreated; - } - - void setCreateSpillDirectoryCb( - std::function spillDirectoryCallback) { - VELOX_CHECK_NULL(spillDirectoryCallback_); - spillDirectoryCallback_ = std::move(spillDirectoryCallback); - } - /// Returns human-friendly representation of the plan augmented with runtime /// statistics. The implementation invokes exec::printPlanWithStats(). /// @@ -271,6 +261,14 @@ class Task : public std::enable_shared_from_this { /// corresponding to plan node with specified ID. void noMoreSplits(const core::PlanNodeId& planNodeId); + /// Use a customized split store implementation to replace the default queue + /// split store for the specified node. The customized split store should + /// have its own source of splits and no longer need addSplit() call from + /// the task. + void setSplitsStore( + const core::PlanNodeId& planNodeId, + std::unique_ptr splitsStore); + /// Updates the total number of output buffers to broadcast or arbitrarily /// distribute the results of the execution to. Used when plan tree ends with /// a PartitionedOutputNode with broadcast of arbitrary output type. @@ -465,7 +463,8 @@ class Task : public std::enable_shared_from_this { std::shared_ptr addLocalMergeSource( uint32_t splitGroupId, const core::PlanNodeId& planNodeId, - const RowTypePtr& rowType); + const RowTypePtr& rowType, + int queueSize); /// Returns all MergeSource's for the specified splitGroupId and planNodeId. const std::vector>& getLocalMergeSources( @@ -514,7 +513,7 @@ class Task : public std::enable_shared_from_this { void setError(const std::string& message); /// Returns all the peer operators of the 'caller' operator from a given - /// 'pipelindId' in this task. + /// 'pipelineId' in this task. std::vector findPeerOperators(int pipelineId, Operator* caller); /// Synchronizes completion of an Operator across Drivers of 'this'. @@ -555,6 +554,11 @@ class Task : public std::enable_shared_from_this { uint32_t splitGroupId, const std::vector& planNodeIds); + /// Adds SpatialJoinBridge's for all the specified plan node IDs. + void addSpatialJoinBridgesLocked( + uint32_t splitGroupId, + const std::vector& planNodeIds); + /// Adds custom join bridges for all the specified plan nodes. void addCustomJoinBridgesLocked( uint32_t splitGroupId, @@ -577,6 +581,11 @@ class Task : public std::enable_shared_from_this { uint32_t splitGroupId, const core::PlanNodeId& planNodeId); + /// Returns a SpatialJoinBridge for 'planNodeId'. + std::shared_ptr getSpatialJoinBridge( + uint32_t splitGroupId, + const core::PlanNodeId& planNodeId); + /// Returns a custom join bridge for 'planNodeId'. std::shared_ptr getCustomJoinBridge( uint32_t splitGroupId, @@ -610,7 +619,7 @@ class Task : public std::enable_shared_from_this { StopReason enterForTerminateLocked(ThreadState& state); /// Marks that the Driver is not on thread. If no more Drivers in the - /// CancelPool are on thread, this realizes threadFinishFutures_. These allow + /// Task are on thread, this realizes threadFinishFutures_. /// syncing with pause or termination. The Driver may go off thread because of /// hasBlockingFuture or pause requested or terminate requested. The /// return value indicates the reason. If kTerminate is returned, the @@ -669,10 +678,10 @@ class Task : public std::enable_shared_from_this { /// Invoked once by a driver thread to signal it has finished the barrier /// processing when all its operators have drained their output and the sink - /// operator has propogated the drain signal to all its downstream pipelines + /// operator has propagated the drain signal to all its downstream pipelines /// through the connected queues. Upon the last driver thread call, the task /// finishes the current barrier processing and 'barrierFinishPromises_' are - /// fullfiled to resume barrier request caller. + /// fulfilled to resume barrier request caller. void finishDriverBarrier(); /// Requests the Task to stop activity. The returned future is @@ -777,6 +786,9 @@ class Task : public std::enable_shared_from_this { /// split. bool testingHasDriverWaitForSplit() const; + /// Returns true if all the splits have finished. + bool testingAllSplitsFinished(); + private: // Hook of system-wide running task list. struct TaskListEntry { @@ -802,6 +814,13 @@ class Task : public std::enable_shared_from_this { int32_t memoryArbitrationPriority = 0, std::function onError = nullptr); + // Invoked to do post-create initialization. + void init(std::optional&& spillDiskOpts); + + // Invoked to initialize the spill storage config for this task. + void setSpillDiskConfig( + std::optional&& spillDiskOpts); + // Invoked to add this to the system-wide running task list on task creation. void addToTaskList(); @@ -840,7 +859,7 @@ class Task : public std::enable_shared_from_this { // message. bool allNodesReceivedNoMoreSplitsMessageLocked() const; - // Recursive helper for 'allSpilitsConsumed()' method. + // Recursive helper for 'allSplitsConsumed()' method. bool allSplitsConsumedHelper(const core::PlanNode* planNode) const; // Remove the spill directory, if the Task was creating it for potential @@ -925,7 +944,7 @@ class Task : public std::enable_shared_from_this { VELOX_CHECK_NOT_NULL(task); } - // Gets the shared pointer to the driver to ensure its liveness during the + // Gets the shared pointer to the task to ensure its liveness during the // memory reclaim operation. // // NOTE: a task's memory pool might outlive the task itself. @@ -957,7 +976,7 @@ class Task : public std::enable_shared_from_this { std::shared_ptr getJoinBridgeInternalLocked( uint32_t splitGroupId, const core::PlanNodeId& planNodeId, - MemberType SplitGroupState::*bridges_member); + MemberType SplitGroupState::* bridges_member); std::shared_ptr getCustomJoinBridgeInternal( uint32_t splitGroupId, @@ -969,23 +988,6 @@ class Task : public std::enable_shared_from_this { const core::PlanNodeId& planNodeId, const exec::Split& split); - /// Retrieve a split or split future from the given split store structure. - BlockingReason getSplitOrFutureLocked( - bool forTableScan, - SplitsStore& splitsStore, - exec::Split& split, - ContinueFuture& future, - int32_t maxPreloadSplits, - const ConnectorSplitPreloadFunc& preload); - - /// Returns next split from the store. The caller must ensure the store is not - /// empty. - exec::Split getSplitLocked( - bool forTableScan, - SplitsStore& splitsStore, - int32_t maxPreloadSplits, - const ConnectorSplitPreloadFunc& preload); - // Creates for the given split group and fills up the 'SplitGroupState' // structure, which stores inter-operator state (local exchange, bridges). void createSplitGroupStateLocked(uint32_t splitGroupId); @@ -1015,17 +1017,26 @@ class Task : public std::enable_shared_from_this { // Notifies listeners that the task is now complete. void onTaskCompletion(); + void onAddSplit(const core::PlanNodeId& planNodeId, const exec::Split& split); + // Returns true if all splits are finished processing and there are no more // splits coming for the task. bool isAllSplitsFinishedLocked(); - std::unique_ptr addSplitLocked( + void addSplitLocked( SplitsState& splitsState, - exec::Split&& split); + const exec::Split& split, + std::vector& promises); - std::unique_ptr addSplitToStoreLocked( - SplitsStore& splitsStore, - exec::Split&& split); + void addSplitToStoreLocked( + SplitsState& splitsState, + uint32_t groupId, + const exec::Split& split, + std::vector& promises); + + void setSplitsStore( + std::unique_ptr& splitsStore, + std::unique_ptr newSplitsStore); // Invoked when all the driver threads are off thread. The function returns // 'threadFinishPromises_' to fulfill. @@ -1120,6 +1131,13 @@ class Task : public std::enable_shared_from_this { void recordBatchStartTime(); void recordBatchEndTime(); + void initSplitListeners(); + + /// Checks if the task supports barrier processing. Barrier processing is + /// supported when the task is under single threaded execution mode and all + /// its plan nodes support barrier processing. + void ensureBarrierSupport() const; + // Universally unique identifier of the task. Used to identify the task when // calling TaskListener. const std::string uuid_; @@ -1142,10 +1160,10 @@ class Task : public std::enable_shared_from_this { core::PlanFragment planFragment_; - // Indicates if this task supports barrier processing. It is set to true if - // the task is under single threaded execution mode and all its plan nodes - // support barrier processing. - const bool supportBarrier_; + // First node in the plan fragment that does not support barrier processing. + // Barrier is supported when the task is under single threaded execution mode + // and all its plan nodes support barrier processing. + const core::PlanNode* firstNodeNotSupportingBarrier_{}; const std::optional traceConfig_; @@ -1158,7 +1176,7 @@ class Task : public std::enable_shared_from_this { // to pool_ must be defined after pool_, childPools_. std::shared_ptr pool_; - // Keep driver and operator memory pools alive for the duration of the task + // Keep plan node and operator memory pools alive for the duration of the task // to allow for sharing vectors across drivers without copy. std::vector> childPools_; @@ -1197,7 +1215,7 @@ class Task : public std::enable_shared_from_this { ConsumerSupplier consumerSupplier_; // The function that is executed when the task encounters its first error, - // that is, serError() is called for the first time. + // that is, setError() is called for the first time. std::function onError_; std::vector> driverFactories_; @@ -1359,7 +1377,7 @@ class Task : public std::enable_shared_from_this { // The promises for the futures returned to callers of requestBarrier(). std::vector barrierFinishPromises_; - std::atomic toYield_ = 0; + std::atomic_int32_t toYield_ = 0; int32_t numThreads_ = 0; // Microsecond real time when 'this' last went from no threads to // one thread running. Used to decide if continuous run should be @@ -1391,6 +1409,8 @@ class Task : public std::enable_shared_from_this { preloadingSplits_; folly::CancellationSource cancellationSource_; + + std::vector> splitListeners_; }; /// Listener invoked on task completion. @@ -1431,6 +1451,50 @@ bool registerTaskListener(std::shared_ptr listener); /// unregistered successfuly, false if listener was not found. bool unregisterTaskListener(const std::shared_ptr& listener); +/// Listener invoked when splits are added to Task. +class SplitListener { + public: + SplitListener(const std::string& taskId, const std::string& taskUuid) + : taskId_(taskId), taskUuid_(taskUuid) {} + + virtual ~SplitListener() = default; + + // Called when a split is added to a task for a given plan node. + virtual void onAddSplit( + const core::PlanNodeId& planNodeId, + const exec::Split& split) = 0; + + /// Called on task completion. Provides the information about success or + /// failure as well as runtime statistics about task execution. + virtual void onTaskCompletion() = 0; + + protected: + const std::string taskId_; + const std::string taskUuid_; +}; + +class SplitListenerFactory { + public: + virtual ~SplitListenerFactory() = default; + + /// Create and return an std::unique_ptr to be used by a Task + /// of the given taskId, taskUuid and config. The Task constructor calls this + /// method and holds the SplitListener. The SplitListener is destroyed when + /// the Task is destructed. This method can return a nullptr, e.g., if taskId + /// doesn't satisfy certain criteria or if config.maxNumSplitsListenedTo() is + /// 0. In this situation, the Task doesn't hold this SplitListener. + virtual std::unique_ptr create( + const std::string& taskId, + const std::string& taskUuid, + const core::QueryConfig& config) = 0; +}; + +bool registerSplitListenerFactory( + const std::shared_ptr& factory); + +bool unregisterSplitListenerFactory( + const std::shared_ptr& factory); + std::string executionModeString(Task::ExecutionMode mode); std::ostream& operator<<(std::ostream& out, Task::ExecutionMode mode); diff --git a/velox/exec/TaskStats.h b/velox/exec/TaskStats.h index fe8289770f58..2bef2a894e41 100644 --- a/velox/exec/TaskStats.h +++ b/velox/exec/TaskStats.h @@ -20,7 +20,9 @@ #include #include -#include "velox/exec/Driver.h" +#include "velox/exec/BlockingReason.h" +#include "velox/exec/DriverStats.h" +#include "velox/exec/OperatorStats.h" #include "velox/exec/OutputBuffer.h" namespace facebook::velox::exec { diff --git a/velox/exec/TaskStructs.cpp b/velox/exec/TaskStructs.cpp new file mode 100644 index 000000000000..21786166125c --- /dev/null +++ b/velox/exec/TaskStructs.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/TaskStructs.h" + +namespace facebook::velox::exec { + +void SplitsStore::addSplit( + Split split, + std::vector& promises) { + VELOX_CHECK(!noMoreSplits_); + VELOX_CHECK(!(remoteSplit_ && split.isBarrier())); + splits_.push_back(std::move(split)); + if (promises_.empty()) { + return; + } + promises.push_back(std::move(promises_.back())); + promises_.pop_back(); +} + +ContinueFuture SplitsStore::makeFuture() { + auto [promise, future] = + makeVeloxContinuePromiseContract("SplitsStore::makeFuture"); + promises_.push_back(std::move(promise)); + return std::move(future); +} + +Split SplitsStore::getSplit( + int maxPreloadSplits, + const ConnectorSplitPreloadFunc& preload) { + int readySplitIndex = -1; + if (maxPreloadSplits > 0) { + for (int i = 0, end = std::min(maxPreloadSplits, splits_.size()); + i < end; + ++i) { + if (splits_[i].isBarrier()) { + VELOX_CHECK(!remoteSplit_); + continue; + } + auto& connectorSplit = splits_[i].connectorSplit; + if (!connectorSplit->dataSource) { + // Initializes split->dataSource. + preload(connectorSplit); + preloadingSplits_->insert(connectorSplit); + } else if ( + readySplitIndex == -1 && connectorSplit->dataSource->hasValue()) { + readySplitIndex = i; + preloadingSplits_->erase(connectorSplit); + } + } + } + if (readySplitIndex == -1) { + readySplitIndex = 0; + } + VELOX_CHECK(!splits_.empty()); + auto split = std::move(splits_[readySplitIndex]); + splits_.erase(splits_.begin() + readySplitIndex); + --taskStats_->numQueuedSplits; + ++taskStats_->numRunningSplits; + if (!remoteSplit_ && split.connectorSplit) { + --taskStats_->numQueuedTableScanSplits; + ++taskStats_->numRunningTableScanSplits; + taskStats_->queuedTableScanSplitWeights -= + split.connectorSplit->splitWeight; + taskStats_->runningTableScanSplitWeights += + split.connectorSplit->splitWeight; + } + taskStats_->lastSplitStartTimeMs = getCurrentTimeMs(); + if (taskStats_->firstSplitStartTimeMs == 0) { + taskStats_->firstSplitStartTimeMs = taskStats_->lastSplitStartTimeMs; + } + return split; +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/TaskStructs.h b/velox/exec/TaskStructs.h index bf3418a3f25b..ba99890d3a0c 100644 --- a/velox/exec/TaskStructs.h +++ b/velox/exec/TaskStructs.h @@ -14,7 +14,21 @@ * limitations under the License. */ #pragma once + +#include "velox/common/base/SkewedPartitionBalancer.h" +#include "velox/common/future/VeloxPromise.h" +#include "velox/connectors/Connector.h" +#include "velox/exec/LocalPartition.h" +#include "velox/exec/ScaledScanController.h" +#include "velox/exec/Split.h" +#include "velox/exec/TaskStats.h" + +#include + #include +#include +#include +#include #include #include @@ -25,7 +39,6 @@ class JoinBridge; class LocalExchangeMemoryManager; class MergeSource; class MergeJoinSource; -struct Split; /// Corresponds to Presto TaskState, needed for reporting query completion. enum class TaskState : int { @@ -55,14 +68,90 @@ struct BarrierState { std::vector allPeersFinishedPromises; }; -/// Structure to accumulate splits for distribution. -struct SplitsStore { - /// Arrived (added), but not distributed yet, splits. - std::deque splits; - /// Signal, that no more splits will arrive. - bool noMoreSplits{false}; - /// Blocking promises given out when out of splits to distribute. - std::vector splitPromises; +using ConnectorSplitPreloadFunc = + std::function&)>; + +/// A split store that either can accumulate splits through addSplit() or +/// generating splits from its own source. +/// +/// This object is not multi-thread safe. +class SplitsStore { + public: + explicit SplitsStore(bool remoteSplit) : remoteSplit_(remoteSplit) {} + + virtual ~SplitsStore() = default; + + /// Add a barrier split to this store. Some of the waiters previously waiting + /// on the ContinueFuture after the nextSplit() call should be notified by the + /// caller of this function by setting values on the `promises` output + /// parameter. + /// + /// `promises` should be set by caller (potentially outside a lock), to notify + /// any waiters on the splits. + virtual void requestBarrier(std::vector& promises) = 0; + + /// Return true when split is set or there is no more splits; false when + /// caller should retry when the future is fulfilled. + virtual bool nextSplit( + Split& split, + ContinueFuture& future, + int maxPreloadSplits, + const ConnectorSplitPreloadFunc& preload) = 0; + + /// Return whether all splits has been consumed and there will be no more + /// splits. + virtual bool allSplitsConsumed() const = 0; + + /// Add new split into this store. Some of the waiters previously waiting on + /// the ContinueFuture after the nextSplit() call should be notified by the + /// caller of this function by setting values on the `promises` output + /// parameter. + /// + /// `promises` should be set by caller (potentially outside a lock), to notify + /// any waiters on the splits. + void addSplit(Split split, std::vector& promises); + + /// Return the number of waiters waiting for new splits. + int numWaiters() const { + return promises_.size(); + } + + /// Signal there will not be any more splits added to this store. + std::vector noMoreSplits() { + noMoreSplits_ = true; + return std::move(promises_); + } + + void setTaskStats(TaskStats& taskStats) { + taskStats_ = &taskStats; + } + + void setPreloadingSplits( + folly::F14FastSet>& + preloadingSplits) { + preloadingSplits_ = &preloadingSplits; + } + + protected: + Split getSplit( + int maxPreloadSplits, + const ConnectorSplitPreloadFunc& preload); + + ContinueFuture makeFuture(); + + const bool remoteSplit_; + TaskStats* taskStats_{}; + folly::F14FastSet>* + preloadingSplits_{}; + + // Arrived (added), but not distributed yet, splits. + std::deque splits_; + + // Signal, that no more splits will arrive. + bool noMoreSplits_{false}; + + // Blocking promises given out when out of splits to distribute. + std::vector promises_; }; /// Structure contains the current info on splits for a particular plan node. @@ -77,7 +166,7 @@ struct SplitsState { long maxSequenceId{std::numeric_limits::min()}; /// Map split group id -> split store. - std::unordered_map groupSplitsStores; + std::unordered_map> groupSplitsStores; /// We need these due to having promises in the structure. SplitsState() = default; diff --git a/velox/exec/TaskTraceReader.cpp b/velox/exec/TaskTraceReader.cpp index b972c90cff8c..24125b01b78b 100644 --- a/velox/exec/TaskTraceReader.cpp +++ b/velox/exec/TaskTraceReader.cpp @@ -31,9 +31,10 @@ TaskTraceMetadataReader::TaskTraceMetadataReader( traceFilePath_(getTaskTraceMetaFilePath(traceDir_)), pool_(pool), metadataObj_(getTaskMetadata(traceFilePath_, fs_)), - tracePlanNode_(ISerializable::deserialize( - metadataObj_[TraceTraits::kPlanNodeKey], - pool_)) {} + tracePlanNode_( + ISerializable::deserialize( + metadataObj_[TraceTraits::kPlanNodeKey], + pool_)) {} std::unordered_map TaskTraceMetadataReader::queryConfigs() const { @@ -66,22 +67,41 @@ core::PlanNodePtr TaskTraceMetadataReader::queryPlan() const { } std::string TaskTraceMetadataReader::nodeName(const std::string& nodeId) const { - const auto* traceNode = core::PlanNode::findFirstNode( - tracePlanNode_.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); + LOG(ERROR) << "node id " << nodeId << " trace plan node " + << tracePlanNode_->toString(true, true); + const auto* traceNode = + core::PlanNode::findNodeById(tracePlanNode_.get(), nodeId); + VELOX_CHECK_NOT_NULL( + traceNode, "trace node id {} not found in the trace plan", nodeId); return std::string(traceNode->name()); } -std::string TaskTraceMetadataReader::connectorId( +std::optional TaskTraceMetadataReader::connectorId( const std::string& nodeId) const { - const auto* traceNode = core::PlanNode::findFirstNode( - tracePlanNode_.get(), - [&nodeId](const core::PlanNode* node) { return node->id() == nodeId; }); - const auto* tableScanNode = - dynamic_cast(traceNode); - VELOX_CHECK_NOT_NULL(tableScanNode); - const auto connectorId = tableScanNode->tableHandle()->connectorId(); - VELOX_CHECK(!connectorId.empty()); - return connectorId; + const auto* traceNode = + core::PlanNode::findNodeById(tracePlanNode_.get(), nodeId); + VELOX_CHECK_NOT_NULL( + traceNode, + "trace node id {} not found in the trace plan: {}", + nodeId, + tracePlanNode_->toString(true, true)); + + if (const auto* indexLookupJoinNode = + dynamic_cast(traceNode)) { + const auto indexLookupConnectorId = + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(); + VELOX_CHECK(!indexLookupConnectorId.empty()); + return indexLookupConnectorId; + } + + if (const auto* tableScanNode = + dynamic_cast(traceNode)) { + VELOX_CHECK_NOT_NULL(tableScanNode); + const auto connectorId = tableScanNode->tableHandle()->connectorId(); + VELOX_CHECK(!connectorId.empty()); + return connectorId; + } + + return std::nullopt; } } // namespace facebook::velox::exec::trace diff --git a/velox/exec/TaskTraceReader.h b/velox/exec/TaskTraceReader.h index f8475f966341..51ee04d2813a 100644 --- a/velox/exec/TaskTraceReader.h +++ b/velox/exec/TaskTraceReader.h @@ -37,8 +37,9 @@ class TaskTraceMetadataReader { /// Returns node name in the trace query plan by ID. std::string nodeName(const std::string& nodeId) const; - /// Returns connector ID in the TableScanNode. - std::string connectorId(const std::string& nodeId) const; + /// Returns optional of connector ID in the TableScanNode. If nullptr, then no + /// connector will be registered. + std::optional connectorId(const std::string& nodeId) const; private: const std::string traceDir_; diff --git a/velox/exec/TopNRowNumber.cpp b/velox/exec/TopNRowNumber.cpp index b4b059cd72df..e0da3986dff8 100644 --- a/velox/exec/TopNRowNumber.cpp +++ b/velox/exec/TopNRowNumber.cpp @@ -19,6 +19,28 @@ namespace facebook::velox::exec { namespace { +#define RANK_FUNCTION_DISPATCH(TEMPLATE_FUNC, functionKind, ...) \ + [&]() { \ + switch (functionKind) { \ + case core::TopNRowNumberNode::RankFunction::kRowNumber: { \ + return TEMPLATE_FUNC< \ + core::TopNRowNumberNode::RankFunction::kRowNumber>(__VA_ARGS__); \ + } \ + case core::TopNRowNumberNode::RankFunction::kRank: { \ + return TEMPLATE_FUNC( \ + __VA_ARGS__); \ + } \ + case core::TopNRowNumberNode::RankFunction::kDenseRank: { \ + return TEMPLATE_FUNC< \ + core::TopNRowNumberNode::RankFunction::kDenseRank>(__VA_ARGS__); \ + } \ + default: \ + VELOX_FAIL( \ + "not a rank function kind: {}", \ + core::TopNRowNumberNode::rankFunctionName(functionKind)); \ + } \ + }() + std::vector reorderInputChannels( const RowTypePtr& inputType, const std::vector& partitionKeys, @@ -113,9 +135,11 @@ TopNRowNumber::TopNRowNumber( node->canSpill(driverCtx->queryConfig()) ? driverCtx->makeSpillConfig(operatorId) : std::nullopt), + rankFunction_(node->rankFunction()), limit_{node->limit()}, generateRowNumber_{node->generateRowNumber()}, numPartitionKeys_{node->partitionKeys().size()}, + numSortingKeys_{node->sortingKeys().size()}, inputChannels_{reorderInputChannels( node->inputType(), node->partitionKeys(), @@ -127,13 +151,14 @@ TopNRowNumber::TopNRowNumber( driverCtx->queryConfig().abandonPartialTopNRowNumberMinRows()), abandonPartialMinPct_( driverCtx->queryConfig().abandonPartialTopNRowNumberMinPct()), - data_(std::make_unique( - slice(inputType_->children(), 0, spillCompareFlags_.size()), - slice( - inputType_->children(), - spillCompareFlags_.size(), - inputType_->size()), - pool())), + data_( + std::make_unique( + slice(inputType_->children(), 0, spillCompareFlags_.size()), + slice( + inputType_->children(), + spillCompareFlags_.size(), + inputType_->size()), + pool())), comparator_( inputType_, node->sortingKeys(), @@ -200,19 +225,30 @@ void TopNRowNumber::addInput(RowVectorPtr input) { SelectivityVector rows(numInput); table_->prepareForGroupProbe( *lookup_, input, rows, BaseHashTable::kNoSpillInputStartPartitionBit); - table_->groupProbe(*lookup_, BaseHashTable::kNoSpillInputStartPartitionBit); + try { + table_->groupProbe( + *lookup_, BaseHashTable::kNoSpillInputStartPartitionBit); + } catch (...) { + // If groupProbe throws (e.g., due to OOM), we need to clean up the new + // groups that were inserted but not yet initialized by + // initializeNewPartitions(). Otherwise, close() will crash when trying to + // destroy uninitialized TopRows structures. + cleanupNewPartitions(); + throw; + } // Initialize new partitions. initializeNewPartitions(); - // Process input rows. For each row, lookup the partition. If number of rows - // in that partition is less than limit, add the new row. Otherwise, check - // if row should replace an existing row or be discarded. - for (auto i = 0; i < numInput; ++i) { - auto& partition = partitionAt(lookup_->hits[i]); - processInputRow(i, partition); - } + // Process input rows. For each row, lookup the partition. If the highest + // (top) rank in that partition is less than limit, add the new row. + // Otherwise, check if row should replace an existing row or be discarded. + RANK_FUNCTION_DISPATCH(processInputRowLoop, rankFunction_, numInput); + // It is determined that the TopNRowNumber (as a partial) is not rejecting + // enough input rows to make the duplicate detection worthwhile. Hence, + // abandon the processing at this partial TopN and let the final TopN do + // the processing. if (abandonPartialEarly()) { abandonedPartial_ = true; addRuntimeStat("abandonedPartial", RuntimeCounter(1)); @@ -222,9 +258,7 @@ void TopNRowNumber::addInput(RowVectorPtr input) { outputRows_.resize(outputBatchSize_); } } else { - for (auto i = 0; i < numInput; ++i) { - processInputRow(i, *singlePartition_); - } + RANK_FUNCTION_DISPATCH(processInputRowLoop, rankFunction_, numInput); } } @@ -249,25 +283,167 @@ void TopNRowNumber::initializeNewPartitions() { } } -void TopNRowNumber::processInputRow(vector_size_t index, TopRows& partition) { +void TopNRowNumber::cleanupNewPartitions() { + std::vector newRows(lookup_->newGroups.size()); + for (auto i = 0; i < lookup_->newGroups.size(); ++i) { + newRows[i] = lookup_->hits[lookup_->newGroups[i]]; + } + table_->erase(folly::Range(newRows.data(), newRows.size())); + lookup_->newGroups.clear(); +} + +template <> +char* TopNRowNumber::processRowWithinLimit< + core::TopNRowNumberNode::RankFunction::kRank>( + vector_size_t index, + TopRows& partition) { + // The topRanks queue is not filled yet. + auto& topRows = partition.rows; + if (topRows.empty()) { + partition.topRank = 1; + } else { + // Rank assigns all peer rows the same rank, but the rank increments by + // the number of peers when moving between peers. So when adding a new + // row: + // If row == top rank then top rank is unchanged. + // If row < top rank then top rank += 1. + // If row > top, then rank += number of peers of top rank. + auto* topRow = topRows.top(); + auto result = comparator_.compare(decodedVectors_, index, topRow); + if (result < 0) { + partition.topRank += 1; + } else if (result > 0) { + partition.topRank += partition.numTopRankRows(); + } + } + return data_->newRow(); +} + +template <> +char* TopNRowNumber::processRowWithinLimit< + core::TopNRowNumberNode::RankFunction::kDenseRank>( + vector_size_t index, + TopRows& partition) { + // The topRanks queue is not filled yet. + // dense_rank will add this row to its partition. But the top rank is + // incremented only if the new row is not a peer of any other existing + // row in the partition queue. + if (!partition.isDuplicate(decodedVectors_, index)) { + partition.topRank++; + } + return data_->newRow(); +} + +template <> +char* TopNRowNumber::processRowWithinLimit< + core::TopNRowNumberNode::RankFunction::kRowNumber>( + vector_size_t /*index*/, + TopRows& partition) { + // row_number accumulates the new row in the partition, and the top rank is + // incremented by 1 as row_number increases by 1 at each new row. + ++partition.topRank; + return data_->newRow(); +} + +template <> +char* TopNRowNumber::processRowExceedingLimit< + core::TopNRowNumberNode::RankFunction::kRank>( + vector_size_t index, + TopRows& partition) { auto& topRows = partition.rows; + // The new row < top rank + // For rank, the new row gets assigned its rank as per its position in the + // queue. But the ranks of all subsequent rows increment by 1. + // So we can remove the rows at the top rank as its rank > limit now. + char* topRow = partition.removeTopRankRows(); + char* newRow = data_->initializeRow(topRow, /*reuse=*/true); + if (topRows.empty()) { + partition.topRank = 1; + } else { + // The new top rank value depends on the number of peers of the top ranking + // row. If the current row also has the same value as the new top ranking + // row then it has to be counted as a peer as well. + auto numNewTopRankRows = partition.numTopRankRows(); + topRow = topRows.top(); + if (comparator_.compare(decodedVectors_, index, topRow) == 0) { + partition.topRank = topRows.size() - numNewTopRankRows + 1; + } else { + partition.topRank = topRows.size() - numNewTopRankRows + 2; + } + } + return newRow; +} +template <> +char* TopNRowNumber::processRowExceedingLimit< + core::TopNRowNumberNode::RankFunction::kDenseRank>( + vector_size_t index, + TopRows& partition) { char* newRow = nullptr; - if (topRows.size() < limit_) { + // The new row < top rank + // For dense_rank: + // i) If the row is a peer of an existing row in the queue, then it + // has the same rank as it. The ranks of other rows are unchanged. So its + // only added to the queue. + // ii) If the row is a distinct new value in the queue, then it is assigned + // a rank as per its position, and the ranks of all subsequent rows += 1. + // So the current top rank rows can be removed from the queue as their new + // rank > limit. + if (partition.isDuplicate(decodedVectors_, index)) { newRow = data_->newRow(); } else { + char* topRow = partition.removeTopRankRows(); + newRow = data_->initializeRow(topRow, /*reuse=*/true); + } + return newRow; +} + +template <> +char* TopNRowNumber::processRowExceedingLimit< + core::TopNRowNumberNode::RankFunction::kRowNumber>( + vector_size_t /*index*/, + TopRows& partition) { + // The new row has rank < highest (aka top) rank at 'limit' function value. + // For row_number, such rows are added to the accumulator queue and the + // top rank row is popped out. The topRank remains the same. + auto& topRows = partition.rows; + char* topRow = topRows.top(); + topRows.pop(); + // Reuses the space of the popped row itself for the new row. + return data_->initializeRow(topRow, true /* reuse */); +} + +template +void TopNRowNumber::processInputRow(vector_size_t index, TopRows& partition) { + auto& topRows = partition.rows; + + char* newRow = nullptr; + if (partition.topRank < limit_) { + newRow = processRowWithinLimit(index, partition); + } else { + // The partition has now accumulated >= limit rows. So the new rows can be + // rejected or replace existing rows based on the order_by values. char* topRow = topRows.top(); - if (!comparator_(decodedVectors_, index, topRow)) { - // Drop this input row. + const auto result = comparator_.compare(decodedVectors_, index, topRow); + if (result > 0) { + // The new row is bigger than the top rank so far, so this row is ignored. return; } - // Replace existing row. - topRows.pop(); + // This row has the same value as the top rank row. row_number rejects + // such rows, but are added to the queue for rank and dense_rank. The top + // rank remains unchanged. + else if (result == 0) { + if (rankFunction_ == core::TopNRowNumberNode::RankFunction::kRowNumber) { + return; + } + newRow = data_->newRow(); + } - // Reuse the topRow's memory. - newRow = data_->initializeRow(topRow, true /* reuse */); + else if (result < 0) { + newRow = processRowExceedingLimit(index, partition); + } } for (auto col = 0; col < decodedVectors_.size(); ++col) { @@ -277,6 +453,19 @@ void TopNRowNumber::processInputRow(vector_size_t index, TopRows& partition) { topRows.push(newRow); } +template +void TopNRowNumber::processInputRowLoop(vector_size_t numInput) { + if (table_) { + for (auto i = 0; i < numInput; ++i) { + processInputRow(i, partitionAt(lookup_->hits[i])); + } + } else { + for (auto i = 0; i < numInput; ++i) { + processInputRow(i, *singlePartition_); + } + } +} + void TopNRowNumber::noMoreInput() { Operator::noMoreInput(); @@ -293,7 +482,7 @@ void TopNRowNumber::noMoreInput() { spiller_->finishSpill(spillPartitionSet); VELOX_CHECK_EQ(spillPartitionSet.size(), 1); merge_ = spillPartitionSet.begin()->second->createOrderedReader( - spillConfig_->readBufferSize, pool(), &spillStats_); + *spillConfig_, pool(), spillStats_.get()); } else { outputRows_.resize(outputBatchSize_); } @@ -318,10 +507,46 @@ void TopNRowNumber::updateEstimatedOutputRowSize() { } } +// This function handles a special case when determining the starting +// rank value for the 'rank' function. +// If there are many peer rows for the highest rank, then topRank could +// oscillate between the two cases of topRank < limit and topRank > limit +// as rows are added +// E.g. If the input rows are 0, 0, 0, 5, 0, 0, 6 and we want rank <= 5, then +// at 0, 0, 0, 5 : +// topRows.pq - 0, 0, 0, 5 topRank -> 4 +// 0 is added. +// topRows.pq - 0, 0, 0, 0, 5 topRank -> 5 +// topRank = limit now. +// So when the next 0 is added, the last 5 is popped from TopRows and 0 is added +// topRows.pq - 0, 0, 0, 0, 0, topRank -> 1 +// This makes topRank < 5 and so when 6 comes by, 6 is pushed +// topRows.pq - 0, 0, 0, 0, 0, 6 topRank -> 6 +// So when doing getOutput, we need to adjust this case. +// Since topRank > limit, then the highest rank is popped and the +// topRank is adjusted as length(pq) - number_of_duplicates_of_new_top_row + 1. +vector_size_t TopNRowNumber::fixTopRank(TopRows& partition) { + if (rankFunction_ == core::TopNRowNumberNode::RankFunction::kRank) { + if (partition.topRank > limit_) { + partition.removeTopRankRows(); + auto numNewTopRankRows = partition.numTopRankRows(); + partition.topRank = partition.rows.size() - numNewTopRankRows + 1; + } + } + + return partition.topRank; +} + TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { + auto setNextRankAndPeer = [&](TopRows& partition) { + nextRank_ = fixTopRank(partition); + numPeers_ = 1; + }; + if (!table_) { if (!outputPartitionNumber_) { outputPartitionNumber_ = 0; + setNextRankAndPeer(*singlePartition_); return singlePartition_.get(); } return nullptr; @@ -337,7 +562,6 @@ TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { // No more partitions. return nullptr; } - outputPartitionNumber_ = 0; } else { ++outputPartitionNumber_.value(); @@ -347,24 +571,58 @@ TopNRowNumber::TopRows* TopNRowNumber::nextPartition() { } } - return &partitionAt(partitions_[outputPartitionNumber_.value()]); + auto partition = &partitionAt(partitions_[outputPartitionNumber_.value()]); + setNextRankAndPeer(*partition); + return partition; +} + +template +void TopNRowNumber::computeNextRankInMemory( + const TopRows& partition, + vector_size_t outputIndex) { + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kRowNumber) { + nextRank_ -= 1; + return; + } + + // This is the logic for rank() and dense_rank(). + // If the next row is a peer of the current one, then the rank remains the + // same, but the number of peers is incremented. + if (comparator_.compare(outputRows_[outputIndex], partition.rows.top()) == + 0) { + numPeers_ += 1; + return; + } + + // The new row is not a peer of the current one. So dense_rank drops the + // rank by 1, but rank drops it by the number of peers (which is then + // reset). + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kDenseRank) { + nextRank_ -= 1; + } else { + nextRank_ -= numPeers_; + numPeers_ = 1; + } } +template void TopNRowNumber::appendPartitionRows( TopRows& partition, vector_size_t numRows, vector_size_t outputOffset, - FlatVector* rowNumbers) { + FlatVector* rankValues) { // The partition.rows priority queue pops rows in order of reverse - // row numbers. - auto rowNumber = partition.rows.size(); + // ranks. Output rows based on nextRank_ and update it with each row. for (auto i = 0; i < numRows; ++i) { - const auto index = outputOffset + i; - if (rowNumbers) { - rowNumbers->set(index, rowNumber--); + auto index = outputOffset + i; + if (rankValues) { + rankValues->set(index, nextRank_); } outputRows_[index] = partition.rows.top(); partition.rows.pop(); + if (!partition.rows.empty()) { + computeNextRankInMemory(partition, index); + } } } @@ -380,7 +638,7 @@ RowVectorPtr TopNRowNumber::getOutput() { return output; } - // We may have input accumulated in 'data_'. + // There could be older rows accumulated in 'data_'. if (data_->numRows() > 0) { return getOutputFromMemory(); } @@ -389,6 +647,7 @@ RowVectorPtr TopNRowNumber::getOutput() { finished_ = true; } + // There is no data to return at this moment. return nullptr; } @@ -396,9 +655,11 @@ RowVectorPtr TopNRowNumber::getOutput() { return nullptr; } + // All the input data is received, so the operator can start producing + // output. RowVectorPtr output; if (merge_ != nullptr) { - output = getOutputFromSpill(); + output = RANK_FUNCTION_DISPATCH(getOutputFromSpill, rankFunction_); } else { output = getOutputFromMemory(); } @@ -435,20 +696,27 @@ RowVectorPtr TopNRowNumber::getOutputFromMemory() { const auto numOutputRowsLeft = outputBatchSize_ - offset; if (outputPartition_->rows.size() > numOutputRowsLeft) { - // Only a partial partition can be output in this getOutput() call. // Output as many rows as possible. - // NOTE: the partial output partition erases the yielded output rows - // and next getOutput() call starts with the remaining rows. - appendPartitionRows( - *outputPartition_, numOutputRowsLeft, offset, rowNumbers); + RANK_FUNCTION_DISPATCH( + appendPartitionRows, + rankFunction_, + *outputPartition_, + numOutputRowsLeft, + offset, + rowNumbers); offset += numOutputRowsLeft; break; } // Add all partition rows. - auto numPartitionRows = outputPartition_->rows.size(); - appendPartitionRows( - *outputPartition_, numPartitionRows, offset, rowNumbers); + const auto numPartitionRows = outputPartition_->rows.size(); + RANK_FUNCTION_DISPATCH( + appendPartitionRows, + rankFunction_, + *outputPartition_, + numPartitionRows, + offset, + rowNumbers); offset += numPartitionRows; outputPartition_ = nullptr; } @@ -475,13 +743,15 @@ RowVectorPtr TopNRowNumber::getOutputFromMemory() { return output; } -bool TopNRowNumber::isNewPartition( +bool TopNRowNumber::compareSpillRowColumns( const RowVectorPtr& output, vector_size_t index, - SpillMergeStream* next) { + const SpillMergeStream* next, + vector_size_t startColumn, + vector_size_t endColumn) { VELOX_CHECK_GT(index, 0); - for (auto i = 0; i < numPartitionKeys_; ++i) { + for (auto i = startColumn; i < endColumn; ++i) { if (!output->childAt(inputChannels_[i]) ->equalValueAt( next->current().childAt(i).get(), @@ -493,22 +763,79 @@ bool TopNRowNumber::isNewPartition( return false; } -void TopNRowNumber::setupNextOutput( +// Compares the partition keys for new partitions. +bool TopNRowNumber::isNewPartition( const RowVectorPtr& output, - int32_t rowNumber) { - auto* lookAhead = merge_->next(); - if (lookAhead == nullptr) { - nextRowNumber_ = 0; + vector_size_t index, + const SpillMergeStream* next) { + return compareSpillRowColumns(output, index, next, 0, numPartitionKeys_); +} + +// Compares the sorting keys for determining peers. +bool TopNRowNumber::isNewRank( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next) { + return compareSpillRowColumns( + output, + index, + next, + numPartitionKeys_, + numPartitionKeys_ + numSortingKeys_); +} + +template +void TopNRowNumber::computeNextRankInSpill( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next) { + if (isNewPartition(output, index, next)) { + nextRank_ = 1; + numPeers_ = 1; + return; + } + + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kRowNumber) { + nextRank_ += 1; + return; + } + + // The function is either rank or dense_rank. + // This row belongs to the same partition as the previous row. However, + // it should be determined if it is a peer row as well. If its a peer, + // then increase numPeers_ but the rank remains unchanged. + if (!isNewRank(output, index, next)) { + numPeers_ += 1; + return; + } + + // The row is not a peer, so increment the rank and peers accordingly. + if constexpr (TRank == core::TopNRowNumberNode::RankFunction::kDenseRank) { + nextRank_ += 1; + numPeers_ = 1; return; } - if (isNewPartition(output, output->size(), lookAhead)) { - nextRowNumber_ = 0; + // Rank function increments by number of peers. + nextRank_ += numPeers_; + numPeers_ = 1; +} + +template +void TopNRowNumber::setupNextOutput(const RowVectorPtr& output) { + auto resetNextRankAndPeer = [this]() { + nextRank_ = 1; + numPeers_ = 1; + }; + + auto* lookAhead = merge_->next(); + if (lookAhead == nullptr) { + resetNextRankAndPeer(); return; } - nextRowNumber_ = rowNumber; - if (nextRowNumber_ < limit_) { + computeNextRankInSpill(output, output->size(), lookAhead); + if (nextRank_ <= limit_) { return; } @@ -516,16 +843,17 @@ void TopNRowNumber::setupNextOutput( lookAhead->pop(); while (auto* next = merge_->next()) { if (isNewPartition(output, output->size(), next)) { - nextRowNumber_ = 0; + resetNextRankAndPeer(); return; } next->pop(); } // This partition is the last partition. - nextRowNumber_ = 0; + resetNextRankAndPeer(); } +template RowVectorPtr TopNRowNumber::getOutputFromSpill() { VELOX_CHECK_NOT_NULL(merge_); @@ -533,37 +861,32 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { // All rows from the same partition will appear together. // We'll identify partition boundaries by comparing partition keys of the // current row with the previous row. When new partition starts, we'll reset - // row number to zero. Once row number reaches the 'limit_', we'll start + // nextRank_ and numPeers_. Once rank reaches the 'limit_', we'll start // dropping rows until the next partition starts. // We'll emit output every time we accumulate 'outputBatchSize_' rows. - auto output = BaseVector::create(outputType_, outputBatchSize_, pool()); - FlatVector* rowNumbers = nullptr; + FlatVector* rankValues = nullptr; if (generateRowNumber_) { - rowNumbers = output->children().back()->as>(); + rankValues = output->children().back()->as>(); } // Index of the next row to append to output. vector_size_t index = 0; - - // Row number of the next row in the current partition. - vector_size_t rowNumber = nextRowNumber_; - VELOX_CHECK_LT(rowNumber, limit_); + VELOX_CHECK_LE(nextRank_, limit_); for (;;) { auto next = merge_->next(); if (next == nullptr) { break; } - // Check if this row comes from a new partition. - if (index > 0 && isNewPartition(output, index, next)) { - rowNumber = 0; + if (index > 0) { + computeNextRankInSpill(output, index, next); } // Copy this row to the output buffer if this partition has // < limit_ rows output. - if (rowNumber < limit_) { + if (nextRank_ <= limit_) { for (auto i = 0; i < inputChannels_.size(); ++i) { output->childAt(inputChannels_[i]) ->copy( @@ -572,12 +895,11 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { next->currentIndex(), 1); } - if (rowNumbers) { - // Row numbers start with 1. - rowNumbers->set(index, rowNumber + 1); + + if (rankValues) { + rankValues->set(index, nextRank_); } ++index; - ++rowNumber; } // Pop this row from the spill. @@ -588,8 +910,8 @@ RowVectorPtr TopNRowNumber::getOutputFromSpill() { // Prepare the next batch : // i) If 'limit_' is reached for this partition, then skip the rows // until the next partition. - // ii) If the next row is from a new partition, then reset rowNumber_. - setupNextOutput(output, rowNumber); + // ii) If the next row is from a new partition, then reset nextRank_. + setupNextOutput(output); return output; } } @@ -704,14 +1026,15 @@ void TopNRowNumber::ensureInputFits(const RowVectorPtr& input) { if ((tableIncrementBytes == 0) && (freeRows > input->size()) && (outOfLineBytes == 0 || outOfLineFreeBytes >= outOfLineBytesPerRow * input->size())) { - // Enough free rows for input rows and enough variable length free space. + // Enough free rows for input rows and enough variable length free + // space. return; } } - // Check if we can increase reservation. The increment is the largest of twice - // the maximum increment from this input and 'spillableReservationGrowthPct_' - // of the current memory usage. + // Check if we can increase reservation. The increment is the largest of + // twice the maximum increment from this input and + // 'spillableReservationGrowthPct_' of the current memory usage. const auto targetIncrementBytes = std::max( incrementBytes * 2, currentUsage * spillConfig_->spillableReservationGrowthPct / 100); @@ -750,6 +1073,68 @@ void TopNRowNumber::setupSpiller() { inputType_, sortingKeys, &spillConfig_.value(), - &spillStats_); + spillStats_.get()); +} + +// Using the underlying vector of the priority queue for the algorithms to +// check duplicates and count the number of top rank rows. This makes the +// algorithms O(n). There could be other approaches to make the +// algorithms O(1), but would trade memory efficiency. +namespace { +template +S& PriorityQueueVector(std::priority_queue& q) { + struct PrivateQueue : private std::priority_queue { + static S& Container(std::priority_queue& q) { + return q.*&PrivateQueue::c; + } + }; + return PrivateQueue::Container(q); +} +} // namespace + +char* TopNRowNumber::TopRows::removeTopRankRows() { + VELOX_CHECK(!rows.empty()); + + char* topRow = rows.top(); + rows.pop(); + + while (!rows.empty()) { + char* newTopRow = rows.top(); + if (rowComparator.compare(topRow, newTopRow) != 0) { + return topRow; + } + rows.pop(); + } + return topRow; +} + +vector_size_t TopNRowNumber::TopRows::numTopRankRows() { + VELOX_CHECK(!rows.empty()); + char* topRow = rows.top(); + vector_size_t numRows = 0; + const std::vector> partitionRowsVector = + PriorityQueueVector(rows); + for (const char* row : partitionRowsVector) { + if (rowComparator.compare(topRow, row) == 0) { + numRows += 1; + } else { + break; + } + } + return numRows; +} + +bool TopNRowNumber::TopRows::isDuplicate( + const std::vector& decodedVectors, + vector_size_t index) { + const std::vector> partitionRowsVector = + PriorityQueueVector(rows); + for (const char* row : partitionRowsVector) { + if (rowComparator.compare(decodedVectors, index, row) == 0) { + return true; + } + } + return false; } + } // namespace facebook::velox::exec diff --git a/velox/exec/TopNRowNumber.h b/velox/exec/TopNRowNumber.h index a89f967f9a04..649869439671 100644 --- a/velox/exec/TopNRowNumber.h +++ b/velox/exec/TopNRowNumber.h @@ -22,20 +22,58 @@ namespace facebook::velox::exec { class TopNRowNumberSpiller; -/// Partitions the input using specified partitioning keys, sorts rows within -/// partitions using specified sorting keys, assigns row numbers and returns up -/// to specified number of rows per partition. +/// TopNRowNumber is an optimized version of a Window operator with a +/// single row_number or rank or dense_rank window function followed by a +/// rank <= N filter. N must be >= 0. If the TopNRowNumber has no partition +/// keys, then all the rows belong to a single partition. However, the +/// TopNRowNumber should have at least one sorting key specified. /// -/// It is allowed to not specify partitioning keys. In this case the whole input -/// is treated as a single partition. +/// TopNRowNumber is more efficient than a general Window operator as it does +/// not store all rows of a partition. Instead, it only keeps the top N +/// rows of the partition at any point. /// -/// At least one sorting key must be specified. +/// The operator partitions the input using specified partitioning keys, +/// and maintains a TopRows structure per partition in a HashTable. The TopRows +/// maintains a priority queue of row pointers. The priority queue is +/// kept ordered by sorting keys of the TopNRowNumber. The TopRows only retains +/// rows whose ranks satisfy the filter condition (so rank <= N). N is also +/// called the limit of the operator. To aid this filtering, the TopRows tracks +/// the greatest rank seen for each partition. /// -/// The limit (maximum number of rows to return per partition) must be greater -/// than zero. +/// The operator processes all input rows before beginning to output rows. /// -/// This is an optimized version of a Window operator with a single row_number -/// window function followed by a row_number <= N filter. +/// For each input row, it retrieves the TopRows corresponding to the partition +/// keys. The TopRows is first filled until it has N rows. Thereafter, new rows +/// are compared with the top row in the TopRows priority queue. +/// If the new rows order by values are less than (for ASC) or greater than +/// (for DESC) so row rank <= topRank, then the row is added to TopRows. +/// For each outcome, the greatest rank of the TopRows is updated as per the +/// ranking function logic. +/// For each function type, the rank maintenance logic is in: +/// - processRowWithinLimit() function when the TopRows is filling the first +/// N rows. +/// - processRowExceedingLimit() function when the TopRows already has N rows. +/// +/// After processing all the input rows, the operator proceeds to output the +/// rows. The rows might all be in memory or spilled to disk if memory +/// reclamation was triggered during processing. +/// +/// If the rows are in memory, then the operator iterates over each partition +/// in the HashTable, and starts outputting rows from the partition. The +/// TopRows structure maintains the rows in descending order of their ranks +/// (greatest rank at the top of the priority queue). So when outputting, +/// the operator first fixes the top rank of the partition using fixTopRank() +/// and then computes the ranks of each row using computeNextRankInMemory(). +/// The logic of the next rank differs based on the ranking function. +/// +/// If the rows are in the spill, then the spiller iterates over each spilled +/// partition in order of the ranks. For each row from the spill, the next +/// rank is computed using computeNextRankInSpill() function. The logic of +/// the next rank differs based on the ranking function. +/// Note : The spill could have > limit rows for a partition as each spill +/// resets the TopRows for the partition. So stop outputting rows after +/// reaching the limit for each partition. + class TopNRowNumber : public Operator { public: TopNRowNumber( @@ -71,7 +109,23 @@ class TopNRowNumber : public Operator { override; private: - // A priority queue to keep track of top 'limit' rows for a given partition. + // This structure holds the top rows for a partition. It uses a priority + // queue to maintain the top rows in order of their ranks. Note the rank + // logic depends on the respective function (row_number, rank or dense_rank). + // However, a common requirement across all three is to maintain the rows in + // order of their sort keys so that the greatest rank row is always at the top + // of the queue. This ordering is done using the RowComparator passed to the + // TopRows. + // + // The number of rows in TopRows are limited to 'limit' specified for the + // operator. The greatest rank of the rows in TopRows is maintained in the + // 'topRank' variable. + // + // The TopRows structure is first filled in order to collect 'limit' + // rows. Thereafter, new rows are compared with the top row and either kept + // or discarded and the new top rank is updated. The rank computation differs + // based on the ranking function. This structure has methods for abstractions + // used for the top rank maintenance algorithms. struct TopRows { struct Compare { RowComparator& comparator; @@ -84,12 +138,39 @@ class TopNRowNumber : public Operator { std::priority_queue>, Compare> rows; + RowComparator& rowComparator; + + // This is the greatest rank seen so far in the input rows. Note: rank is + // the result of the respective function computation (row_number, rank or + // dense_rank). It is compared with the expected limit for the operator. + int64_t topRank = 0; + + // Number of rows with the highest rank in the partition. + vector_size_t numTopRankRows(); + + // Remove all rows with the highest rank in the partition. + // Returns a pointer to the last removed row. + char* removeTopRankRows(); + + // Returns true if the row at position index in decodedVectors + // has the same order by keys as another row in the TopRows + // priority_vector. + bool isDuplicate( + const std::vector& decodedVectors, + vector_size_t index); + TopRows(HashStringAllocator* allocator, RowComparator& comparator) - : rows{{comparator}, StlAllocator(allocator)} {} + : rows{{comparator}, StlAllocator(allocator)}, + rowComparator(comparator) {} }; void initializeNewPartitions(); + // Cleans up any newly inserted but uninitialized partitions from the hash + // table. This is called when groupProbe throws (e.g., due to OOM) to ensure + // close() doesn't crash trying to destroy uninitialized TopRows structures. + void cleanupNewPartitions(); + TopRows& partitionAt(char* group) { return *reinterpret_cast(group + partitionOffset_); } @@ -97,17 +178,46 @@ class TopNRowNumber : public Operator { // Decodes and potentially loads input if lazy vector. void prepareInput(RowVectorPtr& input); + // Handles input row when the partition has not yet accumulated 'limit' rows. + // Returns a pointer to the row to add to the partition accumulator. + template + char* processRowWithinLimit(vector_size_t index, TopRows& partition); + + // Handles input row when the partition has already accumulated 'limit' rows. + // Returns a pointer to the row to add to the partition accumulator. + template + char* processRowExceedingLimit(vector_size_t index, TopRows& partition); + + // Loop to process the numInput input rows received by the operator. + template + void processInputRowLoop(vector_size_t numInput); + // Adds input row to a partition or discards the row. + template void processInputRow(vector_size_t index, TopRows& partition); // Returns next partition to add to output or nullptr if there are no // partitions left. TopRows* nextPartition(); - // Appends numRows of the output partition the output. Note: The rows are - // popped in reverse order of the row_number. + // If there are many rows with the highest rank, then the topRank + // of the partition can oscillate between a very small value and a + // value > limit. Fix the partition for this condition before starting to + // output the partition. + vector_size_t fixTopRank(TopRows& partition); + + // Computes the rank for the next row to be output + // (all output rows in memory). + template + void computeNextRankInMemory( + const TopRows& partition, + vector_size_t rowIndex); + + // Appends numRows of the current partition to the output. Note: The rows are + // popped in reverse order of the rank. // NOTE: This function erases the yielded output rows from the partition // and the next call starts with the remaining rows. + template void appendPartitionRows( TopRows& partition, vector_size_t numRows, @@ -125,6 +235,7 @@ class TopNRowNumber : public Operator { void setupSpiller(); + template RowVectorPtr getOutputFromSpill(); RowVectorPtr getOutputFromMemory(); @@ -134,17 +245,42 @@ class TopNRowNumber : public Operator { bool isNewPartition( const RowVectorPtr& output, vector_size_t index, - SpillMergeStream* next); + const SpillMergeStream* next); + + // Returns true if 'next' row is a new rank (rows differ on order by keys) + // of the previous row in the partition (at output[index] of the + // output block). + bool isNewRank( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next); + + // Utility method to compare values from startColumn to endColumn for + // 'next' row from SpillMergeStream with current row of output (at index). + bool compareSpillRowColumns( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next, + vector_size_t startColumn, + vector_size_t endColumn); + + // Computes next rank value for spill output. + template + inline void computeNextRankInSpill( + const RowVectorPtr& output, + vector_size_t index, + const SpillMergeStream* next); - // Sets nextRowNumber_ to rowNumber. Checks if next row in 'merge_' belongs to - // a different partition than last row in 'output' and if so updates - // nextRowNumber_ to 0. Also, checks current partition reached the limit on - // number of rows and if so advances 'merge_' to the first row on the next - // partition and sets nextRowNumber_ to 0. + // Checks if next row in 'merge_' belongs to a different partition than last + // row in 'output' and if so updates nextRank_ and numPeers_ to 1. + // Also, checks current partition reached the limit on rank and + // if so advances 'merge_' to the first row on the next + // partition and sets nextRank_ and numPeers_ to 0. // // @post 'merge_->next()' is either at end or points to a row that should be - // included in the next output batch using 'nextRowNumber_'. - void setupNextOutput(const RowVectorPtr& output, int32_t rowNumber); + // included in the next output batch using 'nextRank_'. + template + void setupNextOutput(const RowVectorPtr& output); // Called in noMoreInput() and spill(). void updateEstimatedOutputRowSize(); @@ -153,11 +289,15 @@ class TopNRowNumber : public Operator { // cardinality sufficiently. Returns false if spilling was triggered earlier. bool abandonPartialEarly() const; + // Rank function semantics of operator. + const core::TopNRowNumberNode::RankFunction rankFunction_; + const int32_t limit_; const bool generateRowNumber_; const size_t numPartitionKeys_; + const size_t numSortingKeys_; // Input columns in the order of: partition keys, sorting keys, the rest. const std::vector inputChannels_; @@ -244,7 +384,11 @@ class TopNRowNumber : public Operator { // Used to sort-merge spilled data. std::unique_ptr> merge_; - // Row number for the first row in the next output batch from the spiller. - int32_t nextRowNumber_{0}; + // Row number/rank or dense_rank for the first row in the next output batch + // from the spiller. + vector_size_t nextRank_{1}; + // Number of peers of first row in the previous output batch. This is used + // in rank calculation. + vector_size_t numPeers_{1}; }; } // namespace facebook::velox::exec diff --git a/velox/exec/TraceUtil.cpp b/velox/exec/TraceUtil.cpp index 10388724e1bc..22b1c7d04c46 100644 --- a/velox/exec/TraceUtil.cpp +++ b/velox/exec/TraceUtil.cpp @@ -18,10 +18,12 @@ #include -#include +#include + #include "velox/common/base/Exceptions.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" +#include "velox/exec/TableWriter.h" #include "velox/exec/Trace.h" namespace facebook::velox::exec::trace { @@ -35,6 +37,11 @@ std::string findLastPathNode(const std::string& path) { VELOX_CHECK(!pathNodes.empty(), "No valid path nodes found from {}", path); return pathNodes.back(); } + +std::unordered_map& traceNodeRegistry() { + static std::unordered_map registry; + return registry; +} } // namespace void createTraceDirectory( @@ -66,7 +73,13 @@ void createTraceDirectory( std::string getQueryTraceDirectory( const std::string& traceDir, const std::string& queryId) { - return fmt::format("{}/{}", traceDir, queryId); + // Remove trailing slash from traceDir if present + std::string normalizedTraceDir = traceDir; + if (!normalizedTraceDir.empty() && normalizedTraceDir.back() == '/') { + normalizedTraceDir.pop_back(); + } + + return fmt::format("{}/{}", normalizedTraceDir, queryId); } std::string getTaskTraceDirectory( @@ -80,8 +93,9 @@ std::string getTaskTraceDirectory( const std::string& traceDir, const std::string& queryId, const std::string& taskId) { - return fmt::format( - "{}/{}", getQueryTraceDirectory(traceDir, queryId), taskId); + auto queryTraceDir = getQueryTraceDirectory(traceDir, queryId); + + return fmt::format("{}/{}", queryTraceDir, taskId); } std::string getTaskTraceMetaFilePath(const std::string& taskTraceDir) { @@ -166,10 +180,8 @@ RowTypePtr getDataType( const core::PlanNodePtr& tracedPlan, const std::string& tracedNodeId, size_t sourceIndex) { - const auto* traceNode = core::PlanNode::findFirstNode( - tracedPlan.get(), [&tracedNodeId](const core::PlanNode* node) { - return node->id() == tracedNodeId; - }); + const auto* traceNode = + core::PlanNode::findNodeById(tracedPlan.get(), tracedNodeId); VELOX_CHECK_NOT_NULL( traceNode, "traced node id {} not found in the traced plan", @@ -232,13 +244,256 @@ std::vector extractDriverIds(const std::string& driverIds) { bool canTrace(const std::string& operatorType) { static const std::unordered_set kSupportedOperatorTypes{ "Aggregation", + "CallbackSink", + "Exchange", "FilterProject", "HashBuild", "HashProbe", + "IndexLookupJoin", + "MergeExchange", + "MergeJoin", + "OrderBy", "PartialAggregation", "PartitionedOutput", "TableScan", - "TableWrite"}; - return kSupportedOperatorTypes.count(operatorType) > 0; + "TableWrite", + "TopNRowNumber", + "Unnest"}; + if (kSupportedOperatorTypes.count(operatorType) > 0 || + traceNodeRegistry().count(operatorType) > 0) { + return true; + } + return false; +} + +core::PlanNodePtr getTraceNode( + const core::PlanNodePtr& plan, + core::PlanNodeId nodeId) { + const auto* traceNode = core::PlanNode::findNodeById(plan.get(), nodeId); + VELOX_CHECK_NOT_NULL(traceNode, "Failed to find node with id {}", nodeId); + if (const auto* hashJoinNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + hashJoinNode->joinType(), + hashJoinNode->isNullAware(), + hashJoinNode->leftKeys(), + hashJoinNode->rightKeys(), + hashJoinNode->filter(), + std::make_shared( + hashJoinNode->sources()[0]->outputType()), + std::make_shared( + hashJoinNode->sources()[1]->outputType()), + hashJoinNode->outputType()); + } + + if (const auto* mergeJoinNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + mergeJoinNode->joinType(), + mergeJoinNode->leftKeys(), + mergeJoinNode->rightKeys(), + mergeJoinNode->filter(), + std::make_shared( + mergeJoinNode->sources()[0]->outputType()), + std::make_shared( + mergeJoinNode->sources()[1]->outputType()), + mergeJoinNode->outputType()); + } + + if (const auto* filterNode = + dynamic_cast(traceNode)) { + // Single FilterNode. + return std::make_shared( + nodeId, + filterNode->filter(), + std::make_shared( + filterNode->sources().front()->outputType())); + } + + if (const auto* projectNode = + dynamic_cast(traceNode)) { + // A standalone ProjectNode. + if (projectNode->sources().empty() || + projectNode->sources().front()->name() != "Filter") { + return std::make_shared( + nodeId, + projectNode->names(), + projectNode->projections(), + std::make_shared( + projectNode->sources().front()->outputType())); + } + + // -- ProjectNode [nodeId] + // -- FilterNode [nodeId - 1] + const auto originalFilterNode = + std::dynamic_pointer_cast( + projectNode->sources().front()); + VELOX_CHECK_NOT_NULL(originalFilterNode); + + auto filterNode = std::make_shared( + originalFilterNode->id(), + originalFilterNode->filter(), + std::make_shared( + originalFilterNode->sources().front()->outputType())); + return std::make_shared( + nodeId, + projectNode->names(), + projectNode->projections(), + std::move(filterNode)); + } + + if (const auto* aggregationNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + aggregationNode->step(), + aggregationNode->groupingKeys(), + aggregationNode->preGroupedKeys(), + aggregationNode->aggregateNames(), + aggregationNode->aggregates(), + aggregationNode->globalGroupingSets(), + aggregationNode->groupId(), + aggregationNode->ignoreNullKeys(), + aggregationNode->noGroupsSpanBatches(), + std::make_shared( + aggregationNode->sources().front()->outputType())); + } + + if (const auto* partitionedOutputNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + partitionedOutputNode->kind(), + partitionedOutputNode->keys(), + partitionedOutputNode->numPartitions(), + partitionedOutputNode->isReplicateNullsAndAny(), + partitionedOutputNode->partitionFunctionSpecPtr(), + partitionedOutputNode->outputType(), + VectorSerde::Kind::kPresto, + std::make_shared( + partitionedOutputNode->sources().front()->outputType())); + } + + if (const auto* indexLookupJoinNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + indexLookupJoinNode->joinType(), + indexLookupJoinNode->leftKeys(), + indexLookupJoinNode->rightKeys(), + indexLookupJoinNode->joinConditions(), + indexLookupJoinNode->filter(), + indexLookupJoinNode->hasMarker(), + std::make_shared( + indexLookupJoinNode->sources().front()->outputType()), // Probe side + indexLookupJoinNode->lookupSource(), // Index side + indexLookupJoinNode->outputType()); + } + + if (const auto* tableScanNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + tableScanNode->outputType(), + tableScanNode->tableHandle(), + tableScanNode->assignments()); + } + + if (const auto* tableWriteNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + tableWriteNode->columns(), + tableWriteNode->columnNames(), + tableWriteNode->columnStatsSpec(), + tableWriteNode->insertTableHandle(), + tableWriteNode->hasPartitioningScheme(), + TableWriteTraits::outputType(tableWriteNode->columnStatsSpec()), + tableWriteNode->commitStrategy(), + std::make_shared( + tableWriteNode->sources().front()->outputType())); + } + + if (const auto* unnestNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + unnestNode->replicateVariables(), + unnestNode->unnestVariables(), + unnestNode->unnestNames(), + unnestNode->ordinalityName(), + unnestNode->markerName(), + std::make_shared( + unnestNode->sources().front()->outputType())); + } + + if (const auto* orderByNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + orderByNode->sortingKeys(), + orderByNode->sortingOrders(), + orderByNode->isPartial(), + std::make_shared( + orderByNode->sources().front()->outputType())); + } + + if (const auto* topNRowNumberNode = + dynamic_cast(traceNode)) { + const auto generateRowNumber = topNRowNumberNode->generateRowNumber(); + return std::make_shared( + nodeId, + topNRowNumberNode->rankFunction(), + topNRowNumberNode->partitionKeys(), + topNRowNumberNode->sortingKeys(), + topNRowNumberNode->sortingOrders(), + generateRowNumber ? std::make_optional( + topNRowNumberNode->outputType()->names().back()) + : std::nullopt, + topNRowNumberNode->limit(), + std::make_shared( + topNRowNumberNode->sources().front()->outputType())); + } + + if (const auto* exchangeNode = + dynamic_cast(traceNode)) { + // Check if it's a MergeExchangeNode + if (const auto* mergeExchangeNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + mergeExchangeNode->outputType(), + mergeExchangeNode->sortingKeys(), + mergeExchangeNode->sortingOrders(), + mergeExchangeNode->serdeKind()); + } + // Regular ExchangeNode + return std::make_shared( + nodeId, exchangeNode->outputType(), exchangeNode->serdeKind()); + } + + for (const auto& factory : traceNodeRegistry()) { + if (auto node = factory.second(traceNode, nodeId)) { + return node; + } + } + + VELOX_UNSUPPORTED( + fmt::format("Unsupported trace node: {}", traceNode->name())); +} + +void registerTraceNodeFactory( + const std::string& operatorType, + TraceNodeFactory&& factory) { + auto& registry = traceNodeRegistry(); + VELOX_CHECK_EQ(registry.count(operatorType), 0); + registry.emplace(operatorType, std::move(factory)); +} + +void registerDummySourceSerDe() { + auto& registry = DeserializationWithContextRegistryForSharedPtr(); + registry.Register("DummySource", DummySourceNode::create); } } // namespace facebook::velox::exec::trace diff --git a/velox/exec/TraceUtil.h b/velox/exec/TraceUtil.h index 2a4d6d870d1b..96837ff8b171 100644 --- a/velox/exec/TraceUtil.h +++ b/velox/exec/TraceUtil.h @@ -27,6 +27,45 @@ namespace facebook::velox::exec::trace { +static const std::vector kEmptySources; + +class DummySourceNode final : public core::PlanNode { + public: + explicit DummySourceNode(RowTypePtr outputType) + : PlanNode(""), outputType_(std::move(outputType)) {} + + const RowTypePtr& outputType() const override { + return outputType_; + } + + const std::vector& sources() const override { + return kEmptySources; + } + + std::string_view name() const override { + return "DummySource"; + } + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = "DummySource"; + obj["outputType"] = outputType_->serialize(); + return obj; + } + + static core::PlanNodePtr create(const folly::dynamic& obj, void* context) { + return std::make_shared( + ISerializable::deserialize(obj["outputType"])); + } + + private: + void addDetails(std::stringstream& stream) const override { + // Nothing to add. + } + + const RowTypePtr outputType_; +}; + /// Creates a directory to store the query trace metdata and data. void createTraceDirectory( const std::string& traceDir, @@ -135,4 +174,19 @@ folly::dynamic getTaskMetadata( /// Checks whether the operator can be traced. bool canTrace(const std::string& operatorType); + +/// Gets the specified the trace node from 'plan'. In the returned trace node, +/// we replace its source nodes with DummySourceNode for replay. +core::PlanNodePtr getTraceNode( + const core::PlanNodePtr& plan, + core::PlanNodeId nodeId); + +using TraceNodeFactory = std::function< + core::PlanNodePtr(const core::PlanNode*, const core::PlanNodeId&)>; + +void registerTraceNodeFactory( + const std::string& operatorType, + TraceNodeFactory&& factory); + +void registerDummySourceSerDe(); } // namespace facebook::velox::exec::trace diff --git a/velox/exec/Unnest.cpp b/velox/exec/Unnest.cpp index 3891de18550f..2155ef616ac0 100644 --- a/velox/exec/Unnest.cpp +++ b/velox/exec/Unnest.cpp @@ -19,6 +19,18 @@ #include "velox/vector/FlatVector.h" namespace facebook::velox::exec { +namespace { +#ifndef NDEBUG +void debugCheckOutput(const RowVectorPtr& output) { + for (auto i = 0; i < output->childrenSize(); ++i) { + VELOX_CHECK_EQ(output->size(), output->childAt(i)->size()); + } +} +#else +void debugCheckOutput(const RowVectorPtr& output) {} +#endif +} // namespace + Unnest::Unnest( int32_t operatorId, DriverCtx* driverCtx, @@ -29,22 +41,35 @@ Unnest::Unnest( operatorId, unnestNode->id(), "Unnest"), - withOrdinality_(unnestNode->withOrdinality()), - maxOutputSize_(outputBatchRows()) { + withOrdinality_(unnestNode->hasOrdinality()), + withMarker_(unnestNode->hasMarker()), + maxOutputSize_( + driverCtx->queryConfig().unnestSplitOutput() + ? outputBatchRows() + : std::numeric_limits::max()) { const auto& inputType = unnestNode->sources()[0]->outputType(); const auto& unnestVariables = unnestNode->unnestVariables(); for (const auto& variable : unnestVariables) { if (!variable->type()->isArray() && !variable->type()->isMap()) { - VELOX_UNSUPPORTED("Unnest operator supports only ARRAY and MAP types"); + VELOX_UNSUPPORTED( + "Unnest operator supports only ARRAY and MAP types, the actual type is {}", + variable->type()->toString()); } unnestChannels_.push_back(inputType->getChildIdx(variable->name())); } - unnestDecoded_.resize(unnestVariables.size()); + column_index_t checkOutputChannel = outputType_->size() - 1; + if (withMarker_) { + VELOX_CHECK_EQ( + outputType_->childAt(checkOutputChannel), + BOOLEAN(), + "Marker column should be BOOLEAN type."); + --checkOutputChannel; + } if (withOrdinality_) { VELOX_CHECK_EQ( - outputType_->children().back(), + outputType_->childAt(checkOutputChannel), BIGINT(), "Ordinality column should be BIGINT type."); } @@ -83,11 +108,13 @@ void Unnest::addInput(RowVectorPtr input) { if (unnestVector->typeKind() == TypeKind::ARRAY) { const auto* unnestBaseArray = currentDecoded.base()->as(); + VELOX_CHECK_NOT_NULL(unnestBaseArray); rawSizes_[channel] = unnestBaseArray->rawSizes(); rawOffsets_[channel] = unnestBaseArray->rawOffsets(); } else { - VELOX_CHECK(unnestVector->typeKind() == TypeKind::MAP); + VELOX_CHECK_EQ(unnestVector->typeKind(), TypeKind::MAP); const auto* unnestBaseMap = currentDecoded.base()->as(); + VELOX_CHECK_NOT_NULL(unnestBaseMap); rawSizes_[channel] = unnestBaseMap->rawSizes(); rawOffsets_[channel] = unnestBaseMap->rawOffsets(); } @@ -98,9 +125,7 @@ void Unnest::addInput(RowVectorPtr input) { for (auto row = 0; row < size; ++row) { if (!currentDecoded.isNullAt(row)) { const auto unnestSize = currentSizes[currentIndices[row]]; - if (rawMaxSizes_[row] < unnestSize) { - rawMaxSizes_[row] = unnestSize; - } + rawMaxSizes_[row] = std::max(rawMaxSizes_[row], unnestSize); } } } @@ -118,70 +143,89 @@ RowVectorPtr Unnest::getOutput() { return nullptr; } - const auto size = input_->size(); - VELOX_DCHECK_LT(nextInputRow_, size); + const auto numInputRows = input_->size(); + VELOX_DCHECK_LT(nextInputRow_, numInputRows); // Limit the number of input rows to keep output batch size within // 'maxOutputSize_'. When the output size is 'maxOutputSize_', the // first and last row might not be processed completely, and their output // might be split into multiple batches. - auto rowRange = extractRowRange(size); - if (rowRange.numElements == 0) { - // All arrays/maps are null or empty. - input_ = nullptr; - nextInputRow_ = 0; + const auto rowRange = extractRowRange(numInputRows); + if (rowRange.numInnerRows == 0) { + finishInput(); maybeFinishDrain(); return nullptr; } const auto output = generateOutput(rowRange); VELOX_CHECK_NOT_NULL(output); - if (rowRange.lastRowEnd.has_value()) { + if (rowRange.lastInnerRowEnd.has_value()) { // The last row is not processed completely. - firstRowStart_ = rowRange.lastRowEnd.value(); - nextInputRow_ += rowRange.size - 1; + firstInnerRowStart_ = rowRange.lastInnerRowEnd.value(); + nextInputRow_ += rowRange.numInputRows - 1; } else { - firstRowStart_ = 0; - nextInputRow_ += rowRange.size; + firstInnerRowStart_ = 0; + nextInputRow_ += rowRange.numInputRows; } - if (nextInputRow_ >= size) { - input_ = nullptr; - nextInputRow_ = 0; + if (nextInputRow_ >= numInputRows) { + finishInput(); } + debugCheckOutput(output); return output; } -Unnest::RowRange Unnest::extractRowRange(vector_size_t size) const { - vector_size_t numInput = 0; - vector_size_t numElements = 0; - std::optional lastRowEnd; - for (auto row = nextInputRow_; row < size; ++row) { - const bool isFirstRow = (row == nextInputRow_); - const vector_size_t remainingSize = - isFirstRow ? rawMaxSizes_[row] - firstRowStart_ : rawMaxSizes_[row]; - ++numInput; - if (numElements + remainingSize > maxOutputSize_) { +void Unnest::finishInput() { + input_ = nullptr; + nextInputRow_ = 0; + firstInnerRowStart_ = 0; +} + +Unnest::RowRange Unnest::extractRowRange(vector_size_t inputSize) const { + vector_size_t numInputRows{0}; + vector_size_t numInnerRows{0}; + std::optional lastInnerRowEnd; + bool hasEmptyUnnestValue{false}; + for (auto inputRow = nextInputRow_; inputRow < inputSize; ++inputRow) { + const bool isFirstRow = (inputRow == nextInputRow_); + vector_size_t remainingInnerRows = isFirstRow + ? rawMaxSizes_[inputRow] - firstInnerRowStart_ + : rawMaxSizes_[inputRow]; + if (rawMaxSizes_[inputRow] == 0) { + VELOX_CHECK_EQ(remainingInnerRows, 0); + hasEmptyUnnestValue = true; + if (withMarker_) { + remainingInnerRows = 1; + } + } + ++numInputRows; + if (numInnerRows + remainingInnerRows > maxOutputSize_) { // A single row's output needs to be split into multiple batches. // Determines the range to process the first and last rows partially, // rather than processing from 0 to 'rawMaxSizes_[row]'. if (isFirstRow) { - lastRowEnd = firstRowStart_ + maxOutputSize_ - numElements; + lastInnerRowEnd = firstInnerRowStart_ + maxOutputSize_ - numInnerRows; } else { - lastRowEnd = maxOutputSize_ - numElements; + lastInnerRowEnd = maxOutputSize_ - numInnerRows; } // Process maxOutputSize_ in this getOutput. - numElements = maxOutputSize_; + numInnerRows = maxOutputSize_; break; } // Process this row completely. - numElements += remainingSize; - if (numElements == maxOutputSize_) { + numInnerRows += remainingInnerRows; + if (numInnerRows == maxOutputSize_) { break; } } - VELOX_DCHECK_LE(numElements, maxOutputSize_); - return {nextInputRow_, numInput, lastRowEnd, numElements}; + VELOX_DCHECK_GE(numInnerRows, 0); + VELOX_DCHECK_LE(numInnerRows, maxOutputSize_); + return { + nextInputRow_, + numInputRows, + lastInnerRowEnd, + numInnerRows, + hasEmptyUnnestValue}; }; void Unnest::generateRepeatedColumns( @@ -189,26 +233,48 @@ void Unnest::generateRepeatedColumns( std::vector& outputs) { // Create "indices" buffer to repeat rows as many times as there are elements // in the array (or map) in unnestDecoded. - auto repeatedIndices = allocateIndices(range.numElements, pool()); - auto* rawRepeatedIndices = repeatedIndices->asMutable(); - vector_size_t index = 0; - VELOX_DCHECK_GT(range.size, 0); + auto repeatedIndices = allocateIndices(range.numInnerRows, pool()); + vector_size_t* rawRepeatedIndices = + repeatedIndices->asMutable(); + + const bool generateMarker = withMarker_ && range.hasEmptyUnnestValue; + vector_size_t index{0}; + VELOX_CHECK_GT(range.numInputRows, 0); // Record the row number to process. - range.forEachRow( - [&](vector_size_t row, vector_size_t /*start*/, vector_size_t size) { - std::fill( - rawRepeatedIndices + index, rawRepeatedIndices + index + size, row); - index += size; - }, - rawMaxSizes_, - firstRowStart_); + if (generateMarker) { + range.forEachRow( + [&](vector_size_t row, vector_size_t /*start*/, vector_size_t size) { + if (FOLLY_UNLIKELY(size == 0)) { + rawRepeatedIndices[index++] = row; + } else { + std::fill( + rawRepeatedIndices + index, + rawRepeatedIndices + index + size, + row); + index += size; + } + }, + rawMaxSizes_, + firstInnerRowStart_); + } else { + range.forEachRow( + [&](vector_size_t row, vector_size_t /*start*/, vector_size_t size) { + std::fill( + rawRepeatedIndices + index, + rawRepeatedIndices + index + size, + row); + index += size; + }, + rawMaxSizes_, + firstInnerRowStart_); + } // Wrap "replicated" columns in a dictionary using 'repeatedIndices'. for (const auto& projection : identityProjections_) { outputs.at(projection.outputChannel) = BaseVector::wrapInDictionary( - nullptr /*nulls*/, + /*nulls=*/nullptr, repeatedIndices, - range.numElements, + range.numInnerRows, input_->childAt(projection.inputChannel)); } } @@ -216,11 +282,11 @@ void Unnest::generateRepeatedColumns( const Unnest::UnnestChannelEncoding Unnest::generateEncodingForChannel( column_index_t channel, const RowRange& range) { - BufferPtr elementIndices = allocateIndices(range.numElements, pool()); - auto* rawElementIndices = elementIndices->asMutable(); + BufferPtr innerRowIndices = allocateIndices(range.numInnerRows, pool()); + auto* rawInnerRowIndices = innerRowIndices->asMutable(); - auto nulls = allocateNulls(range.numElements, pool()); - auto rawNulls = nulls->asMutable(); + auto nulls = allocateNulls(range.numInnerRows, pool()); + auto* rawNulls = nulls->asMutable(); auto& currentDecoded = unnestDecoded_[channel]; auto* currentSizes = rawSizes_[channel]; @@ -230,12 +296,15 @@ const Unnest::UnnestChannelEncoding Unnest::generateEncodingForChannel( // Make dictionary index for elements column since they may be out of order. vector_size_t index = 0; bool identityMapping = true; - VELOX_DCHECK_GT(range.size, 0); + VELOX_DCHECK_GT(range.numInputRows, 0); range.forEachRow( [&](vector_size_t row, vector_size_t start, vector_size_t size) { const auto end = start + size; - if (!currentDecoded.isNullAt(row)) { + if (size == 0 && withMarker_) { + identityMapping = false; + bits::setNull(rawNulls, index++, true); + } else if (!currentDecoded.isNullAt(row)) { const auto offset = currentOffsets[currentIndices[row]]; const auto unnestSize = currentSizes[currentIndices[row]]; // The 'identityMapping' is false when there exists a partially @@ -244,47 +313,100 @@ const Unnest::UnnestChannelEncoding Unnest::generateEncodingForChannel( unnestSize < end) { identityMapping = false; } - auto currentUnnestSize = std::min(end, unnestSize); - for (auto i = start; i < currentUnnestSize; i++) { - rawElementIndices[index++] = offset + i; + const auto currentUnnestSize = std::min(end, unnestSize); + for (auto i = start; i < currentUnnestSize; ++i) { + rawInnerRowIndices[index++] = offset + i; } - for (auto i = std::max(start, currentUnnestSize); i < end; ++i) { bits::setNull(rawNulls, index++, true); } } else if (size > 0) { identityMapping = false; - for (auto i = start; i < end; ++i) { bits::setNull(rawNulls, index++, true); } } }, rawMaxSizes_, - firstRowStart_); + firstInnerRowStart_); - return {elementIndices, nulls, identityMapping}; + return {innerRowIndices, nulls, identityMapping}; } VectorPtr Unnest::generateOrdinalityVector(const RowRange& range) { + VELOX_DCHECK_GT(range.numInputRows, 0); + auto ordinalityVector = BaseVector::create>( - BIGINT(), range.numElements, pool()); + BIGINT(), range.numInnerRows, pool()); // Set the ordinality at each result row to be the index of the element in // the original array (or map) plus one. auto* rawOrdinality = ordinalityVector->mutableRawValues(); + const bool hasMarker = withMarker_ && range.hasEmptyUnnestValue; + if (!hasMarker) { + range.forEachRow( + [&](vector_size_t /*row*/, vector_size_t start, vector_size_t size) { + std::iota(rawOrdinality, rawOrdinality + size, start + 1); + rawOrdinality += size; + }, + rawMaxSizes_, + firstInnerRowStart_); + } else { + range.forEachRow( + [&](vector_size_t /*row*/, vector_size_t start, vector_size_t size) { + if (FOLLY_LIKELY(size > 0)) { + std::iota(rawOrdinality, rawOrdinality + size, start + 1); + rawOrdinality += size; + } else { + // Set ordinality to 0 for output row with empty unnest value. + // + // NOTE: for non-empty unnest value row, the ordinality starts + // from 1. + VELOX_DCHECK_EQ(size, 0); + *rawOrdinality++ = 0; + } + }, + rawMaxSizes_, + firstInnerRowStart_); + } + return ordinalityVector; +} - VELOX_DCHECK_GT(range.size, 0); +VectorPtr Unnest::generateMarkerVector(const RowRange& range) { + VELOX_CHECK(withMarker_); + VELOX_DCHECK_GT(range.numInputRows, 0); + if (!range.hasEmptyUnnestValue) { + return BaseVector::createConstant( + BOOLEAN(), true, range.numInnerRows, pool()); + } + + // Create a vector with all elements set to true initially assuming most + // output rows have non-empty unnest values. + auto markerBuffer = + velox::AlignedBuffer::allocate(range.numInnerRows, pool(), true); + auto markerVector = std::make_shared>( + pool(), + /*type=*/BOOLEAN(), + /*nulls=*/nullptr, + range.numInnerRows, + /*values=*/std::move(markerBuffer), + /*stringBuffers=*/std::vector{}); + // Set each output row with empty unnest values to false. + auto* const rawMarker = markerVector->mutableRawValues(); + size_t index{0}; range.forEachRow( [&](vector_size_t /*row*/, vector_size_t start, vector_size_t size) { - std::iota(rawOrdinality, rawOrdinality + size, start + 1); - rawOrdinality += size; + if (size > 0) { + index += size; + } else { + VELOX_DCHECK_EQ(size, 0); + bits::setBit(rawMarker, index++, false); + } }, rawMaxSizes_, - firstRowStart_); - - return ordinalityVector; + firstInnerRowStart_); + return markerVector; } RowVectorPtr Unnest::generateOutput(const RowRange& range) { @@ -292,39 +414,42 @@ RowVectorPtr Unnest::generateOutput(const RowRange& range) { generateRepeatedColumns(range, outputs); // Create unnest columns. - vector_size_t outputsIndex = identityProjections_.size(); + column_index_t outputColumnIndex = identityProjections_.size(); for (auto channel = 0; channel < unnestChannels_.size(); ++channel) { const auto unnestChannelEncoding = generateEncodingForChannel(channel, range); - auto& currentDecoded = unnestDecoded_[channel]; + const auto& currentDecoded = unnestDecoded_[channel]; if (currentDecoded.base()->typeKind() == TypeKind::ARRAY) { // Construct unnest column using Array elements wrapped using above // created dictionary. const auto* unnestBaseArray = currentDecoded.base()->as(); - outputs[outputsIndex++] = unnestChannelEncoding.wrap( - unnestBaseArray->elements(), range.numElements); + outputs[outputColumnIndex++] = unnestChannelEncoding.wrap( + unnestBaseArray->elements(), range.numInnerRows); } else { // Construct two unnest columns for Map keys and values vectors wrapped // using above created dictionary. const auto* unnestBaseMap = currentDecoded.base()->as(); - outputs[outputsIndex++] = unnestChannelEncoding.wrap( - unnestBaseMap->mapKeys(), range.numElements); - outputs[outputsIndex++] = unnestChannelEncoding.wrap( - unnestBaseMap->mapValues(), range.numElements); + outputs[outputColumnIndex++] = unnestChannelEncoding.wrap( + unnestBaseMap->mapKeys(), range.numInnerRows); + outputs[outputColumnIndex++] = unnestChannelEncoding.wrap( + unnestBaseMap->mapValues(), range.numInnerRows); } } + // 'Ordinality' and 'EmptyUnnestValue' columns are always at the end. if (withOrdinality_) { - // Ordinality column is always at the end. - outputs.back() = generateOrdinalityVector(range); + outputs[outputColumnIndex++] = generateOrdinalityVector(range); + } + if (withMarker_) { + outputs[outputColumnIndex++] = generateMarkerVector(range); } return std::make_shared( pool(), outputType_, - BufferPtr(nullptr), - range.numElements, + /*nulls=*/nullptr, + range.numInnerRows, std::move(outputs)); } @@ -332,7 +457,11 @@ VectorPtr Unnest::UnnestChannelEncoding::wrap( const VectorPtr& base, vector_size_t wrapSize) const { if (identityMapping) { - return base; + if (wrapSize == base->size()) { + return base; + } + auto* rawIndices = indices->asMutable(); + return base->slice(rawIndices[0], wrapSize); } const auto result = @@ -360,29 +489,31 @@ bool Unnest::isFinished() { } void Unnest::RowRange::forEachRow( - std::function func, - const vector_size_t* const rawMaxSizes, - vector_size_t firstRowStart) const { + vector_size_t /*size*/)>& func, + const vector_size_t* rawMaxSizes, + vector_size_t firstInnerRowStart) const { // Process the first row. - const auto firstRowEnd = size == 1 && lastRowEnd.has_value() - ? lastRowEnd.value() - : rawMaxSizes[start]; - func(start, firstRowStart, firstRowEnd - firstRowStart); + const auto firstInnerRowEnd = numInputRows == 1 && lastInnerRowEnd.has_value() + ? lastInnerRowEnd.value() + : rawMaxSizes[startInputRow]; + func( + startInputRow, firstInnerRowStart, firstInnerRowEnd - firstInnerRowStart); + const auto lastInputRow = startInputRow + numInputRows - 1; // Process the middle rows. - for (auto row = start + 1; row < start + size - 1; ++row) { - func(row, 0, rawMaxSizes[row]); + for (auto inputRow = startInputRow + 1; inputRow < lastInputRow; ++inputRow) { + func(inputRow, 0, rawMaxSizes[inputRow]); } // Process the last row if exists. - if (size > 1) { - if (lastRowEnd.has_value()) { - func(start + size - 1, 0, lastRowEnd.value()); + if (numInputRows > 1) { + if (lastInnerRowEnd.has_value()) { + func(lastInputRow, 0, lastInnerRowEnd.value()); } else { - func(start + size - 1, 0, rawMaxSizes[start + size - 1]); + func(lastInputRow, 0, rawMaxSizes[lastInputRow]); } } } diff --git a/velox/exec/Unnest.h b/velox/exec/Unnest.h index e67dd6ed875f..e1fb967525c3 100644 --- a/velox/exec/Unnest.h +++ b/velox/exec/Unnest.h @@ -46,21 +46,20 @@ class Unnest : public Operator { void maybeFinishDrain(); // Represents the range of rows to process and indicates that the first and - // last - // rows may need to be processed partially to match the configured output + // last rows may need to be processed partially to match the configured output // batch size. When processing a single row, the range is from - // 'firstRowStart_' to 'lastRowEnd'. For multiple rows, the range for the - // first row is from 'firstRowStart_' to 'rawMaxSizes_[firstRow]', and for the - // last row, it is from 0 to 'lastRowEnd', unless the last row is processed - // fully, in which case' rawMaxSizes_[lastRow]' is used as the end of the last - // row. + // 'firstInnerRowStart_' to 'lastRowEnd'. For multiple rows, the range for the + // first row is from 'firstInnerRowStart_' to 'rawMaxSizes_[firstRow]', and + // for the last row, it is from 0 to 'lastRowEnd', unless the last row is + // processed fully, in which case' rawMaxSizes_[lastRow]' is used as the end + // of the last row. // // Single row: - // firstRowStart_ firstRowEnd = lastRowEnd + // firstInnerRowStart_ firstRowEnd = lastRowEnd //---|----------------|--- start, size = 1 // // Multiple rows: - // firstRowStart_ firstRowEnd = rawMaxSizes_[start] + // firstInnerRowStart_ firstRowEnd = rawMaxSizes_[start] //---|-------------------| start //----------------------- //----------------------- @@ -73,34 +72,40 @@ class Unnest : public Operator { // number in the '[start, start + size)' range, 'start' is the row number to // start processing, and 'size' is the number of rows to process.. // @param rawMaxSizes Used to compute the end of each row. - // @param firstRowStart The index to start processing the first row. Same - // with Unnest member firstRowStart_. + // @param firstInnerRowStart The index to start processing the first row. + // Same with Unnest member firstInnerRowStart_. void forEachRow( - std::function func, - const vector_size_t* const rawMaxSizes, - vector_size_t firstRowStart) const; + vector_size_t /*size*/)>& func, + const vector_size_t* rawMaxSizes, + vector_size_t firstInnerRowStart) const; - // First input row to be included in the output. - const vector_size_t start; + // First input row in 'input_' to be included in the output. + const vector_size_t startInputRow; - // Number of input rows to be included in the output. - const vector_size_t size; + // Number of input rows to be included in the output starting from + // 'startInputRow'. + const vector_size_t numInputRows; - // The processing of the last input row starts at index 'firstRowStart_' or - // 0, depending on whether it is the first row being processed, and ends at - // 'lastRowEnd'. It is nullopt when the last row is processed completely. - const std::optional lastRowEnd; + // The processing of the last input row starts at index + // 'firstInnerRowStart_' or 0, depending on whether it is the first row + // being processed, and ends at 'lastRowEnd'. It is nullopt when the last + // row is processed completely. + const std::optional lastInnerRowEnd; // Total number of inner rows in the range. - const vector_size_t numElements; + const vector_size_t numInnerRows; + + // True if the range has input row in which all the unnest columns are + // either null or empty. + const bool hasEmptyUnnestValue; }; // Extract the range of rows to process. // @param size The size of input RowVector. - RowRange extractRowRange(vector_size_t size) const; + RowRange extractRowRange(vector_size_t inputSize) const; // Generate output for 'rowRange' represented rows. // @param rowRange Range of rows to process. @@ -129,18 +134,27 @@ class Unnest : public Operator { // Invoked by generateOutput for the ordinality column. VectorPtr generateOrdinalityVector(const RowRange& rowRange); + // Invoked by generateOutput for the marker column. + VectorPtr generateMarkerVector(const RowRange& rowRange); + + // Invoked when finish one input batch processing to reset the internal + // execution state for the next batch. + void finishInput(); + const bool withOrdinality_; + const bool withMarker_; + // The maximum number of output batch rows. + const vector_size_t maxOutputSize_; + std::vector unnestChannels_; std::vector unnestDecoded_; - // The maximum number of output batch rows. - const uint32_t maxOutputSize_; BufferPtr maxSizes_; vector_size_t* rawMaxSizes_{nullptr}; // The index to start processing the first row. - vector_size_t firstRowStart_ = 0; + vector_size_t firstInnerRowStart_{0}; std::vector rawSizes_; std::vector rawOffsets_; diff --git a/velox/exec/Values.cpp b/velox/exec/Values.cpp index 2660732979ae..dc544fba799e 100644 --- a/velox/exec/Values.cpp +++ b/velox/exec/Values.cpp @@ -45,8 +45,9 @@ void Values::initialize() { // If this is parallelizable, copy the values to prevent Vectors from // being shared across threads. Note that the contract in ValuesNode is // that this should only be enabled for testing. - values_.emplace_back(std::static_pointer_cast( - vector->testingCopyPreserveEncodings())); + values_.emplace_back( + std::static_pointer_cast( + vector->testingCopyPreserveEncodings())); } else { values_.emplace_back(vector); } diff --git a/velox/exec/VectorHasher.cpp b/velox/exec/VectorHasher.cpp index dc4c921fb035..18d00146ed72 100644 --- a/velox/exec/VectorHasher.cpp +++ b/velox/exec/VectorHasher.cpp @@ -23,32 +23,35 @@ namespace facebook::velox::exec { -#define VALUE_ID_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ - [&]() { \ - switch (typeKind) { \ - case TypeKind::BOOLEAN: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::TINYINT: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::SMALLINT: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::INTEGER: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::BIGINT: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - case TypeKind::VARCHAR: \ - case TypeKind::VARBINARY: { \ - return TEMPLATE_FUNC(__VA_ARGS__); \ - } \ - default: \ - VELOX_UNREACHABLE( \ - "Unsupported value ID type: ", mapTypeKindToName(typeKind)); \ - } \ +#define VALUE_ID_TYPE_DISPATCH(TEMPLATE_FUNC, typeKind, ...) \ + [&]() { \ + switch (typeKind) { \ + case TypeKind::BOOLEAN: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::TINYINT: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::SMALLINT: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::INTEGER: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::BIGINT: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::VARCHAR: \ + case TypeKind::VARBINARY: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + case TypeKind::TIMESTAMP: { \ + return TEMPLATE_FUNC(__VA_ARGS__); \ + } \ + default: \ + VELOX_UNREACHABLE( \ + "Unsupported value ID type: ", TypeKindName::toName(typeKind)); \ + } \ }() namespace { @@ -82,7 +85,6 @@ void VectorHasher::hashValues( const SelectivityVector& rows, bool mix, uint64_t* result) { - using T = typename TypeTraits::NativeType; if (decoded_.isConstantMapping()) { auto hash = decoded_.isNullAt(rows.begin()) ? kNullHash @@ -349,6 +351,8 @@ bool VectorHasher::makeValueIdsDecoded( bool VectorHasher::computeValueIds( const SelectivityVector& rows, raw_vector& result) { + checkTypeSupportsValueIds(); + return VALUE_ID_TYPE_DISPATCH(makeValueIds, typeKind_, rows, result.data()); } @@ -359,6 +363,8 @@ bool VectorHasher::computeValueIdsForRows( int32_t nullByte, uint8_t nullMask, raw_vector& result) { + checkTypeSupportsValueIds(); + return VALUE_ID_TYPE_DISPATCH( makeValueIdsForRows, typeKind_, @@ -531,6 +537,8 @@ void VectorHasher::lookupValueIds( SelectivityVector& rows, ScratchMemory& scratchMemory, raw_vector& result) const { + checkTypeSupportsValueIds(); + scratchMemory.decoded.decode(values, rows); VALUE_ID_TYPE_DISPATCH( lookupValueIdsTyped, @@ -550,7 +558,7 @@ void VectorHasher::hash( result[row] = mix ? bits::hashMix(result[row], kNullHash) : kNullHash; }); } else { - if (type_->providesCustomComparison()) { + if (typeProvidesCustomComparison_) { VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( hashValues, true, typeKind_, rows, mix, result.data()); } else { @@ -579,7 +587,7 @@ void VectorHasher::precompute(const BaseVector& value) { const SelectivityVector rows(1, true); decoded_.decode(value, rows); - if (type_->providesCustomComparison()) { + if (typeProvidesCustomComparison_) { precomputedHash_ = VELOX_DYNAMIC_TEMPLATE_TYPE_DISPATCH( hashOne, true, typeKind_, decoded_, 0); } else { @@ -594,6 +602,8 @@ void VectorHasher::analyze( int32_t offset, int32_t nullByte, uint8_t nullMask) { + checkTypeSupportsValueIds(); + VALUE_ID_TYPE_DISPATCH( analyzeTyped, typeKind_, groups, numGroups, offset, nullByte, nullMask); } @@ -666,6 +676,10 @@ void VectorHasher::setRangeOverflow() { std::unique_ptr VectorHasher::getFilter( bool nullAllowed) const { + if (typeProvidesCustomComparison_) { + return nullptr; + } + switch (typeKind_) { case TypeKind::TINYINT: [[fallthrough]]; @@ -736,6 +750,7 @@ void extendRange( case TypeKind::BIGINT: case TypeKind::VARCHAR: case TypeKind::VARBINARY: + case TypeKind::TIMESTAMP: extendRange(reserve, min, max); break; @@ -764,6 +779,12 @@ void VectorHasher::cardinality( int32_t reservePct, uint64_t& asRange, uint64_t& asDistincts) { + if (!typeSupportsValueIds()) { + asRange = kRangeTooLarge; + asDistincts = kRangeTooLarge; + return; + } + if (typeKind_ == TypeKind::BOOLEAN) { hasRange_ = true; asRange = 3; @@ -805,6 +826,8 @@ uint64_t VectorHasher::enableValueIds(uint64_t multiplier, int32_t reservePct) { typeKind_, TypeKind::BOOLEAN, "A boolean VectorHasher should always be by range"); + checkTypeSupportsValueIds(); + multiplier_ = multiplier; rangeSize_ = addIdReserve(uniqueValues_.size(), reservePct) + 1; isRange_ = false; @@ -818,6 +841,8 @@ uint64_t VectorHasher::enableValueIds(uint64_t multiplier, int32_t reservePct) { uint64_t VectorHasher::enableValueRange( uint64_t multiplier, int32_t reservePct) { + checkTypeSupportsValueIds(); + multiplier_ = multiplier; VELOX_CHECK_LE(0, reservePct); VELOX_CHECK(hasRange_); @@ -846,7 +871,7 @@ void VectorHasher::copyStatsFrom(const VectorHasher& other) { uniqueValues_ = other.uniqueValues_; } -void VectorHasher::merge(const VectorHasher& other) { +void VectorHasher::merge(const VectorHasher& other, size_t maxNumDistinct) { if (typeKind_ == TypeKind::BOOLEAN) { return; } @@ -864,18 +889,25 @@ void VectorHasher::merge(const VectorHasher& other) { } else { setRangeOverflow(); } - if (!distinctOverflow_ && !other.distinctOverflow_) { - // Unique values can be merged without dispatch on type. All the - // merged hashers must stay live for string type columns. - for (UniqueValue value : other.uniqueValues_) { - // Assign a new id at end of range for the case 'value' is not - // in 'uniqueValues_'. We do not set overflow here because the - // memory is already allocated and there is a known cap on size. - value.setId(uniqueValues_.size() + 1); - uniqueValues_.insert(value); - } - } else { + if (distinctOverflow_) { + return; + } + if (other.distinctOverflow_) { setDistinctOverflow(); + return; + } + // Unique values can be merged without dispatch on type. All the + // merged hashers must stay live for string type columns. + for (UniqueValue value : other.uniqueValues_) { + // Assign a new id at end of range for the case 'value' is not + // in 'uniqueValues_'. We do not set overflow here because the + // memory is already allocated and there is a known cap on size. + value.setId(uniqueValues_.size() + 1); + if (uniqueValues_.insert(value).second && + uniqueValues_.size() > maxNumDistinct) { + setDistinctOverflow(); + break; + } } } diff --git a/velox/exec/VectorHasher.h b/velox/exec/VectorHasher.h index 6a18df611735..29628099c17e 100644 --- a/velox/exec/VectorHasher.h +++ b/velox/exec/VectorHasher.h @@ -131,8 +131,15 @@ class VectorHasher { static constexpr int32_t kNoLimit = -1; VectorHasher(TypePtr type, column_index_t channel) - : channel_(channel), type_(std::move(type)), typeKind_(type_->kind()) { - if (typeKind_ == TypeKind::BOOLEAN) { + : channel_(channel), + type_(std::move(type)), + typeKind_(type_->kind()), + typeProvidesCustomComparison_(type_->providesCustomComparison()) { + if (!typeSupportsValueIds()) { + // Ensure any range or unique value based hashing is disabled. + setRangeOverflow(); + setDistinctOverflow(); + } else if (typeKind_ == TypeKind::BOOLEAN) { // We do not need samples to know the cardinality or limits of a bool // vector. hasRange_ = true; @@ -235,15 +242,39 @@ class VectorHasher { ScratchMemory& scratchMemory, raw_vector& result) const; - // Returns true if either range or distinct values have not overflowed. + // Returns true if either range or distinct values have not overflowed and the + // type doesn't support custom comparison. bool mayUseValueIds() const { - return hasRange_ || !distinctOverflow_; + return typeSupportsValueIds() && (hasRange_ || !distinctOverflow_); } // Returns an instance of the filter corresponding to a set of unique values. // Returns null if distinctOverflow_ is true. std::unique_ptr getFilter(bool nullAllowed) const; + bool supportsBloomFilter() const { + if (typeProvidesCustomComparison_) { + return false; + } + switch (typeKind_) { + // Smaller integers would never overflow 100'000 distinct values. + case TypeKind::INTEGER: + case TypeKind::BIGINT: + return distinctOverflow_; + default: + return false; + } + } + + void setBloomFilter(common::FilterPtr filter) { + VELOX_DCHECK(supportsBloomFilter()); + bloomFilter_ = std::move(filter); + } + + const common::FilterPtr& getBloomFilter() const { + return bloomFilter_; + } + void resetStats() { uniqueValues_.clear(); uniqueValuesStorage_.clear(); @@ -284,8 +315,12 @@ class VectorHasher { return isRange_; } - static bool typeKindSupportsValueIds(TypeKind kind) { - switch (kind) { + bool typeSupportsValueIds() const { + if (typeProvidesCustomComparison_) { + return false; + } + + switch (typeKind_) { case TypeKind::BOOLEAN: case TypeKind::TINYINT: case TypeKind::SMALLINT: @@ -293,6 +328,7 @@ class VectorHasher { case TypeKind::BIGINT: case TypeKind::VARCHAR: case TypeKind::VARBINARY: + case TypeKind::TIMESTAMP: return true; default: return false; @@ -301,7 +337,7 @@ class VectorHasher { // Merges the value ids information of 'other' into 'this'. Ranges // and distinct values are unioned. - void merge(const VectorHasher& other); + void merge(const VectorHasher& other, size_t maxNumDistinct); // true if no values have been added. bool empty() const { @@ -332,6 +368,10 @@ class VectorHasher { return value; } + inline int64_t toInt64(Timestamp timestamp) const { + return timestamp.toMillis(); + } + // Sets the data statistics from 'other'. Does not set the mapping mode. void copyStatsFrom(const VectorHasher& other); @@ -530,6 +570,13 @@ class VectorHasher { void setRangeOverflow(); + inline void checkTypeSupportsValueIds() const { + VELOX_DCHECK( + typeSupportsValueIds(), + "Value IDs cannot be used, the type {} is not supported.", + type_->toString()); + } + static inline bool isNullAt(const char* group, int32_t nullByte, uint8_t nullMask) { return (group[nullByte] & nullMask) != 0; @@ -546,6 +593,7 @@ class VectorHasher { const column_index_t channel_; const TypePtr type_; const TypeKind typeKind_; + const bool typeProvidesCustomComparison_; DecodedVector decoded_; raw_vector cachedHashes_; @@ -582,6 +630,8 @@ class VectorHasher { // Memory for unique string values. std::vector uniqueValuesStorage_; uint64_t distinctStringsBytes_ = 0; + + common::FilterPtr bloomFilter_; }; template <> @@ -661,10 +711,29 @@ inline uint64_t VectorHasher::lookupValueId(StringView value) const { return kUnmappable; } +template <> +inline uint64_t VectorHasher::lookupValueId(Timestamp timestamp) const { + return timestamp.getNanos() % 1'000'000 != 0 + ? kUnmappable + : lookupValueId(timestamp.toMillis()); +} + template <> inline uint64_t VectorHasher::valueId(bool value) { return value ? 2 : 1; } +template <> +inline uint64_t VectorHasher::valueId(Timestamp value) { + if (FOLLY_UNLIKELY( + value.getNanos() % Timestamp::kNanosecondsInMillisecond != 0)) { + // The timestamp is in nanosecond or microsecond precision. The values are + // not mappable to milliseconds without precision loss. + setRangeOverflow(); + setDistinctOverflow(); + return kUnmappable; + } + return valueId(value.toMillis()); +} template <> inline bool VectorHasher::tryMapToRange( diff --git a/velox/exec/Window.cpp b/velox/exec/Window.cpp index 1f1a91331a95..358290b4e92c 100644 --- a/velox/exec/Window.cpp +++ b/velox/exec/Window.cpp @@ -18,10 +18,22 @@ #include "velox/exec/PartitionStreamingWindowBuild.h" #include "velox/exec/RowsStreamingWindowBuild.h" #include "velox/exec/SortWindowBuild.h" +#include "velox/exec/SubPartitionedSortWindowBuild.h" #include "velox/exec/Task.h" namespace facebook::velox::exec { +namespace { +common::PrefixSortConfig makePrefixSortConfig( + const core::QueryConfig& queryConfig) { + return common::PrefixSortConfig{ + queryConfig.prefixSortNormalizedKeyMaxBytes(), + queryConfig.prefixSortMinRows(), + queryConfig.prefixSortMaxStringPrefixLength()}; +} + +} // namespace + Window::Window( int32_t operatorId, DriverCtx* driverCtx, @@ -55,16 +67,28 @@ Window::Window( windowNode, pool(), spillConfig, &nonReclaimableSection_); } } else { - windowBuild_ = std::make_unique( - windowNode, - pool(), - common::PrefixSortConfig{ - driverCtx->queryConfig().prefixSortNormalizedKeyMaxBytes(), - driverCtx->queryConfig().prefixSortMinRows(), - driverCtx->queryConfig().prefixSortMaxStringPrefixLength()}, - spillConfig, - &nonReclaimableSection_, - &spillStats_); + if (auto numSubPartitions = + operatorCtx_->driverCtx()->queryConfig().windowNumSubPartitions(); + numSubPartitions > 1) { + windowBuild_ = std::make_unique( + windowNode, + numSubPartitions, + pool(), + makePrefixSortConfig(driverCtx->queryConfig()), + spillConfig, + &nonReclaimableSection_, + &stats_, + spillStats_.get()); + } else { + windowBuild_ = std::make_unique( + windowNode, + pool(), + makePrefixSortConfig(driverCtx->queryConfig()), + spillConfig, + &nonReclaimableSection_, + &stats_, + spillStats_.get()); + } } } @@ -159,10 +183,11 @@ Window::WindowFrame Window::createWindowFrame( return std::make_optional( FrameChannelArg{kConstantChannel, nullptr, value}); } else { - return std::make_optional(FrameChannelArg{ - frameChannel, - BaseVector::create(frame->type(), 0, pool()), - std::nullopt}); + return std::make_optional( + FrameChannelArg{ + frameChannel, + BaseVector::create(frame->type(), 0, pool()), + std::nullopt}); } }; @@ -196,14 +221,15 @@ void Window::createWindowFunctions() { } } - windowFunctions_.push_back(WindowFunction::create( - windowNodeFunction.functionCall->name(), - functionArgs, - windowNodeFunction.functionCall->type(), - windowNodeFunction.ignoreNulls, - operatorCtx_->pool(), - &stringAllocator_, - operatorCtx_->driverCtx()->queryConfig())); + windowFunctions_.push_back( + WindowFunction::create( + windowNodeFunction.functionCall->name(), + functionArgs, + windowNodeFunction.functionCall->type(), + windowNodeFunction.ignoreNulls, + operatorCtx_->pool(), + &stringAllocator_, + operatorCtx_->driverCtx()->queryConfig())); windowFrames_.push_back( createWindowFrame(windowNode_, windowNodeFunction.frame, inputType)); @@ -686,6 +712,9 @@ RowVectorPtr Window::getOutput() { const auto numRowsLeft = numRows_ - numProcessedRows_; if (numRowsLeft == 0) { + if (windowBuild_ != nullptr) { + windowBuild_->release(); + } return nullptr; } diff --git a/velox/exec/Window.h b/velox/exec/Window.h index 8cbd3b1ee85c..563169ff4226 100644 --- a/velox/exec/Window.h +++ b/velox/exec/Window.h @@ -67,6 +67,11 @@ class Window : public Operator { void reclaim(uint64_t targetBytes, memory::MemoryReclaimer::Stats& stats) override; + /// Runtime statistics holding total number of batches read from spilled data. + /// 0 if no spilling occurred. + static inline const std::string kWindowSpillReadNumBatches{ + "windowSpillReadNumBatches"}; + private: // Used for k preceding/following frames. Index is the column index if k is a // column. value is used to read column values from the column index when k diff --git a/velox/exec/WindowBuild.h b/velox/exec/WindowBuild.h index 01c470803ed7..b047254c752e 100644 --- a/velox/exec/WindowBuild.h +++ b/velox/exec/WindowBuild.h @@ -70,10 +70,18 @@ class WindowBuild { /// Returns the average size of input rows in bytes stored in the data /// container of the WindowBuild. - std::optional estimateRowSize() { + virtual std::optional estimateRowSize() { return data_->estimateRowSize(); } + /// Releases the memory held by the window build. This is called by the + /// window operator when all rows have been processed. + void release() { + if (data_) { + data_->clear(); + } + } + void setNumRowsPerOutput(vector_size_t numRowsPerOutput) { numRowsPerOutput_ = numRowsPerOutput; } diff --git a/velox/exec/benchmarks/AtomicsBench.cpp b/velox/exec/benchmarks/AtomicsBench.cpp index 74343e9303f2..c86991ef6fcb 100644 --- a/velox/exec/benchmarks/AtomicsBench.cpp +++ b/velox/exec/benchmarks/AtomicsBench.cpp @@ -16,16 +16,39 @@ #include #include -#include #include #include #include +#include "velox/common/base/Portability.h" #include "velox/exec/OneWayStatusFlag.h" -using namespace ::testing; -using namespace facebook::velox; -static const size_t kNumThreads = 88; -static const size_t kNumIterations = 10000; +namespace { + +using facebook::velox::exec::OneWayStatusFlag; +constexpr size_t kNumThreads = 88; +constexpr size_t kNumIterations = 10000; + +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +class OneWayStatusFlagUnsafe { + public: + bool check() const { + return fastStatus_ || atomicStatus_.load(); + } + + void set() { + if (!fastStatus_) { + atomicStatus_.store(true); + fastStatus_ = true; + } + } + + private: + bool fastStatus_{false}; + std::atomic_bool atomicStatus_{false}; +}; + +#endif void runParallelUpdates( std::function callback, @@ -46,12 +69,28 @@ void runParallelUpdates( } } -BENCHMARK(std_atomic_bool_write) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_write_seq_cst) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.store(true); + bool dummy{}; + folly::doNotOptimizeAway(dummy); + } + }, + kNumThreads, // Threads + kNumIterations); // Iterations per thread +} + +BENCHMARK(std_atomic_bool_write_release) { + std::atomic_bool flag{false}; + runParallelUpdates( + [&](size_t iters) { + for (size_t i = 0; i < iters; ++i) { + flag.store(true, std::memory_order_release); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads @@ -59,76 +98,82 @@ BENCHMARK(std_atomic_bool_write) { } BENCHMARK(std_atomic_bool_write_relaxed) { - std::atomic flag{false}; + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.store(true, std::memory_order_relaxed); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_read_write_relaxed) { - std::atomic flag{false}; +BENCHMARK(one_way_flag_write) { + OneWayStatusFlag flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - if (!flag.load(std::memory_order_relaxed)) { - flag.store(true, std::memory_order_acq_rel); - } + flag.set(); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(one_way_flag_write) { - exec::OneWayStatusFlag flag; +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +BENCHMARK(one_way_flag_unsafe_write) { + OneWayStatusFlagUnsafe flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { flag.set(); + bool dummy{}; + folly::doNotOptimizeAway(dummy); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } +#endif + // Read Benchmarks -BENCHMARK(std_atomic_bool_read) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_seq_cst) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway(flag.load()); + folly::doNotOptimizeAway(flag.load(std::memory_order_seq_cst)); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_relaxed_read) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_acquire) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway(flag.load(std::memory_order_relaxed)); + folly::doNotOptimizeAway(flag.load(std::memory_order_acquire)); } }, kNumThreads, // Threads kNumIterations); // Iterations per thread } -BENCHMARK(std_atomic_bool_read_relaxed_acquire) { - std::atomic flag{false}; +BENCHMARK(std_atomic_bool_read_relaxed) { + std::atomic_bool flag{false}; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { - folly::doNotOptimizeAway( - flag.load(std::memory_order_relaxed) || - flag.load(std::memory_order_acquire)); + folly::doNotOptimizeAway(flag.load(std::memory_order_relaxed)); } }, kNumThreads, // Threads @@ -136,7 +181,21 @@ BENCHMARK(std_atomic_bool_read_relaxed_acquire) { } BENCHMARK(one_way_flag_read) { - exec::OneWayStatusFlag flag; + OneWayStatusFlag flag; + runParallelUpdates( + [&](size_t iters) { + for (size_t i = 0; i < iters; ++i) { + folly::doNotOptimizeAway(flag.check()); + } + }, + kNumThreads, // Threads + kNumIterations); // Iterations per thread +} + +#if defined(__x86_64__) && !defined(TSAN_BUILD) + +BENCHMARK(one_way_flag_unsafe_read) { + OneWayStatusFlagUnsafe flag; runParallelUpdates( [&](size_t iters) { for (size_t i = 0; i < iters; ++i) { @@ -147,6 +206,10 @@ BENCHMARK(one_way_flag_read) { kNumIterations); // Iterations per thread } +#endif + +} // namespace + int main(int argc, char** argv) { folly::Init init(&argc, &argv); folly::runBenchmarks(); diff --git a/velox/exec/benchmarks/CMakeLists.txt b/velox/exec/benchmarks/CMakeLists.txt index 994c7305bcea..e0924182982e 100644 --- a/velox/exec/benchmarks/CMakeLists.txt +++ b/velox/exec/benchmarks/CMakeLists.txt @@ -14,8 +14,11 @@ add_executable(velox_exec_vector_hasher_benchmark VectorHasherBenchmark.cpp) target_link_libraries( - velox_exec_vector_hasher_benchmark velox_exec velox_vector_test_lib - Folly::follybenchmark) + velox_exec_vector_hasher_benchmark + velox_exec + velox_vector_test_lib + Folly::follybenchmark +) add_executable(velox_filter_project_benchmark FilterProjectBenchmark.cpp) @@ -24,7 +27,8 @@ target_link_libraries( velox_exec velox_vector_test_lib velox_exec_test_lib - Folly::follybenchmark) + Folly::follybenchmark +) add_executable(velox_exchange_benchmark ExchangeBenchmark.cpp) @@ -33,7 +37,8 @@ target_link_libraries( velox_exec velox_exec_test_lib velox_vector_test_lib - Folly::follybenchmark) + Folly::follybenchmark +) add_executable(velox_merge_benchmark MergeBenchmark.cpp) @@ -43,7 +48,8 @@ target_link_libraries( velox_vector_test_lib Folly::follybenchmark GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable(velox_hash_benchmark HashTableBenchmark.cpp) @@ -53,27 +59,39 @@ target_link_libraries( velox_exec_test_lib velox_profiler velox_vector_test_lib - Folly::follybenchmark) + Folly::follybenchmark +) -add_executable(velox_hash_join_list_result_benchmark - HashJoinListResultBenchmark.cpp) +add_executable(velox_hash_join_list_result_benchmark HashJoinListResultBenchmark.cpp) target_link_libraries( velox_hash_join_list_result_benchmark velox_exec velox_exec_test_lib velox_vector_test_lib - Folly::follybenchmark) + Folly::follybenchmark +) + +add_executable(velox_hash_join_build_benchmark HashJoinBuildBenchmark.cpp) + +target_link_libraries( + velox_hash_join_build_benchmark + velox_exec + velox_exec_test_lib + velox_vector_fuzzer + velox_vector_test_lib + Folly::follybenchmark +) -add_executable(velox_hash_join_prepare_join_table_benchmark - HashJoinPrepareJoinTableBenchmark.cpp) +add_executable(velox_hash_join_prepare_join_table_benchmark HashJoinPrepareJoinTableBenchmark.cpp) target_link_libraries( velox_hash_join_prepare_join_table_benchmark velox_exec velox_exec_test_lib velox_vector_test_lib - Folly::follybenchmark) + Folly::follybenchmark +) if(${VELOX_ENABLE_PARQUET}) add_executable(velox_sort_benchmark RowContainerSortBenchmark.cpp) @@ -86,22 +104,24 @@ if(${VELOX_ENABLE_PARQUET}) velox_vector_test_lib Folly::follybenchmark arrow - thrift) + thrift + ) endif() add_library(velox_orderby_benchmark_util OrderByBenchmarkUtil.cpp) -target_link_libraries( - velox_orderby_benchmark_util velox_vector_fuzzer velox_vector_test_lib) +target_link_libraries(velox_orderby_benchmark_util velox_vector_fuzzer velox_vector_test_lib) add_executable(velox_prefixsort_benchmark PrefixSortBenchmark.cpp) target_link_libraries( - velox_prefixsort_benchmark velox_exec velox_orderby_benchmark_util - Folly::follybenchmark) + velox_prefixsort_benchmark + velox_exec + velox_orderby_benchmark_util + Folly::follybenchmark +) -add_executable(velox_orderby_benchmark OrderByBenchmark.cpp - OrderByBenchmarkUtil.cpp) +add_executable(velox_orderby_benchmark OrderByBenchmark.cpp OrderByBenchmarkUtil.cpp) target_link_libraries( velox_orderby_benchmark @@ -110,7 +130,8 @@ target_link_libraries( velox_orderby_benchmark_util velox_vector_fuzzer velox_vector_test_lib - Folly::follybenchmark) + Folly::follybenchmark +) add_executable(velox_window_prefixsort_benchmark WindowPrefixSortBenchmark.cpp) @@ -123,10 +144,24 @@ target_link_libraries( velox_vector_fuzzer velox_vector_test_lib velox_window - Folly::follybenchmark) + Folly::follybenchmark +) + +add_executable(velox_window_sub_partitioned_sort_benchmark WindowSubPartitionedSortBenchmark.cpp) + +target_link_libraries( + velox_window_sub_partitioned_sort_benchmark + velox_aggregates + velox_exec + velox_exec_test_lib + velox_hive_connector + velox_vector_fuzzer + velox_vector_test_lib + velox_window + Folly::follybenchmark +) -add_executable(velox_streaming_aggregation_benchmark - StreamingAggregationBenchmark.cpp) +add_executable(velox_streaming_aggregation_benchmark StreamingAggregationBenchmark.cpp) target_link_libraries( velox_streaming_aggregation_benchmark @@ -136,4 +171,37 @@ target_link_libraries( velox_exec_test_lib velox_vector_fuzzer velox_vector_test_lib - Folly::follybenchmark) + Folly::follybenchmark +) + +add_executable(velox_exec_bm_duplicate_project DuplicateProjectBenchmark.cpp) + +target_link_libraries( + velox_exec_bm_duplicate_project + velox_exec_test_lib + velox_hive_connector + velox_vector_fuzzer + Folly::follybenchmark +) + +add_executable(velox_atomics_benchmark AtomicsBench.cpp) + +target_link_libraries(velox_atomics_benchmark Folly::follybenchmark) + +if(VELOX_ENABLE_GEO) + add_executable(velox_spatial_join_benchmark SpatialJoinBenchmark.cpp) + + target_compile_definitions(velox_spatial_join_benchmark PRIVATE VELOX_ENABLE_GEO) + + target_link_libraries( + velox_spatial_join_benchmark + velox_memory + velox_exec + velox_exec_test_lib + velox_parse_parser + velox_presto_types + velox_vector_test_lib + velox_functions_prestosql + Folly::follybenchmark + ) +endif() diff --git a/velox/exec/benchmarks/DuplicateProjectBenchmark.cpp b/velox/exec/benchmarks/DuplicateProjectBenchmark.cpp new file mode 100644 index 000000000000..4c08384d3c87 --- /dev/null +++ b/velox/exec/benchmarks/DuplicateProjectBenchmark.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/memory/Memory.h" +#include "velox/common/memory/SharedArbitrator.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + +DEFINE_int64(fuzzer_seed, 99887766, "Seed for random input dataset generator"); + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::test; +using namespace facebook::velox::exec::test; + +namespace { + +static constexpr int32_t kNumVectors = 50; +static constexpr int32_t kRowsPerVector = 10'000; + +class DuplicateProjectBenchmark : public HiveConnectorTestBase { + public: + explicit DuplicateProjectBenchmark() { + HiveConnectorTestBase::SetUp(); + + inputType_ = ROW({ + {"a", INTEGER()}, + {"b", INTEGER()}, + {"c", VARCHAR()}, + {"d", VARCHAR()}, + }); + + VectorFuzzer::Options opts; + opts.vectorSize = kRowsPerVector; + opts.nullRatio = 0.2; + VectorFuzzer fuzzer(opts, pool_.get(), FLAGS_fuzzer_seed); + std::vector inputVectors; + for (auto i = 0; i < kNumVectors; ++i) { + std::vector children; + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + children.emplace_back(fuzzer.fuzzFlat(VARCHAR())); + children.emplace_back(fuzzer.fuzzFlat(VARCHAR())); + + inputVectors.emplace_back(makeRowVector(inputType_->names(), children)); + } + + sourceFilePath_ = TempFilePath::create(); + writeToFile(sourceFilePath_->getPath(), inputVectors); + } + + ~DuplicateProjectBenchmark() override { + HiveConnectorTestBase::TearDown(); + } + + void TestBody() override {} + + void run(const std::string& filter, const std::vector& project) { + auto plan = PlanBuilder() + .tableScan(inputType_) + .filter(filter) + .project(project) + .planNode(); + auto result = exec::test::AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit(sourceFilePath_->getPath())) + .copyResults(pool_.get()); + auto numResultRows = result->size(); + folly::doNotOptimizeAway(numResultRows); + } + + void addBenchmarks() { + folly::addBenchmark( + __FILE__, "filter_integer_duplicate_project_integer", [this]() { + run("a > 99999", {"b", "b"}); + return 1; + }); + folly::addBenchmark( + __FILE__, "filter_string_duplicate_project_string", [this]() { + run("length(c) > 10", {"d", "d"}); + return 1; + }); + folly::addBenchmark(__FILE__, "duplicate_project_integer", [this]() { + run("true", {"b", "b"}); + return 1; + }); + folly::addBenchmark(__FILE__, "duplicate_project_string", [this]() { + run("true", {"d", "d"}); + return 1; + }); + folly::addBenchmark( + __FILE__, "filter_integers_duplicate_project_integers", [this]() { + run("a > 9999 and b < 10000", {"a", "a", "b", "b"}); + return 1; + }); + folly::addBenchmark( + __FILE__, "filter_strings_duplicate_project_strings", [this]() { + run("length(c) > 10 and length(d) < 20", {"c", "c", "d", "d"}); + return 1; + }); + } + + private: + RowTypePtr inputType_; + std::shared_ptr sourceFilePath_; +}; + +} // namespace + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + facebook::velox::memory::MemoryManager::initialize( + facebook::velox::memory::MemoryManager::Options{}); + memory::SharedArbitrator::registerFactory(); + functions::prestosql::registerAllScalarFunctions(); + + DuplicateProjectBenchmark bm; + bm.addBenchmarks(); + folly::runBenchmarks(); + return 0; +} diff --git a/velox/exec/benchmarks/FilterProjectBenchmark.cpp b/velox/exec/benchmarks/FilterProjectBenchmark.cpp index e34be8a68e8d..4c0c26ac03e0 100644 --- a/velox/exec/benchmarks/FilterProjectBenchmark.cpp +++ b/velox/exec/benchmarks/FilterProjectBenchmark.cpp @@ -94,10 +94,11 @@ class FilterProjectBenchmark : public VectorTestBase { auto& type = data[0]->type()->as(); builder.values(data); for (auto level = 0; level < numStages; ++level) { - builder.filter(fmt::format( - "c0 >= {}", - static_cast( - 1000000 - pow(passPct / 100.0, 1 + level) * 1000000))); + builder.filter( + fmt::format( + "c0 >= {}", + static_cast( + 1000000 - pow(passPct / 100.0, 1 + level) * 1000000))); std::vector projections = {"c0"}; int32_t nthBigint = 0; int32_t nthVarchar = 0; diff --git a/velox/exec/benchmarks/HashJoinBuildBenchmark.cpp b/velox/exec/benchmarks/HashJoinBuildBenchmark.cpp new file mode 100644 index 000000000000..3d52e0581e13 --- /dev/null +++ b/velox/exec/benchmarks/HashJoinBuildBenchmark.cpp @@ -0,0 +1,394 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "velox/exec/HashTable.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/VectorTestUtil.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::test; + +namespace { +struct BenchmarkParams { + BenchmarkParams() = default; + + // Benchmark params, we need to provide: + // -the expect hash mode, + // -the build row schema, + // -the duplicate factor, + // -number of building rows, + // -number of probing rows, + // -the abandon percentage, + // -the number of build vector batches. + BenchmarkParams( + BaseHashTable::HashMode mode, + const TypePtr& buildType, + double dupFactor, + int64_t buildSize, + int64_t probeSize, + int32_t abandonPct, + int32_t numBuildBatches) + : mode{mode}, + buildType{buildType}, + hashTableSize{static_cast(std::floor(buildSize / dupFactor))}, + buildSize{buildSize}, + probeSize{probeSize}, + numBuildBatches{numBuildBatches}, + dupFactor{dupFactor}, + abandonPct{abandonPct} { + VELOX_CHECK_LE(hashTableSize, buildSize); + VELOX_CHECK_GE(numBuildBatches, 1); + + if (hashTableSize > BaseHashTable::kArrayHashMaxSize && + mode == BaseHashTable::HashMode::kArray) { + VELOX_FAIL("Bad hash mode."); + } + + numFields = buildType->size(); + if (mode == BaseHashTable::HashMode::kNormalizedKey) { + extraValue = BaseHashTable::kArrayHashMaxSize + 100; + } else if (mode == BaseHashTable::HashMode::kHash) { + extraValue = std::numeric_limits::max() - 1; + } else { + extraValue = 0; + } + + title = fmt::format( + "dupFactor:{:<2},abandonPct:{},{}", + dupFactor, + abandonPct, + BaseHashTable::modeString(mode)); + } + + // Expected mode. + BaseHashTable::HashMode mode; + + // Type of build & probe row. + TypePtr buildType; + + // Distinct rows in the table. + int64_t hashTableSize; + + // Number of build rows. + int64_t buildSize; + + // Number of probe rows. + int64_t probeSize; + + // Number of build RowContainers. + int32_t numBuildBatches; + + // Title for reporting. + std::string title; + + // The duplicate factor, 2 means every row will repeat 2 times. + double dupFactor; + + // This parameter controls the hashing mode. It is incorporated into the keys + // on the build side. If the expected mode is an array, its value is 0. If + // the expected mode is a normalized key, its value is 'kArrayHashMaxSize' + + // 100 to make the key range > 'kArrayHashMaxSize'. If the expected mode is a + // hash, its value is the maximum value of int64_t minus 1 to make the key + // range == 'kRangeTooLarge'. + int64_t extraValue; + + // Number of fields. + int32_t numFields; + + int32_t abandonPct; + + std::string toString() const { + return fmt::format( + "DupFactor:{:<2}, AbandonPct:{}, HashMode:{:<14}", + dupFactor, + abandonPct, + BaseHashTable::modeString(mode)); + } +}; + +struct BenchmarkResult { + BenchmarkParams params; + + uint64_t totalClock{0}; + + uint64_t hashBuildPeakMemoryBytes{0}; + + bool isBuildNoDupHashTableAbandon{false}; + + // The mode of the table. + BaseHashTable::HashMode hashMode; + + std::string toString() const { + return fmt::format( + "{}, isAbandon:{:<5}, totalClock:{}ms, peakMemoryBytes:{}", + params.toString(), + isBuildNoDupHashTableAbandon, + totalClock / 1000'000, + succinctBytes(hashBuildPeakMemoryBytes)); + } +}; + +class HashJoinBuildBenchmark : public VectorTestBase { + public: + HashJoinBuildBenchmark() : randomEngine_((std::random_device{}())) {} + + BenchmarkResult run(BenchmarkParams params) { + params_ = std::move(params); + BenchmarkResult result; + result.params = params_; + result.hashMode = params_.mode; + + std::vector buildVectors; + makeBuildBatches(buildVectors); + + int64_t sequence = 0; + int64_t batchSize = params_.probeSize / 4; + std::vector probeVectors; + for (auto i = 0; i < 4; ++i) { + auto batch = makeProbeVector(batchSize, params_.hashTableSize, sequence); + probeVectors.emplace_back(batch); + } + + uint64_t totalClocks{0}; + { + ClockTimer timer(totalClocks); + auto plan = makeHashJoinPlan(buildVectors, probeVectors); + CursorParameters cursorParams; + cursorParams.planNode = std::move(plan); + cursorParams.queryCtx = core::QueryCtx::create( + executor_.get(), + core::QueryConfig{{}}, + {}, + cache::AsyncDataCache::getInstance(), + rootPool_); + cursorParams.queryCtx->testingOverrideConfigUnsafe({ + {core::QueryConfig::kAbandonDedupHashMapMinPct, + std::to_string(params_.abandonPct)}, + {core::QueryConfig::kAbandonDedupHashMapMinRows, "1000000"}, + }); + + cursorParams.maxDrivers = 1; + auto cursor = TaskCursor::create(cursorParams); + auto* task = cursor->task().get(); + while (cursor->moveNext()) { + } + waitForTaskCompletion(task); + result.isBuildNoDupHashTableAbandon = isBuildNoDupHashTableAbandon(task); + } + result.totalClock = totalClocks; + + result.hashBuildPeakMemoryBytes = getHashBuildPeakMemory(rootPool_.get()); + return result; + } + + private: + std::shared_ptr makeHashJoinPlan( + const std::vector& buildVectors, + const std::vector& probeVectors) { + auto planNodeIdGenerator = std::make_shared(); + return exec::test::PlanBuilder(planNodeIdGenerator, pool_.get()) + .values(probeVectors) + .project({"c0 AS t0", "c1 as t1", "c2 as t2"}) + .hashJoin( + {"t0"}, + {"u0"}, + exec::test::PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0"}) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + } + + // Create the row vector for the build side, where the first column is used + // as the join key, and the remaining columns are dependent fields. + // If expect mode is array, the key is within the range [0, hashTableSize]; + // If expect mode is normalized key, the key is within the range + // [0, hashTableSize] + extraValue(kArrayHashMaxSize + 100); + // If expect mode is hash, the key is within the range [0, hashTableSize] + + // extraValue(max_int64 -1); + RowVectorPtr makeBuildRows( + std::vector& data, + int64_t start, + int64_t end, + bool addExtraValue) { + auto subData = + std::vector(data.begin() + start, data.begin() + end); + if (addExtraValue) { + subData[0] = params_.extraValue; + } + + std::vector children; + children.push_back(makeFlatVector(subData)); + return makeRowVector(children); + } + + // Generate the build side data batches. + void makeBuildBatches(std::vector& batches) { + int64_t buildKey = 0; + std::vector data; + for (auto i = 0; i < params_.buildSize; ++i) { + data.emplace_back((buildKey++) % params_.hashTableSize); + } + std::shuffle(data.begin(), data.end(), randomEngine_); + + auto size = params_.buildSize / params_.numBuildBatches; + for (auto i = 0; i < params_.numBuildBatches; ++i) { + batches.push_back(makeBuildRows( + data, + i * size, + (i + 1) * size + 1, + i == params_.numBuildBatches - 1)); + } + } + + // Create the row vector for the probe side, where the first column is used + // as the join key, and the remaining columns are dependent fields. + // Probe key is within the range [0, hashTableSize]. + RowVectorPtr + makeProbeVector(int64_t size, int64_t hashTableSize, int64_t& sequence) { + std::vector children; + for (int32_t i = 0; i < params_.numFields; ++i) { + children.push_back( + makeFlatVector( + size, + [&](vector_size_t row) { + return (sequence + row) % hashTableSize; + }, + nullptr)); + } + sequence += size; + + for (int32_t i = 0; i < 2; ++i) { + children.push_back( + makeFlatVector( + size, [&](vector_size_t row) { return row + size; }, nullptr)); + } + return makeRowVector(children); + } + + static int64_t getHashBuildPeakMemory(memory::MemoryPool* rootPool) { + int64_t hashBuildPeakBytes = 0; + std::vector pools; + pools.push_back(rootPool); + while (!pools.empty()) { + std::vector childPools; + for (auto pool : pools) { + pool->visitChildren([&](memory::MemoryPool* childPool) -> bool { + if (childPool->name().find("HashBuild") != std::string::npos) { + hashBuildPeakBytes += childPool->peakBytes(); + } + childPools.push_back(childPool); + return true; + }); + } + pools.swap(childPools); + } + if (hashBuildPeakBytes == 0) { + VELOX_FAIL("Failed to get HashBuild peak memory"); + } + return hashBuildPeakBytes; + } + + static bool isBuildNoDupHashTableAbandon(exec::Task* task) { + for (auto& pipelineStat : task->taskStats().pipelineStats) { + for (auto& operatorStat : pipelineStat.operatorStats) { + if (operatorStat.operatorType == "HashBuild") { + return operatorStat.runtimeStats["abandonBuildNoDupHash"].count != 0; + } + } + } + return false; + } + + std::default_random_engine randomEngine_; + BenchmarkParams params_; +}; + +} // namespace + +int main(int argc, char** argv) { + folly::Init init{&argc, &argv}; + memory::MemoryManager::Options options; + options.useMmapAllocator = true; + options.allocatorCapacity = 10UL << 30; + options.useMmapArena = true; + options.mmapArenaCapacityRatio = 1; + memory::MemoryManager::initialize(options); + + auto bm = std::make_unique(); + std::vector results; + + auto buildRowSize = (2L << 20) - 3; + auto probeRowSize = 100000000L; + + TypePtr twoKeyType{ROW({"k1"}, {BIGINT()})}; + + const std::vector hashModes = { + BaseHashTable::HashMode::kArray, + BaseHashTable::HashMode::kNormalizedKey, + BaseHashTable::HashMode::kHash, + }; + const std::vector dupFactorVector = { + 2, + 8, + 32, + }; + const std::vector abandonPcts = { + 90, + 80, + 70, + 50, + 0, + }; + + std::vector params; + for (auto mode : hashModes) { + for (auto dupFactor : dupFactorVector) { + for (auto pct : abandonPcts) { + params.push_back(BenchmarkParams( + mode, twoKeyType, dupFactor, buildRowSize, probeRowSize, pct, 512)); + } + } + } + + for (auto& param : params) { + BenchmarkResult result; + folly::addBenchmark(__FILE__, param.title, [param, &results, &bm]() { + results.emplace_back(bm->run(param)); + return 1; + }); + } + + folly::runBenchmarks(); + + for (auto& result : results) { + std::cout << result.toString() << std::endl; + } + return 0; +} diff --git a/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp b/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp index 417a86825f02..7e4e933a3787 100644 --- a/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp +++ b/velox/exec/benchmarks/HashJoinListResultBenchmark.cpp @@ -280,10 +280,11 @@ class HashTableListJoinResultBenchmark : public VectorTestBase { std::vector children; children.push_back(makeFlatVector(data)); for (int32_t i = 0; i < params_.numDependentFields; ++i) { - children.push_back(makeFlatVector( - data.size(), - [&](vector_size_t row) { return row + maxKey; }, - nullptr)); + children.push_back( + makeFlatVector( + data.size(), + [&](vector_size_t row) { return row + maxKey; }, + nullptr)); } return makeRowVector(children); } @@ -311,21 +312,23 @@ class HashTableListJoinResultBenchmark : public VectorTestBase { RowVectorPtr makeProbeVector(int32_t size, int64_t hashTableSize, int64_t& sequence) { std::vector children; - children.push_back(makeFlatVector( - size, - [&](vector_size_t row) { return (sequence + row) % hashTableSize; }, - nullptr)); + children.push_back( + makeFlatVector( + size, + [&](vector_size_t row) { return (sequence + row) % hashTableSize; }, + nullptr)); sequence += size; for (int32_t i = 0; i < params_.numDependentFields; ++i) { - children.push_back(makeFlatVector( - size, [&](vector_size_t row) { return row + size; }, nullptr)); + children.push_back( + makeFlatVector( + size, [&](vector_size_t row) { return row + size; }, nullptr)); } return makeRowVector(children); } void copyVectorsToTable(RowVectorPtr batch, BaseHashTable* table) { int32_t batchSize = batch->size(); - raw_vector dummy(batchSize); + raw_vector dummy(batchSize, pool()); auto rowContainer = table->rows(); auto& hashers = table->hashers(); auto numKeys = hashers.size(); @@ -391,6 +394,8 @@ class HashTableListJoinResultBenchmark : public VectorTestBase { topTable_->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); } buildTime_ = buildClocks; diff --git a/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp b/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp index cfe961b297e1..07f75f06cab1 100644 --- a/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp +++ b/velox/exec/benchmarks/HashJoinPrepareJoinTableBenchmark.cpp @@ -129,6 +129,8 @@ class HashJoinPrepareJoinTableBenchmark : public VectorTestBase { topTable_->prepareJoinTable( std::move(otherTables_), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); VELOX_CHECK_EQ(topTable_->hashMode(), params_.mode); } @@ -180,7 +182,7 @@ class HashJoinPrepareJoinTableBenchmark : public VectorTestBase { void copyVectorsToTable(RowVectorPtr batch, BaseHashTable* table) { int32_t batchSize = batch->size(); - raw_vector dummy(batchSize); + raw_vector dummy(batchSize, pool()); auto rowContainer = table->rows(); auto& hashers = table->hashers(); auto numKeys = hashers.size(); diff --git a/velox/exec/benchmarks/HashTableBenchmark.cpp b/velox/exec/benchmarks/HashTableBenchmark.cpp index 7352e46aef3f..0e739194dc6a 100644 --- a/velox/exec/benchmarks/HashTableBenchmark.cpp +++ b/velox/exec/benchmarks/HashTableBenchmark.cpp @@ -173,8 +173,9 @@ class HashTableBenchmark : public VectorTestBase { std::vector batches; std::vector> keyHashers; for (auto channel = 0; channel < params_.numKeys; ++channel) { - keyHashers.emplace_back(std::make_unique( - params_.buildType->childAt(channel), channel)); + keyHashers.emplace_back( + std::make_unique( + params_.buildType->childAt(channel), channel)); } auto table = HashTable::createForJoin( std::move(keyHashers), @@ -198,6 +199,8 @@ class HashTableBenchmark : public VectorTestBase { topTable_->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); LOG(INFO) << "Made table " << topTable_->toString(); @@ -286,7 +289,7 @@ class HashTableBenchmark : public VectorTestBase { int32_t tableOffset, BaseHashTable* table) { int32_t batchSize = batches[0]->size(); - raw_vector dummy(batchSize); + raw_vector dummy(batchSize, pool()); int32_t batchOffset = 0; rowOfKey_.resize(tableOffset + batchSize * batches.size()); auto rowContainer = table->rows(); @@ -409,8 +412,9 @@ class HashTableBenchmark : public VectorTestBase { TypePtr buildType, std::vector& batches) { for (auto i = 0; i < numBatches; ++i) { - batches.push_back(std::static_pointer_cast( - makeVector(buildType, batchSize, sequence))); + batches.push_back( + std::static_pointer_cast( + makeVector(buildType, batchSize, sequence))); sequence += batchSize; } } diff --git a/velox/exec/benchmarks/MergeBenchmark.cpp b/velox/exec/benchmarks/MergeBenchmark.cpp index f5fbaee0ff6f..118007abe0d4 100644 --- a/velox/exec/benchmarks/MergeBenchmark.cpp +++ b/velox/exec/benchmarks/MergeBenchmark.cpp @@ -19,7 +19,7 @@ #include -#include "velox/exec/TreeOfLosers.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/exec/tests/utils/MergeTestBase.h" using namespace facebook::velox; diff --git a/velox/exec/benchmarks/OrderByBenchmark.cpp b/velox/exec/benchmarks/OrderByBenchmark.cpp index 326bbf0e53d1..be01eed258f2 100644 --- a/velox/exec/benchmarks/OrderByBenchmark.cpp +++ b/velox/exec/benchmarks/OrderByBenchmark.cpp @@ -56,7 +56,7 @@ class OrderByBenchmark { const auto start = getCurrentTimeMicro(); for (auto i = 0; i < iterations; ++i) { std::shared_ptr task; - test::AssertQueryBuilder(plan).runWithoutResults(task); + test::AssertQueryBuilder(plan).countResults(task); auto taskStats = exec::toPlanStats(task->taskStats()); auto& stats = taskStats.at(orderByNodeId); inputNs += stats.addInputTiming.wallNanos; @@ -79,8 +79,9 @@ class OrderByBenchmark { core::PlanNodeId& orderByNodeId) { folly::BenchmarkSuspender suspender; std::vector vectors; - vectors.emplace_back(OrderByBenchmarkUtil::fuzzRows( - test.rowType, test.numRows, pool_.get())); + vectors.emplace_back( + OrderByBenchmarkUtil::fuzzRows( + test.rowType, test.numRows, pool_.get())); std::vector keys; keys.reserve(test.numKeys); diff --git a/velox/exec/benchmarks/OrderByBenchmarkUtil.h b/velox/exec/benchmarks/OrderByBenchmarkUtil.h index 3605edc936f3..8b3855daabc6 100644 --- a/velox/exec/benchmarks/OrderByBenchmarkUtil.h +++ b/velox/exec/benchmarks/OrderByBenchmarkUtil.h @@ -23,12 +23,13 @@ class OrderByBenchmarkUtil { public: /// Add the benchmarks with the parameter. /// @param benchmarkFunc benchmark generator. - static void addBenchmarks(const std::function& benchmarkFunc); + static void addBenchmarks( + const std::function& benchmarkFunc); /// Generate RowVector by VectorFuzzer according to rowType. Use /// FLAGS_data_null_ratio to specify the columns null ratio diff --git a/velox/exec/benchmarks/SpatialJoinBenchmark.cpp b/velox/exec/benchmarks/SpatialJoinBenchmark.cpp new file mode 100644 index 000000000000..066cceb63795 --- /dev/null +++ b/velox/exec/benchmarks/SpatialJoinBenchmark.cpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/memory/Memory.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +/// Benchmark for SpatialJoin operator, which implements a nested-loop join +/// with spatial predicates (e.g., ST_INTERSECTS, ST_CONTAINS, ST_WITHIN). +/// +/// This benchmark measures the performance of spatial joins under different +/// conditions: +/// - Different build and probe side sizes (cross join cardinality) +/// - Different spatial predicates +/// - Different data distributions (dense vs sparse geometries) +/// - Inner vs Left join types +/// +/// The benchmark creates synthetic geometric data and measures the throughput +/// of spatial join operations. The focus is on understanding how the nested +/// loop pattern performs with varying data sizes and selectivity. + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace { + +/// Spatial distribution patterns for geometry generation. +enum class Distribution { + kUniform, // Geometries uniformly distributed in space + kClustered // Geometries clustered in specific regions +}; + +// Constants for geometry generation. +constexpr int32_t kNullPatternModulo = 13; +constexpr int32_t kRandomCoordinateMax = 10000; +constexpr double kCoordinateScaleDivisor = 10.0; +constexpr int32_t kNumClusters = 5; +constexpr double kClusterSpacing = 200.0; +constexpr double kClusterCenterOffset = 100.0; +constexpr int32_t kClusterSpreadRange = 100; +constexpr int32_t kClusterSpreadHalf = 50; +constexpr double kPolygonSize = 10.0; + +// Constants for benchmark configuration. +constexpr int32_t kDefaultBatchSize = 10000; +constexpr int32_t kSmallBenchmarkSize = 1000; +constexpr int32_t kMediumProbeBenchmarkSize = 50000; +constexpr int32_t kMediumBuildBenchmarkSize = 5000; +constexpr int32_t kLargeProbeBenchmarkSize = 200000; +constexpr int32_t kLargeBuildBenchmarkSize = 50000; + +/// Parameters for a spatial join benchmark test case. +struct SpatialJoinBenchmarkParams { + /// Number of rows on the probe (left) side. + int32_t probeSize; + + /// Number of rows on the build (right) side. + int32_t buildSize; + + /// Spatial predicate to use (e.g., "ST_Intersects", "ST_Contains"). + std::string predicate; + + /// Join type (kInner or kLeft). + core::JoinType joinType; + + /// Spatial distribution pattern for geometry generation. + Distribution distribution; + + /// Description for benchmark naming. + std::string toString() const { + std::string joinTypeStr = + (joinType == core::JoinType::kInner) ? "Inner" : "Left"; + std::string distributionStr = + (distribution == Distribution::kUniform) ? "uniform" : "clustered"; + return fmt::format( + "{}x{}_{}_{}_{}", + probeSize, + buildSize, + predicate, + joinTypeStr, + distributionStr); + } +}; + +class SpatialJoinBenchmark : public facebook::velox::test::VectorTestBase { + public: + SpatialJoinBenchmark() : rng_((std::random_device{}())) {} + + /// Creates a vector of POINT geometries with specified distribution. + VectorPtr + makePointVector(int32_t size, Distribution distribution, bool nulls = false) { + return makeFlatVector( + size, + [&](vector_size_t row) { + if (nulls && (row % kNullPatternModulo == 0)) { + return std::string(""); + } + double x, y; + if (distribution == Distribution::kUniform) { + x = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + y = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + } else { + int cluster = row % kNumClusters; + double centerX = (cluster * kClusterSpacing) + kClusterCenterOffset; + double centerY = (cluster * kClusterSpacing) + kClusterCenterOffset; + x = centerX + + ((folly::Random::rand32(rng_) % kClusterSpreadRange) - + kClusterSpreadHalf); + y = centerY + + ((folly::Random::rand32(rng_) % kClusterSpreadRange) - + kClusterSpreadHalf); + } + return fmt::format("POINT ({} {})", x, y); + }, + [&](vector_size_t row) { + return nulls && (row % kNullPatternModulo == 0); + }); + } + + /// Creates a vector of POLYGON geometries with specified distribution. + VectorPtr makePolygonVector( + int32_t size, + Distribution distribution, + bool nulls = false) { + return makeFlatVector( + size, + [&](vector_size_t row) { + if (nulls && (row % kNullPatternModulo == 0)) { + return std::string(""); + } + double centerX, centerY; + if (distribution == Distribution::kUniform) { + centerX = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + centerY = (folly::Random::rand32(rng_) % kRandomCoordinateMax) / + kCoordinateScaleDivisor; + } else { + int cluster = row % kNumClusters; + centerX = (cluster * kClusterSpacing) + kClusterCenterOffset; + centerY = (cluster * kClusterSpacing) + kClusterCenterOffset; + } + return fmt::format( + "POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))", + centerX - kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY - kPolygonSize); + }, + [&](vector_size_t row) { + return nulls && (row % kNullPatternModulo == 0); + }); + } + + RowVectorPtr createProjectionVector( + const std::string& prefix, + RowVectorPtr input) { + const auto plan = PlanBuilder(std::make_shared()) + .values({input}) + .project( + {fmt::format("{}_id", prefix), + fmt::format( + "ST_GeometryFromText({}_geom) AS {}_geom", + prefix, + prefix)}) + .planNode(); + return AssertQueryBuilder(plan).copyResults(pool_.get()); + } + + /// Creates test data for the specified parameters. + std::pair, std::vector> makeTestData( + const SpatialJoinBenchmarkParams& params) { + // Create probe side data (points) + std::vector probeVectors; + const int32_t batchSize = std::min(params.probeSize, kDefaultBatchSize); + const int32_t numBatches = (params.probeSize + batchSize - 1) / batchSize; + + for (int32_t i = 0; i < numBatches; ++i) { + int32_t currentBatchSize = + std::min(batchSize, params.probeSize - (i * batchSize)); + auto geomVector = + makePointVector(currentBatchSize, params.distribution, false); + auto idVector = makeFlatVector( + currentBatchSize, + [i, batchSize](vector_size_t row) { return (i * batchSize) + row; }); + probeVectors.push_back(createProjectionVector( + "probe", + makeRowVector({"probe_id", "probe_geom"}, {idVector, geomVector}))); + } + + // Create build side data (polygons) + std::vector buildVectors; + const int32_t buildBatchSize = + std::min(params.buildSize, kDefaultBatchSize); + const int32_t numBuildBatches = + (params.buildSize + buildBatchSize - 1) / buildBatchSize; + + for (int32_t i = 0; i < numBuildBatches; ++i) { + int32_t currentBatchSize = + std::min(buildBatchSize, params.buildSize - (i * buildBatchSize)); + auto geomVector = + makePolygonVector(currentBatchSize, params.distribution, false); + auto idVector = makeFlatVector( + currentBatchSize, [i, buildBatchSize](vector_size_t row) { + return (i * buildBatchSize) + row; + }); + buildVectors.push_back(createProjectionVector( + "build", + makeRowVector({"build_id", "build_geom"}, {idVector, geomVector}))); + } + + return {probeVectors, buildVectors}; + } + + /// Creates a spatial join plan with the specified parameters. + std::shared_ptr makeSpatialJoinPlan( + std::vector&& probeVectors, + std::vector&& buildVectors, + const SpatialJoinBenchmarkParams& params) { + const auto planNodeIdGenerator = + std::make_shared(); + return PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .spatialJoin( + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(), + fmt::format("{}(probe_geom, build_geom)", params.predicate), + "probe_geom", + "build_geom", + std::nullopt, + {"probe_id", "probe_geom", "build_id", "build_geom"}, + params.joinType) + .planNode(); + } + + /// Runs a single benchmark iteration. + uint64_t run( + std::shared_ptr plan, + const SpatialJoinBenchmarkParams& params) { + auto result = AssertQueryBuilder(plan).copyResults(pool_.get()); + return result->size(); + } + + /// Adds a benchmark for the given parameters. + void addBenchmark(const SpatialJoinBenchmarkParams& params) { + auto name = params.toString(); + folly::addBenchmark(__FILE__, name, [this, params]() { + std::shared_ptr plan; + BENCHMARK_SUSPEND { + auto [probeVectors, buildVectors] = makeTestData(params); + plan = makeSpatialJoinPlan( + std::move(probeVectors), std::move(buildVectors), params); + } + + run(plan, params); + return 1; + }); + } + + private: + std::default_random_engine rng_; +}; + +} // namespace + +int main(int argc, char** argv) { + folly::Init init{&argc, &argv}; + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + parse::registerTypeResolver(); + functions::prestosql::registerAllScalarFunctions(); + + SpatialJoinBenchmark bm; + + // Small scale benchmarks (1K x 1K) + bm.addBenchmark( + {kSmallBenchmarkSize, + kSmallBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kUniform}); + bm.addBenchmark( + {kSmallBenchmarkSize, + kSmallBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kClustered}); + + // Medium scale benchmarks (50K x 5K) + bm.addBenchmark( + {kMediumProbeBenchmarkSize, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kUniform}); + bm.addBenchmark( + {kMediumProbeBenchmarkSize, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kClustered}); + + // Left join benchmarks (50K x 5K) + bm.addBenchmark( + {kMediumProbeBenchmarkSize / 2, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kLeft, + Distribution::kUniform}); + bm.addBenchmark( + {kMediumProbeBenchmarkSize / 2, + kMediumBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kLeft, + Distribution::kClustered}); + + // Contains predicate benchmarks (50K x 5K) + bm.addBenchmark( + {kMediumProbeBenchmarkSize / 2, + kMediumBuildBenchmarkSize, + "ST_Contains", + core::JoinType::kInner, + Distribution::kUniform}); + + // Large scale benchmark (200K x 50K) + bm.addBenchmark( + {kLargeProbeBenchmarkSize, + kLargeBuildBenchmarkSize, + "ST_Intersects", + core::JoinType::kInner, + Distribution::kUniform}); + + folly::runBenchmarks(); + return 0; +} diff --git a/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp b/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp index dc1f16ae6a1d..9393d7cef7d8 100644 --- a/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp +++ b/velox/exec/benchmarks/StreamingAggregationBenchmark.cpp @@ -84,10 +84,9 @@ class StreamingAggregationBenchmark : public VectorTestBase { std::to_string(params.numGroups)); folly::addBenchmark(__FILE__, name, [plan = &test->plan]() { - std::shared_ptr task; exec::test::AssertQueryBuilder(*plan) .serialExecution(true) - .runWithoutResults(task); + .countResults(); return 1; }); diff --git a/velox/exec/benchmarks/VectorHasherBenchmark.cpp b/velox/exec/benchmarks/VectorHasherBenchmark.cpp index 41bb56c05672..f3af6953b601 100644 --- a/velox/exec/benchmarks/VectorHasherBenchmark.cpp +++ b/velox/exec/benchmarks/VectorHasherBenchmark.cpp @@ -74,7 +74,7 @@ void benchmarkComputeValueIds(bool withNulls) { [](vector_size_t row) { return row % 17; }, withNulls ? velox::test::VectorMaker::nullEvery(7) : nullptr); - raw_vector hashes(size); + raw_vector hashes(size, base.pool()); SelectivityVector rows(size); hasher.decode(*values, rows); hasher.computeValueIds(rows, hashes); @@ -155,7 +155,7 @@ void benchmarkComputeValueIdsForStrings(bool flattenDictionaries) { uint64_t multiplier = 1; for (int i = 0; i < 4; i++) { auto hasher = hashers[i].get(); - raw_vector result(size); + raw_vector result(size, base.pool()); hasher->decode(*vectors[i], allRows); auto ok = hasher->computeValueIds(allRows, result); folly::doNotOptimizeAway(ok); @@ -164,7 +164,7 @@ void benchmarkComputeValueIdsForStrings(bool flattenDictionaries) { } suspender.dismiss(); - raw_vector result(size); + raw_vector result(size, base.pool()); for (int i = 0; i < 10'000; i++) { for (int j = 0; j < 4; j++) { auto hasher = hashers[j].get(); @@ -198,7 +198,7 @@ BENCHMARK(computeValueIdsLowCardinalityLargeBatchSize) { auto values = base.vectorMaker().dictionaryVector(data); for (int i = 0; i < 10; i++) { - raw_vector hashes(batchSize); + raw_vector hashes(batchSize, base.pool()); SelectivityVector rows(batchSize); VectorHasher hasher(BIGINT(), 0); hasher.decode(*values, rows); @@ -228,7 +228,7 @@ BENCHMARK(computeValueIdsLowCardinalityNotAllUsed) { auto values = BaseVector::wrapInDictionary(nullptr, indices, batchSize, data); for (int i = 0; i < 10; i++) { - raw_vector hashes(batchSize); + raw_vector hashes(batchSize, base.pool()); SelectivityVector rows(batchSize); VectorHasher hasher(BIGINT(), 0); hasher.decode(*values, rows); @@ -258,7 +258,7 @@ BENCHMARK(computeValueIdsDictionaryForFiltering) { auto values = BaseVector::wrapInDictionary(nullptr, indices, batchSize, data); for (int i = 0; i < 10; i++) { - raw_vector hashes(batchSize); + raw_vector hashes(batchSize, base.pool()); SelectivityVector rows(batchSize); VectorHasher hasher(BIGINT(), 0); hasher.decode(*values, rows); diff --git a/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp b/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp index 27025f723515..c63f4dd1188c 100644 --- a/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp +++ b/velox/exec/benchmarks/WindowPrefixSortBenchmark.cpp @@ -72,8 +72,9 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { // Generate key with a small number of unique values from a small range // (0-16). - children.emplace_back(makeFlatVector( - kRowsPerVector, [](auto row) { return row % 17; })); + children.emplace_back( + makeFlatVector( + kRowsPerVector, [](auto row) { return row % 17; })); // Generate key with a small number of unique values from a large range // (300 total values). @@ -94,8 +95,9 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { // Generate a column with increasing values to get a deterministic sort // order. - children.emplace_back(makeFlatVector( - kRowsPerVector, [](auto row) { return row; })); + children.emplace_back( + makeFlatVector( + kRowsPerVector, [](auto row) { return row; })); // Generate random values without nulls. children.emplace_back(fuzzer.fuzzFlat(INTEGER())); @@ -192,7 +194,8 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { std::move(plan), 0, core::QueryCtx::create(executor_.get()), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); } else { const std::unordered_map queryConfigMap( {{core::QueryConfig::kPrefixSortNormalizedKeyMaxBytes, "0"}}); @@ -202,7 +205,8 @@ class WindowPrefixSortBenchmark : public HiveConnectorTestBase { 0, core::QueryCtx::create( executor_.get(), core::QueryConfig(queryConfigMap)), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); } } diff --git a/velox/exec/benchmarks/WindowSubPartitionedSortBenchmark.cpp b/velox/exec/benchmarks/WindowSubPartitionedSortBenchmark.cpp new file mode 100644 index 000000000000..2fe4451bea97 --- /dev/null +++ b/velox/exec/benchmarks/WindowSubPartitionedSortBenchmark.cpp @@ -0,0 +1,407 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "velox/common/memory/SharedArbitrator.h" +#include "velox/exec/Cursor.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" + +DEFINE_int64(fuzzer_seed, 99887766, "Seed for random input dataset generator"); + +using namespace facebook::velox; +using namespace facebook::velox::test; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +static constexpr int32_t kNumVectors = 200; +static constexpr int32_t kRowsPerVector = 10'000; + +namespace { +class BenchmarkRecorder { + public: + BenchmarkRecorder() = default; + + void record(std::string name, uint64_t numBytes) { + // Only record the first apperance. + if (numBytesRecords_.count(name) == 0) { + numBytesRecords_[name] = {1, numBytes}; + names_.push_back(name); + } else { + auto& record = numBytesRecords_[name]; + record.numAppearance++; + record.totalCount += numBytes; + } + } + + std::string report() { + std::string result = "name, memory(MB)\n"; + for (auto& name : names_) { + auto& record = numBytesRecords_[name]; + result += fmt::format( + "{}, {}MB\n", + name, + record.totalCount / 1024 / 1024 / record.numAppearance); + } + return result; + } + + private: + struct Counter { + int32_t numAppearance{0}; + uint64_t totalCount{0}; + }; + std::vector names_; + std::unordered_map numBytesRecords_; +}; + +class WindowSubPartitionedSortBenchmark : public HiveConnectorTestBase { + public: + WindowSubPartitionedSortBenchmark( + int32_t numVectors, + int32_t rowsPerVector, + std::shared_ptr recorder) + : numVectors_(numVectors), + rowsPerVector_(rowsPerVector), + recorder_(recorder) { + memory::SharedArbitrator::registerFactory(); + HiveConnectorTestBase::SetUp(); + aggregate::prestosql::registerAllAggregateFunctions(); + window::prestosql::registerAllWindowFunctions(); + + inputType_ = ROW({ + {"k_array", INTEGER()}, + {"k_norm", INTEGER()}, + {"k_hash", INTEGER()}, + {"k_sort", INTEGER()}, + {"i32", INTEGER()}, + {"i64", BIGINT()}, + {"f32", REAL()}, + {"f64", DOUBLE()}, + {"i32_halfnull", INTEGER()}, + {"i64_halfnull", BIGINT()}, + {"f32_halfnull", REAL()}, + {"f64_halfnull", DOUBLE()}, + }); + + VectorFuzzer::Options opts; + opts.vectorSize = rowsPerVector_; + opts.nullRatio = 0; + VectorFuzzer fuzzer(opts, pool_.get(), FLAGS_fuzzer_seed); + std::vector inputVectors; + for (auto i = 0; i < numVectors_; ++i) { + std::vector children; + + // Generate key with a small number of unique values from a small range + // (0-16). + children.emplace_back( + makeFlatVector( + rowsPerVector_, [](auto row) { return row % 17; })); + + // Generate key with a small number of unique values from a large range + // (300 total values). + children.emplace_back( + makeFlatVector(rowsPerVector_, [](auto row) { + if (row % 3 == 0) { + return std::numeric_limits::max() - row % 100; + } else if (row % 3 == 1) { + return row % 100; + } else { + return std::numeric_limits::min() + row % 100; + } + })); + + // Generate key with many unique values from a large range (500K total + // values). + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + + // Generate a column with increasing values to get a deterministic sort + // order. + children.emplace_back( + makeFlatVector( + rowsPerVector_, [](auto row) { return row; })); + + // Generate random values without nulls. + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + children.emplace_back(fuzzer.fuzzFlat(BIGINT())); + children.emplace_back(fuzzer.fuzzFlat(REAL())); + children.emplace_back(fuzzer.fuzzFlat(DOUBLE())); + + // Generate random values with nulls. + opts.nullRatio = 0.05; // 5% + fuzzer.setOptions(opts); + + children.emplace_back(fuzzer.fuzzFlat(INTEGER())); + children.emplace_back(fuzzer.fuzzFlat(BIGINT())); + children.emplace_back(fuzzer.fuzzFlat(REAL())); + children.emplace_back(fuzzer.fuzzFlat(DOUBLE())); + + inputVectors.emplace_back(makeRowVector(inputType_->names(), children)); + } + + sourceFilePath_ = TempFilePath::create(); + writeToFile(sourceFilePath_->getPath(), inputVectors); + } + + ~WindowSubPartitionedSortBenchmark() override { + HiveConnectorTestBase::TearDown(); + } + + CpuWallTiming windowNanos() { + return windowNanos_; + } + + void TestBody() override {} + + void run( + const std::string& recordName, + const std::string& key, + const std::string& aggregate, + int32_t numSubPartitions) { + folly::BenchmarkSuspender suspender1; + + windowNanos_.clear(); + windowMems_.clear(); + + std::string functionSql = fmt::format( + "{} over (partition by {} order by k_sort)", aggregate, key); + + core::PlanNodeId tableScanPlanId; + core::PlanFragment plan = PlanBuilder() + .tableScan(inputType_) + .capturePlanNodeId(tableScanPlanId) + .window({functionSql}) + .planFragment(); + + vector_size_t numResultRows = 0; + auto task = makeTask(plan, numSubPartitions); + task->addSplit( + tableScanPlanId, + exec::Split(makeHiveConnectorSplit(sourceFilePath_->getPath()))); + task->noMoreSplits(tableScanPlanId); + suspender1.dismiss(); + + while (auto result = task->next()) { + numResultRows += result->size(); + } + + folly::BenchmarkSuspender suspender2; + auto stats = task->taskStats(); + for (auto& pipeline : stats.pipelineStats) { + for (auto& op : pipeline.operatorStats) { + if (op.operatorType == "Window") { + windowNanos_.add(op.addInputTiming); + windowNanos_.add(op.getOutputTiming); + windowMems_.add(op.memoryStats); + } + if (op.operatorType == "Values") { + // This is the timing for Window::noMoreInput() where the window + // sorting happens. So including in the cpu timing. + windowNanos_.add(op.finishTiming); + } + } + } + recorder_->record(recordName, windowMems_.peakTotalMemoryReservation); + suspender2.dismiss(); + folly::doNotOptimizeAway(numResultRows); + } + + std::shared_ptr makeTask( + core::PlanFragment plan, + int32_t numSubPartitions) { + bool subPartitionedSort = numSubPartitions > 1; + if (subPartitionedSort) { + const std::unordered_map queryConfigMap( + {{core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)}}); + return exec::Task::create( + "t", + std::move(plan), + 0, + core::QueryCtx::create( + executor_.get(), core::QueryConfig(queryConfigMap)), + Task::ExecutionMode::kSerial); + + } else { + return exec::Task::create( + "t", + std::move(plan), + 0, + core::QueryCtx::create(executor_.get()), + Task::ExecutionMode::kSerial); + } + } + + uint64_t getLatestMemoryUsage() { + return windowMems_.peakTotalMemoryReservation; + } + + private: + const int32_t numVectors_; + const int32_t rowsPerVector_; + const std::shared_ptr recorder_; + RowTypePtr inputType_; + std::shared_ptr sourceFilePath_; + + CpuWallTiming windowNanos_; + MemoryStats windowMems_; +}; + +std::unique_ptr benchmark; +auto recorder = std::make_shared(); + +void doSortRun( + uint32_t, + const std::string& recordName, + int32_t numSubPartitions, + const std::string& key, + const std::string& aggregate) { + benchmark->run(recordName, key, aggregate, numSubPartitions); +} + +#define BENCHMARK_AND_RECORD_HEAD(_num_, _name_, _key_, _agg_) \ + BENCHMARK_NAMED_PARAM( \ + doSortRun, \ + num##_num_##_##_name_, \ + fmt::format("num{}_{}", #_num_, #_name_), \ + _num_, \ + _key_, \ + _agg_); + +#define BENCHMARK_AND_RECORD_TAIL(_num_, _name_, _key_, _agg_) \ + BENCHMARK_RELATIVE_NAMED_PARAM( \ + doSortRun, \ + num##_num_##_##_name_, \ + fmt::format("num{}_{}", #_num_, #_name_), \ + _num_, \ + _key_, \ + _agg_); + +#define BATCHED_BENCHMARKS(_name_, _key_, _agg_) \ + BENCHMARK_AND_RECORD_HEAD(1, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(2, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(4, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(8, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(16, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(32, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(64, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(128, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(256, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(512, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(1024, _name_, _key_, _agg_); \ + BENCHMARK_AND_RECORD_TAIL(2048, _name_, _key_, _agg_); + +#define AGG_BENCHMARKS(_name_, _key_) \ + BATCHED_BENCHMARKS( \ + _name_##_INTEGER_##_key_, #_key_, fmt::format("{}(i32)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_REAL_##_key_, #_key_, fmt::format("{}(f32)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_INTEGER_NULLS_##_key_, \ + #_key_, \ + fmt::format("{}(i32_halfnull)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_REAL_NULLS_##_key_, \ + #_key_, \ + fmt::format("{}(f32_halfnull)", (#_name_))); + +#define MULTI_KEY_AGG_BENCHMARKS(_name_, _key1_, _key2_) \ + BATCHED_BENCHMARKS( \ + _name_##_BIGINT_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(i64)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_BIGINT_NULLS_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(i64_halfnull)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_DOUBLE_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(f64)", (#_name_))); \ + BATCHED_BENCHMARKS( \ + _name_##_DOUBLE_NULLS_##_key1_##_key2_, \ + fmt::format("{},{}", (#_key1_), (#_key2_)), \ + fmt::format("{}(f64_halfnull)", (#_name_))); + +// Count(1) aggregate. +BATCHED_BENCHMARKS(count_k_array, "k_array", "count(1)"); +BATCHED_BENCHMARKS(count_k_norm, "k_norm", "count(1)"); +BATCHED_BENCHMARKS(count_k_hash, "k_hash", "count(1)"); +BATCHED_BENCHMARKS(count_k_array_k_hash, "k_array,i32", "count(1)"); +BENCHMARK_DRAW_LINE(); + +// Count aggregate. +AGG_BENCHMARKS(count, k_array) +AGG_BENCHMARKS(count, k_norm) +AGG_BENCHMARKS(count, k_hash) +MULTI_KEY_AGG_BENCHMARKS(count, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(count, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(count, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(count, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +// Avg aggregate. +AGG_BENCHMARKS(avg, k_array) +AGG_BENCHMARKS(avg, k_norm) +AGG_BENCHMARKS(avg, k_hash) +MULTI_KEY_AGG_BENCHMARKS(avg, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(avg, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(avg, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(avg, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +// Min aggregate. +AGG_BENCHMARKS(min, k_array) +AGG_BENCHMARKS(min, k_norm) +AGG_BENCHMARKS(min, k_hash) +MULTI_KEY_AGG_BENCHMARKS(min, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(min, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(min, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(min, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +// Max aggregate. +AGG_BENCHMARKS(max, k_array) +AGG_BENCHMARKS(max, k_norm) +AGG_BENCHMARKS(max, k_hash) +MULTI_KEY_AGG_BENCHMARKS(max, k_array, i32) +MULTI_KEY_AGG_BENCHMARKS(max, k_array, i64) +MULTI_KEY_AGG_BENCHMARKS(max, k_hash, f32) +MULTI_KEY_AGG_BENCHMARKS(max, k_hash, f64) +BENCHMARK_DRAW_LINE(); + +} // namespace + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + facebook::velox::memory::MemoryManager::initialize( + facebook::velox::memory::MemoryManager::Options{}); + + benchmark = std::make_unique( + kNumVectors, kRowsPerVector, recorder); + folly::runBenchmarks(); + benchmark.reset(); + + std::cout << std::endl << recorder->report(); + return 0; +} diff --git a/velox/exec/fuzzer/AggregationFuzzer.cpp b/velox/exec/fuzzer/AggregationFuzzer.cpp index fd98a8d8a83d..2f635bbfbef8 100644 --- a/velox/exec/fuzzer/AggregationFuzzer.cpp +++ b/velox/exec/fuzzer/AggregationFuzzer.cpp @@ -265,10 +265,7 @@ AggregationFuzzer::AggregationFuzzer( void AggregationFuzzer::go(const std::string& planPath) { Type::registerSerDe(); - connector::hive::HiveTableHandle::registerSerDe(); - connector::hive::LocationHandle::registerSerDe(); - connector::hive::HiveColumnHandle::registerSerDe(); - connector::hive::HiveInsertTableHandle::registerSerDe(); + connector::hive::HiveConnector::registerSerDe(); core::ITypedExpr::registerSerDe(); core::PlanNode::registerSerDe(); registerPartitionFunctionSerDe(); @@ -493,21 +490,23 @@ void makeAlternativePlansWithValues( const std::vector& projections, std::vector& plans) { // Partial -> final aggregation plan. - plans.push_back(PlanBuilder() - .values(inputVectors) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .values(inputVectors) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .finalAggregation() + .planNode()); // Partial -> intermediate -> final aggregation plan. - plans.push_back(PlanBuilder() - .values(inputVectors) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .intermediateAggregation() - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .values(inputVectors) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .intermediateAggregation() + .finalAggregation() + .planNode()); // Partial -> local exchange -> final aggregation plan. auto numSources = std::min(4, inputVectors.size()); @@ -553,23 +552,25 @@ void makeAlternativePlansWithTableScan( // the false negatives. #ifndef TSAN_BUILD // Partial -> final aggregation plan. - plans.push_back(PlanBuilder() - .tableScan(inputRowType) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .localPartition(groupingKeys) - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .tableScan(inputRowType) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .localPartition(groupingKeys) + .finalAggregation() + .planNode()); // Partial -> intermediate -> final aggregation plan. - plans.push_back(PlanBuilder() - .tableScan(inputRowType) - .projectExpressions(projections) - .partialAggregation(groupingKeys, aggregates, masks) - .localPartition(groupingKeys) - .intermediateAggregation() - .finalAggregation() - .planNode()); + plans.push_back( + PlanBuilder() + .tableScan(inputRowType) + .projectExpressions(projections) + .partialAggregation(groupingKeys, aggregates, masks) + .localPartition(groupingKeys) + .intermediateAggregation() + .finalAggregation() + .planNode()); #endif } @@ -581,17 +582,18 @@ void makeStreamingPlansWithValues( const std::vector& projections, std::vector& plans) { // Single aggregation. - plans.push_back(PlanBuilder() - .values(inputVectors) - .projectExpressions(projections) - .orderBy(groupingKeys, false) - .streamingAggregation( - groupingKeys, - aggregates, - masks, - core::AggregationNode::Step::kSingle, - false) - .planNode()); + plans.push_back( + PlanBuilder() + .values(inputVectors) + .projectExpressions(projections) + .orderBy(groupingKeys, false) + .streamingAggregation( + groupingKeys, + aggregates, + masks, + core::AggregationNode::Step::kSingle, + false) + .planNode()); // Partial -> final aggregation plan. plans.push_back( @@ -646,17 +648,18 @@ void makeStreamingPlansWithTableScan( const std::vector& projections, std::vector& plans) { // Single aggregation. - plans.push_back(PlanBuilder() - .tableScan(inputRowType) - .projectExpressions(projections) - .orderBy(groupingKeys, false) - .streamingAggregation( - groupingKeys, - aggregates, - masks, - core::AggregationNode::Step::kSingle, - false) - .planNode()); + plans.push_back( + PlanBuilder() + .tableScan(inputRowType) + .projectExpressions(projections) + .orderBy(groupingKeys, false) + .streamingAggregation( + groupingKeys, + aggregates, + masks, + core::AggregationNode::Step::kSingle, + false) + .planNode()); // Partial -> final aggregation plan. plans.push_back( @@ -891,7 +894,7 @@ bool AggregationFuzzer::verifySortedAggregation( if (customVerification && (!aggregateOrderSensitive || customVerifier == nullptr || - customVerifier->supportsVerify())) { + customVerifier->supportsVerify() || customVerifier->supportsCompare())) { // We have custom verification enabled and: // 1) the aggregate function is not order sensitive (sorting the input won't // have an effect on the output) or @@ -899,13 +902,12 @@ bool AggregationFuzzer::verifySortedAggregation( // verification of this aggregation) or // 3) the custom verifier supports verification (it can't compare the // results of the aggregation with the reference DB) + // 4) the custom verifier supports compare. // keep the custom verifier enabled. return compareEquivalentPlanResults( plans, customVerification, input, customVerifier, 1); } else { - // If custom verification is not enabled or the custom verifier is used for - // compare and the aggregation is order sensitive (the result shoudl be - // deterministic if the input is sorted), then compare the results directly. + // If custom verification is not enabled, then compare the results directly. return compareEquivalentPlanResults(plans, false, input, nullptr, 1); } } @@ -1059,7 +1061,8 @@ bool AggregationFuzzer::compareEquivalentPlanResults( expectedResult.value(), firstPlan->outputType(), {resultOrError.result}), - "Velox and reference DB results don't match"); + "Velox and reference DB results don't match, plan: {}", + firstPlan->toString(true, true)); LOG(INFO) << "Verified results against reference DB"; } } else if (referenceQueryRunner_->supportsVeloxVectorResults()) { diff --git a/velox/exec/fuzzer/AggregationFuzzerBase.cpp b/velox/exec/fuzzer/AggregationFuzzerBase.cpp index 2bd16248fbdd..31606d040c3a 100644 --- a/velox/exec/fuzzer/AggregationFuzzerBase.cpp +++ b/velox/exec/fuzzer/AggregationFuzzerBase.cpp @@ -303,8 +303,9 @@ std::vector AggregationFuzzerBase::generateInputData( children.push_back(vectorFuzzer_.fuzz(inputType->childAt(j), size)); } - input.push_back(std::make_shared( - pool_.get(), inputType, nullptr, size, std::move(children))); + input.push_back( + std::make_shared( + pool_.get(), inputType, nullptr, size, std::move(children))); } if (generator != nullptr) { @@ -404,16 +405,18 @@ std::vector AggregationFuzzerBase::generateInputDataWithRowNumber( // values. This is done to introduce some repetition of key values for // windowing. auto baseVector = vectorFuzzer_.fuzz(types[i], numPartitions); - children.push_back(BaseVector::wrapInDictionary( - partitionNulls, partitionIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + partitionNulls, partitionIndices, size, baseVector)); } else if ( windowFrameBoundsSet.find(names[i]) != windowFrameBoundsSet.end()) { // Frame bound columns cannot have NULLs. children.push_back(vectorFuzzer_.fuzzNotNull(types[i], size)); } else if (sortingKeySet.find(names[i]) != sortingKeySet.end()) { auto baseVector = vectorFuzzer_.fuzz(types[i], numPeerGroups); - children.push_back(BaseVector::wrapInDictionary( - sortingNulls, sortingIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + sortingNulls, sortingIndices, size, baseVector)); } else { children.push_back(vectorFuzzer_.fuzz(types[i], size)); } @@ -553,13 +556,19 @@ void AggregationFuzzerBase::testPlan( const std::vector>& customVerifiers, const velox::fuzzer::ResultOrError& expected, int32_t maxDrivers) { - auto actual = execute( - planWithSplits.plan, - planWithSplits.splits, - injectSpill, - abandonPartial, - maxDrivers); - compare(actual, customVerification, customVerifiers, expected); + try { + auto actual = execute( + planWithSplits.plan, + planWithSplits.splits, + injectSpill, + abandonPartial, + maxDrivers); + compare(actual, customVerification, customVerifiers, expected); + } catch (...) { + LOG(ERROR) << "Failed while testing plan: " + << planWithSplits.plan->toString(true, true); + throw; + } } void AggregationFuzzerBase::compare( @@ -606,10 +615,12 @@ void AggregationFuzzerBase::compare( VELOX_CHECK( verifier->compare(expected.result, actual.result), "Logically equivalent plans produced different results"); + LOG(INFO) << "Verified through custom verifier."; } else if (verifier->supportsVerify()) { VELOX_CHECK( verifier->verify(actual.result), "Result of a logically equivalent plan failed custom verification"); + LOG(INFO) << "Verified through custom verifier."; } else { VELOX_UNREACHABLE( "Custom verifier must support either 'compare' or 'verify' API."); diff --git a/velox/exec/fuzzer/AggregationFuzzerBase.h b/velox/exec/fuzzer/AggregationFuzzerBase.h index 668fc9f6ecb6..e058bd898761 100644 --- a/velox/exec/fuzzer/AggregationFuzzerBase.h +++ b/velox/exec/fuzzer/AggregationFuzzerBase.h @@ -81,8 +81,6 @@ class AggregationFuzzerBase { : getFuzzerOptions(timestampPrecision), pool_.get()} { filesystems::registerLocalFileSystem(); - connector::registerConnectorFactory( - std::make_shared()); registerHiveConnector(hiveConfigs); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); @@ -123,6 +121,7 @@ class AggregationFuzzerBase { opts.stringVariableLength = true; opts.stringLength = 4'000; opts.nullRatio = FLAGS_null_ratio; + opts.useRandomNullPattern = true; opts.timestampPrecision = timestampPrecision; return opts; } diff --git a/velox/exec/fuzzer/AggregationFuzzerRunner.h b/velox/exec/fuzzer/AggregationFuzzerRunner.h index dc8d15928e7c..6aa044d84efe 100644 --- a/velox/exec/fuzzer/AggregationFuzzerRunner.h +++ b/velox/exec/fuzzer/AggregationFuzzerRunner.h @@ -101,7 +101,7 @@ class AggregationFuzzerRunner { if (filteredSignatures.empty()) { LOG(ERROR) << "No aggregate functions left after filtering using 'only' and 'skip' lists."; - exit(1); + return 1; } facebook::velox::parse::registerTypeResolver(); diff --git a/velox/exec/fuzzer/CMakeLists.txt b/velox/exec/fuzzer/CMakeLists.txt index 50b945562b7a..eadc76d8a283 100644 --- a/velox/exec/fuzzer/CMakeLists.txt +++ b/velox/exec/fuzzer/CMakeLists.txt @@ -20,10 +20,20 @@ add_library( PrestoQueryRunner.cpp PrestoQueryRunnerIntermediateTypeTransforms.cpp PrestoQueryRunnerTimestampWithTimeZoneTransform.cpp + PrestoQueryRunnerJsonTransform.cpp + PrestoQueryRunnerIntervalTransform.cpp PrestoQueryRunnerToSqlPlanNodeVisitor.cpp FuzzerUtil.cpp PrestoSql.cpp - PrestoQueryRunnerIntermediateTypeTransforms.cpp) +) + +# TODO Add VeloxQueryRunner to velox_fuzzer_util to support in +# ExpressionFuzzerTest. More information can be found here: +# https://github.com/facebookincubator/velox/issues/15414 +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + target_sources(velox_fuzzer_util PRIVATE VeloxQueryRunner.cpp) + target_link_libraries(velox_fuzzer_util FBThrift::thriftcpp2) +endif() target_link_libraries( velox_fuzzer_util @@ -33,14 +43,14 @@ target_link_libraries( velox_exec_test_lib velox_expression_functions velox_presto_types - cpr::cpr - Boost::regex - velox_type_parser + CURL::libcurl + velox_presto_type_parser Folly::folly velox_hive_connector velox_dwio_dwrf_writer velox_dwio_catalog_fbhive - velox_dwio_faulty_file_sink) + velox_dwio_faulty_file_sink +) add_library(velox_aggregation_fuzzer_base AggregationFuzzerBase.cpp) @@ -58,7 +68,8 @@ target_link_libraries( velox_fuzzer_util velox_expression_test_utility velox_vector - velox_core) + velox_core +) add_library(velox_aggregation_fuzzer AggregationFuzzer.cpp) @@ -69,7 +80,8 @@ target_link_libraries( velox_exec_test_lib velox_expression_test_utility velox_aggregation_fuzzer_base - velox_fuzzer_util) + velox_fuzzer_util +) add_library(velox_window_fuzzer WindowFuzzer.cpp) @@ -81,7 +93,8 @@ target_link_libraries( velox_exec_test_lib velox_expression_test_utility velox_aggregation_fuzzer_base - velox_temp_path) + velox_temp_path +) add_library(velox_row_number_fuzzer_base_lib RowNumberFuzzerBase.cpp) @@ -90,35 +103,39 @@ target_link_libraries( velox_dwio_dwrf_reader velox_fuzzer_util velox_vector_fuzzer - velox_exec_test_lib) + velox_exec_test_lib +) add_library(velox_row_number_fuzzer_lib RowNumberFuzzer.cpp) target_link_libraries( - velox_row_number_fuzzer_lib velox_row_number_fuzzer_base_lib velox_type - velox_expression_test_utility) + velox_row_number_fuzzer_lib + velox_row_number_fuzzer_base_lib + velox_type + velox_expression_test_utility +) # RowNumber Fuzzer. add_executable(velox_row_number_fuzzer RowNumberFuzzerRunner.cpp) -target_link_libraries( - velox_row_number_fuzzer velox_row_number_fuzzer_lib) +target_link_libraries(velox_row_number_fuzzer velox_row_number_fuzzer_lib) add_library(velox_topn_row_number_fuzzer_lib TopNRowNumberFuzzer.cpp) target_link_libraries( - velox_topn_row_number_fuzzer_lib velox_row_number_fuzzer_base_lib velox_type - velox_expression_test_utility) + velox_topn_row_number_fuzzer_lib + velox_row_number_fuzzer_base_lib + velox_type + velox_expression_test_utility +) # TopNRowNumber Fuzzer. add_executable(velox_topn_row_number_fuzzer TopNRowNumberFuzzerRunner.cpp) -target_link_libraries( - velox_topn_row_number_fuzzer velox_topn_row_number_fuzzer_lib) +target_link_libraries(velox_topn_row_number_fuzzer velox_topn_row_number_fuzzer_lib) # Join Fuzzer. -add_executable(velox_join_fuzzer JoinFuzzerRunner.cpp JoinFuzzer.cpp - JoinMaker.cpp) +add_executable(velox_join_fuzzer JoinFuzzerRunner.cpp JoinFuzzer.cpp JoinMaker.cpp) target_link_libraries( velox_join_fuzzer @@ -126,7 +143,21 @@ target_link_libraries( velox_vector_fuzzer velox_fuzzer_util velox_exec_test_lib - velox_expression_test_utility) + velox_expression_test_utility +) + +# Spatial Join Fuzzer. +add_executable(velox_spatial_join_fuzzer SpatialJoinFuzzerRunner.cpp SpatialJoinFuzzer.cpp) + +target_link_libraries( + velox_spatial_join_fuzzer + velox_type + velox_vector_fuzzer + velox_fuzzer_util + velox_exec_test_lib + velox_expression_test_utility + velox_vector_test_lib +) add_library(velox_writer_fuzzer WriterFuzzer.cpp) @@ -140,9 +171,15 @@ target_link_libraries( velox_temp_path velox_vector_test_lib velox_dwio_faulty_file_sink - velox_file_test_utils) + velox_file_test_utils +) -add_library(velox_memory_arbitration_fuzzer MemoryArbitrationFuzzer.cpp) +# Arbitration Fuzzer. +add_executable( + velox_memory_arbitration_fuzzer + MemoryArbitrationFuzzerRunner.cpp + MemoryArbitrationFuzzer.cpp +) target_link_libraries( velox_memory_arbitration_fuzzer @@ -153,19 +190,22 @@ target_link_libraries( velox_exec_test_lib velox_expression_test_utility velox_functions_prestosql - velox_aggregates) + velox_aggregates +) add_library(velox_cache_fuzzer_lib CacheFuzzer.cpp) # Cache Fuzzer add_executable(velox_cache_fuzzer CacheFuzzerRunner.cpp) -target_link_libraries( - velox_cache_fuzzer velox_cache_fuzzer_lib velox_fuzzer_util) +target_link_libraries(velox_cache_fuzzer velox_cache_fuzzer_lib velox_fuzzer_util) target_link_libraries( - velox_cache_fuzzer_lib velox_dwio_common velox_temp_path - velox_vector_test_lib) + velox_cache_fuzzer_lib + velox_dwio_common + velox_temp_path + velox_vector_test_lib +) # Exchange Fuzzer add_executable(velox_exchange_fuzzer ExchangeFuzzer.cpp) @@ -175,7 +215,55 @@ target_link_libraries( velox_exec_test_lib velox_aggregates velox_vector_test_lib - velox_vector_fuzzer) + velox_vector_fuzzer +) + +# LocalRunnerService (requires FBThrift support) +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + # Generate Thrift library for LocalRunnerService + include(FBThriftCppLibrary) + add_fbthrift_cpp_library( + local_runner_service_thrift + if/LocalRunnerService.thrift + SERVICES + LocalRunnerService + ) + + target_compile_options(local_runner_service_thrift PRIVATE -Wno-error=deprecated-declarations) + + # LocalRunnerService Library + add_library(velox_local_runner_service_lib LocalRunnerService.cpp) + + target_link_libraries( + velox_local_runner_service_lib + local_runner_service_thrift + velox_core + velox_exec + velox_exec_test_lib + velox_expression + velox_functions_prestosql + velox_common_base + velox_memory + Folly::folly + FBThrift::thriftcpp2 + gflags + glog::glog + ) + + target_include_directories(velox_local_runner_service_lib PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + + # LocalRunnerService Executable + add_executable(velox_local_runner_service_runner LocalRunnerServiceRunner.cpp) + + target_link_libraries( + velox_local_runner_service_runner + velox_local_runner_service_lib + velox_functions_prestosql + gtest + gflags + Folly::folly + ) +endif() if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/exec/fuzzer/CacheFuzzer.cpp b/velox/exec/fuzzer/CacheFuzzer.cpp index 2ab9ef5354f6..8b4a311fbbdb 100644 --- a/velox/exec/fuzzer/CacheFuzzer.cpp +++ b/velox/exec/fuzzer/CacheFuzzer.cpp @@ -380,17 +380,18 @@ void CacheFuzzer::initializeInputs() { // Initialize buffered input. auto readFile = fs_->openFileForRead(fileNames_[i]); auto const withExecutor = !folly::Random::oneIn(3, rng_); - inputs_.emplace_back(std::make_unique( - std::move(readFile), - MetricsLog::voidLog(), - fileIds_[i].id(), // NOLINT - cache_.get(), - tracker, - fileIds_[i].id(), // NOLINT - ioStats, - fsStats, - withExecutor ? executor_.get() : nullptr, - readOptions)); + inputs_.emplace_back( + std::make_unique( + std::move(readFile), + MetricsLog::voidLog(), + fileIds_[i], // NOLINT + cache_.get(), + tracker, + fileIds_[i], // NOLINT + ioStats, + fsStats, + withExecutor ? executor_.get() : nullptr, + readOptions)); // Divide file into fragments. std::vector> fragments; diff --git a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp index 388e67ca36ad..4ec4f00ca5c1 100644 --- a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp +++ b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.cpp @@ -277,7 +277,8 @@ void DuckQueryRunnerToSqlPlanNodeVisitor::visit( sql << inputType->nameOf(i); } - sql << ", row_number() OVER ("; + sql << ", " << core::TopNRowNumberNode::rankFunctionName(node.rankFunction()) + << "() OVER ("; const auto& partitionKeys = node.partitionKeys(); if (!partitionKeys.empty()) { diff --git a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h index 72d0f2dd5419..d7566ed833f0 100644 --- a/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h +++ b/velox/exec/fuzzer/DuckQueryRunnerToSqlPlanNodeVisitor.h @@ -106,6 +106,11 @@ class DuckQueryRunnerToSqlPlanNodeVisitor : public PrestoSqlPlanNodeVisitor { const core::NestedLoopJoinNode& node, core::PlanNodeVisitorContext& ctx) const override; + void visit(const core::SpatialJoinNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + void visit(const core::OrderByNode&, core::PlanNodeVisitorContext&) const override { VELOX_NYI(); @@ -119,6 +124,12 @@ class DuckQueryRunnerToSqlPlanNodeVisitor : public PrestoSqlPlanNodeVisitor { void visit(const core::ProjectNode& node, core::PlanNodeVisitorContext& ctx) const override; + void visit( + const core::ParallelProjectNode& node, + core::PlanNodeVisitorContext& ctx) const override { + VELOX_NYI(); + } + void visit(const core::RowNumberNode& node, core::PlanNodeVisitorContext& ctx) const override; diff --git a/velox/exec/fuzzer/ExchangeFuzzer.cpp b/velox/exec/fuzzer/ExchangeFuzzer.cpp index 4443b24fba21..84824658fc5d 100644 --- a/velox/exec/fuzzer/ExchangeFuzzer.cpp +++ b/velox/exec/fuzzer/ExchangeFuzzer.cpp @@ -409,9 +409,11 @@ class ExchangeFuzzer : public VectorTestBase { LOG(INFO) << "Terminating with error"; exit(1); } - LOG(INFO) << "Memory after run=" - << succinctBytes(memory::AllocationTraits::pageBytes( - memory::memoryManager()->allocator()->numAllocated())); + LOG(INFO) + << "Memory after run=" + << succinctBytes( + memory::AllocationTraits::pageBytes( + memory::memoryManager()->allocator()->numAllocated())); if (FLAGS_duration_sec == 0 && FLAGS_steps && counter + 1 >= FLAGS_steps) { diff --git a/velox/exec/fuzzer/FuzzerUtil.cpp b/velox/exec/fuzzer/FuzzerUtil.cpp index b229fd3243f0..caafe71da5b5 100644 --- a/velox/exec/fuzzer/FuzzerUtil.cpp +++ b/velox/exec/fuzzer/FuzzerUtil.cpp @@ -323,20 +323,28 @@ TypePtr sanitizeTryResolveType( const exec::TypeSignature& typeSignature, const std::unordered_map& variables, const std::unordered_map& resolvedTypeVariables) { - return sanitize(SignatureBinder::tryResolveType( - typeSignature, variables, resolvedTypeVariables)); + return sanitize( + SignatureBinder::tryResolveType( + typeSignature, variables, resolvedTypeVariables)); } TypePtr sanitizeTryResolveType( const exec::TypeSignature& typeSignature, const std::unordered_map& variables, const std::unordered_map& typeVariablesBindings, - std::unordered_map& integerVariablesBindings) { - return sanitize(SignatureBinder::tryResolveType( - typeSignature, - variables, - typeVariablesBindings, - integerVariablesBindings)); + std::unordered_map& integerVariablesBindings, + const std::unordered_map& + longEnumParameterVariablesBindings, + const std::unordered_map& + varcharEnumParameterVariablesBindings) { + return sanitize( + SignatureBinder::tryResolveType( + typeSignature, + variables, + typeVariablesBindings, + integerVariablesBindings, + longEnumParameterVariablesBindings, + varcharEnumParameterVariablesBindings)); } void setupMemory( @@ -353,11 +361,13 @@ void setupMemory( options.checkUsageLeak = true; options.arbitrationStateCheckCb = memoryArbitrationStateCheck; options.extraArbitratorConfigs = { - {std::string(velox::memory::SharedArbitrator::ExtraConfig:: - kGlobalArbitrationEnabled), + {std::string( + velox::memory::SharedArbitrator::ExtraConfig:: + kGlobalArbitrationEnabled), enableGlobalArbitration ? "true" : "false"}, - {std::string(velox::memory::SharedArbitrator::ExtraConfig:: - kMemoryPoolMinReclaimBytes), + {std::string( + velox::memory::SharedArbitrator::ExtraConfig:: + kMemoryPoolMinReclaimBytes), "0B"}}; facebook::velox::memory::MemoryManager::initialize(options); } @@ -365,17 +375,13 @@ void setupMemory( void registerHiveConnector( const std::unordered_map& hiveConfigs) { auto configs = hiveConfigs; - if (!connector::hasConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - connector::registerConnectorFactory( - std::make_shared()); - } - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared(std::move(configs))); + // Make sure not to run out of open file descriptors. + configs[connector::hive::HiveConfig::kNumCacheFileHandles] = "1000"; + + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + kHiveConnectorId, + std::make_shared(std::move(configs))); connector::registerConnector(hiveConnector); } diff --git a/velox/exec/fuzzer/FuzzerUtil.h b/velox/exec/fuzzer/FuzzerUtil.h index dd8f7b37936a..5ad39818ca55 100644 --- a/velox/exec/fuzzer/FuzzerUtil.h +++ b/velox/exec/fuzzer/FuzzerUtil.h @@ -117,7 +117,11 @@ TypePtr sanitizeTryResolveType( const exec::TypeSignature& typeSignature, const std::unordered_map& variables, const std::unordered_map& typeVariablesBindings, - std::unordered_map& integerVariablesBindings); + std::unordered_map& integerVariablesBindings, + const std::unordered_map& + longEnumParameterVariablesBindings, + const std::unordered_map& + varcharEnumParameterVariablesBindings); // Invoked to set up memory system with arbitration. void setupMemory( diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index ded82a03a310..73a0ebf8f901 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -88,6 +88,7 @@ class JoinFuzzer { opts.stringVariableLength = true; opts.stringLength = 100; opts.nullRatio = FLAGS_null_ratio; + opts.useRandomNullPattern = true; opts.timestampPrecision = VectorFuzzer::Options::TimestampPrecision::kMilliSeconds; return opts; @@ -205,14 +206,11 @@ JoinFuzzer::JoinFuzzer( // Make sure not to run out of open file descriptors. std::unordered_map hiveConfig = { {connector::hive::HiveConfig::kNumCacheFileHandles, "1000"}}; - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - test::kHiveConnectorId, - std::make_shared(std::move(hiveConfig))); + + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + test::kHiveConnectorId, + std::make_shared(std::move(hiveConfig))); connector::registerConnector(hiveConnector); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); diff --git a/velox/exec/fuzzer/JoinMaker.cpp b/velox/exec/fuzzer/JoinMaker.cpp index 07e2b66d6b38..3e4631eb96a4 100644 --- a/velox/exec/fuzzer/JoinMaker.cpp +++ b/velox/exec/fuzzer/JoinMaker.cpp @@ -62,10 +62,11 @@ std::vector makeSourcesForPartitionedJoinPlan( std::vector sourceNodes; for (const auto& sourceInput : sourceInputs) { - sourceNodes.push_back(test::PlanBuilder(planNodeIdGenerator) - .values(sourceInput) - .projectExpressions(joinSource->projections()) - .planNode()); + sourceNodes.push_back( + test::PlanBuilder(planNodeIdGenerator) + .values(sourceInput) + .projectExpressions(joinSource->projections()) + .planNode()); } return sourceNodes; diff --git a/velox/exec/fuzzer/LocalRunnerService.cpp b/velox/exec/fuzzer/LocalRunnerService.cpp new file mode 100644 index 000000000000..2475f6af40ed --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerService.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include +#include +#include +#include +#include +#include +#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/memory/ByteStream.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/serializers/PrestoSerializer.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; + +namespace facebook::velox::runner { +namespace { + +class StdoutCapture { + public: + StdoutCapture() { + oldCoutBuf_ = std::cout.rdbuf(); + std::cout.rdbuf(buffer_.rdbuf()); + } + ~StdoutCapture() { + std::cout.rdbuf(oldCoutBuf_); + } + std::string str() { + return buffer_.str(); + } + + private: + std::stringstream buffer_; + std::streambuf* oldCoutBuf_; +}; + +std::pair execute( + const std::string& serializedPlan, + const std::string& queryId, + std::shared_ptr pool) { + StdoutCapture stdoutCapture; + + core::PlanNodePtr plan; + try { + folly::dynamic planJson = folly::parseJson(serializedPlan); + VLOG(1) << "Deserializing plan:\n" << serializedPlan; + plan = core::PlanNode::deserialize(planJson, pool.get()); + } catch (const std::exception& e) { + throw std::runtime_error( + fmt::format("Failed to deserialize plan: {}", e.what())); + } + VLOG(1) << "Deserialized plan:\n" << plan->toString(true, true); + + try { + exec::test::AssertQueryBuilder queryBuilder(plan); + queryBuilder.config("session_timezone", "America/Los_Angeles"); + + std::shared_ptr task; + auto results = queryBuilder.copyResults(pool.get(), task); + + return {results, stdoutCapture.str()}; + } catch (const std::exception& e) { + throw std::runtime_error( + fmt::format("Error executing query: {}", e.what())); + } +} + +} // namespace + +std::string serializeBatch( + const RowVectorPtr& rowVector, + memory::MemoryPool* pool) { + std::ostringstream out; + + OStreamOutputStream outputStream(&out); + + auto serde = std::make_unique(); + serializer::presto::PrestoVectorSerde::PrestoOptions options; + + auto serializer = serde->createBatchSerializer(pool, &options); + serializer->serialize(rowVector, &outputStream); + + return out.str(); +} + +std::vector convertToBatches( + const std::vector& rowVectors, + memory::MemoryPool* pool) { + std::vector results; + + if (rowVectors.empty()) { + return results; + } + + auto leafPool = pool->addLeafChild("batchSerialization"); + + for (const auto& rowVector : rowVectors) { + Batch result; + const auto& rowType = rowVector->type()->asRow(); + + for (auto i = 0; i < rowType.size(); ++i) { + result.columnNames()->push_back(rowType.nameOf(i)); + result.columnTypes()->push_back(rowType.childAt(i)->toString()); + } + + std::string serializedData = serializeBatch(rowVector, leafPool.get()); + result.serializedData() = std::move(serializedData); + + results.push_back(std::move(result)); + } + + return results; +} + +void LocalRunnerServiceHandler::execute( + ExecutePlanResponse& response, + std::unique_ptr request) { + VLOG(1) << "Received executePlan request"; + + auto rootPool = memory::memoryManager()->addRootPool(); + auto pool = rootPool->addLeafChild("localRunnerHandler"); + + RowVectorPtr results; + std::string output; + + try { + VLOG(1) << "Executing plan in service handler"; + std::tie(results, output) = + ::execute(*request->serializedPlan(), *request->queryId(), pool); + + VLOG(1) << fmt::format( + "Result:\nresult rowVector: {}\nstdout: {}", + results->toString(true), + output); + } catch (const std::exception& e) { + VLOG(1) << "Exception executing plan: " << e.what(); + response.success() = false; + response.errorMessage() = e.what(); + return; + } + + VLOG(1) << "Converting results to Thrift response"; + auto resultBatches = convertToBatches({results}, rootPool.get()); + response.results() = std::move(resultBatches); + response.output() = output; + response.success() = true; + VLOG(1) << "Response sent"; +} + +} // namespace facebook::velox::runner diff --git a/velox/exec/fuzzer/LocalRunnerService.h b/velox/exec/fuzzer/LocalRunnerService.h new file mode 100644 index 000000000000..dbd5f5d8f17f --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerService.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// Thrift service implementation and library for executing Velox query plans +/// remotely. +/// +/// This file provides conversion utilities and a service handler for the +/// LocalRunnerService. It enables remote execution of serialized Velox +/// expression evaluation primarily used for fuzzing where query plans need to +/// be executed on remote workers. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox::runner { + +/// Converts a collection of Velox RowVectors into Thrift Batches using +/// binary serialization. +std::vector convertToBatches( + const std::vector& rowVectors, + memory::MemoryPool* pool); + +/// Thrift service handler for executing Velox query plans. +/// Executes a serialized Velox query plan. This method deserializes the plan +/// from JSON, configures execution, runs the query plan to completion, +/// converts results to Thrift Batches and captures any subsequent errors or +/// output. The method returns a Thrift response containing the results. +class LocalRunnerServiceHandler + : public apache::thrift::ServiceHandler { + public: + void execute( + ExecutePlanResponse& response, + std::unique_ptr request) override; +}; + +} // namespace facebook::velox::runner diff --git a/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp b/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp new file mode 100644 index 000000000000..7bb15f9d355a --- /dev/null +++ b/velox/exec/fuzzer/LocalRunnerServiceRunner.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "velox/core/ITypedExpr.h" +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/type/Type.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; + +DEFINE_int32( + port, + 9091, + "LocalRunnerService port number to be used in conjunction with ExpressionFuzzerTest flag local_runner_port."); + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + folly::Init init(&argc, &argv); + + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + functions::prestosql::registerAllScalarFunctions(); + functions::prestosql::registerInternalFunctions(); + + std::shared_ptr thriftServer = + std::make_shared(); + thriftServer->setPort(FLAGS_port); + thriftServer->setInterface(std::make_shared()); + thriftServer->setNumIOWorkerThreads(1); + thriftServer->setNumCPUWorkerThreads(1); + + VLOG(1) << "Starting LocalRunnerService"; + thriftServer->serve(); + + return 0; +} diff --git a/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp b/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp index 4aaeee65b887..6d9988108b43 100644 --- a/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp +++ b/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp @@ -18,10 +18,10 @@ #include #include +#include #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" #include "velox/common/fuzzer/Utils.h" -#include "velox/common/memory/SharedArbitrator.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" // @manual #include "velox/dwio/dwrf/RegisterDwrfWriter.h" // @manual @@ -31,7 +31,6 @@ #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" -#include "velox/functions/sparksql/aggregates/Register.h" #include "velox/serializers/CompactRowSerializer.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/serializers/UnsafeRowSerializer.h" @@ -78,14 +77,19 @@ DEFINE_int32( "Each second, the percentage chance of triggering global arbitration by " "calling shrinking pools globally."); +DEFINE_double( + spillable_query_ratio, + 0.7, + "The ratio of queries that are spillable."); + DEFINE_double( spill_faulty_fs_ratio, - 0.1, + 0.2, "Chance of spill filesystem being faulty(expressed as double from 0 to 1)"); -DEFINE_int32( +DEFINE_double( spill_fs_fault_injection_ratio, - 0.01, + 0.02, "The chance of actually injecting fault in file operations for spill " "filesystem. This is only applicable when 'spill_faulty_fs_ratio' is " "larger than 0"); @@ -96,9 +100,15 @@ DEFINE_int32( "After each specified number of milliseconds, abort a random task." "If given 0, no task will be aborted."); +DEFINE_string( + plan_type, + "all", + "Type of plans to test. Options: all, hash_join, aggregate, " + "row_number, topn_row_number, order_by."); + using namespace facebook::velox::tests::utils; -namespace facebook::velox::exec::test { +namespace facebook::velox::exec { namespace { using fuzzer::coinToss; @@ -147,7 +157,7 @@ class MemoryArbitrationFuzzer { return boost::random::uniform_int_distribution(min, max)(rng_); } - std::shared_ptr maybeGenerateFaultySpillDirectory(); + std::shared_ptr maybeGenerateFaultySpillDirectory(); // Returns a list of randomly generated key types for join and aggregation. std::vector generateKeyTypes(int32_t numKeys); @@ -158,7 +168,8 @@ class MemoryArbitrationFuzzer { // Returns randomly generated input with up to 3 additional payload columns. std::vector generateInput( const std::vector& keyNames, - const std::vector& keyTypes); + const std::vector& keyTypes, + int32_t minPayload = 0); // Reuses the 'generateInput' method to return randomly generated // probe input. @@ -178,6 +189,12 @@ class MemoryArbitrationFuzzer { const std::vector& keyNames, const std::vector& keyTypes); + // Reuses the 'generateInput' method to return randomly generated + // topN row number input. + std::vector generateTopNRowNumberInput( + const std::vector& keyNames, + const std::vector& keyTypes); + // Reuses the 'generateInput' method to return randomly generated // order by input. std::vector generateOrderByInput( @@ -207,6 +224,8 @@ class MemoryArbitrationFuzzer { std::vector rowNumberPlans(const std::string& tableDir); + std::vector topNRowNumberPlans(const std::string& tableDir); + std::vector orderByPlans(const std::string& tableDir); // Helper method that combines all above plan methods into one. @@ -223,6 +242,9 @@ class MemoryArbitrationFuzzer { return opts; } + std::string extractQueryIdFromSpillPath(const std::string& spillPath); + + const std::string kQueryIdPrefix = "query_id_"; FuzzerGenerator rng_; size_t currentSeed_{0}; std::unordered_map queryConfigsWithSpill_{ @@ -231,6 +253,7 @@ class MemoryArbitrationFuzzer { {core::QueryConfig::kSpillStartPartitionBit, "29"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}, {core::QueryConfig::kRowNumberSpillEnabled, "true"}, + {core::QueryConfig::kTopNRowNumberSpillEnabled, "true"}, {core::QueryConfig::kOrderBySpillEnabled, "true"}, }; @@ -240,7 +263,7 @@ class MemoryArbitrationFuzzer { memory::kMaxMemory, memory::MemoryReclaimer::create())}; std::shared_ptr pool_{ - memory::memoryManager()->testingDefaultRoot().addLeafChild( + memory::memoryManager()->deprecatedSysRootPool().addLeafChild( "memoryArbitrationFuzzerLeaf", true)}; std::shared_ptr writerPool_{rootPool_->addAggregateChild( @@ -250,12 +273,16 @@ class MemoryArbitrationFuzzer { VectorFuzzer vectorFuzzer_; std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::hardware_concurrency())}; folly::Synchronized stats_; }; MemoryArbitrationFuzzer::MemoryArbitrationFuzzer(size_t initialSeed) : vectorFuzzer_{getFuzzerOptions(), pool_.get()} { + // Set timestamp precision as milliseconds, as timestamp may be used as + // paritition key, and presto doesn't supports nanosecond precision. + vectorFuzzer_.getMutableOptions().timestampPrecision = + fuzzer::FuzzerTimestampPrecision::kMilliSeconds; if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); } @@ -268,14 +295,11 @@ MemoryArbitrationFuzzer::MemoryArbitrationFuzzer(size_t initialSeed) // Make sure not to run out of open file descriptors. std::unordered_map hiveConfig = { {connector::hive::HiveConfig::kNumCacheFileHandles, "1000"}}; - connector::registerConnectorFactory( - std::make_shared()); - const auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared(std::move(hiveConfig))); + + connector::hive::HiveConnectorFactory hiveFactory; + const auto hiveConnector = hiveFactory.newConnector( + test::kHiveConnectorId, + std::make_shared(std::move(hiveConfig))); connector::registerConnector(hiveConnector); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); @@ -318,7 +342,8 @@ MemoryArbitrationFuzzer::generatePartitionKeys() { std::vector MemoryArbitrationFuzzer::generateInput( const std::vector& keyNames, - const std::vector& keyTypes) { + const std::vector& keyTypes, + int32_t minPayload) { std::vector names = keyNames; std::vector types = keyTypes; @@ -330,8 +355,7 @@ std::vector MemoryArbitrationFuzzer::generateInput( } } - // Add up to 3 payload columns. - const auto numPayload = randInt(0, 3); + const auto numPayload = randInt(minPayload, 3); for (auto i = 0; i < numPayload; ++i) { names.push_back(fmt::format("tp{}", i + keyNames.size())); types.push_back(vectorFuzzer_.randType(2 /*maxDepth*/)); @@ -426,6 +450,12 @@ std::vector MemoryArbitrationFuzzer::generateRowNumberInput( return generateInput(keyNames, keyTypes); } +std::vector MemoryArbitrationFuzzer::generateTopNRowNumberInput( + const std::vector& keyNames, + const std::vector& keyTypes) { + return generateInput(keyNames, keyTypes, 1); +} + std::vector MemoryArbitrationFuzzer::generateOrderByInput( const std::vector& keyNames, const std::vector& keyTypes) { @@ -445,7 +475,7 @@ MemoryArbitrationFuzzer::hashJoinPlans( (core::isLeftSemiProjectJoin(joinType) || core::isLeftSemiFilterJoin(joinType) || core::isAntiJoin(joinType)) ? asRowType(probeInput[0]->type())->names() - : concat( + : test::concat( asRowType(probeInput[0]->type()), asRowType(buildInput[0]->type())) ->names(); @@ -456,22 +486,23 @@ MemoryArbitrationFuzzer::hashJoinPlans( std::vector plans; auto planNodeIdGenerator = std::make_shared(); - auto plan = - PlanBuilder(planNodeIdGenerator) - .values(probeInput) - .hashJoin( - probeKeys, - buildKeys, - PlanBuilder(planNodeIdGenerator).values(buildInput).planNode(), - /*filter=*/"", - outputColumns, - joinType, - false) - .planNode(); + auto plan = test::PlanBuilder(planNodeIdGenerator) + .values(probeInput) + .hashJoin( + probeKeys, + buildKeys, + test::PlanBuilder(planNodeIdGenerator) + .values(buildInput) + .planNode(), + /*filter=*/"", + outputColumns, + joinType, + false) + .planNode(); plans.push_back(PlanWithSplits{std::move(plan), {}}); - if (!isTableScanSupported(probeInput[0]->type()) || - !isTableScanSupported(buildInput[0]->type())) { + if (!test::isTableScanSupported(probeInput[0]->type()) || + !test::isTableScanSupported(buildInput[0]->type())) { return plans; } @@ -480,13 +511,13 @@ MemoryArbitrationFuzzer::hashJoinPlans( const auto buildType = asRowType(buildInput[0]->type()); core::PlanNodeId probeScanId; core::PlanNodeId buildScanId; - plan = PlanBuilder(planNodeIdGenerator) + plan = test::PlanBuilder(planNodeIdGenerator) .tableScan(probeType) .capturePlanNodeId(probeScanId) .hashJoin( probeKeys, buildKeys, - PlanBuilder(planNodeIdGenerator) + test::PlanBuilder(planNodeIdGenerator) .tableScan(buildType) .capturePlanNodeId(buildScanId) .planNode(), @@ -495,9 +526,10 @@ MemoryArbitrationFuzzer::hashJoinPlans( joinType, false) .planNode(); - plans.push_back(PlanWithSplits{ - std::move(plan), - {{probeScanId, probeSplits}, {buildScanId, buildSplits}}}); + plans.push_back( + PlanWithSplits{ + std::move(plan), + {{probeScanId, probeSplits}, {buildScanId, buildSplits}}}); return plans; } @@ -513,14 +545,14 @@ MemoryArbitrationFuzzer::hashJoinPlans(const std::string& tableDir) { const auto numKeys = randInt(1, 5); const std::vector keyTypes = generateKeyTypes(numKeys); - std::vector probeKeys = makeNames("t", keyTypes.size()); - std::vector buildKeys = makeNames("u", keyTypes.size()); + std::vector probeKeys = test::makeNames("t", keyTypes.size()); + std::vector buildKeys = test::makeNames("u", keyTypes.size()); const auto probeInput = generateProbeInput(probeKeys, keyTypes); const auto buildInput = generateBuildInput(probeInput, probeKeys, buildKeys); - const std::vector probeScanSplits = - makeSplits(probeInput, fmt::format("{}/probe", tableDir), writerPool_); - const std::vector buildScanSplits = - makeSplits(buildInput, fmt::format("{}/build", tableDir), writerPool_); + const std::vector probeScanSplits = test::makeSplits( + probeInput, fmt::format("{}/probe", tableDir), writerPool_); + const std::vector buildScanSplits = test::makeSplits( + buildInput, fmt::format("{}/build", tableDir), writerPool_); std::vector totalPlans; for (const auto& joinType : kJoinTypes) { @@ -545,10 +577,11 @@ MemoryArbitrationFuzzer::aggregatePlans(const std::string& tableDir) { const auto numKeys = randInt(1, 5); // Reuse the hash join utilities to generate aggregation keys and inputs. const std::vector keyTypes = generateKeyTypes(numKeys); - const std::vector groupingKeys = makeNames("g", keyTypes.size()); + const std::vector groupingKeys = + test::makeNames("g", keyTypes.size()); const auto aggregateInput = generateAggregateInput(groupingKeys, keyTypes); const std::vector aggregates{"count(1)"}; - const std::vector splits = makeSplits( + const std::vector splits = test::makeSplits( aggregateInput, fmt::format("{}/aggregate", tableDir), writerPool_); std::vector plans; @@ -559,7 +592,7 @@ MemoryArbitrationFuzzer::aggregatePlans(const std::string& tableDir) { std::make_shared(); core::PlanNodeId scanId; auto plan = PlanWithSplits{ - PlanBuilder(planNodeIdGenerator) + test::PlanBuilder(planNodeIdGenerator) .tableScan(inputRowType) .capturePlanNodeId(scanId) .singleAggregation(groupingKeys, aggregates, {}) @@ -568,7 +601,7 @@ MemoryArbitrationFuzzer::aggregatePlans(const std::string& tableDir) { plans.push_back(std::move(plan)); plan = PlanWithSplits{ - PlanBuilder() + test::PlanBuilder() .values(aggregateInput) .singleAggregation(groupingKeys, aggregates, {}) .planNode(), @@ -582,7 +615,7 @@ MemoryArbitrationFuzzer::aggregatePlans(const std::string& tableDir) { std::make_shared(); core::PlanNodeId scanId; auto plan = PlanWithSplits{ - PlanBuilder(planNodeIdGenerator) + test::PlanBuilder(planNodeIdGenerator) .tableScan(inputRowType) .capturePlanNodeId(scanId) .partialAggregation(groupingKeys, aggregates, {}) @@ -592,7 +625,7 @@ MemoryArbitrationFuzzer::aggregatePlans(const std::string& tableDir) { plans.push_back(std::move(plan)); plan = PlanWithSplits{ - PlanBuilder() + test::PlanBuilder() .values(aggregateInput) .partialAggregation(groupingKeys, aggregates, {}) .finalAggregation() @@ -607,7 +640,7 @@ MemoryArbitrationFuzzer::aggregatePlans(const std::string& tableDir) { std::make_shared(); core::PlanNodeId scanId; auto plan = PlanWithSplits{ - PlanBuilder(planNodeIdGenerator) + test::PlanBuilder(planNodeIdGenerator) .tableScan(inputRowType) .capturePlanNodeId(scanId) .partialAggregation(groupingKeys, aggregates, {}) @@ -618,7 +651,7 @@ MemoryArbitrationFuzzer::aggregatePlans(const std::string& tableDir) { plans.push_back(std::move(plan)); plan = PlanWithSplits{ - PlanBuilder() + test::PlanBuilder() .values(aggregateInput) .partialAggregation(groupingKeys, aggregates, {}) .intermediateAggregation() @@ -641,7 +674,7 @@ MemoryArbitrationFuzzer::rowNumberPlans(const std::string& tableDir) { std::vector projectFields = keyNames; projectFields.emplace_back("row_number"); auto plan = PlanWithSplits{ - PlanBuilder() + test::PlanBuilder() .values(input) .rowNumber(keyNames) .project(projectFields) @@ -649,17 +682,17 @@ MemoryArbitrationFuzzer::rowNumberPlans(const std::string& tableDir) { {}}; plans.push_back(std::move(plan)); - if (!isTableScanSupported(input[0]->type())) { + if (!test::isTableScanSupported(input[0]->type())) { return plans; } - const std::vector splits = - makeSplits(input, fmt::format("{}/row_number", tableDir), writerPool_); + const std::vector splits = test::makeSplits( + input, fmt::format("{}/row_number", tableDir), writerPool_); auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId scanId; plan = PlanWithSplits{ - PlanBuilder(planNodeIdGenerator) + test::PlanBuilder(planNodeIdGenerator) .tableScan(asRowType(input[0]->type())) .capturePlanNodeId(scanId) .rowNumber(keyNames) @@ -671,6 +704,90 @@ MemoryArbitrationFuzzer::rowNumberPlans(const std::string& tableDir) { return plans; } +std::vector +MemoryArbitrationFuzzer::topNRowNumberPlans(const std::string& tableDir) { + static const std::vector kRankFunctions = { + "row_number", "rank", "dense_rank"}; + + const auto [keyNames, keyTypes] = generatePartitionKeys(); + const auto input = generateTopNRowNumberInput(keyNames, keyTypes); + + std::vector plans; + + const auto inputType = asRowType(input[0]->type()); + std::vector sortingKeys; + + std::unordered_set partitionKeySet( + keyNames.begin(), keyNames.end()); + for (const auto& name : inputType->names()) { + if (partitionKeySet.find(name) == partitionKeySet.end()) { + sortingKeys.push_back(name); + } + } + + const auto numSortingKeys = randInt(1, sortingKeys.size()); + sortingKeys.resize(numSortingKeys); + + const auto rankFunction = + kRankFunctions[randInt(0, kRankFunctions.size() - 1)]; + const auto limit = randInt(1, 100); + const bool generateRowNumber = vectorFuzzer_.coinToss(0.5); + + std::vector projectFields = keyNames; + if (generateRowNumber) { + projectFields.emplace_back("row_number"); + } + + // Values plan with Partiton Keys + auto plan = PlanWithSplits{ + test::PlanBuilder() + .values(input) + .topNRank( + rankFunction, keyNames, sortingKeys, limit, generateRowNumber) + .project(projectFields) + .planNode(), + {}}; + plans.push_back(std::move(plan)); + + if (!test::isTableScanSupported(input[0]->type())) { + return plans; + } + + const std::vector splits = test::makeSplits( + input, fmt::format("{}/topn_row_number", tableDir), writerPool_); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId scanId; + // TableScan Plan with Parition Keys + plan = PlanWithSplits{ + test::PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(input[0]->type())) + .capturePlanNodeId(scanId) + .topNRank( + rankFunction, keyNames, sortingKeys, limit, generateRowNumber) + .project(projectFields) + .planNode(), + {{scanId, splits}}}; + plans.push_back(std::move(plan)); + + std::vector globalProjectFields; + if (generateRowNumber) { + globalProjectFields.emplace_back("row_number"); + } + + // Global TopN + plan = PlanWithSplits{ + test::PlanBuilder() + .values(input) + .topNRank(rankFunction, {}, sortingKeys, limit, generateRowNumber) + .project(globalProjectFields) + .planNode(), + {}}; + plans.push_back(std::move(plan)); + + return plans; +} + std::vector MemoryArbitrationFuzzer::orderByPlans(const std::string& tableDir) { const auto [keyNames, keyTypes] = generatePartitionKeys(); @@ -679,20 +796,21 @@ MemoryArbitrationFuzzer::orderByPlans(const std::string& tableDir) { std::vector plans; auto plan = PlanWithSplits{ - PlanBuilder().values(input).orderBy(keyNames, false).planNode(), {}}; + test::PlanBuilder().values(input).orderBy(keyNames, false).planNode(), + {}}; plans.push_back(std::move(plan)); - if (!isTableScanSupported(input[0]->type())) { + if (!test::isTableScanSupported(input[0]->type())) { return plans; } - const std::vector splits = - makeSplits(input, fmt::format("{}/order_by", tableDir), writerPool_); + const std::vector splits = test::makeSplits( + input, fmt::format("{}/order_by", tableDir), writerPool_); auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId scanId; plan = PlanWithSplits{ - PlanBuilder(std::move(planNodeIdGenerator)) + test::PlanBuilder(std::move(planNodeIdGenerator)) .tableScan(asRowType(input[0]->type())) .capturePlanNodeId(scanId) .orderBy(keyNames, false) @@ -706,33 +824,62 @@ MemoryArbitrationFuzzer::orderByPlans(const std::string& tableDir) { std::vector MemoryArbitrationFuzzer::allPlans(const std::string& tableDir) { std::vector plans; - for (const auto& plan : hashJoinPlans(tableDir)) { - plans.push_back(plan); - } - for (const auto& plan : aggregatePlans(tableDir)) { - plans.push_back(plan); - } - for (const auto& plan : rowNumberPlans(tableDir)) { - plans.push_back(plan); - } - for (const auto& plan : orderByPlans(tableDir)) { - plans.push_back(plan); - } + const std::string planType = FLAGS_plan_type; + + auto appendPlansIf = + [&](const std::string& type, + std::function(const std::string&)> + planGenerator) { + if (planType == "all" || planType == type) { + auto newPlans = planGenerator(tableDir); + plans.insert( + plans.end(), + std::make_move_iterator(newPlans.begin()), + std::make_move_iterator(newPlans.end())); + } + }; + appendPlansIf("hash_join", [this](const std::string& dir) { + return hashJoinPlans(dir); + }); + appendPlansIf("aggregate", [this](const std::string& dir) { + return aggregatePlans(dir); + }); + appendPlansIf("row_number", [this](const std::string& dir) { + return rowNumberPlans(dir); + }); + appendPlansIf("topn_row_number", [this](const std::string& dir) { + return topNRowNumberPlans(dir); + }); + appendPlansIf( + "order_by", [this](const std::string& dir) { return orderByPlans(dir); }); + + VELOX_USER_CHECK( + !plans.empty(), + "No plans generated for plan_type: {}. Valid options are: all, hash_join, aggregate, row_number, topn_row_number, order_by", + planType); + return plans; } -struct ThreadLocalStats { - uint64_t spillFsFaultCount{0}; -}; +std::string MemoryArbitrationFuzzer::extractQueryIdFromSpillPath( + const std::string& spillPath) { + std::vector parts; + folly::split('/', spillPath, parts); + for (const auto& part : parts) { + if (part.starts_with(kQueryIdPrefix)) { + return part; + } + } + VELOX_FAIL("No query id found in spill path: {}", spillPath); +} // Stats that keeps track of per thread execution status in verify() -thread_local ThreadLocalStats threadLocalStats; +folly::ConcurrentHashMap spillFsTaskSet; -std::shared_ptr +std::shared_ptr MemoryArbitrationFuzzer::maybeGenerateFaultySpillDirectory() { FuzzerGenerator fsRng(rng_()); - const auto injectFsFault = - coinToss(fsRng, FLAGS_spill_fs_fault_injection_ratio); + const auto injectFsFault = coinToss(fsRng, FLAGS_spill_faulty_fs_ratio); if (!injectFsFault) { return exec::test::TempDirectoryPath::create(false); } @@ -756,10 +903,14 @@ MemoryArbitrationFuzzer::maybeGenerateFaultySpillDirectory() { } FuzzerGenerator fsRng(rng_()); if (coinToss(fsRng, FLAGS_spill_fs_fault_injection_ratio)) { - ++threadLocalStats.spillFsFaultCount; + auto queryId = extractQueryIdFromSpillPath(op->path); + spillFsTaskSet.insert(queryId, folly::Unit()); VELOX_FAIL( - "Fault file injection on {}", - FaultFileOperation::typeString(op->type)); + "Fault file injection on {} of query {} path {}", + FaultFileOperation::typeString(op->type), + queryId, + op->path, + process::StackTrace().toString()); } }); return directory; @@ -772,7 +923,7 @@ void MemoryArbitrationFuzzer::verify() { auto plans = allPlans(tableScanDir->getPath()); SCOPE_EXIT { - waitForAllTasksToBeDeleted(); + test::waitForAllTasksToBeDeleted(); if (auto faultyFileSystem = std::dynamic_pointer_cast( filesystems::getFileSystem(spillDirectory->getPath(), nullptr))) { faultyFileSystem->clearFileFaultInjections(); @@ -792,36 +943,34 @@ void MemoryArbitrationFuzzer::verify() { queryThreads.emplace_back([&, spillDirectory, i, seed]() { FuzzerGenerator rng(seed); while (!stop) { - const auto prevSpillFsFaultCount = threadLocalStats.spillFsFaultCount; - const auto queryId = fmt::format("query_id_{}", queryCount++); + const auto queryId = fmt::format("{}{}", kQueryIdPrefix, queryCount++); queryTaskAbortRequestMap.insert(queryId, false); try { - const auto queryCtx = newQueryCtx( + const auto queryCtx = test::newQueryCtx( memory::memoryManager(), executor_.get(), FLAGS_arbitrator_capacity, queryId); const auto plan = plans.at(getRandomIndex(rng, plans.size() - 1)); - AssertQueryBuilder builder(plan.plan); + test::AssertQueryBuilder builder(plan.plan); builder.queryCtx(queryCtx); for (const auto& [planNodeId, nodeSplits] : plan.splits) { builder.splits(planNodeId, nodeSplits); } - if (coinToss(rng, 0.3)) { - builder.queryCtx(queryCtx).copyResults(pool_.get()); + if (coinToss(rng, FLAGS_spillable_query_ratio)) { + auto res = builder.configs(queryConfigsWithSpill_) + .spillDirectory( + spillDirectory->getPath() + + fmt::format("/{}/{}", i, queryId)) + .queryCtx(queryCtx) + .copyResults(pool_.get()); } else { - auto res = - builder.configs(queryConfigsWithSpill_) - .spillDirectory( - spillDirectory->getPath() + fmt::format("/{}/", i)) - .queryCtx(queryCtx) - .copyResults(pool_.get()); + builder.queryCtx(queryCtx).copyResults(pool_.get()); } ++stats_.wlock()->successCount; - VELOX_CHECK_EQ( - threadLocalStats.spillFsFaultCount, prevSpillFsFaultCount); + VELOX_CHECK(spillFsTaskSet.find(queryId) == spillFsTaskSet.end()); } catch (const VeloxException& e) { auto lockedStats = stats_.wlock(); if (e.errorCode() == error_code::kMemCapExceeded.c_str()) { @@ -830,9 +979,28 @@ void MemoryArbitrationFuzzer::verify() { ++lockedStats->abortCount; } else if (e.errorCode() == error_code::kInvalidState.c_str()) { const auto injectedSpillFsFault = - threadLocalStats.spillFsFaultCount > prevSpillFsFaultCount; + spillFsTaskSet.find(queryId) != spillFsTaskSet.end(); + if (injectedSpillFsFault) { + spillFsTaskSet.erase(queryId); + } const auto injectedTaskAbortRequest = queryTaskAbortRequestMap.find(queryId)->second; + + // Debug logging to understand the failure + if (!injectedSpillFsFault && !injectedTaskAbortRequest) { + LOG(ERROR) << "============== VELOX_CHECK failure debug info:"; + LOG(ERROR) << " queryId: " << queryId; + LOG(ERROR) << " spillFsTaskSet size: " << spillFsTaskSet.size(); + LOG(ERROR) << " spillFsTaskSet contents:"; + // Iterate through spillFsTaskSet to log contents + for (auto it = spillFsTaskSet.cbegin(); + it != spillFsTaskSet.cend(); + ++it) { + LOG(ERROR) << " key: " << it->first; + } + LOG(ERROR) << " error message: " << e.message(); + } + VELOX_CHECK( injectedSpillFsFault || injectedTaskAbortRequest, "injectedSpillFsFault: {}, injectedTaskAbortRequest: {}, error message: {}", @@ -947,4 +1115,4 @@ void MemoryArbitrationFuzzer::go() { void memoryArbitrationFuzzer(size_t seed) { MemoryArbitrationFuzzer(seed).go(); } -} // namespace facebook::velox::exec::test +} // namespace facebook::velox::exec diff --git a/velox/exec/fuzzer/MemoryArbitrationFuzzer.h b/velox/exec/fuzzer/MemoryArbitrationFuzzer.h index 73f32e2215d0..622d27084ae8 100644 --- a/velox/exec/fuzzer/MemoryArbitrationFuzzer.h +++ b/velox/exec/fuzzer/MemoryArbitrationFuzzer.h @@ -17,6 +17,6 @@ #include -namespace facebook::velox::exec::test { +namespace facebook::velox::exec { void memoryArbitrationFuzzer(size_t seed); } diff --git a/velox/exec/tests/MemoryArbitrationFuzzerTest.cpp b/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.cpp similarity index 65% rename from velox/exec/tests/MemoryArbitrationFuzzerTest.cpp rename to velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.cpp index 1ad0247fb887..cfa1e928c1d1 100644 --- a/velox/exec/tests/MemoryArbitrationFuzzerTest.cpp +++ b/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.cpp @@ -16,13 +16,16 @@ #include #include -#include #include + +#include "velox/common/file/tests/FaultyFileSystem.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + #include "velox/common/memory/SharedArbitrator.h" #include "velox/connectors/hive/HiveConnector.h" -#include "velox/exec/MemoryReclaimer.h" #include "velox/exec/fuzzer/FuzzerUtil.h" -#include "velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.h" +#include "velox/exec/fuzzer/MemoryArbitrationFuzzer.h" #include "velox/exec/fuzzer/PrestoQueryRunner.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" @@ -39,13 +42,17 @@ DEFINE_int64( using namespace facebook::velox::exec; int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - // Calls common init functions in the necessary order, initializing - // singletons, installing proper signal handlers for better debugging + // singletons, installing proper signal handlers for a better debugging // experience, and initialize glog and gflags. folly::Init init(&argc, &argv); test::setupMemory(FLAGS_allocator_capacity, FLAGS_arbitrator_capacity); - const size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; - return test::MemoryArbitrationFuzzerRunner::run(initialSeed); + const size_t seed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; + + facebook::velox::serializer::presto::PrestoVectorSerde::registerVectorSerde(); + facebook::velox::filesystems::registerLocalFileSystem(); + facebook::velox::tests::utils::registerFaultyFileSystem(); + facebook::velox::functions::prestosql::registerAllScalarFunctions(); + facebook::velox::aggregate::prestosql::registerAllAggregateFunctions(); + memoryArbitrationFuzzer(seed); } diff --git a/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.h b/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.h deleted file mode 100644 index dfe8144bb21e..000000000000 --- a/velox/exec/fuzzer/MemoryArbitrationFuzzerRunner.h +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "velox/common/file/FileSystems.h" - -#include "velox/common/file/tests/FaultyFileSystem.h" -#include "velox/exec/fuzzer/MemoryArbitrationFuzzer.h" -#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" -#include "velox/functions/prestosql/registration/RegistrationFunctions.h" -#include "velox/serializers/PrestoSerializer.h" - -namespace facebook::velox::exec::test { - -class MemoryArbitrationFuzzerRunner { - public: - static int run(size_t seed) { - serializer::presto::PrestoVectorSerde::registerVectorSerde(); - filesystems::registerLocalFileSystem(); - tests::utils::registerFaultyFileSystem(); - functions::prestosql::registerAllScalarFunctions(); - aggregate::prestosql::registerAllAggregateFunctions(); - memoryArbitrationFuzzer(seed); - return RUN_ALL_TESTS(); - } -}; - -} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index b545235b683e..a2d10ab3f845 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include // @manual +#include #include #include +#include #include #include "velox/common/base/Fs.h" @@ -32,19 +33,35 @@ #include "velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h" #include "velox/exec/fuzzer/PrestoSql.h" #include "velox/exec/tests/utils/QueryAssertions.h" +#include "velox/functions/prestosql/types/BingTileType.h" +#include "velox/functions/prestosql/types/GeometryType.h" +#include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/IPAddressType.h" #include "velox/functions/prestosql/types/IPPrefixType.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/KHyperLogLogType.h" +#include "velox/functions/prestosql/types/QDigestType.h" +#include "velox/functions/prestosql/types/SetDigestType.h" +#include "velox/functions/prestosql/types/SfmSketchType.h" +#include "velox/functions/prestosql/types/TDigestType.h" +#include "velox/functions/prestosql/types/TimeWithTimezoneType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/functions/prestosql/types/UuidType.h" +#include "velox/functions/prestosql/types/parser/TypeParser.h" #include "velox/serializers/PrestoSerializer.h" -#include "velox/type/parser/TypeParser.h" using namespace facebook::velox; namespace facebook::velox::exec::test { namespace { +static size_t +writeFunction(char* data, size_t size, size_t nmemb, void* userdata) { + std::string* response = static_cast(userdata); + response->append(data, size * nmemb); + return size * nmemb; +} + void writeToFile( const std::string& path, const std::vector& data, @@ -126,7 +143,9 @@ class ServerResponse { std::vector types; for (const auto& column : response_["columns"]) { names.push_back(column["name"].asString()); - types.push_back(parseType(column["type"].asString())); + types.push_back( + facebook::velox::functions::prestosql::parseType( + column["type"].asString())); } auto rowType = ROW(std::move(names), std::move(types)); @@ -191,7 +210,8 @@ const std::vector& PrestoQueryRunner::supportedScalarTypes() const { // static bool PrestoQueryRunner::isSupportedDwrfType(const TypePtr& type) { - if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown()) { + if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown() || + isGeometryType(type)) { return false; } @@ -222,10 +242,10 @@ PrestoQueryRunner::inputProjections( // unchanged and the projection is just an identity mapping. if (isIntermediateOnlyType(childType)) { for (int batchIndex = 0; batchIndex < input.size(); batchIndex++) { - children[batchIndex].push_back(transformIntermediateOnlyType( - input[batchIndex]->childAt(childIndex))); + children[batchIndex].push_back( + transformIntermediateTypes(input[batchIndex]->childAt(childIndex))); } - projections.push_back(getIntermediateOnlyTypeProjectionExpr( + projections.push_back(getProjectionsToIntermediateTypes( childType, std::make_shared( names[childIndex], names[childIndex]), @@ -235,8 +255,9 @@ PrestoQueryRunner::inputProjections( children[batchIndex].push_back(input[batchIndex]->childAt(childIndex)); } - projections.push_back(std::make_shared( - names[childIndex], names[childIndex])); + projections.push_back( + std::make_shared( + names[childIndex], names[childIndex])); } } @@ -250,12 +271,13 @@ PrestoQueryRunner::inputProjections( std::vector output; output.reserve(input.size()); for (int batchIndex = 0; batchIndex < input.size(); batchIndex++) { - output.push_back(std::make_shared( - input[batchIndex]->pool(), - rowType, - input[batchIndex]->nulls(), - input[batchIndex]->size(), - std::move(children[batchIndex]))); + output.push_back( + std::make_shared( + input[batchIndex]->pool(), + rowType, + input[batchIndex]->nulls(), + input[batchIndex]->size(), + std::move(children[batchIndex]))); } return std::make_pair(output, projections); @@ -301,27 +323,45 @@ bool PrestoQueryRunner::isConstantExprSupported( return type->isPrimitiveType() && !type->isTimestamp() && !isJsonType(type) && !type->isIntervalDayTime() && !isIPAddressType(type) && !isIPPrefixType(type) && !isUuidType(type) && - !isTimestampWithTimeZoneType(type); + !isTimestampWithTimeZoneType(type) && !isHyperLogLogType(type) && + !isKHyperLogLogType(type) && !isTDigestType(type) && + !isQDigestType(type) && !isSetDigestType(type) && + !isBingTileType(type) && !isSfmSketchType(type) && + !isTimeWithTimeZone(type); } return true; } bool PrestoQueryRunner::isSupported(const exec::FunctionSignature& signature) { - // TODO: support queries with these types. Among the types below, hugeint is - // not a native type in Presto, so fuzzer should not use it as the type of - // cast-to or constant literals. Hyperloglog and TDigest can only be casted - // from varbinary and cannot be used as the type of constant literals. - // Interval year to month can only be casted from NULL and cannot be used as - // the type of constant literals. Json, Ipaddress, Ipprefix, and UUID require - // special handling, because Presto requires literals of these types to be - // valid, and doesn't allow creating HIVE columns of these types. + // TODO: support queries with these types. + // Types not supported by PrestoQueryRunner and their reasons: + // + // hugeint: + // - Not a native type in Presto + // - Fuzzer should not use it for cast-to or constant literals + // + // interval year to month: + // - Can only be casted from NULL + // - Cannot be used as constant literal types + // + // ipaddress, ipprefix, uuid: + // - Require special handling in Presto + // - Presto requires literals of these types to be valid + // - Cannot create HIVE columns of these types + // + // geometry: + // - Under development in Presto + // - Cannot be used as constant literals + // - Expected differences between Presto Java and Velox C++ implementations + // + // p4hyperloglog: + // - Not a native type in Presto + // - Cannot create HIVE columns of these types return !( - usesTypeName(signature, "bingtile") || usesTypeName(signature, "interval year to month") || usesTypeName(signature, "hugeint") || - usesTypeName(signature, "hyperloglog") || - usesTypeName(signature, "tdigest") || - usesInputTypeName(signature, "json") || + usesTypeName(signature, "geometry") || usesTypeName(signature, "time") || + usesTypeName(signature, "p4hyperloglog") || usesInputTypeName(signature, "ipaddress") || usesInputTypeName(signature, "ipprefix") || usesInputTypeName(signature, "uuid")); @@ -355,17 +395,19 @@ std::string PrestoQueryRunner::createTable( execute(fmt::format("DROP TABLE IF EXISTS {}", name)); - execute(fmt::format( - "CREATE TABLE {}({}) WITH (format = 'DWRF') AS SELECT {}", - name, - folly::join(", ", inputType->names()), - nullValues.str())); + execute( + fmt::format( + "CREATE TABLE {}({}) WITH (format = 'DWRF') AS SELECT {}", + name, + folly::join(", ", inputType->names()), + nullValues.str())); // Query Presto to find out table's location on disk. auto results = execute(fmt::format("SELECT \"$path\" FROM {}", name)); - auto filePath = extractSingleValue(results); - auto tableDirectoryPath = fs::path(filePath).parent_path(); + + // TODO: Remove explicit std::string_view cast. + auto tableDirectoryPath = fs::path(std::string_view(filePath)).parent_path(); // Delete the all-null row. execute(fmt::format("DELETE FROM {}", name)); @@ -373,6 +415,10 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } +void PrestoQueryRunner::cleanUp(const std::string& name) { + execute(fmt::format("DROP TABLE IF EXISTS {}", name)); +} + std::pair< std::optional>, ReferenceQueryErrorCode> @@ -402,8 +448,13 @@ PrestoQueryRunner::executeAndReturnVector(const core::PlanNodePtr& plan) { writeToFile(filePath, input, writerPool.get()); } - // Run the query. - return std::make_pair(execute(*sql), ReferenceQueryErrorCode::kSuccess); + // Run the query. If successful, delete the table. + auto result = execute(*sql); + for (const auto& [tableName, _] : inputMap) { + cleanUp(tableName); + } + + return std::make_pair(result, ReferenceQueryErrorCode::kSuccess); } catch (const VeloxRuntimeError& e) { // Throw if connection to Presto server is unsuccessful. if (e.message().find("Couldn't connect to server") != std::string::npos) { @@ -455,38 +506,78 @@ std::vector PrestoQueryRunner::execute( std::string PrestoQueryRunner::startQuery( const std::string& sql, const std::string& sessionProperty) { - auto uri = fmt::format("{}/v1/statement?binaryResults=true", coordinatorUri_); - cpr::Url url{uri}; - cpr::Body body{sql}; - cpr::Header header( - {{"X-Presto-User", user_}, - {"X-Presto-Catalog", "hive"}, - {"X-Presto-Schema", "tpch"}, - {"Content-Type", "text/plain"}, - {"X-Presto-Session", sessionProperty}}); - cpr::Timeout timeout{timeout_}; - cpr::Response response = cpr::Post(url, body, header, timeout); + CURL* curl = curl_easy_init(); + VELOX_CHECK_NOT_NULL(curl, "Failed to initialize libcurl"); + + // Prepare curl headers + struct curl_slist* headers = nullptr; + headers = curl_slist_append( + headers, fmt::format("X-Presto-User: {}", user_).c_str()); + headers = curl_slist_append(headers, "X-Presto-Catalog: hive"); + headers = curl_slist_append(headers, "X-Presto-Schema: tpch"); + headers = curl_slist_append(headers, "Content-Type: text/plain"); + headers = curl_slist_append( + headers, fmt::format("X-Presto-Session: {}", sessionProperty).c_str()); + + std::string url = + fmt::format("{}/v1/statement?binaryResults=true", coordinatorUri_); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, sql.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, sql.size()); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, timeout_); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunction); + curl_easy_setopt(curl, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2); + + std::string response; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + + // Perform the request + CURLcode res = curl_easy_perform(curl); VELOX_CHECK_EQ( - response.status_code, - 200, + CURLE_OK, + res, "POST to {} failed: {}", - uri, - response.error.message); - return response.text; + coordinatorUri_, + curl_easy_strerror(res)); + + // Cleanup + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return response; } std::string PrestoQueryRunner::fetchNext(const std::string& nextUri) { - cpr::Url url(nextUri); - cpr::Header header({{"X-Presto-Client-Binary-Results", "true"}}); - cpr::Timeout timeout{timeout_}; - cpr::Response response = cpr::Get(url, header, timeout); + CURL* curl = curl_easy_init(); + VELOX_CHECK_NOT_NULL(curl, "Failed to initialize libcurl"); + + // Set up the request URL + curl_easy_setopt(curl, CURLOPT_URL, nextUri.c_str()); + + // Set up headers + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "X-Presto-Client-Binary-Results: true"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, timeout_); + curl_easy_setopt(curl, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2); + + // Capture the response body + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeFunction); + std::string response; + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + + // Perform GET request + CURLcode res = curl_easy_perform(curl); VELOX_CHECK_EQ( - response.status_code, - 200, - "GET from {} failed: {}", - nextUri, - response.error.message); - return response.text; + CURLE_OK, res, "Get request failed: {}", curl_easy_strerror(res)); + + // Cleanup + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return response; } bool PrestoQueryRunner::supportsVeloxVectorResults() const { @@ -494,3 +585,10 @@ bool PrestoQueryRunner::supportsVeloxVectorResults() const { } } // namespace facebook::velox::exec::test + +template <> +struct fmt::formatter : formatter { + auto format(CURLcode s, format_context& ctx) const { + return formatter::format(static_cast(s), ctx); + } +}; diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index 697bf748013c..64ecbcc86b39 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -121,6 +121,7 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { // Creates an empty table with given data type and table name. The function // returns the root directory of table files. std::string createTable(const std::string& name, const TypePtr& type); + void cleanUp(const std::string& name); const std::string coordinatorUri_; const std::string user_; diff --git a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp index 45658502b413..97a1015b1b16 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.cpp @@ -15,59 +15,77 @@ */ #include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.h" +#include "velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.h" #include "velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.h" -#include "velox/expression/VectorWriters.h" +#include "velox/expression/Expr.h" +#include "velox/functions/prestosql/types/BingTileType.h" +#include "velox/functions/prestosql/types/HyperLogLogType.h" +#include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/KHyperLogLogType.h" +#include "velox/functions/prestosql/types/QDigestType.h" +#include "velox/functions/prestosql/types/SetDigestType.h" +#include "velox/functions/prestosql/types/SfmSketchType.h" +#include "velox/functions/prestosql/types/TDigestType.h" +#include "velox/functions/prestosql/types/TimeWithTimezoneType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/parse/Expressions.h" -#include "velox/vector/ComplexVector.h" -#include "velox/vector/DecodedVector.h" -#include "velox/vector/fuzzer/Utils.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/tests/utils/VectorMaker.h" namespace facebook::velox::exec::test { namespace { -class ArrayTransform { - public: - VectorPtr transform(const VectorPtr& vector, const SelectivityVector& rows) - const; - - core::ExprPtr projectionExpr( - const TypePtr& type, - const core::ExprPtr& inputExpr, - const std::string& columnAlias) const; -}; - -class MapTransform { - public: - VectorPtr transform(const VectorPtr& vector, const SelectivityVector& rows) - const; - - core::ExprPtr projectionExpr( - const TypePtr& type, - const core::ExprPtr& inputExpr, - const std::string& columnAlias) const; -}; - -class RowTransform { - public: - VectorPtr transform(const VectorPtr& vector, const SelectivityVector& rows) - const; - - core::ExprPtr projectionExpr( - const TypePtr& type, - const core::ExprPtr& inputExpr, - const std::string& columnAlias) const; -}; - -const ArrayTransform kArrayTransform; -const MapTransform kMapTransform; -const RowTransform kRowTransform; const std::unordered_map>& intermediateTypeTransforms() { static std::unordered_map> intermediateTypeTransforms{ + // Note: INTERVAL DAY TO SECOND is included in the below because its + // transform is defined and properly tested; however, due to Presto + // Java imprecision in functions parse_duration and to_milliseconds, + // it will be temporarily excluded from fuzzer runs until Presto Java + // fixes said imprecision or when we compare with Prestissimo + // directly. Please track below: + // https://github.com/prestodb/presto/issues/25275 + // https://github.com/prestodb/presto/issues/25340 {TIMESTAMP_WITH_TIME_ZONE(), - std::make_shared()}}; + std::make_shared()}, + {HYPERLOGLOG(), + std::make_shared( + HYPERLOGLOG(), VARBINARY())}, + {KHYPERLOGLOG(), + std::make_shared( + KHYPERLOGLOG(), VARBINARY())}, + {TDIGEST(DOUBLE()), + std::make_shared( + TDIGEST(DOUBLE()), VARBINARY())}, + {QDIGEST(DOUBLE()), + std::make_shared( + QDIGEST(DOUBLE()), VARBINARY())}, + {QDIGEST(BIGINT()), + std::make_shared( + QDIGEST(BIGINT()), VARBINARY())}, + {QDIGEST(REAL()), + std::make_shared( + QDIGEST(REAL()), VARBINARY())}, + {SETDIGEST(), + std::make_shared( + SETDIGEST(), VARBINARY())}, + {SFMSKETCH(), + std::make_shared( + SFMSKETCH(), VARBINARY())}, + {JSON(), std::make_shared()}, + {TIME(), + std::make_shared( + TIME(), VARCHAR())}, + {TIME_WITH_TIME_ZONE(), + std::make_shared( + TIME_WITH_TIME_ZONE(), VARCHAR())}, + {BINGTILE(), + std::make_shared( + BINGTILE(), BIGINT())}, + {INTERVAL_DAY_TIME(), std::make_shared()}, + }; return intermediateTypeTransforms; } @@ -84,129 +102,43 @@ const std::shared_ptr& getIntermediateTypeTransform( return it->second; } -VectorPtr transformIntermediateOnlyType( - const VectorPtr& vector, +VectorPtr evaluateExpression( + core::ExprPtr& expression, + const VectorPtr& input, const SelectivityVector& rows) { - const auto& type = vector->type(); - if (type->isArray()) { - return kArrayTransform.transform(vector, rows); - } else if (type->isMap()) { - return kMapTransform.transform(vector, rows); - } else if (type->isRow()) { - return kRowTransform.transform(vector, rows); - } else { - const auto& transform = getIntermediateTypeTransform(type); - DecodedVector decoded(*vector); - const auto* base = decoded.base(); - - VectorPtr result; - const auto& transformedType = transform->transformedType(); - BaseVector::ensureWritable( - SelectivityVector(base->size()), transformedType, base->pool(), result); - VectorWriter writer; - writer.init(*result); - - if (base->isConstantEncoding()) { - if (rows.countSelected() > 0) { - if (base->isNullAt(0)) { - result = BaseVector::createNullConstant( - transformedType, base->size(), base->pool()); - } else { - const auto value = transform->transform(base, 0); - - if (value.isNull()) { - result = BaseVector::createNullConstant( - transformedType, base->size(), base->pool()); - } else { - writer.setOffset(0); - VELOX_DYNAMIC_TYPE_DISPATCH( - fuzzer::writeOne, - transformedType->kind(), - value, - writer.current()); - writer.commit(true); - result = BaseVector::wrapInConstant(base->size(), 0, result); - } - } - } - } else { - rows.applyToSelected([&](vector_size_t row) { - auto index = decoded.index(row); - writer.setOffset(index); - - if (base->isNullAt(index)) { - writer.commitNull(); - } else { - const auto value = transform->transform(base, index); - - if (value.isNull()) { - writer.commitNull(); - } else { - VELOX_DYNAMIC_TYPE_DISPATCH( - fuzzer::writeOne, - transformedType->kind(), - value, - writer.current()); - writer.commit(true); - } - } - }); - } - - if (!decoded.isIdentityMapping()) { - result = decoded.wrap(result, *vector, vector->size()); - } - - return result; - } + std::shared_ptr queryCtx_{velox::core::QueryCtx::create()}; + std::unique_ptr execCtx{ + std::make_unique(input->pool(), queryCtx_.get())}; + velox::test::VectorMaker vectorMaker(input->pool()); + auto rowVector = vectorMaker.rowVector({input}); + core::TypedExprPtr typedExpr = core::Expressions::inferTypes( + expression, rowVector->type(), input->pool()); + exec::ExprSet exprSet( + std::vector{typedExpr}, execCtx.get()); + exec::EvalCtx evalCtx(execCtx.get(), &exprSet, rowVector.get()); + std::vector result; + exprSet.eval(rows, evalCtx, result); + VELOX_CHECK_EQ(result.size(), 1); + VELOX_CHECK_NOT_NULL(result[0]); + return result[0]; } -// Converts an ArrayVector so that any intermediate only types in the elements -// are transformed. -VectorPtr ArrayTransform::transform( - const VectorPtr& vector, - const SelectivityVector& rows) const { - VELOX_CHECK(vector->type()->isArray()); - DecodedVector decoded(*vector); - const auto* base = decoded.base()->as(); - - SelectivityVector elementRows(base->elements()->size(), false); - rows.applyToSelected([&](vector_size_t row) { - if (!decoded.isNullAt(row)) { - auto index = decoded.index(row); - elementRows.setValidRange( - base->offsetAt(index), - base->offsetAt(index) + base->sizeAt(index), - true); - } - }); - elementRows.updateBounds(); - - VectorPtr elementsVector = - transformIntermediateOnlyType(base->elements(), elementRows); - - VectorPtr array = std::make_shared( - base->pool(), - ARRAY(elementsVector->type()), - base->nulls(), - base->size(), - base->offsets(), - base->sizes(), - elementsVector); - - if (!decoded.isIdentityMapping()) { - array = decoded.wrap(array, *vector, vector->size()); - } +enum class TransformDirection { TO_INTERMEDIATE, TO_TARGET }; - return array; -} +core::ExprPtr getProjection( + const TypePtr& type, + const core::ExprPtr& inputExpr, + const std::string& columnAlias, + const TransformDirection transformDirection); // Applies a lambda transform to the elements of an array to convert input -// types to intermediate only types where necessary. -core::ExprPtr ArrayTransform::projectionExpr( +// types <=> intermediate types (where necessary) depending on +// 'transformDirection'. +core::ExprPtr getProjectionForArray( const TypePtr& type, const core::ExprPtr& inputExpr, - const std::string& columnAlias) const { + const std::string& columnAlias, + const TransformDirection transformDirection) { VELOX_CHECK(type->isArray()); return std::make_shared( @@ -215,70 +147,22 @@ core::ExprPtr ArrayTransform::projectionExpr( inputExpr, std::make_shared( std::vector{"x"}, - getIntermediateOnlyTypeProjectionExpr( + getProjection( type->asArray().elementType(), std::make_shared("x", "x"), - "x"))}, + "x", + transformDirection))}, columnAlias); } -// Converts an MapVector so that any intermediate only types in the keys and -// values are transformed. -VectorPtr MapTransform::transform( - const VectorPtr& vector, - const SelectivityVector& rows) const { - VELOX_CHECK(vector->type()->isMap()); - DecodedVector decoded(*vector); - const auto* base = decoded.base()->as(); - - VectorPtr keysVector = base->mapKeys(); - VectorPtr valuesVector = base->mapValues(); - const auto& keysType = keysVector->type(); - const auto& valuesType = valuesVector->type(); - - SelectivityVector elementRows(keysVector->size(), false); - rows.applyToSelected([&](vector_size_t row) { - if (!decoded.isNullAt(row)) { - auto index = decoded.index(row); - elementRows.setValidRange( - base->offsetAt(index), - base->offsetAt(index) + base->sizeAt(index), - true); - } - }); - elementRows.updateBounds(); - - if (isIntermediateOnlyType(keysType)) { - keysVector = transformIntermediateOnlyType(keysVector, elementRows); - } - - if (isIntermediateOnlyType(valuesType)) { - valuesVector = transformIntermediateOnlyType(valuesVector, elementRows); - } - - VectorPtr map = std::make_shared( - base->pool(), - MAP(keysVector->type(), valuesVector->type()), - base->nulls(), - base->size(), - base->offsets(), - base->sizes(), - keysVector, - valuesVector); - - if (!decoded.isIdentityMapping()) { - map = decoded.wrap(map, *vector, vector->size()); - } - - return map; -} - // Applies a lambda transform to the keys and values of a map to convert input -// types to intermediate only types where necessary. -core::ExprPtr MapTransform::projectionExpr( +// types <=> intermediate types (where necessary) depending on +// 'transformDirection'. +core::ExprPtr getProjectionForMap( const TypePtr& type, const core::ExprPtr& inputExpr, - const std::string& columnAlias) const { + const std::string& columnAlias, + const TransformDirection transformDirection) { VELOX_CHECK(type->isMap()); const auto& mapType = type->asMap(); const auto& keysType = mapType.keyType(); @@ -293,10 +177,11 @@ core::ExprPtr MapTransform::projectionExpr( expr, std::make_shared( std::vector{"k", "v"}, - getIntermediateOnlyTypeProjectionExpr( + getProjection( keysType, std::make_shared("k", "k"), - "k"))}, + "k", + transformDirection))}, columnAlias); } @@ -307,103 +192,132 @@ core::ExprPtr MapTransform::projectionExpr( expr, std::make_shared( std::vector{"k", "v"}, - getIntermediateOnlyTypeProjectionExpr( + getProjection( valuesType, std::make_shared("v", "v"), - "v"))}, + "v", + transformDirection))}, columnAlias); } return expr; } -// Converts an RowVector so that any intermediate only types in the children -// are transformed. -VectorPtr RowTransform::transform( - const VectorPtr& vector, - const SelectivityVector& rows) const { - VELOX_CHECK(vector->type()->isRow()); - DecodedVector decoded(*vector); - const auto* base = decoded.base()->as(); - - SelectivityVector childRows(base->size(), false); - rows.applyToSelected([&](vector_size_t row) { - if (!decoded.isNullAt(row)) { - childRows.setValid(decoded.index(row), true); - } - }); - childRows.updateBounds(); - - std::vector children; - std::vector childrenTypes; - std::vector childrenNames = base->type()->asRow().names(); - for (const auto& child : base->children()) { - if (isIntermediateOnlyType(child->type())) { - children.push_back(transformIntermediateOnlyType(child, childRows)); - childrenTypes.push_back(children.back()->type()); - } else { - children.push_back(child); - childrenTypes.push_back(child->type()); +TypePtr replaceIntermediateWithTargetType(TypePtr type) { + if (type->isArray()) { + const auto& arrayType = type->asArray(); + return ARRAY(replaceIntermediateWithTargetType(arrayType.elementType())); + } else if (type->isMap()) { + const auto& mapType = type->asMap(); + return MAP( + replaceIntermediateWithTargetType(mapType.keyType()), + replaceIntermediateWithTargetType(mapType.valueType())); + } else if (type->isRow()) { + const auto& rowType = type->asRow(); + std::vector names; + std::vector children; + for (int i = 0; i < rowType.size(); i++) { + names.push_back(rowType.nameOf(i)); + children.push_back(replaceIntermediateWithTargetType(rowType.childAt(i))); } + return ROW(std::move(names), std::move(children)); } - - VectorPtr row = std::make_shared( - base->pool(), - ROW(std::move(childrenNames), std::move(childrenTypes)), - base->nulls(), - base->size(), - std::move(children)); - - if (!decoded.isIdentityMapping()) { - row = decoded.wrap(row, *vector, vector->size()); + if (isIntermediateOnlyType(type)) { + const auto& transform = getIntermediateTypeTransform(type); + return transform->targetType(); } - - return row; + return type; } -// Applies transforms to the children of a row to convert input types to -// intermediate only types where necessary, and reconstructs the row via -// row_constructor. -core::ExprPtr RowTransform::projectionExpr( +// Applies transforms to the children of a row to convert input +// types <=> intermediate types (where necessary) depending on +// 'transformDirection', and reconstructs the row via row_constructor. +core::ExprPtr getProjectionForRow( const TypePtr& type, const core::ExprPtr& inputExpr, - const std::string& columnAlias) const { + const std::string& columnAlias, + const TransformDirection transformDirection) { VELOX_CHECK(type->isRow()); const auto& rowType = type->asRow(); std::vector children; for (int i = 0; i < rowType.size(); i++) { if (isIntermediateOnlyType(rowType.childAt(i))) { - children.push_back(getIntermediateOnlyTypeProjectionExpr( + children.push_back(getProjection( rowType.childAt(i), std::make_shared( rowType.nameOf(i), rowType.nameOf(i), std::vector{inputExpr}), - rowType.nameOf(i))); - } else { - children.push_back(std::make_shared( - rowType.nameOf(i), rowType.nameOf(i), - std::vector{inputExpr})); + transformDirection)); + } else { + children.push_back( + std::make_shared( + rowType.nameOf(i), + rowType.nameOf(i), + std::vector{inputExpr})); } } + TypePtr outputRowType; + if (transformDirection == TransformDirection::TO_TARGET) { + outputRowType = replaceIntermediateWithTargetType(type); + } else { + outputRowType = type; + } + return std::make_shared( "switch", std::vector{ std::make_shared( "is_null", std::vector{inputExpr}, std::nullopt), std::make_shared( - type, variant::null(TypeKind::ROW), std::nullopt), + outputRowType, variant::null(TypeKind::ROW), std::nullopt), std::make_shared( - type, + outputRowType, std::make_shared( "row_constructor", std::move(children), std::nullopt), false, std::nullopt)}, columnAlias); } + +core::ExprPtr getProjection( + const TypePtr& originaltype, + const core::ExprPtr& inputExpr, + const std::string& columnAlias, + const TransformDirection transformDirection) { + if (originaltype->isArray()) { + return getProjectionForArray( + originaltype, inputExpr, columnAlias, transformDirection); + } else if (originaltype->isMap()) { + return getProjectionForMap( + originaltype, inputExpr, columnAlias, transformDirection); + } else if (originaltype->isRow()) { + return getProjectionForRow( + originaltype, inputExpr, columnAlias, transformDirection); + } + const auto& transform = getIntermediateTypeTransform(originaltype); + if (transformDirection == TransformDirection::TO_TARGET) { + return transform->projectToTargetType(inputExpr, columnAlias); + } + return transform->projectToIntermediateType(inputExpr, columnAlias); +} + +VectorPtr transformIntermediateTypes( + const VectorPtr& vector, + const SelectivityVector& rows) { + const auto& type = vector->type(); + auto expression = getProjection( + type, + std::make_shared("c0", "c0"), + "c0", + TransformDirection::TO_TARGET); + + return evaluateExpression(expression, vector, rows); +} + } // namespace bool isIntermediateOnlyType(const TypePtr& type) { @@ -421,24 +335,30 @@ bool isIntermediateOnlyType(const TypePtr& type) { return false; } -VectorPtr transformIntermediateOnlyType(const VectorPtr& vector) { - return transformIntermediateOnlyType( +VectorPtr transformIntermediateTypes(const VectorPtr& vector) { + return transformIntermediateTypes( vector, SelectivityVector(vector->size(), true)); } -core::ExprPtr getIntermediateOnlyTypeProjectionExpr( +core::ExprPtr getProjectionsToIntermediateTypes( const TypePtr& type, const core::ExprPtr& inputExpr, const std::string& columnAlias) { - if (type->isArray()) { - return kArrayTransform.projectionExpr(type, inputExpr, columnAlias); - } else if (type->isMap()) { - return kMapTransform.projectionExpr(type, inputExpr, columnAlias); - } else if (type->isRow()) { - return kRowTransform.projectionExpr(type, inputExpr, columnAlias); - } else { - return getIntermediateTypeTransform(type)->projectionExpr( - inputExpr, columnAlias); - } + return getProjection( + type, inputExpr, columnAlias, TransformDirection::TO_INTERMEDIATE); +} + +core::ExprPtr IntermediateTypeTransformUsingCast::projectToTargetType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const { + return std::make_shared( + targetType_, inputExpr, false, columnAlias); +} + +core::ExprPtr IntermediateTypeTransformUsingCast::projectToIntermediateType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const { + return std::make_shared( + intermediateType_, inputExpr, false, columnAlias); } } // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h index ca807b69260d..c29e16e9474c 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h +++ b/velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h @@ -20,50 +20,90 @@ #include "velox/vector/BaseVector.h" namespace facebook::velox::exec::test { -/// Defines a transform for an intermediate type. +/// Defines a transform for a intermediate type (that does not exist in file +/// formats like json) to a target type. class IntermediateTypeTransform { public: + IntermediateTypeTransform(TypePtr intermediateType, TypePtr targetType) + : intermediateType_(std::move(intermediateType)), + targetType_(std::move(targetType)) {} virtual ~IntermediateTypeTransform() = default; /// The type of the value returned by transform(). - virtual TypePtr transformedType() const = 0; + TypePtr targetType() const { + return targetType_; + }; - /// Converts the value in vector at position row into a value that can be - /// converted back to the original type by the result of projectExpr. Will - /// only be called with a ConstantVector or a flat-like Vector (depending on - /// the intermediate type's parent type) where the value is non-null. - virtual variant transform(const BaseVector* const vector, vector_size_t row) - const = 0; + TypePtr intermediateType() const { + return intermediateType_; + }; + + /// An expression tree that can convert the type of value returned by + /// 'inputExpr' into its target type. It is assumed this has default-null + /// behavior, i.e. a null input produces a null output. + virtual core::ExprPtr projectToTargetType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const = 0; /// An expression tree that can convert the value returned by transform() back /// into its original value. It is assumed this has default-null behavior, /// i.e. a null input produces a null output. - virtual core::ExprPtr projectionExpr( + virtual core::ExprPtr projectToIntermediateType( const core::ExprPtr& inputExpr, const std::string& columnAlias) const = 0; + + protected: + TypePtr intermediateType_; + TypePtr targetType_; +}; + +/// This class offers a default implementation for the +/// IntermediateTypeTransform, enabling conversion to and from target types +/// using the CAST operator. It is designed to support intermediate types that +/// can utilize the CAST operator for correct conversion semantics. To add +/// support for such intermediate types, simply create an instance of this class +/// with the appropriate parameters. +class IntermediateTypeTransformUsingCast : public IntermediateTypeTransform { + public: + IntermediateTypeTransformUsingCast( + TypePtr intermediateType, + TypePtr targetType) + : IntermediateTypeTransform( + std::move(intermediateType), + std::move(targetType)) {} + + core::ExprPtr projectToTargetType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const override; + + core::ExprPtr projectToIntermediateType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const override; }; /// Returns true if this types is an intermediate only type or contains an /// intermediate only type. bool isIntermediateOnlyType(const TypePtr& type); -/// Converts a Vector of an intermediate only type, or containing one, to a -/// Vector of value(s) that can be input to a projection to produce those values -/// of that type but are of types supported as input. Preserves nulls and -/// encodings. -VectorPtr transformIntermediateOnlyType(const VectorPtr& vector); +/// Converts all the intermediate types in an input 'vector' to their respective +/// target types. For eg. ROW(, MAP(, BIGINT)) => +/// ROW(, MAP(, BIGINT)) +VectorPtr transformIntermediateTypes(const VectorPtr& vector); -/// Converts an expression that takes in a value of an intermediate only type so -/// that it applies a transformation to convert valid input types into values -/// of the intermediate only type. -/// @param type The expected output type of the expression, either an +/// Generates an expression that takes the output of 'inputExpr' and converts it +/// into 'ouputType'. The 'ouputType' is expected to be have same type hierarchy +/// as the types returned from 'inputExpr' with some of the leaf type replaced +/// with intermediate types. For eg. ouputType can be ROW(, +/// MAP(, BIGINT)) and type returned by 'inputExpr' can be +/// ROW(, MAP(, BIGINT)). +/// @param ouputType The expected output type of the expression, either an /// intermediate only type, or a complex type containing one. /// @param inputExpr The expression that will be the input to the returned /// expression, the output of this expression will be of a type not containing /// intermediate only types. /// @param columnAlias The alias to give the returned expression. -core::ExprPtr getIntermediateOnlyTypeProjectionExpr( - const TypePtr& type, +core::ExprPtr getProjectionsToIntermediateTypes( + const TypePtr& ouputType, const core::ExprPtr& inputExpr, const std::string& columnAlias); } // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.cpp b/velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.cpp new file mode 100644 index 000000000000..0c1e52341d41 --- /dev/null +++ b/velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.h" +#include "velox/parse/Expressions.h" + +namespace facebook::velox::exec::test { + +core::ExprPtr IntervalDayTimeTransform::projectToTargetType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const { + return std::make_shared( + "try", + std::vector{std::make_shared( + "to_milliseconds", + std::vector{inputExpr}, + std::nullopt)}, + columnAlias); +} + +// The below transform executes the following Presto SQL: +// try(parse_duration(concat(CAST(inputExpr AS VARCHAR), 'ms'))) +// +// This casts the inputted bigint as a string and concatenates it with 'ms' +// to generate a valid millisecond duration as a string, which will be parsed +// as a valid interval and easily re-converted using to_milliseconds. +core::ExprPtr IntervalDayTimeTransform::projectToIntermediateType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const { + return std::make_shared( + "try", + std::vector{std::make_shared( + "parse_duration", + std::vector{std::make_shared( + "concat", + std::vector{ + std::make_shared( + VARCHAR(), inputExpr, false, columnAlias), + std::make_shared( + VARCHAR(), + variant::create("ms"), + std::nullopt)}, + std::nullopt)}, + std::nullopt)}, + columnAlias); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.h b/velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.h new file mode 100644 index 000000000000..097359e7d391 --- /dev/null +++ b/velox/exec/fuzzer/PrestoQueryRunnerIntervalTransform.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" + +namespace facebook::velox::exec::test { + +// This transform converts negative interval values to NULL due to the +// constraints in parse_duration. Although constants and inputs to +// to_milliseconds allow for negatives, parse_duration does not, the try +// will capture this error and NULL the output. The behavior is the same +// in Presto Java and Prestissimo. +class IntervalDayTimeTransform : public IntermediateTypeTransform { + public: + IntervalDayTimeTransform() + : IntermediateTypeTransform(INTERVAL_DAY_TIME(), BIGINT()) {} + + core::ExprPtr projectToTargetType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const override; + + core::ExprPtr projectToIntermediateType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const override; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.cpp b/velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.cpp new file mode 100644 index 000000000000..8285cbbc7276 --- /dev/null +++ b/velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.h" +#include "velox/parse/Expressions.h" + +namespace facebook::velox::exec::test { + +core::ExprPtr JsonTransform::projectToTargetType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const { + return std::make_shared( + "try", + std::vector{std::make_shared( + "json_format", std::vector{inputExpr}, std::nullopt)}, + columnAlias); +} + +core::ExprPtr JsonTransform::projectToIntermediateType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const { + return std::make_shared( + "try", + std::vector{std::make_shared( + "json_parse", std::vector{inputExpr}, std::nullopt)}, + columnAlias); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.h b/velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.h new file mode 100644 index 000000000000..88bc67dd56bc --- /dev/null +++ b/velox/exec/fuzzer/PrestoQueryRunnerJsonTransform.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/functions/prestosql/types/JsonType.h" + +namespace facebook::velox::exec::test { +class JsonTransform : public IntermediateTypeTransform { + public: + JsonTransform() : IntermediateTypeTransform(JSON(), VARCHAR()) {} + + core::ExprPtr projectToTargetType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const override; + + core::ExprPtr projectToIntermediateType( + const core::ExprPtr& inputExpr, + const std::string& columnAlias) const override; +}; +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.cpp b/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.cpp index ef1e6d0d1888..fcad8e0e4314 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.cpp @@ -15,51 +15,17 @@ */ #include "velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.h" -#include "velox/functions/lib/DateTimeFormatter.h" -#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/parse/Expressions.h" -#include "velox/type/tz/TimeZoneMap.h" -#include "velox/vector/SimpleVector.h" namespace facebook::velox::exec::test { namespace { const std::string kFormat = "yyyy-MM-dd HH:mm:ss.SSS ZZZ"; const std::string kBackupFormat = "yyyy-MM-dd HH:mm:ss.SSS ZZ"; - -std::string format(const int64_t timestampWithTimeZone) { - static const std::shared_ptr kJodaDateTime = - functions::buildJodaDateTimeFormatter(kFormat).value(); - - const auto timestamp = unpackTimestampUtc(timestampWithTimeZone); - const auto timeZoneId = unpackZoneKeyId(timestampWithTimeZone); - auto* timezonePtr = tz::locateZone(tz::getTimeZoneName(timeZoneId)); - - const auto maxResultSize = kJodaDateTime->maxResultSize(timezonePtr); - std::string str; - str.resize(maxResultSize); - const auto resultSize = - kJodaDateTime->format(timestamp, timezonePtr, maxResultSize, str.data()); - str.resize(resultSize); - - return str; -} } // namespace -// Convert a TimestampWithTimeZone into a Varchar using DatetimeFormatter, so -// that we can get the TimestampWithTimeZone back by calling parse_datetime. -variant TimestampWithTimeZoneTransform::transform( - const BaseVector* const vector, - vector_size_t row) const { - VELOX_CHECK(isTimestampWithTimeZoneType(vector->type())); - VELOX_CHECK(!vector->isNullAt(row)); - - return variant::create( - format(vector->asChecked>()->valueAt(row))); -} - // Applies parse_datetime to a Vector of VARCHAR (formatted timestamps with time // zone) to produce values of type TimestampWithTimeZone. -core::ExprPtr TimestampWithTimeZoneTransform::projectionExpr( +core::ExprPtr TimestampWithTimeZoneTransform::projectToIntermediateType( const core::ExprPtr& inputExpr, const std::string& columnAlias) const { // format_datetime with the ZZZ pattern produces time zones that need to be diff --git a/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.h b/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.h index 874d5eba5802..228d4ab53a07 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.h +++ b/velox/exec/fuzzer/PrestoQueryRunnerTimestampWithTimeZoneTransform.h @@ -17,18 +17,18 @@ #pragma once #include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" namespace facebook::velox::exec::test { -class TimestampWithTimeZoneTransform : public IntermediateTypeTransform { +class TimestampWithTimeZoneTransform + : public IntermediateTypeTransformUsingCast { public: - TypePtr transformedType() const override { - return VARCHAR(); - } + TimestampWithTimeZoneTransform() + : IntermediateTypeTransformUsingCast( + TIMESTAMP_WITH_TIME_ZONE(), + VARCHAR()) {} - variant transform(const BaseVector* const vector, vector_size_t row) - const override; - - core::ExprPtr projectionExpr( + core::ExprPtr projectToIntermediateType( const core::ExprPtr& inputExpr, const std::string& columnAlias) const override; }; diff --git a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp index 0d74dbbd351d..61676760099e 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.cpp @@ -203,7 +203,7 @@ void PrestoQueryRunnerToSqlPlanNodeVisitor::visit( static_cast(ctx); auto insertTableHandle = - std::dynamic_pointer_cast( + std::dynamic_pointer_cast( node.insertTableHandle()->connectorInsertTableHandle()); // Returns a CTAS sql with specified table properties from TableWriteNode, @@ -291,7 +291,8 @@ void PrestoQueryRunnerToSqlPlanNodeVisitor::visit( sql << inputType->nameOf(i); } - sql << ", row_number() OVER ("; + sql << ", " << core::TopNRowNumberNode::rankFunctionName(node.rankFunction()) + << "() OVER ("; const auto& partitionKeys = node.partitionKeys(); if (!partitionKeys.empty()) { @@ -344,7 +345,7 @@ void PrestoQueryRunnerToSqlPlanNodeVisitor::visit( std::stringstream sql; sql << "SELECT "; - const auto& inputType = node.sources()[0]->outputType(); + const auto& inputType = node.inputType(); for (auto i = 0; i < inputType->size(); ++i) { appendComma(i, sql); sql << inputType->nameOf(i); @@ -380,7 +381,7 @@ void PrestoQueryRunnerToSqlPlanNodeVisitor::visit( } sql << " " << queryRunnerContext_->windowFrames_.at(node.id()).at(i); - sql << ")"; + sql << ") as " << node.windowColumnNames()[i]; } // WindowNode should have a single source. diff --git a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h index 2c2b5d09f530..dd26b66611c4 100644 --- a/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h +++ b/velox/exec/fuzzer/PrestoQueryRunnerToSqlPlanNodeVisitor.h @@ -124,9 +124,20 @@ class PrestoQueryRunnerToSqlPlanNodeVisitor : public PrestoSqlPlanNodeVisitor { void visit(const core::ProjectNode& node, core::PlanNodeVisitorContext& ctx) const override; + void visit(const core::ParallelProjectNode&, core::PlanNodeVisitorContext&) + const override { + VELOX_NYI(); + } + void visit(const core::RowNumberNode& node, core::PlanNodeVisitorContext& ctx) const override; + void visit( + const core::SpatialJoinNode& node, + core::PlanNodeVisitorContext& ctx) const override { + VELOX_NYI(); + } + void visit(const core::TableScanNode& node, core::PlanNodeVisitorContext& ctx) const override { PrestoSqlPlanNodeVisitor::visit(node, ctx); diff --git a/velox/exec/fuzzer/PrestoSql.cpp b/velox/exec/fuzzer/PrestoSql.cpp index 8e9bab903771..6b034ef89d18 100644 --- a/velox/exec/fuzzer/PrestoSql.cpp +++ b/velox/exec/fuzzer/PrestoSql.cpp @@ -19,6 +19,7 @@ #include "velox/exec/fuzzer/PrestoQueryRunner.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/vector/SimpleVector.h" namespace facebook::velox::exec::test { @@ -94,7 +95,7 @@ std::string toTypeSql(const TypePtr& type) { } default: if (type->isPrimitiveType()) { - return type->name(); + return type->toString(); } VELOX_UNSUPPORTED("Type is not supported: {}", type->toString()); } @@ -339,7 +340,7 @@ std::string toCallSql(const core::CallTypedExprPtr& call) { std::string toCastSql(const core::CastTypedExpr& cast) { std::stringstream sql; - if (cast.nullOnFailure()) { + if (cast.isTryCast()) { sql << "try_cast("; } else { sql << "cast("; @@ -393,6 +394,10 @@ std::string toConstantSql(const core::ConstantTypedExpr& constant) { sql << typeSql << " "; } sql << std::quoted(getConstantValue(constant), '\'', '\''); + } else if (type->isIntervalYearMonth()) { + sql << fmt::format("INTERVAL '{}' YEAR TO MONTH", constant.toString()); + } else if (type->isIntervalDayTime()) { + sql << fmt::format("INTERVAL '{}' DAY TO SECOND", constant.toString()); } else if (type->isBigint()) { sql << getConstantValue(constant); } else if (type->isPrimitiveType()) { @@ -593,6 +598,12 @@ void PrestoSqlPlanNodeVisitor::visit( visitorContext.sql = sql.str(); } +void PrestoSqlPlanNodeVisitor::visit( + const core::SpatialJoinNode& node, + core::PlanNodeVisitorContext& ctx) const { + VELOX_NYI("SpatialJoinNode is not yet supported in SQL conversion"); +} + std::optional PrestoSqlPlanNodeVisitor::toSql( const core::PlanNodePtr& node) const { PrestoSqlPlanNodeVisitorContext sourceContext; diff --git a/velox/exec/fuzzer/PrestoSql.h b/velox/exec/fuzzer/PrestoSql.h index 7e3ed99d7c4d..2acfaaa2df96 100644 --- a/velox/exec/fuzzer/PrestoSql.h +++ b/velox/exec/fuzzer/PrestoSql.h @@ -70,6 +70,10 @@ class PrestoSqlPlanNodeVisitor : public core::PlanNodeVisitor { const core::NestedLoopJoinNode& node, core::PlanNodeVisitorContext& ctx) const override; + void visit( + const core::SpatialJoinNode& node, + core::PlanNodeVisitorContext& ctx) const override; + void visit(const core::TableScanNode& node, core::PlanNodeVisitorContext& ctx) const override; diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 4bfe30e955e7..983de96f5f8f 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -55,7 +55,8 @@ class ReferenceQueryRunner { enum class RunnerType { kPrestoQueryRunner, kDuckQueryRunner, - kSparkQueryRunner + kSparkQueryRunner, + kVeloxQueryRunner }; // @param aggregatePool Used to allocate memory needed for vectors produced @@ -114,15 +115,6 @@ class ReferenceQueryRunner { return true; } - /// Executes SQL query returned by the 'toSql' method using 'input' data. - /// Converts results using 'resultType' schema. - virtual std::multiset> execute( - const std::string& /*sql*/, - const std::vector& /*input*/, - const velox::RowTypePtr& /*resultType*/) { - VELOX_UNSUPPORTED(); - } - // Converts 'plan' into an SQL query and executes it. Result is returned as // a MaterializedRowMultiset with the ReferenceQueryErrorCode::kSuccess if // successful, or an std::nullopt with a ReferenceQueryErrorCode if the @@ -154,21 +146,12 @@ class ReferenceQueryRunner { VELOX_UNSUPPORTED(); } - /// Returns true if 'executeVector' can be called to get results as Velox - /// Vector. + /// Returns true if 'executeAndReturnVector' can be called to get results as + /// Velox Vector. virtual bool supportsVeloxVectorResults() const { return false; } - /// Similar to 'execute' but returns results in RowVector format. - /// Caller should ensure 'supportsVeloxVectorResults' returns true. - virtual std::vector executeVector( - const std::string& /*sql*/, - const std::vector& /*input*/, - const RowTypePtr& /*resultType*/) { - VELOX_UNSUPPORTED(); - } - virtual std::vector execute(const std::string& /*sql*/) { VELOX_UNSUPPORTED(); } diff --git a/velox/exec/fuzzer/RowNumberFuzzer.cpp b/velox/exec/fuzzer/RowNumberFuzzer.cpp index 2fc3320cdc4d..7c8c911db758 100644 --- a/velox/exec/fuzzer/RowNumberFuzzer.cpp +++ b/velox/exec/fuzzer/RowNumberFuzzer.cpp @@ -68,7 +68,12 @@ class RowNumberFuzzer : public RowNumberFuzzerBase { RowNumberFuzzer::RowNumberFuzzer( size_t initialSeed, std::unique_ptr referenceQueryRunner) - : RowNumberFuzzerBase(initialSeed, std::move(referenceQueryRunner)) {} + : RowNumberFuzzerBase(initialSeed, std::move(referenceQueryRunner)) { + // Set timestamp precision as milliseconds, as timestamp may be used as + // paritition key, and presto doesn't supports nanosecond precision. + vectorFuzzer_.getMutableOptions().timestampPrecision = + fuzzer::FuzzerTimestampPrecision::kMilliSeconds; +} std::pair, std::vector> RowNumberFuzzer::generatePartitionKeys() { diff --git a/velox/exec/fuzzer/RowNumberFuzzerBase.cpp b/velox/exec/fuzzer/RowNumberFuzzerBase.cpp index 1b634b0f23bf..307216916c2b 100644 --- a/velox/exec/fuzzer/RowNumberFuzzerBase.cpp +++ b/velox/exec/fuzzer/RowNumberFuzzerBase.cpp @@ -19,6 +19,7 @@ #include #include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" +#include "velox/dwio/dwrf/RegisterDwrfWriter.h" #include "velox/exec/Spill.h" #include "velox/exec/fuzzer/FuzzerUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" @@ -78,6 +79,7 @@ RowNumberFuzzerBase::RowNumberFuzzerBase( void RowNumberFuzzerBase::setupReadWrite() { filesystems::registerLocalFileSystem(); dwrf::registerDwrfReaderFactory(); + dwrf::registerDwrfWriterFactory(); if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); diff --git a/velox/exec/fuzzer/RowNumberFuzzerBase.h b/velox/exec/fuzzer/RowNumberFuzzerBase.h index abae8551c555..1aaf5f4f3b03 100644 --- a/velox/exec/fuzzer/RowNumberFuzzerBase.h +++ b/velox/exec/fuzzer/RowNumberFuzzerBase.h @@ -18,6 +18,7 @@ #include #include "velox/common/fuzzer/Utils.h" +#include "velox/exec/MemoryReclaimer.h" #include "velox/exec/Split.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -124,14 +125,14 @@ class RowNumberFuzzerBase { memory::memoryManager()->addRootPool( "rowNumberFuzzer", memory::kMaxMemory, - memory::MemoryReclaimer::create())}; + exec::MemoryReclaimer::create())}; std::shared_ptr pool_{rootPool_->addLeafChild( "rowNumberFuzzerLeaf", true, - memory::MemoryReclaimer::create())}; + exec::MemoryReclaimer::create())}; std::shared_ptr writerPool_{rootPool_->addAggregateChild( "rowNumberFuzzerWriter", - memory::MemoryReclaimer::create())}; + exec::MemoryReclaimer::create())}; VectorFuzzer vectorFuzzer_; std::unique_ptr referenceQueryRunner_; }; diff --git a/velox/exec/fuzzer/SpatialJoinFuzzer.cpp b/velox/exec/fuzzer/SpatialJoinFuzzer.cpp new file mode 100644 index 000000000000..ffce520eef0a --- /dev/null +++ b/velox/exec/fuzzer/SpatialJoinFuzzer.cpp @@ -0,0 +1,598 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/SpatialJoinFuzzer.h" + +#include "velox/common/file/FileSystems.h" +#include "velox/common/fuzzer/Utils.h" +#include "velox/exec/fuzzer/FuzzerUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" + +DEFINE_int32(steps, 10, "Number of plans to generate and test."); + +DEFINE_int32( + duration_sec, + 0, + "For how long it should run (in seconds). If zero, " + "it executes exactly --steps iterations and exits."); + +DEFINE_int32( + batch_size, + 100, + "The number of elements on each generated vector."); + +DEFINE_int32(num_batches, 10, "The number of generated vectors."); + +DEFINE_double( + null_ratio, + 0.1, + "Chance of adding a null value in a vector " + "(expressed as double from 0 to 1)."); + +namespace facebook::velox::exec { + +namespace { +using namespace facebook::velox; + +/// Spatial distribution patterns for geometry generation. +enum class GeometryDistribution { + kUniform, // Geometries uniformly distributed in space + kClustered, // Geometries clustered in specific regions + kSparse // Sparse geometries with low overlap probability +}; + +// Constants for geometry generation. +constexpr int32_t kRandomCoordinateMax = 1000; +constexpr int32_t kNumClusters = 5; +constexpr double kClusterSpacing = 200.0; +constexpr double kClusterCenterOffset = 100.0; +constexpr int32_t kClusterSpreadRange = 100; +constexpr int32_t kClusterSpreadHalf = kClusterSpreadRange / 2; +constexpr double kPolygonSize = 10.0; +constexpr double kSparseSpread = 2000.0; +constexpr uint32_t kMaxRadius = 100; + +// Base class for geometry string generators. +class GeometryInputGenerator : public AbstractInputGenerator { + public: + GeometryInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : AbstractInputGenerator(seed, VARCHAR(), nullptr, nullRatio), + distribution_(distribution) {} + + protected: + std::pair generateCoordinates() { + double x, y; + switch (distribution_) { + case GeometryDistribution::kUniform: { + x = fuzzer::rand( + rng_, -kRandomCoordinateMax, kRandomCoordinateMax); + y = fuzzer::rand( + rng_, -kRandomCoordinateMax, kRandomCoordinateMax); + break; + } + case GeometryDistribution::kClustered: { + uint32_t cluster = fuzzer::rand(rng_, 0, kNumClusters); + double centerX = (cluster * kClusterSpacing) + kClusterCenterOffset; + double centerY = (cluster * kClusterSpacing) + kClusterCenterOffset; + x = centerX + + ((fuzzer::rand( + rng_, -kClusterSpreadRange, kClusterSpreadRange)) - + kClusterSpreadHalf); + y = centerY + + ((fuzzer::rand( + rng_, -kClusterSpreadRange, kClusterSpreadRange)) - + kClusterSpreadHalf); + break; + } + case GeometryDistribution::kSparse: { + x = fuzzer::rand(rng_, -kSparseSpread, kSparseSpread); + y = fuzzer::rand(rng_, -kSparseSpread, kSparseSpread); + break; + } + } + return {x, y}; + } + + GeometryDistribution distribution_; +}; + +// Generates POINT geometry strings. +class PointInputGenerator : public GeometryInputGenerator { + public: + PointInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : GeometryInputGenerator(distribution, seed, nullRatio) {} + + variant generate() override { + if (fuzzer::coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + + auto [x, y] = generateCoordinates(); + return fmt::format("POINT ({} {})", x, y); + } +}; + +// Generates POLYGON geometry strings. +class PolygonInputGenerator : public GeometryInputGenerator { + public: + PolygonInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : GeometryInputGenerator(distribution, seed, nullRatio) {} + + variant generate() override { + if (fuzzer::coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + auto [centerX, centerY] = generateCoordinates(); + return fmt::format( + "POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))", + centerX - kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY - kPolygonSize, + centerX + kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY + kPolygonSize, + centerX - kPolygonSize, + centerY - kPolygonSize); + } +}; + +// Generates LINESTRING geometry strings. +class LineStringInputGenerator : public GeometryInputGenerator { + public: + LineStringInputGenerator( + GeometryDistribution distribution, + size_t seed, + double nullRatio) + : GeometryInputGenerator(distribution, seed, nullRatio) {} + + variant generate() override { + if (fuzzer::coinToss(rng_, nullRatio_)) { + return variant::null(type_->kind()); + } + auto [x1, y1] = generateCoordinates(); + double x2 = x1 + kPolygonSize; + double y2 = y1 + kPolygonSize; + return fmt::format("LINESTRING ({} {}, {} {})", x1, y1, x2, y2); + } +}; + +class SpatialJoinFuzzer { + public: + explicit SpatialJoinFuzzer(size_t initialSeed); + + void go(); + + private: + static VectorFuzzer::Options getFuzzerOptions() { + VectorFuzzer::Options opts; + opts.vectorSize = FLAGS_batch_size; + opts.stringVariableLength = true; + opts.stringLength = 100; + opts.nullRatio = FLAGS_null_ratio; + return opts; + } + + void seed(size_t seed) { + currentSeed_ = seed; + vectorFuzzer_.reSeed(seed); + rng_.seed(currentSeed_); + } + + void reSeed() { + seed(rng_()); + } + + // Randomly pick a join type supported by SpatialJoin. + core::JoinType pickJoinType(); + + // Randomly pick a spatial predicate function. + std::string pickSpatialPredicate(); + + // Randomly pick a geometry distribution pattern. + GeometryDistribution pickDistribution(); + + // Runs one test iteration from query plans generation, execution and result + // verification. + void verify(core::JoinType joinType); + + // Creates a vector of POINT geometries with specified distribution. + VectorPtr makePointVector(int32_t size, GeometryDistribution distribution); + + // Creates a vector of POLYGON geometries with specified distribution. + VectorPtr makePolygonVector(int32_t size, GeometryDistribution distribution); + + // Creates a vector of LINESTRING geometries with specified distribution. + VectorPtr makeLineStringVector( + int32_t size, + GeometryDistribution distribution); + + // Returns randomly generated probe input with geometry columns (as WKT + // strings). + std::vector generateProbeInput( + GeometryDistribution distribution); + + // Same as generateProbeInput() but copies over 10% of the input to ensure + // some matches during joining. Also generates an empty input with a 10% + // chance. + std::vector generateBuildInput( + const std::vector& probeInput, + GeometryDistribution distribution); + + // Executes a plan and returns the result. + RowVectorPtr execute(const core::PlanNodePtr& plan); + + int32_t randInt(int32_t min, int32_t max) { + return boost::random::uniform_int_distribution(min, max)(rng_); + } + + const std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + std::mt19937 rng_; + size_t currentSeed_{0}; + + VectorFuzzer vectorFuzzer_; + + struct { + size_t numIterations{0}; + } stats_; +}; + +SpatialJoinFuzzer::SpatialJoinFuzzer(size_t initialSeed) + : vectorFuzzer_{getFuzzerOptions(), pool_.get()} { + filesystems::registerLocalFileSystem(); + seed(initialSeed); +} + +template +bool isDone(size_t i, T startTime) { + if (FLAGS_duration_sec > 0) { + std::chrono::duration elapsed = + std::chrono::system_clock::now() - startTime; + return elapsed.count() >= FLAGS_duration_sec; + } + return i >= FLAGS_steps; +} + +core::JoinType SpatialJoinFuzzer::pickJoinType() { + // SpatialJoin only supports INNER and LEFT join types. + static std::vector kJoinTypes = { + core::JoinType::kInner, core::JoinType::kLeft}; + + const size_t idx = randInt(0, kJoinTypes.size() - 1); + return kJoinTypes[idx]; +} + +std::string SpatialJoinFuzzer::pickSpatialPredicate() { + // Common spatial predicates supported by spatial joins. + static std::vector kPredicates = { + "ST_Intersects", + "ST_Contains", + "ST_Within", + "ST_Distance", + "ST_Overlaps", + "ST_Crosses", + "ST_Touches", + "ST_Equals"}; + + const size_t idx = randInt(0, kPredicates.size() - 1); + return kPredicates[idx]; +} + +GeometryDistribution SpatialJoinFuzzer::pickDistribution() { + static std::vector kDistributions = { + GeometryDistribution::kUniform, + GeometryDistribution::kClustered, + GeometryDistribution::kSparse}; + + const size_t idx = randInt(0, kDistributions.size() - 1); + return kDistributions[idx]; +} + +VectorPtr SpatialJoinFuzzer::makePointVector( + int32_t size, + GeometryDistribution distribution) { + auto generator = std::make_shared( + distribution, currentSeed_, getFuzzerOptions().nullRatio); + return vectorFuzzer_.fuzzFlat(VARCHAR(), size, generator); +} + +VectorPtr SpatialJoinFuzzer::makePolygonVector( + int32_t size, + GeometryDistribution distribution) { + auto generator = std::make_shared( + distribution, currentSeed_, getFuzzerOptions().nullRatio); + return vectorFuzzer_.fuzzFlat(VARCHAR(), size, generator); +} + +VectorPtr SpatialJoinFuzzer::makeLineStringVector( + int32_t size, + GeometryDistribution distribution) { + auto generator = std::make_shared( + distribution, currentSeed_, getFuzzerOptions().nullRatio); + return vectorFuzzer_.fuzzFlat(VARCHAR(), size, generator); +} + +std::vector SpatialJoinFuzzer::generateProbeInput( + GeometryDistribution distribution) { + std::vector input; + + const int32_t numRows = FLAGS_batch_size * FLAGS_num_batches; + const int32_t batchSize = FLAGS_batch_size; + const int32_t numBatches = FLAGS_num_batches; + + // Randomly pick geometry type for probe side. + const int geometryType = randInt(0, 2); + + for (int32_t i = 0; i < numBatches; ++i) { + int32_t currentBatchSize = std::min(batchSize, numRows - (i * batchSize)); + + VectorPtr geomVector; + if (geometryType == 0) { + geomVector = makePointVector(currentBatchSize, distribution); + } else if (geometryType == 1) { + geomVector = makePolygonVector(currentBatchSize, distribution); + } else { + geomVector = makeLineStringVector(currentBatchSize, distribution); + } + + auto idVector = vectorFuzzer_.fuzzFlat(BIGINT(), currentBatchSize); + auto rowType = ROW( + {"probe_id", "probe_geom_wkt"}, {idVector->type(), geomVector->type()}); + auto rowVector = std::make_shared( + pool_.get(), + rowType, + nullptr, + currentBatchSize, + std::vector{idVector, geomVector}); + input.push_back(rowVector); + } + + return input; +} + +std::vector SpatialJoinFuzzer::generateBuildInput( + const std::vector& probeInput, + GeometryDistribution distribution) { + std::vector input; + + // 1 in 10 times use empty build. + if (vectorFuzzer_.coinToss(0.1)) { + auto rowType = ROW({"build_id", "build_geom_wkt"}, {BIGINT(), VARCHAR()}); + auto rowVector = std::make_shared( + pool_.get(), + rowType, + nullptr, + 0, + std::vector{ + vectorFuzzer_.fuzzFlat(BIGINT(), 0), + vectorFuzzer_.fuzzFlat(VARCHAR(), 0)}); + return {rowVector}; + } + + // Randomly pick geometry type for build side. + const int geometryType = randInt(0, 2); + + for (const auto& probe : probeInput) { + auto numRows = 1 + probe->size() / 8; + + VectorPtr geomVector; + if (geometryType == 0) { + geomVector = makePointVector(numRows, distribution); + } else if (geometryType == 1) { + geomVector = makePolygonVector(numRows, distribution); + } else { + geomVector = makeLineStringVector(numRows, distribution); + } + + auto idVector = vectorFuzzer_.fuzzFlat(BIGINT(), numRows); + + // To ensure some matches, copy some geometries from probe side. + if (probe->size() > 0) { + std::vector rowNumbers(numRows); + SelectivityVector rows(numRows, false); + for (vector_size_t i = 0; i < numRows; ++i) { + if (vectorFuzzer_.coinToss(0.3)) { + rowNumbers[i] = randInt(0, probe->size() - 1); + rows.setValid(i, true); + } + } + + // Copy geometry from probe to build. + auto probeGeom = probe->childAt(1); + geomVector->copy(probeGeom.get(), rows, rowNumbers.data()); + } + + auto rowType = ROW( + {"build_id", "build_geom_wkt"}, {idVector->type(), geomVector->type()}); + auto rowVector = std::make_shared( + pool_.get(), + rowType, + nullptr, + numRows, + std::vector{idVector, geomVector}); + input.push_back(rowVector); + } + + return input; +} + +RowVectorPtr SpatialJoinFuzzer::execute(const core::PlanNodePtr& plan) { + LOG(INFO) << "Executing query plan: " << std::endl + << plan->toString(true, true); + + return test::AssertQueryBuilder(plan).copyResults(pool_.get()); +} + +void SpatialJoinFuzzer::verify(core::JoinType joinType) { + const auto distribution = pickDistribution(); + const auto predicate = pickSpatialPredicate(); + + // Generate test data (WKT strings). + auto probeInput = generateProbeInput(distribution); + auto buildInput = generateBuildInput(probeInput, distribution); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Probe input: " << probeInput[0]->toString(); + for (const auto& v : probeInput) { + VLOG(1) << std::endl << v->toString(0, v->size()); + } + + VLOG(1) << "Build input: " << buildInput[0]->toString(); + for (const auto& v : buildInput) { + VLOG(1) << std::endl << v->toString(0, v->size()); + } + } + + // Build spatial join plan with geometry conversion as part of the plan. + const auto planNodeIdGenerator = + std::make_shared(); + + std::string joinCondition; + std::optional radiusColumn; + std::optional radiusExpression; + if (predicate == "ST_Distance") { + // ST_Distance returns a value, use it with a threshold. + // For ST_Distance, we use a radius column instead of embedding the + // threshold in the join condition. + joinCondition = + fmt::format("{}(probe_geom, build_geom) < radius", predicate); + radiusColumn = "radius"; + radiusExpression = fmt::format( + "CAST({} AS DOUBLE) AS radius", + static_cast(randInt(0, kMaxRadius))); + } else { + // Other predicates return boolean. + joinCondition = fmt::format("{}(probe_geom, build_geom)", predicate); + } + + // Create SpatialJoin plan with geometry conversion projections. + auto spatialJoinPlan = + test::PlanBuilder(planNodeIdGenerator) + .values(probeInput) + // Convert probe WKT strings to Geometry + .project( + {"probe_id", + "ST_GeometryFromText(probe_geom_wkt) AS probe_geom", + "probe_geom_wkt"}) + .spatialJoin( + test::PlanBuilder(planNodeIdGenerator) + .values(buildInput) + // Convert build WKT strings to Geometry + .project( + radiusColumn.has_value() + ? std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt", radiusExpression.value()} + : std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt"}) + .planNode(), + joinCondition, + "probe_geom", + "build_geom", + radiusColumn, + {"probe_id", "probe_geom_wkt", "build_id", "build_geom_wkt"}, + joinType) + .planNode(); + + // Create equivalent NestedLoopJoin plan for comparison. + auto nestedLoopJoinPlan = + test::PlanBuilder(planNodeIdGenerator) + .values(probeInput) + // Convert probe WKT strings to Geometry + .project( + {"probe_id", + "ST_GeometryFromText(probe_geom_wkt) AS probe_geom", + "probe_geom_wkt"}) + .nestedLoopJoin( + test::PlanBuilder(planNodeIdGenerator) + .values(buildInput) + // Convert build WKT strings to Geometry + .project( + radiusColumn.has_value() + ? std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt", radiusExpression.value()} + : std::vector< + std:: + string>{"build_id", "ST_GeometryFromText(build_geom_wkt) AS build_geom", "build_geom_wkt"}) + .planNode(), + {joinCondition}, + {"probe_id", "probe_geom_wkt", "build_id", "build_geom_wkt"}, + joinType) + .planNode(); + + LOG(INFO) << "Executing SpatialJoin plan..."; + const auto spatialJoinResult = execute(spatialJoinPlan); + + LOG(INFO) << "Executing NestedLoopJoin plan..."; + const auto nestedLoopJoinResult = execute(nestedLoopJoinPlan); + + // Compare SpatialJoin vs NestedLoopJoin results. + auto result = + test::assertEqualResults({nestedLoopJoinResult}, {spatialJoinResult}); + VELOX_CHECK(result, "SpatialJoin and NestedLoopJoin results don't match"); + + LOG(INFO) << "SpatialJoin matches NestedLoopJoin."; +} + +void SpatialJoinFuzzer::go() { + VELOX_USER_CHECK( + FLAGS_steps > 0 || FLAGS_duration_sec > 0, + "Either --steps or --duration_sec needs to be greater than zero."); + VELOX_USER_CHECK_GE(FLAGS_batch_size, 10, "Batch size must be at least 10."); + + const auto startTime = std::chrono::system_clock::now(); + + while (!isDone(stats_.numIterations, startTime)) { + LOG(WARNING) << "==============================> Started iteration " + << stats_.numIterations << " (seed: " << currentSeed_ << ")"; + + // Pick join type. + const auto joinType = pickJoinType(); + + verify(joinType); + + LOG(WARNING) << "==============================> Done with iteration " + << stats_.numIterations; + + reSeed(); + ++stats_.numIterations; + } + + LOG(INFO) << "Total iterations: " << stats_.numIterations; +} + +} // namespace + +void spatialJoinFuzzer(size_t seed) { + SpatialJoinFuzzer(seed).go(); +} + +} // namespace facebook::velox::exec diff --git a/velox/exec/fuzzer/SpatialJoinFuzzer.h b/velox/exec/fuzzer/SpatialJoinFuzzer.h new file mode 100644 index 000000000000..92bdd967213b --- /dev/null +++ b/velox/exec/fuzzer/SpatialJoinFuzzer.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +namespace facebook::velox::exec { + +/// Runs the fuzzer for SpatialJoin operator. Generates random geometry data +/// and spatial join plans with various predicates (ST_Intersects, ST_Contains, +/// ST_Within, ST_Distance), comparing SpatialJoin results against +/// NestedLoopJoin as the reference implementation. +/// +/// The fuzzer tests: +/// - Different spatial predicates +/// - INNER and LEFT join types (the only types supported by SpatialJoin) +/// - Different geometry types (POINT, POLYGON, LINESTRING) +/// - Various data distributions (uniform, clustered, sparse) +/// - Different sizes of probe and build sides +/// - Plans with and without filters +/// - Different output column projections +void spatialJoinFuzzer(size_t seed); + +} // namespace facebook::velox::exec diff --git a/velox/exec/fuzzer/SpatialJoinFuzzerRunner.cpp b/velox/exec/fuzzer/SpatialJoinFuzzerRunner.cpp new file mode 100644 index 000000000000..943c872290c6 --- /dev/null +++ b/velox/exec/fuzzer/SpatialJoinFuzzerRunner.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/file/FileSystems.h" +#include "velox/common/memory/Memory.h" +#include "velox/exec/fuzzer/SpatialJoinFuzzer.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/serializers/PrestoSerializer.h" + +DEFINE_int64( + seed, + 0, + "Initial seed for random number generator used to reproduce previous " + "results (0 means start with random seed)."); + +using namespace facebook::velox; + +int main(int argc, char** argv) { + folly::Init init(&argc, &argv); + + // Initialize memory system. + memory::initializeMemoryManager(memory::MemoryManager::Options{}); + auto pool = memory::memoryManager()->addLeafPool(); + + // Register file systems. + filesystems::registerLocalFileSystem(); + + // Register Presto functions. + functions::prestosql::registerAllScalarFunctions(); + + // Register type resolver. + parse::registerTypeResolver(); + + // Register serializers. + if (!isRegisteredNamedVectorSerde(VectorSerde::Kind::kPresto)) { + serializer::presto::PrestoVectorSerde::registerNamedVectorSerde(); + } + + // Determine the seed. + size_t seed = FLAGS_seed == 0 ? std::random_device{}() : FLAGS_seed; + LOG(INFO) << "Using seed: " << seed; + + // Run the spatial join fuzzer. + exec::spatialJoinFuzzer(seed); + + return 0; +} diff --git a/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp b/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp index 7d658aa68885..34b33d330ee9 100644 --- a/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp +++ b/velox/exec/fuzzer/TopNRowNumberFuzzer.cpp @@ -41,6 +41,8 @@ class TopNRowNumberFuzzer : public RowNumberFuzzerBase { std::pair, std::vector> generateKeys( const std::string& prefix); + std::string generateRankFunction(); + std::vector generateInput( const std::vector& keyNames, const std::vector& keyTypes, @@ -93,6 +95,19 @@ TopNRowNumberFuzzer::generateKeys(const std::string& prefix) { return std::make_pair(keys, types); } +std::string TopNRowNumberFuzzer::generateRankFunction() { + int32_t rankFunction = randInt(0, 2); + switch (rankFunction) { + case 0: + return "row_number"; + case 1: + return "rank"; + case 2: + return "dense_rank"; + } + return "row_number"; +} + std::vector TopNRowNumberFuzzer::generateInput( const std::vector& keyNames, const std::vector& keyTypes, @@ -153,12 +168,14 @@ std::vector TopNRowNumberFuzzer::generateInput( // values. This is done to introduce some repetition of key values for // windowing. auto baseVector = vectorFuzzer_.fuzz(keyTypes[i], numPartitions); - children.push_back(BaseVector::wrapInDictionary( - partitionNulls, partitionIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + partitionNulls, partitionIndices, size, baseVector)); } else if (sortingKeySet.find(keyNames[i]) != sortingKeySet.end()) { auto baseVector = vectorFuzzer_.fuzz(keyTypes[i], numPeerGroups); - children.push_back(BaseVector::wrapInDictionary( - sortingNulls, sortingIndices, size, baseVector)); + children.push_back( + BaseVector::wrapInDictionary( + sortingNulls, sortingIndices, size, baseVector)); } else { children.push_back(vectorFuzzer_.fuzz(keyTypes[i], size)); } @@ -182,11 +199,13 @@ TopNRowNumberFuzzer::makeDefaultPlan( projectFields.emplace_back("row_number"); int32_t limit = randInt(1, FLAGS_batch_size); - auto plan = test::PlanBuilder() - .values(input) - .topNRowNumber(partitionKeys, sortKeys, limit, true) - .project(projectFields) - .planNode(); + auto plan = + test::PlanBuilder() + .values(input) + .topNRank( + generateRankFunction(), partitionKeys, sortKeys, limit, true) + .project(projectFields) + .planNode(); return std::make_pair(PlanWithSplits{std::move(plan)}, limit); } @@ -203,11 +222,13 @@ RowNumberFuzzerBase::PlanWithSplits TopNRowNumberFuzzer::makePlanWithTableScan( projectFields.emplace_back("row_number"); auto planNodeIdGenerator = std::make_shared(); - auto plan = test::PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(input[0]->type())) - .topNRowNumber(partitionKeys, sortKeys, limit, true) - .project(projectFields) - .planNode(); + auto plan = + test::PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(input[0]->type())) + .topNRank( + generateRankFunction(), partitionKeys, sortKeys, limit, true) + .project(projectFields) + .planNode(); const std::vector splits = test::makeSplits( input, fmt::format("{}/topn_row_number", tableDir), writerPool_); diff --git a/velox/exec/fuzzer/VeloxQueryRunner.cpp b/velox/exec/fuzzer/VeloxQueryRunner.cpp new file mode 100644 index 000000000000..af938741db0e --- /dev/null +++ b/velox/exec/fuzzer/VeloxQueryRunner.cpp @@ -0,0 +1,244 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/VeloxQueryRunner.h" + +#include +#include +#include +#include +#include +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/exec/tests/utils/QueryAssertions.h" +#include "velox/functions/prestosql/types/BingTileType.h" +#include "velox/functions/prestosql/types/GeometryType.h" +#include "velox/functions/prestosql/types/IPAddressType.h" +#include "velox/functions/prestosql/types/IPPrefixType.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/functions/prestosql/types/UuidType.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/type/parser/TypeParser.h" + +using namespace facebook::velox::runner; + +namespace facebook::velox::exec::test { + +namespace { + +RowTypePtr parseBatchRowType(Batch batch) { + std::vector names; + std::vector types; + + for (const auto& name : *batch.columnNames()) { + names.push_back(name); + } + + // Clean up type strings and format according to TypeParser::parseType + // expectation. Input types are serialized by type()->asRow() for Thrift + // struct by LocalRunnerService in the format of SOME_COMPLEX_TYPE (note the '<' and '>'). And in the case of complex type ROW, + // types follow column name and a semicolon: as an example, ROW. We need to change these particular character choices to + // paranthesis (in the case of the angled brackets) and spaces (in the case of + // the semicolon). As an example, + // MAP -> MAP(VARCHAR, TIMESTAMP) + // ROW -> ROW(f0 TYPE_1, f1 TYPE_2, etc.) + // as expected by TypeParser::parseType. Without this cleanup, parse will + // crash and fuzzer will fail. + for (const auto& typeString : *batch.columnTypes()) { + auto parsedTypeString = typeString; + std::replace(parsedTypeString.begin(), parsedTypeString.end(), '<', '('); + std::replace(parsedTypeString.begin(), parsedTypeString.end(), '>', ')'); + std::replace(parsedTypeString.begin(), parsedTypeString.end(), ':', ' '); + types.push_back(parseType(parsedTypeString)); + } + + return ROW(std::move(names), std::move(types)); +} + +std::vector deserializeBatches( + const std::vector& resultBatches, + memory::MemoryPool* pool) { + std::vector queryResults; + + auto serde = std::make_unique(); + serializer::presto::PrestoVectorSerde::PrestoOptions options; + + for (const auto& batch : resultBatches) { + VELOX_CHECK( + apache::thrift::is_non_optional_field_set_manually_or_by_serializer( + batch.serializedData_ref())); + VELOX_CHECK(!batch.serializedData()->empty()); + + // Deserialize binary data. + const auto& serializedData = *batch.serializedData(); + ByteRange byteRange{ + reinterpret_cast(const_cast(serializedData.data())), + static_cast(serializedData.length()), + 0}; + auto byteStream = std::make_unique( + std::vector{{byteRange}}); + + RowVectorPtr rowVector; + serde->deserialize( + byteStream.get(), + pool, + parseBatchRowType(batch), + &rowVector, + 0, + &options); + + VELOX_CHECK_NOT_NULL(rowVector); + queryResults.push_back(rowVector); + } + + return queryResults; +} + +std::shared_ptr> createThriftClient( + const std::string& host, + int port, + std::chrono::milliseconds timeout, + folly::EventBase* evb) { + folly::SocketAddress addr(host, port); + auto socket = folly::AsyncSocket::newSocket(evb, addr, timeout.count()); + auto channel = + apache::thrift::RocketClientChannel::newChannel(std::move(socket)); + return std::make_shared>( + std::move(channel)); +} +} // namespace + +VeloxQueryRunner::VeloxQueryRunner( + memory::MemoryPool* aggregatePool, + std::string serviceUri, + std::chrono::milliseconds timeout) + : ReferenceQueryRunner(aggregatePool), + serviceUri_(std::move(serviceUri)), + timeout_(timeout) { + pool_ = aggregatePool->addLeafChild("leaf"); + + folly::Uri uri(serviceUri_); + thriftHost_ = uri.host(); + thriftPort_ = uri.port(); +} + +const std::vector& VeloxQueryRunner::supportedScalarTypes() const { + static const std::vector kScalarTypes{ + BOOLEAN(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + VARBINARY(), + TIMESTAMP(), + TIMESTAMP_WITH_TIME_ZONE(), + IPADDRESS(), + UUID(), + // https://github.com/facebookincubator/velox/issues/15379 (IPPREFIX) + // https://github.com/facebookincubator/velox/issues/15380 (Non-orderable + // custom types such as HYPERLOGLOG, JSON, BINGTILE, GEOMETRY, etc.) + }; + return kScalarTypes; +} + +const std::unordered_map& +VeloxQueryRunner::aggregationFunctionDataSpecs() const { + static const std::unordered_map + kAggregationFunctionDataSpecs{}; + return kAggregationFunctionDataSpecs; +} + +std::optional VeloxQueryRunner::toSql( + const core::PlanNodePtr& /*plan*/) { + // We don't need to convert to SQL for VeloxQueryRunner + // as we're sending the serialized plan directly + VELOX_FAIL("VeloxQueryRunner does not support SQL conversion"); +} + +bool VeloxQueryRunner::isConstantExprSupported( + const core::TypedExprPtr& /*expr*/) { + // Since we're using Velox directly, we support all constant expressions + return true; +} + +bool VeloxQueryRunner::isSupported( + const exec::FunctionSignature& /*signature*/) { + // Since we're using Velox directly, we support all function signatures + return true; +} + +std::vector VeloxQueryRunner::execute( + const std::string& /*sql*/) { + VELOX_FAIL("VeloxQueryRunner does not support SQL execution"); +} + +std::vector VeloxQueryRunner::execute( + const std::string& /*sql*/, + const std::string& /*sessionProperty*/) { + VELOX_FAIL("VeloxQueryRunner does not support SQL execution"); +} + +std::pair< + std::optional>>, + ReferenceQueryErrorCode> +VeloxQueryRunner::execute(const core::PlanNodePtr& plan) { + auto serializedPlan = serializePlan(plan); + auto queryId = fmt::format("velox_local_query_runner_{}", rand()); + + auto client = + createThriftClient(thriftHost_, thriftPort_, timeout_, &eventBase_); + + // Create the request + ExecutePlanRequest request; + request.serializedPlan() = serializedPlan; + request.queryId() = queryId; + request.numWorkers() = 4; // Default value + request.numDrivers() = 2; // Default value + + // Send the request + ExecutePlanResponse response; + try { + client->sync_execute(response, request); + } catch (const std::exception& e) { + VELOX_FAIL("Thrift request failed: {}", e.what()); + } + + // Handle the response + if (*response.success()) { + LOG(INFO) << "Reference eval succeeded."; + return std::make_pair( + exec::test::materialize( + deserializeBatches(*response.results(), pool_.get())), + ReferenceQueryErrorCode::kSuccess); + } else { + LOG(INFO) << "Reference eval failed."; + return std::make_pair( + std::nullopt, ReferenceQueryErrorCode::kReferenceQueryFail); + } +} + +std::string VeloxQueryRunner::serializePlan(const core::PlanNodePtr& plan) { + // Serialize the plan to JSON + folly::dynamic serializedPlan = plan->serialize(); + return folly::toJson(serializedPlan); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/VeloxQueryRunner.h b/velox/exec/fuzzer/VeloxQueryRunner.h new file mode 100644 index 000000000000..67826c62bced --- /dev/null +++ b/velox/exec/fuzzer/VeloxQueryRunner.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/exec/fuzzer/ReferenceQueryRunner.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox::exec::test { + +class VeloxQueryRunner : public ReferenceQueryRunner { + public: + /// @param serviceUri Thrift URI of the LocalRunnerService. + /// @param timeout Timeout in milliseconds of a request. + VeloxQueryRunner( + memory::MemoryPool* aggregatePool, + std::string serviceUri, + std::chrono::milliseconds timeout); + + RunnerType runnerType() const override { + return RunnerType::kVeloxQueryRunner; + } + + const std::vector& supportedScalarTypes() const override; + + const std::unordered_map& + aggregationFunctionDataSpecs() const override; + + std::optional toSql(const core::PlanNodePtr& plan) override; + + bool isConstantExprSupported(const core::TypedExprPtr& expr) override; + + bool isSupported(const exec::FunctionSignature& signature) override; + + std::pair< + std::optional>>, + ReferenceQueryErrorCode> + execute(const core::PlanNodePtr& plan) override; + + bool supportsVeloxVectorResults() const override { + return true; + } + + std::vector execute(const std::string& sql) override; + + std::vector execute( + const std::string& sql, + const std::string& sessionProperty) override; + + private: + // Serializes the plan node to JSON string + std::string serializePlan(const core::PlanNodePtr& plan); + + std::string serviceUri_; + std::chrono::milliseconds timeout_; + folly::EventBase eventBase_; + std::shared_ptr pool_; + + // Thrift-specific members + std::string thriftHost_; + int thriftPort_{9091}; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/WindowFuzzer.cpp b/velox/exec/fuzzer/WindowFuzzer.cpp index e782d263cc67..28792adeb277 100644 --- a/velox/exec/fuzzer/WindowFuzzer.cpp +++ b/velox/exec/fuzzer/WindowFuzzer.cpp @@ -151,12 +151,12 @@ std::string WindowFuzzer::frameClauseString( const std::vector& kRangeOffsetColumns) { auto frameType = [&](const core::WindowNode::BoundType boundType, bool isStartBound) -> std::string { - const auto boundTypeString = core::WindowNode::boundTypeName(boundType); + const auto boundTypeString = core::WindowNode::toName(boundType); switch (boundType) { case core::WindowNode::BoundType::kUnboundedPreceding: case core::WindowNode::BoundType::kCurrentRow: case core::WindowNode::BoundType::kUnboundedFollowing: - return boundTypeString; + return std::string(boundTypeString); case core::WindowNode::BoundType::kPreceding: case core::WindowNode::BoundType::kFollowing: { std::string frameBound; @@ -177,7 +177,7 @@ std::string WindowFuzzer::frameClauseString( return fmt::format( " {} BETWEEN {} AND {}", - core::WindowNode::windowTypeName(frameMetadata.windowType), + core::WindowNode::toName(frameMetadata.windowType), frameType(frameMetadata.startBoundType, true), frameType(frameMetadata.endBoundType, false)); } @@ -218,7 +218,7 @@ const T WindowFuzzer::genOffsetAtIdx( "Offset cannot be generated: orderBy key type: {}, sortOrder ascending {}, frameBoundType {}", CppToType::name, sortOrder.toString(), - core::WindowNode::boundTypeName(frameBoundType)); + core::WindowNode::toName(frameBoundType)); return T{}; } @@ -609,8 +609,9 @@ void WindowFuzzer::testAlternativePlans( allKeys.emplace_back(key + " NULLS FIRST"); } for (const auto& keyAndOrder : sortingKeysAndOrders) { - allKeys.emplace_back(fmt::format( - "{} {}", keyAndOrder.key_, keyAndOrder.sortOrder_.toString())); + allKeys.emplace_back( + fmt::format( + "{} {}", keyAndOrder.key_, keyAndOrder.sortOrder_.toString())); } // Streaming window from values. @@ -793,7 +794,8 @@ bool WindowFuzzer::verifyWindow( expectedResult.value(), plan->outputType(), {resultOrError.result}), - "Velox and reference DB results don't match"); + "Velox and reference DB results don't match, plan: {}", + plan->toString(true, true)); LOG(INFO) << "Verified results against reference DB"; } } diff --git a/velox/exec/fuzzer/WindowFuzzerRunner.h b/velox/exec/fuzzer/WindowFuzzerRunner.h index d24536ff8175..e71ae53161c2 100644 --- a/velox/exec/fuzzer/WindowFuzzerRunner.h +++ b/velox/exec/fuzzer/WindowFuzzerRunner.h @@ -71,7 +71,7 @@ class WindowFuzzerRunner { filteredWindowSignatures.empty()) { LOG(ERROR) << "No function left after filtering using 'only' and 'skip' lists."; - exit(1); + return 1; } facebook::velox::parse::registerTypeResolver(); diff --git a/velox/exec/fuzzer/WriterFuzzer.cpp b/velox/exec/fuzzer/WriterFuzzer.cpp index 79109f9dfa84..5dcacca9e180 100644 --- a/velox/exec/fuzzer/WriterFuzzer.cpp +++ b/velox/exec/fuzzer/WriterFuzzer.cpp @@ -16,6 +16,7 @@ #include "velox/exec/fuzzer/WriterFuzzer.h" #include +#include #include #include @@ -149,8 +150,7 @@ class WriterFuzzer { const std::shared_ptr& outputDirectoryPath); // Generates table column handles based on table column properties - std::unordered_map> - getTableColumnHandles( + connector::ColumnHandleMap getTableColumnHandles( const std::vector& names, const std::vector& types, int32_t partitionOffset, @@ -248,7 +248,7 @@ class WriterFuzzer { }; // Supported partition key column types - // According to VectorHasher::typeKindSupportsValueIds and + // According to VectorHasher::typeSupportsValueIds and // https://github.com/prestodb/presto/blob/10143be627beb2c61aba5b3d36af473d2a8ef65e/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java#L593 const std::vector kPartitionKeyTypes_{ BOOLEAN(), @@ -256,7 +256,8 @@ class WriterFuzzer { SMALLINT(), INTEGER(), BIGINT(), - VARCHAR()}; + VARCHAR(), + TIMESTAMP()}; const std::shared_ptr faultyFs_ = std::dynamic_pointer_cast( @@ -364,11 +365,12 @@ void WriterFuzzer::go() { sortColumnOffset -= offset; sortBy.reserve(sortColumns.size()); for (const auto& sortByColumn : sortColumns) { - sortBy.push_back(std::make_shared( - sortByColumn, - kSortOrderTypes_.at( - boost::random::uniform_int_distribution( - 0, 1)(rng_)))); + sortBy.push_back( + std::make_shared( + sortByColumn, + kSortOrderTypes_.at( + boost::random::uniform_int_distribution( + 0, 1)(rng_)))); } } } @@ -483,8 +485,9 @@ std::vector WriterFuzzer::generateInputData( partitionValues.at(j - partitionOffset), size)); } } - input.push_back(std::make_shared( - pool_.get(), inputType, nullptr, size, std::move(children))); + input.push_back( + std::make_shared( + pool_.get(), inputType, nullptr, size, std::move(children))); } return input; @@ -635,14 +638,12 @@ void WriterFuzzer::verifyWriter( LOG(INFO) << "Verified results against reference DB"; } -std::unordered_map> -WriterFuzzer::getTableColumnHandles( +connector::ColumnHandleMap WriterFuzzer::getTableColumnHandles( const std::vector& names, const std::vector& types, const int32_t partitionOffset, const int32_t bucketCount) { - std::unordered_map> - columnHandle; + connector::ColumnHandleMap columnHandle; for (int i = 0; i < names.size(); ++i) { HiveColumnHandle::ColumnType columnType; if (i < partitionOffset) { @@ -706,8 +707,9 @@ RowVectorPtr WriterFuzzer::veloxToPrestoResult(const RowVectorPtr& result) { std::string WriterFuzzer::getReferenceOutputDirectoryPath(int32_t layers) { auto filePath = referenceQueryRunner_->execute("SELECT \"$path\" FROM tmp_write"); + auto stringView = extractSingleValue(filePath); auto tableDirectoryPath = - fs::path(extractSingleValue(filePath)).parent_path(); + fs::path(std::string_view(stringView)).parent_path(); while (layers-- > 0) { tableDirectoryPath = tableDirectoryPath.parent_path(); } @@ -745,11 +747,30 @@ void WriterFuzzer::comparePartitionAndBucket( // If not bucketed, only verify if their partition names match VELOX_CHECK( partitionNames == referencePartitionNames, - "Velox and reference DB output partitions don't match"); - } else { - VELOX_CHECK( - partitionNameAndFileCount == referencedPartitionNameAndFileCount, - "Velox and reference DB output partition and bucket don't match"); + "Velox and reference DB output partitions don't match. Velox: [{}], Presto: [{}]", + fmt::join(partitionNames, ", "), + fmt::join(referencePartitionNames, ", ")); + } else if (partitionNameAndFileCount != referencedPartitionNameAndFileCount) { + std::vector partitionNameAndFileCountStrs; + std::vector referencedPartitionNameAndFileCountStrs; + + partitionNameAndFileCountStrs.reserve(partitionNameAndFileCount.size()); + referencedPartitionNameAndFileCountStrs.reserve( + referencedPartitionNameAndFileCount.size()); + + for (const auto& p : partitionNameAndFileCount) { + partitionNameAndFileCountStrs.push_back( + fmt::format("'{}': {}", p.first, p.second)); + } + for (const auto& p : referencedPartitionNameAndFileCount) { + referencedPartitionNameAndFileCountStrs.push_back( + fmt::format("'{}': {}", p.first, p.second)); + } + + VELOX_FAIL( + "Velox and reference DB output partition and bucket don't match. Velox: {{{}}}, Presto: {{{}}}", + fmt::join(partitionNameAndFileCountStrs, ", "), + fmt::join(referencedPartitionNameAndFileCountStrs, ", ")); } } diff --git a/velox/exec/fuzzer/WriterFuzzerRunner.h b/velox/exec/fuzzer/WriterFuzzerRunner.h index 527b11bbef1f..fbe76685c898 100644 --- a/velox/exec/fuzzer/WriterFuzzerRunner.h +++ b/velox/exec/fuzzer/WriterFuzzerRunner.h @@ -77,15 +77,11 @@ class WriterFuzzerRunner { std::unique_ptr referenceQueryRunner) { filesystems::registerLocalFileSystem(); tests::utils::registerFaultyFileSystem(); - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared( - std::unordered_map())); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + kHiveConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(hiveConnector); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); diff --git a/velox/exec/fuzzer/if/LocalRunnerService.thrift b/velox/exec/fuzzer/if/LocalRunnerService.thrift new file mode 100644 index 000000000000..5e28b5f01518 --- /dev/null +++ b/velox/exec/fuzzer/if/LocalRunnerService.thrift @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file defines a Thrift service for executing Velox query plans remotely. +// Results are returned using Presto's binary serialization format for efficient +// data transfer. + +namespace cpp2 facebook.velox.runner + +// Represents a batch of rows using Presto's binary serialization format. +// The serialized data can be deserialized using PrestoVectorSerde to reconstruct +// the original RowVector. +struct Batch { + // Binary serialized RowVector data in Presto format + 1: binary serializedData; + // Column names in the RowVector + 2: list columnNames; + // Column type strings in the RowVector + 3: list columnTypes; +} + +// Request to execute a serialized Velox query plan. +struct ExecutePlanRequest { + 1: string serializedPlan; + 2: string queryId; + 3: i32 numWorkers = 4; + 4: i32 numDrivers = 2; +} + +// Response from executing a query plan. +struct ExecutePlanResponse { + 1: list results; + 2: string output; + 3: bool success; + 4: optional string errorMessage; +} + +// Service for executing Velox query plans locally. +// This service enables remote execution of serialized query plans with +// configurable parallelism, returning results in a structured format. +service LocalRunnerService { + // Inputs a Thrift request and executes a serialized Velox query plan and + // returns the results as a Thrift response. + ExecutePlanResponse execute(1: ExecutePlanRequest request); +} diff --git a/velox/exec/fuzzer/tests/CMakeLists.txt b/velox/exec/fuzzer/tests/CMakeLists.txt index 3179fad5cfc3..a77e6691a5cc 100644 --- a/velox/exec/fuzzer/tests/CMakeLists.txt +++ b/velox/exec/fuzzer/tests/CMakeLists.txt @@ -15,5 +15,25 @@ add_executable(presto_sql_test PrestoSqlTest.cpp) add_test(presto_sql_test presto_sql_test) -target_link_libraries( - presto_sql_test velox_fuzzer_util velox_presto_types) +target_link_libraries(presto_sql_test velox_fuzzer_util velox_presto_types) + +# LocalRunnerService Test (requires FBThrift support) +if(VELOX_ENABLE_REMOTE_FUNCTIONS) + add_executable(local_runner_service_test LocalRunnerServiceTest.cpp) + add_test(local_runner_service_test local_runner_service_test) + + target_link_libraries( + local_runner_service_test + velox_local_runner_service_lib + local_runner_service_thrift + velox_core + velox_type + velox_functions_prestosql + velox_functions_test_lib + velox_vector_test_lib + velox_common_base + Folly::folly + gtest + gtest_main + ) +endif() diff --git a/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp b/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp new file mode 100644 index 000000000000..97ec69308775 --- /dev/null +++ b/velox/exec/fuzzer/tests/LocalRunnerServiceTest.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/LocalRunnerService.h" +#include "velox/exec/fuzzer/if/gen-cpp2/LocalRunnerService.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/type/Type.h" + +using namespace facebook::velox; +using namespace facebook::velox::runner; +using namespace facebook::velox::test; + +namespace facebook::velox::fuzzer::test { +class LocalRunnerServiceTest : public functions::test::FunctionBaseTest { + protected: + void SetUp() override { + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + functions::prestosql::registerAllScalarFunctions(); + functions::prestosql::registerInternalFunctions(); + + createTestData(); + } + + void createTestData() { + // Create test vectors for different data types + auto rowType = ROW({ + {"bool_col", BOOLEAN()}, + {"int_col", INTEGER()}, + {"bigint_col", BIGINT()}, + {"double_col", DOUBLE()}, + {"varchar_col", VARCHAR()}, + {"timestamp_col", TIMESTAMP()}, + {"array_col", ARRAY(ARRAY(INTEGER()))}, + }); + + testRowVector_ = makeRowVector( + {"bool_col", + "int_col", + "bigint_col", + "double_col", + "varchar_col", + "timestamp_col", + "array_col"}, + { + makeFlatVector( + 10, + [](auto row) { return row % 2 == 0; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return row * 1.1; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return fmt::format("str_{}", row); }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 10, + [](auto row) { return facebook::velox::Timestamp(row, 0); }, + [](auto row) { return row % 3 == 0; }), + makeNestedArrayVectorFromJson( + {"[[1, 2]]", + "[[3]]", + "[[4]]", + "[[5, 6]]", + "[[7]]", + "[[8]]", + "[[9]]", + "[[10]]", + "[[11]]", + "[[12]]"}), + }); + + testRowVectorWrapped_ = makeRowVector( + {"bool_col", + "int_col", + "bigint_col", + "double_col", + "varchar_col", + "timestamp_col", + "array_col"}, + { + makeFlatVector( + 5, + [](auto row) { return row % 2 == 0; }, + [](auto row) { return row % 3 == 0; }), + wrapInDictionary( + makeIndices(5, [](auto row) { return (row * 17 + 3) % 10; }), + 5, + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; })), + BaseVector::wrapInConstant( + 5, + 0, + makeFlatVector( + 10, + [](auto row) { return row; }, + [](auto row) { return row % 3 == 0; })), + makeFlatVector( + 5, + [](auto row) { return row * 1.1; }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 5, + [](auto row) { return fmt::format("str_{}", row); }, + [](auto row) { return row % 3 == 0; }), + makeFlatVector( + 5, + [](auto row) { return facebook::velox::Timestamp(row, 0); }, + [](auto row) { return row % 3 == 0; }), + makeNestedArrayVectorFromJson( + {"[[1, 2]]", + "[[3]]", + "[[4]]", + "[[5, 6]]", + "[[7]]", + "[[8]]", + "[[9]]", + "[[10]]", + "[[11]]", + "[[12]]"}), + }); + } + + RowVectorPtr testRowVector_; + RowVectorPtr testRowVectorWrapped_; +}; + +TEST_F(LocalRunnerServiceTest, ConvertToBatchesRoundTrip) { + auto result = facebook::velox::runner::convertToBatches( + {testRowVector_}, rootPool_.get()); + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].columnNames()->size(), 7); + ASSERT_EQ(result[0].columnTypes()->size(), 7); + + // Verify serializedData is present + ASSERT_GT(result[0].serializedData()->size(), 0); + + // Deserialize and verify + auto leafPool = rootPool_->addLeafChild("deserialize"); + auto serde = std::make_unique< + facebook::velox::serializer::presto::PrestoVectorSerde>(); + facebook::velox::serializer::presto::PrestoVectorSerde::PrestoOptions options; + + const auto& serializedData = *result[0].serializedData(); + ByteRange byteRange{ + reinterpret_cast(const_cast(serializedData.data())), + static_cast(serializedData.length()), + 0}; + auto byteStream = std::make_unique( + std::vector{{byteRange}}); + + RowVectorPtr deserialized; + serde->deserialize( + byteStream.get(), + leafPool.get(), + asRowType(testRowVector_->type()), + &deserialized, + 0, + &options); + + ASSERT_NE(deserialized, nullptr); + ASSERT_EQ(deserialized->size(), testRowVector_->size()); + ASSERT_EQ(deserialized->childrenSize(), testRowVector_->childrenSize()); + + assertEqualVectors(deserialized, testRowVector_); +} + +TEST_F(LocalRunnerServiceTest, ServiceHandlerMockRequestIntegration) { + LocalRunnerServiceHandler handler; + + auto request = std::make_unique(); + // Serialized plan for the following: + // expressions: (p0:DOUBLE, plus(null,0.1646418017335236)) + request->serializedPlan() = + R"({"names":["p0","p1"],"id":"project","name":"ProjectNode","sources":[{"name":"ProjectNode","id":"transform","projections":[{"name":"FieldAccessTypedExpr","type":{"name":"Type","type":"BIGINT"},"inputs":[{"name":"InputTypedExpr","type":{"type":"ROW","name":"Type","names":["row_number"],"cTypes":[{"name":"Type","type":"BIGINT"}]}}],"fieldName":"row_number"}],"names":["row_number"],"sources":[{"name":"ValuesNode","id":"efb6650a_8541_4214_82dd_9792a4965380","data":"AAAAAF4AAAB7ImNUeXBlcyI6W3sidHlwZSI6IkJJR0lOVCIsIm5hbWUiOiJUeXBlIn1dLCJuYW1lcyI6WyJyb3dfbnVtYmVyIl0sInR5cGUiOiJST1ciLCJuYW1lIjoiVHlwZSJ9AQAAAAABAAAAAQAAAAAfAAAAeyJ0eXBlIjoiQklHSU5UIiwibmFtZSI6IlR5cGUifQEAAAAAAQgAAAAAAAAAAAAAAA==","parallelizable":false,"repeatTimes":1}]}],"projections":[{"name":"CallTypedExpr","type":{"name":"Type","type":"DOUBLE"},"functionName":"plus","inputs":[{"name":"ConstantTypedExpr","type":{"name":"Type","type":"DOUBLE"},"valueVector":"AQAAAB8AAAB7InR5cGUiOiJET1VCTEUiLCJuYW1lIjoiVHlwZSJ9AQAAAAE="},{"name":"ConstantTypedExpr","type":{"name":"Type","type":"DOUBLE"},"valueVector":"AQAAAB8AAAB7InR5cGUiOiJET1VCTEUiLCJuYW1lIjoiVHlwZSJ9AQAAAAABAAAAifsSxT8="}]},{"name":"FieldAccessTypedExpr","type":{"name":"Type","type":"BIGINT"},"fieldName":"row_number"}]})"; + request->queryId() = "query1"; + + ExecutePlanResponse response; + handler.execute(response, std::move(request)); + + EXPECT_TRUE(*response.success()); + EXPECT_EQ(response.results()->size(), 1); + + const auto& batch = (*response.results()).front(); + EXPECT_EQ(batch.columnNames()->size(), 2); + EXPECT_EQ((*batch.columnNames())[0], "p0"); + EXPECT_EQ(batch.columnTypes()->size(), 2); + EXPECT_EQ((*batch.columnTypes())[0], "DOUBLE"); + EXPECT_GT(batch.serializedData()->size(), 0); +} + +TEST_F(LocalRunnerServiceTest, ServiceHandlerMockRequestIntegrationFailure) { + LocalRunnerServiceHandler handler; + + auto request = std::make_unique(); + // Serialized plan for the following: + // expressions: (p0:TINYINT, divide(89,"c0") + // Will encounter divide by zero error. + request->serializedPlan() = + R"({"projections":[{"inputs":[{"valueVector":"AQAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQEAAAAAAVk=","type":{"type":"TINYINT","name":"Type"},"name":"ConstantTypedExpr"},{"fieldName":"c0","type":{"type":"TINYINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"functionName":"divide","type":{"type":"TINYINT","name":"Type"},"name":"CallTypedExpr"},{"fieldName":"row_number","type":{"type":"BIGINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"sources":[{"projections":[{"inputs":[{"type":{"cTypes":[{"type":"TINYINT","name":"Type"},{"type":"BIGINT","name":"Type"}],"names":["c0","row_number"],"type":"ROW","name":"Type"},"name":"InputTypedExpr"}],"fieldName":"c0","type":{"type":"TINYINT","name":"Type"},"name":"FieldAccessTypedExpr"},{"inputs":[{"type":{"cTypes":[{"type":"TINYINT","name":"Type"},{"type":"BIGINT","name":"Type"}],"names":["c0","row_number"],"type":"ROW","name":"Type"},"name":"InputTypedExpr"}],"fieldName":"row_number","type":{"type":"BIGINT","name":"Type"},"name":"FieldAccessTypedExpr"}],"sources":[{"parallelizable":false,"repeatTimes":1,"data":"AAAAAIQAAAB7ImNUeXBlcyI6W3sidHlwZSI6IlRJTllJTlQiLCJuYW1lIjoiVHlwZSJ9LHsidHlwZSI6IkJJR0lOVCIsIm5hbWUiOiJUeXBlIn1dLCJuYW1lcyI6WyJjMCIsInJvd19udW1iZXIiXSwidHlwZSI6IlJPVyIsIm5hbWUiOiJUeXBlIn0KAAAAAAIAAAABAgAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAKAAAAAMAAAACAAAABgAAAAAAAAABAAAACAAAAAUAAAAAAAAACAAAAAUAAAACAAAAIAAAAHsidHlwZSI6IlRJTllJTlQiLCJuYW1lIjoiVHlwZSJ9CgAAAAECAAAA9/8oAAAACQAAAAQAAAAJAAAAAAAAAAYAAAAHAAAABAAAAAYAAAAAAAAAAAAAAAIAAAAgAAAAeyJ0eXBlIjoiVElOWUlOVCIsIm5hbWUiOiJUeXBlIn0KAAAAAQIAAAD7oigAAAAJAAAAAQAAAAkAAAAHAAAAAAAAAAUAAAAEAAAAAwAAAAEAAAAAAAAAAAAAACAAAAB7InR5cGUiOiJUSU5ZSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAAQoAAABTOkYvJBw5ZUAAAQAAAAAfAAAAeyJ0eXBlIjoiQklHSU5UIiwibmFtZSI6IlR5cGUifQoAAAAAAVAAAAAAAAAAAAAAAAEAAAAAAAAAAgAAAAAAAAADAAAAAAAAAAQAAAAAAAAABQAAAAAAAAAGAAAAAAAAAAcAAAAAAAAACAAAAAAAAAAJAAAAAAAAAA==","id":"d69f11dc_1f0e_40ae_8c5d_2cde4b784a12","name":"ValuesNode"}],"names":["c0","row_number"],"id":"transform","name":"ProjectNode"}],"names":["p0","p1"],"id":"project","name":"ProjectNode"})"; + request->queryId() = "query1"; + + ExecutePlanResponse response; + handler.execute(response, std::move(request)); + + ASSERT_TRUE(response.errorMessage().has_value()); + auto errorMsg = response.errorMessage().value(); + EXPECT_NE(errorMsg.find("Error Source: USER"), std::string::npos); + EXPECT_NE(errorMsg.find("Error Code: ARITHMETIC_ERROR"), std::string::npos); + EXPECT_NE(errorMsg.find("Reason: division by zero"), std::string::npos); + + EXPECT_FALSE(*response.success()); + EXPECT_EQ(response.results()->size(), 0); +} + +} // namespace facebook::velox::fuzzer::test + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + folly::Init init(&argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/velox/exec/fuzzer/tests/PrestoSqlTest.cpp b/velox/exec/fuzzer/tests/PrestoSqlTest.cpp index 42dd2f8e3ea5..1461d5f16cdd 100644 --- a/velox/exec/fuzzer/tests/PrestoSqlTest.cpp +++ b/velox/exec/fuzzer/tests/PrestoSqlTest.cpp @@ -19,6 +19,10 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/fuzzer/PrestoSql.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/KHyperLogLogType.h" +#include "velox/functions/prestosql/types/QDigestType.h" +#include "velox/functions/prestosql/types/SetDigestType.h" +#include "velox/functions/prestosql/types/TDigestType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" namespace facebook::velox::exec::test { @@ -34,7 +38,13 @@ TEST(PrestoSqlTest, toTypeSql) { EXPECT_EQ(toTypeSql(DOUBLE()), "DOUBLE"); EXPECT_EQ(toTypeSql(VARCHAR()), "VARCHAR"); EXPECT_EQ(toTypeSql(VARBINARY()), "VARBINARY"); + EXPECT_EQ(toTypeSql(TDIGEST(DOUBLE())), "TDIGEST(DOUBLE)"); EXPECT_EQ(toTypeSql(TIMESTAMP()), "TIMESTAMP"); + EXPECT_EQ(toTypeSql(QDIGEST(DOUBLE())), "QDIGEST(DOUBLE)"); + EXPECT_EQ(toTypeSql(QDIGEST(BIGINT())), "QDIGEST(BIGINT)"); + EXPECT_EQ(toTypeSql(QDIGEST(REAL())), "QDIGEST(REAL)"); + EXPECT_EQ(toTypeSql(SETDIGEST()), "SETDIGEST"); + EXPECT_EQ(toTypeSql(KHYPERLOGLOG()), "KHYPERLOGLOG"); EXPECT_EQ(toTypeSql(DATE()), "DATE"); EXPECT_EQ(toTypeSql(TIMESTAMP_WITH_TIME_ZONE()), "TIMESTAMP WITH TIME ZONE"); EXPECT_EQ(toTypeSql(ARRAY(BOOLEAN())), "ARRAY(BOOLEAN)"); @@ -58,9 +68,8 @@ void toUnaryOperator( const std::string& expectedSql) { auto expression = std::make_shared( INTEGER(), - std::vector{ - std::make_shared(VARCHAR(), "c0")}, - operatorName); + operatorName, + std::make_shared(VARCHAR(), "c0")); EXPECT_EQ(toCallSql(expression), expectedSql); } @@ -69,10 +78,9 @@ void toBinaryOperator( const std::string& expectedSql) { auto expression = std::make_shared( INTEGER(), - std::vector{ - std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c1")}, - operatorName); + operatorName, + std::make_shared(INTEGER(), "c0"), + std::make_shared(INTEGER(), "c1")); EXPECT_EQ(toCallSql(expression), expectedSql); } @@ -81,9 +89,8 @@ void toIsNullOrIsNotNull( const std::string& expectedSql) { auto expression = std::make_shared( BOOLEAN(), - std::vector{ - std::make_shared(INTEGER(), "c0")}, - operatorName); + operatorName, + std::make_shared(INTEGER(), "c0")); EXPECT_EQ(toCallSql(expression), expectedSql); } @@ -92,12 +99,12 @@ TEST(PrestoSqlTest, toCallSql) { toUnaryOperator("negate", "(- c0)"); toUnaryOperator("not", "(not c0)"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "not", std::make_shared(VARCHAR(), "c0"), - std::make_shared(VARCHAR(), "c1")}, - "not")), + std::make_shared(VARCHAR(), "c1"))), "Expected one argument to a unary operator"); // Binary operators @@ -114,271 +121,264 @@ TEST(PrestoSqlTest, toCallSql) { toBinaryOperator("gte", "(c0 >= c1)"); toBinaryOperator("distinct_from", "(c0 is distinct from c1)"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "plus", std::make_shared(INTEGER(), "c0"), std::make_shared(INTEGER(), "c1"), - std::make_shared(INTEGER(), "c3")}, - "plus")), + std::make_shared(INTEGER(), "c3"))), "Expected two arguments to a binary operator"); // Functions IS NULL and NOT NULL toIsNullOrIsNotNull("is_null", "(c0 is null)"); toIsNullOrIsNotNull("not_null", "(c0 is not null)"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "is_null", std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c1")}, - "is_null")), + std::make_shared(INTEGER(), "c1"))), "Expected one argument to function 'is_null' or 'not_null'"); // Function IN EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "in", std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b")}, - "in")), + std::make_shared(VARCHAR(), "b"))), "'a' in ('b')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "in", std::make_shared(VARCHAR(), "a"), std::make_shared(VARCHAR(), "b"), std::make_shared(VARCHAR(), "c"), - std::make_shared(VARCHAR(), "d")}, - "in")), + std::make_shared(VARCHAR(), "d"))), "'a' in ('b', 'c', 'd')"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ - std::make_shared(VARCHAR(), "a")}, - "in")), + toCallSql( + std::make_shared( + BOOLEAN(), + "in", + std::make_shared(VARCHAR(), "a"))), "Expected at least two arguments to function 'in'"); // Function LIKE EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "like", std::make_shared(VARCHAR(), "c0"), - std::make_shared(VARCHAR(), "a")}, - "like")), + std::make_shared(VARCHAR(), "a"))), "(c0 like 'a')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "like", std::make_shared(VARCHAR(), "c0"), std::make_shared(VARCHAR(), "a"), - std::make_shared(VARCHAR(), "b")}, - "like")), + std::make_shared(VARCHAR(), "b"))), "(c0 like 'a' escape 'b')"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ - std::make_shared(VARCHAR(), "a")}, - "like")), + toCallSql( + std::make_shared( + BOOLEAN(), + "like", + std::make_shared(VARCHAR(), "a"))), "Expected at least two arguments to function 'like'"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "like", std::make_shared(VARCHAR(), "a"), std::make_shared(VARCHAR(), "b"), std::make_shared(VARCHAR(), "c"), - std::make_shared(VARCHAR(), "d")}, - "like")), + std::make_shared(VARCHAR(), "d"))), "Expected at most three arguments to function 'like'"); // Functions OR and AND EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "or", std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false)}, - "or")), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' or BOOLEAN 'false')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "and", std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false)}, - "and")), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' and BOOLEAN 'false')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "or", std::make_shared(BOOLEAN(), true), std::make_shared(BOOLEAN(), false), std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false)}, - "or")), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' or BOOLEAN 'false' or BOOLEAN 'true' or BOOLEAN 'false')"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "and", std::make_shared(BOOLEAN(), true), std::make_shared(BOOLEAN(), false), std::make_shared(BOOLEAN(), true), - std::make_shared(BOOLEAN(), false)}, - "and")), + std::make_shared(BOOLEAN(), false))), "(BOOLEAN 'true' and BOOLEAN 'false' and BOOLEAN 'true' and BOOLEAN 'false')"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), std::vector{}, "or")), + toCallSql(std::make_shared(BOOLEAN(), "or")), "Expected at least two arguments to function 'or' or 'and'"); // Functions ARRAY_CONSTRUCTOR and ROW_CONSTRUCTOR EXPECT_EQ( - toCallSql(std::make_shared( - ARRAY(INTEGER()), - std::vector{ + toCallSql( + std::make_shared( + ARRAY(INTEGER()), + "array_constructor", std::make_shared(VARCHAR(), "a"), std::make_shared(VARCHAR(), "b"), - std::make_shared(VARCHAR(), "c")}, - "array_constructor")), + std::make_shared(VARCHAR(), "c"))), "ARRAY['a', 'b', 'c']"); EXPECT_EQ( - toCallSql(std::make_shared( - ARRAY(INTEGER()), - std::vector{}, - "array_constructor")), + toCallSql( + std::make_shared( + ARRAY(INTEGER()), "array_constructor")), "ARRAY[]"); EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "row_constructor", std::make_shared(VARCHAR(), "a"), std::make_shared(VARCHAR(), "b"), std::make_shared(VARCHAR(), "c"), - std::make_shared(VARCHAR(), "d")}, - "row_constructor")), + std::make_shared(VARCHAR(), "d"))), "row('a', 'b', 'c', 'd')"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), std::vector{}, "row_constructor")), + toCallSql( + std::make_shared(BOOLEAN(), "row_constructor")), "Expected at least one argument to function 'row_constructor'"); // Function BETWEEN EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "between", std::make_shared(INTEGER(), "c0"), std::make_shared(INTEGER(), "c1"), - std::make_shared(INTEGER(), "c2")}, - "between")), + std::make_shared(INTEGER(), "c2"))), "(c0 between c1 and c2)"); // Edge case check for ambiguous parantheses processing, query will fail // without the parantheses wrapping the left-hand side. EXPECT_EQ( - toCallSql(std::make_shared( - BOOLEAN(), - std::vector{ + toCallSql( + std::make_shared( + BOOLEAN(), + "lt", std::make_shared( BOOLEAN(), - std::vector{ - std::make_shared( - INTEGER(), "c0"), - std::make_shared( - INTEGER(), "c0"), - std::make_shared( - INTEGER(), variant::null(TypeKind::INTEGER))}, - "between"), - std::make_shared(INTEGER(), "c0")}, - "lt")), + "between", + std::make_shared(INTEGER(), "c0"), + std::make_shared(INTEGER(), "c0"), + std::make_shared( + INTEGER(), variant::null(TypeKind::INTEGER))), + std::make_shared(INTEGER(), "c0"))), "((c0 between c0 and cast(null as INTEGER)) < c0)"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - BOOLEAN(), std::vector{}, "between")), + toCallSql(std::make_shared(BOOLEAN(), "between")), "Expected three arguments to function 'between'"); // Function SUBSCRIPT, builds '[]' SQL EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "subscript", std::make_shared( ARRAY(INTEGER()), "array"), - std::make_shared(INTEGER(), "c0")}, - "subscript")), + std::make_shared(INTEGER(), "c0"))), "array[c0]"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "subscript", std::make_shared( ARRAY(INTEGER()), "array"), std::make_shared(INTEGER(), "c0"), - std::make_shared(INTEGER(), "c1")}, - "subscript")), + std::make_shared(INTEGER(), "c1"))), "Expected two arguments to function 'subscript'"); // Function SWITCH, builds 'CASE WHEN ... THEN ... ELSE ... END' SQL // SWITCH cases with no ELSE. EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "switch", std::make_shared(BOOLEAN(), "c0"), - std::make_shared(VARCHAR(), "c1")}, - "switch")), + std::make_shared(VARCHAR(), "c1"))), "case when c0 then c1 end"); EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "switch", std::make_shared(BOOLEAN(), "c0"), std::make_shared(INTEGER(), "c1"), std::make_shared(BOOLEAN(), "c2"), - std::make_shared(INTEGER(), "c3")}, - "switch")), + std::make_shared(INTEGER(), "c3"))), "case when c0 then c1 when c2 then c3 end"); // SWITCH case with ELSE. EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "switch", std::make_shared(BOOLEAN(), "c0"), std::make_shared(INTEGER(), "c1"), std::make_shared(BOOLEAN(), "c2"), std::make_shared(INTEGER(), "c3"), - std::make_shared(INTEGER(), "c4")}, - "switch")), + std::make_shared(INTEGER(), "c4"))), "case when c0 then c1 when c2 then c3 else c4 end"); VELOX_ASSERT_THROW( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ - std::make_shared(INTEGER(), "c0")}, - "switch")), + toCallSql( + std::make_shared( + INTEGER(), + "switch", + std::make_shared(INTEGER(), "c0"))), "Expected at least two arguments to function 'switch'"); // Generic functions EXPECT_EQ( - toCallSql(std::make_shared( - INTEGER(), - std::vector{ + toCallSql( + std::make_shared( + INTEGER(), + "array_top_n", std::make_shared( ARRAY(INTEGER()), "c0"), - std::make_shared(INTEGER(), "c1")}, - "array_top_n")), + std::make_shared(INTEGER(), "c1"))), "array_top_n(c0, c1)"); EXPECT_EQ( - toCallSql(std::make_shared( - REAL(), std::vector{}, "infinity")), + toCallSql(std::make_shared(REAL(), "infinity")), "infinity()"); } @@ -405,5 +405,14 @@ TEST(PrestoSqlTest, toCallInputsSql) { EXPECT_EQ(sql.str(), "c0.field0"); } +TEST(PrestoSqlTest, toConstantSql) { + EXPECT_EQ( + toConstantSql(core::ConstantTypedExpr(INTERVAL_YEAR_MONTH(), 123)), + "INTERVAL '123' YEAR TO MONTH"); + EXPECT_EQ( + toConstantSql(core::ConstantTypedExpr(INTERVAL_DAY_TIME(), int64_t(123))), + "INTERVAL '123' DAY TO SECOND"); +} + } // namespace } // namespace facebook::velox::exec::test diff --git a/velox/exec/prefixsort/CMakeLists.txt b/velox/exec/prefixsort/CMakeLists.txt index aac7db501ecd..ccaca96a1c46 100644 --- a/velox/exec/prefixsort/CMakeLists.txt +++ b/velox/exec/prefixsort/CMakeLists.txt @@ -21,3 +21,5 @@ endif() if(${VELOX_ENABLE_BENCHMARKS}) add_subdirectory(benchmarks) endif() + +velox_install_library_headers() diff --git a/velox/exec/prefixsort/PrefixSortAlgorithm.h b/velox/exec/prefixsort/PrefixSortAlgorithm.h index 385a449c3509..7930b080e0a7 100644 --- a/velox/exec/prefixsort/PrefixSortAlgorithm.h +++ b/velox/exec/prefixsort/PrefixSortAlgorithm.h @@ -96,28 +96,12 @@ class PrefixSortIterator { return (prefix_ - other.prefix_) / other.entrySize_; } - FOLLY_ALWAYS_INLINE bool operator<(const PrefixSortIterator& other) const { - return prefix_ < other.prefix_; - } - - FOLLY_ALWAYS_INLINE bool operator>(const PrefixSortIterator& other) const { - return prefix_ > other.prefix_; - } - - FOLLY_ALWAYS_INLINE bool operator>=(const PrefixSortIterator& other) const { - return prefix_ >= other.prefix_; - } - - FOLLY_ALWAYS_INLINE bool operator<=(const PrefixSortIterator& other) const { - return prefix_ <= other.prefix_; - } - FOLLY_ALWAYS_INLINE bool operator==(const PrefixSortIterator& other) const { return prefix_ == other.prefix_; } - FOLLY_ALWAYS_INLINE bool operator!=(const PrefixSortIterator& other) const { - return prefix_ != other.prefix_; + FOLLY_ALWAYS_INLINE auto operator<=>(const PrefixSortIterator& other) const { + return prefix_ <=> other.prefix_; } private: diff --git a/velox/exec/prefixsort/PrefixSortEncoder.h b/velox/exec/prefixsort/PrefixSortEncoder.h index 945b1de5e160..87756792a726 100644 --- a/velox/exec/prefixsort/PrefixSortEncoder.h +++ b/velox/exec/prefixsort/PrefixSortEncoder.h @@ -28,7 +28,7 @@ namespace facebook::velox::exec::prefixsort { class PrefixSortEncoder { public: PrefixSortEncoder(bool ascending, bool nullsFirst) - : ascending_(ascending), nullsFirst_(nullsFirst){}; + : ascending_(ascending), nullsFirst_(nullsFirst) {} /// Encode native primitive types(such as uint64_t, int64_t, uint32_t, /// int32_t, uint16_t, int16_t, float, double, Timestamp). diff --git a/velox/exec/prefixsort/benchmarks/CMakeLists.txt b/velox/exec/prefixsort/benchmarks/CMakeLists.txt index 77a62b5c0481..5557dca1f5c4 100644 --- a/velox/exec/prefixsort/benchmarks/CMakeLists.txt +++ b/velox/exec/prefixsort/benchmarks/CMakeLists.txt @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_prefix_sort_algorithm_benchmark - PrefixSortAlgorithmBenchmark.cpp) +add_executable(velox_prefix_sort_algorithm_benchmark PrefixSortAlgorithmBenchmark.cpp) target_link_libraries( - velox_prefix_sort_algorithm_benchmark velox_exec_prefixsort_test_lib - velox_vector_test_lib Folly::follybenchmark) + velox_prefix_sort_algorithm_benchmark + velox_exec_prefixsort_test_lib + velox_vector_test_lib + Folly::follybenchmark +) diff --git a/velox/exec/prefixsort/tests/CMakeLists.txt b/velox/exec/prefixsort/tests/CMakeLists.txt index 85e0748d9dea..99f5d91d2c81 100644 --- a/velox/exec/prefixsort/tests/CMakeLists.txt +++ b/velox/exec/prefixsort/tests/CMakeLists.txt @@ -13,16 +13,19 @@ # limitations under the License. add_subdirectory(utils) -add_executable(velox_exec_prefixsort_test PrefixSortAlgorithmTest.cpp - PrefixEncoderTest.cpp) +add_executable(velox_exec_prefixsort_test PrefixSortAlgorithmTest.cpp PrefixEncoderTest.cpp) add_test( NAME velox_exec_prefixsort_test COMMAND velox_exec_prefixsort_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) set_tests_properties(velox_exec_prefixsort_test PROPERTIES TIMEOUT 3000) target_link_libraries( - velox_exec_prefixsort_test velox_exec_prefixsort_test_lib velox_vector_fuzzer - velox_vector_test_lib) + velox_exec_prefixsort_test + velox_exec_prefixsort_test_lib + velox_vector_fuzzer + velox_vector_test_lib +) diff --git a/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp b/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp index 398da750b045..ab88faa09cb2 100644 --- a/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp +++ b/velox/exec/prefixsort/tests/PrefixEncoderTest.cpp @@ -270,8 +270,7 @@ class PrefixEncoderTest : public testing::Test, const auto rightValue = rightVector->isNullAt(i) ? std::nullopt : std::optional(rightVector->valueAt(i)); - if constexpr ( - Kind == TypeKind::VARCHAR || Kind == TypeKind::VARBINARY) { + if constexpr (is_string_kind(Kind)) { encoder.encode(leftValue, leftEncoded, 17, true); encoder.encode(rightValue, rightEncoded, 17, true); } else { diff --git a/velox/exec/prefixsort/tests/utils/CMakeLists.txt b/velox/exec/prefixsort/tests/utils/CMakeLists.txt index 813bd5696e7a..c4691246a33f 100644 --- a/velox/exec/prefixsort/tests/utils/CMakeLists.txt +++ b/velox/exec/prefixsort/tests/utils/CMakeLists.txt @@ -14,5 +14,4 @@ add_library(velox_exec_prefixsort_test_lib EncoderTestUtils.cpp) -target_link_libraries( - velox_exec_prefixsort_test_lib velox_vector_test_lib) +target_link_libraries(velox_exec_prefixsort_test_lib velox_vector_test_lib) diff --git a/velox/exec/tests/AggregateCompanionAdapterTest.cpp b/velox/exec/tests/AggregateCompanionAdapterTest.cpp index 61d27961ca86..e050dff7fdbb 100644 --- a/velox/exec/tests/AggregateCompanionAdapterTest.cpp +++ b/velox/exec/tests/AggregateCompanionAdapterTest.cpp @@ -77,8 +77,8 @@ class AggregateCompanionRegistryTest : public testing::Test { const std::vector& argTypes, const TypePtr& intermediateType, const TypePtr& resultType) { - const auto& [resolvedResult, resolveIntermediate] = - resolveAggregateFunction(name, argTypes); + const auto& resolvedResult = resolveResultType(name, argTypes); + const auto& resolveIntermediate = resolveIntermediateType(name, argTypes); checkEqual(resolvedResult, resultType); checkEqual(resolveIntermediate, intermediateType); } @@ -414,22 +414,27 @@ TEST_F( TEST_F( AggregateCompanionRegistryTest, resultTypeNotResolvableFromIntermediateType) { - // We only register companion functions for original signatures whose result - // type can be resolved from its intermediate type. + // We only register partial, merge and merge_extract companion functions for + // original signatures whose result type cannot be resolved from its + // intermediate type. std::vector> signatures{ AggregateFunctionSignatureBuilder() - .typeVariable("T") - .returnType("array(T)") - .intermediateType("varbinary") - .argumentType("T") + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("i_precision", "min(38, a_precision + 10)") + .integerVariable("r_precision", "min(38, a_precision + 4)") + .integerVariable("r_scale", "min(38, a_scale + 4)") + .returnType("DECIMAL(r_precision, r_scale)") + .intermediateType("ROW(DECIMAL(i_precision, a_scale), bigint)") + .argumentType("DECIMAL(a_precision, a_scale)") .build()}; registerDummyAggregateFunction("aggregateFunc6", signatures); - checkAggregateSignaturesCount("aggregateFunc6_partial", 0); + checkAggregateSignaturesCount("aggregateFunc6_partial", 1); - checkAggregateSignaturesCount("aggregateFunc6_merge", 0); + checkAggregateSignaturesCount("aggregateFunc6_merge", 1); - checkAggregateSignaturesCount("aggregateFunc6_merge_extract", 0); + checkAggregateSignaturesCount("aggregateFunc6_merge_extract", 1); checkScalarSignaturesCount("aggregateFunc6_extract", 0); } diff --git a/velox/exec/tests/AggregateFunctionRegistryTest.cpp b/velox/exec/tests/AggregateFunctionRegistryTest.cpp index 223631579182..72b75c538388 100644 --- a/velox/exec/tests/AggregateFunctionRegistryTest.cpp +++ b/velox/exec/tests/AggregateFunctionRegistryTest.cpp @@ -15,77 +15,131 @@ */ #include "velox/exec/AggregateFunctionRegistry.h" + +#include #include + +#include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/Aggregate.h" #include "velox/exec/AggregateUtil.h" #include "velox/exec/WindowFunction.h" #include "velox/exec/tests/AggregateRegistryTestUtil.h" -#include "velox/functions/Registerer.h" #include "velox/type/Type.h" namespace facebook::velox::exec::test { -class FunctionRegistryTest : public testing::Test { - public: - FunctionRegistryTest() { +class AggregateFunctionRegistryTest : public testing::Test { + protected: + AggregateFunctionRegistryTest() { registerAggregateFunc("aggregate_func"); registerAggregateFunc("Aggregate_Func_Alias"); } - void checkEqual(const TypePtr& actual, const TypePtr& expected) { - if (expected) { - EXPECT_EQ(*actual, *expected); - } else { - EXPECT_EQ(actual, nullptr); + void testResolve( + const std::string& name, + const std::vector& argTypes, + const TypePtr& expectedFinalType, + const TypePtr& expectedIntermediateType) { + { + auto finalType = resolveResultType(name, argTypes); + auto intermediateType = resolveIntermediateType(name, argTypes); + VELOX_EXPECT_EQ_TYPES(finalType, expectedFinalType); + VELOX_EXPECT_EQ_TYPES(intermediateType, expectedIntermediateType); + } + + { + std::vector coercions; + auto finalType = + resolveResultTypeWithCoercions(name, argTypes, coercions); + VELOX_EXPECT_EQ_TYPES(finalType, expectedFinalType); + + EXPECT_EQ(coercions.size(), argTypes.size()); + for (const auto& coercion : coercions) { + EXPECT_EQ(coercion, nullptr); + } } } - void testResolveAggregateFunction( - const std::string& functionName, + void testCoersions( + const std::string& name, const std::vector& argTypes, - const TypePtr& expectedReturn, - const TypePtr& expectedIntermediate) { - auto result = resolveAggregateFunction(functionName, argTypes); - checkEqual(result.first, expectedReturn); - checkEqual(result.second, expectedIntermediate); + const TypePtr& expectedFinalType, + const std::vector& expectedCoercions) { + VELOX_ASSERT_THROW( + resolveResultType(name, argTypes), + "Aggregate function signature is not supported"); + + std::vector coercions; + auto finalType = resolveResultTypeWithCoercions(name, argTypes, coercions); + VELOX_EXPECT_EQ_TYPES(finalType, expectedFinalType); + + EXPECT_EQ(coercions.size(), argTypes.size()); + for (int i = 0; i < coercions.size(); ++i) { + VELOX_EXPECT_EQ_TYPES(coercions[i], expectedCoercions[i]); + } + } + + void clearRegistry() { + aggregateFunctions().withWLock( + [](auto& aggregationFunctionMap) { aggregationFunctionMap.clear(); }); } }; -TEST_F(FunctionRegistryTest, hasAggregateFunctionSignature) { - testResolveAggregateFunction( +TEST_F(AggregateFunctionRegistryTest, basic) { + testResolve( "aggregate_func", {BIGINT(), DOUBLE()}, BIGINT(), ARRAY(BIGINT())); - testResolveAggregateFunction( + testResolve( "aggregate_func", {DOUBLE(), DOUBLE()}, DOUBLE(), ARRAY(DOUBLE())); - testResolveAggregateFunction( + testResolve( "aggregate_func", {ARRAY(BOOLEAN()), ARRAY(BOOLEAN())}, ARRAY(BOOLEAN()), ARRAY(ARRAY(BOOLEAN()))); - testResolveAggregateFunction("aggregate_func", {}, DATE(), DATE()); + testResolve("aggregate_func", {}, DATE(), DATE()); } -TEST_F(FunctionRegistryTest, hasAggregateFunctionSignatureWrongFunctionName) { - testResolveAggregateFunction( - "aggregate_func_nonexist", {BIGINT(), BIGINT()}, nullptr, nullptr); - testResolveAggregateFunction("aggregate_func_nonexist", {}, nullptr, nullptr); +TEST_F(AggregateFunctionRegistryTest, wrongFunctionName) { + VELOX_ASSERT_THROW( + resolveIntermediateType("aggregate_func_nonexist", {BIGINT(), BIGINT()}), + "Aggregate function not registered: aggregate_func_nonexist"); + VELOX_ASSERT_THROW( + resolveIntermediateType("aggregate_func_nonexist", {}), + "Aggregate function not registered: aggregate_func_nonexist"); } -TEST_F(FunctionRegistryTest, hasAggregateFunctionSignatureWrongArgType) { - testResolveAggregateFunction( - "aggregate_func", {DOUBLE(), BIGINT()}, nullptr, nullptr); - testResolveAggregateFunction("aggregate_func", {BIGINT()}, nullptr, nullptr); - testResolveAggregateFunction( - "aggregate_func", {BIGINT(), BIGINT(), BIGINT()}, nullptr, nullptr); +TEST_F(AggregateFunctionRegistryTest, wrongArgType) { + VELOX_ASSERT_THROW( + resolveIntermediateType("aggregate_func", {DOUBLE(), BIGINT()}), + "Aggregate function signature is not supported"); + VELOX_ASSERT_THROW( + resolveResultType("aggregate_func", {BIGINT()}), + "Aggregate function signature is not supported"); + VELOX_ASSERT_THROW( + resolveResultType("aggregate_func", {BIGINT(), BIGINT(), BIGINT()}), + "Aggregate function signature is not supported"); } -TEST_F(FunctionRegistryTest, functionNameInMixedCase) { - testResolveAggregateFunction( +TEST_F(AggregateFunctionRegistryTest, coercions) { + // (bigint, double) -> bigint + // (T, T) -> T + testCoersions( + "aggregate_func", {DOUBLE(), BIGINT()}, DOUBLE(), {nullptr, DOUBLE()}); + + testCoersions( + "aggregate_func", {TINYINT(), BIGINT()}, BIGINT(), {BIGINT(), nullptr}); + + testCoersions( + "aggregate_func", {INTEGER(), DOUBLE()}, BIGINT(), {BIGINT(), nullptr}); +} + +TEST_F(AggregateFunctionRegistryTest, functionNameInMixedCase) { + testResolve( "aggregatE_funC", {BIGINT(), DOUBLE()}, BIGINT(), ARRAY(BIGINT())); - testResolveAggregateFunction( + testResolve( "aggregatE_funC_aliaS", {DOUBLE(), DOUBLE()}, DOUBLE(), ARRAY(DOUBLE())); } -TEST_F(FunctionRegistryTest, getAggregateFunctionSignatures) { +TEST_F(AggregateFunctionRegistryTest, getSignatures) { auto functionSignatures = getAggregateFunctionSignatures(); auto aggregateFuncSignatures = functionSignatures["aggregate_func"]; std::vector aggregateFuncSignaturesStr; @@ -106,7 +160,7 @@ TEST_F(FunctionRegistryTest, getAggregateFunctionSignatures) { ASSERT_EQ(aggregateFuncSignaturesStr, expectedSignaturesStr); } -TEST_F(FunctionRegistryTest, aggregateWindowFunctionSignature) { +TEST_F(AggregateFunctionRegistryTest, windowFunction) { auto windowFunctionSignatures = getWindowFunctionSignatures("aggregate_func"); ASSERT_EQ(windowFunctionSignatures->size(), 3); @@ -121,12 +175,12 @@ TEST_F(FunctionRegistryTest, aggregateWindowFunctionSignature) { ASSERT_EQ(functionSignatures.count("(T,T) -> array(T) -> T"), 1); } -TEST_F(FunctionRegistryTest, duplicateRegistration) { +TEST_F(AggregateFunctionRegistryTest, duplicateRegistration) { EXPECT_FALSE(registerAggregateFunc("aggregate_func")); EXPECT_TRUE(registerAggregateFunc("aggregate_func", true)); } -TEST_F(FunctionRegistryTest, multipleNames) { +TEST_F(AggregateFunctionRegistryTest, multipleNames) { auto signatures = AggregateFunc::signatures(); auto factory = [&](core::AggregationNode::Step step, const std::vector& argTypes, @@ -149,9 +203,9 @@ TEST_F(FunctionRegistryTest, multipleNames) { /*overwrite*/ false); exec::AggregateRegistrationResult allSuccess{true, true, true, true, true}; EXPECT_EQ(registrationResult, allSuccess); - testResolveAggregateFunction( + testResolve( "aggregate_func1", {BIGINT(), DOUBLE()}, BIGINT(), ARRAY(BIGINT())); - testResolveAggregateFunction( + testResolve( "aggregate_func1_partial", {BIGINT(), DOUBLE()}, ARRAY(BIGINT()), @@ -166,7 +220,7 @@ TEST_F(FunctionRegistryTest, multipleNames) { exec::AggregateRegistrationResult onlyMainSuccess{ true, false, false, false, false}; EXPECT_EQ(registrationResult, onlyMainSuccess); - testResolveAggregateFunction( + testResolve( "aggregate_func2", {BIGINT(), DOUBLE()}, BIGINT(), ARRAY(BIGINT())); auto registrationResults = registerAggregateFunction( @@ -179,20 +233,32 @@ TEST_F(FunctionRegistryTest, multipleNames) { false, true, true, true, true}; EXPECT_EQ(registrationResults[0], allSuccessExceptMain); EXPECT_EQ(registrationResults[1], allSuccess); - testResolveAggregateFunction( + testResolve( "aggregate_func2", {BIGINT(), DOUBLE()}, BIGINT(), ARRAY(BIGINT())); - testResolveAggregateFunction( + testResolve( "aggregate_func2_partial", {BIGINT(), DOUBLE()}, ARRAY(BIGINT()), ARRAY(BIGINT())); - testResolveAggregateFunction( + testResolve( "aggregate_func3", {BIGINT(), DOUBLE()}, BIGINT(), ARRAY(BIGINT())); - testResolveAggregateFunction( + testResolve( "aggregate_func3_partial", {BIGINT(), DOUBLE()}, ARRAY(BIGINT()), ARRAY(BIGINT())); } +TEST_F(AggregateFunctionRegistryTest, getAggregateFunctionNames) { + clearRegistry(); + registerAggregateFunc("aggregate_func"); + registerAggregateFunc("Aggregate_Func_Alias"); + + auto functions = getAggregateFunctionNames(); + EXPECT_EQ(functions.size(), 2); + EXPECT_THAT( + functions, + testing::UnorderedElementsAre("aggregate_func", "aggregate_func_alias")); +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/AggregateSpillBenchmarkBase.cpp b/velox/exec/tests/AggregateSpillBenchmarkBase.cpp index b8f15c3c799b..91e0ddc469e8 100644 --- a/velox/exec/tests/AggregateSpillBenchmarkBase.cpp +++ b/velox/exec/tests/AggregateSpillBenchmarkBase.cpp @@ -36,10 +36,11 @@ std::unique_ptr makeRowContainer( true, // nullableKeys std::vector{}, dependentTypes, - false, // hasNext - false, // isJoinBuild - false, // hasProbedFlag - false, // hasNormalizedKey + /*hasNext=*/false, + /*isJoinBuild=*/false, + /*hasProbedFlag=*/false, + /*hasNormalizedKey=*/false, + /*useListRowIndex=*/false, pool.get()); } diff --git a/velox/exec/tests/AggregateSpillBenchmarkBase.h b/velox/exec/tests/AggregateSpillBenchmarkBase.h index aafb67d009ec..f25f35831bbd 100644 --- a/velox/exec/tests/AggregateSpillBenchmarkBase.h +++ b/velox/exec/tests/AggregateSpillBenchmarkBase.h @@ -20,7 +20,7 @@ namespace facebook::velox::exec::test { class AggregateSpillBenchmarkBase : public SpillerBenchmarkBase { public: explicit AggregateSpillBenchmarkBase(std::string spillerType) - : spillerType_(spillerType){}; + : spillerType_(spillerType) {}; /// Sets up the test. void setUp() override; diff --git a/velox/exec/tests/AggregationTest.cpp b/velox/exec/tests/AggregationTest.cpp index 25d552b25ed6..94c235f93c2f 100644 --- a/velox/exec/tests/AggregationTest.cpp +++ b/velox/exec/tests/AggregationTest.cpp @@ -26,6 +26,7 @@ #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Aggregate.h" +#include "velox/exec/AggregateCompanionSignatures.h" #include "velox/exec/GroupingSet.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/PrefixSort.h" @@ -37,6 +38,7 @@ #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/SumNonPODAggregate.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" namespace facebook::velox::exec::test { @@ -332,10 +334,12 @@ class AggregationTest : public OperatorTestBase { std::vector& batches) { std::vector children; dictionary->setSize(count * sizeof(vector_size_t)); - children.push_back(BaseVector::wrapInDictionary( - BufferPtr(nullptr), dictionary, count, rows->childAt(0))); - children.push_back(BaseVector::wrapInDictionary( - BufferPtr(nullptr), dictionary, count, rows->childAt(1))); + children.push_back( + BaseVector::wrapInDictionary( + BufferPtr(nullptr), dictionary, count, rows->childAt(0))); + children.push_back( + BaseVector::wrapInDictionary( + BufferPtr(nullptr), dictionary, count, rows->childAt(1))); children.push_back(children[1]); batches.push_back(vectorMaker_.rowVector(children)); dictionary = AlignedBuffer::allocate( @@ -380,18 +384,20 @@ class AggregationTest : public OperatorTestBase { false, true, true, + false, pool_.get()); } RowTypePtr rowType_{ - ROW({"c0", "c1", "c2", "c3", "c4", "c5", "c6"}, + ROW({"c0", "c1", "c2", "c3", "c4", "c5", "c6", "c7"}, {BIGINT(), SMALLINT(), INTEGER(), BIGINT(), REAL(), DOUBLE(), - VARCHAR()})}; + VARCHAR(), + TIMESTAMP()})}; folly::Random::DefaultGenerator rng_; memory::MemoryReclaimer::Stats reclaimerStats_; VectorFuzzer::Options fuzzerOpts_{ @@ -399,6 +405,8 @@ class AggregationTest : public OperatorTestBase { .nullRatio = 0, .stringLength = 1024, .stringVariableLength = false, + .timestampPrecision = + VectorFuzzer::Options::TimestampPrecision::kMicroSeconds, .allowLazyVector = false}; }; @@ -453,8 +461,8 @@ TEST_F(AggregationTest, missingFunctionOrSignature) { BIGINT(), inputs, "missing-function"); auto wrongInputTypes = std::make_shared(BIGINT(), inputs, "test_aggregate"); - auto missingInputs = std::make_shared( - BIGINT(), std::vector{}, "test_aggregate"); + auto missingInputs = + std::make_shared(BIGINT(), "test_aggregate"); auto makePlan = [&](const core::CallTypedExprPtr& aggExpr) { return PlanBuilder() @@ -475,7 +483,8 @@ TEST_F(AggregationTest, missingFunctionOrSignature) { std::vector{}, std::vector{"agg"}, aggregates, - false, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, std::move(source)); }) .planNode(); @@ -515,9 +524,7 @@ TEST_F(AggregationTest, missingLambdaFunction) { std::make_shared( ROW({"a", "b"}, {BIGINT(), BIGINT()}), std::make_shared( - BIGINT(), - std::vector{field("a"), field("b")}, - "multiply")), + BIGINT(), "multiply", field("a"), field("b"))), }; auto plan = PlanBuilder() @@ -538,7 +545,8 @@ TEST_F(AggregationTest, missingLambdaFunction) { std::vector{}, std::vector{"agg"}, aggregates, - false, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/false, std::move(source)); }) .planNode(); @@ -549,47 +557,6 @@ TEST_F(AggregationTest, missingLambdaFunction) { readCursor(params), "Aggregate function not registered: missing-lambda"); } -TEST_F(AggregationTest, DISABLED_resultTypeMismatch) { - using Step = core::AggregationNode::Step; - - registerAggregateFunction( - "test_aggregate", - {AggregateFunctionSignatureBuilder() - .returnType("bigint") - .intermediateType("bigint") - .argumentType("bigint") - .build()}, - [&](Step /*step*/, - const std::vector& /*argTypes*/, - const TypePtr& /*resultType*/, - const core::QueryConfig& /*config*/) - -> std::unique_ptr { VELOX_UNREACHABLE(); }, - false /*registerCompanionFunctions*/, - true /*overwrite*/); - - for (auto step : {Step::kIntermediate, Step::kPartial}) { - VELOX_ASSERT_THROW( - Aggregate::create( - "test_aggregate", - step, - std::vector{BIGINT()}, - INTEGER(), - core::QueryConfig{{}}), - "Intermediate type mismatch"); - } - - for (auto step : {Step::kFinal, Step::kSingle}) { - VELOX_ASSERT_THROW( - Aggregate::create( - "test_aggregate", - step, - std::vector{BIGINT()}, - INTEGER(), - core::QueryConfig{{}}), - "Final type mismatch"); - } -} - TEST_F(AggregationTest, global) { auto vectors = makeVectors(rowType_, 10, 100); createDuckDbTable(vectors); @@ -669,8 +636,11 @@ TEST_F(AggregationTest, manyGlobalAggregations) { createDuckDbTable(vectors); aggregates.clear(); for (int i = 0; i < rowType->size(); i++) { - aggregates.push_back(fmt::format( - "array_agg({} ORDER BY {})", rowType->nameOf(i), rowType->nameOf(i))); + aggregates.push_back( + fmt::format( + "array_agg({} ORDER BY {})", + rowType->nameOf(i), + rowType->nameOf(i))); } op = PlanBuilder() @@ -711,6 +681,20 @@ TEST_F(AggregationTest, singleStringKeyDistinct) { testSingleKey(vectors, "c6", true, true); } +TEST_F(AggregationTest, singleTimestampKey) { + auto vectors = createVectors(100, rowType_, fuzzerOpts_); + createDuckDbTable(vectors); + testSingleKey(vectors, "c7", false, false); + testSingleKey(vectors, "c7", true, false); +} + +TEST_F(AggregationTest, singleTimestampKeyDistinct) { + auto vectors = createVectors(100, rowType_, fuzzerOpts_); + createDuckDbTable(vectors); + testSingleKey(vectors, "c7", false, true); + testSingleKey(vectors, "c7", true, true); +} + TEST_F(AggregationTest, multiKey) { auto vectors = makeVectors(rowType_, 10, 100); createDuckDbTable(vectors); @@ -861,8 +845,9 @@ TEST_F(AggregationTest, allKeyTypes) { std::vector batches; for (auto i = 0; i < 10; ++i) { - batches.push_back(std::static_pointer_cast( - BatchMaker::createBatch(rowType, 100, *pool_))); + batches.push_back( + std::static_pointer_cast( + BatchMaker::createBatch(rowType, 100, *pool_))); } createDuckDbTable(batches); auto op = @@ -896,12 +881,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimit) { core::PlanNodeId aggNodeId; auto task = AssertQueryBuilder(duckDbQueryRunner_) .config(QueryConfig::kMaxPartialAggregationMemory, 100) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); EXPECT_GT( toPlanStats(task->taskStats()) @@ -919,12 +905,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimit) { // Count aggregation. task = AssertQueryBuilder(duckDbQueryRunner_) .config(QueryConfig::kMaxPartialAggregationMemory, 1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {"count(1)"}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {"count(1)"}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT c0, count(1) FROM tmp GROUP BY 1"); EXPECT_GT( toPlanStats(task->taskStats()) @@ -942,12 +929,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimit) { // Global aggregation. task = AssertQueryBuilder(duckDbQueryRunner_) .config(QueryConfig::kMaxPartialAggregationMemory, 1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({}, {"sum(c0)"}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({}, {"sum(c0)"}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT sum(c0) FROM tmp"); EXPECT_EQ( 0, @@ -982,11 +970,12 @@ TEST_F(AggregationTest, partialDistinctWithAbandon) { .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {}) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {}) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); // with aggregation, just in case. @@ -994,11 +983,12 @@ TEST_F(AggregationTest, partialDistinctWithAbandon) { .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {"sum(c0)"}) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {"sum(c0)"}) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0, sum(c0) FROM tmp group by c0"); } @@ -1016,6 +1006,7 @@ TEST_F(AggregationTest, distinctWithGroupingKeysReordered) { options.vectorSize = vectorSize; options.stringVariableLength = false; options.stringLength = 128; + options.nullRatio = 0.1; VectorFuzzer fuzzer(options, pool()); const int numVectors{5}; std::vector vectors; @@ -1029,20 +1020,20 @@ TEST_F(AggregationTest, distinctWithGroupingKeysReordered) { // first. auto spillDirectory = exec::test::TempDirectoryPath::create(); TestScopedSpillInjection scopedSpillInjection(100); - auto task = - AssertQueryBuilder(duckDbQueryRunner_) - .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) - .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) - .spillDirectory(spillDirectory->getPath()) - .config(QueryConfig::kSpillEnabled, true) - .config(QueryConfig::kAggregationSpillEnabled, true) - .config(QueryConfig::kSpillPrefixSortEnabled, true) - .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c4", "c1", "c3", "c2", "c0"}, {}) - .planNode()) - .assertResults("SELECT distinct c4, c1, c3, c2, c0 FROM tmp"); + auto task = AssertQueryBuilder(duckDbQueryRunner_) + .config(QueryConfig::kAbandonPartialAggregationMinRows, 100) + .config(QueryConfig::kAbandonPartialAggregationMinPct, 50) + .spillDirectory(spillDirectory->getPath()) + .config(QueryConfig::kSpillEnabled, true) + .config(QueryConfig::kAggregationSpillEnabled, true) + .config(QueryConfig::kSpillPrefixSortEnabled, true) + .maxDrivers(1) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c4", "c1", "c3", "c2", "c0"}, {}) + .planNode()) + .assertResults("SELECT distinct c4, c1, c3, c2, c0 FROM tmp"); } TEST_F(AggregationTest, largeValueRangeArray) { @@ -1142,12 +1133,13 @@ TEST_F(AggregationTest, partialAggregationMemoryLimitIncrease) { .config( QueryConfig::kMaxExtendedPartialAggregationMemory, std::to_string(testData.extendedPartialMemoryLimit)) - .plan(PlanBuilder() - .values(vectors) - .partialAggregation({"c0"}, {}) - .capturePlanNodeId(aggNodeId) - .finalAggregation() - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .partialAggregation({"c0"}, {}) + .capturePlanNodeId(aggNodeId) + .finalAggregation() + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); const auto runtimeStats = toPlanStats(task->taskStats()).at(aggNodeId).customStats; @@ -1218,13 +1210,13 @@ TEST_F(AggregationTest, partialAggregationMaybeReservationReleaseCheck) { TEST_F(AggregationTest, spillAll) { auto inputs = makeVectors(rowType_, 100, 10); - const auto numDistincts = - AssertQueryBuilder(PlanBuilder() - .values(inputs) - .singleAggregation({"c0"}, {}, {}) - .planNode()) - .copyResults(pool_.get()) - ->size(); + const auto numDistincts = AssertQueryBuilder( + PlanBuilder() + .values(inputs) + .singleAggregation({"c0"}, {}, {}) + .planNode()) + .copyResults(pool_.get()) + ->size(); auto plan = PlanBuilder() .values(inputs) @@ -1304,7 +1296,7 @@ TEST_F(AggregationTest, memoryAllocations) { // hash table, 1 for the RowContainer holding accumulators, 2 for results (1 // for values of the grouping key column, 1 for sum column). planStats = toPlanStats(task->taskStats()); - ASSERT_EQ(7, planStats.at(aggNodeId).numMemoryAllocations); + ASSERT_EQ(8, planStats.at(aggNodeId).numMemoryAllocations); } TEST_F(AggregationTest, groupingSets) { @@ -1738,11 +1730,12 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithSpill) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(testData.maxOutputRows)) - .plan(PlanBuilder() - .values(inputs) - .singleAggregation({"c0"}, {"array_agg(c1)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(inputs) + .singleAggregation({"c0"}, {"array_agg(c1)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT c0, array_agg(c1) FROM tmp GROUP BY 1"); ASSERT_GT(toPlanStats(task->taskStats()).at(aggrNodeId).spilledBytes, 0); ASSERT_EQ( @@ -1802,11 +1795,12 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithSpillForOrderedAggr) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(testData.maxOutputRows)) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0"}, {"array_agg(c1 order by c1)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0"}, {"array_agg(c1 order by c1)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults( "SELECT c0, array_agg(c1 order by c1) FROM tmp GROUP BY 1"); ASSERT_GT(toPlanStats(task->taskStats()).at(aggrNodeId).spilledBytes, 0); @@ -1848,11 +1842,12 @@ TEST_F(AggregationTest, spillDuringOutputProcessing) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(numOutputRows)) .config(QueryConfig::kSpillNumPartitionBits, "0") - .plan(PlanBuilder() - .values({input}) - .singleAggregation({"c0", "c1"}, {"max(c2)", "min(c3)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values({input}) + .singleAggregation({"c0", "c1"}, {"max(c2)", "min(c3)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults( "SELECT c0, c1, max(c2), min(c3) FROM tmp GROUP BY 1, 2"); @@ -1928,11 +1923,12 @@ TEST_F(AggregationTest, outputBatchSizeCheckWithoutSpill) { .config( QueryConfig::kMaxOutputBatchRows, std::to_string(testData.maxOutputRows)) - .plan(PlanBuilder() - .values(inputs) - .singleAggregation({"c0"}, {"array_agg(c1)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(inputs) + .singleAggregation({"c0"}, {"array_agg(c1)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT c0, array_agg(c1) FROM tmp GROUP BY 1"); ASSERT_EQ( @@ -1958,17 +1954,18 @@ DEBUG_ONLY_TEST_F(AggregationTest, minSpillableMemoryReservation) { createDuckDbTable(batches); for (int32_t minSpillableReservationPct : {5, 50, 100}) { - SCOPED_TRACE(fmt::format( - "minSpillableReservationPct: {}", minSpillableReservationPct)); + SCOPED_TRACE( + fmt::format( + "minSpillableReservationPct: {}", minSpillableReservationPct)); SCOPED_TESTVALUE_SET( "facebook::velox::exec::GroupingSet::addInputForActiveRows", std::function( ([&](exec::GroupingSet* groupingSet) { - memory::MemoryPool& pool = groupingSet->testingPool(); + memory::MemoryPool* pool = groupingSet->testingPool(); const auto availableReservationBytes = - pool.availableReservation(); - const auto currentUsedBytes = pool.usedBytes(); + pool->availableReservation(); + const auto currentUsedBytes = pool->usedBytes(); // Verifies we always have min reservation after ensuring the // input. ASSERT_GE( @@ -1988,10 +1985,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, minSpillableMemoryReservation) { .config( QueryConfig::kSpillableReservationGrowthPct, std::to_string(minSpillableReservationPct + 1)) - .plan(PlanBuilder() - .values(batches) - .singleAggregation({"c0"}, {"array_agg(c2)", "max(c3)"}) - .planNode()) + .plan( + PlanBuilder() + .values(batches) + .singleAggregation({"c0"}, {"array_agg(c2)", "max(c3)"}) + .planNode()) .assertResults( "SELECT c0, array_agg(c2), max(c3) FROM tmp GROUP BY 1"); OperatorTestBase::deleteTaskAndCheckSpillDirectory(task); @@ -2026,11 +2024,12 @@ TEST_F(AggregationTest, distinctWithSpilling) { .spillDirectory(spillDirectory->getPath()) .config(QueryConfig::kSpillEnabled, true) .config(QueryConfig::kAggregationSpillEnabled, true) - .plan(PlanBuilder() - .values(testParam.inputs) - .singleAggregation({"c0"}, {}, {}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(testParam.inputs) + .singleAggregation({"c0"}, {}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); // Verify that spilling is not triggered. @@ -2056,11 +2055,12 @@ TEST_F(AggregationTest, spillingForAggrsWithDistinct) { .spillDirectory(spillDirectory->getPath()) .config(QueryConfig::kSpillEnabled, true) .config(QueryConfig::kAggregationSpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c1"}, {"count(DISTINCT c0)"}, {}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c1"}, {"count(DISTINCT c0)"}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT c1, count(DISTINCT c0) FROM tmp GROUP BY c1"); // Verify that spilling is not triggered. const auto& queryConfig = task->queryCtx()->queryConfig(); @@ -2355,17 +2355,18 @@ TEST_F(AggregationTest, preGroupedAggregationWithSpilling) { .spillDirectory(spillDirectory->getPath()) .config(QueryConfig::kSpillEnabled, true) .config(QueryConfig::kAggregationSpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .aggregation( - {"c0", "c1"}, - {"c0"}, - {"sum(c2)"}, - {}, - core::AggregationNode::Step::kSingle, - false) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .aggregation( + {"c0", "c1"}, + {"c0"}, + {"sum(c2)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT c0, c1, sum(c2) FROM tmp GROUP BY c0, c1"); auto stats = task->taskStats().pipelineStats; // Verify that spilling is not triggered. @@ -2450,8 +2451,9 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringInputProcessing) { auto tempDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -2600,13 +2602,15 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringReserve) { auto tempDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .queryCtx(queryCtx) .copyResults(pool_.get()); @@ -2651,10 +2655,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringReserve) { }))); std::thread taskThread([&]() { - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .queryCtx(queryCtx) .spillDirectory(tempDirectory->getPath()) .config(QueryConfig::kSpillEnabled, true) @@ -2838,8 +2843,9 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringOutputProcessing) { auto tempDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -3012,7 +3018,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringNonReclaimableSection) { if (!testData.nonReclaimableInput) { return; } - if (groupSet->testingPool().usedBytes() == 0) { + if (groupSet->testingPool()->usedBytes() == 0) { return; } if (!injectNonReclaimableSectionOnce.exchange(false)) { @@ -3497,8 +3503,9 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimEmptyInput) { auto tempDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); core::PlanNodeId aggNodeId; auto task = AssertQueryBuilder( @@ -3526,10 +3533,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimEmptyOutput) { auto batches = makeVectors(rowType, 100, 5); auto expectedResult = - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .copyResults(pool_.get()); std::atomic_int numGetOutput{0}; @@ -3569,24 +3577,25 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimEmptyOutput) { auto tempDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); core::PlanNodeId aggNodeId; - auto task = - AssertQueryBuilder(PlanBuilder() - .values(batches) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggNodeId) - .planNode()) - .spillDirectory(tempDirectory->getPath()) - .queryCtx(queryCtx) - .config(QueryConfig::kSpillEnabled, true) - .config(QueryConfig::kAggregationSpillEnabled, true) - // Set the output query configs to ensure fetch the result in one - // output batch. - .config(QueryConfig::kPreferredOutputBatchBytes, 1UL << 30) - .config(QueryConfig::kMaxOutputBatchRows, 1024) - .assertResults(expectedResult); + auto task = AssertQueryBuilder( + PlanBuilder() + .values(batches) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggNodeId) + .planNode()) + .spillDirectory(tempDirectory->getPath()) + .queryCtx(queryCtx) + .config(QueryConfig::kSpillEnabled, true) + .config(QueryConfig::kAggregationSpillEnabled, true) + // Set the output query configs to ensure fetch the result in + // one output batch. + .config(QueryConfig::kPreferredOutputBatchBytes, 1UL << 30) + .config(QueryConfig::kMaxOutputBatchRows, 1024) + .assertResults(expectedResult); // Since the spilling is triggered after the aggregation operator has produced // all the output, we don't expect any spilled data. auto taskStats = exec::toPlanStats(task->taskStats()); @@ -3673,11 +3682,12 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregation) { .config( core::QueryConfig::kMaxSpillRunRows, std::to_string(maxSpillRunRows)) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto taskStats = exec::toPlanStats(task->taskStats()); @@ -3726,11 +3736,12 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromDistinctAggregation) { .config( core::QueryConfig::kMaxSpillRunRows, std::to_string(maxSpillRunRows)) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0"}, {}) - .capturePlanNodeId(aggrNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0"}, {}) + .capturePlanNodeId(aggrNodeId) + .planNode()) .assertResults("SELECT distinct c0 FROM tmp"); auto taskStats = exec::toPlanStats(task->taskStats()); auto& planStats = taskStats.at(aggrNodeId); @@ -3765,10 +3776,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregationOnNoMoreInput) { .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kAggregationSpillEnabled, true) .maxDrivers(1) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto stats = task->taskStats().pipelineStats; @@ -3810,10 +3822,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimFromAggregationDuringOutput) { .config(core::QueryConfig::kPreferredOutputBatchRows, numRows / 10) .maxDrivers(1) //.queryCtx(aggregationQueryCtx) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); auto stats = task->taskStats().pipelineStats; @@ -3832,10 +3845,11 @@ TEST_F(AggregationTest, reclaimFromCompletedAggregation) { std::thread aggregationThread([&]() { auto task = AssertQueryBuilder(duckDbQueryRunner_) - .plan(PlanBuilder() - .values(vectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); waitForTaskCompletion(task.get()); @@ -4015,6 +4029,7 @@ TEST_F(AggregationTest, destroyAfterPartialInitialization) { false, // isJoinBuild false, // hasProbedFlag false, // hasNormalizedKeys + false, // useListRowIndex pool()); const auto rowColumn = rows.columnAt(0); agg.setOffsets( @@ -4030,6 +4045,64 @@ TEST_F(AggregationTest, destroyAfterPartialInitialization) { ASSERT_TRUE(agg.destroyCalled); } +DEBUG_ONLY_TEST_F( + AggregationTest, + uninitializedDistinctAggrWithExternalMemAggrDuringAbort) { + const auto createInput = + [&](int32_t startKey, uint32_t numGroups, uint32_t numElementsPerGroup) { + return makeRowVector({ + makeFlatVector([&]() { + std::vector keys; + for (auto i = 0; i < numGroups; ++i) { + for (auto j = 0; j < numElementsPerGroup; ++j) { + keys.push_back(startKey + i); + } + } + return keys; + }()), + makeFlatVector( + numGroups * numElementsPerGroup, + [&](auto row) { return startKey; }), + }); + }; + + std::vector inputs; + inputs.emplace_back(createInput(0, 10000, 10)); + createDuckDbTable(inputs); + + GroupingSet* groupingSet{nullptr}; + + std::atomic_bool groupingSetExtracted{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::GroupingSet::addInputForActiveRows", + std::function([&](GroupingSet* _groupingSet) { + if (!groupingSetExtracted.exchange(true)) { + groupingSet = _groupingSet; + } + })); + + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::MemoryPoolImpl::reserveThreadSafe", + std::function([&](void* /*unused*/) { + if (groupingSet == nullptr) { + return; + } + if (groupingSet->numRows() > 0) { + VELOX_FAIL("Inject allocation failure."); + } + })); + + auto plan = PlanBuilder() + .values(inputs) + .singleAggregation({"c0"}, {"array_agg(distinct c1)"}) + .planNode(); + + VELOX_ASSERT_THROW( + assertQuery( + plan, "SELECT c0, array_agg(distinct c1) FROM tmp GROUP BY c0"), + "Inject allocation failure."); +} + TEST_F(AggregationTest, nanKeys) { // Some keys are NaNs. auto kNaN = std::numeric_limits::quiet_NaN(); @@ -4065,4 +4138,40 @@ TEST_F(AggregationTest, nanKeys) { {makeRowVector({c0, c1}), c1}, {makeRowVector({e0, e1}), e1}); } + +TEST_F(AggregationTest, keysProvideCustomComparison) { + // Columns reused across test cases. + auto c0 = makeFlatVector( + {0, 1, 256, 257, 512, 513}, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()); + auto c1 = makeFlatVector({1, 2, 1, 2, 1, 2}); + // Expected result columns reused across test cases. A deduplicated version of + // c0 and c1. + auto e0 = makeFlatVector( + {0, 1}, velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON()); + auto e1 = makeFlatVector({1, 2}); + + auto testDistinctAgg = [&](const std::vector& aggKeys, + const std::vector& inputCols, + const std::vector& expectedCols) { + auto plan = PlanBuilder() + .values({makeRowVector(inputCols)}) + .singleAggregation(aggKeys, {}, {}) + .planNode(); + AssertQueryBuilder(plan).assertResults(makeRowVector(expectedCols)); + }; + + // Test with a primitive type key. + testDistinctAgg({"c0"}, {c0}, {e0}); + // Multiple key columns. + testDistinctAgg({"c0", "c1"}, {c0, c1}, {e0, e1}); + + // Test with a complex type key. + testDistinctAgg({"c0"}, {makeRowVector({c0, c1})}, {makeRowVector({e0, e1})}); + // Multiple key columns. + testDistinctAgg( + {"c0", "c1"}, + {makeRowVector({c0, c1}), c1}, + {makeRowVector({e0, e1}), e1}); +} } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/AssertQueryBuilderTest.cpp b/velox/exec/tests/AssertQueryBuilderTest.cpp index db4c44eb2c9f..6ae70783f7a9 100644 --- a/velox/exec/tests/AssertQueryBuilderTest.cpp +++ b/velox/exec/tests/AssertQueryBuilderTest.cpp @@ -125,7 +125,7 @@ TEST_F(AssertQueryBuilderTest, hiveSplits) { } // Split with partition key. - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"ds", partitionKey("ds", VARCHAR())}, {"c0", regularColumn("c0", BIGINT())}}; diff --git a/velox/exec/tests/AssignUniqueIdTest.cpp b/velox/exec/tests/AssignUniqueIdTest.cpp index 852a31adee9a..6a34e3bc1418 100644 --- a/velox/exec/tests/AssignUniqueIdTest.cpp +++ b/velox/exec/tests/AssignUniqueIdTest.cpp @@ -16,31 +16,51 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" -#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/QueryAssertions.h" -using namespace facebook::velox; +namespace facebook::velox::exec { + using namespace facebook::velox::test; -using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; -class AssignUniqueIdTest : public OperatorTestBase { +namespace { + +class AssignUniqueIdTest : public HiveConnectorTestBase { protected: + void SetUp() override { + HiveConnectorTestBase::SetUp(); + } + void verifyUniqueId( const std::shared_ptr& plan, const std::vector& input) { CursorParameters params; params.planNode = plan; - auto result = readCursor(params); - auto numColumns = result.second[0]->childrenSize(); + ASSERT_EQ(result.second[0]->childrenSize(), input[0]->childrenSize() + 1); + verifyUniqueId(input, result.second); + + auto task = result.first->task(); + // Verify number of memory allocations. There should be exactly one + // allocation (per thread of execution) for the values buffer of the + // unique ID vector. Memory should be allocated when producing first + // batch of output and re-used for subsequent batches. + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(1, stats.at(uniqueNodeId_).numMemoryAllocations); + } + + void verifyUniqueId( + const std::vector& input, + const std::vector& vectors) { + auto numColumns = vectors[0]->childrenSize(); ASSERT_EQ(numColumns, input[0]->childrenSize() + 1); std::set ids; for (int i = 0; i < numColumns; i++) { - for (auto batch = 0; batch < result.second.size(); ++batch) { - auto column = result.second[batch]->childAt(i); + for (auto batch = 0; batch < vectors.size(); ++batch) { + auto column = vectors[batch]->childAt(i); if (i < numColumns - 1) { assertEqualVectors(input[batch]->childAt(i), column); } else { @@ -59,15 +79,6 @@ class AssignUniqueIdTest : public OperatorTestBase { } ASSERT_EQ(totalInputSize, ids.size()); - - auto task = result.first->task(); - - // Verify number of memory allocations. There should be exactly one - // allocation (per thread of execution) for the values buffer of the unique - // ID vector. Memory should be allocated when producing first batch of - // output and re-used for subsequent batches. - auto stats = toPlanStats(task->taskStats()); - ASSERT_EQ(1, stats.at(uniqueNodeId_).numMemoryAllocations); } core::PlanNodeId uniqueNodeId_; @@ -76,6 +87,7 @@ class AssignUniqueIdTest : public OperatorTestBase { TEST_F(AssignUniqueIdTest, multiBatch) { vector_size_t batchSize = 1000; std::vector input; + input.reserve(3); for (int i = 0; i < 3; ++i) { input.push_back( makeRowVector({makeFlatVector(batchSize, folly::identity)})); @@ -133,9 +145,9 @@ TEST_F(AssignUniqueIdTest, multiThread) { ASSERT_EQ(batchSize * 8, ids.size()); // Verify number of memory allocations. There should be exactly one - // allocation (per thread of execution) for the values buffer of the unique - // ID vector. Memory should be allocated when producing first batch of - // output and re-used for subsequent batches. + // allocation (per thread of execution) for the values buffer of the + // unique ID vector. Memory should be allocated when producing first batch + // of output and re-used for subsequent batches. auto stats = toPlanStats(task->taskStats()); ASSERT_EQ(8, stats.at(uniqueNodeId_).numMemoryAllocations); } @@ -166,3 +178,46 @@ TEST_F(AssignUniqueIdTest, taskUniqueIdLimit) { AssertQueryBuilder(plan).copyResults(pool()), "(16777216 vs. 16777216) Unique 24-bit ID specified for AssignUniqueId exceeds the limit"); } + +TEST_F(AssignUniqueIdTest, barrier) { + auto rowType{ROW({"c0", "c1"}, {BIGINT(), INTEGER()})}; + + const int numSplits{5}; + + std::vector vectors; + std::vector> tempFiles; + + const int numRowsPerSplit{100}; + for (int32_t i = 0; i < numSplits; ++i) { + auto vector = makeRowVector(rowType, {.vectorSize = numRowsPerSplit}); + vectors.push_back(vector); + tempFiles.push_back(TempFilePath::create()); + } + writeToFiles(toFilePaths(tempFiles), vectors); + + auto plan = PlanBuilder() + .tableScan(rowType) + .assignUniqueId("row_number") + .project({"c0", "c1", "row_number"}) + .planNode(); + + for (const auto barrierExecution : {false, true}) { + SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution)); + + std::shared_ptr task; + auto result = AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits(tempFiles)) + .serialExecution(true) + .barrierExecution(barrierExecution) + .copyResults(pool(), task); + auto results = split(result, numSplits); + + verifyUniqueId(vectors, results); + + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); + ASSERT_EQ(taskStats.numFinishedSplits, numSplits); + } +} +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/AsyncConnectorTest.cpp b/velox/exec/tests/AsyncConnectorTest.cpp index 7f2ad55c00d6..91533ea7f904 100644 --- a/velox/exec/tests/AsyncConnectorTest.cpp +++ b/velox/exec/tests/AsyncConnectorTest.cpp @@ -32,11 +32,16 @@ const std::string kTestConnectorId = "test"; class TestTableHandle : public connector::ConnectorTableHandle { public: - TestTableHandle() : connector::ConnectorTableHandle(kTestConnectorId) {} + TestTableHandle(std::string name) + : connector::ConnectorTableHandle(kTestConnectorId), + name_{std::move(name)} {} - std::string toString() const override { - VELOX_NYI(); + const std::string& name() const override { + return name_; } + + private: + const std::string name_; }; class TestSplit : public connector::ConnectorSplit { @@ -55,7 +60,8 @@ class TestSplit : public connector::ConnectorSplit { return ContinueFuture::makeEmpty(); } - auto [promise, future] = makeVeloxContinuePromiseContract(); + auto [promise, future] = + makeVeloxContinuePromiseContract("TestSplit::touch"); promise_ = std::move(promise); scheduler_.addFunction( @@ -121,7 +127,7 @@ class TestDataSource : public connector::DataSource { return 0; } - std::unordered_map runtimeStats() override { + std::unordered_map getRuntimeStats() override { return {}; } @@ -137,18 +143,17 @@ class TestConnector : public connector::Connector { std::unique_ptr createDataSource( const RowTypePtr& /* outputType */, - const std::shared_ptr& /* tableHandle */, + const ConnectorTableHandlePtr& /* tableHandle */, const std::unordered_map< std::string, - std::shared_ptr>& /* columnHandles */, + connector::ColumnHandlePtr>& /* columnHandles */, connector::ConnectorQueryCtx* connectorQueryCtx) override { return std::make_unique(connectorQueryCtx->memoryPool()); } std::unique_ptr createDataSink( RowTypePtr /*inputType*/, - std::shared_ptr< - ConnectorInsertTableHandle> /*connectorInsertTableHandle*/, + ConnectorInsertTableHandlePtr /*connectorInsertTableHandle*/, ConnectorQueryCtx* /*connectorQueryCtx*/, CommitStrategy /*commitStrategy*/) override final { VELOX_NYI(); @@ -175,15 +180,13 @@ class AsyncConnectorTest : public OperatorTestBase { public: void SetUp() override { OperatorTestBase::SetUp(); - connector::registerConnectorFactory( - std::make_shared()); - auto testConnector = - connector::getConnectorFactory(TestConnectorFactory::kTestConnectorName) - ->newConnector( - kTestConnectorId, - std::make_shared( - std::unordered_map()), - nullptr); + TestConnectorFactory factory; + auto testConnector = factory.newConnector( + kTestConnectorId, + std::make_shared( + std::unordered_map()), + nullptr, + nullptr); connector::registerConnector(testConnector); } @@ -194,7 +197,7 @@ class AsyncConnectorTest : public OperatorTestBase { }; TEST_F(AsyncConnectorTest, basic) { - auto tableHandle = std::make_shared(); + auto tableHandle = std::make_shared("test"); core::PlanNodeId scanId; auto plan = PlanBuilder() .startTableScan() diff --git a/velox/exec/tests/CMakeLists.txt b/velox/exec/tests/CMakeLists.txt index 70342b290739..882fcfcf7052 100644 --- a/velox/exec/tests/CMakeLists.txt +++ b/velox/exec/tests/CMakeLists.txt @@ -15,13 +15,16 @@ add_subdirectory(utils) add_executable( aggregate_companion_functions_test - AggregateCompanionAdapterTest.cpp AggregateCompanionSignaturesTest.cpp - DummyAggregateFunction.cpp) + AggregateCompanionAdapterTest.cpp + AggregateCompanionSignaturesTest.cpp + DummyAggregateFunction.cpp +) add_test( NAME aggregate_companion_functions_test COMMAND aggregate_companion_functions_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( aggregate_companion_functions_test @@ -30,7 +33,8 @@ target_link_libraries( velox_exec_test_lib velox_presto_types GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_executable( velox_exec_test @@ -42,11 +46,13 @@ add_executable( AsyncConnectorTest.cpp ConcatFilesSpillMergeStreamTest.cpp ContainerRowSerdeTest.cpp + ColumnStatsCollectorTest.cpp CustomJoinTest.cpp EnforceSingleRowTest.cpp ExchangeClientTest.cpp ExpandTest.cpp FilterProjectTest.cpp + FilterToExpressionTest.cpp FunctionResolutionTest.cpp HashBitRangeTest.cpp HashJoinBridgeTest.cpp @@ -61,17 +67,19 @@ add_executable( MemoryReclaimerTest.cpp MergeJoinTest.cpp MergeTest.cpp + MergerTest.cpp MultiFragmentTest.cpp NestedLoopJoinTest.cpp OrderByTest.cpp OperatorTraceTest.cpp OutputBufferManagerTest.cpp + ParallelProjectTest.cpp PartitionedOutputTest.cpp PlanNodeSerdeTest.cpp + PlanNodeStatsTest.cpp PlanNodeToStringTest.cpp PlanNodeToSummaryStringTest.cpp PrefixSortTest.cpp - PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp PrintPlanWithStatsTest.cpp ProbeOperatorStateTest.cpp TraceUtilTest.cpp @@ -81,8 +89,11 @@ add_executable( ScaledScanControllerTest.cpp ScaleWriterLocalPartitionTest.cpp SortBufferTest.cpp + SpatialIndexTest.cpp + HilbertIndexTest.cpp SpillerTest.cpp SpillTest.cpp + SplitListenerTest.cpp SplitTest.cpp SqlTest.cpp StreamingAggregationTest.cpp @@ -98,7 +109,12 @@ add_executable( VectorHasherTest.cpp WindowFunctionRegistryTest.cpp WindowTest.cpp - WriterFuzzerUtilTest.cpp) + WriterFuzzerUtilTest.cpp +) + +if(VELOX_ENABLE_GEO) + target_sources(velox_exec_test PRIVATE SpatialJoinTest.cpp) +endif() add_executable( velox_exec_infra_test @@ -112,19 +128,20 @@ add_executable( PrestoQueryRunnerTest.cpp QueryAssertionsTest.cpp TaskTest.cpp - TreeOfLosersTest.cpp) + TreeOfLosersTest.cpp +) -add_test( - NAME velox_exec_test - COMMAND velox_exec_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +add_test(NAME velox_exec_test COMMAND velox_exec_test WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) -set_tests_properties(velox_exec_test PROPERTIES TIMEOUT 3000) +# TODO: Revert back to 3000 once is fixed. +# https://github.com/facebookincubator/velox/issues/13879 +set_tests_properties(velox_exec_test PROPERTIES TIMEOUT 6000) add_test( NAME velox_exec_infra_test COMMAND velox_exec_infra_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) target_link_libraries( velox_exec_test @@ -132,6 +149,7 @@ target_link_libraries( velox_dwio_common velox_dwio_common_exception velox_dwio_common_test_utils + velox_dwio_orc_reader velox_dwio_parquet_reader velox_dwio_parquet_writer velox_exec @@ -156,7 +174,6 @@ target_link_libraries( Boost::date_time Boost::filesystem Boost::program_options - Boost::regex Boost::thread Boost::system GTest::gtest @@ -165,7 +182,8 @@ target_link_libraries( Folly::folly gflags::gflags glog::glog - fmt::fmt) + fmt::fmt +) target_link_libraries( velox_exec_infra_test @@ -193,7 +211,6 @@ target_link_libraries( Boost::date_time Boost::filesystem Boost::program_options - Boost::regex Boost::thread Boost::system GTest::gtest @@ -202,7 +219,36 @@ target_link_libraries( Folly::folly gflags::gflags glog::glog - fmt::fmt) + fmt::fmt +) + +add_executable( + velox_exec_util_test + Main.cpp + PrestoQueryRunnerHyperLogLogTransformTest.cpp + PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp + PrestoQueryRunnerTDigestTransformTest.cpp + PrestoQueryRunnerQDigestTransformTest.cpp + PrestoQueryRunnerKHyperLogLogTransformTest.cpp + PrestoQueryRunnerSetDigestTransformTest.cpp + PrestoQueryRunnerJsonTransformTest.cpp + PrestoQueryRunnerIntervalTransformTest.cpp + PrestoQueryRunnerTimeTransformTest.cpp + PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp +) + +add_test(velox_exec_util_test velox_exec_util_test) + +target_link_libraries( + velox_exec_util_test + velox_fuzzer_util + velox_exec_test_lib + velox_functions_test_lib + velox_presto_types + velox_presto_types_fuzzer_utils + GTest::gtest + GTest::gtest_main +) add_executable(velox_in_10_min_demo VeloxIn10MinDemo.cpp) @@ -215,31 +261,34 @@ target_link_libraries( velox_exec velox_exec_test_lib velox_tpch_connector - velox_memory) - -# Arbitration Fuzzer. -add_executable(velox_memory_arbitration_fuzzer_test - MemoryArbitrationFuzzerTest.cpp) - -target_link_libraries( - velox_memory_arbitration_fuzzer_test velox_memory_arbitration_fuzzer - GTest::gtest GTest::gtest_main) + velox_memory +) -add_executable(velox_table_evolution_fuzzer_test TableEvolutionFuzzerTest.cpp - TableEvolutionFuzzer.cpp) +add_executable( + velox_table_evolution_fuzzer_test + TableEvolutionFuzzerTest.cpp + TableEvolutionFuzzer.cpp +) target_link_libraries( velox_table_evolution_fuzzer_test + velox_expression_fuzzer + velox_expression_test_utility + velox_constrained_input_generators + velox_fuzzer_util + velox_functions_prestosql velox_exec_test_lib velox_temp_path velox_vector_fuzzer GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) add_test( NAME velox_table_evolution_fuzzer_test COMMAND velox_table_evolution_fuzzer_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) add_executable(velox_aggregation_runner_test AggregationRunnerTest.cpp) @@ -250,33 +299,37 @@ target_link_libraries( velox_aggregates velox_vector_test_lib GTest::gtest - GTest::gtest_main) + GTest::gtest_main +) -add_library(velox_simple_aggregate SimpleAverageAggregate.cpp - SimpleArrayAggAggregate.cpp) +add_library(velox_simple_aggregate SimpleAverageAggregate.cpp SimpleArrayAggAggregate.cpp) target_link_libraries( velox_simple_aggregate velox_exec velox_expression velox_expression_functions - velox_aggregates) + velox_aggregates +) -add_executable(velox_simple_aggregate_test SimpleAggregateAdapterTest.cpp - Main.cpp) +add_executable(velox_simple_aggregate_test SimpleAggregateAdapterTest.cpp Main.cpp) target_link_libraries( velox_simple_aggregate_test velox_simple_aggregate velox_exec velox_functions_aggregates_test_lib - GTest::gtest) + GTest::gtest +) add_test(velox_simple_aggregate_test velox_simple_aggregate_test) if(VELOX_ENABLE_BENCHMARKS) - add_library(velox_spiller_join_benchmark_base JoinSpillInputBenchmarkBase.cpp - SpillerBenchmarkBase.cpp) + add_library( + velox_spiller_join_benchmark_base + JoinSpillInputBenchmarkBase.cpp + SpillerBenchmarkBase.cpp + ) target_link_libraries( velox_spiller_join_benchmark_base velox_exec @@ -286,14 +339,22 @@ if(VELOX_ENABLE_BENCHMARKS) glog::glog gflags::gflags Folly::folly - pthread) + pthread + ) add_executable(velox_spiller_join_benchmark SpillerJoinInputBenchmarkTest.cpp) - target_link_libraries(velox_spiller_join_benchmark velox_exec - velox_exec_test_lib velox_spiller_join_benchmark_base) + target_link_libraries( + velox_spiller_join_benchmark + velox_exec + velox_exec_test_lib + velox_spiller_join_benchmark_base + ) - add_library(velox_spiller_aggregate_benchmark_base - AggregateSpillBenchmarkBase.cpp SpillerBenchmarkBase.cpp) + add_library( + velox_spiller_aggregate_benchmark_base + AggregateSpillBenchmarkBase.cpp + SpillerBenchmarkBase.cpp + ) target_link_libraries( velox_spiller_aggregate_benchmark_base velox_exec @@ -303,28 +364,23 @@ if(VELOX_ENABLE_BENCHMARKS) glog::glog gflags::gflags Folly::folly - pthread) + pthread + ) - add_executable(velox_spiller_aggregate_benchmark - SpillerAggregateBenchmarkTest.cpp) + add_executable(velox_spiller_aggregate_benchmark SpillerAggregateBenchmarkTest.cpp) target_link_libraries( - velox_spiller_aggregate_benchmark velox_exec velox_exec_test_lib - velox_spiller_aggregate_benchmark_base) + velox_spiller_aggregate_benchmark + velox_exec + velox_exec_test_lib + velox_spiller_aggregate_benchmark_base + ) endif() -add_executable(cpr_http_client_test CprHttpClientTest.cpp) -add_test( - NAME cpr_http_client_test - COMMAND cpr_http_client_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) -target_link_libraries( - cpr_http_client_test cpr::cpr GTest::gtest GTest::gtest_main) - add_executable(velox_driver_test OperatorReplacementTest.cpp Main.cpp) add_test( NAME velox_driver_test COMMAND velox_driver_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) -target_link_libraries( - velox_driver_test velox_exec velox_exec_test_lib GTest::gtest) +target_link_libraries(velox_driver_test velox_exec velox_exec_test_lib GTest::gtest) diff --git a/velox/exec/tests/ColumnStatsCollectorTest.cpp b/velox/exec/tests/ColumnStatsCollectorTest.cpp new file mode 100644 index 000000000000..250ae45e9203 --- /dev/null +++ b/velox/exec/tests/ColumnStatsCollectorTest.cpp @@ -0,0 +1,289 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/core/PlanNode.h" +#include "velox/duckdb/conversion/DuckParser.h" +#include "velox/exec/ColumnStatsCollector.h" +#include "velox/exec/tests/utils/AggregationResolver.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/parse/TypeResolver.h" + +namespace facebook::velox::exec::test { + +class ColumnStatsCollectorTest : public OperatorTestBase { + protected: + static void SetUpTestCase() { + OperatorTestBase::SetUpTestCase(); + } + + void SetUp() override { + OperatorTestBase::SetUp(); + } + + core::ColumnStatsSpec createStatsSpec( + const RowTypePtr& type, + const std::vector& groupingKeys, + core::AggregationNode::Step step, + const std::vector& aggregates) { + VELOX_CHECK( + step == core::AggregationNode::Step::kPartial || + step == core::AggregationNode::Step::kSingle, + "Unsupported aggregation step: {}", + core::AggregationNode::toName(step)); + + std::vector aggs; + aggs.reserve(aggregates.size()); + std::vector names; + names.reserve(aggregates.size()); + + duckdb::ParseOptions options; + options.parseIntegerAsBigint = true; + AggregateTypeResolver resolver(step); + + for (auto i = 0; i < aggregates.size(); ++i) { + const auto& aggregate = aggregates[i]; + const auto untypedExpr = duckdb::parseAggregateExpr(aggregate, options); + + core::AggregationNode::Aggregate agg; + agg.call = std::dynamic_pointer_cast( + core::Expressions::inferTypes(untypedExpr.expr, type, pool())); + + for (const auto& input : agg.call->inputs()) { + agg.rawInputTypes.push_back(input->type()); + } + + VELOX_CHECK_NULL(untypedExpr.maskExpr); + VELOX_CHECK(!untypedExpr.distinct); + VELOX_CHECK(untypedExpr.orderBy.empty()); + + aggs.emplace_back(agg); + + if (untypedExpr.expr->alias().has_value()) { + names.push_back(untypedExpr.expr->alias().value()); + } else { + names.push_back(fmt::format("a{}", i)); + } + } + VELOX_CHECK_EQ(aggs.size(), names.size()); + + std::vector groupingKeyExprs; + groupingKeyExprs.reserve(groupingKeys.size()); + for (const auto& groupingKey : groupingKeys) { + auto untypedGroupingKeyExpr = duckdb::parseExpr(groupingKey, options); + auto groupingKeyExpr = + std::dynamic_pointer_cast( + core::Expressions::inferTypes( + untypedGroupingKeyExpr, type, pool())); + VELOX_CHECK_NOT_NULL( + groupingKeyExpr, + "Grouping key must use a column name, not an expression: {}", + groupingKey); + groupingKeyExprs.emplace_back(std::move(groupingKeyExpr)); + } + + return core::ColumnStatsSpec( + std::move(groupingKeyExprs), step, std::move(names), std::move(aggs)); + } + + // Helper to create test data + std::vector createTestData( + const RowTypePtr& rowType, + size_t numBatches, + size_t batchSize, + bool withNulls = false) { + VectorFuzzer fuzzer( + {.vectorSize = batchSize, .nullRatio = withNulls ? 0.2 : 0.0}, pool()); + std::vector batches; + batches.reserve(numBatches); + for (int i = 0; i < numBatches; ++i) { + batches.push_back(fuzzer.fuzzInputRow(rowType)); + } + return batches; + } + + std::unique_ptr createStatsCollector( + const core::ColumnStatsSpec& statsSpec, + const RowTypePtr& inputType) { + return std::make_unique( + statsSpec, inputType, &queryConfig_, pool(), &nonReclaimableSection_); + } + + void verifyResult( + const std::vector& inputBatches, + const std::vector& groupingKeys, + core::AggregationNode::Step step, + const std::vector& aggregates, + const std::vector& resultBatches) { + core::PlanNodePtr plan; + if (step == core::AggregationNode::Step::kSingle) { + plan = PlanBuilder() + .values(inputBatches) + .singleAggregation(groupingKeys, aggregates) + .planNode(); + } else { + plan = PlanBuilder() + .values(inputBatches) + .partialAggregation(groupingKeys, aggregates) + .planNode(); + } + + const auto expectedResult = AssertQueryBuilder(plan).copyResults(pool()); + assertEqualResults(resultBatches, {expectedResult}); + } + + core::QueryConfig queryConfig_{{}}; + tsan_atomic nonReclaimableSection_{false}; +}; + +TEST_F(ColumnStatsCollectorTest, basic) { + const auto inputType = + ROW({"c0", "c1", "c2"}, {INTEGER(), DOUBLE(), INTEGER()}); + + struct { + std::vector groupingKeys; + std::vector aggregates; + core::AggregationNode::Step step; + bool hasNulls; + + std::string debugString() const { + return fmt::format( + "groupingKeys: {}, aggregates: {}, {}, hasNulls: {}", + folly::join(",", groupingKeys), + folly::join(",", aggregates), + core::AggregationNode::toName(step), + hasNulls); + } + } testCases[] = { + {std::vector{}, + std::vector{"sum(c0)", "count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kSingle, + false}, + {std::vector{"c0"}, + std::vector{"count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kSingle, + false}, + {std::vector{}, + std::vector{"sum(c0)", "count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kPartial, + false}, + {std::vector{"c0"}, + std::vector{"count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kPartial, + false}, + {std::vector{}, + std::vector{"sum(c0)", "count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kSingle, + true}, + {std::vector{"c0"}, + std::vector{"count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kSingle, + true}, + {std::vector{}, + std::vector{"sum(c0)", "count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kPartial, + true}, + {std::vector{"c0"}, + std::vector{"count(c1)", "sum(c2)"}, + core::AggregationNode::Step::kPartial, + true}, + // Multiple grouping keys. + {std::vector{"c0", "c1"}, + std::vector{"sum(c2)"}, + core::AggregationNode::Step::kSingle, + false}, + {std::vector{"c0", "c1"}, + std::vector{"sum(c2)"}, + core::AggregationNode::Step::kPartial, + false}, + {std::vector{"c0", "c1"}, + std::vector{"sum(c2)"}, + core::AggregationNode::Step::kSingle, + true}, + {std::vector{"c0", "c1"}, + std::vector{"sum(c2)"}, + core::AggregationNode::Step::kPartial, + true}}; + + for (const auto& testCase : testCases) { + SCOPED_TRACE(testCase.debugString()); + + // Create stats spec for count and sum + auto statsSpec = createStatsSpec( + inputType, testCase.groupingKeys, testCase.step, testCase.aggregates); + + // Create collector. + auto collector = createStatsCollector(statsSpec, inputType); + + // Create test data. + auto inputBatches = createTestData(inputType, 10, 100, testCase.hasNulls); + VELOX_ASSERT_THROW(collector->addInput(inputBatches[0]), ""); + + // Initialize. + collector->initialize(); + + // Add input. + for (const auto& input : inputBatches) { + collector->addInput(input); + ASSERT_EQ(collector->getOutput(), nullptr); + } + ASSERT_FALSE(collector->finished()); + collector->noMoreInput(); + ASSERT_FALSE(collector->finished()); + + // Get output + std::vector outputBatches; + for (;;) { + auto output = collector->getOutput(); + if (output != nullptr) { + // Single output row without grouping keys. + if (testCase.groupingKeys.empty()) { + ASSERT_EQ(output->size(), 1); + } + ASSERT_EQ( + output->childrenSize(), + testCase.aggregates.size() + testCase.groupingKeys.size()); + // Verify column names (using asRowType to access field names) + auto rowType = std::dynamic_pointer_cast(output->type()); + ASSERT_NE(rowType, nullptr); + auto outputChannel = testCase.groupingKeys.size(); + const int numAggregates = testCase.aggregates.size(); + for (int i = 0; i < numAggregates; ++i) { + EXPECT_EQ( + rowType->nameOf(outputChannel++), statsSpec.aggregateNames[i]); + } + outputBatches.push_back(output); + EXPECT_FALSE(collector->finished()); + } else { + break; + } + } + // Verify finished. + EXPECT_TRUE(collector->finished()); + + verifyResult( + inputBatches, + testCase.groupingKeys, + testCase.step, + testCase.aggregates, + outputBatches); + } +} +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp b/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp index 57a91976fba6..78218e912667 100644 --- a/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp +++ b/velox/exec/tests/ConcatFilesSpillMergeStreamTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "velox/common/file/FileSystems.h" #include "velox/exec/SortBuffer.h" #include "velox/exec/Spill.h" @@ -21,7 +22,6 @@ #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/type/Type.h" #include "velox/vector/fuzzer/VectorFuzzer.h" -#include "velox/vector/tests/utils/VectorTestBase.h" #include @@ -67,7 +67,7 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { SpillFiles generateSortedSpillFiles( const std::vector& sortedVectors) { - const auto spiller = std::make_unique( + const auto spiller = std::make_unique( inputType_, std::nullopt, HashBitRange{}, @@ -99,8 +99,12 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { std::vector> spillReadFiles; spillReadFiles.reserve(spillFiles.size()); for (const auto& spillFile : spillFiles) { - spillReadFiles.emplace_back(SpillReadFile::create( - spillFile, spillConfig_.readBufferSize, pool_.get(), &spillStats_)); + spillReadFiles.emplace_back( + SpillReadFile::create( + spillFile, + spillConfig_.readBufferSize, + pool_.get(), + &spillStats_)); } auto stream = ConcatFilesSpillMergeStream::create(i - 1, std::move(spillReadFiles)); @@ -196,7 +200,7 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { {"c3", VARCHAR()}}); const std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::hardware_concurrency())}; const std::vector sortColumnIndices_{0, 2}; const std::vector sortCompareFlags_{ CompareFlags{}, @@ -221,6 +225,7 @@ class ConcatFilesSpillMergeStreamTest : public OperatorTestBase { 0, 0, "none", + 0, std::nullopt}; folly::Synchronized spillStats_; diff --git a/velox/exec/tests/CprHttpClientTest.cpp b/velox/exec/tests/CprHttpClientTest.cpp deleted file mode 100644 index 9a19ccaaed24..000000000000 --- a/velox/exec/tests/CprHttpClientTest.cpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/common/base/tests/GTestUtils.h" - -#include -#include - -class CprHttpClientTest : public testing::Test {}; - -// This test requires open access to internet and most places test runners might -// be closed off from the general internet. And this test case is just an -// illustration of how to use cpr, so disable it by default. -TEST_F(CprHttpClientTest, DISABLED_basic) { - auto response = cpr::Get( - cpr::Url{"https://facebookincubator.github.io/velox/"}, - cpr::Timeout{std::chrono::seconds{3}}); - ASSERT_EQ(response.status_code, 200); - ASSERT_FALSE(response.text.empty()); - - response = cpr::Get(cpr::Url{"null"}); - ASSERT_NE(response.status_code, 200); - ASSERT_TRUE(response.text.empty()); - - response = cpr::Post( - cpr::Url{"https://facebookincubator.github.io/velox/"}, - cpr::Body{"select * from nation limit 1"}, - cpr::Header({{"Content-Type", "text/plain"}})); - ASSERT_EQ(response.status_code, 405); - ASSERT_FALSE(response.text.empty()); - - response = cpr::Post( - cpr::Url{"null"}, - cpr::Body{"select * from nation limit 1"}, - cpr::Header({{"Content-Type", "text/plain"}})); - ASSERT_NE(response.status_code, 200); - ASSERT_TRUE(response.text.empty()); -} diff --git a/velox/exec/tests/DriverTest.cpp b/velox/exec/tests/DriverTest.cpp index 7c22ef39a653..3169c79b775f 100644 --- a/velox/exec/tests/DriverTest.cpp +++ b/velox/exec/tests/DriverTest.cpp @@ -124,8 +124,9 @@ class DriverTest : public OperatorTestBase { bool addTestingPauser = false) { std::vector batches; for (int32_t i = 0; i < numBatches; ++i) { - batches.push_back(std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType, rowsInBatch, *pool_))); + batches.push_back( + std::dynamic_pointer_cast( + BatchMaker::createBatch(rowType, rowsInBatch, *pool_))); } if (filterFunc) { int32_t hits = 0; @@ -468,10 +469,11 @@ TEST_F(DriverTest, error) { EXPECT_EQ(numRead, 0); EXPECT_TRUE(stateFutures_.at(0).isReady()); // Realized immediately since task not running. - EXPECT_TRUE(tasks_[0] - ->taskCompletionFuture() - .within(std::chrono::microseconds(1'000'000)) - .isReady()); + EXPECT_TRUE( + tasks_[0] + ->taskCompletionFuture() + .within(std::chrono::microseconds(1'000'000)) + .isReady()); EXPECT_EQ(tasks_[0]->state(), TaskState::kFailed); } @@ -799,8 +801,9 @@ TEST_F(DriverTest, pauserNode) { // all its Tasks in the test instance to create inter-Task pauses. static DriverTest* testInstance; testInstance = this; - Operator::registerOperator(std::make_unique( - kThreadsPerTask, sequence, testInstance)); + Operator::registerOperator( + std::make_unique( + kThreadsPerTask, sequence, testInstance)); std::vector params(kNumTasks); int32_t hits{0}; @@ -1144,6 +1147,43 @@ TEST_F(DriverTest, nonVeloxOperatorException) { "Operator::getOutput failed for [operator: Throw, plan node ID: 1]"); } +TEST_F(DriverTest, enableOperatorBatchSizeStatsConfig) { + CursorParameters params; + int32_t hits; + params.planNode = makeValuesFilterProject( + rowType_, + "m1 % 10 > 0", + "m1 % 3 + m2 % 5 + m3 % 7 + m4 % 11 + m5 % 13 + m6 % 17 + m7 % 19", + 100, + 1'000, + [](int64_t num) { return num % 10 > 0; }, + &hits); + params.maxDrivers = 4; + std::unordered_map queryConfig{ + {core::QueryConfig::kEnableOperatorBatchSizeStats, "true"}}; + params.queryCtx = core::QueryCtx::create( + executor_.get(), core::QueryConfig(std::move(queryConfig))); + int32_t numRead = 0; + readResults(params, ResultOperation::kRead, 1'000'000, &numRead); + EXPECT_EQ(numRead, 4 * hits); + auto stateFuture = tasks_[0]->taskCompletionFuture().within( + std::chrono::microseconds(100'000'000)); + auto& executor = folly::QueuedImmediateExecutor::instance(); + auto state = std::move(stateFuture).via(&executor); + state.wait(); + EXPECT_TRUE(tasks_[0]->isFinished()); + EXPECT_EQ(tasks_[0]->numRunningDrivers(), 0); + const auto taskStats = tasks_[0]->taskStats(); + ASSERT_EQ(taskStats.pipelineStats.size(), 1); + const auto& operatorStats = taskStats.pipelineStats[0].operatorStats; + EXPECT_GT(operatorStats[1].getOutputTiming.wallNanos, 0); + EXPECT_EQ(operatorStats[0].outputPositions, 400'000); + EXPECT_GT(operatorStats[0].outputBytes, 0); + EXPECT_EQ(operatorStats[1].inputPositions, 400'000); + EXPECT_EQ(operatorStats[1].outputPositions, 4 * hits); + EXPECT_GT(operatorStats[1].outputBytes, 0); +} + DEBUG_ONLY_TEST_F(DriverTest, driverSuspensionRaceWithTaskPause) { struct { int numDrivers; @@ -1577,7 +1617,8 @@ DEBUG_ONLY_TEST_F(DriverTest, driverCpuTimeSlicingCheck) { 0, core::QueryCtx::create( driverExecutor_.get(), core::QueryConfig{std::move(queryConfig)}), - testParam.executionMode); + testParam.executionMode, + exec::Consumer{}); while (task->next() != nullptr) { } } diff --git a/velox/exec/tests/ExchangeClientTest.cpp b/velox/exec/tests/ExchangeClientTest.cpp index 86883d2519de..18b0db75bac7 100644 --- a/velox/exec/tests/ExchangeClientTest.cpp +++ b/velox/exec/tests/ExchangeClientTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/exec/ExchangeClient.h" #include #include #include @@ -90,7 +91,8 @@ class ExchangeClientTest core::PlanFragment{plan}, 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); } int32_t enqueue( @@ -107,9 +109,9 @@ class ExchangeClientTest return pageSize; } - std::vector> + std::vector> fetchPages(int consumerId, ExchangeClient& client, int32_t numPages) { - std::vector> allPages; + std::vector> allPages; for (auto i = 0; i < numPages; ++i) { bool atEnd{false}; ContinueFuture future; @@ -137,7 +139,7 @@ class ExchangeClientTest static void enqueue( ExchangeQueue& queue, - std::unique_ptr page) { + std::unique_ptr page) { std::vector promises; { std::lock_guard l(queue.mutex()); @@ -148,10 +150,10 @@ class ExchangeClientTest } } - static std::unique_ptr makePage(uint64_t size) { + static std::unique_ptr makePage(uint64_t size) { auto ioBuf = folly::IOBuf::create(size); ioBuf->append(size); - return std::make_unique(std::move(ioBuf), nullptr, 1); + return std::make_unique(std::move(ioBuf), nullptr, 1); } folly::Executor* executor() const { @@ -587,7 +589,7 @@ TEST_P(ExchangeClientTest, acknowledge) { SCOPED_TESTVALUE_SET( "facebook::velox::exec::test::LocalExchangeSource::pause", std::function(([&numberOfAcknowledgeRequests](void*) { - numberOfAcknowledgeRequests++; + ++numberOfAcknowledgeRequests; }))); { @@ -608,10 +610,11 @@ TEST_P(ExchangeClientTest, acknowledge) { client->addRemoteTaskId(sourceTaskId); client->noMoreRemoteTasks(); - ASSERT_TRUE(std::move(future) - .via(executor()) - .wait(std::chrono::seconds{10}) - .isReady()); + ASSERT_TRUE( + std::move(future) + .via(executor()) + .wait(std::chrono::seconds{10}) + .isReady()); #ifndef NDEBUG // The client knew there is more data available but could not fetch any more @@ -662,7 +665,7 @@ TEST_P(ExchangeClientTest, acknowledge) { int attempts = 100; bool outputBuffersEmpty; while (attempts > 0) { - attempts--; + --attempts; outputBuffersEmpty = bufferManager_->getUtilization(sourceTaskId) == 0; if (outputBuffersEmpty) { break; @@ -688,10 +691,11 @@ TEST_P(ExchangeClientTest, acknowledge) { pages = client->next(1, 1, &atEnd, &dequeueEndOfDataFuture); ASSERT_EQ(0, pages.size()); - ASSERT_TRUE(std::move(dequeueEndOfDataFuture) - .via(executor()) - .wait(std::chrono::seconds{10}) - .isReady()); + ASSERT_TRUE( + std::move(dequeueEndOfDataFuture) + .via(executor()) + .wait(std::chrono::seconds{10}) + .isReady()); pages = client->next(1, 1, &atEnd, &dequeueEndOfDataFuture); ASSERT_EQ(0, pages.size()); ASSERT_TRUE(atEnd); @@ -975,13 +979,193 @@ TEST_P(ExchangeClientTest, minOutputBatchBytesMultipleConsumers) { client->close(); } +TEST_P(ExchangeClientTest, skipRequestDataSizeWithSingleSource) { + // Test skipRequestDataSizeWithSingleSource flag behavior + + struct { + bool skipEnabled; + + std::string debugString() const { + return fmt::format("skipEnabled={}", skipEnabled); + } + } testSettings[] = { + // skip enabled + {true}, + // skip disabled + {false}}; + + for (const auto& setting : testSettings) { + SCOPED_TRACE(setting.debugString()); + + auto client = std::make_shared( + "test-" + setting.debugString(), + 17, + 1024, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, + setting.skipEnabled); + + client->close(); + } +} + +TEST_P(ExchangeClientTest, skipRequestDataSizeNotTriggeredWithMultipleSources) { + // Test that optimization is NOT triggered with multiple sources + + auto data = makeRowVector({makeFlatVector(100, folly::identity)}); + auto page = test::toSerializedPage(data, serdeKind_, bufferManager_, pool()); + + // Client with optimization ENABLED but multiple sources + auto client = std::make_shared( + "test-multi-source", + 17, + page->size() * 10, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, + // enableSingleSourceOptimization = true (but won't trigger with + // multiple sources) + true); + + // Setup: Create tasks with TWO sources + std::vector> tasks; + for (int i = 0; i < 2; ++i) { + auto taskId = fmt::format("local://test-source-{}", i); + auto task = makeTask(taskId); + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); + + // Enqueue data + for (int j = 0; j < 3; ++j) { + enqueue(taskId, 17, data); + } + + tasks.push_back(task); + client->addRemoteTaskId(taskId); + } + + client->noMoreRemoteTasks(); + + // Fetch pages - should work with regular path (not single source + // optimization) + // 3 pages from each of 2 sources + auto pages = fetchPages(1, *client, 6); + ASSERT_EQ(pages.size(), 6); + + // Cleanup + for (auto& task : tasks) { + task->requestCancel(); + bufferManager_->removeTask(task->taskId()); + } + + client->close(); +} + +// Test that lazyFetching=true defers data fetching until next() is called. +// When lazyFetching=false (default), fetching starts immediately when remote +// tasks are added via pickSourcesToRequestLocked(). When lazyFetching=true, +// pickSourcesToRequestLocked() is not called in addRemoteTaskId(), deferring +// the fetch until next() is called. This is useful for cached hash table +// scenarios where waiter tasks may not need the data if the table is already +// cached. +TEST_P(ExchangeClientTest, lazyFetching) { + auto data = makeRowVector({makeFlatVector({1, 2, 3, 4, 5})}); + + // Test with lazyFetching=false (default behavior). + // Verify that fetching starts and we can retrieve pages normally. + { + auto taskId = "local://eager-fetching-test"; + auto task = makeTask(taskId); + + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); + + auto client = std::make_shared( + "t", + 17, + ExchangeClient::kDefaultMaxQueuedBytes, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, // requestDataSizesMaxWaitSec + false, // skipRequestDataSizeWithSingleSource + false); // lazyFetching=false (default) + + client->addRemoteTaskId(taskId); + enqueue(taskId, 17, data); + + auto pages = fetchPages(1, *client, 1); + ASSERT_EQ(1, pages.size()); + + task->requestCancel(); + bufferManager_->removeTask(taskId); + client->close(); + } + + // Test with lazyFetching=true. + // Verify that we can still retrieve pages (fetch is triggered by next()). + { + auto taskId = "local://lazy-fetching-test"; + auto task = makeTask(taskId); + + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); + + auto client = std::make_shared( + "t", + 17, + ExchangeClient::kDefaultMaxQueuedBytes, + 1, + kDefaultMinExchangeOutputBatchBytes, + pool(), + executor(), + 10, // requestDataSizesMaxWaitSec + false, // skipRequestDataSizeWithSingleSource + true); // lazyFetching=true + + client->addRemoteTaskId(taskId); + enqueue(taskId, 17, data); + + // Even with lazy fetching, we should be able to retrieve pages + // since next() triggers the fetch. + auto pages = fetchPages(1, *client, 1); + ASSERT_EQ(1, pages.size()); + + task->requestCancel(); + bufferManager_->removeTask(taskId); + client->close(); + } +} + +// Test the new hasNoMoreSources() API +TEST_P(ExchangeClientTest, hasNoMoreSourcesApi) { + auto queue = std::make_shared(1, 0); + + // Initially, should return false + EXPECT_FALSE(queue->hasNoMoreSources()); + + // After calling noMoreSources(), should return true + queue->noMoreSources(); + + EXPECT_TRUE(queue->hasNoMoreSources()); +} + VELOX_INSTANTIATE_TEST_SUITE_P( ExchangeClientTest, ExchangeClientTest, testing::Values( VectorSerde::Kind::kPresto, VectorSerde::Kind::kCompactRow, - VectorSerde::Kind::kUnsafeRow)); + VectorSerde::Kind::kUnsafeRow), + [](const testing::TestParamInfo& info) { + return fmt::format("{}", info.param); + }); } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/ExpandTest.cpp b/velox/exec/tests/ExpandTest.cpp index f541d20c04a5..d1161b374fc7 100644 --- a/velox/exec/tests/ExpandTest.cpp +++ b/velox/exec/tests/ExpandTest.cpp @@ -21,7 +21,6 @@ using namespace facebook::velox; using namespace facebook::velox::exec::test; namespace facebook::velox::exec { - namespace { class ExpandTest : public OperatorTestBase { public: @@ -37,7 +36,31 @@ class ExpandTest : public OperatorTestBase { }); } }; -} // anonymous namespace + +TEST_F(ExpandTest, complexConstant) { + auto data = makeRowVectorData(3); + auto children = data->children(); + auto arrayVector = + makeArrayVector({{1, 2, 3}, {1, 2, 3}, {1, 2, 3}}); + children.push_back(arrayVector); + children.push_back(makeAllNullArrayVector(3, INTEGER())); + children.push_back(makeNullConstant(TypeKind::INTEGER, 3)); + auto expected = makeRowVector(children); + + auto plan = PlanBuilder(pool()) + .values({data}) + .expand( + {{"k1", + "k2", + "a", + "b", + "ARRAY[1, 2, 3] as c", + "null::integer[] as d", + "null::integer as e"}}) + .planNode(); + + assertQuery(plan, expected); +} TEST_F(ExpandTest, groupingSets) { auto data = makeRowVectorData(1'000); @@ -151,4 +174,5 @@ TEST_F(ExpandTest, invalidUseCases) { "projections must not be empty."); } +} // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/ExpressionBuilderTest.cpp b/velox/exec/tests/ExpressionBuilderTest.cpp new file mode 100644 index 000000000000..084d3275d115 --- /dev/null +++ b/velox/exec/tests/ExpressionBuilderTest.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/ExpressionBuilder.h" + +#include +#include "velox/parse/ExpressionsParser.h" +#include "velox/type/Variant.h" + +namespace facebook::velox::expr_builder::test { +namespace { + +// Test convenience functions for downcasting. +template +std::shared_ptr as(detail::ExprWrapper in) { + return std::dynamic_pointer_cast(in.expr()); +} + +template +bool is(detail::ExprWrapper in) { + return std::dynamic_pointer_cast(in.expr()) != nullptr; +} + +// Parses a SQL expression using DuckDB. +core::ExprPtr parseSql(const std::string& sql) { + return parse::parseExpr(sql, {}); +} + +TEST(ExpressionBuilderTest, columnReference) { + EXPECT_EQ(col("c0"), parseSql("c0")); + EXPECT_EQ(parseSql("c0"), col("c0")); + EXPECT_EQ("c0"_c, parseSql("c0")); + + EXPECT_EQ(col("parent", "child"), parseSql("parent.child")); + EXPECT_EQ(col("parent").subfield("child"), parseSql("parent.child")); +} + +TEST(ExpressionBuilderTest, literals) { + auto validate = [](detail::ExprWrapper expr, + const TypePtr& expectedType, + variant expectedValue) { + EXPECT_TRUE(is(expr)); + auto constant = as(expr); + EXPECT_EQ(*constant->type(), *expectedType); + EXPECT_TRUE(constant->value().equalsWithEpsilon(expectedValue)); + }; + + // Integer literal types. + validate(lit(123456L), BIGINT(), variant(123456L)); + validate(lit(123), INTEGER(), variant(123)); + validate(lit(int16_t(123)), SMALLINT(), variant(int16_t(123))); + validate(lit(int8_t(123)), TINYINT(), variant(int8_t(123))); + + // Boolean. + validate(lit(true), BOOLEAN(), variant(true)); + validate(lit(false), BOOLEAN(), variant(false)); + + // Floating point. + validate(lit(10.1f), REAL(), variant(10.1f)); + validate(lit(10.1), DOUBLE(), variant(10.1)); + + // String. + validate(lit("str"), VARCHAR(), variant("str")); + + // Null. + validate(lit(nullptr), UNKNOWN(), variant::null(TypeKind::UNKNOWN)); +} + +TEST(ExpressionBuilderTest, comparisons) { + // Make sure all combinations work, as long as at least one side is a + // ExprWrapper. + EXPECT_EQ(col("a") == lit(10L), parseSql("a = 10")); + EXPECT_EQ(lit(10L) == col("a"), parseSql("10 = a")); + + EXPECT_EQ(col("a") == 10L, parseSql("a = 10")); + EXPECT_EQ(10L == col("a"), parseSql("10 = a")); + + EXPECT_EQ(col("a") == col("b"), parseSql("a = b")); + EXPECT_EQ(col("a") == nullptr, parseSql("a = null")); + + // Other comparisons. + EXPECT_EQ(col("a") != 1.1, parseSql("a != 1.1")); + EXPECT_EQ(col("a") != lit(1.1), parseSql("a != 1.1")); + EXPECT_EQ(col("a") > 42L, parseSql("a > 42")); + EXPECT_EQ(col("a") >= 42L, parseSql("a >= 42")); + EXPECT_EQ(col("a") < 42L, parseSql("a < 42")); + EXPECT_EQ(col("a") <= 42L, parseSql("a <= 42")); + + EXPECT_EQ(!col("a"), parseSql("not a")); + EXPECT_EQ(isNull(col("a")), parseSql("a is null")); + EXPECT_EQ(col("a").isNull(), parseSql("a is null")); + EXPECT_EQ(!isNull(col("a")), parseSql("a is not null")); + EXPECT_EQ(!col("a").isNull(), parseSql("a is not null")); + + EXPECT_EQ(isNull("a"), parseSql("\'a\' is null")); // this is "a" literal. +} + +TEST(ExpressionBuilderTest, between) { + EXPECT_EQ(between(col("a"), 0L, 10L), parseSql("a between 0 and 10")); + + EXPECT_EQ(col("a").between(0L, 10L), parseSql("a between 0 and 10")); +} + +TEST(ExpressionBuilderTest, arithmetics) { + EXPECT_EQ(col("b") + 1L, parseSql("b + 1")); + EXPECT_EQ(1L + col("b"), parseSql("1 + b")); + EXPECT_EQ(lit("str") + col("b"), parseSql("'str' + b")); + + EXPECT_EQ(col("b") - 1L, parseSql("b - 1")); + EXPECT_EQ(col("b") * 1L, parseSql("b * 1")); + EXPECT_EQ(col("b") / 1L, parseSql("b / 1")); + EXPECT_EQ(col("b") % 1L, parseSql("b % 1")); + + EXPECT_EQ(col("b") + 1L / col("c") * 10L, parseSql("b + 1 / c * 10")); +} + +TEST(ExpressionBuilderTest, conjuncts) { + EXPECT_EQ(col("b") && 1L, parseSql("b and 1")); + EXPECT_EQ(col("b") || 1L, parseSql("b or 1")); + EXPECT_EQ(col("b") || false, parseSql("b or false")); + + EXPECT_EQ(col("a") && col("b") || col("c"), parseSql("a and b or c")); +} + +TEST(ExpressionBuilderTest, functions) { + EXPECT_EQ(call("func"), parseSql("func()")); + EXPECT_EQ( + call("func", col("a"), 100L, col("c")), parseSql("func(a, 100, c)")); + + // Nested functions. + auto expr = call("f1", call("f2", col("a") > call("f3", col("d")))); + EXPECT_EQ(expr, parseSql("f1(f2(a > f3(d)))")); + + expr = 10L * col("c1") > call("func", 3.4, col("g") / col("h"), call("j")); + EXPECT_EQ(expr, parseSql("10 * c1 > func(3.4, g / h, j())")); +} + +TEST(ExpressionBuilderTest, casts) { + // Casts. + EXPECT_EQ(lit("1").cast(TINYINT()).toString(), "cast(1 as TINYINT)"); + EXPECT_EQ( + col("c0").cast(VARBINARY()).toString(), "cast(\"c0\" as VARBINARY)"); + + EXPECT_EQ(cast(1, TINYINT()).toString(), "cast(1 as TINYINT)"); + EXPECT_EQ( + cast(col("c0"), VARBINARY()).toString(), "cast(\"c0\" as VARBINARY)"); + + // Try casts. + EXPECT_EQ(lit("1").tryCast(TINYINT()).toString(), "try_cast(1 as TINYINT)"); + EXPECT_EQ( + col("c0").tryCast(VARBINARY()).toString(), + "try_cast(\"c0\" as VARBINARY)"); + + EXPECT_EQ(tryCast(1, TINYINT()).toString(), "try_cast(1 as TINYINT)"); + EXPECT_EQ( + tryCast(col("c0"), VARBINARY()).toString(), + "try_cast(\"c0\" as VARBINARY)"); +} + +TEST(ExpressionBuilderTest, alias) { + EXPECT_EQ(lit("str").alias("col"), parseSql("'str' as col")); + EXPECT_EQ(col("c1").alias("col"), parseSql("c1 as col")); + EXPECT_EQ((col("c1") > 1.1).alias("col"), parseSql("c1 > 1.1 as col")); + + EXPECT_EQ( + col("c1").between(1L, 10L).alias("my_col"), + parseSql("c1 between 1 and 10 as my_col")); + + // As a free function. + EXPECT_EQ(alias(col("c1") == "bla", "col"), parseSql("c1 = 'bla' as col")); +} + +TEST(ExpressionBuilderTest, lambdas) { + EXPECT_EQ(lambda("x", 1L), parseSql("x -> 1")); + EXPECT_EQ(lambda({"x"}, 1L), parseSql("x -> 1")); + EXPECT_EQ(lambda({"x"}, col("x") + 1L), parseSql("x -> x + 1")); + EXPECT_EQ( + lambda({"x", "y"}, col("x") * col("y")), parseSql("(x, y) -> x * y")); +} + +} // namespace +} // namespace facebook::velox::expr_builder::test diff --git a/velox/exec/tests/FilterProjectTest.cpp b/velox/exec/tests/FilterProjectTest.cpp index 9d7f2e8f172a..87d21381d37a 100644 --- a/velox/exec/tests/FilterProjectTest.cpp +++ b/velox/exec/tests/FilterProjectTest.cpp @@ -17,17 +17,12 @@ #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" -#include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/parse/Expressions.h" -using namespace facebook::velox; -using namespace facebook::velox::exec; -using namespace facebook::velox::exec::test; +namespace facebook::velox::exec { +namespace { -using facebook::velox::test::BatchMaker; - -class FilterProjectTest : public HiveConnectorTestBase { +class FilterProjectTest : public test::HiveConnectorTestBase { protected: void SetUp() override { HiveConnectorTestBase::SetUp(); @@ -37,7 +32,7 @@ class FilterProjectTest : public HiveConnectorTestBase { const std::vector& vectors, const std::string& filter = "c1 % 10 > 0") { core::PlanNodePtr filterNode; - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values(vectors) .filter(filter) .capturePlanNode(filterNode) @@ -49,7 +44,7 @@ class FilterProjectTest : public HiveConnectorTestBase { void assertProject(const std::vector& vectors) { core::PlanNodePtr projectNode; - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values(vectors) .project({"c0", "c1", "c0 + c1"}) .capturePlanNode(projectNode) @@ -64,49 +59,50 @@ class FilterProjectTest : public HiveConnectorTestBase { auto projectNodeId = plan->id(); auto it = planStats.find(projectNodeId); ASSERT_TRUE(it != planStats.end()); - ASSERT_TRUE(it->second.peakMemoryBytes > 0); + ASSERT_GT(it->second.peakMemoryBytes, 0); + } + + RowVectorPtr makeTestVector(vector_size_t size = 100) const { + auto rowType = ROW( + {"c0", "c1", "c2", "c3"}, {BIGINT(), INTEGER(), SMALLINT(), DOUBLE()}); + + return std::dynamic_pointer_cast( + velox::test::BatchMaker::createBatch(rowType, size, *pool_)); } - std::shared_ptr rowType_{ - ROW({"c0", "c1", "c2", "c3"}, - {BIGINT(), INTEGER(), SMALLINT(), DOUBLE()})}; + std::vector makeTestVectors() const { + std::vector vectors; + vectors.reserve(10); + for (int32_t i = 0; i < 10; ++i) { + vectors.push_back(makeTestVector()); + } + return vectors; + } }; TEST_F(FilterProjectTest, filter) { - std::vector vectors; - for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); - vectors.push_back(vector); - } + auto vectors = makeTestVectors(); createDuckDbTable(vectors); - assertFilter(vectors); } TEST_F(FilterProjectTest, filterOverDictionary) { std::vector vectors; for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); - - auto indices = - AlignedBuffer::allocate(2 * vector->size(), pool_.get()); - auto indicesPtr = indices->asMutable(); - for (int32_t j = 0; j < vector->size() / 2; j++) { - indicesPtr[2 * j] = j; - indicesPtr[2 * j + 1] = j; + auto vector = makeTestVector(); + + auto indices = allocateIndices(2 * vector->size(), pool_.get()); + auto rawIndices = indices->asMutable(); + for (int32_t j = 0; j < vector->size(); j += 2) { + rawIndices[j] = j / 2; + rawIndices[j + 1] = j / 2; } - std::vector newChildren = vector->children(); - newChildren[1] = BaseVector::wrapInDictionary( - BufferPtr(nullptr), indices, vector->size(), vector->childAt(1)); - vectors.push_back(std::make_shared( - pool_.get(), - rowType_, - BufferPtr(nullptr), - vector->size(), - newChildren, - 0 /*nullCount*/)); + + auto newChildren = vector->children(); + newChildren[1] = + wrapInDictionary(indices, vector->size(), vector->childAt(1)); + + vectors.push_back(makeRowVector(newChildren)); } createDuckDbTable(vectors); @@ -116,19 +112,13 @@ TEST_F(FilterProjectTest, filterOverDictionary) { TEST_F(FilterProjectTest, filterOverConstant) { std::vector vectors; for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); + auto vector = makeTestVector(); std::vector newChildren = vector->children(); newChildren[1] = BaseVector::wrapInConstant(vector->size(), 7, vector->childAt(1)); - vectors.push_back(std::make_shared( - pool_.get(), - rowType_, - BufferPtr(nullptr), - vector->size(), - newChildren, - 0 /*nullCount*/)); + + vectors.push_back(makeRowVector(newChildren)); } createDuckDbTable(vectors); @@ -136,40 +126,27 @@ TEST_F(FilterProjectTest, filterOverConstant) { } TEST_F(FilterProjectTest, project) { - std::vector vectors; - for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); - vectors.push_back(vector); - } + auto vectors = makeTestVectors(); createDuckDbTable(vectors); - assertProject(vectors); } TEST_F(FilterProjectTest, projectOverDictionary) { std::vector vectors; for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); - - auto indices = - AlignedBuffer::allocate(2 * vector->size(), pool_.get()); - auto indicesPtr = indices->asMutable(); - for (int32_t j = 0; j < vector->size() / 2; j++) { - indicesPtr[2 * j] = j; - indicesPtr[2 * j + 1] = j; + auto vector = makeTestVector(); + + auto indices = allocateIndices(2 * vector->size(), pool_.get()); + auto rawIndices = indices->asMutable(); + for (int32_t j = 0; j < vector->size(); j += 2) { + rawIndices[j] = j / 2; + rawIndices[j + 1] = j / 2; } + std::vector newChildren = vector->children(); - newChildren[1] = BaseVector::wrapInDictionary( - BufferPtr(nullptr), indices, vector->size(), vector->childAt(1)); - vectors.push_back(std::make_shared( - pool_.get(), - rowType_, - BufferPtr(nullptr), - vector->size(), - newChildren, - 0 /*nullCount*/)); + newChildren[1] = + wrapInDictionary(indices, vector->size(), vector->childAt(1)); + vectors.push_back(makeRowVector(newChildren)); } createDuckDbTable(vectors); @@ -179,19 +156,12 @@ TEST_F(FilterProjectTest, projectOverDictionary) { TEST_F(FilterProjectTest, projectOverConstant) { std::vector vectors; for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); + auto vector = makeTestVector(); std::vector newChildren = vector->children(); newChildren[1] = BaseVector::wrapInConstant(vector->size(), 7, vector->childAt(1)); - vectors.push_back(std::make_shared( - pool_.get(), - rowType_, - BufferPtr(nullptr), - vector->size(), - newChildren, - 0 /*nullCount*/)); + vectors.push_back(makeRowVector(newChildren)); } createDuckDbTable(vectors); @@ -218,7 +188,7 @@ TEST_F(FilterProjectTest, projectOverLazy) { createDuckDbTable({vectors}); - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values({lazyVectors}) .project({"c0 > 0 AND c1 > 0.0", "c1 + 5.2"}) .planNode(); @@ -226,15 +196,10 @@ TEST_F(FilterProjectTest, projectOverLazy) { } TEST_F(FilterProjectTest, filterProject) { - std::vector vectors; - for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); - vectors.push_back(vector); - } + auto vectors = makeTestVectors(); createDuckDbTable(vectors); - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values(vectors) .filter("c1 % 10 > 0") .project({"c0", "c1", "c0 + c1"}) @@ -244,22 +209,17 @@ TEST_F(FilterProjectTest, filterProject) { } TEST_F(FilterProjectTest, dereference) { - std::vector vectors; - for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); - vectors.push_back(vector); - } + auto vectors = makeTestVectors(); createDuckDbTable(vectors); - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values(vectors) .project({"row_constructor(c1, c2) AS c1_c2"}) .project({"c1_c2.c1", "c1_c2.c2"}) .planNode(); assertQuery(plan, "SELECT c1, c2 FROM tmp"); - plan = PlanBuilder() + plan = test::PlanBuilder() .values(vectors) .project({"row_constructor(c1, c2) AS c1_c2"}) .filter("c1_c2.c1 % 10 = 5") @@ -291,15 +251,10 @@ TEST_F(FilterProjectTest, allFailedOrPassed) { // Tests fusing of consecutive filters and projects. TEST_F(FilterProjectTest, filterProjectFused) { - std::vector vectors; - for (int32_t i = 0; i < 10; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, 100, *pool_)); - vectors.push_back(vector); - } + auto vectors = makeTestVectors(); createDuckDbTable(vectors); - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values(vectors) .filter("c0 % 10 < 9") .project({"c0", "c1", "c0 % 100 + c1 % 50 AS e1"}) @@ -332,7 +287,7 @@ TEST_F(FilterProjectTest, projectAndIdentityOverLazy) { createDuckDbTable({vectors}); - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values({lazyVectors}) .project({"c0 < 10 AND c1 < 10", "c1"}) .planNode(); @@ -350,7 +305,7 @@ TEST_F(FilterProjectTest, nestedFieldReferenceSharedChild) { }), }); auto plan = - PlanBuilder() + test::PlanBuilder() .values({vector}) .project({"coalesce((c0).c0.c0, 0) + coalesce((c0).c1.c0, 0)"}) .planNode(); @@ -358,22 +313,21 @@ TEST_F(FilterProjectTest, nestedFieldReferenceSharedChild) { for (int i = 0; i < 10; ++i) { expected->set(i, (i % 2 == 0 ? 0 : i) + (i % 3 == 0 ? 0 : i)); } - AssertQueryBuilder(plan).assertResults(makeRowVector({expected})); + test::AssertQueryBuilder(plan).assertResults(makeRowVector({expected})); } TEST_F(FilterProjectTest, numSilentThrow) { - auto row = makeRowVector( - {makeFlatVector(100, [&](auto row) { return row; })}); + auto row = makeRowVector({makeFlatVector(100, folly::identity)}); core::PlanNodeId filterId; // Change the plan when /0 error is fixed not to throw. - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values({row}) .filter("try (c0 / 0) = 1") .capturePlanNodeId(filterId) .planNode(); - auto task = AssertQueryBuilder(plan).assertEmptyResults(); + auto task = test::AssertQueryBuilder(plan).assertEmptyResults(); auto planStats = toPlanStats(task->taskStats()); ASSERT_EQ(100, planStats.at(filterId).customStats.at("numSilentThrow").sum); } @@ -388,7 +342,7 @@ TEST_F(FilterProjectTest, statsSplitter) { // row. core::PlanNodeId filterId; core::PlanNodeId projectId; - auto plan = PlanBuilder() + auto plan = test::PlanBuilder() .values(split(data, 5)) .filter("if (c0 < 25, false, c0 % 3 = 0)") .capturePlanNodeId(filterId) @@ -397,7 +351,7 @@ TEST_F(FilterProjectTest, statsSplitter) { .planNode(); std::shared_ptr task; - AssertQueryBuilder(plan).runWithoutResults(task); + test::AssertQueryBuilder(plan).countResults(task); auto planStats = toPlanStats(task->taskStats()); @@ -423,24 +377,18 @@ TEST_F(FilterProjectTest, statsSplitter) { TEST_F(FilterProjectTest, barrier) { std::vector vectors; - std::vector> tempFiles; + std::vector> tempFiles; const int numSplits{5}; - const int numRowsPerSplit{100}; - for (int32_t i = 0; i < 5; ++i) { - auto vector = std::dynamic_pointer_cast( - BatchMaker::createBatch(rowType_, numRowsPerSplit, *pool_)); - vectors.push_back(vector); - tempFiles.push_back(TempFilePath::create()); + for (int32_t i = 0; i < numSplits; ++i) { + vectors.push_back(makeTestVector()); + tempFiles.push_back(test::TempFilePath::create()); } writeToFiles(toFilePaths(tempFiles), vectors); createDuckDbTable(vectors); - auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId projectPlanNodeId; - auto plan = PlanBuilder(planNodeIdGenerator) - .startTableScan() - .outputType(rowType_) - .endTableScan() + auto plan = test::PlanBuilder() + .tableScan(vectors.front()->rowType()) .filter("c1 % 10 > 0") .project({"c0", "c1", "c0 + c1"}) .capturePlanNodeId(projectPlanNodeId) @@ -459,17 +407,16 @@ TEST_F(FilterProjectTest, barrier) { for (const auto& testData : testSettings) { SCOPED_TRACE(testData.toString()); auto task = - AssertQueryBuilder(plan, duckDbQueryRunner_) - .config(core::QueryConfig::kSparkPartitionId, "0") + test::AssertQueryBuilder(plan, duckDbQueryRunner_) .config( core::QueryConfig::kMaxSplitPreloadPerDriver, std::to_string(tempFiles.size())) - .splits(makeHiveConnectorSplits(tempFiles)) - .serialExecution(true) - .barrierExecution(testData.barrierExecution) .config( core::QueryConfig::kPreferredOutputBatchRows, std::to_string(testData.numOutputRows)) + .splits(makeHiveConnectorSplits(tempFiles)) + .serialExecution(true) + .barrierExecution(testData.barrierExecution) .assertResults("SELECT c0, c1, c0 + c1 FROM tmp WHERE c1 % 10 > 0"); const auto taskStats = task->taskStats(); ASSERT_EQ(taskStats.numBarriers, testData.barrierExecution ? numSplits : 0); @@ -482,3 +429,42 @@ TEST_F(FilterProjectTest, barrier) { numSplits); } } + +TEST_F(FilterProjectTest, lazyDereference) { + constexpr int kSize = 10; + VectorPtr expected[] = { + makeFlatVector(kSize, [](auto i) { return i; }), + makeFlatVector(kSize, [](auto i) { return i + kSize; }), + makeFlatVector(kSize, [](auto i) { return i + kSize * 2; }), + }; + auto vector = makeRowVector({ + makeRowVector({expected[0], expected[1]}), + makeRowVector({makeRowVector({expected[2]})}), + }); + auto file = test::TempFilePath::create(); + writeToFile(file->getPath(), vector); + CursorParameters params; + params.copyResult = false; + params.serialExecution = true; + params.planNode = test::PlanBuilder() + .tableScan(vector->rowType()) + .lazyDereference({"c0.c0", "c0.c1", "(c1).c0.c0"}) + .planNode(); + auto cursor = TaskCursor::create(params); + cursor->task()->addSplit( + "0", exec::Split(makeHiveConnectorSplit(file->getPath()))); + cursor->task()->noMoreSplits("0"); + ASSERT_TRUE(cursor->moveNext()); + auto* result = cursor->current()->asChecked(); + ASSERT_EQ(result->size(), kSize); + ASSERT_EQ(result->childrenSize(), 3); + for (int i = 0; i < result->childrenSize(); ++i) { + auto* lazy = result->childAt(i)->asChecked(); + ASSERT_FALSE(lazy->isLoaded()); + ASSERT_EQ(lazy->size(), kSize); + velox::test::assertEqualVectors(expected[i], lazy->loadedVectorShared()); + } +} + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/FilterToExpressionTest.cpp b/velox/exec/tests/FilterToExpressionTest.cpp new file mode 100644 index 000000000000..8f84e04fede5 --- /dev/null +++ b/velox/exec/tests/FilterToExpressionTest.cpp @@ -0,0 +1,425 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/exec/tests/utils/FilterToExpression.h" +#include +#include "velox/core/Expressions.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +namespace facebook::velox::core::test { + +class FilterToExpressionTest : public testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + void verifyExpr( + const TypedExprPtr& expr, + const std::string& expectedType, + const std::string& expectedName) { + ASSERT_TRUE(expr != nullptr); + ASSERT_EQ(expr->type()->toString(), expectedType); + + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_TRUE(callExpr != nullptr); + ASSERT_EQ(callExpr->name(), expectedName); + } + + TypedExprPtr toExpr(const common::Filter* filter, const TypePtr& type) { + common::Subfield subfield("a"); + return filterToExpr(subfield, filter, ROW({"a"}, {type}), pool()); + } +}; + +TEST_F(FilterToExpressionTest, alwaysTrue) { + auto filter = std::make_unique(); + auto expr = toExpr(filter.get(), BIGINT()); + + ASSERT_TRUE(expr != nullptr); + ASSERT_EQ(expr->type()->toString(), "BOOLEAN"); + + auto constantExpr = std::dynamic_pointer_cast(expr); + ASSERT_TRUE(constantExpr != nullptr); + ASSERT_TRUE(constantExpr->value().value()); +} + +TEST_F(FilterToExpressionTest, alwaysFalse) { + auto filter = std::make_unique(); + auto expr = toExpr(filter.get(), BIGINT()); + + ASSERT_TRUE(expr != nullptr); + ASSERT_EQ(expr->type()->toString(), "BOOLEAN"); + + auto constantExpr = std::dynamic_pointer_cast(expr); + ASSERT_TRUE(constantExpr != nullptr); + ASSERT_FALSE(constantExpr->value().value()); +} + +TEST_F(FilterToExpressionTest, isNull) { + auto filter = std::make_unique(); + auto expr = toExpr(filter.get(), BIGINT()); + + verifyExpr(expr, "BOOLEAN", "is_null"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 1); +} + +TEST_F(FilterToExpressionTest, isNotNull) { + auto filter = std::make_unique(); + auto expr = toExpr(filter.get(), BIGINT()); + + verifyExpr(expr, "BOOLEAN", "not"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 1); + + // Verify the inner expression is an IS_NULL operation + auto isNullExpr = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(isNullExpr != nullptr); + ASSERT_EQ(isNullExpr->name(), "is_null"); +} + +TEST_F(FilterToExpressionTest, boolValue) { + auto filter = std::make_unique(true, false); + auto expr = toExpr(filter.get(), BOOLEAN()); + + verifyExpr(expr, "BOOLEAN", "eq"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + // First input should be the field access expression + auto fieldExpr = callExpr->inputs()[0]; + ASSERT_TRUE(fieldExpr != nullptr); + + // Second input should be the boolean constant + auto constantExpr = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(constantExpr != nullptr); + ASSERT_EQ(constantExpr->value().value(), true); +} + +TEST_F(FilterToExpressionTest, bigintRangeSingleValue) { + auto filter = std::make_unique(42, 42, false); + auto expr = toExpr(filter.get(), BIGINT()); + + verifyExpr(expr, "BOOLEAN", "eq"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + auto constantExpr = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(constantExpr != nullptr); + ASSERT_EQ(constantExpr->value().value(), 42); +} + +TEST_F(FilterToExpressionTest, bigintRangeWithRange) { + auto filter = std::make_unique(10, 20, false); + auto expr = toExpr(filter.get(), BIGINT()); + + verifyExpr(expr, "BOOLEAN", "and"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + auto greaterOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(greaterOrEqual != nullptr); + ASSERT_EQ(greaterOrEqual->name(), "gte"); + + auto lessOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(lessOrEqual != nullptr); + ASSERT_EQ(lessOrEqual->name(), "lte"); +} + +TEST_F(FilterToExpressionTest, negatedBigintRangeSingleValue) { + auto filter = std::make_unique(42, 42, false); + auto expr = toExpr(filter.get(), BIGINT()); + + // The implementation now uses getNonNegated() which creates a NOT expression + // even for single values, so we expect "not" instead of "neq" + verifyExpr(expr, "BOOLEAN", "not"); + auto notExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(notExpr->inputs().size(), 1); + + // The inner expression might be an OR expression due to handleNullAllowed + auto innerExpr = + std::dynamic_pointer_cast(notExpr->inputs()[0]); + ASSERT_TRUE(innerExpr != nullptr); + + if (innerExpr->name() == "or") { + // If it's an OR expression, the first input should be the EQ operation + ASSERT_EQ(innerExpr->inputs().size(), 2); + auto eqExpr = + std::dynamic_pointer_cast(innerExpr->inputs()[0]); + ASSERT_TRUE(eqExpr != nullptr); + ASSERT_EQ(eqExpr->name(), "eq"); + ASSERT_EQ(eqExpr->inputs().size(), 2); + + // Verify the constant value is 42 + auto constantExpr = + std::dynamic_pointer_cast(eqExpr->inputs()[1]); + ASSERT_TRUE(constantExpr != nullptr); + ASSERT_EQ(constantExpr->value().value(), 42); + } else if (innerExpr->name() == "eq") { + // If it's directly an EQ expression + ASSERT_EQ(innerExpr->inputs().size(), 2); + + // Verify the constant value is 42 + auto constantExpr = std::dynamic_pointer_cast( + innerExpr->inputs()[1]); + ASSERT_TRUE(constantExpr != nullptr); + ASSERT_EQ(constantExpr->value().value(), 42); + } else { + FAIL() << "Expected either 'or' or 'eq' expression, got: " + << innerExpr->name(); + } +} + +TEST_F(FilterToExpressionTest, doubleRange) { + auto filter = std::make_unique( + 1.5, false, false, 3.5, false, false, false); + auto expr = toExpr(filter.get(), DOUBLE()); + + verifyExpr(expr, "BOOLEAN", "and"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + auto greaterOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(greaterOrEqual != nullptr); + ASSERT_EQ(greaterOrEqual->name(), "gte"); + + auto lessOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(lessOrEqual != nullptr); + ASSERT_EQ(lessOrEqual->name(), "lte"); +} + +TEST_F(FilterToExpressionTest, floatRange) { + auto filter = std::make_unique( + 1.5f, false, true, 3.5f, false, true, false); + auto expr = toExpr(filter.get(), REAL()); + + verifyExpr(expr, "BOOLEAN", "and"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + auto greaterThan = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(greaterThan != nullptr); + ASSERT_EQ(greaterThan->name(), "gt"); + + auto lessThan = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(lessThan != nullptr); + ASSERT_EQ(lessThan->name(), "lt"); +} + +TEST_F(FilterToExpressionTest, bytesRange) { + auto filter = std::make_unique( + "apple", false, false, "orange", false, false, false); + auto expr = toExpr(filter.get(), VARCHAR()); + + verifyExpr(expr, "BOOLEAN", "and"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + auto greaterOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(greaterOrEqual != nullptr); + ASSERT_EQ(greaterOrEqual->name(), "gte"); + + auto lessOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(lessOrEqual != nullptr); + ASSERT_EQ(lessOrEqual->name(), "lte"); +} + +TEST_F(FilterToExpressionTest, bigintValuesUsingHashTable) { + std::vector values = {10, 20, 30}; + auto filter = common::createBigintValues(values, false); + auto expr = toExpr(filter.get(), BIGINT()); + + // The implementation creates an optimized expression: (range check) AND (in + // check) + verifyExpr(expr, "BOOLEAN", "and"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + // First input should be the range check (field >= min AND field <= max) + auto rangeCheckExpr = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(rangeCheckExpr != nullptr); + ASSERT_EQ(rangeCheckExpr->name(), "and"); + + // Second input should be the IN expression + auto inExpr = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(inExpr != nullptr); + ASSERT_EQ(inExpr->name(), "in"); + ASSERT_EQ(inExpr->inputs().size(), 2); + + auto arrayExpr = + std::dynamic_pointer_cast(inExpr->inputs()[1]); + ASSERT_TRUE(arrayExpr != nullptr); + ASSERT_EQ(arrayExpr->name(), "array_constructor"); + ASSERT_EQ(arrayExpr->inputs().size(), 3); +} + +TEST_F(FilterToExpressionTest, bytesValues) { + std::vector values = {"apple", "banana", "orange"}; + auto filter = std::make_unique(values, false); + auto expr = toExpr(filter.get(), VARCHAR()); + + verifyExpr(expr, "BOOLEAN", "in"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + auto arrayExpr = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(arrayExpr != nullptr); + ASSERT_EQ(arrayExpr->name(), "array_constructor"); + ASSERT_EQ(arrayExpr->inputs().size(), 3); +} + +TEST_F(FilterToExpressionTest, negatedBytesValues) { + std::vector values = {"apple", "banana", "orange"}; + auto filter = std::make_unique(values, false); + auto expr = toExpr(filter.get(), VARCHAR()); + + verifyExpr(expr, "BOOLEAN", "not"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 1); + + auto containsExpr = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(containsExpr != nullptr); + + ASSERT_TRUE(containsExpr->name() == "in" || containsExpr->name() == "or"); +} + +TEST_F(FilterToExpressionTest, negatedBigintValuesUsingHashTable) { + std::vector values = {10, 20, 30}; + auto filter = std::make_unique( + 10, 30, values, false); + auto expr = toExpr(filter.get(), BIGINT()); + + // The implementation creates a NOT expression for the optimized IN check + verifyExpr(expr, "BOOLEAN", "not"); + auto notExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(notExpr->inputs().size(), 1); + + // The input should be an OR expression + auto orExpr = + std::dynamic_pointer_cast(notExpr->inputs()[0]); + ASSERT_TRUE(orExpr != nullptr); + ASSERT_EQ(orExpr->name(), "or"); + ASSERT_EQ(orExpr->inputs().size(), 2); + + // First input of OR should be range check + auto rangeCheckExpr = + std::dynamic_pointer_cast(orExpr->inputs()[0]); + ASSERT_TRUE(rangeCheckExpr != nullptr); + ASSERT_EQ(rangeCheckExpr->name(), "and"); + + // Second input of OR should be IS_NULL expression + auto isNullExpr = + std::dynamic_pointer_cast(orExpr->inputs()[1]); + ASSERT_TRUE(isNullExpr != nullptr); + ASSERT_EQ(isNullExpr->name(), "is_null"); +} + +TEST_F(FilterToExpressionTest, timestampRange) { + auto timestamp1 = Timestamp::fromMillis(1609459200000); // 2021-01-01 + auto timestamp2 = Timestamp::fromMillis(1640995200000); // 2022-01-01 + auto filter = + std::make_unique(timestamp1, timestamp2, false); + auto expr = toExpr(filter.get(), TIMESTAMP()); + + verifyExpr(expr, "BOOLEAN", "and"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); + + auto greaterOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(greaterOrEqual != nullptr); + ASSERT_EQ(greaterOrEqual->name(), "gte"); + + auto lessOrEqual = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(lessOrEqual != nullptr); + ASSERT_EQ(lessOrEqual->name(), "lte"); +} + +TEST_F(FilterToExpressionTest, bigintMultiRange) { + std::vector> ranges; + ranges.push_back(std::make_unique(10, 20, false)); + ranges.push_back(std::make_unique(30, 40, false)); + auto filter = + std::make_unique(std::move(ranges), false); + auto expr = toExpr(filter.get(), BIGINT()); + + verifyExpr(expr, "BOOLEAN", "or"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 2); +} + +TEST_F(FilterToExpressionTest, multiRange) { + // Create a MultiRange filter with compatible filters for BIGINT field + std::vector> filters; + + // Add a BigintRange filter + filters.push_back(std::make_unique(10, 20, false)); + + // Add an IsNull filter + filters.push_back(std::make_unique()); + + // Add another BigintRange filter instead of BytesRange to avoid type mismatch + filters.push_back(std::make_unique(30, 40, false)); + + auto filter = std::make_unique(std::move(filters), false); + auto expr = toExpr(filter.get(), BIGINT()); + + // Verify the top-level expression is an OR + verifyExpr(expr, "BOOLEAN", "or"); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_EQ(callExpr->inputs().size(), 3); + + // Verify the first input is a BigintRange expression (AND of + // greater_than_or_equal and less_than_or_equal) + auto firstInput = + std::dynamic_pointer_cast(callExpr->inputs()[0]); + ASSERT_TRUE(firstInput != nullptr); + ASSERT_EQ(firstInput->name(), "and"); + ASSERT_EQ(firstInput->inputs().size(), 2); + + // Verify the second input is an IsNull expression + auto secondInput = + std::dynamic_pointer_cast(callExpr->inputs()[1]); + ASSERT_TRUE(secondInput != nullptr); + ASSERT_EQ(secondInput->name(), "is_null"); + + // Verify the third input is another BigintRange expression (AND of + // greater_than_or_equal and less_than_or_equal) + auto thirdInput = + std::dynamic_pointer_cast(callExpr->inputs()[2]); + ASSERT_TRUE(thirdInput != nullptr); + ASSERT_EQ(thirdInput->name(), "and"); + ASSERT_EQ(thirdInput->inputs().size(), 2); +} + +} // namespace facebook::velox::core::test diff --git a/velox/exec/tests/FunctionSignatureBuilderTest.cpp b/velox/exec/tests/FunctionSignatureBuilderTest.cpp index 6ac14b0592ad..832f32c957aa 100644 --- a/velox/exec/tests/FunctionSignatureBuilderTest.cpp +++ b/velox/exec/tests/FunctionSignatureBuilderTest.cpp @@ -45,11 +45,12 @@ TEST_F(FunctionSignatureBuilderTest, basicTypeTests) { // Integer variables do not have to be used in the inputs, but in that case // must appear in the return. - ASSERT_NO_THROW(FunctionSignatureBuilder() - .integerVariable("a") - .returnType("DECIMAL(a, a)") - .argumentType("integer") - .build();); + ASSERT_NO_THROW( + FunctionSignatureBuilder() + .integerVariable("a") + .returnType("DECIMAL(a, a)") + .argumentType("integer") + .build();); VELOX_ASSERT_THROW( FunctionSignatureBuilder() @@ -124,7 +125,7 @@ TEST_F(FunctionSignatureBuilderTest, typeParamTests) { .returnType("integer") .argumentType("row(..., varchar)") .build(), - "Failed to parse type signature [row(..., varchar)]: syntax error, unexpected COMMA"); + "Failed to parse type signature [row(..., varchar)]: syntax error, unexpected ELLIPSIS"); // Type params cant have type params. VELOX_ASSERT_THROW( @@ -155,6 +156,60 @@ TEST_F(FunctionSignatureBuilderTest, anyInReturn) { "Type 'Any' cannot appear in return type"); } +TEST_F(FunctionSignatureBuilderTest, homogeneousRowInReturn) { + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("row(T, ...)") + .argumentType("T") + .build(), + "Homogeneous row cannot appear in return type"); + + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("array(row(bigint, ...))") + .argumentType("bigint") + .build(), + "Homogeneous row cannot appear in return type"); +} + +TEST_F(FunctionSignatureBuilderTest, variableArity) { + // .variableArity() requires at least one argument. + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .variableArity() + .build(), + "Variable arity requires at least one argument"); + + // .variableArity() can be used only once. + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .variableArity("bigint") + .variableArity("integer") + .build(), + "Cannot add arguments after variable arity argument"); + + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .variableArity() + .variableArity() + .build(), + "Only one variable arity argument is allowed"); + + // No arguments can be added after calling .variableArity(). + VELOX_ASSERT_USER_THROW( + exec::FunctionSignatureBuilder() + .returnType("bigint") + .variableArity("bigint") + .argumentType("boolean") + .build(), + "Cannot add arguments after variable arity argument"); +} + TEST_F(FunctionSignatureBuilderTest, scalarConstantFlags) { { auto signature = FunctionSignatureBuilder() diff --git a/velox/exec/tests/GroupedExecutionTest.cpp b/velox/exec/tests/GroupedExecutionTest.cpp index 8263e17bcaed..da3dd53ee2e8 100644 --- a/velox/exec/tests/GroupedExecutionTest.cpp +++ b/velox/exec/tests/GroupedExecutionTest.cpp @@ -341,7 +341,7 @@ TEST_F(GroupedExecutionTest, hashJoinWithMixedGroupedExecution) { return fmt::format( "mode {}, joinType {}, supported {}", modeToString(mode), - core::joinTypeName(joinType), + core::JoinTypeName::toName(joinType), supported); } }; @@ -449,7 +449,7 @@ TEST_F(GroupedExecutionTest, hashJoinWithMixedGroupedExecution) { task->start(3, 1), fmt::format( "Hash join currently does not support mixed grouped execution for join type {}", - core::joinTypeName(testData.joinType))); + core::JoinTypeName::toName(testData.joinType))); continue; } @@ -675,16 +675,25 @@ DEBUG_ONLY_TEST_F( } })); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + std::optional spillOpts; + if (testData.enableSpill) { + spillOpts = common::SpillDiskOptions{ + .spillDirPath = spillDirectory->getPath(), + .spillDirCreated = true, + .spillDirCreateCb = nullptr}; + } + auto task = exec::Task::create( "0", std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - if (testData.enableSpill) { - task->setSpillDirectory(spillDirectory->getPath()); - } + Task::ExecutionMode::kParallel, + /*consumer=*/Consumer{}, + /*memoryArbitrationPriority=*/0, + spillOpts, + /*onError=*/nullptr); // 'numDriversPerGroup' drivers max to execute one group at a time. task->start(numDriversPerGroup, testData.groupConcurrency); @@ -817,15 +826,21 @@ DEBUG_ONLY_TEST_F( memory::testingRunArbitration(op->pool()); })); + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + common::SpillDiskOptions spillOpts{ + .spillDirPath = spillDirectory->getPath(), + .spillDirCreated = true, + .spillDirCreateCb = nullptr}; + auto task = exec::Task::create( "0", std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - - task->setSpillDirectory(spillDirectory->getPath()); + Task::ExecutionMode::kParallel, + Consumer{}, + /*memoryArbitrationPriority=*/0, + spillOpts); // 'numDriversPerGroup' drivers max to execute one group at a time. task->start(numDriversPerGroup, 1); @@ -848,8 +863,8 @@ DEBUG_ONLY_TEST_F( } // Total drivers should be numDriversPerGroup * (numGroups + 1), but since - // probe does not receive termination signal, it cannot signal the build side - // to finish. we expect only build's numDriversPerGroup finished. + // probe does not receive termination signal, it cannot signal the build + // side to finish. we expect only build's numDriversPerGroup finished. waitForFinishedDrivers(task, numDriversPerGroup); // 'Delete results' from output buffer triggers 'set all output consumed', @@ -1025,8 +1040,8 @@ TEST_F(GroupedExecutionTest, groupedExecutionWithHashAndNestedLoopJoin) { const std::unordered_set expectedSplitGroupIds({1, 5, 8}); int numSplitGroupJoinNodes{0}; task->pool()->visitChildren([&](memory::MemoryPool* childPool) -> bool { - if (folly::StringPiece(childPool->name()) - .startsWith(fmt::format("node.{}[", joinNodeId))) { + if (childPool->name().starts_with( + fmt::format("node.{}[", joinNodeId))) { ++numSplitGroupJoinNodes; std::vector parts; folly::split(".", childPool->name(), parts); diff --git a/velox/exec/tests/HashJoinBridgeTest.cpp b/velox/exec/tests/HashJoinBridgeTest.cpp index d5a6e5b24072..757e1658491a 100644 --- a/velox/exec/tests/HashJoinBridgeTest.cpp +++ b/velox/exec/tests/HashJoinBridgeTest.cpp @@ -44,6 +44,7 @@ class HashJoinBridgeTestHelper { HashJoinBridge* const bridge_; }; +namespace { struct TestParam { int32_t numProbers{1}; int32_t numBuilders{1}; @@ -670,12 +671,16 @@ TEST_P(HashJoinBridgeTest, hashJoinTableType) { std::vector buildKeys; std::vector probeKeys; for (uint32_t i = 0; i < testData.buildKeyType->size(); i++) { - buildKeys.push_back(std::make_shared( - testData.buildKeyType->childAt(i), testData.buildKeyType->nameOf(i))); + buildKeys.push_back( + std::make_shared( + testData.buildKeyType->childAt(i), + testData.buildKeyType->nameOf(i))); } for (uint32_t i = 0; i < testData.probeKeyType->size(); i++) { - probeKeys.push_back(std::make_shared( - testData.probeKeyType->childAt(i), testData.probeKeyType->nameOf(i))); + probeKeys.push_back( + std::make_shared( + testData.probeKeyType->childAt(i), + testData.probeKeyType->nameOf(i))); } const auto joinNode = std::make_shared( "join-bridge-test", @@ -707,7 +712,7 @@ TEST(HashJoinBridgeTest, hashJoinTableSpillType) { std::string debugString() const { return fmt::format( "joinType: {}, expectedTableSpillType: {}", - joinTypeName(joinType), + core::JoinTypeName::toName(joinType), expectedTableSpillType->toString()); } } testSettings[] = { @@ -724,4 +729,5 @@ TEST(HashJoinBridgeTest, hashJoinTableSpillType) { ASSERT_EQ(spillType->names(), testData.expectedTableSpillType->names()); } } +} // namespace } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 4ff343979326..ff76fa0a16ba 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -33,38 +33,47 @@ #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/exec/tests/utils/VectorTestUtil.h" +#include "velox/type/tests/utils/CustomTypesForTesting.h" #include "velox/vector/fuzzer/VectorFuzzer.h" using namespace facebook::velox; -using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; using namespace facebook::velox::common::testutil; using facebook::velox::test::BatchMaker; +namespace facebook::velox::exec { namespace { -class HashJoinTest : public HashJoinTestBase { +class HashJoinTest : public HashJoinTestBase, + public testing::WithParamInterface { public: - HashJoinTest() : HashJoinTestBase(TestParam(1)) {} + HashJoinTest() : HashJoinTestBase(GetParam()) {} explicit HashJoinTest(const TestParam& param) : HashJoinTestBase(param) {} + + static std::vector getTestParams() { + return std::vector({TestParam{1, false}, TestParam{1, true}}); + } }; -class MultiThreadedHashJoinTest - : public HashJoinTest, - public testing::WithParamInterface { +class MultiThreadedHashJoinTest : public HashJoinTest { public: MultiThreadedHashJoinTest() : HashJoinTest(GetParam()) {} static std::vector getTestParams() { - return std::vector({TestParam{1}, TestParam{3}}); + return std::vector( + {TestParam{1, false}, + TestParam{1, true}, + TestParam{3, false}, + TestParam{3, true}}); } }; TEST_P(MultiThreadedHashJoinTest, bigintArray) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -76,6 +85,7 @@ TEST_P(MultiThreadedHashJoinTest, bigintArray) { TEST_P(MultiThreadedHashJoinTest, outOfJoinKeyColumnOrder) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeKeys({"t_k2"}) .probeVectors(5, 10) @@ -91,6 +101,7 @@ TEST_P(MultiThreadedHashJoinTest, outOfJoinKeyColumnOrder) { TEST_P(MultiThreadedHashJoinTest, joinWithCancellation) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -108,6 +119,7 @@ TEST_P(MultiThreadedHashJoinTest, testJoinWithSpillenabledCancellation) { auto spillDirectory = exec::test::TempDirectoryPath::create(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -128,6 +140,7 @@ TEST_P(MultiThreadedHashJoinTest, emptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(0, 5) @@ -159,6 +172,7 @@ TEST_P(MultiThreadedHashJoinTest, emptyBuild) { TEST_P(MultiThreadedHashJoinTest, emptyProbe) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(0, 5) .buildVectors(1500, 5) @@ -195,6 +209,7 @@ TEST_P(MultiThreadedHashJoinTest, emptyProbe) { TEST_P(MultiThreadedHashJoinTest, normalizedKey) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT(), VARCHAR()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -220,6 +235,7 @@ DEBUG_ONLY_TEST_P(MultiThreadedHashJoinTest, parallelJoinBuildCheck) { std::function([&](void*) { isParallelBuild = true; })); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT(), VARCHAR()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -250,6 +266,7 @@ DEBUG_ONLY_TEST_P( VELOX_ASSERT_THROW( HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT(), VARCHAR()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -280,6 +297,7 @@ TEST_P(MultiThreadedHashJoinTest, allTypes) { TEST_P(MultiThreadedHashJoinTest, filter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -312,6 +330,7 @@ DEBUG_ONLY_TEST_P(MultiThreadedHashJoinTest, filterSpillOnFirstProbeInput) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .numDrivers(1) .probeVectors(1600, 5) @@ -362,6 +381,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithNull) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeKeys({"t_k2"}) .probeVectors(std::move(probeVectors)) @@ -401,6 +421,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithLargeOutput) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -452,6 +473,7 @@ TEST_P(MultiThreadedHashJoinTest, arrayBasedLookup) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"c0"}) @@ -518,6 +540,7 @@ TEST_P(MultiThreadedHashJoinTest, joinSidesDifferentSchema) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t_c0"}) .probeVectors(std::move(probeVectors)) .probeProjections({"c0 AS t_c0", "c1 AS t_c1", "c2 AS t_c2"}) @@ -557,6 +580,7 @@ TEST_P(MultiThreadedHashJoinTest, innerJoinWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"c0"}) @@ -591,6 +615,7 @@ TEST_P(MultiThreadedHashJoinTest, innerJoinWithEmptyBuild) { TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeVectors(174, 5) .probeKeys({"t_k1"}) @@ -627,6 +652,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"c0"}) @@ -668,6 +694,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -684,6 +711,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -700,6 +728,7 @@ TEST_P(MultiThreadedHashJoinTest, leftSemiJoinFilterWithExtraFilter) { TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeType(probeType_) .probeVectors(133, 3) .probeKeys({"t_k1"}) @@ -741,6 +770,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -798,6 +828,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithAllMatches) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -833,6 +864,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -855,6 +887,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -876,6 +909,7 @@ TEST_P(MultiThreadedHashJoinTest, rightSemiJoinFilterWithExtraFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -1019,6 +1053,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoin) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"c0"}) @@ -1039,6 +1074,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoin) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"c0"}) @@ -1059,6 +1095,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoin) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"c0"}) @@ -1096,6 +1133,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -1150,6 +1188,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::vector(probeVectors)) .buildKeys({"u0"}) @@ -1209,6 +1248,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndNullKey) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -1269,6 +1309,7 @@ TEST_P( auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u0"}) @@ -1307,6 +1348,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) { }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -1357,6 +1399,7 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) { }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u0"}) @@ -1405,6 +1448,7 @@ TEST_P(MultiThreadedHashJoinTest, antiJoin) { }); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::vector(probeVectors)) .buildKeys({"u0"}) @@ -1433,6 +1477,7 @@ TEST_P(MultiThreadedHashJoinTest, antiJoin) { for (const std::string& filter : filters) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::vector(probeVectors)) .buildKeys({"u0"}) @@ -1440,9 +1485,10 @@ TEST_P(MultiThreadedHashJoinTest, antiJoin) { .joinType(core::JoinType::kAnti) .joinFilter(filter) .joinOutputLayout({"t0", "t1"}) - .referenceQuery(fmt::format( - "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0 AND {})", - filter)) + .referenceQuery( + fmt::format( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0 AND {})", + filter)) .run(); } } @@ -1472,6 +1518,7 @@ TEST_P(MultiThreadedHashJoinTest, antiJoinWithFilterAndEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"t0"}) .probeVectors(std::vector(probeVectors)) .buildKeys({"u0"}) @@ -1546,6 +1593,7 @@ TEST_P(MultiThreadedHashJoinTest, leftJoin) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -1601,6 +1649,7 @@ TEST_P(MultiThreadedHashJoinTest, nullStatsWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -1684,6 +1733,7 @@ TEST_P(MultiThreadedHashJoinTest, leftJoinWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -1745,6 +1795,7 @@ TEST_P(MultiThreadedHashJoinTest, leftJoinWithNoJoin) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -1803,6 +1854,7 @@ TEST_P(MultiThreadedHashJoinTest, leftJoinWithAllMatch) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .probeFilter("c0 < 5") @@ -1866,6 +1918,7 @@ TEST_P(MultiThreadedHashJoinTest, leftJoinWithFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u_c0"}) @@ -1885,6 +1938,7 @@ TEST_P(MultiThreadedHashJoinTest, leftJoinWithFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u_c0"}) @@ -1937,6 +1991,7 @@ TEST_P(MultiThreadedHashJoinTest, leftJoinWithNullableFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -1988,6 +2043,7 @@ TEST_P(MultiThreadedHashJoinTest, rightJoin) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -2043,6 +2099,7 @@ TEST_P(MultiThreadedHashJoinTest, rightJoinWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -2095,6 +2152,7 @@ TEST_P(MultiThreadedHashJoinTest, rightJoinWithAllMatch) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -2150,6 +2208,7 @@ TEST_P(MultiThreadedHashJoinTest, rightJoinWithFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u_c0"}) @@ -2169,6 +2228,7 @@ TEST_P(MultiThreadedHashJoinTest, rightJoinWithFilter) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u_c0"}) @@ -2222,6 +2282,7 @@ TEST_P(MultiThreadedHashJoinTest, fullJoin) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -2277,6 +2338,7 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithEmptyBuild) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .hashProbeFinishEarlyOnEmptyBuild(finishOnEmpty) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -2330,6 +2392,7 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithNoMatch) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(probeVectors)) .buildKeys({"u_c0"}) @@ -2385,6 +2448,7 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithFilters) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u_c0"}) @@ -2404,6 +2468,7 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithFilters) { auto testBuildVectors = buildVectors; HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .probeKeys({"c0"}) .probeVectors(std::move(testProbeVectors)) .buildKeys({"u_c0"}) @@ -2421,6 +2486,7 @@ TEST_P(MultiThreadedHashJoinTest, fullJoinWithFilters) { TEST_P(MultiThreadedHashJoinTest, noSpillLevelLimit) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({INTEGER()}) .probeVectors(1600, 5) .buildVectors(1500, 5) @@ -2441,7 +2507,7 @@ TEST_P(MultiThreadedHashJoinTest, noSpillLevelLimit) { // Verify that dynamic filter pushed down is turned off for null-aware right // semi project join. -TEST_F(HashJoinTest, nullAwareRightSemiProjectOverScan) { +TEST_P(HashJoinTest, nullAwareRightSemiProjectOverScan) { std::vector probes; std::vector builds; // Matches present: @@ -2515,7 +2581,7 @@ TEST_F(HashJoinTest, nullAwareRightSemiProjectOverScan) { } } -TEST_F(HashJoinTest, duplicateJoinKeys) { +TEST_P(HashJoinTest, duplicateJoinKeys) { auto leftVectors = makeBatches(3, [&](int32_t /*unused*/) { return makeRowVector({ makeNullableFlatVector( @@ -2597,7 +2663,7 @@ TEST_F(HashJoinTest, duplicateJoinKeys) { } } -TEST_F(HashJoinTest, semiProject) { +TEST_P(HashJoinTest, semiProject) { // Some keys have multiple rows: 2, 3, 5. auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { return makeRowVector({ @@ -2713,7 +2779,7 @@ TEST_F(HashJoinTest, semiProject) { .run(); } -TEST_F(HashJoinTest, semiProjectWithNullKeys) { +TEST_P(HashJoinTest, semiProjectWithNullKeys) { // Some keys have multiple rows: 2, 3, 5. auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { return makeRowVector( @@ -2916,7 +2982,7 @@ TEST_F(HashJoinTest, semiProjectWithNullKeys) { .run(); } -TEST_F(HashJoinTest, semiProjectWithFilter) { +TEST_P(HashJoinTest, semiProjectWithFilter) { auto probeVectors = makeBatches(3, [&](auto /*unused*/) { return makeRowVector( {"t0", "t1"}, @@ -2965,8 +3031,10 @@ TEST_F(HashJoinTest, semiProjectWithFilter) { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(plan) - .referenceQuery(fmt::format( - "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", filter)) + .referenceQuery( + fmt::format( + "SELECT t0, t1, t0 IN (SELECT u0 FROM u WHERE {}) FROM t", + filter)) .injectSpill(false) .run(); @@ -2976,15 +3044,16 @@ TEST_F(HashJoinTest, semiProjectWithFilter) { // these values. HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(plan) - .referenceQuery(fmt::format( - "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE (u0 is not null OR t0 is not null) AND u0 = t0 AND {}) FROM t", - filter)) + .referenceQuery( + fmt::format( + "SELECT t0, t1, EXISTS (SELECT * FROM u WHERE (u0 is not null OR t0 is not null) AND u0 = t0 AND {}) FROM t", + filter)) .injectSpill(false) .run(); } } -TEST_F(HashJoinTest, nullAwareRightSemiProjectWithFilterNotAllowed) { +TEST_P(HashJoinTest, nullAwareRightSemiProjectWithFilterNotAllowed) { auto probe = makeRowVector(ROW({"t0", "t1"}, {INTEGER(), BIGINT()}), 10); auto build = makeRowVector(ROW({"u0", "u1"}, {INTEGER(), BIGINT()}), 10); @@ -3003,7 +3072,7 @@ TEST_F(HashJoinTest, nullAwareRightSemiProjectWithFilterNotAllowed) { "Null-aware right semi project join doesn't support extra filter"); } -TEST_F(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { +TEST_P(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { std::vector probeVectors; std::vector buildVectors; probeVectors.push_back(makeRowVector( @@ -3078,7 +3147,7 @@ TEST_F(HashJoinTest, leftSemiJoinWithExtraOutputCapacity) { } } -TEST_F(HashJoinTest, nullAwareMultiKeyNotAllowed) { +TEST_P(HashJoinTest, nullAwareMultiKeyNotAllowed) { auto probe = makeRowVector( ROW({"t0", "t1", "t2"}, {INTEGER(), BIGINT(), VARCHAR()}), 10); auto build = makeRowVector( @@ -3128,7 +3197,7 @@ TEST_F(HashJoinTest, nullAwareMultiKeyNotAllowed) { "Null-aware joins allow only one join key"); } -TEST_F(HashJoinTest, semiProjectOverLazyVectors) { +TEST_P(HashJoinTest, semiProjectOverLazyVectors) { auto probeVectors = makeBatches(1, [&](auto /*unused*/) { return makeRowVector( {"t0", "t1"}, @@ -3232,12 +3301,15 @@ TEST_F(HashJoinTest, semiProjectOverLazyVectors) { } VELOX_INSTANTIATE_TEST_SUITE_P( - HashJoinTest, MultiThreadedHashJoinTest, - testing::ValuesIn(MultiThreadedHashJoinTest::getTestParams())); + MultiThreadedHashJoinTest, + testing::ValuesIn(MultiThreadedHashJoinTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return TestParamToName(info.param); + }); // TODO: try to parallelize the following test cases if possible. -TEST_F(HashJoinTest, memory) { +TEST_P(HashJoinTest, memory) { // Measures memory allocation in a 1:n hash join followed by // projection and aggregation. We expect vectors to be mostly // reused, except for t_k0 + 1, which is a dictionary after the @@ -3276,7 +3348,7 @@ TEST_F(HashJoinTest, memory) { EXPECT_GT(40'000'000, params.queryCtx->pool()->stats().cumulativeBytes); } -TEST_F(HashJoinTest, lazyVectors) { +TEST_P(HashJoinTest, lazyVectors) { // a dataset of multiple row groups with multiple columns. We create // different dictionary wrappings for different columns and load the // rows in scope at different times. @@ -3322,8 +3394,9 @@ TEST_F(HashJoinTest, lazyVectors) { } std::vector buildSplits; for (int i = 0; i < buildVectors.size(); ++i) { - buildSplits.push_back(exec::Split(makeHiveConnectorSplit( - tempFiles[probeSplits.size() + i]->getPath()))); + buildSplits.push_back( + exec::Split(makeHiveConnectorSplit( + tempFiles[probeSplits.size() + i]->getPath()))); } SplitInput splits; splits.emplace(probeScanId, probeSplits); @@ -3390,7 +3463,7 @@ TEST_F(HashJoinTest, lazyVectors) { } } -TEST_F(HashJoinTest, lazyVectorNotLoadedInFilter) { +TEST_P(HashJoinTest, lazyVectorNotLoadedInFilter) { // Ensure that if lazy vectors are temporarily wrapped during a filter's // execution and remain unloaded, the temporary wrap is promptly // discarded. This precaution prevents the generation of the probe's output @@ -3408,7 +3481,7 @@ TEST_F(HashJoinTest, lazyVectorNotLoadedInFilter) { "SELECT t.c1, t.c2 FROM t, u WHERE t.c0 = u.c0"); } -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftJoin) { +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftJoin) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. @@ -3419,7 +3492,7 @@ TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftJoin) { "SELECT t.c1, t.c2 FROM t LEFT JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"); } -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterFullJoin) { +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterFullJoin) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. @@ -3430,7 +3503,7 @@ TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterFullJoin) { "SELECT t.c1, t.c2 FROM t FULL OUTER JOIN u ON t.c0 = u.c0 AND (c1 > 0 AND c2 > 0)"); } -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiProject) { +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiProject) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. @@ -3441,7 +3514,7 @@ TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiProject) { "SELECT t.c1, t.c2, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0)) FROM t"); } -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterAntiJoin) { +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterAntiJoin) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. @@ -3452,7 +3525,7 @@ TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterAntiJoin) { "SELECT t.c1, t.c2 FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE t.c0 = u.c0 AND (t.c1 > 0 AND t.c2 > 0))"); } -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterInnerJoin) { +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterInnerJoin) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. @@ -3463,7 +3536,7 @@ TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterInnerJoin) { "SELECT t.c1, t.c2 FROM t, u WHERE t.c0 = u.c0 AND NOT (c1 < 15 AND c2 >= 0)"); } -TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiFilter) { +TEST_P(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiFilter) { // Test the case where a filter loads a subset of the rows that will be output // from a column on the probe side. @@ -3474,7 +3547,7 @@ TEST_F(HashJoinTest, lazyVectorPartiallyLoadedInFilterLeftSemiFilter) { "SELECT t.c1, t.c2 FROM t WHERE c0 IN (SELECT u.c0 FROM u WHERE t.c0 = u.c0 AND NOT (t.c1 < 15 AND t.c2 >= 0))"); } -TEST_F(HashJoinTest, dynamicFilters) { +TEST_P(HashJoinTest, dynamicFilters) { const int32_t numSplits = 10; const int32_t numRowsProbe = 333; const int32_t numRowsBuild = 100; @@ -3542,7 +3615,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Basic push-down. { - // Inner join. + SCOPED_TRACE("Inner join"); core::PlanNodeId probeScanId; core::PlanNodeId joinId; auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) @@ -3720,8 +3793,9 @@ TEST_F(HashJoinTest, dynamicFilters) { // Basic push-down with column names projected out of the table scan // having different names than column names in the files. { + SCOPED_TRACE("Inner join column rename"); auto scanOutputType = ROW({"a", "b"}, {INTEGER(), BIGINT()}); - ColumnHandleMap assignments; + connector::ColumnHandleMap assignments; assignments["a"] = regularColumn("c0", INTEGER()); assignments["b"] = regularColumn("c1", BIGINT()); @@ -3768,6 +3842,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Push-down that requires merging filters. { + SCOPED_TRACE("Merge filters"); core::PlanNodeId probeScanId; core::PlanNodeId joinId; auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) @@ -3808,6 +3883,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Push-down that turns join into a no-op. { + SCOPED_TRACE("canReplaceWithDynamicFilter"); core::PlanNodeId probeScanId; core::PlanNodeId joinId; auto op = @@ -3851,6 +3927,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Push-down that turns join into a no-op with output having a different // number of columns than the input. { + SCOPED_TRACE("canReplaceWithDynamicFilter column rename"); core::PlanNodeId probeScanId; core::PlanNodeId joinId; auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) @@ -3891,6 +3968,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Push-down that requires merging filters and turns join into a no-op. { + SCOPED_TRACE("canReplaceWithDynamicFilter merge filters"); core::PlanNodeId probeScanId; core::PlanNodeId joinId; auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) @@ -3931,6 +4009,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Push-down with highly selective filter in the scan. { + SCOPED_TRACE("Highly selective filter"); // Inner join. core::PlanNodeId probeScanId; core::PlanNodeId joinId; @@ -3945,6 +4024,7 @@ TEST_F(HashJoinTest, dynamicFilters) { .planNode(); { + SCOPED_TRACE("Inner join"); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(std::move(op)) .makeInputSplits(makeInputSplits(probeScanId)) @@ -3989,6 +4069,7 @@ TEST_F(HashJoinTest, dynamicFilters) { .planNode(); { + SCOPED_TRACE("Left semi join"); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(std::move(op)) .makeInputSplits(makeInputSplits(probeScanId)) @@ -4033,6 +4114,7 @@ TEST_F(HashJoinTest, dynamicFilters) { .planNode(); { + SCOPED_TRACE("Right semi join"); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(std::move(op)) .makeInputSplits(makeInputSplits(probeScanId)) @@ -4073,6 +4155,7 @@ TEST_F(HashJoinTest, dynamicFilters) { .planNode(); { + SCOPED_TRACE("Right join"); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .planNode(std::move(op)) .makeInputSplits(makeInputSplits(probeScanId)) @@ -4104,6 +4187,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Disable filter push-down by using values in place of scan. { + SCOPED_TRACE("Disabled in case of values node"); core::PlanNodeId joinId; auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) .values(probeVectors) @@ -4127,6 +4211,7 @@ TEST_F(HashJoinTest, dynamicFilters) { // Disable filter push-down by using an expression as the join key on the // probe side. { + SCOPED_TRACE("Disabled in case of join condition"); core::PlanNodeId probeScanId; core::PlanNodeId joinId; auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) @@ -4153,7 +4238,7 @@ TEST_F(HashJoinTest, dynamicFilters) { } } -TEST_F(HashJoinTest, dynamicFiltersStatsWithChainedJoins) { +TEST_P(HashJoinTest, dynamicFiltersStatsWithChainedJoins) { const int32_t numSplits = 10; const int32_t numProbeRows = 333; const int32_t numBuildRows = 100; @@ -4250,7 +4335,7 @@ TEST_F(HashJoinTest, dynamicFiltersStatsWithChainedJoins) { .run(); } -TEST_F(HashJoinTest, dynamicFiltersWithSkippedSplits) { +TEST_P(HashJoinTest, dynamicFiltersWithSkippedSplits) { const int32_t numSplits = 20; const int32_t numNonSkippedSplits = 10; const int32_t numRowsProbe = 333; @@ -4472,7 +4557,7 @@ TEST_F(HashJoinTest, dynamicFiltersWithSkippedSplits) { } } -TEST_F(HashJoinTest, dynamicFiltersAppliedToPreloadedSplits) { +TEST_P(HashJoinTest, dynamicFiltersAppliedToPreloadedSplits) { vector_size_t size = 1000; const int32_t numSplits = 5; @@ -4500,7 +4585,7 @@ TEST_F(HashJoinTest, dynamicFiltersAppliedToPreloadedSplits) { } auto outputType = ROW({"p0", "p1"}, {BIGINT(), BIGINT()}); - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"p0", regularColumn("p0", BIGINT())}, {"p1", partitionKey("p1", BIGINT())}}; createDuckDbTable("p", probeVectors); @@ -4556,7 +4641,7 @@ TEST_F(HashJoinTest, dynamicFiltersAppliedToPreloadedSplits) { .run(); } -TEST_F(HashJoinTest, dynamicFiltersPushDownThroughAgg) { +TEST_P(HashJoinTest, dynamicFiltersPushDownThroughAgg) { const int32_t numRowsProbe = 300; const int32_t numRowsBuild = 100; @@ -4625,7 +4710,7 @@ TEST_F(HashJoinTest, dynamicFiltersPushDownThroughAgg) { .run(); } -TEST_F(HashJoinTest, noDynamicFiltersPushDownThroughRightJoin) { +TEST_P(HashJoinTest, noDynamicFiltersPushDownThroughRightJoin) { std::vector innerBuild = {makeRowVector( {"a"}, { @@ -4672,7 +4757,7 @@ TEST_F(HashJoinTest, noDynamicFiltersPushDownThroughRightJoin) { // Verify the size of the join output vectors when projecting build-side // variable-width column. -TEST_F(HashJoinTest, memoryUsage) { +TEST_P(HashJoinTest, memoryUsage) { std::vector probeVectors = makeBatches(10, [&](int32_t /*unused*/) { return makeRowVector( @@ -4727,7 +4812,7 @@ TEST_F(HashJoinTest, memoryUsage) { /// Test an edge case in producing small output batches where the logic to /// calculate the set of probe-side rows to load lazy vectors for was /// triggering a crash. -TEST_F(HashJoinTest, smallOutputBatchSize) { +TEST_P(HashJoinTest, smallOutputBatchSize) { // Setup probe data with 50 non-null matching keys followed by 50 null // keys: 1, 2, 1, 2,...null, null. auto probeVectors = makeRowVector({ @@ -4773,12 +4858,13 @@ TEST_F(HashJoinTest, smallOutputBatchSize) { .run(); } -TEST_F(HashJoinTest, spillFileSize) { +TEST_P(HashJoinTest, spillFileSize) { const std::vector maxSpillFileSizes({0, 1, 1'000'000'000}); for (const auto spillFileSize : maxSpillFileSizes) { SCOPED_TRACE(fmt::format("spillFileSize: {}", spillFileSize)); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT()}) .probeVectors(100, 3) .buildVectors(100, 3) @@ -4809,10 +4895,11 @@ TEST_F(HashJoinTest, spillFileSize) { } } -TEST_F(HashJoinTest, spillPartitionBitsOverlap) { +TEST_P(HashJoinTest, spillPartitionBitsOverlap) { auto builder = HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .keyTypes({BIGINT(), BIGINT()}) .probeVectors(2'000, 3) .buildVectors(2'000, 3) @@ -4827,7 +4914,7 @@ TEST_F(HashJoinTest, spillPartitionBitsOverlap) { // The test is to verify if the hash build reservation has been released on // task error. -DEBUG_ONLY_TEST_F(HashJoinTest, buildReservationReleaseCheck) { +DEBUG_ONLY_TEST_P(HashJoinTest, buildReservationReleaseCheck) { std::vector probeVectors = makeBatches(1, [&](int32_t /*unused*/) { return std::dynamic_pointer_cast( @@ -4879,7 +4966,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, buildReservationReleaseCheck) { ASSERT_TRUE(waitForTaskAborted(task, 5'000'000)); } -TEST_F(HashJoinTest, dynamicFilterOnPartitionKey) { +TEST_P(HashJoinTest, dynamicFilterOnPartitionKey) { vector_size_t size = 10; auto filePaths = makeFilePaths(1); auto rowVector = makeRowVector( @@ -4894,7 +4981,7 @@ TEST_F(HashJoinTest, dynamicFilterOnPartitionKey) { .partitionKey("k", "0") .build(); auto outputType = ROW({"n1_0", "n1_1"}, {BIGINT(), BIGINT()}); - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"n1_0", regularColumn("c0", BIGINT())}, {"n1_1", partitionKey("k", BIGINT())}}; @@ -4926,7 +5013,7 @@ TEST_F(HashJoinTest, dynamicFilterOnPartitionKey) { .run(); } -TEST_F(HashJoinTest, probeMemoryLimitOnBuildProjection) { +TEST_P(HashJoinTest, probeMemoryLimitOnBuildProjection) { const uint64_t numBuildRows = 20; std::vector probeVectors = makeBatches(10, [&](int32_t /*unused*/) { @@ -5031,7 +5118,7 @@ TEST_F(HashJoinTest, probeMemoryLimitOnBuildProjection) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringInputProcessing) { +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringInputProcessing) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); const int32_t numBuildVectors = 10; @@ -5126,6 +5213,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringInputProcessing) { std::thread taskThread([&]() { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .queryPool(std::move(queryPool)) .injectSpill(false) @@ -5197,7 +5285,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringInputProcessing) { ASSERT_EQ(reclaimerStats_, memory::MemoryReclaimer::Stats{}); } -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringReserve) { +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringReserve) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB const int32_t numBuildVectors = 3; std::vector buildVectors; @@ -5281,6 +5369,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringReserve) { std::thread taskThread([&]() { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .queryPool(std::move(queryPool)) .injectSpill(false) @@ -5330,7 +5419,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringReserve) { taskThread.join(); } -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringAllocation) { +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringAllocation) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); const int32_t numBuildVectors = 10; @@ -5414,6 +5503,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringAllocation) { std::thread taskThread([&]() { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .queryPool(std::move(queryPool)) .injectSpill(false) @@ -5461,7 +5551,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringAllocation) { ASSERT_EQ(reclaimerStats_, memory::MemoryReclaimer::Stats{0}); } -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringOutputProcessing) { +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringOutputProcessing) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); const int32_t numBuildVectors = 10; @@ -5533,6 +5623,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringOutputProcessing) { std::thread taskThread([&]() { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .queryPool(std::move(queryPool)) .injectSpill(false) @@ -5594,7 +5685,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringOutputProcessing) { ASSERT_EQ(reclaimerStats_.numNonReclaimableAttempts, 1); } -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringWaitForProbe) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); const int32_t numBuildVectors = 10; @@ -5680,6 +5771,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { std::thread taskThread([&]() { HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .queryPool(std::move(queryPool)) .injectSpill(false) @@ -5732,7 +5824,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { ASSERT_EQ(reclaimerStats_.numNonReclaimableAttempts, 1); } -DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringOutputProcessing) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashBuildAbortDuringOutputProcessing) { const auto buildVectors = makeVectors(buildType_, 10, 128); const auto probeVectors = makeVectors(probeType_, 5, 128); @@ -5798,6 +5890,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringOutputProcessing) { VELOX_ASSERT_THROW( HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .injectSpill(false) .referenceQuery( @@ -5808,7 +5901,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringOutputProcessing) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringInputProcessing) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashBuildAbortDuringInputProcessing) { const auto buildVectors = makeVectors(buildType_, 10, 128); const auto probeVectors = makeVectors(probeType_, 5, 128); @@ -5874,6 +5967,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringInputProcessing) { VELOX_ASSERT_THROW( HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .injectSpill(false) .referenceQuery( @@ -5885,7 +5979,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringInputProcessing) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringAllocation) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashBuildAbortDuringAllocation) { const auto buildVectors = makeVectors(buildType_, 10, 128); const auto probeVectors = makeVectors(probeType_, 5, 128); @@ -5952,6 +6046,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringAllocation) { VELOX_ASSERT_THROW( HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .injectSpill(false) .referenceQuery( @@ -5963,7 +6058,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashBuildAbortDuringAllocation) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeAbortDuringInputProcessing) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeAbortDuringInputProcessing) { const auto buildVectors = makeVectors(buildType_, 10, 128); const auto probeVectors = makeVectors(probeType_, 5, 128); @@ -6025,6 +6120,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeAbortDuringInputProcessing) { VELOX_ASSERT_THROW( HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .injectSpill(false) .referenceQuery( @@ -6035,7 +6131,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeAbortDuringInputProcessing) { } } -TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatch) { +TEST_P(HashJoinTest, leftJoinWithMissAtEndOfBatch) { // Tests some cases where the row at the end of an output batch fails the // filter. auto probeVectors = std::vector{makeRowVector( @@ -6070,9 +6166,10 @@ TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatch) { .numDrivers(1) .config( core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .referenceQuery(fmt::format( - "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", - filter)) + .referenceQuery( + fmt::format( + "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", + filter)) .run(); }; @@ -6086,7 +6183,7 @@ TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatch) { test("t_k2 > 9"); } -TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatchMultipleBuildMatches) { +TEST_P(HashJoinTest, leftJoinWithMissAtEndOfBatchMultipleBuildMatches) { // Tests some cases where the row at the end of an output batch fails the // filter and there are multiple matches with the build side.. auto probeVectors = std::vector{makeRowVector( @@ -6121,9 +6218,10 @@ TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatchMultipleBuildMatches) { .numDrivers(1) .config( core::QueryConfig::kPreferredOutputBatchRows, std::to_string(10)) - .referenceQuery(fmt::format( - "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", - filter)) + .referenceQuery( + fmt::format( + "SELECT t_k1, u_k1 from t left join u on t_k1 = u_k1 and {}", + filter)) .run(); }; @@ -6135,7 +6233,7 @@ TEST_F(HashJoinTest, leftJoinWithMissAtEndOfBatchMultipleBuildMatches) { test("t_k2 != 4 and t_k2 != 8"); } -TEST_F(HashJoinTest, leftJoinPreserveProbeOrder) { +TEST_P(HashJoinTest, leftJoinPreserveProbeOrder) { const std::vector probeVectors = { makeRowVector( {"k1", "v1"}, @@ -6177,7 +6275,7 @@ TEST_F(HashJoinTest, leftJoinPreserveProbeOrder) { ASSERT_EQ(v1->valueAt(2), 0); } -DEBUG_ONLY_TEST_F(HashJoinTest, minSpillableMemoryReservation) { +DEBUG_ONLY_TEST_P(HashJoinTest, minSpillableMemoryReservation) { VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); const int32_t numBuildVectors = 10; std::vector buildVectors; @@ -6208,8 +6306,9 @@ DEBUG_ONLY_TEST_F(HashJoinTest, minSpillableMemoryReservation) { .planNode(); for (int32_t minSpillableReservationPct : {5, 50, 100}) { - SCOPED_TRACE(fmt::format( - "minSpillableReservationPct: {}", minSpillableReservationPct)); + SCOPED_TRACE( + fmt::format( + "minSpillableReservationPct: {}", minSpillableReservationPct)); SCOPED_TESTVALUE_SET( "facebook::velox::exec::HashBuild::addInput", @@ -6226,6 +6325,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, minSpillableMemoryReservation) { auto tempDirectory = exec::test::TempDirectoryPath::create(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .planNode(plan) .injectSpill(false) .spillDirectory(tempDirectory->getPath()) @@ -6235,7 +6335,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, minSpillableMemoryReservation) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, exceededMaxSpillLevel) { +DEBUG_ONLY_TEST_P(HashJoinTest, exceededMaxSpillLevel) { VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); const int32_t numBuildVectors = 10; std::vector buildVectors; @@ -6325,7 +6425,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, exceededMaxSpillLevel) { exceededMaxSpillLevelCount + 16); } -TEST_F(HashJoinTest, maxSpillBytes) { +TEST_P(HashJoinTest, maxSpillBytes) { const auto rowType = ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); const auto probeVectors = createVectors(rowType, 1024, 10 << 20); @@ -6382,7 +6482,7 @@ TEST_F(HashJoinTest, maxSpillBytes) { } } -TEST_F(HashJoinTest, onlyHashBuildMaxSpillBytes) { +TEST_P(HashJoinTest, onlyHashBuildMaxSpillBytes) { const auto rowType = ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); const auto probeVectors = createVectors(rowType, 32, 128); @@ -6438,7 +6538,7 @@ TEST_F(HashJoinTest, onlyHashBuildMaxSpillBytes) { } } -TEST_F(HashJoinTest, reclaimFromJoinBuilderWithMultiDrivers) { +TEST_P(HashJoinTest, reclaimFromJoinBuilderWithMultiDrivers) { auto rowType = ROW({ {"c0", INTEGER()}, {"c1", INTEGER()}, @@ -6487,7 +6587,135 @@ TEST_F(HashJoinTest, reclaimFromJoinBuilderWithMultiDrivers) { ASSERT_GT(arbitrator->stats().reclaimedUsedBytes, 0); } -DEBUG_ONLY_TEST_F( +TEST_P(HashJoinTest, semiJoinAbandonBuildNoDupHashEarly) { + auto probeVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 2, 2, 3, 3, 3, 4, 5, 5, 6, 7}), + makeFlatVector({10, 20, 21, 30, 31, 32, 40, 50, 51, 60, 70}), + }); + }); + + auto buildVectors = makeBatches(3, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector({1, 1, 3, 4, 5, 5, 7, 8}), + makeFlatVector({100, 101, 300, 400, 500, 501, 700, 800}), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .config(core::QueryConfig::kAbandonDedupHashMapMinRows, "1") + .config(core::QueryConfig::kAbandonDedupHashMapMinPct, "10") + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(); +} + +TEST_P(HashJoinTest, antiJoinAbandonBuildNoDupHashEarly) { + auto probeVectors = makeBatches(64, [&](int32_t /*unused*/) { + return makeRowVector( + {"t0", "t1"}, + { + makeNullableFlatVector({std::nullopt, 1, 2, 3, 4, 5, 6}), + makeFlatVector({0, 1, 2, 3, 4, 5, 6}), + }); + }); + auto buildVectors = makeBatches(64, [&](int32_t /*unused*/) { + return makeRowVector( + {"u0", "u1"}, + { + makeNullableFlatVector({std::nullopt, 2, 3, 4, 6, 7, 8}), + makeFlatVector({0, 2, 3, 4, 6, 7, 8}), + }); + }); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .config(core::QueryConfig::kAbandonDedupHashMapMinRows, "1") + .config(core::QueryConfig::kAbandonDedupHashMapMinPct, "10") + .numDrivers(numDrivers_) + .probeKeys({"t0"}) + .probeVectors(std::vector(probeVectors)) + .buildKeys({"u0"}) + .buildVectors(std::vector(buildVectors)) + .joinType(core::JoinType::kAnti) + .joinOutputLayout({"t0", "t1"}) + .referenceQuery( + "SELECT t.* FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE u.u0 = t.t0)") + .run(); +} + +TEST_P(HashJoinTest, semiJoinDeduplicateResetCapacity) { + const int32_t vectorSize = 10; + const int32_t batches = 210; + auto probeVectors = makeBatches(batches, [&](int32_t /*unused*/) { + return makeRowVector({ + // Join Key is double -> VectorHasher::typeKindSupportsValueIds will + // return false -> HashMode is kHash + makeFlatVector( + vectorSize, [&](vector_size_t /*row*/) { return rand(); }), + makeFlatVector( + vectorSize, [&](vector_size_t /*row*/) { return rand(); }), + }); + }); + + auto buildVectors = makeBatches(batches, [&](int32_t batch) { + return makeRowVector({ + makeFlatVector( + vectorSize, [&](vector_size_t /*row*/) { return rand(); }), + makeFlatVector( + vectorSize, [&](vector_size_t /*row*/) { return rand(); }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors) + .project({"c0 AS t0", "c1 AS t1"}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors) + .project({"c0 AS u0", "c1 AS u1"}) + .planNode(), + "", + {"t0", "t1", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .config(core::QueryConfig::kAbandonDedupHashMapMinRows, "10") + .config(core::QueryConfig::kAbandonDedupHashMapMinPct, "50") + .numDrivers(1) + .checkSpillStats(false) + .planNode(plan) + .referenceQuery( + "SELECT t.c0, t.c1, EXISTS (SELECT * FROM u WHERE t.c0 = u.c0) FROM t") + .run(); +} + +DEBUG_ONLY_TEST_P( HashJoinTest, failedToReclaimFromHashJoinBuildersInNonReclaimableSection) { auto rowType = ROW({ @@ -6580,7 +6808,7 @@ DEBUG_ONLY_TEST_F( 2); } -DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringTableBuild) { +DEBUG_ONLY_TEST_P(HashJoinTest, reclaimDuringTableBuild) { VectorFuzzer fuzzer({.vectorSize = 1000}, pool()); const int32_t numBuildVectors = 5; std::vector buildVectors; @@ -6640,7 +6868,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringTableBuild) { .run(); } -DEBUG_ONLY_TEST_F(HashJoinTest, exceptionDuringFinishJoinBuild) { +DEBUG_ONLY_TEST_P(HashJoinTest, exceptionDuringFinishJoinBuild) { // This test is to make sure there is no memory leak when exceptions are // thrown while parallelly preparing join table. auto memoryManager = memory::memoryManager(); @@ -6724,7 +6952,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, exceptionDuringFinishJoinBuild) { ASSERT_EQ(arbitrator->stats().freeCapacityBytes, expectedFreeCapacityBytes); } -DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) { +DEBUG_ONLY_TEST_P(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) { std::unique_ptr memoryManager = createMemoryManager(); const uint64_t numDrivers = 2; @@ -6839,7 +7067,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) { waitForAllTasksToBeDeleted(); } -DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) { +DEBUG_ONLY_TEST_P(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) { // Use manual spill injection other than spill injection framework. This is // because spill injection framework does not allow fine grain spill within a // single operator (We do not want to spill during addInput() but only during @@ -6853,6 +7081,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) { auto tempDirectory = exec::test::TempDirectoryPath::create(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) .numDrivers(numDrivers_) + .parallelizeJoinBuildRows(parallelBuildSideRowsEnabled_) .injectSpill(false) .spillDirectory(tempDirectory->getPath()) .keyTypes({BIGINT()}) @@ -6867,8 +7096,8 @@ DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) { .run(); } -DEBUG_ONLY_TEST_F(HashJoinTest, joinBuildSpillError) { - const int kMemoryCapacity = 32 << 20; +DEBUG_ONLY_TEST_P(HashJoinTest, joinBuildSpillError) { + const int kMemoryCapacity = 27 << 20; // Set a small memory capacity to trigger spill. std::unique_ptr memoryManager = createMemoryManager(kMemoryCapacity, 0); @@ -6930,7 +7159,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, joinBuildSpillError) { waitForAllTasksToBeDeleted(); } -DEBUG_ONLY_TEST_F(HashJoinTest, probeSpillOnWaitForPeers) { +DEBUG_ONLY_TEST_P(HashJoinTest, probeSpillOnWaitForPeers) { // This test creates a scenario when tester probe thread finishes processing // input, entering kWaitForPeers state, and the other thread is still // processing, spill is triggered properly performed. @@ -7037,7 +7266,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, probeSpillOnWaitForPeers) { waitForAllTasksToBeDeleted(); } -DEBUG_ONLY_TEST_F(HashJoinTest, taskWaitTimeout) { +DEBUG_ONLY_TEST_P(HashJoinTest, taskWaitTimeout) { const int queryMemoryCapacity = 128 << 20; // Creates a large number of vectors based on the query capacity to trigger // memory arbitration. @@ -7127,7 +7356,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, taskWaitTimeout) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpill) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpill) { struct { bool triggerBuildSpill; // Triggers after no more input or not. @@ -7217,7 +7446,9 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpill) { .injectSpill(false) .verifier([&](const std::shared_ptr& task, bool /*unused*/) { auto opStats = toOperatorStats(task->taskStats()); - ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + if (!parallelBuildSideRowsEnabled_) { + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + } if (testData.triggerBuildSpill) { ASSERT_GT(opStats.at("HashBuild").spilledBytes, 0); } else { @@ -7232,7 +7463,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpill) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillInMiddeOfLastOutputProcessing) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillInMiddleOfLastOutputProcessing) { std::atomic_int outputCountAfterNoMoreInout{0}; std::atomic_bool injectOnce{true}; ::facebook::velox::common::testutil::ScopedTestValue abc( @@ -7285,7 +7516,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillInMiddeOfLastOutputProcessing) { // Inject probe-side spilling in the middle of output processing. If // 'recursiveSpill' is true, we trigger probe-spilling when probe the hash table // built from spilled data. -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillInMiddeOfOutputProcessing) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillInMiddleOfOutputProcessing) { for (bool recursiveSpill : {false, true}) { std::atomic_int buildInputCount{0}; SCOPED_TESTVALUE_SET( @@ -7355,7 +7586,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillInMiddeOfOutputProcessing) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillWhenOneOfProbeFinish) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillWhenOneOfProbeFinish) { const int numDrivers{3}; std::atomic_bool probeWaitFlag{true}; @@ -7412,7 +7643,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillWhenOneOfProbeFinish) { queryThread.join(); } -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillExceedLimit) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillExceedLimit) { // If 'buildTriggerSpill' is true, then spilling is triggered by hash build. for (const bool buildTriggerSpill : {false, true}) { SCOPED_TRACE(fmt::format("buildTriggerSpill {}", buildTriggerSpill)); @@ -7484,7 +7715,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillExceedLimit) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillUnderNonReclaimableSection) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashProbeSpillUnderNonReclaimableSection) { std::atomic_bool injectOnce{true}; SCOPED_TESTVALUE_SET( "facebook::velox::common::memory::MemoryPoolImpl::allocateNonContiguous", @@ -7528,7 +7759,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillUnderNonReclaimableSection) { // This test case is to cover the case that hash probe trigger spill for right // semi join types and the pending input needs to be processed in multiple // steps. -DEBUG_ONLY_TEST_F(HashJoinTest, spillOutputWithRightSemiJoins) { +DEBUG_ONLY_TEST_P(HashJoinTest, spillOutputWithRightSemiJoins) { for (const auto joinType : {core::JoinType::kRightSemiFilter, core::JoinType::kRightSemiProject}) { std::atomic_bool injectOnce{true}; @@ -7584,7 +7815,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, spillOutputWithRightSemiJoins) { } } -DEBUG_ONLY_TEST_F(HashJoinTest, spillCheckOnLeftSemiFilterWithDynamicFilters) { +DEBUG_ONLY_TEST_P(HashJoinTest, spillCheckOnLeftSemiFilterWithDynamicFilters) { const int32_t numSplits = 10; const int32_t numRowsProbe = 333; const int32_t numRowsBuild = 100; @@ -7700,7 +7931,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, spillCheckOnLeftSemiFilterWithDynamicFilters) { // start processing. This can cause unnecessary spill and query OOM under some // real workload with many stages as each hash probe might reserve non-trivial // amount of memory. -DEBUG_ONLY_TEST_F( +DEBUG_ONLY_TEST_P( HashJoinTest, hashProbeMemoryReservationCheckBeforeProbeStartWithSpillEnabled) { fuzzerOpts_.vectorSize = 128; @@ -7746,7 +7977,7 @@ DEBUG_ONLY_TEST_F( .run(); } -TEST_F(HashJoinTest, nanKeys) { +TEST_P(HashJoinTest, nanKeys) { // Verify the NaN values with different binary representations are considered // equal. static const double kNan = std::numeric_limits::quiet_NaN(); @@ -7780,7 +8011,7 @@ TEST_F(HashJoinTest, nanKeys) { facebook::velox::test::assertEqualVectors(expected, result); } -DEBUG_ONLY_TEST_F(HashJoinTest, spillOnBlockedProbe) { +DEBUG_ONLY_TEST_P(HashJoinTest, spillOnBlockedProbe) { auto blockedOperatorFactoryUniquePtr = std::make_unique(); auto blockedOperatorFactory = blockedOperatorFactoryUniquePtr.get(); @@ -7854,9 +8085,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, spillOnBlockedProbe) { } arbitrationThread.join(); waitForAllTasksToBeDeleted(30'000'000); + Operator::unregisterAllOperators(); } -DEBUG_ONLY_TEST_F(HashJoinTest, buildReclaimedMemoryReport) { +DEBUG_ONLY_TEST_P(HashJoinTest, buildReclaimedMemoryReport) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB const int32_t numBuildVectors = 3; std::vector buildVectors; @@ -7983,7 +8215,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, buildReclaimedMemoryReport) { taskThread.join(); } -DEBUG_ONLY_TEST_F(HashJoinTest, probeReclaimedMemoryReport) { +DEBUG_ONLY_TEST_P(HashJoinTest, probeReclaimedMemoryReport) { constexpr int64_t kMaxBytes = 1LL << 30; // 1GB const int32_t numBuildVectors = 3; std::vector buildVectors; @@ -8088,7 +8320,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, probeReclaimedMemoryReport) { taskThread.join(); } -DEBUG_ONLY_TEST_F(HashJoinTest, hashTableCleanupAfterProbeFinish) { +DEBUG_ONLY_TEST_P(HashJoinTest, hashTableCleanupAfterProbeFinish) { auto buildVectors = makeVectors(buildType_, 5, 100); auto probeVectors = makeVectors(probeType_, 5, 100); @@ -8139,4 +8371,347 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashTableCleanupAfterProbeFinish) { .run(); ASSERT_TRUE(tableEmpty); } + +TEST_P(HashJoinTest, innerJoinForTypeWithCustomComparisonAndSmallVector) { + // This test corresponds to the SQL query: + // SELECT + // LEFT_TABLE.ip_addr as ip_left_string, + // RIGHT_TABLE.ip_addr as ip_right_string, + // CAST(LEFT_TABLE.ip_addr AS IPADDRESS) as ip_left_ip_address_type, + // CAST(RIGHT_TABLE.ip_addr AS IPADDRESS) as ip_right_ip_address_type, + // CAST(LEFT_TABLE.ip_addr AS IPADDRESS) = CAST(RIGHT_TABLE.ip_addr AS + // IPADDRESS) as are_equal_as_ip_address_type + // FROM + // (VALUES ('2620:10d:c0a8:f0::37'), ('2620:10d:c053:33::37')) AS + // LEFT_TABLE(ip_addr) INNER JOIN (VALUES ('2620:10d:c0a8:f0::37')) AS + // RIGHT_TABLE(ip_addr) ON CAST(LEFT_TABLE.ip_addr AS IPADDRESS) = + // CAST(RIGHT_TABLE.ip_addr AS IPADDRESS) + // LIMIT 1000 + + auto leftVectors = makeRowVector({makeFlatVector( + {StringView("2620:10d:c0a8:f0::37"), + StringView("2620:10d:c053:33::37")})}); + + auto rightVectors = makeRowVector( + {makeFlatVector({StringView("2620:10d:c0a8:f0::37")})}); + createDuckDbTable("t", {leftVectors}); + createDuckDbTable("u", {rightVectors}); + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .project( + {"c0 AS ip_addr_right", + "CAST(c0 AS IPADDRESS) AS ip_addr_cast_right"}) + .planNode(); + + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .project({"c0 AS ip_addr", "CAST(c0 AS IPADDRESS) AS ip_addr_cast"}) + .hashJoin( + {"ip_addr_cast"}, + {"ip_addr_cast_right"}, + rightPlan, + "", + {"ip_addr", "ip_addr_cast"}, + core::JoinType::kInner) + .limit(0, 1000, false) + .planNode(); + + auto result = AssertQueryBuilder(plan).copyResults(pool()); + + ASSERT_EQ(result->size(), 1); + auto ipAddr = result->childAt(0)->as>(); + + // We expect 1 row (only the matching IPv6 address: 2620:10d:c0a8:f0::37) + ASSERT_EQ(ipAddr->valueAt(0), StringView("2620:10d:c0a8:f0::37")); + + // Test that different IPADDRESS values correctly don't match in hash join. + leftVectors = makeRowVector({ + makeFlatVector({ + "2620:10d:c053:33::37"_sv, + }), + }); + + rightVectors = makeRowVector({ + makeFlatVector({ + "2620:10d:c0a8:f0::37"_sv, + }), + }); + + planNodeIdGenerator = std::make_shared(); + + rightPlan = PlanBuilder(planNodeIdGenerator) + .values({rightVectors}) + .project( + {"c0 AS ip_addr_right", + "CAST(c0 AS IPADDRESS) AS ip_addr_cast_right"}) + .planNode(); + + plan = PlanBuilder(planNodeIdGenerator) + .values({leftVectors}) + .project( + {"c0 AS ip_left", + "CAST(c0 AS IPADDRESS) AS ip_left_cast", + "CAST(c0 AS VARCHAR) AS ip_left_string"}) + .hashJoin( + {"ip_left_cast"}, + {"ip_addr_cast_right"}, + rightPlan, + "", + {"ip_left_cast", "ip_addr_cast_right", "ip_addr_right"}, + core::JoinType::kInner) + .planNode(); + + // Result should be empty since the IP addresses are different + result = AssertQueryBuilder(plan).copyResults(pool()); + ASSERT_EQ(result->size(), 0) + << "Expected no matches between different IP addresses, but got " + << result->size() << " rows"; +} + +/// Test hash join where build-side keys have a type that supports custom +/// comparison and come from a small range which would allow for array-based +/// lookup instead of a hash table for other types. +TEST_P(HashJoinTest, arrayBasedLookupCustomComparisonType) { + std::vector probeVectors = { + makeRowVector({makeFlatVector( + 1'024, + [](auto row) { return row; }, + nullptr, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})}; + + std::vector buildVectors = { + makeRowVector({makeFlatVector( + 256, + [](auto row) { return row; }, + nullptr, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})}; + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = PlanBuilder(planNodeIdGenerator) + .values({buildVectors}) + .project({"c0 as right"}) + .planNode(); + + auto plan = PlanBuilder(planNodeIdGenerator) + .values({probeVectors}) + .project({"c0 as left"}) + .hashJoin( + {"left"}, + {"right"}, + rightPlan, + "", + {"left"}, + core::JoinType::kInner) + .planNode(); + + auto result = AssertQueryBuilder(plan).copyResults(pool()); + + // The probe side consists of the values 0-1023, the build side consists of + // the values 0-255. If custom comparison is not respected, the join will + // produce 256 values (0-255). When custom comparison is respected equality is + // treated mod 256 so we get 1024 values (0-1023). + EXPECT_EQ(result->size(), 1'024); +} + +DEBUG_ONLY_TEST_P( + HashJoinTest, + hashProbeShouldYieldWhenFilterConsistentlyRejectAll) { + const uint32_t kProbeSize = 100; + const uint32_t kBuildSize = 10'000; + const uint64_t kDriverCpuTimeSliceLimitMs = 1'000; + const std::string kLargeBatchSize = + folly::to(kProbeSize * kBuildSize); + + struct { + uint32_t numGetOutputCalls; + bool hasDelay; + std::string debugString() const { + return fmt::format( + "numGetOutputCalls: {}, hasDelay: {}", numGetOutputCalls, hasDelay); + } + } testSettings[] = {{0, false}, {0, true}}; + + // Create probe data with keys 0-99 and an additional filter column + const auto probeData = makeRowVector( + {"t_k1", "t_filter"}, + { + makeFlatVector(kProbeSize, [](auto row) { return row; }), + makeFlatVector( + kProbeSize, + [](/*row=*/auto) { return 1; }), // All rows have value 1 + }); + + const auto buildData = makeRowVector( + {"u_k1"}, + { + makeFlatVector(kBuildSize, [](auto row) { return row; }), + }); + + createDuckDbTable("t", {probeData}); + createDuckDbTable("u", {buildData}); + + auto planNodeIdGenerator = std::make_shared(); + auto planNode = + PlanBuilder(planNodeIdGenerator) + .values({probeData}) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator).values({buildData}).planNode(), + // Filter that DOES find join matches but then rejects all of them + // This ensures numOut > 0 after listJoinResults, but == 0 after + // evalFilter All probe rows have t_filter=1, so the condition + // t_filter > 100000 rejects all + "t_filter > 100000", + {"t_k1", "u_k1"}, + core::JoinType::kInner) + .planNode(); + + for (auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + std::atomic_int hashProbeGetOutputCalls{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (op->operatorType() == "HashProbe") { + // Inject delay on the 2nd getOutput call when hasDelay is true + // This simulates the scenario where: + // 1. First getOutput: Probe data added via addInput + // 2. Second getOutput: Join finds matches, filter rejects all + // During this call, we inject delay INSIDE the processing + // to simulate CPU-intensive work in the loop + if (hashProbeGetOutputCalls.fetch_add(1) == 1 && + testData.hasDelay) { + std::this_thread::sleep_for( + std::chrono::milliseconds(2 * kDriverCpuTimeSliceLimitMs)); + } + } + })); + + auto queryCtx = core::QueryCtx::create( + executor_.get(), + core::QueryConfig({ + {core::QueryConfig::kDriverCpuTimeSliceLimitMs, + folly::to(kDriverCpuTimeSliceLimitMs)}, + {core::QueryConfig::kPreferredOutputBatchRows, kLargeBatchSize}, + })); + + AssertQueryBuilder(planNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .assertResults( + "SELECT t_k1, u_k1 FROM t, u WHERE t_k1 = u_k1 AND t_filter > 100000"); + testData.numGetOutputCalls = hashProbeGetOutputCalls.load(); + } + ASSERT_LT( + testSettings[0].numGetOutputCalls, testSettings[1].numGetOutputCalls); +} + +// This test validates that when spillOutput() is running (toSpillOutput=true), +// the operator should NOT yield even when shouldYield() returns true. This is +// critical because yielding during spillOutput would break the spilling loop. +DEBUG_ONLY_TEST_P( + HashJoinTest, + spillOutputShouldNotYieldWhenFilterConsistentlyRejectAll) { + const uint32_t kProbeSize = 100; + const uint32_t kBuildSize = 10'000; + const uint64_t driverCpuTimeSliceLimitMs = 1'000; + const std::string largeBatchSize = + folly::to(kProbeSize * kBuildSize); + + // Create probe data with keys 0-99 and an additional filter column + const auto probeData = makeRowVector( + {"t_k1", "t_filter"}, + { + makeFlatVector(kProbeSize, [](auto row) { return row; }), + makeFlatVector( + kProbeSize, + [](/*row=*/auto) { return 1; }), // All rows have value 1 + }); + + const auto buildData = makeRowVector( + {"u_k1"}, + { + makeFlatVector(kBuildSize, [](auto row) { return row; }), + }); + + createDuckDbTable("t", {probeData}); + createDuckDbTable("u", {buildData}); + + auto planNodeIdGenerator = std::make_shared(); + auto planNode = + PlanBuilder(planNodeIdGenerator) + .values({probeData}) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator).values({buildData}).planNode(), + // Filter that DOES find join matches but then rejects all of them + // This ensures numOut > 0 after listJoinResults, but == 0 after + // evalFilter. All probe rows have t_filter=1, so the condition + // t_filter > 100000 rejects all + "t_filter > 100000", + {"t_k1", "u_k1"}, + core::JoinType::kInner) + .planNode(); + + std::atomic_bool spillTriggered{false}; + ::facebook ::velox ::common ::testutil ::ScopedTestValue _scopedTestValue5200( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](Operator* op) { + if (spillTriggered.load() || op->operatorType() != "HashProbe" || + !op->testingHasInput()) { + return; + } + spillTriggered = true; + testingRunArbitration(op->pool()); + })); + + // We inject delay in reclaim to trigger shouldYield(). + // The test verifies that the query completes successfully despite + // shouldYield() returning true, which would only happen if the + // !toSpillOutput check prevents early return. + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashProbe::reclaim", + std::function([&](HashProbe* probe) { + if (!spillTriggered.load()) { + return; + } + // Inject delay once to trigger shouldYield() + std::this_thread::sleep_for( + std::chrono::milliseconds(2 * driverCpuTimeSliceLimitMs)); + })); + + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + AssertQueryBuilder(planNode, duckDbQueryRunner_) + .queryCtx(core::QueryCtx::create(driverExecutor_.get())) + .maxDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kJoinSpillEnabled, true) + .config(core::QueryConfig::kSpillStartPartitionBit, 29) + .config( + core::QueryConfig::kDriverCpuTimeSliceLimitMs, + driverCpuTimeSliceLimitMs) + .config(core::QueryConfig::kPreferredOutputBatchRows, largeBatchSize) + .assertResults( + "SELECT t_k1, u_k1 FROM t, u WHERE t_k1 = u_k1 AND t_filter > 100000"); + + ASSERT_TRUE(spillTriggered.load()); +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + HashJoinTest, + HashJoinTest, + testing::ValuesIn(HashJoinTest::getTestParams()), + [](const testing::TestParamInfo& info) { + return TestParamToName(info.param); + }); + } // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/HashJoinWithCacheTest.cpp b/velox/exec/tests/HashJoinWithCacheTest.cpp new file mode 100644 index 000000000000..772c282a7804 --- /dev/null +++ b/velox/exec/tests/HashJoinWithCacheTest.cpp @@ -0,0 +1,574 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/exec/Cursor.h" +#include "velox/exec/HashJoinBridge.h" +#include "velox/exec/HashProbe.h" +#include "velox/exec/HashTable.h" +#include "velox/exec/HashTableCache.h" +#include "velox/exec/MemoryReclaimer.h" +#include "velox/exec/PlanNodeStats.h" + +#include "velox/exec/tests/utils/ArbitratorTestUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/QueryAssertions.h" + +#include "velox/exec/tests/utils/VectorTestUtil.h" + +namespace facebook::velox::exec::test { +namespace { + +// Test fixture for hash join with hash table caching tests. +class HashJoinWithCacheTest : public HiveConnectorTestBase {}; + +// Tests hash table caching for broadcast joins. +// First task builds the table (cache miss), second task reuses it (cache hit). +TEST_F(HashJoinWithCacheTest, sequential) { + // Use a unique query ID for this test to ensure clean cache state. + const std::string queryId = + "hashTableCachingTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create probe and build vectors with distinct column names. + std::vector probeVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + std::vector buildVectors = makeBatches(5, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(50, [](auto row) { return row % 31; }), + makeFlatVector(50, [](auto row) { return row * 10; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + + // Build the plan using HashJoinNode::Builder to set useHashTableCache. + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + // Create HashJoinNode with useHashTableCache = true. + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + const int numDrivers = 3; + + // Create a shared QueryCtx for all tasks. + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Helper to run a task and return the completed task. + // Both tasks use the same queryCtx so they share the cache entry. + auto runTask = [&]() { + return AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(numDrivers) + .assertResults( + "SELECT t.t_k, t.t_v, u.u_k, u.u_v FROM t, u WHERE t.t_k = u.u_k"); + }; + + // First task - should build the table (cache miss). + auto task1 = runTask(); + + // Get stats from first task - expect cache miss. + auto opStats1 = toOperatorStats(task1->taskStats()); + ASSERT_EQ(opStats1.count("HashBuild"), 1); + auto& hashBuildStats1 = opStats1.at("HashBuild"); + + // The last driver that finishes building reports the cache miss stat. + ASSERT_EQ( + hashBuildStats1.runtimeStats.count(BaseHashTable::kHashTableCacheMiss), 1) + << "First task should report cache miss"; + EXPECT_EQ( + hashBuildStats1.runtimeStats.at(BaseHashTable::kHashTableCacheMiss).count, + 1) + << "Exactly one driver should report cache miss (the one that builds)"; + EXPECT_EQ( + hashBuildStats1.runtimeStats.count(BaseHashTable::kHashTableCacheHit), 0) + << "First task should not have any cache hits"; + + // Second task - should reuse the cached table (cache hit). + auto task2 = runTask(); + + // Get stats from second task - expect cache hit. + auto opStats2 = toOperatorStats(task2->taskStats()); + ASSERT_EQ(opStats2.count("HashBuild"), 1); + auto& hashBuildStats2 = opStats2.at("HashBuild"); + + // The last driver that finishes reports the cache hit stat. + ASSERT_EQ( + hashBuildStats2.runtimeStats.count(BaseHashTable::kHashTableCacheHit), 1) + << "Second task should report cache hit"; + EXPECT_EQ( + hashBuildStats2.runtimeStats.at(BaseHashTable::kHashTableCacheHit).count, + 1) + << "Exactly one driver should report cache hit (the one after barrier)"; + EXPECT_EQ( + hashBuildStats2.runtimeStats.count(BaseHashTable::kHashTableCacheMiss), 0) + << "Second task should not have any cache misses"; + + // Clean up cache entry before tasks are destroyed. + // The release callback on QueryCtx fires too late (after Task destruction + // starts), so we need explicit cleanup in tests. + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + HashTableCache::instance()->drop(cacheKey); +} + +// Tests that multiple tasks running concurrently share the cached hash table. +// One task builds the table (cache miss), all others wait and reuse it +// (cache hits). +TEST_F(HashJoinWithCacheTest, concurrent) { + // Use a unique query ID for this test to ensure clean cache state. + const std::string queryId = + "hashTableCachingConcurrentTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create probe and build vectors with distinct column names. + std::vector probeVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + std::vector buildVectors = makeBatches(5, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(50, [](auto row) { return row % 31; }), + makeFlatVector(50, [](auto row) { return row * 10; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + + // Build the plan using HashJoinNode::Builder to set useHashTableCache. + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + // Create HashJoinNode with useHashTableCache = true. + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + const int numDrivers = 3; + + // Create a shared QueryCtx for all tasks. + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Helper to run a task and return the completed task. + // All tasks use the same queryCtx so they share the cache entry. + auto runTask = [&]() { + return AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(numDrivers) + .assertResults( + "SELECT t.t_k, t.t_v, u.u_k, u.u_v FROM t, u WHERE t.t_k = u.u_k"); + }; + + // Run 10 threads concurrently, each executing 5 tasks sequentially. + constexpr int kNumThreads = 10; + constexpr int kTasksPerThread = 5; + constexpr int kTotalTasks = kNumThreads * kTasksPerThread; + + // Each thread maintains its own local vector of tasks. + std::vector>> threadTasks(kNumThreads); + std::vector threads; + threads.reserve(kNumThreads); + + // Launch all threads at once. + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&, i]() { + // Each thread runs multiple tasks sequentially into its local vector. + threadTasks[i].reserve(kTasksPerThread); + for (int j = 0; j < kTasksPerThread; ++j) { + threadTasks[i].push_back(runTask()); + } + }); + } + + // Wait for all threads to complete. + for (auto& thread : threads) { + thread.join(); + } + + // Merge all thread-local task vectors into a single vector. + std::vector> allTasks; + allTasks.reserve(kTotalTasks); + for (auto& tasks : threadTasks) { + for (auto& task : tasks) { + allTasks.push_back(std::move(task)); + } + } + + ASSERT_EQ(allTasks.size(), kTotalTasks); + + // Collect stats from all tasks. + int totalCacheMisses = 0; + int totalCacheHits = 0; + + for (const auto& task : allTasks) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ(opStats.count("HashBuild"), 1); + auto& hashBuildStats = opStats.at("HashBuild"); + + if (hashBuildStats.runtimeStats.count(BaseHashTable::kHashTableCacheMiss)) { + totalCacheMisses += + hashBuildStats.runtimeStats.at(BaseHashTable::kHashTableCacheMiss) + .count; + } + if (hashBuildStats.runtimeStats.count(BaseHashTable::kHashTableCacheHit)) { + totalCacheHits += + hashBuildStats.runtimeStats.at(BaseHashTable::kHashTableCacheHit) + .count; + } + } + + // Exactly one task should build (cache miss) and all others should reuse + // (cache hits). + EXPECT_EQ(totalCacheMisses, 1) + << "Exactly one task should report a cache miss (the builder)"; + EXPECT_EQ(totalCacheHits, kTotalTasks - 1) + << "All other tasks should report cache hits"; + + // Clean up cache entry before tasks are destroyed. + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + HashTableCache::instance()->drop(cacheKey); +} + +// Tests that HashBuild and HashProbe cannot reclaim when using a cached hash +// table. When useHashTableCache() is true, canReclaim() returns false for both +// operators because spilling would clear the cached table and corrupt it for +// other tasks. This test uses TestValue to verify canReclaim() returns false. +DEBUG_ONLY_TEST_F(HashJoinWithCacheTest, probeCannotSpillWithCachedTable) { + const std::string queryId = + "probeCannotSpillTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create build and probe vectors. + std::vector buildVectors = makeBatches(1, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row * 10; }), + }); + }); + + std::vector probeVectors = makeBatches(5, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(100, [](auto row) { return row % 23; }), + makeFlatVector(100, [](auto row) { return row; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + const auto joinNodeId = joinNode->id(); + + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Use TestValue to verify canReclaim() returns false for HashBuild. + std::atomic_bool hashBuildChecked{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashBuild::addInput", + std::function([&](HashBuild* build) { + if (hashBuildChecked.exchange(true)) { + return; + } + ASSERT_FALSE(build->canReclaim()) + << "HashBuild should not be reclaimable with cached hash table"; + })); + + // Use TestValue to verify canReclaim() returns false for HashProbe. + std::atomic_bool hashProbeChecked{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function([&](Operator* op) { + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (hashProbeChecked.exchange(true)) { + return; + } + auto* probe = dynamic_cast(op); + ASSERT_NE(probe, nullptr); + ASSERT_FALSE(probe->canReclaim()) + << "HashProbe should not be reclaimable with cached hash table"; + })); + + // Run the query and verify results. + auto task = + AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .assertResults( + "SELECT t.t_k, t.t_v, u.u_k, u.u_v FROM t, u WHERE t.t_k = u.u_k"); + + // Verify that both operators were checked. + ASSERT_TRUE(hashBuildChecked) << "HashBuild canReclaim check was not reached"; + ASSERT_TRUE(hashProbeChecked) << "HashProbe canReclaim check was not reached"; + + // Verify that HashProbe operator stats show no spilling. + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ(opStats.count("HashProbe"), 1); + auto& probeStats = opStats.at("HashProbe"); + + EXPECT_EQ(probeStats.spilledInputBytes, 0) + << "HashProbe should not spill when using cached hash table"; + EXPECT_EQ(probeStats.spilledBytes, 0) + << "HashProbe should not spill when using cached hash table"; + + // Clean up cache entry before task is destroyed. + const auto cacheKey = fmt::format("{}:{}", queryId, joinNodeId); + HashTableCache::instance()->drop(cacheKey); +} + +// Tests OOM behavior with cached hash tables via memory arbitration. +// This test triggers memory arbitration during HashProbe to verify: +// 1. HashProbe::canReclaim() returns false when useHashTableCache=true +// 2. When an allocation exceeds capacity, arbitration runs but can't reclaim +// 3. OOM is thrown by the arbitration framework +// 4. Cleanup works correctly via QueryCtx release callbacks +DEBUG_ONLY_TEST_F(HashJoinWithCacheTest, probeOOMWithCachedTable) { + const std::string queryId = + "probeOOMTest_" + + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()); + + // Create build side with ~1MB hash table. + // 10 batches × 10,000 rows = 100,000 rows × 12 bytes = ~1.2MB raw data. + std::vector buildVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"u_k", "u_v"}, + { + makeFlatVector(10000, [](auto row) { return row % 5000; }), + makeFlatVector(10000, [](auto row) { return row * 10; }), + }); + }); + + // Create probe vectors with matching key range. + std::vector probeVectors = makeBatches(10, [&](int32_t) { + return makeRowVector( + {"t_k", "t_v"}, + { + makeFlatVector(1000, [](auto row) { return row % 5000; }), + makeFlatVector(1000, [](auto row) { return row; }), + }); + }); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto buildPlanNode = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + auto probePlanNode = + PlanBuilder(planNodeIdGenerator).values(probeVectors).planNode(); + + auto outputType = ROW( + {"t_k", "t_v", "u_k", "u_v"}, {INTEGER(), BIGINT(), INTEGER(), BIGINT()}); + + auto joinNode = + core::HashJoinNode::Builder() + .id(planNodeIdGenerator->next()) + .joinType(core::JoinType::kInner) + .nullAware(false) + .leftKeys( + {std::make_shared(INTEGER(), "t_k")}) + .rightKeys( + {std::make_shared(INTEGER(), "u_k")}) + .left(probePlanNode) + .right(buildPlanNode) + .outputType(outputType) + .useHashTableCache(true) + .build(); + + // Create QueryCtx with sufficient memory for build but we'll exhaust it + // during probe via TestValue injection. + auto queryCtx = core::QueryCtx::create( + driverExecutor_.get(), + core::QueryConfig({}), + std::unordered_map>{}, + cache::AsyncDataCache::getInstance(), + nullptr, + nullptr, + queryId); + + // Use a pool with limited capacity. Build uses ~10MB, so 20MB should be + // enough for build but tight for additional allocations during probe. + constexpr int64_t kPoolCapacity = 20 * 1024 * 1024; // 20MB + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kPoolCapacity, exec::MemoryReclaimer::create())); + + // Use TestValue at the addInput injection point to trigger OOM during + // HashProbe. We allocate more memory than the pool has available, which + // triggers arbitration. Since HashProbe::canReclaim() returns false when + // useHashTableCache=true, arbitration cannot reclaim and OOM is thrown. + std::atomic_bool injected{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function([&](Operator* op) { + // Only inject once, and only for HashProbe operator. + if (!isHashProbeMemoryPool(*op->pool())) { + return; + } + if (injected.exchange(true)) { + return; + } + + auto* probe = dynamic_cast(op); + ASSERT_NE(probe, nullptr); + + // Verify that HashProbe cannot reclaim when using cached hash table. + // canReclaim() returns false because canSpill() is false. + ASSERT_FALSE(probe->canReclaim()) + << "HashProbe should not be reclaimable with cached hash table"; + + // Allocate memory equal to pool capacity. + // If HashProbe could spill, arbitration would reclaim memory and this + // allocation would succeed. But since canReclaim() returns false with + // cached hash table, arbitration can't free memory and OOM is thrown. + auto* pool = op->pool(); + // This allocation will trigger arbitration and throw OOM. + pool->allocate(kPoolCapacity); + })); + + // This should throw OOM during probe. The cleanup should work correctly. + VELOX_ASSERT_THROW( + AssertQueryBuilder(joinNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .copyResults(pool()), + "Exceeded memory pool capacity"); + + waitForAllTasksToBeDeleted(); + queryCtx.reset(); + // Cache should be cleaned up by QueryCtx destructor via release callback. +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HashTableCacheTest.cpp b/velox/exec/tests/HashTableCacheTest.cpp new file mode 100644 index 000000000000..38be0378ba42 --- /dev/null +++ b/velox/exec/tests/HashTableCacheTest.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/HashTableCache.h" + +#include + +#include "velox/common/caching/AsyncDataCache.h" +#include "velox/common/memory/Memory.h" +#include "velox/core/QueryCtx.h" + +namespace facebook::velox::exec::test { + +class HashTableCacheTest : public testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + pool_ = memory::memoryManager()->addRootPool("HashTableCacheTest"); + queryCtx_ = core::QueryCtx::create(); + } + + void TearDown() override { + // Clean up any cache entries created during tests. + auto* cache = HashTableCache::instance(); + for (const auto& key : createdKeys_) { + cache->drop(key); + } + createdKeys_.clear(); + queryCtx_.reset(); + } + + // Helper to track keys for cleanup. + void trackKey(const std::string& key) { + createdKeys_.push_back(key); + } + + std::shared_ptr pool_; + std::shared_ptr queryCtx_; + std::vector createdKeys_; +}; + +TEST_F(HashTableCacheTest, basicGet) { + auto* cache = HashTableCache::instance(); + const std::string key = "query1:node1"; + trackKey(key); + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "task1", queryCtx_.get(), &future); + + ASSERT_NE(entry, nullptr); + ASSERT_NE(entry->tablePool, nullptr); + EXPECT_EQ(entry->builderTaskId, "task1"); + EXPECT_FALSE(entry->buildComplete); + // First caller should not get a future (they are the builder). + EXPECT_FALSE(future.valid()); +} + +TEST_F(HashTableCacheTest, secondCallerGetsWaitFuture) { + auto* cache = HashTableCache::instance(); + const std::string key = "query2:node1"; + trackKey(key); + + // First caller (builder). + ContinueFuture future1 = ContinueFuture::makeEmpty(); + auto entry1 = cache->get(key, "task1", queryCtx_.get(), &future1); + EXPECT_FALSE(future1.valid()); + EXPECT_EQ(entry1->builderTaskId, "task1"); + + // Second caller (waiter). + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry2 = cache->get(key, "task2", queryCtx_.get(), &future2); + + // Should get the same entry. + EXPECT_EQ(entry1, entry2); + // Should get a valid future to wait on. + EXPECT_TRUE(future2.valid()); + // Builder task ID should not change. + EXPECT_EQ(entry2->builderTaskId, "task1"); +} + +TEST_F(HashTableCacheTest, putNotifiesWaiters) { + auto* cache = HashTableCache::instance(); + const std::string key = "query3:node1"; + trackKey(key); + + // First caller (builder). + ContinueFuture future1 = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "task1", queryCtx_.get(), &future1); + + // Second caller (waiter). + ContinueFuture future2 = ContinueFuture::makeEmpty(); + cache->get(key, "task2", queryCtx_.get(), &future2); + ASSERT_TRUE(future2.valid()); + + // Put the table. + cache->put(key, nullptr, false); + + // Entry should now be marked complete. + EXPECT_TRUE(entry->buildComplete); + + // The future should be fulfilled. + EXPECT_TRUE(future2.isReady()); +} + +TEST_F(HashTableCacheTest, getAfterBuildComplete) { + auto* cache = HashTableCache::instance(); + const std::string key = "query4:node1"; + trackKey(key); + + // First caller creates and builds. + ContinueFuture future1 = ContinueFuture::makeEmpty(); + cache->get(key, "task1", queryCtx_.get(), &future1); + cache->put(key, nullptr, true); + + // Later caller should get completed entry without waiting. + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "task2", queryCtx_.get(), &future2); + + EXPECT_TRUE(entry->buildComplete); + EXPECT_FALSE(future2.valid()); // No need to wait. + EXPECT_TRUE(entry->hasNullKeys); +} + +TEST_F(HashTableCacheTest, drop) { + auto* cache = HashTableCache::instance(); + const std::string key = "query5:node1"; + // Don't track - we're testing drop. + + ContinueFuture future = ContinueFuture::makeEmpty(); + auto entry1 = cache->get(key, "task1", queryCtx_.get(), &future); + ASSERT_NE(entry1, nullptr); + + // Keep track of the original pool to verify it's different after flush. + auto originalPool = entry1->tablePool; + + // Drop the entry. + cache->drop(key); + + // Getting the same key should create a new entry. + // Use a new queryCtx with different pool to avoid the leaf child name + // collision. + auto pool2 = memory::memoryManager()->addRootPool("HashTableCacheTest2"); + auto queryCtx2 = core::QueryCtx::create( + nullptr, // executor + core::QueryConfig{{}}, // queryConfig + {}, // connectorConfigs + cache::AsyncDataCache::getInstance(), // cache + pool2); + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry2 = cache->get(key, "task2", queryCtx2.get(), &future2); + + EXPECT_NE(entry1, entry2); + EXPECT_EQ(entry2->builderTaskId, "task2"); + // The pool should be different (created under queryCtx2's pool). + EXPECT_NE(entry1->tablePool, entry2->tablePool); + + // Cleanup. + cache->drop(key); +} + +TEST_F(HashTableCacheTest, concurrentWaiters) { + auto* cache = HashTableCache::instance(); + const std::string key = "query6:node1"; + trackKey(key); + + // First caller (builder). + ContinueFuture builderFuture = ContinueFuture::makeEmpty(); + auto entry = cache->get(key, "builder", queryCtx_.get(), &builderFuture); + + // Multiple waiters. + constexpr int kNumWaiters = 5; + std::vector waiterFutures(kNumWaiters); + + for (int i = 0; i < kNumWaiters; ++i) { + waiterFutures[i] = ContinueFuture::makeEmpty(); + cache->get( + key, fmt::format("waiter{}", i), queryCtx_.get(), &waiterFutures[i]); + EXPECT_TRUE(waiterFutures[i].valid()); + } + + // Put the table. + cache->put(key, nullptr, false); + + // All waiters should be notified. + for (int i = 0; i < kNumWaiters; ++i) { + EXPECT_TRUE(waiterFutures[i].isReady()); + } +} + +TEST_F(HashTableCacheTest, builderTaskDriversDoNotWait) { + auto* cache = HashTableCache::instance(); + const std::string key = "query7:node1"; + trackKey(key); + + // First driver of builder task. + ContinueFuture future1 = ContinueFuture::makeEmpty(); + auto entry1 = cache->get(key, "task1", queryCtx_.get(), &future1); + EXPECT_FALSE(future1.valid()); + + // Second driver of the same builder task should not wait. + ContinueFuture future2 = ContinueFuture::makeEmpty(); + auto entry2 = cache->get(key, "task1", queryCtx_.get(), &future2); + + EXPECT_EQ(entry1, entry2); + // Same task should not get a future - they coordinate via JoinBridge. + EXPECT_FALSE(future2.valid()); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HashTableTest.cpp b/velox/exec/tests/HashTableTest.cpp index bee323d38f8f..390bbf5bda6e 100644 --- a/velox/exec/tests/HashTableTest.cpp +++ b/velox/exec/tests/HashTableTest.cpp @@ -132,8 +132,9 @@ class HashTableTest : public testing::TestWithParam, std::vector batches; std::vector> keyHashers; for (auto channel = 0; channel < numKeys; ++channel) { - keyHashers.emplace_back(std::make_unique( - buildType->childAt(channel), channel)); + keyHashers.emplace_back( + std::make_unique( + buildType->childAt(channel), channel)); } auto table = HashTable::createForJoin( std::move(keyHashers), dependentTypes, true, false, 1'000, pool()); @@ -158,6 +159,8 @@ class HashTableTest : public testing::TestWithParam, topTable_->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); ASSERT_GE( estimatedTableSize, @@ -431,8 +434,9 @@ class HashTableTest : public testing::TestWithParam, TypePtr buildType, std::vector& batches) { for (auto i = 0; i < numBatches; ++i) { - batches.push_back(std::static_pointer_cast( - makeVector(buildType, batchSize, sequence))); + batches.push_back( + std::static_pointer_cast( + makeVector(buildType, batchSize, sequence))); sequence += batchSize; } } @@ -542,7 +546,11 @@ class HashTableTest : public testing::TestWithParam, std::move(hashers), {BIGINT()}, true, false, 1'000, pool()); copyVectorsToTable({batch}, 0, table.get()); table->prepareJoinTable( - {}, BaseHashTable::kNoSpillInputStartPartitionBit, executor_.get()); + {}, + BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, + executor_.get()); ASSERT_EQ(table->hashMode(), mode); std::vector rows(nullValues.size()); BaseHashTable::NullKeyRowsIterator iter; @@ -840,7 +848,11 @@ TEST_P(HashTableTest, regularHashingTableSize) { makeRows(1 << 12, 1, 0, type, batches); copyVectorsToTable(batches, 0, table.get()); table->prepareJoinTable( - {}, BaseHashTable::kNoSpillInputStartPartitionBit, executor_.get()); + {}, + BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, + executor_.get()); ASSERT_EQ(table->hashMode(), mode); EXPECT_GE(table->testingRehashSize(), table->numDistinct()); }; @@ -1146,6 +1158,8 @@ DEBUG_ONLY_TEST_P(HashTableTest, failureInCreateRowPartitions) { topTable->prepareJoinTable( std::move(otherTables), BaseHashTable::kNoSpillInputStartPartitionBit, + 1'000'000, + false, executor_.get()); auto topTabletestHelper = HashTableTestHelper::create(topTable.get()); @@ -1234,7 +1248,8 @@ TEST_P(HashTableTest, toStringSingleKey) { store(*table->rows(), data); - table->prepareJoinTable({}, BaseHashTable::kNoSpillInputStartPartitionBit); + table->prepareJoinTable( + {}, BaseHashTable::kNoSpillInputStartPartitionBit, 1'000'000); ASSERT_NO_THROW(table->toString()); ASSERT_NO_THROW(table->toString(0)); @@ -1265,7 +1280,8 @@ TEST_P(HashTableTest, toStringMultipleKeys) { store(*table->rows(), data); - table->prepareJoinTable({}, BaseHashTable::kNoSpillInputStartPartitionBit); + table->prepareJoinTable( + {}, BaseHashTable::kNoSpillInputStartPartitionBit, 1'000'000); ASSERT_NO_THROW(table->toString()); } diff --git a/velox/exec/tests/HilbertIndexTest.cpp b/velox/exec/tests/HilbertIndexTest.cpp new file mode 100644 index 000000000000..85cc634f253f --- /dev/null +++ b/velox/exec/tests/HilbertIndexTest.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/HilbertIndex.h" +#include +#include + +using namespace ::testing; +using namespace facebook::velox::exec; + +namespace facebook::velox::exec::test { + +class HilbertIndexTest : public virtual testing::Test {}; + +TEST_F(HilbertIndexTest, testOrder) { + HilbertIndex hilbert(0, 0, 4, 4); + + uint32_t h0 = hilbert.indexOf(0.0, 0.0); + uint32_t h1 = hilbert.indexOf(1.0, 1.0); + uint32_t h2 = hilbert.indexOf(1.0, 3.0); + uint32_t h3 = hilbert.indexOf(3.0, 3.0); + uint32_t h4 = hilbert.indexOf(3.0, 1.0); + + ASSERT_LT(h0, h1); + ASSERT_LT(h1, h2); + ASSERT_LT(h2, h3); + ASSERT_LT(h3, h4); +} + +TEST_F(HilbertIndexTest, testOutOfBounds) { + HilbertIndex hilbert(0, 0, 1, 1); + + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testDegenerateRectangle) { + HilbertIndex hilbert(0, 0, 0, 0); + + ASSERT_EQ(hilbert.indexOf(0.0, 0.0), 0); + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testDegenerateHorizontalRectangle) { + HilbertIndex hilbert(0, 0, 4, 0); + + ASSERT_EQ(hilbert.indexOf(0.0, 0.0), 0); + ASSERT_LT(hilbert.indexOf(1.0, 0.0), hilbert.indexOf(2.0, 0.0)); + ASSERT_EQ(hilbert.indexOf(0.0, 2.0), std::numeric_limits::max()); + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testDegenerateVerticalRectangle) { + HilbertIndex hilbert(0, 0, 0, 4); + + ASSERT_EQ(hilbert.indexOf(0.0, 0.0), 0); + ASSERT_LT(hilbert.indexOf(0.0, 1.0), hilbert.indexOf(0.0, 2.0)); + ASSERT_EQ(hilbert.indexOf(2.0, 0.0), std::numeric_limits::max()); + ASSERT_EQ(hilbert.indexOf(2.0, 2.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testNegativeCoordinates) { + HilbertIndex hilbert(-10, -10, 10, 10); + + uint32_t h0 = hilbert.indexOf(-5.0, -5.0); + uint32_t h1 = hilbert.indexOf(0.0, 0.0); + uint32_t h2 = hilbert.indexOf(5.0, 5.0); + + ASSERT_LT(h0, h1); + ASSERT_LT(h1, h2); + + ASSERT_EQ( + hilbert.indexOf(-15.0, -15.0), std::numeric_limits::max()); + ASSERT_EQ(hilbert.indexOf(15.0, 15.0), std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testFloatingPointPrecision) { + HilbertIndex hilbert(0, 0, 1, 1); + + uint32_t h1 = hilbert.indexOf(0.1, 0.1); + uint32_t h2 = hilbert.indexOf(0.2, 0.2); + uint32_t h3 = hilbert.indexOf(0.9, 0.9); + + ASSERT_LT(h1, h2); + ASSERT_LT(h2, h3); +} + +TEST_F(HilbertIndexTest, testBoundaryPoints) { + HilbertIndex hilbert(0, 0, 10, 10); + + uint32_t h0 = hilbert.indexOf(0.0, 0.0); + uint32_t h1 = hilbert.indexOf(10.0, 10.0); + uint32_t h2 = hilbert.indexOf(0.0, 10.0); + // Bottom-right corner is at the end of the range, so may be MAX + + ASSERT_NE(h0, std::numeric_limits::max()); + ASSERT_NE(h1, std::numeric_limits::max()); + ASSERT_NE(h2, std::numeric_limits::max()); +} + +TEST_F(HilbertIndexTest, testLargeCoordinates) { + HilbertIndex hilbert(0, 0, 1000000, 1000000); + + uint32_t h1 = hilbert.indexOf(100000, 100000); + uint32_t h2 = hilbert.indexOf(500000, 500000); + uint32_t h3 = hilbert.indexOf(900000, 900000); + + ASSERT_LT(h1, h2); + ASSERT_LT(h2, h3); +} + +TEST_F(HilbertIndexTest, testDensityInSmallRegion) { + HilbertIndex hilbert(0, 0, 100, 100); + + std::vector indices; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + indices.push_back(hilbert.indexOf(i * 10.0f + 5.0f, j * 10.0f + 5.0f)); + } + } + + std::set uniqueIndices(indices.begin(), indices.end()); + ASSERT_EQ(indices.size(), 100); + ASSERT_GT(uniqueIndices.size(), 90); +} + +TEST_F(HilbertIndexTest, testSpatialLocality) { + HilbertIndex hilbert(0, 0, 100, 100); + + uint32_t h1 = hilbert.indexOf(50.0, 50.0); + uint32_t h2 = hilbert.indexOf(50.1, 50.1); + uint32_t h3 = hilbert.indexOf(50.2, 50.2); + uint32_t h4 = hilbert.indexOf(90.0, 90.0); + + uint32_t diff12 = std::abs(static_cast(h1 - h2)); + uint32_t diff23 = std::abs(static_cast(h2 - h3)); + uint32_t diff14 = std::abs(static_cast(h1 - h4)); + + ASSERT_LT(diff12, diff14); + ASSERT_LT(diff23, diff14); +} + +TEST_F(HilbertIndexTest, testIdenticalPoints) { + HilbertIndex hilbert(0, 0, 10, 10); + + uint32_t h1 = hilbert.indexOf(5.0, 5.0); + uint32_t h2 = hilbert.indexOf(5.0, 5.0); + + ASSERT_EQ(h1, h2); +} + +TEST_F(HilbertIndexTest, testExtremelySmallBounds) { + HilbertIndex hilbert(0, 0, 0.001, 0.001); + + uint32_t h1 = hilbert.indexOf(0.0, 0.0); + uint32_t h2 = hilbert.indexOf(0.0005, 0.0005); + + ASSERT_NE(h1, std::numeric_limits::max()); + ASSERT_NE(h2, std::numeric_limits::max()); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/IndexLookupJoinTest.cpp b/velox/exec/tests/IndexLookupJoinTest.cpp index 857ee78d61a8..34179e0dcf1c 100644 --- a/velox/exec/tests/IndexLookupJoinTest.cpp +++ b/velox/exec/tests/IndexLookupJoinTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/IndexLookupJoin.h" +#include "fmt/format.h" #include "folly/experimental/EventCount.h" #include "gmock/gmock.h" #include "gtest/gtest-matchers.h" @@ -34,24 +35,31 @@ using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; using namespace facebook::velox::common::testutil; -namespace fecebook::velox::exec::test { +namespace facebook::velox::exec::test { namespace { struct TestParam { bool asyncLookup; int32_t numPrefetches; bool serialExecution; + bool hasNullKeys; - TestParam(bool _asyncLookup, int32_t _numPrefetches, bool _serialExecution) + TestParam( + bool _asyncLookup, + int32_t _numPrefetches, + bool _serialExecution, + bool _hasNullKeys) : asyncLookup(_asyncLookup), numPrefetches(_numPrefetches), - serialExecution(_serialExecution) {} + serialExecution(_serialExecution), + hasNullKeys(_hasNullKeys) {} std::string toString() const { return fmt::format( - "asyncLookup={}, numPrefetches={}, serialExecution={}", + "asyncLookup={}, numPrefetches={}, serialExecution={}, hasNullKeys={}", asyncLookup, numPrefetches, - serialExecution); + serialExecution, + hasNullKeys); } }; @@ -60,36 +68,27 @@ class IndexLookupJoinTest : public IndexLookupJoinTestBase, public: static std::vector getTestParams() { std::vector testParams; - testParams.emplace_back(true, 0, true); - testParams.emplace_back(true, 0, false); - testParams.emplace_back(false, 0, true); - testParams.emplace_back(false, 0, false); - testParams.emplace_back(true, 3, true); - testParams.emplace_back(true, 3, false); - testParams.emplace_back(false, 3, true); - testParams.emplace_back(false, 3, false); + for (bool asyncLookup : {false, true}) { + for (int numPrefetches : {0, 3}) { + for (bool serialExecution : {false, true}) { + for (bool hasNullKeys : {false, true}) { + testParams.emplace_back( + asyncLookup, numPrefetches, serialExecution, hasNullKeys); + } + } + } + } return testParams; } protected: - IndexLookupJoinTest() = default; - void SetUp() override { HiveConnectorTestBase::SetUp(); core::PlanNode::registerSerDe(); connector::hive::HiveColumnHandle::registerSerDe(); Type::registerSerDe(); core::ITypedExpr::registerSerDe(); - connector::registerConnectorFactory( - std::make_shared()); - std::shared_ptr connector = - connector::getConnectorFactory(kTestIndexConnectorName) - ->newConnector( - kTestIndexConnectorName, - {}, - nullptr, - connectorCpuExecutor_.get()); - connector::registerConnector(connector); + TestIndexConnectorFactory::registerConnector(connectorCpuExecutor_.get()); keyType_ = ROW({"u0", "u1", "u2"}, {BIGINT(), BIGINT(), BIGINT()}); valueType_ = ROW({"u3", "u4", "u5"}, {BIGINT(), BIGINT(), VARCHAR()}); @@ -97,12 +96,9 @@ class IndexLookupJoinTest : public IndexLookupJoinTestBase, probeType_ = ROW( {"t0", "t1", "t2", "t3", "t4", "t5"}, {BIGINT(), BIGINT(), BIGINT(), BIGINT(), ARRAY(BIGINT()), VARCHAR()}); - - TestIndexTableHandle::registerSerDe(); } void TearDown() override { - connector::unregisterConnectorFactory(kTestIndexConnectorName); connector::unregisterConnector(kTestIndexConnectorName); HiveConnectorTestBase::TearDown(); } @@ -113,88 +109,30 @@ class IndexLookupJoinTest : public IndexLookupJoinTestBase, ASSERT_EQ(plan->toString(true, true), copy->toString(true, true)); } - // Create index table with the given key and value inputs. - std::shared_ptr createIndexTable( - int numEqualJoinKeys, - const RowVectorPtr& keyData, - const RowVectorPtr& valueData) { - const auto keyType = - std::dynamic_pointer_cast(keyData->type()); - VELOX_CHECK_GE(keyType->size(), 1); - VELOX_CHECK_GE(keyType->size(), numEqualJoinKeys); - auto valueType = - std::dynamic_pointer_cast(valueData->type()); - VELOX_CHECK_GE(valueType->size(), 1); - const auto numRows = keyData->size(); - VELOX_CHECK_EQ(numRows, valueData->size()); - - std::vector> hashers; - hashers.reserve(numEqualJoinKeys); - std::vector keyVectors; - keyVectors.reserve(numEqualJoinKeys); - for (auto i = 0; i < numEqualJoinKeys; ++i) { - hashers.push_back(std::make_unique(keyType->childAt(i), i)); - keyVectors.push_back(keyData->childAt(i)); - } - - std::vector dependentTypes; - std::vector dependentVectors; - for (int i = numEqualJoinKeys; i < keyType->size(); ++i) { - dependentTypes.push_back(keyType->childAt(i)); - dependentVectors.push_back(keyData->childAt(i)); - } - for (int i = 0; i < valueType->size(); ++i) { - dependentTypes.push_back(valueType->childAt(i)); - dependentVectors.push_back(valueData->childAt(i)); - } - - // Create the table. - auto table = HashTable::createForJoin( - std::move(hashers), - /*dependentTypes=*/dependentTypes, - /*allowDuplicates=*/true, - /*hasProbedFlag=*/false, - /*minTableSizeForParallelJoinBuild=*/1, - pool_.get()); - - // Insert data into the row container. - auto* rowContainer = table->rows(); - std::vector decodedVectors; - for (auto& vector : keyData->children()) { - decodedVectors.emplace_back(*vector); - } - for (auto& vector : valueData->children()) { - decodedVectors.emplace_back(*vector); - } - - for (auto row = 0; row < numRows; ++row) { - auto* newRow = rowContainer->newRow(); - - for (auto col = 0; col < decodedVectors.size(); ++col) { - rowContainer->store(decodedVectors[col], row, newRow, col); - } - } - - // Build the table index. - table->prepareJoinTable({}, BaseHashTable::kNoSpillInputStartPartitionBit); - return std::make_shared( - std::move(keyType), std::move(valueType), std::move(table)); - } - // Makes index table handle with the specified index table and async lookup // flag. - std::shared_ptr makeIndexTableHandle( + static std::shared_ptr makeIndexTableHandle( const std::shared_ptr& indexTable, bool asyncLookup) { return std::make_shared( kTestIndexConnectorName, indexTable, asyncLookup); } + static connector::ColumnHandleMap makeIndexColumnHandles( + const std::vector& names) { + connector::ColumnHandleMap handles; + for (const auto& name : names) { + handles.emplace(name, std::make_shared(name)); + } + + return handles; + } + const std::unique_ptr connectorCpuExecutor_{ std::make_unique(128)}; }; -TEST_P(IndexLookupJoinTest, joinCondition) { +TEST_F(IndexLookupJoinTest, joinCondition) { const auto rowType = ROW({"c0", "c1", "c2", "c3", "c4"}, {BIGINT(), BIGINT(), BIGINT(), ARRAY(BIGINT()), BIGINT()}); @@ -207,9 +145,7 @@ TEST_P(IndexLookupJoinTest, joinCondition) { auto inFilterCondition = PlanBuilder::parseIndexJoinCondition( "contains(ARRAY[1,2], c2)", rowType, pool_.get()); ASSERT_TRUE(inFilterCondition->isFilter()); - ASSERT_EQ( - inFilterCondition->toString(), - "ROW[\"c2\"] IN 2 elements starting at 0 {1, 2}"); + ASSERT_EQ(inFilterCondition->toString(), "ROW[\"c2\"] IN {1, 2}"); auto betweenFilterCondition = PlanBuilder::parseIndexJoinCondition( "c0 between 0 AND 1", rowType, pool_.get()); @@ -236,13 +172,17 @@ TEST_P(IndexLookupJoinTest, joinCondition) { ASSERT_EQ( betweenJoinCondition3->toString(), "ROW[\"c0\"] BETWEEN ROW[\"c1\"] AND 0"); + + auto equalFilterCondition = + PlanBuilder::parseIndexJoinCondition("c0=1", rowType, pool_.get()); + ASSERT_TRUE(equalFilterCondition->isFilter()); + ASSERT_EQ(equalFilterCondition->toString(), "ROW[\"c0\"] = 1"); } TEST_P(IndexLookupJoinTest, planNodeAndSerde) { TestIndexTableHandle::registerSerDe(); - auto indexConnectorHandle = std::make_shared( - kTestIndexConnectorName, nullptr, true); + auto indexConnectorHandle = makeIndexTableHandle(nullptr, true); auto left = makeRowVector( {"t0", "t1", "t2", "t3", "t4"}, @@ -269,7 +209,7 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { auto planBuilder = PlanBuilder(); auto nonIndexTableScan = std::dynamic_pointer_cast( PlanBuilder::TableScanBuilder(planBuilder) - .outputType(std::dynamic_pointer_cast(right->type())) + .outputType(asRowType(right->type())) .endTableScan() .planNode()); VELOX_CHECK_NOT_NULL(nonIndexTableScan); @@ -277,7 +217,7 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { auto indexTableScan = std::dynamic_pointer_cast( PlanBuilder::TableScanBuilder(planBuilder) .tableHandle(indexConnectorHandle) - .outputType(std::dynamic_pointer_cast(right->type())) + .outputType(asRowType(right->type())) .endTableScan() .planNode()); VELOX_CHECK_NOT_NULL(indexTableScan); @@ -285,13 +225,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {}, - {"t0", "u1", "t2", "t1"}, - joinType) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); @@ -303,6 +243,101 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { } // with in join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "contains(t4, u1)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with between join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions( + {"u0 between t0 AND t1", + "u1 between t1 AND 10", + "u1 between 10 AND t1"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 3); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with mix join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "u1 between 10 AND t1"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(joinType) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with has match column. + { + auto plan = + PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t3, u0)", "u1 between 10 AND t1"}) + .hasMarker(true) + .outputLayout({"t0", "u1", "t2", "t1", "match"}) + .joinType(core::JoinType::kLeft) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with filter. for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -310,20 +345,25 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"contains(t3, u0)", "contains(t4, u1)"}, + {}, + /*filter=*/"t1 % 2 = 0", + /*hasMarker=*/false, {"t0", "u1", "t2", "t1"}, joinType) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_TRUE(indexLookupJoinNode->joinConditions().empty()); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), "eq(mod(ROW[\"t1\"],2),0)"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); testSerde(plan); } - // with between join conditions. + // with join conditions and filter. for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) @@ -331,36 +371,73 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0"}, {"u0"}, indexTableScan, - {"u0 between t0 AND t1", - "u1 between t1 AND 10", - "u1 between 10 AND t1"}, + {"contains(t3, u0)"}, + /*filter=*/"u1 % 2 = 0 AND t2 > 5", + /*hasMarker=*/false, {"t0", "u1", "t2", "t1"}, joinType) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 3); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 1); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "and(eq(mod(ROW[\"u1\"],2),0),gt(ROW[\"t2\"],5))"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); testSerde(plan); } - // with mix join conditions. - for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + // with filter and marker for left join. + { auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) .indexLookupJoin( {"t0"}, {"u0"}, indexTableScan, - {"contains(t3, u0)", "u1 between 10 AND t1"}, - {"t0", "u1", "t2", "t1"}, - joinType) + {"u1 between 10 AND t1"}, + /*filter=*/"t2 < u2", + /*hasMarker=*/true, + {"t0", "u1", "t2", "t1", "match"}, + core::JoinType::kLeft) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 1); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "lt(ROW[\"t2\"],ROW[\"u2\"])"); + ASSERT_TRUE(indexLookupJoinNode->hasMarker()); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with complex filter expression. + { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .filter("(t1 + u1) * 2 > 100 OR t2 = u2") + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(core::JoinType::kInner) + .endIndexLookupJoin() + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_TRUE(indexLookupJoinNode->joinConditions().empty()); + ASSERT_NE(indexLookupJoinNode->filter(), nullptr); + ASSERT_EQ( + indexLookupJoinNode->filter()->toString(), + "or(gt(multiply(plus(ROW[\"t1\"],ROW[\"u1\"]),2),100),eq(ROW[\"t2\"],ROW[\"u2\"]))"); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); @@ -372,13 +449,13 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_USER_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, - {"u0"}, - indexTableScan, - {}, - {"t0", "u1", "t2", "t1"}, - core::JoinType::kFull) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .joinType(core::JoinType::kFull) + .endIndexLookupJoin() .planNode(), "Unsupported index lookup join type FULL"); } @@ -388,8 +465,12 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_USER_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0"}, {"u0"}, nonIndexTableScan, {}, {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(nonIndexTableScan) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), "The lookup table handle hive_table from connector test-hive doesn't support index lookup"); } @@ -399,14 +480,15 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {"t0", "t1"}, - {"u0"}, - indexTableScan, - {"contains(t4, u0)"}, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({"t0", "t1"}) + .rightKeys({"u0"}) + .indexSource(indexTableScan) + .joinConditions({"contains(t4, u0)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), - "JoinNode requires same number of join keys on left and right sides"); + "The index lookup join node requires same number of join keys on left and right sides"); } // No join keys. @@ -414,14 +496,15 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { VELOX_ASSERT_THROW( PlanBuilder(planNodeIdGenerator) .values({left}) - .indexLookupJoin( - {}, - {}, - indexTableScan, - {"contains(t4, u0)"}, - {"t0", "u1", "t2", "t1"}) + .startIndexLookupJoin() + .leftKeys({}) + .rightKeys({}) + .indexSource(indexTableScan) + .joinConditions({"contains(t4, u0)"}) + .outputLayout({"t0", "u1", "t2", "t1"}) + .endIndexLookupJoin() .planNode(), - "JoinNode requires at least one join key"); + "The index lookup join node requires at least one join key"); } } @@ -447,7 +530,7 @@ TEST_P(IndexLookupJoinTest, equalJoin) { matchPct, folly::join(",", scanOutputColumns), folly::join(",", outputColumns), - core::joinTypeName(joinType), + core::JoinTypeName::toName(joinType), duckDbVerifySql); } } testSettings[] = { @@ -643,7 +726,6 @@ TEST_P(IndexLookupJoinTest, equalJoin) { {"t1", "u1", "u2", "u3", "u5"}, core::JoinType::kLeft, "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0"}, - // 10% match with larger lookup table. {{500, 1, 1}, 10, @@ -787,6 +869,7 @@ TEST_P(IndexLookupJoinTest, equalJoin) { tableData, pool_, {"t0", "t1", "t2"}, + GetParam().hasNullKeys, {}, {}, testData.matchPct); @@ -796,18 +879,19 @@ TEST_P(IndexLookupJoinTest, equalJoin) { createDuckDbTable("t", probeVectors); createDuckDbTable("u", {tableData.tableData}); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyData, + tableData.valueData, + *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, makeScanOutputType(testData.scanOutputColumns), - columnHandles); + makeIndexColumnHandles(testData.scanOutputColumns)); auto plan = makeLookupPlan( planNodeIdGenerator, @@ -815,6 +899,8 @@ TEST_P(IndexLookupJoinTest, equalJoin) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -825,6 +911,22 @@ TEST_P(IndexLookupJoinTest, equalJoin) { 32, GetParam().numPrefetches, testData.duckDbVerifySql); + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); } } @@ -850,7 +952,7 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { betweenMatchPct, folly::join(",", lookupOutputColumns), folly::join(",", outputColumns), - core::joinTypeName(joinType), + core::JoinTypeName::toName(joinType), duckDbVerifySql); } } testSettings[] = {// Inner join. @@ -1240,10 +1342,11 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { tableData, pool_, {"t0", "t1"}, + GetParam().hasNullKeys, {}, {{"t2", "t3"}}, - /*eqaulityMatchPct=*/80, - /*inColumns=*/std::nullopt, + /*equalMatchPct=*/80, + /*inMatchPct=*/std::nullopt, testData.betweenMatchPct); std::vector> probeFiles = createProbeFiles(probeVectors); @@ -1251,18 +1354,19 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { createDuckDbTable("t", probeVectors); createDuckDbTable("u", {tableData.tableData}); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/2, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/2, + tableData.keyData, + tableData.valueData, + *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, makeScanOutputType(testData.lookupOutputColumns), - columnHandles); + makeIndexColumnHandles(testData.lookupOutputColumns)); auto plan = makeLookupPlan( planNodeIdGenerator, @@ -1270,6 +1374,8 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.betweenCondition}, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1280,6 +1386,22 @@ TEST_P(IndexLookupJoinTest, betweenJoinCondition) { 32, GetParam().numPrefetches, testData.duckDbVerifySql); + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1"}, + {"u0", "u1"}, + {testData.betweenCondition}, + /*filter=*/"", + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); } } @@ -1305,7 +1427,7 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { inMatchPct, folly::join(",", lookupOutputColumns), folly::join(",", outputColumns), - core::joinTypeName(joinType), + core::JoinTypeName::toName(joinType), duckDbVerifySql); } } testSettings[] = { @@ -1562,9 +1684,10 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { tableData, pool_, {"t0", "t1"}, + GetParam().hasNullKeys, {{"t4"}}, {}, - /*eqaulityMatchPct=*/80, + /*equalMatchPct=*/80, testData.inMatchPct); std::vector> probeFiles = createProbeFiles(probeVectors); @@ -1572,18 +1695,19 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { createDuckDbTable("t", probeVectors); createDuckDbTable("u", {tableData.tableData}); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/2, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/2, + tableData.keyData, + tableData.valueData, + *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, makeScanOutputType(testData.lookupOutputColumns), - columnHandles); + makeIndexColumnHandles(testData.lookupOutputColumns)); auto plan = makeLookupPlan( planNodeIdGenerator, @@ -1591,6 +1715,8 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { {"t0", "t1"}, {"u0", "u1"}, {testData.inCondition}, + /*filter=*/"", + /*hasMarker=*/false, testData.joinType, testData.outputColumns); runLookupQuery( @@ -1601,88 +1727,552 @@ TEST_P(IndexLookupJoinTest, inJoinCondition) { 32, GetParam().numPrefetches, testData.duckDbVerifySql); + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1"}, + {"u0", "u1"}, + {testData.inCondition}, + /*filter=*/"", + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); } } -DEBUG_ONLY_TEST_P(IndexLookupJoinTest, connectorError) { - SequenceTableData tableData; - generateIndexTableData({100, 1, 1}, tableData, pool_); - const std::vector probeVectors = generateProbeInput( - 20, 100, 1, tableData, pool_, {"t0", "t1", "t2"}, {}, {}, 100); - std::vector> probeFiles = - createProbeFiles(probeVectors); +TEST_P(IndexLookupJoinTest, prefixKeysEqualJoin) { + struct { + std::vector keyCardinalities; + int numProbeBatches; + int numRowsPerProbeBatch; + int matchPct; + int numKeysToUse; // Number of keys to use from the full key set + std::vector scanOutputColumns; + std::vector outputColumns; + core::JoinType joinType; + std::string duckDbVerifySql; - const std::string errorMsg{"injectedError"}; - std::atomic_int lookupCount{0}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::test::TestIndexSource::ResultIterator::syncLookup", - std::function([&](void*) { - // Triggers error in the middle. - if (lookupCount++ == 10) { - VELOX_FAIL(errorMsg); - } - })); + std::string debugString() const { + return fmt::format( + "keyCardinalities: {}, numProbeBatches: {}, numRowsPerProbeBatch: {}, matchPct: {}, numKeysToUse: {}, " + "scanOutputColumns: {}, outputColumns: {}, joinType: {}," + " duckDbVerifySql: {}", + folly::join(",", keyCardinalities), + numProbeBatches, + numRowsPerProbeBatch, + matchPct, + numKeysToUse, + folly::join(",", scanOutputColumns), + folly::join(",", outputColumns), + core::JoinTypeName::toName(joinType), + duckDbVerifySql); + } + } testSettings[] = { + // Inner join with 2 out of 3 keys + {{100, 1, 1}, + 5, + 100, + 80, + 2, + {"u0", "u1", "u2", "u3", "u5"}, + {"t0", "t1", "u0", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "SELECT t.c0, t.c1, u.c0, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1"}, + // Left join with 2 out of 3 keys + {{100, 1, 1}, + 5, + 100, + 80, + 2, + {"u0", "u1", "u2", "u3", "u5"}, + {"t0", "t1", "u0", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "SELECT t.c0, t.c1, u.c0, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 "}, + // Inner join with 1 out of 3 keys + {{100, 1, 1}, + 5, + 100, + 80, + 1, + {"u0", "u1", "u2", "u3", "u5"}, + {"t0", "u0", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "SELECT t.c0, u.c0, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 "}, + // Left join with 1 out of 3 keys + {{100, 1, 1}, + 5, + 100, + 80, + 1, + {"u0", "u1", "u2", "u3", "u5"}, + {"t0", "u0", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "SELECT t.c0, u.c0, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 "}}; - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData); - const auto indexTableHandle = - makeIndexTableHandle(indexTable, GetParam().asyncLookup); - auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; - const auto indexScanNode = makeIndexScanNode( - planNodeIdGenerator, - indexTableHandle, - makeScanOutputType({"u0", "u1", "u2", "u5"}), - columnHandles); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); - auto plan = makeLookupPlan( - planNodeIdGenerator, - indexScanNode, - {"t0", "t1", "t2"}, - {"u0", "u1", "u2"}, - {}, - core::JoinType::kInner, - {"u0", "u1", "u2", "t5"}); - VELOX_ASSERT_THROW( - runLookupQuery( - plan, - probeFiles, - GetParam().serialExecution, - GetParam().serialExecution, - 100, - GetParam().numPrefetches, - "SELECT u.c0, u.c1, t.c2, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"), - errorMsg); -} + SequenceTableData tableData; + generateIndexTableData(testData.keyCardinalities, tableData, pool_); -DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { - if (!GetParam().asyncLookup || GetParam().serialExecution) { - // This test only works for async lookup. - return; - } - SequenceTableData tableData; - generateIndexTableData({100, 1, 1}, tableData, pool_); - const int numProbeBatches{20}; - ASSERT_GT(numProbeBatches, GetParam().numPrefetches); - const std::vector probeVectors = generateProbeInput( - numProbeBatches, - 100, - 1, - tableData, - pool_, - {"t0", "t1", "t2"}, - {}, - {}, - 100); - std::vector> probeFiles = - createProbeFiles(probeVectors); - createDuckDbTable("t", probeVectors); - createDuckDbTable("u", {tableData.tableData}); + // Generate probe vectors with only the prefix of keys + std::vector probeKeys; + for (int i = 0; i < testData.numKeysToUse; ++i) { + probeKeys.push_back(fmt::format("t{}", i)); + } - std::atomic_int lookupCount{0}; - folly::EventCount asyncLookupWait; - std::atomic_bool asyncLookupWaitFlag{true}; + auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numRowsPerProbeBatch, + 1, + tableData, + pool_, + probeKeys, + GetParam().hasNullKeys, + {}, + {}, + testData.matchPct); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/testData.numKeysToUse, + tableData.keyData, + tableData.valueData, + *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(testData.scanOutputColumns), + makeIndexColumnHandles(testData.scanOutputColumns)); + + // Create left and right join keys with only the prefix + std::vector leftKeys; + std::vector rightKeys; + for (int i = 0; i < testData.numKeysToUse; ++i) { + leftKeys.push_back(fmt::format("t{}", i)); + rightKeys.push_back(fmt::format("u{}", i)); + } + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + leftKeys, + rightKeys, + {}, + /*filter=*/"", + /*hasMarker=*/false, + testData.joinType, + testData.outputColumns); + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + testData.duckDbVerifySql); + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + leftKeys, + rightKeys, + {}, + /*filter=*/"", + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + } +} + +TEST_P(IndexLookupJoinTest, prefixKeysbetweenJoinCondition) { + struct { + std::vector keyCardinalities; + int numProbeBatches; + int numProbeRowsPerBatch; + std::string betweenCondition; + int betweenMatchPct; + std::vector lookupOutputColumns; + std::vector outputColumns; + core::JoinType joinType; + std::string duckDbVerifySql; + + std::string debugString() const { + return fmt::format( + "keyCardinalities: {}, numProbeBatches: {}, numProbeRowsPerBatch: {}, betweenCondition: {}, betweenMatchPct: {}, " + "lookupOutputColumns: {}, outputColumns: {}, joinType: {}, duckDbVerifySql: {}", + folly::join(",", keyCardinalities), + numProbeBatches, + numProbeRowsPerBatch, + betweenCondition, + betweenMatchPct, + folly::join(",", lookupOutputColumns), + folly::join(",", outputColumns), + core::JoinTypeName::toName(joinType), + duckDbVerifySql); + } + } testSettings[] = { + {{50, 1, 10}, + 5, + 100, + "u1 between t1 and t3", + 10, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "t3", "u1", "u3"}, + core::JoinType::kInner, + "SELECT t.c0, t.c1, t.c3, u.c1, u.c3 FROM t, u WHERE t.c0 = u.c0 AND u.c1 BETWEEN t.c1 AND t.c3"}, + {{50, 1, 10}, + 5, + 100, + "u1 between t1 and t3", + 10, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "t3", "u1", "u3"}, + core::JoinType::kLeft, + "SELECT t.c0, t.c1, t.c3, u.c1, u.c3 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c1 BETWEEN t.c1 AND t.c3"}, + {{50, 1, 10}, + 5, + 100, + "u1 between 0 and t1", + 80, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "u1", "u3"}, + core::JoinType::kInner, + "SELECT t.c0, t.c1, u.c1, u.c3 FROM t, u WHERE t.c0 = u.c0 AND u.c1 BETWEEN 0 AND t.c1"}, + {{50, 1, 10}, + 5, + 100, + "u1 between 0 and t1", + 80, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "u1", "u3"}, + core::JoinType::kLeft, + "SELECT t.c0, t.c1, u.c1, u.c3 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c1 BETWEEN 0 AND t.c1"}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + SequenceTableData tableData; + generateIndexTableData(testData.keyCardinalities, tableData, pool_); + + auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numProbeRowsPerBatch, + 1, + tableData, + pool_, + {"t0"}, + GetParam().hasNullKeys, + {}, + {{"t1", "t2"}}, + /*equalMatchPct=*/80, + /*inMatchPct=*/std::nullopt, + testData.betweenMatchPct); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, + tableData.keyData, + tableData.valueData, + *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(testData.lookupOutputColumns), + makeIndexColumnHandles(testData.lookupOutputColumns)); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {testData.betweenCondition}, + /*filter=*/"", + /*hasMarker=*/false, + testData.joinType, + testData.outputColumns); + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + testData.duckDbVerifySql); + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {testData.betweenCondition}, + /*filter=*/"", + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + } +} + +TEST_P(IndexLookupJoinTest, prefixInJoinCondition) { + struct { + std::vector keyCardinalities; + int numProbeBatches; + int numProbeRowsPerBatch; + std::string inCondition; + int inMatchPct; + std::vector lookupOutputColumns; + std::vector outputColumns; + core::JoinType joinType; + std::string duckDbVerifySql; + + std::string debugString() const { + return fmt::format( + "keyCardinalities: {}: numProbeBatches: {}, numProbeRowsPerBatch: {}, inCondition: {}, inMatchPct: {}, lookupOutputColumns: {}, outputColumns: {}, joinType: {}, duckDbVerifySql: {}", + folly::join(",", keyCardinalities), + numProbeBatches, + numProbeRowsPerBatch, + inCondition, + inMatchPct, + folly::join(",", lookupOutputColumns), + folly::join(",", outputColumns), + core::JoinTypeName::toName(joinType), + duckDbVerifySql); + } + } testSettings[] = { + // Inner join with prefix keys (c0, c1) + {{50, 1, 10}, + 1, + 100, + "contains(t4, u1)", + 10, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "t4", "u1", "u3"}, + core::JoinType::kInner, + "SELECT t.c0, t.c1, t.c4, u.c1, u.c3 FROM t, u WHERE t.c0 = u.c0 AND array_contains(t.c4, u.c1)"}, + // Left join with prefix keys (c0, c1) + {{50, 1, 10}, + 1, + 100, + "contains(t4, u1)", + 10, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "t4", "u1", "u3"}, + core::JoinType::kLeft, + "SELECT t.c0, t.c1, t.c4, u.c1, u.c3 FROM t LEFT JOIN u ON t.c0 = u.c0 AND array_contains(t.c4, u.c1)"}, + // Inner join with prefix keys (c0, c1) - higher match percentage + {{50, 1, 10}, + 10, + 100, + "contains(t4, u1)", + 80, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "t4", "u1", "u3"}, + core::JoinType::kInner, + "SELECT t.c0, t.c1, t.c4, u.c1, u.c3 FROM t, u WHERE t.c0 = u.c0 AND array_contains(t.c4, u.c1)"}, + // Left join with prefix keys (c0, c1) - higher match percentage + {{50, 1, 10}, + 10, + 100, + "contains(t4, u1)", + 80, + {"u0", "u1", "u2", "u3"}, + {"t0", "t1", "t4", "u1", "u3"}, + core::JoinType::kLeft, + "SELECT t.c0, t.c1, t.c4, u.c1, u.c3 FROM t LEFT JOIN u ON t.c0 = u.c0 AND array_contains(t.c4, u.c1)"}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + SequenceTableData tableData; + generateIndexTableData(testData.keyCardinalities, tableData, pool_); + auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numProbeRowsPerBatch, + 1, + tableData, + pool_, + {"t0"}, + GetParam().hasNullKeys, + {{"t4"}}, + {}, + /*equalMatchPct=*/80, + testData.inMatchPct); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, + tableData.keyData, + tableData.valueData, + *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(testData.lookupOutputColumns), + makeIndexColumnHandles(testData.lookupOutputColumns)); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {testData.inCondition}, + /*filter=*/"", + /*hasMarker=*/false, + testData.joinType, + testData.outputColumns); + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + testData.duckDbVerifySql); + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {testData.inCondition}, + /*filter=*/"", + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + } +} + +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, connectorError) { + SequenceTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const std::vector probeVectors = generateProbeInput( + 20, + 100, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + const std::string errorMsg{"injectedError"}; + std::atomic_int lookupCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::test::TestIndexSource::ResultIterator::syncLookup", + std::function([&](void*) { + // Triggers error in the middle. + if (lookupCount++ == 10) { + VELOX_FAIL(errorMsg); + } + })); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u0", "u1", "u2", "t5"}); + VELOX_ASSERT_THROW( + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 100, + GetParam().numPrefetches, + "SELECT u.c0, u.c1, t.c2, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"), + errorMsg); +} + +DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { + if (!GetParam().asyncLookup || GetParam().serialExecution) { + // This test only works for async lookup. + return; + } + SequenceTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeBatches{20}; + ASSERT_GT(numProbeBatches, GetParam().numPrefetches); + const std::vector probeVectors = generateProbeInput( + numProbeBatches, + 100, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + std::atomic_int lookupCount{0}; + folly::EventCount asyncLookupWait; + std::atomic_bool asyncLookupWaitFlag{true}; SCOPED_TESTVALUE_SET( "facebook::velox::exec::test::TestIndexSource::ResultIterator::asyncLookup", std::function([&](void*) { @@ -1692,18 +2282,16 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { asyncLookupWait.await([&] { return !asyncLookupWaitFlag.load(); }); })); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), - columnHandles); + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); auto plan = makeLookupPlan( planNodeIdGenerator, @@ -1711,6 +2299,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); std::thread queryThread([&] { @@ -1733,7 +2323,7 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, prefetch) { queryThread.join(); } -TEST_P(IndexLookupJoinTest, outputBatchSize) { +TEST_P(IndexLookupJoinTest, outputBatchSizeWithInnerJoin) { SequenceTableData tableData; generateIndexTableData({3'000, 1, 1}, tableData, pool_); @@ -1741,27 +2331,37 @@ TEST_P(IndexLookupJoinTest, outputBatchSize) { int numProbeBatches; int numRowsPerProbeBatch; int maxBatchRows; + bool splitOutput; int numExpectedOutputBatch; std::string debugString() const { return fmt::format( - "numProbeBatches: {}, numRowsPerProbeBatch: {}, maxBatchRows: {}, numExpectedOutputBatch: {}", + "numProbeBatches: {}, numRowsPerProbeBatch: {}, maxBatchRows: {}, splitOutput: {}, numExpectedOutputBatch: {}", numProbeBatches, numRowsPerProbeBatch, maxBatchRows, + splitOutput, numExpectedOutputBatch); } } testSettings[] = { - {10, 100, 10, 100}, - {10, 500, 10, 500}, - {10, 1, 200, 10}, - {1, 500, 10, 50}, - {1, 300, 10, 30}, - {1, 500, 200, 3}, - {10, 200, 200, 10}, - {10, 500, 300, 20}, - {10, 50, 1, 500}}; - + {10, 100, 10, false, 10}, + {10, 500, 10, false, 10}, + {10, 1, 200, false, 10}, + {1, 500, 10, false, 1}, + {1, 300, 10, false, 1}, + {1, 500, 200, false, 1}, + {10, 200, 200, false, 10}, + {10, 500, 300, false, 10}, + {10, 50, 1, false, 10}, + {10, 100, 10, true, 100}, + {10, 500, 10, true, 500}, + {10, 1, 200, true, 10}, + {1, 500, 10, true, 50}, + {1, 300, 10, true, 30}, + {1, 500, 200, true, 3}, + {10, 200, 200, true, 10}, + {10, 500, 300, true, 20}, + {10, 50, 1, true, 500}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); @@ -1772,6 +2372,7 @@ TEST_P(IndexLookupJoinTest, outputBatchSize) { tableData, pool_, {"t0", "t1", "t2"}, + GetParam().hasNullKeys, {}, {}, /*equalMatchPct=*/100); @@ -1781,18 +2382,20 @@ TEST_P(IndexLookupJoinTest, outputBatchSize) { createDuckDbTable("t", probeVectors); createDuckDbTable("u", {tableData.tableData}); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyData, + tableData.valueData, + *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; + const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, makeScanOutputType({"u0", "u1", "u2", "u5"}), - columnHandles); + makeIndexColumnHandles({"u0", "u1", "u2", "u5"})); auto plan = makeLookupPlan( planNodeIdGenerator, @@ -1800,6 +2403,8 @@ TEST_P(IndexLookupJoinTest, outputBatchSize) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"t4", "u5"}); const auto task = @@ -1814,6 +2419,9 @@ TEST_P(IndexLookupJoinTest, outputBatchSize) { .config( core::QueryConfig::kPreferredOutputBatchBytes, std::to_string(1ULL << 30)) + .config( + core::QueryConfig::kIndexLookupJoinSplitOutput, + testData.splitOutput ? "true" : "false") .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) .serialExecution(GetParam().serialExecution) .barrierExecution(GetParam().serialExecution) @@ -1825,6 +2433,128 @@ TEST_P(IndexLookupJoinTest, outputBatchSize) { } } +TEST_P(IndexLookupJoinTest, outputBatchSizeWithLeftJoin) { + SequenceTableData tableData; + generateIndexTableData({3'000, 1, 1}, tableData, pool_); + + struct { + int numProbeBatches; + int numRowsPerProbeBatch; + int maxBatchRows; + bool splitOutput; + int numExpectedOutputBatch; + + std::string debugString() const { + return fmt::format( + "numProbeBatches: {}, numRowsPerProbeBatch: {}, maxBatchRows: {}, splitOutput: {}, numExpectedOutputBatch: {}", + numProbeBatches, + numRowsPerProbeBatch, + maxBatchRows, + splitOutput, + numExpectedOutputBatch); + } + } testSettings[] = { + {10, 100, 10, false, 10}, + {10, 500, 10, false, 10}, + {10, 1, 200, false, 10}, + {1, 500, 10, false, 1}, + {1, 300, 10, false, 1}, + {1, 500, 200, false, 1}, + {10, 200, 200, false, 10}, + {10, 500, 300, false, 10}, + {10, 50, 1, false, 10}, + {10, 100, 10, true, 100}, + {10, 500, 10, true, 500}, + {10, 1, 200, true, 10}, + {1, 500, 10, true, 50}, + {1, 300, 10, true, 30}, + {1, 500, 200, true, 3}, + {10, 200, 200, true, 10}, + {10, 500, 300, true, 20}, + {10, 50, 1, true, 500}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + const auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numRowsPerProbeBatch, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + /*equalMatchPct=*/100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyData, + tableData.valueData, + *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t4", "u5"}); + const auto task = + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .config( + core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(GetParam().numPrefetches)) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(testData.maxBatchRows)) + .config( + core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(1ULL << 30)) + .config( + core::QueryConfig::kIndexLookupJoinSplitOutput, + testData.splitOutput ? "true" : "false") + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .serialExecution(GetParam().serialExecution) + .barrierExecution(GetParam().serialExecution) + .assertResults( + "SELECT t.c4, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + ASSERT_EQ( + toPlanStats(task->taskStats()).at(joinNodeId_).outputVectors, + testData.numExpectedOutputBatch); + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/true, + core::JoinType::kLeft, + {"t4", "u5"}); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); + } +} + DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { SequenceTableData tableData; generateIndexTableData({100, 1, 1}, tableData, pool_); @@ -1836,6 +2566,7 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { tableData, pool_, {"t0", "t1", "t2"}, + GetParam().hasNullKeys, {}, {}, 100); @@ -1850,18 +2581,16 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT })); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), - columnHandles); + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); auto plan = makeLookupPlan( planNodeIdGenerator, @@ -1869,6 +2598,8 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, {"u3", "t5"}); auto task = runLookupQuery( @@ -1915,40 +2646,259 @@ DEBUG_ONLY_TEST_P(IndexLookupJoinTest, runtimeStats) { ".*Runtime stats.*connectorResultPrepareCpuNanos.*")); } -TEST_P(IndexLookupJoinTest, barrier) { - if (!GetParam().serialExecution) { - GTEST_SKIP(); - } +TEST_P(IndexLookupJoinTest, barrier) { + SequenceTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeSplits{5}; + const auto probeVectors = generateProbeInput( + numProbeSplits, + 256, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + 100); + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u3", "t5"}); + + struct { + int numPrefetches; + bool barrierExecution; + + std::string debugString() const { + return fmt::format( + "numPrefetches {}, barrierExecution {}", + numPrefetches, + barrierExecution); + } + } testSettings[] = { + {0, true}, + {0, false}, + {1, true}, + {1, false}, + {4, true}, + {4, false}, + {256, true}, + {256, false}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + auto task = runLookupQuery( + plan, + probeFiles, + true, + testData.barrierExecution, + 32, + testData.numPrefetches, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + const auto taskStats = task->taskStats(); + ASSERT_EQ( + taskStats.numBarriers, testData.barrierExecution ? numProbeSplits : 0); + ASSERT_EQ(taskStats.numFinishedSplits, numProbeSplits); + } +} + +TEST_P(IndexLookupJoinTest, nullKeys) { + SequenceTableData tableData; + generateIndexTableData({100, 1, 1}, tableData, pool_); + const int numProbeSplits{5}; + const int probeBatchSize{256}; + const auto probeVectors = generateProbeInput( + numProbeSplits, + probeBatchSize, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + /*hasNullKeys=*/true, + {}, + {}, + /*equalMatchPct=*/100); + // Set some probe key vector to all nulls to trigger the case that entire + // probe input is skipped. + for (int i = 0; i < numProbeSplits; i += 2) { + for (int row = 0; row < probeBatchSize; ++row) { + probeVectors[i]->childAt(i % keyType_->size())->setNull(row, true); + } + } + std::vector> probeFiles = + createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + std::unordered_map> + columnHandles; + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); + + const auto innerPlan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u3", "t5"}); + + runLookupQuery( + innerPlan, + probeFiles, + true, + true, + 32, + GetParam().numPrefetches, + "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + const auto leftPlan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"u3", "t5"}); + + runLookupQuery( + leftPlan, + probeFiles, + true, + true, + 32, + GetParam().numPrefetches, + "SELECT u.c3, t.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + /*filter=*/"", + /*hasMarker=*/true, + core::JoinType::kLeft, + {"u3", "t5"}); + verifyResultWithMatchColumn( + leftPlan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); +} + +TEST_P(IndexLookupJoinTest, joinFuzzer) { + SequenceTableData tableData; + generateIndexTableData({1024, 1, 1}, tableData, pool_); + const auto probeVectors = generateProbeInput( + 50, 256, 1, tableData, pool_, {"t0", "t1", "t2"}, GetParam().hasNullKeys); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/1, tableData.keyData, tableData.valueData, *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + auto scanOutput = tableType_->names(); + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(scanOutput.begin(), scanOutput.end(), g); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(scanOutput), + makeIndexColumnHandles(scanOutput)); + + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0"}, + {"u0"}, + {"contains(t4, u1)", "u2 between t1 and t2"}, + /*filter=*/"", + /*hasMarker=*/false, + core::JoinType::kInner, + {"u0", "u4", "t0", "t1", "t4"}); + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5, t.c0, t.c1, t.c2, t.c3, t.c4, t.c5 FROM t, u WHERE t.c0 = u.c0 AND array_contains(t.c4, u.c1) AND u.c2 BETWEEN t.c1 AND t.c2"); +} + +TEST_P(IndexLookupJoinTest, tableRowsWithDuplicateKeys) { SequenceTableData tableData; - generateIndexTableData({100, 1, 1}, tableData, pool_); - const int numProbeSplits{5}; - const auto probeVectors = generateProbeInput( - numProbeSplits, - 256, - 1, - tableData, - pool_, - {"t0", "t1", "t2"}, - {}, - {}, - 100); + generateIndexTableData({10, 1, 1}, tableData, pool_); + for (int i = 0; i < keyType_->size(); ++i) { + tableData.keyData->childAt(i) = makeFlatVector( + tableData.keyData->childAt(i)->size(), + [](auto /*unused*/) { return 1; }); + tableData.tableData->childAt(i) = makeFlatVector( + tableData.keyData->childAt(i)->size(), + [](auto /*unused*/) { return 1; }); + } + + auto probeVectors = generateProbeInput( + 4, 32, 1, tableData, pool_, {"t0", "t1", "t2"}, false, {}, {}, 100); std::vector> probeFiles = createProbeFiles(probeVectors); + createDuckDbTable("t", probeVectors); createDuckDbTable("u", {tableData.tableData}); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - std::unordered_map> - columnHandles; + auto scanOutput = tableType_->names(); const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, - makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), - columnHandles); + makeScanOutputType(scanOutput), + makeIndexColumnHandles(scanOutput)); auto plan = makeLookupPlan( planNodeIdGenerator, @@ -1956,91 +2906,391 @@ TEST_P(IndexLookupJoinTest, barrier) { {"t0", "t1", "t2"}, {"u0", "u1", "u2"}, {}, + /*filter=*/"", + /*hasMarker=*/false, core::JoinType::kInner, - {"u3", "t5"}); + scanOutput); + runLookupQuery( + plan, + probeFiles, + GetParam().serialExecution, + GetParam().serialExecution, + 32, + GetParam().numPrefetches, + "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND u.c2 = t.c2"); +} +TEST_P(IndexLookupJoinTest, withFilter) { struct { - int numPrefetches; - bool barrierExecution; + std::vector keyCardinalities; + int numProbeBatches; + int numRowsPerProbeBatch; + int matchPct; + std::vector scanOutputColumns; + std::vector outputColumns; + core::JoinType joinType; + std::string filter; + std::string duckDbVerifySql; std::string debugString() const { return fmt::format( - "numPrefetches {}, barrierExecution {}", - numPrefetches, - barrierExecution); + "keyCardinalities: {}, numProbeBatches: {}, numRowsPerProbeBatch: {}, matchPct: {}, " + "scanOutputColumns: {}, outputColumns: {}, joinType: {}, filter: {}, " + "duckDbVerifySql: {}", + folly::join(",", keyCardinalities), + numProbeBatches, + numRowsPerProbeBatch, + matchPct, + folly::join(",", scanOutputColumns), + folly::join(",", outputColumns), + core::JoinTypeName::toName(joinType), + filter, + duckDbVerifySql); } } testSettings[] = { - {0, true}, - {0, false}, - {1, true}, - {1, false}, - {4, true}, - {4, false}, - {256, true}, - {256, false}}; + // Inner join with filter on probe side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c3 != t.c3"}, + + // Inner join with filter on lookup side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 = u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 = u.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 != u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 != u.c3"}, + + // Inner join with filter on both side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 % 2 = 0 AND t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 % 2 = 0 AND t.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 = u3 AND t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 = u.c3 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kInner, + "u3 != u3 AND t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t, u WHERE t.c0 = u.c0 AND u.c3 != u.c3 AND t.c3 != t.c3"}, + + // Left join with filter on probe side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 % 2 = 0"}, + // Left join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 = t.c3"}, + // Inner join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c3 != t.c3"}, + + // Left join with filter on lookup side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 % 2 = 0"}, + // Inner join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 = u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 = u.c3"}, + // Left join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 != u3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 != u.c3"}, + + // Left join with filter on both side + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 % 2 = 0 AND t3 % 2 = 0", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 % 2 = 0 AND t.c3 % 2 = 0"}, + // Left join with filter always be true. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 = u3 AND t3 = t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 = u.c3 AND t.c3 = t.c3"}, + // Left join with filter always be false. + {{100, 1, 1}, + 5, + 100, + 80, + {"u0", "u1", "u2", "u3", "u5"}, + {"t1", "u1", "u2", "u3", "u5"}, + core::JoinType::kLeft, + "u3 != u3 AND t3 != t3", + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND u.c3 != u.c3 AND t.c3 != t.c3"}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - auto task = runLookupQuery( + + SequenceTableData tableData; + generateIndexTableData(testData.keyCardinalities, tableData, pool_); + auto probeVectors = generateProbeInput( + testData.numProbeBatches, + testData.numRowsPerProbeBatch, + 1, + tableData, + pool_, + {"t0", "t1", "t2"}, + GetParam().hasNullKeys, + {}, + {}, + testData.matchPct); + std::vector> probeFiles = + createProbeFiles(probeVectors); + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", {tableData.tableData}); + + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, + tableData.keyData, + tableData.valueData, + *pool()); + const auto indexTableHandle = + makeIndexTableHandle(indexTable, GetParam().asyncLookup); + auto planNodeIdGenerator = std::make_shared(); + const auto indexScanNode = makeIndexScanNode( + planNodeIdGenerator, + indexTableHandle, + makeScanOutputType(testData.scanOutputColumns), + makeIndexColumnHandles(testData.scanOutputColumns)); + + // Create a plan with filter + auto plan = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + testData.filter, + /*hasMarker=*/false, + testData.joinType, + testData.outputColumns); + + runLookupQuery( plan, probeFiles, - true, - testData.barrierExecution, + GetParam().serialExecution, + GetParam().serialExecution, 32, - testData.numPrefetches, - "SELECT u.c3, t.c5 FROM t, u WHERE t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2"); + GetParam().numPrefetches, + testData.duckDbVerifySql); - const auto taskStats = task->taskStats(); - ASSERT_EQ( - taskStats.numBarriers, testData.barrierExecution ? numProbeSplits : 0); - ASSERT_EQ(taskStats.numFinishedSplits, numProbeSplits); + if (testData.joinType != core::JoinType::kLeft) { + continue; + } + const auto probeScanId = probeScanNodeId_; + auto planWithMatchColumn = makeLookupPlan( + planNodeIdGenerator, + indexScanNode, + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + testData.filter, + /*hasMarker=*/true, + testData.joinType, + testData.outputColumns); + verifyResultWithMatchColumn( + plan, probeScanId, planWithMatchColumn, probeScanNodeId_, probeFiles); } } -TEST_P(IndexLookupJoinTest, joinFuzzer) { +TEST_P(IndexLookupJoinTest, mixedFilterBatches) { + // Create SequenceTableData using VectorTestBase utilities SequenceTableData tableData; - generateIndexTableData({1024, 1, 1}, tableData, pool_); - const auto probeVectors = - generateProbeInput(50, 256, 1, tableData, pool_, {"t0", "t1", "t2"}); + + const std::string dummyString("test"); + StringView dummyStringView(dummyString); + // Create table key data (u0, u1, u2) using makeFlatVector + auto u0 = makeFlatVector(64, [&](auto row) { return row % 8; }); + auto u1 = makeFlatVector(64, [&](auto row) { return row % 8; }); + auto u2 = makeFlatVector(64, [&](auto row) { return row % 8; }); + tableData.keyData = makeRowVector({"u0", "u1", "u2"}, {u0, u1, u2}); + + // Create table value data (u3, u4, u5) using makeFlatVector + auto u3 = makeFlatVector(64, [&](auto row) { return row; }); + auto u4 = makeFlatVector(64, [&](auto row) { return row; }); + auto u5 = makeFlatVector( + 64, [&](auto /*unused*/) { return dummyStringView; }); + tableData.valueData = makeRowVector({"u3", "u4", "u5"}, {u3, u4, u5}); + + // Create complete table data by combining key and value data + tableData.tableData = makeRowVector( + {"u0", "u1", "u2", "u3", "u4", "u5"}, {u0, u1, u2, u3, u4, u5}); + + // Create probe vectors using makeArrayVectorFromJson in a loop + std::vector probeVectors; + probeVectors.reserve(5); + for (int i = 0; i < 5; ++i) { + probeVectors.push_back(makeRowVector( + {"t0", "t1", "t2", "t3", "t4", "t5"}, + {makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeFlatVector(128, [&](auto row) { return row; }), + makeArrayVector( + 128, + [](vector_size_t /*unused*/) { return 1; }, + [](vector_size_t, vector_size_t) { return 1; }), + makeFlatVector( + 128, [&](auto /*unused*/) { return dummyStringView; })})); + } + std::vector> probeFiles = createProbeFiles(probeVectors); createDuckDbTable("t", probeVectors); createDuckDbTable("u", {tableData.tableData}); - const auto indexTable = createIndexTable( - /*numEqualJoinKeys=*/1, tableData.keyData, tableData.valueData); + const auto indexTable = TestIndexTable::create( + /*numEqualJoinKeys=*/3, tableData.keyData, tableData.valueData, *pool()); const auto indexTableHandle = makeIndexTableHandle(indexTable, GetParam().asyncLookup); auto planNodeIdGenerator = std::make_shared(); - auto scanOutput = tableType_->names(); - std::random_device rd; - std::mt19937 g(rd()); - std::shuffle(scanOutput.begin(), scanOutput.end(), g); - std::unordered_map> - columnHandles; const auto indexScanNode = makeIndexScanNode( planNodeIdGenerator, indexTableHandle, - makeScanOutputType(scanOutput), - columnHandles); + makeScanOutputType({"u0", "u1", "u2", "u3", "u5"}), + makeIndexColumnHandles({"u0", "u1", "u2", "u3", "u5"})); auto plan = makeLookupPlan( planNodeIdGenerator, indexScanNode, - {"t0"}, - {"u0"}, - {"contains(t4, u1)", "u2 between t1 and t2"}, - core::JoinType::kInner, - {"u0", "u4", "t0", "t1", "t4"}); - runLookupQuery( - plan, - probeFiles, - GetParam().serialExecution, - GetParam().serialExecution, - 32, - GetParam().numPrefetches, - "SELECT u.c0, u.c1, u.c2, u.c3, u.c4, u.c5, t.c0, t.c1, t.c2, t.c3, t.c4, t.c5 FROM t, u WHERE t.c0 = u.c0 AND array_contains(t.c4, u.c1) AND u.c2 BETWEEN t.c1 AND t.c2"); + {"t0", "t1", "t2"}, + {"u0", "u1", "u2"}, + {}, + "t3 > 4", + /*hasMarker=*/false, + core::JoinType::kLeft, + {"t1", "u1", "u2", "u3", "u5"}); + + AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .config( + core::QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(GetParam().numPrefetches)) + .config(core::QueryConfig::kPreferredOutputBatchRows, "4") + .config(core::QueryConfig::kIndexLookupJoinSplitOutput, "true") + .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) + .serialExecution(GetParam().serialExecution) + .barrierExecution(GetParam().serialExecution) + .assertResults( + "SELECT t.c1, u.c1, u.c2, u.c3, u.c5 FROM t LEFT JOIN u ON t.c0 = u.c0 AND t.c1 = u.c1 AND t.c2 = u.c2 AND t.c3 > 4"); } } // namespace @@ -2050,9 +3300,10 @@ VELOX_INSTANTIATE_TEST_SUITE_P( testing::ValuesIn(IndexLookupJoinTest::getTestParams()), [](const testing::TestParamInfo& info) { return fmt::format( - "{}_{}prefetches_{}", + "{}_{}prefetches_{}_{}", info.param.asyncLookup ? "async" : "sync", info.param.numPrefetches, - info.param.serialExecution ? "serial" : "parallel"); + info.param.serialExecution ? "serial" : "parallel", + info.param.hasNullKeys ? "nullKeys" : "noNullKeys"); }); -} // namespace fecebook::velox::exec::test +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/LimitTest.cpp b/velox/exec/tests/LimitTest.cpp index d1be2199c006..c00dbc0a32ca 100644 --- a/velox/exec/tests/LimitTest.cpp +++ b/velox/exec/tests/LimitTest.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ #include "velox/exec/OutputBufferManager.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -21,6 +23,7 @@ using namespace facebook::velox; using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; +namespace { class LimitTest : public HiveConnectorTestBase {}; TEST_F(LimitTest, basic) { @@ -142,3 +145,190 @@ TEST_F(LimitTest, partialLimitEagerFlush) { test(true); test(false); } + +TEST_F(LimitTest, barrier) { + std::vector vectors; + std::vector> tempFiles; + const int numSplits{5}; + const int numRowsPerSplit{1'000}; + for (int32_t i = 0; i < numSplits; ++i) { + vectors.push_back(makeRowVector({makeFlatVector( + numRowsPerSplit, [](auto row) { return row; })})); + tempFiles.push_back(TempFilePath::create()); + } + writeToFiles(toFilePaths(tempFiles), vectors); + createDuckDbTable(vectors); + + struct { + bool barrierExecution; + int offset; + int limit; + int numExpectedBarriers; + int numExpectedOutputRows; + int numExpectedFinishedSplits; + int numExpectedOutputBatches; + + std::string toString() const { + return fmt::format( + "barrierExecution {}, offset: {}, limit: {}, numExpectedBarriers: {}, numExpectedOutputRows: {}, numExpectedFinishedSplits: {}, numExpectedOutputBatches: {}", + barrierExecution, + offset, + limit, + numExpectedBarriers, + numExpectedOutputRows, + numExpectedFinishedSplits, + numExpectedOutputBatches); + } + } testSettings[] = {// Test the case where the limit covers all the input rows + // with barrier and not. + {true, + 0, + numRowsPerSplit * numSplits, + numSplits, + numRowsPerSplit * numSplits, + numSplits - 1, + numSplits}, + {false, + 0, + numRowsPerSplit * numSplits, + 0, + numRowsPerSplit * numSplits, + numSplits - 1, + numSplits}, + // Test the cases where the limit covers the first entire + // split rows with barrier or not. with barrier and not. + {true, 0, numRowsPerSplit, 1, numRowsPerSplit, 0, 1}, + {false, 0, numRowsPerSplit, 0, numRowsPerSplit, 0, 1}, + // Test the case where the limit covers the one and half + // split rows with barrier or not. + {true, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 2, + numRowsPerSplit + numRowsPerSplit / 2, + 1, + 2}, + {false, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 0, + numRowsPerSplit + numRowsPerSplit / 2, + 1, + 2}, + {true, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + numRowsPerSplit + numRowsPerSplit - 1, + 1, + 2}, + {false, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 1, + 2}, + // Test the case where the limit set to cover more than + // all the input rows with barrier or not. + {true, + 0, + numRowsPerSplit * (numSplits + 1), + numSplits, + numRowsPerSplit * numSplits, + numSplits, + numSplits}, + {false, + 0, + numRowsPerSplit * (numSplits + 1), + 0, + numRowsPerSplit * numSplits, + numSplits, + numSplits}, + // Test the cases where the limit set to cover partial + // input rows in the middle with barrier or not. + {true, + numRowsPerSplit, + numRowsPerSplit + numRowsPerSplit - 1, + 3, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + 2}, + {false, + numRowsPerSplit, + numRowsPerSplit + numRowsPerSplit - 1, + 0, + numRowsPerSplit + numRowsPerSplit - 1, + 2, + 2}, + {true, + numRowsPerSplit, + numRowsPerSplit * (numSplits - 1), + numSplits, + numRowsPerSplit * (numSplits - 1), + numSplits - 1, + numSplits - 1}, + {false, + numRowsPerSplit, + numRowsPerSplit * (numSplits - 1), + 0, + numRowsPerSplit * (numSplits - 1), + numSplits - 1, + numSplits - 1}, + {true, + numRowsPerSplit, + numRowsPerSplit / 2, + 2, + numRowsPerSplit / 2, + 1, + 1}, + {false, + numRowsPerSplit, + numRowsPerSplit / 2, + 0, + numRowsPerSplit / 2, + 1, + 1}, + {true, + numRowsPerSplit / 2, + numRowsPerSplit * numSplits, + numSplits, + numRowsPerSplit * numSplits - numRowsPerSplit / 2, + numSplits, + numSplits}, + {false, + numRowsPerSplit / 2, + numRowsPerSplit * numSplits, + 0, + numRowsPerSplit * numSplits - numRowsPerSplit / 2, + numSplits, + numSplits}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); + core::PlanNodeId limitPlanNodeId; + auto plan = PlanBuilder() + .tableScan(asRowType(vectors.back()->type())) + .limit(testData.offset, testData.limit, true) + .capturePlanNodeId(limitPlanNodeId) + .planNode(); + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits(makeHiveConnectorSplits(tempFiles)) + .serialExecution(true) + .barrierExecution(testData.barrierExecution) + .assertResults( + fmt::format( + "SELECT * FROM tmp LIMIT {} OFFSET {}", + testData.limit, + testData.offset)); + const auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, testData.numExpectedBarriers); + ASSERT_EQ(taskStats.numFinishedSplits, testData.numExpectedFinishedSplits); + ASSERT_EQ( + exec::toPlanStats(taskStats).at(limitPlanNodeId).outputRows, + testData.numExpectedOutputRows); + ASSERT_EQ( + exec::toPlanStats(taskStats).at(limitPlanNodeId).outputVectors, + testData.numExpectedOutputBatches); + } +} +} // namespace diff --git a/velox/exec/tests/LocalPartitionTest.cpp b/velox/exec/tests/LocalPartitionTest.cpp index 508f8db67618..917aa1e32131 100644 --- a/velox/exec/tests/LocalPartitionTest.cpp +++ b/velox/exec/tests/LocalPartitionTest.cpp @@ -329,6 +329,80 @@ TEST_F(LocalPartitionTest, partitionBuffering) { queryBuilder.assertResults(query), 2200, 2, 2); } +TEST_F(LocalPartitionTest, partitionBufferingWithStringBuffers) { + // Test string buffer memory accounting with partition buffering. + // String buffers are multiply-referenced across partitions. The fix + // amortizes string buffer sizes across partitions to avoid over-counting, + // which allows more efficient buffering. + std::vector vectors; + for (auto i = 0; i < 4; ++i) { + vectors.emplace_back(makeRowVector( + {"c0", "c1"}, + {makeFlatVector(100, [](auto row) { return row % 2; }), + makeFlatVector( + 100, [](auto /*row*/) { return std::string(100, 'a'); })})); + } + + auto runQuery = [&](const std::vector& input, + int maxDrivers, + int maxPartitionBufferSize) { + std::string query{"SELECT c0, arbitrary(c1) FROM tmp GROUP BY c0"}; + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .localPartition( + {"c0"}, + {PlanBuilder(planNodeIdGenerator).values(input).planNode()}) + .partialAggregation({"c0"}, {"arbitrary(c1)"}) + .planNode(); + createDuckDbTable(vectors); + + AssertQueryBuilder queryBuilder(plan, duckDbQueryRunner_); + queryBuilder.maxDrivers(maxDrivers); + + std::unordered_map configs; + configs[core::QueryConfig:: + kMinLocalExchangePartitionCountToUsePartitionBuffer] = + std::to_string(2); + configs[core::QueryConfig::kMaxLocalExchangePartitionBufferSize] = + std::to_string(maxPartitionBufferSize); + queryBuilder.configs(configs); + + return queryBuilder.assertResults(query); + }; + + // With amortized string buffer accounting, the buffer can accumulate rows + // from all input vectors. So only 2 vectors are flushed to + // LocalExchangeQueues, compared to 8 if not amortizing string buffer sizes. + verifyExchangeSourceOperatorStats(runQuery(vectors, 2, 50000), 400, 2, 2); + + // Test case with 99% of data belonging to one partition. We expect all + // partition buffers getting flushed together when the total size of all + // partition buffers exceeds the limit, instead of flushing each partiton + // buffer individually. + vectors.clear(); + for (auto i = 0; i < 4; ++i) { + vectors.emplace_back(makeRowVector( + {"c0", "c1"}, + {makeFlatVector( + 100, + [](auto row) { + if (row == 0) { + return 0; + } else { + return 1; + } + }), + makeFlatVector( + 100, [](auto /*row*/) { return std::string(100, 'a'); })})); + } + + // The total size of all partition buffers should exceed the limit and trigger + // the flush of all partition buffers at every batch, hence flushing a total + // of 8 vectors. + verifyExchangeSourceOperatorStats(runQuery(vectors, 2, 20000), 400, 8, 2); +} + TEST_F(LocalPartitionTest, partitionBufferingPreserveEncoding) { std::vector vectors = { makeRowVector({"c0"}, {makeConstant(0, 100)}), @@ -388,8 +462,9 @@ TEST_F(LocalPartitionTest, maxBufferSizeGather) { auto valuesNode = [&](int start, int end) { return PlanBuilder(planNodeIdGenerator) - .values(std::vector( - vectors.begin() + start, vectors.begin() + end)) + .values( + std::vector( + vectors.begin() + start, vectors.begin() + end)) .planNode(); }; @@ -498,9 +573,11 @@ TEST_F(LocalPartitionTest, indicesBufferCapacity) { params.maxDrivers = 2; auto cursor = TaskCursor::create(params); for (auto i = 0; i < filePaths.size(); ++i) { - auto id = scanNodeIds[i % 3]; + auto& id = scanNodeIds[i % 3]; cursor->task()->addSplit( id, Split(makeHiveConnectorSplit(filePaths[i]->getPath()))); + } + for (auto& id : scanNodeIds) { cursor->task()->noMoreSplits(id); } int numRows = 0; @@ -959,8 +1036,9 @@ TEST_F(LocalPartitionTest, unionAllLocalExchangeWithInterDependency) { } }; - Operator::registerOperator(std::make_unique( - std::move(blockingCallback), std::move(finishCallback))); + Operator::registerOperator( + std::make_unique( + std::move(blockingCallback), std::move(finishCallback))); auto planNodeIdGenerator = std::make_shared(); auto plan = PlanBuilder(planNodeIdGenerator) @@ -1028,8 +1106,9 @@ TEST_F( auto finishCallback = [&](bool /*unused*/) {}; - Operator::registerOperator(std::make_unique( - std::move(blockingCallback), std::move(finishCallback))); + Operator::registerOperator( + std::make_unique( + std::move(blockingCallback), std::move(finishCallback))); auto planNodeIdGenerator = std::make_shared(); auto plan = PlanBuilder(planNodeIdGenerator) diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 5714301a2f3b..f48871c09139 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -16,6 +16,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -112,12 +113,13 @@ class MergeJoinTest : public HiveConnectorTestBase { for (const auto& row : input) { std::vector children; for (const auto& child : row->children()) { - children.push_back(std::make_shared( - pool(), - child->type(), - child->size(), - std::make_unique( - batchId, counter, [=](RowSet) { return child; }))); + children.push_back( + std::make_shared( + pool(), + child->type(), + child->size(), + std::make_unique( + batchId, counter, [=, this](RowSet) { return child; }))); } data.push_back(makeRowVector(children)); @@ -369,6 +371,87 @@ class MergeJoinTest : public HiveConnectorTestBase { std::bind( &MergeJoinTest::generateLazyInput, this, std::placeholders::_1)); } + + void testJoinTwoKeysWithNulls( + RowVectorPtr& leftVectors, + RowVectorPtr& rightVectors) { + auto leftFile = TempFilePath::create(); + writeToFile(leftFile->getPath(), leftVectors); + createDuckDbTable("t", {leftVectors}); + auto rightFile = TempFilePath::create(); + writeToFile(rightFile->getPath(), rightVectors); + createDuckDbTable("u", {rightVectors}); + + auto joinTypes = { + core::JoinType::kInner, + core::JoinType::kLeft, + core::JoinType::kRight, + core::JoinType::kFull, + }; + + for (auto joinType : joinTypes) { + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanId; + core::PlanNodeId rightScanId; + auto op = PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"c0", "c1", "c2", "c3"}, + {VARCHAR(), VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(leftScanId) + .mergeJoin( + {"c0", "c1"}, + {"rc0", "rc1"}, + PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"rc0", "rc1", "rc2"}, + {VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(rightScanId) + .planNode(), + "", + {"c0", "c1", "c2", "c3", "rc0", "rc1", "rc2"}, + joinType) + .planNode(); + AssertQueryBuilder(op, duckDbQueryRunner_) + .split(rightScanId, makeHiveConnectorSplit(rightFile->getPath())) + .split(leftScanId, makeHiveConnectorSplit(leftFile->getPath())) + .assertResults( + fmt::format( + "SELECT * FROM t {} JOIN u " + "ON t.c0 = u.rc0 AND t.c1 = u.rc1", + core::JoinTypeName::toName(joinType))); + } + + { + // anti join + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId leftScanId; + core::PlanNodeId rightScanId; + auto op = PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"c0", "c1", "c2", "c3"}, + {VARCHAR(), VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(leftScanId) + .mergeJoin( + {"c0", "c1"}, + {"rc0", "rc1"}, + PlanBuilder(planNodeIdGenerator) + .tableScan( + ROW({"rc0", "rc1", "rc2"}, + {VARCHAR(), VARCHAR(), VARCHAR()})) + .capturePlanNodeId(rightScanId) + .planNode(), + "", + {"c0", "c1", "c2", "c3"}, + core::JoinType::kAnti) + .planNode(); + AssertQueryBuilder(op, duckDbQueryRunner_) + .split(rightScanId, makeHiveConnectorSplit(rightFile->getPath())) + .split(leftScanId, makeHiveConnectorSplit(leftFile->getPath())) + .assertResults( + "SELECT * FROM t WHERE NOT exists (select * from u " + "where t.c0 = u.rc0 AND t.c1 = u.rc1)"); + } + } }; TEST_F(MergeJoinTest, oneToOneAllMatch) { @@ -870,10 +953,11 @@ TEST_F(MergeJoinTest, lazyVectors) { AssertQueryBuilder(op, duckDbQueryRunner_) .split(rightScanId, makeHiveConnectorSplit(rightFile->getPath())) .split(leftScanId, makeHiveConnectorSplit(leftFile->getPath())) - .assertResults(fmt::format( - "SELECT c0, rc0, c1, rc1, c2, c3 FROM t {} JOIN u " - "ON t.c0 = u.rc0 AND c1 + rc1 < 30", - joinTypeName(joinType))); + .assertResults( + fmt::format( + "SELECT c0, rc0, c1, rc1, c2, c3 FROM t {} JOIN u " + "ON t.c0 = u.rc0 AND c1 + rc1 < 30", + core::JoinTypeName::toName(joinType))); } } @@ -1016,6 +1100,110 @@ TEST_F(MergeJoinTest, semiJoinWithMultipleMatchVectors) { core::JoinType::kLeftSemiFilter); } +TEST_F(MergeJoinTest, semiJoinWithMultiMatchedRowsWithFilter) { + auto left = makeRowVector( + {"t0", "t1"}, + {makeNullableFlatVector({2, 2, 2, 2, 2}), + makeNullableFlatVector({3, 2, 3, 2, 2})}); + + auto right = makeRowVector( + {"u0", "u1"}, + {makeNullableFlatVector({2, 2, 2, 2, 2, 2}), + makeNullableFlatVector({2, 2, 2, 2, 2, 4})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(split(left, 2)) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(split(right, 2)) + .planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchRows, "2") + .config(core::QueryConfig::kMaxOutputBatchRows, "2") + .assertResults(sql); + }; + + // Left Semi join With filter + testSemiJoin( + "t1 > u1", + "SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 > u1)", + {"t0", "t1"}, + core::JoinType::kLeftSemiFilter); + + // Right Semi join With filter + testSemiJoin( + "u1 > t1", + "SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where u1 > t1)", + {"u0", "u1"}, + core::JoinType::kRightSemiFilter); +} + +TEST_F(MergeJoinTest, semiJoinWithOneMatchedRowWithFilter) { + auto left = makeRowVector( + {"t0", "t1"}, + {makeNullableFlatVector({2, 2}), + makeNullableFlatVector({3, 5})}); + + auto right = makeRowVector( + {"u0", "u1"}, + {makeNullableFlatVector({2, 2}), + makeNullableFlatVector({1, 4})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(split(left, 2)) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(split(right, 2)) + .planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchRows, "2") + .config(core::QueryConfig::kMaxOutputBatchRows, "2") + .assertResults(sql); + }; + + // Left Semi join With filter + testSemiJoin( + "t1 > u1", + "SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 > u1)", + {"t0", "t1"}, + core::JoinType::kLeftSemiFilter); + + // Right Semi join With filter + testSemiJoin( + "u1 > t1", + "SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where u1 > t1)", + {"u0", "u1"}, + core::JoinType::kRightSemiFilter); +} + TEST_F(MergeJoinTest, rightJoin) { auto left = makeRowVector( {"t0"}, @@ -1206,6 +1394,160 @@ TEST_F(MergeJoinTest, antiJoinWithTwoJoinKeys) { "SELECT * FROM t WHERE NOT exists (select * from u where t.a = u.c and t.b < u.d)"); } +TEST_F(MergeJoinTest, matchRatioStats) { + // Test match ratio statistics for different join scenarios. + + // Inner join with full match (all rows match). + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 3, 4, 5})}); + auto right = makeRowVector( + {"u0"}, {makeNullableFlatVector({1, 2, 3, 4, 5})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t, u WHERE t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 5); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 5); + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 5); + } + + // Inner join with partial match. + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 3, 4, 5, 6, 7, 8})}); + auto right = makeRowVector( + {"u0"}, {makeNullableFlatVector({2, 4, 6, 10, 12})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t, u WHERE t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 3); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + // Only 3 left rows match (2, 4, 6). + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 3); + // Only 3 right rows match (2, 4, 6). + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 3); + } + + // Left join - all left rows appear in output. + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 3, 4, 5})}); + auto right = + makeRowVector({"u0"}, {makeNullableFlatVector({2, 4})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kLeft) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t LEFT JOIN u ON t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 5); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + // Only 2 left rows match (2, 4). + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 2); + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 2); + } + + // Join with duplicate keys (cartesian product). + { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 1, 1, 2, 2})}); + auto right = makeRowVector( + {"u0"}, {makeNullableFlatVector({1, 1, 2, 2, 2})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId mergeJoinNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "u0"}, + core::JoinType::kInner) + .capturePlanNodeId(mergeJoinNodeId) + .planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults("SELECT t0, u0 FROM t, u WHERE t0 = u0"); + + auto stats = toPlanStats(task->taskStats()); + ASSERT_EQ(stats.at(mergeJoinNodeId).outputRows, 12); + + auto runtimeStats = stats.at(mergeJoinNodeId).customStats; + // 3 left rows with key=1 and 2 left rows with key=2. + ASSERT_EQ(runtimeStats.at("matchedLeftRows").sum, 5); + // 2 right rows with key=1 and 3 right rows with key=2. + ASSERT_EQ(runtimeStats.at("matchedRightRows").sum, 5); + } +} + TEST_F(MergeJoinTest, antiJoinWithUniqueJoinKeys) { auto left = makeRowVector( {"a", "b"}, @@ -1811,3 +2153,85 @@ TEST_F(MergeJoinTest, barrier) { } } } + +TEST_F(MergeJoinTest, antiJoinWithFilterWithMultiMatchedRows) { + auto left = makeRowVector({"t0"}, {makeNullableFlatVector({1, 2})}); + + auto right = + makeRowVector({"u0"}, {makeNullableFlatVector({1, 2, 2, 2})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Anti join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "t0 > 2", + {"t0"}, + core::JoinType::kAnti) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT t0 FROM t WHERE NOT exists (select 1 from u where t0 = u0 AND t.t0 > 2 ) "); +} + +TEST_F(MergeJoinTest, antiJoinWithTwoJoinKeysInDifferentBatch) { + auto left = makeRowVector( + {"a", "b"}, + {makeNullableFlatVector({1, 1, 1, 1}), + makeNullableFlatVector({3.0, 3.0, 3.0, 3.0})}); + + auto right = makeRowVector( + {"c", "d"}, + {makeNullableFlatVector({1, 1, 1}), + makeNullableFlatVector({2.0, 2.0, 4.0})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Anti join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({split(left, 2)}) + .mergeJoin( + {"a"}, + {"c"}, + PlanBuilder(planNodeIdGenerator) + .values({split(right, 2)}) + .planNode(), + "b < d", + {"a", "b"}, + core::JoinType::kAnti) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT * FROM t WHERE NOT exists (select * from u where t.a = u.c and t.b < u.d)"); +} + +TEST_F(MergeJoinTest, testJoinWithTwoKeysAndSecondColumnHasNulls) { + auto left = makeRowVector( + {"c0", "c1", "c2", "c3"}, + { + makeNullableFlatVector( + {"202408", "202409", "202409", "202410"}), + makeNullableFlatVector({"1", std::nullopt, "2", "3"}), + makeNullableFlatVector({"1", "2", "2", "3"}), + makeNullableFlatVector({"1", "2", "2", "3"}), + }); + auto right = makeRowVector( + {"rc0", "rc1", "rc2"}, + {makeNullableFlatVector( + {"202408", "202409", "202409", "202410"}), + makeNullableFlatVector({"1", std::nullopt, "2", "3"}), + makeNullableFlatVector({"1", std::nullopt, "2", "3"})}); + + testJoinTwoKeysWithNulls(left, right); +} diff --git a/velox/exec/tests/MergeTest.cpp b/velox/exec/tests/MergeTest.cpp index da7fc3a55183..843d524fc5f3 100644 --- a/velox/exec/tests/MergeTest.cpp +++ b/velox/exec/tests/MergeTest.cpp @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include "velox/exec/Merge.h" #include "folly/experimental/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" +#include "velox/exec/PlanNodeStats.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox; using namespace facebook::velox::exec; @@ -25,6 +30,11 @@ using namespace facebook::velox::common::testutil; using namespace facebook::velox::exec::test; class MergeTest : public OperatorTestBase { + public: + MergeTest() { + filesystems::registerLocalFileSystem(); + } + protected: void testSingleKey( const std::vector& inputVectors, @@ -33,7 +43,6 @@ class MergeTest : public OperatorTestBase { std::vector sortOrderSqls = { "NULLS LAST", "NULLS FIRST", "DESC NULLS FIRST", "DESC NULLS LAST"}; - for (const auto& sortOrderSql : sortOrderSqls) { const auto orderByClause = fmt::format("{} {}", key, sortOrderSql); auto planNodeIdGenerator = std::make_shared(); @@ -41,18 +50,17 @@ class MergeTest : public OperatorTestBase { .localMerge( {orderByClause}, {PlanBuilder(planNodeIdGenerator) - .values(inputVectors, true) + .values(inputVectors) .orderBy({orderByClause}, true) .planNode()}) .planNode(); - + // Use single source for local merge. CursorParameters params; params.planNode = plan; - params.maxDrivers = 2; + params.maxDrivers = 1; assertQueryOrdered( params, - "SELECT * FROM (SELECT * FROM tmp UNION ALL SELECT * FROM tmp) ORDER BY " + - orderByClause, + "SELECT * FROM (SELECT * FROM tmp) ORDER BY " + orderByClause, {keyIndex}); // Use multiple sources for local merge. @@ -100,18 +108,17 @@ class MergeTest : public OperatorTestBase { .localMerge( orderByClauses, {PlanBuilder(planNodeIdGenerator) - .values(inputVectors, true) + .values(inputVectors) .orderBy(orderByClauses, true) .planNode()}) .planNode(); - + // Use single source for local merge. CursorParameters params; params.planNode = plan; - params.maxDrivers = 2; + params.maxDrivers = 1; assertQueryOrdered( params, - "SELECT * FROM (SELECT * FROM tmp UNION ALL SELECT * FROM tmp) " + - orderBySql, + "SELECT * FROM (SELECT * FROM tmp) " + orderBySql, sortingKeys); // Use multiple sources for local merge. @@ -131,12 +138,473 @@ class MergeTest : public OperatorTestBase { } } } + + void testSingleKeyWithSpill( + const std::vector& inputVectors, + const std::string& key) { + auto keyIndex = inputVectors[0]->type()->asRow().getChildIdx(key); + + std::vector sortOrderSqls = { + "NULLS LAST", "NULLS FIRST", "DESC NULLS FIRST", "DESC NULLS LAST"}; + + for (const auto& sortOrderSql : sortOrderSqls) { + const auto orderByClause = fmt::format("{} {}", key, sortOrderSql); + const auto planNodeIdGenerator = + std::make_shared(); + const std::shared_ptr spillDirectory = + TempDirectoryPath::create(); + std::vector> sources; + for (const auto& input : inputVectors) { + sources.push_back(PlanBuilder(planNodeIdGenerator) + .values({input}) + .orderBy({orderByClause}, true) + .planNode()); + } + core::PlanNodeId nodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .localMerge({orderByClause}, std::move(sources)) + .capturePlanNodeId(nodeId) + .planNode(); + CursorParameters params; + params.planNode = plan; + params.maxDrivers = 2; + params.queryCtx = createQueryCtx( + {{"spill_enabled", "true"}, + {"local_merge_spill_enabled", "true"}, + {"local_merge_max_num_merge_sources", "2"}}); + params.spillDirectory = spillDirectory->getPath(); + auto task = assertQueryOrdered( + params, "SELECT * FROM tmp ORDER BY " + orderByClause, {keyIndex}); + auto taskStats = toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(nodeId); + auto expectedNumSpillFiles = (inputVectors.size() + 2 - 1) / 2; + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_EQ(planStats.spilledPartitions, expectedNumSpillFiles); + ASSERT_EQ(planStats.spilledFiles, expectedNumSpillFiles); + ASSERT_EQ( + planStats.spilledRows, inputVectors.size() * inputVectors[0]->size()); + } + } + + void testTwoKeysWithSpill( + const std::vector& inputVectors, + const std::string& key1, + const std::string& key2) { + auto& rowType = inputVectors[0]->type()->asRow(); + auto sortingKeys = {rowType.getChildIdx(key1), rowType.getChildIdx(key2)}; + + std::vector sortOrders = { + core::kAscNullsLast, + core::kAscNullsFirst, + core::kDescNullsFirst, + core::kDescNullsLast}; + std::vector sortOrderSqls = { + "NULLS LAST", "NULLS FIRST", "DESC NULLS FIRST", "DESC NULLS LAST"}; + + for (auto i = 0; i < sortOrders.size(); ++i) { + for (auto j = 0; j < sortOrders.size(); ++j) { + const std::vector orderByClauses = { + fmt::format("{} {}", key1, sortOrderSqls[i]), + fmt::format("{} {}", key2, sortOrderSqls[j])}; + const auto orderBySql = fmt::format( + "ORDER BY {}, {}", orderByClauses[0], orderByClauses[1]); + const auto planNodeIdGenerator = + std::make_shared(); + const std::shared_ptr spillDirectory = + TempDirectoryPath::create(); + std::vector> sources; + for (const auto& input : inputVectors) { + sources.push_back(PlanBuilder(planNodeIdGenerator) + .values({input}) + .orderBy(orderByClauses, true) + .planNode()); + } + core::PlanNodeId nodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .localMerge({orderByClauses}, std::move(sources)) + .capturePlanNodeId(nodeId) + .planNode(); + CursorParameters params; + params.planNode = plan; + params.maxDrivers = 2; + params.queryCtx = createQueryCtx( + {{"spill_enabled", "true"}, + {"local_merge_spill_enabled", "true"}, + {"local_merge_max_num_merge_sources", "2"}}); + params.spillDirectory = spillDirectory->getPath(); + auto task = assertQueryOrdered( + params, "SELECT * FROM tmp " + orderBySql, sortingKeys); + auto taskStats = toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(nodeId); + auto expectedNumSpillFiles = (inputVectors.size() + 2 - 1) / 2; + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_EQ(planStats.spilledPartitions, expectedNumSpillFiles); + ASSERT_EQ(planStats.spilledFiles, expectedNumSpillFiles); + ASSERT_EQ( + planStats.spilledRows, + inputVectors.size() * inputVectors[0]->size()); + } + } + } + + void testLocalMerge( + int numInputSources, + int numInputBatches, + int inputBatchSize, + int maxBatchRows, + int maxBatchBytes, + int expectedOuputBatches) { + std::vector> inputVectors; + for (int32_t i = 0; i < numInputSources; ++i) { + std::vector vectors; + for (int32_t j = 0; j < numInputBatches; ++j) { + auto c0 = makeFlatVector( + inputBatchSize, + [&](auto row) { return inputBatchSize * j + row; }, + nullEvery(5)); + auto c1 = makeFlatVector( + inputBatchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + inputBatchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(inputBatchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + inputVectors.push_back(std::move(vectors)); + } + std::vector duckInputs; + for (const auto& input : inputVectors) { + for (const auto& vector : input) { + duckInputs.push_back(vector); + } + } + createDuckDbTable(duckInputs); + auto keyIndex = inputVectors[0][0]->type()->asRow().getChildIdx("c0"); + + const auto orderByClause = fmt::format("{}", "c0"); + const auto planNodeIdGenerator = + std::make_shared(); + const std::shared_ptr spillDirectory = + TempDirectoryPath::create(); + std::vector> sources; + sources.reserve(inputVectors.size()); + for (const auto& vectors : inputVectors) { + sources.push_back(PlanBuilder(planNodeIdGenerator) + .values(vectors) + .orderBy({orderByClause}, true) + .planNode()); + } + core::PlanNodeId localMergeNodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .localMerge({orderByClause}, std::move(sources)) + .capturePlanNodeId(localMergeNodeId) + .planNode(); + CursorParameters params; + params.planNode = plan; + params.maxDrivers = numInputSources; + params.queryCtx = createQueryCtx( + {{"spill_enabled", "false"}, + {"local_merge_spill_enabled", "false"}, + {"preferred_output_batch_bytes", std::to_string(maxBatchBytes)}, + {"preferred_output_batch_rows", std::to_string(maxBatchRows)}}, + false); + auto task = assertQueryOrdered( + params, "SELECT * FROM tmp ORDER BY " + orderByClause, {keyIndex}); + + auto taskStats = toPlanStats(task->taskStats()); + const auto& mergeStats = taskStats.at(localMergeNodeId); + ASSERT_EQ(mergeStats.spilledBytes, 0); + ASSERT_EQ(mergeStats.spilledPartitions, 0); + ASSERT_EQ(mergeStats.spilledFiles, 0); + ASSERT_EQ(mergeStats.spilledRows, 0); + ASSERT_EQ( + mergeStats.outputRows, + numInputSources * numInputBatches * inputBatchSize); + ASSERT_EQ(mergeStats.outputVectors, expectedOuputBatches); + } + + void testLocalMergeSpill( + int numInputSources, + int numInputBatches, + int inputBatchSize, + int maxBatchRows, + int maxBatchBytes, + int numMaxMergeSources, + bool hasSpillExecutor, + int expectedOuputBatches) { + std::vector> inputVectors; + for (int32_t i = 0; i < numInputSources; ++i) { + std::vector vectors; + for (int32_t j = 0; j < numInputBatches; ++j) { + auto c0 = makeFlatVector( + inputBatchSize, + [&](auto row) { return inputBatchSize * j + row; }, + nullEvery(5)); + auto c1 = makeFlatVector( + inputBatchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + inputBatchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(inputBatchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + inputVectors.push_back(std::move(vectors)); + } + std::vector duckInputs; + for (const auto& input : inputVectors) { + for (const auto& vector : input) { + duckInputs.push_back(vector); + } + } + createDuckDbTable(duckInputs); + auto keyIndex = inputVectors[0][0]->type()->asRow().getChildIdx("c0"); + + const auto orderByClause = fmt::format("{}", "c0"); + const auto planNodeIdGenerator = + std::make_shared(); + const std::shared_ptr spillDirectory = + TempDirectoryPath::create(); + std::vector> sources; + sources.reserve(inputVectors.size()); + for (const auto& vectors : inputVectors) { + sources.push_back(PlanBuilder(planNodeIdGenerator) + .values(vectors) + .orderBy({orderByClause}, true) + .planNode()); + } + core::PlanNodeId nodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .localMerge({orderByClause}, std::move(sources)) + .capturePlanNodeId(nodeId) + .planNode(); + CursorParameters params; + params.planNode = plan; + params.maxDrivers = 2; + params.queryCtx = createQueryCtx( + {{"spill_enabled", "true"}, + {"local_merge_spill_enabled", "true"}, + {"local_merge_max_num_merge_sources", + std::to_string(numMaxMergeSources)}, + {"preferred_output_batch_bytes", std::to_string(maxBatchBytes)}, + {"preferred_output_batch_rows", std::to_string(maxBatchRows)}}, + hasSpillExecutor); + params.spillDirectory = spillDirectory->getPath(); + auto task = assertQueryOrdered( + params, "SELECT * FROM tmp ORDER BY " + orderByClause, {keyIndex}); + + auto taskStats = toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(nodeId); + if (inputBatchSize == 0 || numMaxMergeSources >= numInputSources || + !hasSpillExecutor) { + ASSERT_EQ(planStats.spilledBytes, 0); + ASSERT_EQ(planStats.spilledPartitions, 0); + ASSERT_EQ(planStats.spilledFiles, 0); + ASSERT_EQ(planStats.spilledRows, 0); + ASSERT_EQ( + planStats.customStats.count(Merge::kSpilledSourceReadWallNanos), 0); + ASSERT_GE( + planStats.customStats.count(Merge::kStreamingSourceReadWallNanos), 0); + } else { + const auto expectedFiles = + (inputVectors.size() + numMaxMergeSources - 1) / numMaxMergeSources; + const auto expectedSpillRows = inputBatchSize * duckInputs.size(); + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_EQ(planStats.spilledPartitions, expectedFiles); + ASSERT_EQ(planStats.spilledFiles, expectedFiles); + ASSERT_EQ(planStats.spilledRows, expectedSpillRows); + ASSERT_GE( + planStats.customStats.count(Merge::kSpilledSourceReadWallNanos), 0); + ASSERT_GE( + planStats.customStats.count(Merge::kStreamingSourceReadWallNanos), 0); + } + ASSERT_EQ( + planStats.outputRows, + numInputSources * numInputBatches * inputBatchSize); + ASSERT_EQ(planStats.outputVectors, expectedOuputBatches); + } + + std::shared_ptr createQueryCtx( + std::unordered_map queryConfigs = {}, + bool hasSpillExecutor = true) const { + return core::QueryCtx::create( + executor_.get(), + core::QueryConfig{std::move(queryConfigs)}, + {}, + nullptr, + nullptr, + hasSpillExecutor ? spillExecutor_.get() : nullptr); + } }; -TEST_F(MergeTest, localMerge) { - vector_size_t batchSize = 1000; +TEST_F(MergeTest, localMergeSpillBasic) { std::vector vectors; - for (int32_t i = 0; i < 3; ++i) { + for (int32_t i = 0; i < 9; ++i) { + constexpr vector_size_t batchSize = 137; + auto c0 = makeFlatVector( + batchSize, [&](auto row) { return batchSize * i + row; }, nullEvery(5)); + auto c1 = makeFlatVector( + batchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + batchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(batchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + createDuckDbTable(vectors); + + testSingleKeyWithSpill(vectors, "c0"); + testSingleKeyWithSpill(vectors, "c3"); + + testTwoKeysWithSpill(vectors, "c0", "c3"); + testTwoKeysWithSpill(vectors, "c3", "c0"); +} + +TEST_F(MergeTest, localMergeSpill) { + struct TestParam { + int numInputSources; + int numInputBatches; + int inputBatchSize; + int maxNumMergeSources; + int maxOutputBatchRows; + int maxOutputBatchBytes; + int numExpectedOutputBatches; + bool hasSpillExecutor; + + std::string debugString() const { + return fmt::format( + "numInputSources {}, numInputBatches {}, inputBatchSize {}, maxNumMergeSources {}, maxOutputBatchRows {}, maxOutputBatchBytes {}, numExpectedOutputBatches {}, hasSpillExecutor {}", + numInputSources, + numInputBatches, + inputBatchSize, + maxNumMergeSources, + maxOutputBatchRows, + maxOutputBatchBytes, + numExpectedOutputBatches, + hasSpillExecutor); + } + } testSettings[]{ + {1, 1, 1, 1, 1, std::numeric_limits::max(), 1, true}, + {1, 1, 1, 1, 1, std::numeric_limits::max(), 1, false}, + {1, 4, 1, 1, 1, std::numeric_limits::max(), 4, true}, + {1, 4, 1, 1, 1, std::numeric_limits::max(), 4, false}, + {1, 4, 32, 1, 1, std::numeric_limits::max(), 4 * 32, true}, + {1, 4, 32, 1, 1, std::numeric_limits::max(), 4 * 32, false}, + {3, 4, 32, 1, 1, std::numeric_limits::max(), 3 * 4 * 32, true}, + {3, 4, 32, 1, 1, std::numeric_limits::max(), 3 * 4 * 32, false}, + {3, 4, 32, 2, 1, std::numeric_limits::max(), 3 * 4 * 32, true}, + {3, 4, 32, 2, 1, std::numeric_limits::max(), 3 * 4 * 32, false}, + {3, 4, 32, 3, 1, std::numeric_limits::max(), 3 * 4 * 32, true}, + {3, 4, 32, 3, 1, std::numeric_limits::max(), 3 * 4 * 32, false}, + {3, 4, 32, 4, 1, std::numeric_limits::max(), 3 * 4 * 32, true}, + {3, 4, 32, 4, 1, std::numeric_limits::max(), 3 * 4 * 32, false}, + {1, 1, 1, 1, 1024, std::numeric_limits::max(), 1, true}, + {1, 1, 1, 1, 1024, std::numeric_limits::max(), 1, false}, + {1, 4, 1, 1, 1024, std::numeric_limits::max(), 1, true}, + {1, 4, 1, 1, 1024, std::numeric_limits::max(), 1, false}, + {1, 4, 32, 1, 1024, std::numeric_limits::max(), 1, true}, + {1, 4, 32, 1, 1024, std::numeric_limits::max(), 1, false}, + {3, 4, 32, 1, 1024, std::numeric_limits::max(), 1, true}, + {3, 4, 32, 1, 1024, std::numeric_limits::max(), 1, false}, + {3, 4, 32, 2, 1024, std::numeric_limits::max(), 1, true}, + {3, 4, 32, 2, 1024, std::numeric_limits::max(), 1, false}, + {3, 4, 32, 3, 1024, std::numeric_limits::max(), 1, true}, + {3, 4, 32, 3, 1024, std::numeric_limits::max(), 1, false}, + {3, 4, 32, 4, 1024, std::numeric_limits::max(), 1, true}, + {3, 4, 32, 4, 1024, std::numeric_limits::max(), 1, false}, + {1, 1, 1, 1, 1024, 1, 1, true}, + {1, 1, 1, 1, 1024, 1, 1, false}, + {1, 4, 1, 1, 1024, 1, 4, true}, + {1, 4, 1, 1, 1024, 1, 4, false}, + {1, 4, 32, 1, 1024, 1, 4 * 32, true}, + {1, 4, 32, 1, 1024, 1, 4 * 32, false}, + {3, 4, 32, 1, 1024, 1, 3 * 4 * 32, true}, + {3, 4, 32, 1, 1024, 1, 3 * 4 * 32, false}, + {3, 4, 32, 2, 1024, 1, 3 * 4 * 32, true}, + {3, 4, 32, 2, 1024, 1, 3 * 4 * 32, false}, + {3, 4, 32, 3, 1024, 1, 3 * 4 * 32, true}, + {3, 4, 32, 3, 1024, 1, 3 * 4 * 32, false}, + {3, 4, 32, 4, 1024, 1, 3 * 4 * 32, true}, + {3, 4, 32, 4, 1024, 1, 3 * 4 * 32, false}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + testLocalMergeSpill( + testData.numInputSources, + testData.numInputBatches, + testData.inputBatchSize, + testData.maxOutputBatchRows, + testData.maxOutputBatchBytes, + testData.maxNumMergeSources, + testData.hasSpillExecutor, + testData.numExpectedOutputBatches); + } +} + +TEST_F(MergeTest, localMergeSpillPartialEmpty) { + std::vector vectors; + for (int32_t i = 0; i < 9; ++i) { + auto batchSize = 30; + if (i % 2 == 0) { + batchSize = 0; + } + + auto c0 = makeFlatVector( + batchSize, [&](auto row) { return batchSize * i + row; }, nullEvery(5)); + auto c1 = makeFlatVector( + batchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + batchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(batchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + createDuckDbTable(vectors); + auto keyIndex = vectors[0]->type()->asRow().getChildIdx("c0"); + + const auto orderByClause = fmt::format("{}", "c0"); + const auto planNodeIdGenerator = + std::make_shared(); + const std::shared_ptr spillDirectory = + TempDirectoryPath::create(); + std::vector> sources; + sources.reserve(vectors.size()); + for (const auto& vector : vectors) { + sources.push_back(PlanBuilder(planNodeIdGenerator) + .values({vector}) + .orderBy({orderByClause}, true) + .planNode()); + } + core::PlanNodeId nodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .localMerge({orderByClause}, std::move(sources)) + .capturePlanNodeId(nodeId) + .planNode(); + CursorParameters params; + params.planNode = plan; + params.maxDrivers = 2; + params.queryCtx = createQueryCtx( + {{"spill_enabled", "true"}, + {"local_merge_spill_enabled", "true"}, + {"local_merge_max_num_merge_sources", "2"}}); + params.spillDirectory = spillDirectory->getPath(); + auto task = assertQueryOrdered( + params, "SELECT * FROM tmp ORDER BY " + orderByClause, {keyIndex}); + + auto taskStats = toPlanStats(task->taskStats()); + auto& planStats = taskStats.at(nodeId); + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_EQ(planStats.spilledPartitions, 4); + ASSERT_EQ(planStats.spilledFiles, 4); + ASSERT_EQ(planStats.spilledRows, 120); +} + +DEBUG_ONLY_TEST_F(MergeTest, localMergeSpillWithException) { + std::vector vectors; + for (int32_t i = 0; i < 9; ++i) { + constexpr vector_size_t batchSize = 137; auto c0 = makeFlatVector( batchSize, [&](auto row) { return batchSize * i + row; }, nullEvery(5)); auto c1 = makeFlatVector( @@ -150,6 +618,200 @@ TEST_F(MergeTest, localMerge) { } createDuckDbTable(vectors); + for (auto i = 0; i < 11; ++i) { + std::atomic_int cnt{0}; + const auto errorMessage = "ConcatFilesSpillBatchStream::nextBatch fail"; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", + std::function([&](void* /*unused*/) { + if (cnt++ == i) { + VELOX_FAIL("ConcatFilesSpillBatchStream::nextBatch fail"); + } + })); + + VELOX_ASSERT_THROW(testSingleKeyWithSpill(vectors, "c0"), errorMessage); + } +} + +DEBUG_ONLY_TEST_F(MergeTest, localMergeSmallBatch) { + std::vector vectors; + for (int32_t i = 0; i < 9; ++i) { + auto batchSize = 30; + if (i != 0) { + batchSize = 0; + } + + auto c0 = makeFlatVector( + batchSize, [&](auto row) { return batchSize * i + row; }, nullEvery(5)); + auto c1 = makeFlatVector( + batchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + batchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(batchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + createDuckDbTable(vectors); + auto keyIndex = vectors[0]->type()->asRow().getChildIdx("c0"); + + const auto orderByClause = fmt::format("{}", "c0"); + const auto planNodeIdGenerator = + std::make_shared(); + const std::shared_ptr spillDirectory = + TempDirectoryPath::create(); + std::vector> sources; + sources.reserve(vectors.size()); + for (const auto& vector : vectors) { + sources.push_back(PlanBuilder(planNodeIdGenerator) + .values({vector}) + .orderBy({orderByClause}, true) + .planNode()); + } + core::PlanNodeId nodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .localMerge({orderByClause}, std::move(sources)) + .capturePlanNodeId(nodeId) + .planNode(); + + std::atomic_bool blockFlag{true}; + folly::Promise promise; + folly::EventCount callWait; + std::atomic_bool callWaitFlag{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::SourceMerger::getOutput", + std::function*)>( + [&](std::vector* sourceBlockingFutures) { + if (blockFlag) { + blockFlag = false; + auto [p, f] = folly::makePromiseContract(); + sourceBlockingFutures->push_back(std::move(f)); + promise = std::move(p); + callWaitFlag = false; + callWait.notifyAll(); + } + })); + + std::thread promiseThread([&]() { + callWait.await([&]() { return !callWaitFlag.load(); }); + std::this_thread::sleep_for(std::chrono::milliseconds(1'000)); // NOLINT + promise.setValue(); + }); + + CursorParameters params; + params.planNode = plan; + auto task = assertQueryOrdered( + params, "SELECT * FROM tmp ORDER BY " + orderByClause, {keyIndex}); + auto taskStats = toPlanStats(task->taskStats()); + ASSERT_EQ(taskStats[nodeId].outputRows, 30); + ASSERT_EQ(taskStats[nodeId].outputVectors, 1); + promiseThread.join(); +} + +DEBUG_ONLY_TEST_F(MergeTest, localMergeAbort) { + std::vector> inputVectors; + for (int32_t i = 0; i < 4; ++i) { + std::vector vectors; + for (int32_t j = 0; j < 13; ++j) { + constexpr auto batchSize = 5000; + auto c0 = makeFlatVector( + batchSize, + [&](auto row) { return batchSize * j + row; }, + nullEvery(5)); + auto c1 = makeFlatVector( + batchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + batchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(batchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + inputVectors.push_back(std::move(vectors)); + } + + const auto orderByClause = fmt::format("{}", "c0"); + const auto planNodeIdGenerator = + std::make_shared(); + const std::shared_ptr spillDirectory = + TempDirectoryPath::create(); + std::vector> sources; + sources.reserve(inputVectors.size()); + for (const auto& vectors : inputVectors) { + sources.push_back(PlanBuilder(planNodeIdGenerator) + .values(vectors) + .orderBy({orderByClause}, true) + .planNode()); + } + core::PlanNodeId nodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .localMerge({orderByClause}, std::move(sources)) + .capturePlanNodeId(nodeId) + .planNode(); + std::atomic_int cnt{0}; + std::atomic_bool blocked{false}; + folly::EventCount callWait; + std::atomic_bool callWaitFlag{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::SpillMerger::getOutput", + std::function*)>( + [&](std::vector* /*unused*/) { + if (blocked) { + std::this_thread::sleep_for( + std::chrono::milliseconds(1'000)); // NOLINT + blocked = false; + callWaitFlag = false; + callWait.notifyAll(); + VELOX_USER_FAIL("Abort merge"); + } + })); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::SpillMerger::readFromSpillFileStream", + std::function([&](void* /*unused*/) { + if (cnt++ == 2) { + blocked = true; + callWait.await([&]() { return callWaitFlag.load(); }); + std::this_thread::sleep_for( + std::chrono::milliseconds(1'000)); // NOLINT + } + })); + + auto queryCtx = createQueryCtx(); + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .spillDirectory(spillDirectory->getPath()) + .queryCtx(queryCtx) + .config(core::QueryConfig::kSpillEnabled, true) + .config(core::QueryConfig::kLocalMergeSpillEnabled, true) + .config(core::QueryConfig::kLocalMergeMaxNumMergeSources, 2) + .config(core::QueryConfig::kMaxOutputBatchRows, 10) + .config(core::QueryConfig::kPreferredOutputBatchRows, 10) + .copyResults(pool()), + "Abort merge"); + std::dynamic_pointer_cast(spillExecutor_) + ->join(); +} + +TEST_F(MergeTest, localMerge) { + std::vector vectors; + for (int32_t i = 0; i < 3; ++i) { + static constexpr vector_size_t kBatchSize = 100; + auto c0 = makeFlatVector( + kBatchSize, + [&](auto row) { return kBatchSize * i + row; }, + nullEvery(5)); + auto c1 = makeFlatVector( + kBatchSize, [&](auto row) { return row; }, nullEvery(5)); + auto c2 = makeFlatVector( + kBatchSize, [](auto row) { return row * 0.1; }, nullEvery(11)); + auto c3 = makeFlatVector(kBatchSize, [](auto row) { + return StringView::makeInline(std::to_string(row)); + }); + vectors.push_back(makeRowVector({c0, c1, c2, c3})); + } + createDuckDbTable(vectors); + testSingleKey(vectors, "c0"); testSingleKey(vectors, "c3"); @@ -238,3 +900,50 @@ TEST_F(MergeTest, offByOne) { {{core::QueryConfig::kPreferredOutputBatchRows, "6"}}); assertQueryOrdered(params, "VALUES (0), (1), (2), (3), (4), (5), (10)", {0}); } + +TEST_F(MergeTest, localMergeOutputSizeWithoutSpill) { + struct TestParam { + int numSources; + int numInputBatches; + int inputBatchSize; + int maxOutputBatchRows; + int maxOutputBatchBytes; + int numExpectedOutputBatches; + + std::string debugString() const { + return fmt::format( + "numSources {}, numInputBatches {}, inputBatchSize {}, maxOutputBatchRows {}, maxOutputBatchBytes {}, numExpectedOutputBatches {}", + numSources, + numInputBatches, + inputBatchSize, + maxOutputBatchRows, + maxOutputBatchBytes, + numExpectedOutputBatches); + } + } testSettings[]{ + {1, 1, 1, 1, 1'000'000'000, 1}, + {3, 1, 1, 1, 1'000'000'000, 3}, + {3, 4, 1, 1, 1'000'000'000, 3 * 4}, + {3, 4, 32, 1, 1'000'000'000, 3 * 4 * 32}, + {1, 1, 1, 1024, 1'000'000'000, 1}, + {3, 1, 1, 1024, 1'000'000'000, 1}, + {3, 4, 32, 1024, 1'000'000'000, 1}, + {1, 1, 1, 1, 1, 1}, + {3, 1, 1, 1, 1, 3}, + {3, 4, 1, 1, 1, 3 * 4}, + {3, 4, 32, 1, 1, 3 * 4 * 32}, + {1, 1, 1, 1024, 1, 1}, + {3, 1, 1, 1024, 1, 3}, + {3, 4, 1, 1024, 1, 3 * 4}, + {3, 4, 32, 1024, 1, 3 * 4 * 32}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + testLocalMerge( + testData.numSources, + testData.numInputBatches, + testData.inputBatchSize, + testData.maxOutputBatchRows, + testData.maxOutputBatchBytes, + testData.numExpectedOutputBatches); + } +} diff --git a/velox/exec/tests/MergerTest.cpp b/velox/exec/tests/MergerTest.cpp new file mode 100644 index 000000000000..1d848bd115ae --- /dev/null +++ b/velox/exec/tests/MergerTest.cpp @@ -0,0 +1,461 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/file/FileSystems.h" +#include "velox/exec/Merge.h" +#include "velox/exec/MergeSource.h" +#include "velox/exec/SortBuffer.h" +#include "velox/exec/Spill.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/type/Type.h" +#include "velox/vector/fuzzer/VectorFuzzer.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +#include + +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox; +using namespace facebook::velox::memory; + +namespace facebook::velox::exec::test { + +class MergerTest : public OperatorTestBase { + protected: + void SetUp() override { + OperatorTestBase::SetUp(); + filesystems::registerLocalFileSystem(); + } + + std::vector generateSortedVectors( + const int32_t numVectors, + const size_t vectorSize) { + const VectorFuzzer::Options fuzzerOpts{.vectorSize = vectorSize}; + const auto vectors = createVectors(numVectors, inputType_, fuzzerOpts); + const auto sortBuffer = std::make_unique( + inputType_, + sortColumnIndices_, + sortCompareFlags_, + pool_.get(), + &nonReclaimableSection_, + common::PrefixSortConfig{}, + nullptr, + nullptr); + for (const auto& vector : vectors) { + sortBuffer->addInput(vector); + } + sortBuffer->noMoreInput(); + std::vector sortedVectors; + sortedVectors.reserve(numVectors); + for (auto i = 0; i < numVectors; ++i) { + sortedVectors.emplace_back(sortBuffer->getOutput(vectorSize)); + } + return sortedVectors; + } + + SpillFiles generateSortedSpillFiles( + const std::vector& sortedVectors) { + const auto spiller = std::make_unique( + inputType_, + std::nullopt, + HashBitRange{}, + sortingKeys_, + &spillConfig_, + spillStats_.get()); + for (const auto& vector : sortedVectors) { + spiller->spill(SpillPartitionId(0), vector); + } + SpillPartitionSet spillPartitionSet; + spiller->finishSpill(spillPartitionSet); + EXPECT_EQ(spillPartitionSet.size(), 1); + return spillPartitionSet.cbegin()->second->files(); + } + + std::pair< + std::vector>, + std::vector>>> + generateInputs(size_t numStreams, size_t vectorSize) { + std::vector> totalVectors; + std::vector>> + spillReadFilesGroups; + for (auto i = 1; i <= numStreams; ++i) { + const auto vectors = generateSortedVectors(i * 3 + 1, vectorSize); + totalVectors.push_back(vectors); + const auto spillFiles = generateSortedSpillFiles(vectors); + EXPECT_EQ(spillFiles.size(), vectors.size()); + std::vector> spillReadFiles; + spillReadFiles.reserve(spillFiles.size()); + for (const auto& spillFile : spillFiles) { + spillReadFiles.emplace_back( + SpillReadFile::create( + spillFile, + spillConfig_.readBufferSize, + pool_.get(), + spillStats_.get())); + } + spillReadFilesGroups.emplace_back(std::move(spillReadFiles)); + } + return std::make_pair( + std::move(totalVectors), std::move(spillReadFilesGroups)); + } + + std::vector makeExpectedResults( + const std::vector>& inputs, + size_t vectorSize) { + std::vector flatInputs; + for (const auto& vectors : inputs) { + for (const auto& vector : vectors) { + flatInputs.emplace_back(vector); + } + } + const auto sortBuffer = std::make_unique( + inputType_, + sortColumnIndices_, + sortCompareFlags_, + pool_.get(), + &nonReclaimableSection_, + common::PrefixSortConfig{}, + nullptr, + nullptr); + for (const auto& vector : flatInputs) { + sortBuffer->addInput(vector); + } + sortBuffer->noMoreInput(); + std::vector sortedVectors; + sortedVectors.reserve(flatInputs.size()); + for (auto i = 0; i < flatInputs.size(); ++i) { + sortedVectors.emplace_back(sortBuffer->getOutput(vectorSize)); + } + return sortedVectors; + } + + std::unique_ptr createSourceMerger( + const std::vector>& sources, + uint64_t outputBatchRows, + uint64_t outputBatchBytes) const { + std::vector> sourceStreams; + for (const auto& source : sources) { + sourceStreams.push_back( + std::make_unique( + source.get(), sortingKeys_, outputBatchRows)); + } + return std::make_unique( + inputType_, + std::move(sourceStreams), + outputBatchRows, + outputBatchBytes, + pool()); + } + + static std::vector> createMergeSources( + size_t numSources, + size_t queueSize) { + std::vector> sources; + sources.reserve(numSources); + for (auto i = 0; i < numSources; ++i) { + sources.push_back(MergeSource::createLocalMergeSource(queueSize)); + } + for (const auto& source : sources) { + source->start(); + } + return sources; + } + + void produceAsync( + MergeSource* mergeSource, + const std::vector& vectors, + size_t index = 0) const { + ContinueFuture future; + if (index >= vectors.size()) { + const auto reason = mergeSource->enqueue(nullptr, &future); + EXPECT_EQ(reason, BlockingReason::kNotBlocked); + return; + } + + mergeSource->enqueue(vectors[index], &future); + std::move(future) + .via(executor_.get()) + .thenValue([this, mergeSource, &vectors, index](folly::Unit) { + produceAsync(mergeSource, vectors, index + 1); + }) + .thenError(folly::tag_t{}, [](const std::exception& e) { + VELOX_FAIL(e.what()); + }); + } + + void createProducers( + int num, + const std::vector>& inputs, + const std::vector>& sources) const { + for (auto i = 0; i < inputs.size(); ++i) { + executor_->add([&, i]() { produceAsync(sources[i].get(), inputs[i]); }); + } + } + + static std::vector getOutputFromSourceMerger( + SourceMerger* sourceMerger) { + std::vector sourceBlockingFutures; + std::vector results; + for (;;) { + sourceMerger->isBlocked(sourceBlockingFutures); + if (!sourceBlockingFutures.empty()) { + auto future = std::move(sourceBlockingFutures.back()); + sourceBlockingFutures.pop_back(); + future.wait(); + continue; + } + bool atEnd = false; + auto output = sourceMerger->getOutput(sourceBlockingFutures, atEnd); + if (output != nullptr) { + results.emplace_back(std::move(output)); + } + + if (atEnd) { + break; + } + } + return results; + } + + std::shared_ptr createSpillMerger( + std::vector>> + spillReadFilesGroups, + vector_size_t outputBatchRows, + int queueSize) const { + return std::make_shared( + sortingKeys_, + inputType_, + std::move(spillReadFilesGroups), + outputBatchRows, + std::numeric_limits::max(), + queueSize, + &spillConfig_, + spillStats_, + pool()); + } + + static std::vector getOutputFromSpillMerger( + SpillMerger* spillMerger) { + std::vector sourceBlockingFutures; + std::vector results; + for (;;) { + bool atEnd = false; + auto output = spillMerger->getOutput(sourceBlockingFutures, atEnd); + if (output != nullptr) { + results.emplace_back(std::move(output)); + } + + if (atEnd) { + break; + } + + while (!sourceBlockingFutures.empty()) { + auto future = std::move(sourceBlockingFutures.back()); + sourceBlockingFutures.pop_back(); + future.wait(); + } + } + return results; + } + + static void checkResults( + std::vector expectedResults, + std::vector actualResults) { + ASSERT_TRUE(assertEqualResults(expectedResults, actualResults)); + const auto& actual = actualResults[0]; + std::for_each( + std::next(actualResults.begin()), + actualResults.end(), + [&](const auto& ele) { actual->append(ele.get()); }); + const auto& expect = expectedResults[0]; + std::for_each( + std::next(expectedResults.begin()), + expectedResults.end(), + [&](const auto& ele) { expect->append(ele.get()); }); + facebook::velox::test::assertEqualVectors(expect, actual); + } + + private: + const RowTypePtr inputType_ = ROW({{"c0", BIGINT()}, {"c1", SMALLINT()}}); + const std::shared_ptr executor_{ + std::make_shared( + folly::hardware_concurrency())}; + const std::vector sortColumnIndices_{0, 1}; + const std::vector sortCompareFlags_{ + CompareFlags{.ascending = true}, + CompareFlags{.ascending = false}}; + const std::vector sortingKeys_ = + SpillState::makeSortingKeys(sortColumnIndices_, sortCompareFlags_); + const std::shared_ptr spillDirectory_ = + exec::test::TempDirectoryPath::create(); + const common::SpillConfig spillConfig_{ + [&]() -> const std::string& { return spillDirectory_->getPath(); }, + [&](uint64_t) {}, + "0.0.0", + 10, // Force to create a file per spill to mock multiple files per stream + 0, + 1 << 20, + executor_.get(), + 100, + 100, + 0, + 0, + 0, + 0, + 0, + "none", + 0, + std::nullopt}; + + std::shared_ptr> spillStats_ = + std::make_shared>(); + tsan_atomic nonReclaimableSection_{false}; +}; +} // namespace facebook::velox::exec::test + +TEST_F(MergerTest, sourceMerger) { + struct TestSetting { + size_t maxOutputRows; + size_t maxOutputBytes; + size_t numSources; + size_t queueSize; + + std::string debugString() const { + return fmt::format( + "maxOutputRows:{}, maxOutputBytes:{}, numStreams:{}, queueSize:{}", + maxOutputRows, + succinctBytes(maxOutputBytes), + numSources, + queueSize); + } + }; + std::vector testSettings; + for (size_t maxOutputRows : {1, 7, 16}) { + for (size_t numSources : {1, 3, 8}) { + for (size_t queueSize : {1, 2}) { + testSettings.push_back( + {maxOutputRows, 1'000'000'000, numSources, queueSize}); + } + } + } + testSettings.push_back({32, 1'000'000'000, 3, 2}); + testSettings.push_back({1024, 1'000'000'000, 8, 2}); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + std::vector> inputs; + for (auto i = 1; i <= testData.numSources; ++i) { + inputs.emplace_back(generateSortedVectors(i, 16)); + } + const auto sources = + createMergeSources(testData.numSources, testData.queueSize); + const auto sourceMerger = createSourceMerger( + sources, testData.maxOutputRows, testData.maxOutputBytes); + createProducers(testData.numSources, inputs, sources); + const auto results = getOutputFromSourceMerger(sourceMerger.get()); + const auto expectedResults = makeExpectedResults(inputs, 16); + checkResults(expectedResults, results); + } +} + +TEST_F(MergerTest, sourceMergerWithEmptySources) { + std::vector> inputs; + for (auto i = 0; i < 10; ++i) { + const auto numVectors = (i % 2 == 0) ? 0 : i; + inputs.emplace_back(generateSortedVectors(numVectors, 16)); + } + const auto sources = createMergeSources(10, 2); + const auto sourceMerger = createSourceMerger(sources, 32, 1000'000'000); + createProducers(10, inputs, sources); + const auto results = getOutputFromSourceMerger(sourceMerger.get()); + const auto expectedResults = makeExpectedResults(inputs, 16); + checkResults(expectedResults, results); +} + +TEST_F(MergerTest, spillMerger) { + struct TestSetting { + size_t maxOutputRows; + size_t numSources; + size_t queueSize; + + std::string debugString() const { + return fmt::format( + "maxOutputRows:{}, numStreams:{}, queueSize:{}", + maxOutputRows, + numSources, + queueSize); + } + }; + std::vector testSettings; + for (size_t maxOutputRows : {1, 7, 16}) { + for (size_t numSources : {1, 3, 8}) { + for (size_t queueSize : {1, 2}) { + testSettings.push_back({maxOutputRows, numSources, queueSize}); + } + } + } + testSettings.push_back({32, 3, 2}); + testSettings.push_back({1024, 8, 2}); + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + const auto sources = + createMergeSources(testData.numSources, testData.queueSize); + auto [inputs, filesGroup] = generateInputs(testData.numSources, 16); + ASSERT_EQ(filesGroup.size(), testData.numSources); + const auto spillMerger = createSpillMerger( + std::move(filesGroup), testData.maxOutputRows, testData.queueSize); + spillMerger->start(); + const auto results = getOutputFromSpillMerger(spillMerger.get()); + const auto expectedResults = makeExpectedResults(inputs, 16); + checkResults(expectedResults, results); + } +} + +DEBUG_ONLY_TEST_F(MergerTest, spillMergerException) { + struct TestSetting { + size_t maxOutputRows; + size_t numSources; + size_t queueSize; + + std::string debugString() const { + return fmt::format( + "maxOutputRows:{}, numStreams:{}, queueSize:{}", + maxOutputRows, + numSources, + queueSize); + } + }; + + std::atomic_int cnt{0}; + const auto errorMessage = "ConcatFilesSpillBatchStream::nextBatch fail"; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::ConcatFilesSpillBatchStream::nextBatch", + std::function([&](void* /*unused*/) { + if (cnt++ == 11) { + VELOX_FAIL("ConcatFilesSpillBatchStream::nextBatch fail"); + } + })); + const auto numSources = 5; + const auto queueSize = 2; + const auto sources = createMergeSources(numSources, queueSize); + auto [inputs, filesGroup] = generateInputs(numSources, 16); + const auto spillMerger = + createSpillMerger(std::move(filesGroup), 100, queueSize); + spillMerger->start(); + VELOX_ASSERT_THROW(getOutputFromSpillMerger(spillMerger.get()), errorMessage); +} diff --git a/velox/exec/tests/MultiFragmentTest.cpp b/velox/exec/tests/MultiFragmentTest.cpp index 52c46c20180a..13c36172a9fb 100644 --- a/velox/exec/tests/MultiFragmentTest.cpp +++ b/velox/exec/tests/MultiFragmentTest.cpp @@ -16,6 +16,7 @@ #include "folly/experimental/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Exchange.h" @@ -23,11 +24,13 @@ #include "velox/exec/PartitionedOutput.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/RoundRobinPartitionFunction.h" +#include "velox/exec/TableScan.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/LocalExchangeSource.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/SerializedPageUtil.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" using namespace facebook::velox::exec::test; @@ -101,13 +104,14 @@ class MultiFragmentTest : public HiveConnectorTestBase, int destination = 0, Consumer consumer = nullptr, int64_t maxMemory = memory::kMaxMemory, - folly::Executor* executor = nullptr) { + folly::Executor* executor = nullptr) const { auto configCopy = configSettings_; auto queryCtx = core::QueryCtx::create( executor ? executor : executor_.get(), core::QueryConfig(std::move(configCopy))); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), maxMemory, MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), maxMemory, MemoryReclaimer::create())); core::PlanFragment planFragment{planNode}; return Task::create( taskId, @@ -118,6 +122,41 @@ class MultiFragmentTest : public HiveConnectorTestBase, std::move(consumer)); } + std::shared_ptr makeTask( + const std::string& taskId, + const core::PlanNodePtr& planNode, + std::unordered_map& extraQueryConfigs, + int destination = 0, + Consumer consumer = nullptr, + int64_t maxMemory = memory::kMaxMemory, + const std::optional& diskSpillOpts = + std::nullopt) const { + auto configCopy = configSettings_; + for (const auto& [k, v] : extraQueryConfigs) { + configCopy[k] = v; + } + auto queryCtx = core::QueryCtx::create( + executor_.get(), + core::QueryConfig(std::move(configCopy)), + {}, + nullptr, + nullptr, + executor_.get()); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), maxMemory, MemoryReclaimer::create())); + core::PlanFragment planFragment{planNode}; + return Task::create( + taskId, + std::move(planFragment), + destination, + std::move(queryCtx), + Task::ExecutionMode::kParallel, + std::move(consumer), + /*memoryArbitrationPriority=*/0, + diskSpillOpts); + } + std::vector makeVectors(int count, int rowsPerVector) { std::vector vectors; for (int i = 0; i < count; ++i) { @@ -128,8 +167,8 @@ class MultiFragmentTest : public HiveConnectorTestBase, return vectors; } - void addHiveSplits( - std::shared_ptr task, + static void addHiveSplits( + const std::shared_ptr& task, const std::vector>& filePaths) { for (auto& filePath : filePaths) { auto split = exec::Split( @@ -144,11 +183,33 @@ class MultiFragmentTest : public HiveConnectorTestBase, task->noMoreSplits("0"); } - exec::Split remoteSplit(const std::string& taskId) { + static void addHiveSplits( + const std::shared_ptr& task, + const std::unordered_map< + core::PlanNodeId, + std::vector>>& nodeSplits) { + for (const auto& [scanNodeId, splitPaths] : nodeSplits) { + for (const auto& filePath : splitPaths) { + auto split = exec::Split( + std::make_shared( + kHiveConnectorId, + "file:" + filePath->getPath(), + facebook::velox::dwio::common::FileFormat::DWRF), + -1); + task->addSplit(scanNodeId, std::move(split)); + } + } + + for (const auto& [scanNodeId, _] : nodeSplits) { + task->noMoreSplits(scanNodeId); + } + } + + static exec::Split remoteSplit(const std::string& taskId) { return exec::Split(std::make_shared(taskId)); } - void addRemoteSplits( + static void addRemoteSplits( std::shared_ptr task, const std::vector& remoteTaskIds) { for (auto& taskId : remoteTaskIds) { @@ -197,7 +258,7 @@ class MultiFragmentTest : public HiveConnectorTestBase, exchangeStats.at("localExchangeSource.numPages").count); ASSERT_EQ( expectedBackgroundCpuCount, - exchangeStats.at(ExchangeClient::kBackgroundCpuTimeMs).count); + exchangeStats.at(Operator::kBackgroundCpuTimeNanos).count); ASSERT_EQ( expectedBackgroundCpuCount, taskStats.at("0").backgroundTiming.count); } @@ -525,42 +586,61 @@ TEST_P(MultiFragmentTest, abortMergeExchange) { TEST_P(MultiFragmentTest, mergeExchange) { setupSources(20, 1000); - static const core::SortOrder kAscNullsLast(true, false); std::vector> tasks; - std::vector> filePaths0( filePaths_.begin(), filePaths_.begin() + 10); std::vector> filePaths1( filePaths_.begin() + 10, filePaths_.end()); - std::vector>> filePathsList = { filePaths0, filePaths1}; - std::vector partialSortTaskIds; RowTypePtr outputType; - core::PlanNodeId partitionNodeId; - for (int i = 0; i < 2; ++i) { + for (int numPartialSortTasks = 0; numPartialSortTasks < 2; + ++numPartialSortTasks) { auto sortTaskId = makeTaskId("orderby", tasks.size()); partialSortTaskIds.push_back(sortTaskId); auto planNodeIdGenerator = std::make_shared(); + std::vector> sources; + sources.reserve(5); + std::vector scanNodeIds; + scanNodeIds.reserve(5); + for (int numSourcesPerPartialSortTask = 0; numSourcesPerPartialSortTask < 5; + ++numSourcesPerPartialSortTask) { + core::PlanNodeId scanNodeId; + sources.push_back(PlanBuilder(planNodeIdGenerator) + .tableScan(rowType_) + .capturePlanNodeId(scanNodeId) + .orderBy({"c0"}, true) + .planNode()); + scanNodeIds.push_back(scanNodeId); + } + auto partialSortPlan = PlanBuilder(planNodeIdGenerator) - .localMerge( - {"c0"}, - {PlanBuilder(planNodeIdGenerator) - .tableScan(rowType_) - .orderBy({"c0"}, true) - .planNode()}) + .localMerge({"c0"}, std::move(sources)) .partitionedOutput({}, 1, /*outputLayout=*/{}, GetParam().serdeKind) .capturePlanNodeId(partitionNodeId) .planNode(); - auto sortTask = makeTask(sortTaskId, partialSortPlan, tasks.size()); tasks.push_back(sortTask); sortTask->start(4); - addHiveSplits(sortTask, filePathsList[i]); + + std::unordered_map< + core::PlanNodeId, + std::vector>> + nodeSplits; + nodeSplits.reserve(5); + for (int i = 0; i < 5; ++i) { + nodeSplits[scanNodeIds[i]] = std::vector>(); + } + auto& filePaths = filePathsList[numPartialSortTasks]; + for (int i = 0; i < filePaths.size(); ++i) { + nodeSplits[scanNodeIds[i % scanNodeIds.size()]].push_back(filePaths[i]); + } + + addHiveSplits(sortTask, nodeSplits); outputType = partialSortPlan->outputType(); } @@ -795,6 +875,138 @@ TEST_P(MultiFragmentTest, partitionedOutput) { } } +TEST_P(MultiFragmentTest, mergeExchangeWithSpill) { + setupSources(20, 1000); + static const core::SortOrder kAscNullsLast(true, false); + std::vector> tasks; + std::vector> filePaths0( + filePaths_.begin(), filePaths_.begin() + 10); + std::vector> filePaths1( + filePaths_.begin() + 10, filePaths_.end()); + std::vector>> filePathsList = { + filePaths0, filePaths1}; + std::vector partialSortTaskIds; + RowTypePtr outputType; + std::vector> spillDirectories; + core::PlanNodeId partitionNodeId; + std::unordered_map spillMergeConfigs{ + {"spill_enabled", "true"}, + {"local_merge_spill_enabled", "true"}, + {"local_merge_max_num_merge_sources", "3"}}; + std::vector localMergeNodeIds; + for (int numPartialSortTasks = 0; numPartialSortTasks < 2; + ++numPartialSortTasks) { + auto sortTaskId = makeTaskId("orderby", tasks.size()); + partialSortTaskIds.push_back(sortTaskId); + auto planNodeIdGenerator = std::make_shared(); + std::vector> sources; + sources.reserve(5); + std::vector scanNodeIds; + scanNodeIds.reserve(5); + for (int numSourcesPerPartialSortTask = 0; numSourcesPerPartialSortTask < 5; + ++numSourcesPerPartialSortTask) { + core::PlanNodeId scanNodeId; + sources.push_back(PlanBuilder(planNodeIdGenerator) + .tableScan(rowType_) + .capturePlanNodeId(scanNodeId) + .orderBy({"c0"}, true) + .planNode()); + scanNodeIds.push_back(scanNodeId); + } + + core::PlanNodeId localMergeNodeId; + auto partialSortPlan = + PlanBuilder(planNodeIdGenerator) + .localMerge({"c0"}, std::move(sources)) + .capturePlanNodeId(localMergeNodeId) + .partitionedOutput({}, 1, /*outputLayout=*/{}, GetParam().serdeKind) + .capturePlanNodeId(partitionNodeId) + .planNode(); + localMergeNodeIds.push_back(localMergeNodeId); + spillDirectories.push_back(TempDirectoryPath::create()); + common::SpillDiskOptions spillOpts; + spillOpts.spillDirPath = spillDirectories[numPartialSortTasks]->getPath(); + auto sortTask = makeTask( + sortTaskId, + partialSortPlan, + spillMergeConfigs, + tasks.size(), + /*consumer=*/nullptr, + memory::kMaxMemory, + spillOpts); + tasks.push_back(sortTask); + sortTask->start(4); + + std::unordered_map< + core::PlanNodeId, + std::vector>> + nodeSplits; + nodeSplits.reserve(5); + for (int i = 0; i < 5; ++i) { + nodeSplits[scanNodeIds[i]] = std::vector>(); + } + auto& filePaths = filePathsList[numPartialSortTasks]; + for (int i = 0; i < filePaths.size(); ++i) { + nodeSplits[scanNodeIds[i % scanNodeIds.size()]].push_back(filePaths[i]); + } + + addHiveSplits(sortTask, nodeSplits); + outputType = partialSortPlan->outputType(); + } + + auto finalSortTaskId = makeTaskId("orderby", tasks.size()); + core::PlanNodeId mergeExchangeId; + auto finalSortPlan = + PlanBuilder() + .mergeExchange(outputType, {"c0"}, GetParam().serdeKind) + .capturePlanNodeId(mergeExchangeId) + .partitionedOutput({}, 1, /*outputLayout=*/{}, GetParam().serdeKind) + .planNode(); + + auto mergeTask = makeTask(finalSortTaskId, finalSortPlan, 0); + tasks.push_back(mergeTask); + mergeTask->start(1); + addRemoteSplits(mergeTask, partialSortTaskIds); + + auto op = PlanBuilder().exchange(outputType, GetParam().serdeKind).planNode(); + + test::AssertQueryBuilder(op, duckDbQueryRunner_) + .split(remoteSplit(finalSortTaskId)) + .config( + core::QueryConfig::kShuffleCompressionKind, + common::compressionKindToString(GetParam().compressionKind)) + .assertResults( + "SELECT * FROM tmp ORDER BY 1 NULLS LAST", std::vector{0}); + + for (auto& task : tasks) { + ASSERT_TRUE(waitForTaskCompletion(task.get())) << task->taskId(); + } + for (auto i = 0; i < 2; ++i) { + auto taskStats = toPlanStats(tasks[i]->taskStats()); + auto& planStats = taskStats.at(localMergeNodeIds[i]); + ASSERT_GT(planStats.spilledBytes, 0); + ASSERT_GT(planStats.spilledPartitions, 0); + ASSERT_GT(planStats.spilledFiles, 0); + } + + const auto finalSortStats = toPlanStats(mergeTask->taskStats()); + const auto& mergeExchangeStats = finalSortStats.at(mergeExchangeId); + + EXPECT_EQ(20'000, mergeExchangeStats.inputRows); + EXPECT_EQ(20'000, mergeExchangeStats.rawInputRows); + + EXPECT_LT(0, mergeExchangeStats.inputBytes); + EXPECT_LT(0, mergeExchangeStats.rawInputBytes); + + const auto serdeKindRuntimsStats = + mergeExchangeStats.customStats.at(Operator::kShuffleSerdeKind); + ASSERT_EQ(serdeKindRuntimsStats.count, 1); + ASSERT_EQ( + serdeKindRuntimsStats.min, static_cast(GetParam().serdeKind)); + ASSERT_EQ( + serdeKindRuntimsStats.max, static_cast(GetParam().serdeKind)); +} + TEST_P(MultiFragmentTest, noHashPartitionSkew) { setupSources(10, 1000); @@ -1411,6 +1623,56 @@ TEST_P(MultiFragmentTest, mergeExchangeOverEmptySources) { } } +DEBUG_ONLY_TEST_P(MultiFragmentTest, mergeExchangeFailureOnStart) { + std::vector> tasks; + std::vector leafTaskIds; + + const auto injectErrorMsg{"injectError"}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::ExchangeQueue::dequeueLocked", + std::function( + ([&](ExchangeQueue* /*unused*/) { VELOX_FAIL(injectErrorMsg); }))); + + auto data = makeRowVector(rowType_, 0); + + for (int i = 0; i < 2; ++i) { + auto taskId = makeTaskId("leaf-", i); + leafTaskIds.push_back(taskId); + auto plan = + PlanBuilder() + .values({data}) + .partitionedOutput({}, 1, /*outputLayout=*/{}, GetParam().serdeKind) + .planNode(); + + auto task = makeTask(taskId, plan, tasks.size()); + tasks.push_back(task); + task->start(4); + } + + auto exchangeTaskId = makeTaskId("exchange-", 0); + auto plan = PlanBuilder() + .mergeExchange(rowType_, {"c0"}, GetParam().serdeKind) + .singleAggregation({"c0"}, {"count(1)"}) + .planNode(); + + std::vector leafTaskSplits; + for (auto leafTaskId : leafTaskIds) { + leafTaskSplits.emplace_back(remoteSplit(leafTaskId)); + } + VELOX_ASSERT_THROW( + test::AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits(std::move(leafTaskSplits)) + .config( + core::QueryConfig::kShuffleCompressionKind, + common::compressionKindToString(GetParam().compressionKind)) + .assertResults(""), + injectErrorMsg); + + for (auto& task : tasks) { + ASSERT_TRUE(waitForTaskCompletion(task.get())) << task->taskId(); + } +} + namespace { core::PlanNodePtr makeJoinOverExchangePlan( const RowTypePtr& exchangeType, @@ -2164,7 +2426,7 @@ DEBUG_ONLY_TEST_P( std::thread failThread([&]() { try { VELOX_FAIL("Test terminate task"); - } catch (const VeloxException& e) { + } catch (const VeloxException&) { task->setError(std::current_exception()); } }); @@ -2237,8 +2499,8 @@ DEBUG_ONLY_TEST_P(MultiFragmentTest, mergeWithEarlyTermination) { mergeIsBlockedReady.store(true); mergeIsBlockedWait.notifyAll(); - ASSERT_TRUE(waitForTaskCompletion(partialSortTask.get(), 1'000'000'000)); - ASSERT_TRUE(waitForTaskAborted(finalSortTask.get(), 1'000'000'000)); + ASSERT_TRUE(waitForTaskCompletion(partialSortTask.get(), 30'000'000)); + ASSERT_TRUE(waitForTaskAborted(finalSortTask.get(), 30'000'000)); } class DataFetcher { @@ -2397,12 +2659,15 @@ DEBUG_ONLY_TEST_P(MultiFragmentTest, maxBytes) { return; } std::string s(25, 'x'); - // Keep the row count under 7000 to avoid hitting the row limit in the - // operator instead. + // This number is chosen so that serialization of all the 100 vectors can fit + // in 32MB (the QueryConfig::maxOutputBufferSize). Otherwise the driver is + // blocked, since the data consumed by DataFetcher is so small that the + // remaining buffered data is still larger than OutputBuffer::continueSize_. + constexpr int kNumRows = 4'800; auto data = makeRowVector({ - makeFlatVector(5'000, [](auto row) { return row; }), - makeFlatVector(5'000, [](auto row) { return row; }), - makeConstant(StringView(s), 5'000), + makeFlatVector(kNumRows, [](auto row) { return row; }), + makeFlatVector(kNumRows, [](auto row) { return row; }), + makeConstant(StringView(s), kNumRows), }); core::PlanNodeId outputNodeId; @@ -2567,7 +2832,7 @@ TEST_P(MultiFragmentTest, earlyTaskFailure) { if (internalFailure) { try { VELOX_FAIL("memoryAbortTest"); - } catch (const VeloxRuntimeError& e) { + } catch (const VeloxRuntimeError&) { finalSortTask->pool()->abort(std::current_exception()); } } else { @@ -2652,8 +2917,8 @@ TEST_P(MultiFragmentTest, mergeSmallBatchesInExchange) { if (GetParam().serdeKind == VectorSerde::Kind::kPresto) { test(1, 1'000); test(1'000, 56); - test(10'000, 6); - test(100'000, 1); + test(10'000, 7); + test(100'000, 2); } else if (GetParam().serdeKind == VectorSerde::Kind::kCompactRow) { test(1, 1'000); test(1'000, 39); @@ -2662,8 +2927,8 @@ TEST_P(MultiFragmentTest, mergeSmallBatchesInExchange) { } else { test(1, 1'000); test(1'000, 72); - test(10'000, 7); - test(100'000, 1); + test(10'000, 8); + test(100'000, 2); } } @@ -2840,10 +3105,11 @@ TEST_P(MultiFragmentTest, compression) { test("local://t1", 0.7, false); } SCOPED_TRACE(fmt::format("minCompressionRatio 0.0000001")); - { test("local://t2", 0.0000001, true); } + { + test("local://t2", 0.0000001, true); + } } } - TEST_P(MultiFragmentTest, scaledTableScan) { const int numSplits = 20; std::vector> splitFiles; @@ -2970,6 +3236,158 @@ TEST_P(MultiFragmentTest, scaledTableScan) { } } +// Test row output with no columns (empty schema). +TEST_P(MultiFragmentTest, emptySchema) { + // Create data with rows but no columns + auto emptyRowType = ROW({}, {}); + auto data = makeRowVector(emptyRowType, 1'000); + + std::vector> tasks; + auto leafTaskId = makeTaskId("leaf", 0); + + // Leaf task: Values -> PartitionedOutput + auto leafPlan = + PlanBuilder() + .values({data}) + .partitionedOutput({}, 1, /*outputLayout=*/{}, GetParam().serdeKind) + .planNode(); + + auto leafTask = makeTask(leafTaskId, leafPlan, tasks.size()); + tasks.push_back(leafTask); + leafTask->start(4); + + // Root task: Exchange -> Project + auto rootTaskId = makeTaskId("root", 0); + auto rootPlan = PlanBuilder() + .exchange(emptyRowType, GetParam().serdeKind) + .singleAggregation({}, {"count(1)"}) + .planNode(); + + test::AssertQueryBuilder(rootPlan, duckDbQueryRunner_) + .split(remoteSplit(leafTaskId)) + .config( + core::QueryConfig::kShuffleCompressionKind, + common::compressionKindToString(GetParam().compressionKind)) + .assertResults("SELECT 1000"); + + for (auto& task : tasks) { + ASSERT_TRUE(waitForTaskCompletion(task.get())) << task->taskId(); + } +} + +// Test stateful deserialization with different batch byte limits. +// This validates that the Exchange operator correctly breaks in the middle +// and continues from the leftover when batch size limits are reached. +TEST_P(MultiFragmentTest, batchBytes) { + auto test = [&](int32_t numBatches, + int32_t rowsPerBatch, + uint64_t preferredBatchBytes, + uint64_t expectedAtLeastOutputBatches = 0) { + SCOPED_TRACE( + fmt::format( + "numBatches={}, rowsPerBatch={}, preferredBatchBytes={}", + numBatches, + rowsPerBatch, + succinctBytes(preferredBatchBytes))); + + std::vector batches; + batches.reserve(numBatches); + + for (int i = 0; i < numBatches; ++i) { + auto batch = makeRowVector({ + makeFlatVector( + rowsPerBatch, + [i, rowsPerBatch](auto row) { return i * rowsPerBatch + row; }), + makeFlatVector( + rowsPerBatch, + [i, rowsPerBatch](auto row) { + return (i * rowsPerBatch + row) % 1000; + }), + makeFlatVector( + rowsPerBatch, + [i, rowsPerBatch](auto row) { + return (i * rowsPerBatch + row) * 1.5; + }), + }); + batches.push_back(batch); + } + + auto leafTaskId = makeTaskId("leaf", 0); + auto leafPlan = + PlanBuilder() + .values(batches) + .partitionedOutput({}, 1, {"c0", "c1", "c2"}, GetParam().serdeKind) + .planNode(); + + auto leafTask = makeTask(leafTaskId, leafPlan, 0); + leafTask->start(1); + + core::PlanNodeId exchangeNodeId; + auto rootPlan = + PlanBuilder() + .exchange( + ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), DOUBLE()}), + GetParam().serdeKind) + .capturePlanNodeId(exchangeNodeId) + .singleAggregation({}, {"count(1)", "sum(c0)", "avg(c2)"}) + .planNode(); + + auto extraConfigs = std::unordered_map{ + {core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(preferredBatchBytes)}, + {core::QueryConfig::kShuffleCompressionKind, + common::compressionKindToString(GetParam().compressionKind)}}; + + auto task = test::AssertQueryBuilder(rootPlan, duckDbQueryRunner_) + .split(remoteSplit(leafTaskId)) + .configs(extraConfigs) + .assertResults( + fmt::format( + "SELECT {}, {}, {}", + numBatches * rowsPerBatch, + (static_cast(numBatches) * rowsPerBatch * + (numBatches * rowsPerBatch - 1)) / + 2, + (static_cast(numBatches) * rowsPerBatch * + (numBatches * rowsPerBatch - 1)) / + 2 * 1.5 / (numBatches * rowsPerBatch))); + + waitForTaskCompletion(leafTask.get()); + + // Verify Exchange stats to ensure data was processed correctly + auto rootTaskStats = toPlanStats(task->taskStats()); + const auto& exchangeStats = rootTaskStats.at(exchangeNodeId); + + EXPECT_GE(exchangeStats.outputVectors, expectedAtLeastOutputBatches); + }; + + // Presto serialization operates at page-level granularity (pages are atomic). + // The number of output batches depends on how many Presto pages are created + // during serialization, which varies based on encoding, compression, and + // data. + // + // For this test (100 input batches × 100 rows = 10,000 rows): + // The actual behavior shows all pages are merged and processed together, + // resulting in a single batch output currently. + // + // This is a known limitation - Presto pages cannot be partially deserialized. + // The fix prevents INT32_MAX overflow by controlling the merge size, but + // fine-grained batch control requires deeper changes to PrestoVectorSerde. + + if (GetParam().serdeKind == VectorSerde::Kind::kPresto) { + // Current implementation merges all pages and processes in one batch + // The key improvement is preventing overflow, not fine-grained batching + test(100, 100, 1, 1); // Expect single batch with all data + test(100, 100, 1ULL << 30, 1); // Expect single batch with all data + } else { + // Row-based serialization (CompactRow/UnsafeRow) supports row-level + // batching With 1 byte limit: Can produce many small batches + test(100, 100, 1, 100); + // With 1GB limit: All rows fit in one batch + test(100, 100, 1ULL << 30, 1); + } +} + VELOX_INSTANTIATE_TEST_SUITE_P( MultiFragmentTest, MultiFragmentTest, diff --git a/velox/exec/tests/NestedLoopJoinTest.cpp b/velox/exec/tests/NestedLoopJoinTest.cpp index f6f21a1138ad..4baec4a8c432 100644 --- a/velox/exec/tests/NestedLoopJoinTest.cpp +++ b/velox/exec/tests/NestedLoopJoinTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" #include "velox/core/PlanNode.h" #include "velox/exec/NestedLoopJoinBuild.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" @@ -99,11 +100,12 @@ class NestedLoopJoinTest : public HiveConnectorTestBase { for (const auto joinType : joinTypes_) { for (const auto& comparison : comparisons_) { - SCOPED_TRACE(fmt::format( - "maxDrivers:{} joinType:{} comparison:{}", - std::to_string(numDrivers), - joinTypeName(joinType), - comparison)); + SCOPED_TRACE( + fmt::format( + "maxDrivers:{} joinType:{} comparison:{}", + std::to_string(numDrivers), + core::JoinTypeName::toName(joinType), + comparison)); params.planNode = PlanBuilder(planNodeIdGenerator) @@ -122,7 +124,9 @@ class NestedLoopJoinTest : public HiveConnectorTestBase { assertQuery( params, fmt::format( - fmt::runtime(queryStr_), joinTypeName(joinType), comparison)); + fmt::runtime(queryStr_), + core::JoinTypeName::toName(joinType), + comparison)); } } } @@ -376,7 +380,7 @@ TEST_F(NestedLoopJoinTest, outerJoinWithoutCondition) { op, fmt::format( "SELECT count(*) FROM t {} join u on 1", - core::joinTypeName(joinType))); + core::JoinTypeName::toName(joinType))); }; testOuterJoin(core::JoinType::kLeft); testOuterJoin(core::JoinType::kRight); @@ -728,5 +732,86 @@ TEST_F(NestedLoopJoinTest, mergeBuildVectorsOverflow) { ASSERT_EQ(mergeResult.size(), 2); } +DEBUG_ONLY_TEST_F(NestedLoopJoinTest, longBatchDurationYield) { + const uint32_t kProbeSize = 10; + const uint32_t kBuildSize = 1'000; + const uint64_t kDriverCpuTimeSliceLimitMs = 1'000; + const std::string kLargeBatchSize = + folly::to(kProbeSize * kBuildSize); + + struct { + uint32_t numGetOutputCalls; + bool hasDelay; + std::string debugString() const { + return fmt::format( + "numGetOutputCalls: {}, needSleep: {}", numGetOutputCalls, hasDelay); + } + } testSettings[] = {{0, false}, {0, true}}; + + const auto probeData = makeRowVector( + {"t_c0", "t_c1"}, + { + makeFlatVector(kProbeSize, [](auto row) { return row; }), + makeFlatVector(kProbeSize, [](auto row) { return row * 2; }), + }); + + const auto buildData = makeRowVector( + {"u_c0", "u_c1"}, + { + makeFlatVector(kBuildSize, [](auto row) { return row; }), + makeFlatVector(kBuildSize, [](auto row) { return row * 3; }), + }); + + createDuckDbTable("t", {probeData}); + createDuckDbTable("u", {buildData}); + + auto planNodeIdGenerator = std::make_shared(); + auto planNode = + PlanBuilder(planNodeIdGenerator) + .values({probeData}) + .nestedLoopJoin( + PlanBuilder(planNodeIdGenerator).values({buildData}).planNode(), + "", + {"t_c0", "t_c1", "u_c0", "u_c1"}, + core::JoinType::kInner) + .planNode(); + + for (auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + std::atomic nestedLoopJoinProbeGetOutputCalls{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](void* op) { + auto* operatorPtr = static_cast(op); + if (operatorPtr->operatorType() == "NestedLoopJoinProbe") { + // The second time NestedLoopJoinProbe::getOutput actually calls + // generateOutput, and the function + // NestedLoopJoinProbe::shouldYield is expected to be called. + if (nestedLoopJoinProbeGetOutputCalls.fetch_add(1) == 2 && + testData.hasDelay) { + std::this_thread::sleep_for( + std::chrono::milliseconds(2 * kDriverCpuTimeSliceLimitMs)); + } + } + })); + + auto queryCtx = core::QueryCtx::create( + executor_.get(), + core::QueryConfig({ + {core::QueryConfig::kDriverCpuTimeSliceLimitMs, + folly::to(kDriverCpuTimeSliceLimitMs)}, + {core::QueryConfig::kPreferredOutputBatchRows, kLargeBatchSize}, + })); + + AssertQueryBuilder(planNode, duckDbQueryRunner_) + .queryCtx(queryCtx) + .maxDrivers(1) + .assertResults("SELECT t_c0, t_c1, u_c0, u_c1 FROM t, u"); + testData.numGetOutputCalls = nestedLoopJoinProbeGetOutputCalls.load(); + } + ASSERT_LT( + testSettings[0].numGetOutputCalls, testSettings[1].numGetOutputCalls); +} + } // namespace } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/OperatorTraceTest.cpp b/velox/exec/tests/OperatorTraceTest.cpp index 64a9886f50d5..25da7c069bf1 100644 --- a/velox/exec/tests/OperatorTraceTest.cpp +++ b/velox/exec/tests/OperatorTraceTest.cpp @@ -21,15 +21,14 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" -#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/exec/OperatorTraceReader.h" -#include "velox/exec/OperatorTraceWriter.h" #include "velox/exec/PartitionFunction.h" #include "velox/exec/Split.h" #include "velox/exec/TaskTraceReader.h" +#include "velox/exec/TaskTraceWriter.h" #include "velox/exec/Trace.h" #include "velox/exec/TraceUtil.h" -#include "velox/exec/tests/utils/ArbitratorTestUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -50,15 +49,11 @@ class OperatorTraceTest : public HiveConnectorTestBase { } Type::registerSerDe(); common::Filter::registerSerDe(); - connector::hive::HiveTableHandle::registerSerDe(); - connector::hive::LocationHandle::registerSerDe(); - connector::hive::HiveColumnHandle::registerSerDe(); - connector::hive::HiveInsertTableHandle::registerSerDe(); - connector::hive::HiveConnectorSplit::registerSerDe(); - connector::hive::HiveInsertFileNameGenerator::registerSerDe(); + connector::hive::HiveConnector::registerSerDe(); core::PlanNode::registerSerDe(); core::ITypedExpr::registerSerDe(); registerPartitionFunctionSerDe(); + registerDummySourceSerDe(); } void SetUp() override { @@ -100,7 +95,9 @@ class OperatorTraceTest : public HiveConnectorTestBase { } for (auto i = 0; i < left->sources().size(); ++i) { - isSamePlan(left->sources().at(i), right->sources().at(i)); + if (!isSamePlan(left->sources().at(i), right->sources().at(i))) { + return false; + } } return true; } @@ -135,7 +132,7 @@ TEST_F(OperatorTraceTest, emptyTrace) { .config(core::QueryConfig::kQueryTraceDir, traceDirPath->getPath()) .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") - .config(core::QueryConfig::kQueryTraceNodeIds, planNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, planNodeId) .assertResults("SELECT a, count(1) FROM tmp WHERE a > 0 GROUP BY 1"); const auto taskTraceDir = @@ -200,7 +197,7 @@ TEST_F(OperatorTraceTest, traceData) { core::QueryConfig::kQueryTraceMaxBytes, testData.maxTracedBytes) .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") - .config(core::QueryConfig::kQueryTraceNodeIds, planNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, planNodeId) .assertResults("SELECT a, count(1) FROM tmp GROUP BY 1"), "Query exceeded per-query local trace limit of"); continue; @@ -213,7 +210,7 @@ TEST_F(OperatorTraceTest, traceData) { .config( core::QueryConfig::kQueryTraceMaxBytes, testData.maxTracedBytes) .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") - .config(core::QueryConfig::kQueryTraceNodeIds, planNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, planNodeId) .assertResults("SELECT a, count(1) FROM tmp GROUP BY 1"); const auto fs = @@ -268,6 +265,7 @@ TEST_F(OperatorTraceTest, traceMetadata) { const auto outputDir = TempDirectoryPath::create(); auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId traceNodeId; const auto planNode = PlanBuilder(planNodeIdGenerator) .values(rows, false) @@ -283,6 +281,7 @@ TEST_F(OperatorTraceTest, traceMetadata) { "c0 < 135", {"c0", "c1", "c2"}, core::JoinType::kInner) + .capturePlanNodeId(traceNodeId) .planNode(); const auto expectedQueryConfigs = std::unordered_map{ @@ -301,14 +300,15 @@ TEST_F(OperatorTraceTest, traceMetadata) { core::QueryConfig(expectedQueryConfigs), expectedConnectorProperties); auto writer = trace::TaskTraceMetadataWriter(outputDir->getPath(), pool()); - writer.write(queryCtx, planNode); + auto traceNode = getTraceNode(planNode, traceNodeId); + writer.write(queryCtx, traceNode); const auto reader = trace::TaskTraceMetadataReader(outputDir->getPath(), pool()); const auto actualQueryConfigs = reader.queryConfigs(); const auto actualConnectorProperties = reader.connectorProperties(); const auto actualQueryPlan = reader.queryPlan(); - ASSERT_TRUE(isSamePlan(actualQueryPlan, planNode)); + ASSERT_TRUE(isSamePlan(actualQueryPlan, traceNode)); ASSERT_EQ(actualQueryConfigs.size(), expectedQueryConfigs.size()); for (const auto& [key, value] : actualQueryConfigs) { ASSERT_EQ(actualQueryConfigs.at(key), expectedQueryConfigs.at(key)); @@ -380,7 +380,7 @@ TEST_F(OperatorTraceTest, task) { std::to_string(100UL << 30)}, {core::QueryConfig::kQueryTraceDir, outputDir->getPath()}, {core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr}, - {core::QueryConfig::kQueryTraceNodeIds, hashJoinNodeId}, + {core::QueryConfig::kQueryTraceNodeId, hashJoinNodeId}, {"key1", "value1"}, }; @@ -423,7 +423,8 @@ TEST_F(OperatorTraceTest, task) { const auto actualConnectorProperties = reader.connectorProperties(); const auto actualQueryPlan = reader.queryPlan(); - ASSERT_TRUE(isSamePlan(actualQueryPlan, planNode)); + ASSERT_TRUE( + isSamePlan(actualQueryPlan, getTraceNode(planNode, hashJoinNodeId))); ASSERT_EQ(actualQueryConfigs.size(), expectedQueryConfigs.size()); for (const auto& [key, value] : actualQueryConfigs) { ASSERT_EQ(actualQueryConfigs.at(key), expectedQueryConfigs.at(key)); @@ -486,32 +487,16 @@ TEST_F(OperatorTraceTest, error) { .queryCtx(queryCtx) .maxDrivers(1) .copyResults(pool()), - "Query trace nodes are not set"); - } - // Duplicate trace plan node ids. - { - const auto queryConfigs = std::unordered_map{ - {core::QueryConfig::kQueryTraceEnabled, "true"}, - {core::QueryConfig::kQueryTraceDir, "traceDir"}, - {core::QueryConfig::kQueryTraceTaskRegExp, ".*"}, - {core::QueryConfig::kQueryTraceNodeIds, "1,1"}, - }; - const auto queryCtx = core::QueryCtx::create( - executor_.get(), core::QueryConfig(queryConfigs)); - VELOX_ASSERT_USER_THROW( - AssertQueryBuilder(planNode) - .queryCtx(queryCtx) - .maxDrivers(1) - .copyResults(pool()), - "Duplicate trace nodes found: 1, 1"); + "Query trace node ID are not set"); } + // Nonexist trace plan node id. { const auto queryConfigs = std::unordered_map{ {core::QueryConfig::kQueryTraceEnabled, "true"}, {core::QueryConfig::kQueryTraceDir, "traceDir"}, {core::QueryConfig::kQueryTraceTaskRegExp, ".*"}, - {core::QueryConfig::kQueryTraceNodeIds, "nonexist"}, + {core::QueryConfig::kQueryTraceNodeId, "nonexist"}, }; const auto queryCtx = core::QueryCtx::create( executor_.get(), core::QueryConfig(queryConfigs)); @@ -520,7 +505,8 @@ TEST_F(OperatorTraceTest, error) { .queryCtx(queryCtx) .maxDrivers(1) .copyResults(pool()), - "Trace plan nodes not found from task"); + + "Trace plan node ID = nonexist not found from task"); } } @@ -578,7 +564,7 @@ TEST_F(OperatorTraceTest, traceTableWriter) { .config( core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr) - .config(core::QueryConfig::kQueryTraceNodeIds, tableWriteNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, tableWriteNodeId) .copyResults(pool(), task), "Query exceeded per-query local trace limit of"); continue; @@ -589,7 +575,7 @@ TEST_F(OperatorTraceTest, traceTableWriter) { .config(core::QueryConfig::kQueryTraceDir, traceRoot) .config(core::QueryConfig::kQueryTraceMaxBytes, testData.maxTracedBytes) .config(core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr) - .config(core::QueryConfig::kQueryTraceNodeIds, tableWriteNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, tableWriteNodeId) .copyResults(pool(), task); const auto taskTraceDir = getTaskTraceDirectory(traceRoot, *task); @@ -684,7 +670,7 @@ TEST_F(OperatorTraceTest, filterProject) { .config( core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr) - .config(core::QueryConfig::kQueryTraceNodeIds, projectNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, projectNodeId) .copyResults(pool(), task), "Query exceeded per-query local trace limit of"); continue; @@ -695,7 +681,7 @@ TEST_F(OperatorTraceTest, filterProject) { .config(core::QueryConfig::kQueryTraceDir, traceRoot) .config(core::QueryConfig::kQueryTraceMaxBytes, testData.maxTracedBytes) .config(core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr) - .config(core::QueryConfig::kQueryTraceNodeIds, projectNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, projectNodeId) .copyResults(pool(), task); const auto taskTraceDir = getTaskTraceDirectory(traceRoot, *task); @@ -759,7 +745,7 @@ TEST_F(OperatorTraceTest, traceSplitRoundTrip) { .config(core::QueryConfig::kQueryTraceEnabled, true) .config(core::QueryConfig::kQueryTraceDir, traceDirPath->getPath()) .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") - .config(core::QueryConfig::kQueryTraceNodeIds, "0") + .config(core::QueryConfig::kQueryTraceNodeId, "0") .splits(splits) .copyResults(pool(), task); @@ -769,7 +755,7 @@ TEST_F(OperatorTraceTest, traceSplitRoundTrip) { for (int i = 0; i < 3; ++i) { const auto opTraceDir = getOpTraceDirectory( taskTraceDir, - /*planNodeId=*/"0", + /*nodeId=*/"0", /*pipelineId=*/0, /*driverId=*/i); const auto summaryFilePath = getOpTraceSummaryFilePath(opTraceDir); @@ -825,7 +811,7 @@ TEST_F(OperatorTraceTest, traceSplitPartial) { .config(core::QueryConfig::kQueryTraceEnabled, true) .config(core::QueryConfig::kQueryTraceDir, traceDirPath->getPath()) .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") - .config(core::QueryConfig::kQueryTraceNodeIds, "0") + .config(core::QueryConfig::kQueryTraceNodeId, "0") .splits(splits) .copyResults(pool(), task); @@ -835,7 +821,7 @@ TEST_F(OperatorTraceTest, traceSplitPartial) { for (int i = 0; i < 3; ++i) { const auto opTraceDir = getOpTraceDirectory( taskTraceDir, - /*planNodeId=*/"0", + /*nodeId=*/"0", /*pipelineId=*/0, /*driverId=*/i); const auto summaryFilePath = getOpTraceSummaryFilePath(opTraceDir); @@ -914,7 +900,7 @@ TEST_F(OperatorTraceTest, traceSplitCorrupted) { .config(core::QueryConfig::kQueryTraceEnabled, true) .config(core::QueryConfig::kQueryTraceDir, traceDirPath->getPath()) .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") - .config(core::QueryConfig::kQueryTraceNodeIds, "0") + .config(core::QueryConfig::kQueryTraceNodeId, "0") .splits(splits) .copyResults(pool(), task); @@ -924,7 +910,7 @@ TEST_F(OperatorTraceTest, traceSplitCorrupted) { for (int i = 0; i < 3; ++i) { const auto opTraceDir = getOpTraceDirectory( taskTraceDir, - /*planNodeId=*/"0", + /*nodeId=*/"0", /*pipelineId=*/0, /*driverId=*/i); const auto summaryFilePath = getOpTraceSummaryFilePath(opTraceDir); @@ -1056,7 +1042,7 @@ TEST_F(OperatorTraceTest, hashJoin) { .config( core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr) - .config(core::QueryConfig::kQueryTraceNodeIds, hashJoinNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, hashJoinNodeId) .copyResults(pool(), task), "Query exceeded per-query local trace limit of"); continue; @@ -1067,7 +1053,7 @@ TEST_F(OperatorTraceTest, hashJoin) { .config(core::QueryConfig::kQueryTraceDir, traceRoot) .config(core::QueryConfig::kQueryTraceMaxBytes, testData.maxTracedBytes) .config(core::QueryConfig::kQueryTraceTaskRegExp, testData.taskRegExpr) - .config(core::QueryConfig::kQueryTraceNodeIds, hashJoinNodeId) + .config(core::QueryConfig::kQueryTraceNodeId, hashJoinNodeId) .copyResults(pool(), task); const auto taskTraceDir = getTaskTraceDirectory(traceRoot, *task); @@ -1135,13 +1121,18 @@ TEST_F(OperatorTraceTest, canTrace) { {"PartitionedOutput", true}, {"HashBuild", true}, {"HashProbe", true}, + {"IndexLookupJoin", true}, + {"Unnest", true}, {"RowNumber", false}, - {"OrderBy", false}, + {"OrderBy", true}, + {"TopNRowNumber", true}, {"PartialAggregation", true}, {"Aggregation", true}, {"TableWrite", true}, {"TableScan", true}, - {"FilterProject", true}}; + {"FilterProject", true}, + {"Exchange", true}, + {"MergeExchange", true}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); ASSERT_EQ(testData.canTrace, trace::canTrace(testData.operatorType)); @@ -1172,9 +1163,9 @@ TEST_F(OperatorTraceTest, hiveConnectorId) { .config(core::QueryConfig::kQueryTraceEnabled, true) .config(core::QueryConfig::kQueryTraceDir, traceDirPath->getPath()) .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") - .config(core::QueryConfig::kQueryTraceNodeIds, "0") + .config(core::QueryConfig::kQueryTraceNodeId, "0") .splits(splits) - .runWithoutResults(task); + .countResults(task); const auto taskTraceDir = getTaskTraceDirectory(traceDirPath->getPath(), *task); const auto reader = trace::TaskTraceMetadataReader(taskTraceDir, pool()); diff --git a/velox/exec/tests/OperatorUtilsTest.cpp b/velox/exec/tests/OperatorUtilsTest.cpp index 4ea4864cdd77..0036f2f3dbf9 100644 --- a/velox/exec/tests/OperatorUtilsTest.cpp +++ b/velox/exec/tests/OperatorUtilsTest.cpp @@ -17,6 +17,7 @@ #include #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Operator.h" +#include "velox/exec/tests/utils/MergeTestBase.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -48,7 +49,8 @@ class OperatorUtilsTest : public OperatorTestBase { std::move(planFragment), 0, core::QueryCtx::create(executor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); driver_ = Driver::testingCreate(); driverCtx_ = std::make_unique(task_, 0, 0, 0, 0); driverCtx_->driver = driver_.get(); @@ -57,7 +59,8 @@ class OperatorUtilsTest : public OperatorTestBase { void gatherCopyTest( const std::shared_ptr& targetType, const std::shared_ptr& sourceType, - int numSources) { + int numSources, + bool flattenSources = true) { folly::Random::DefaultGenerator rng(1); const int kNumRows = 500; const int kNumColumns = sourceType->size(); @@ -65,8 +68,9 @@ class OperatorUtilsTest : public OperatorTestBase { // Build source vectors with nulls. std::vector sources; for (int i = 0; i < numSources; ++i) { - sources.push_back(std::static_pointer_cast( - BatchMaker::createBatch(sourceType, kNumRows, *pool_))); + sources.push_back( + std::static_pointer_cast( + BatchMaker::createBatch(sourceType, kNumRows, *pool_))); for (int j = 0; j < kNumColumns; ++j) { auto vector = sources.back()->childAt(j); int nullRow = (folly::Random::rand32() % kNumRows) / 4; @@ -78,6 +82,23 @@ class OperatorUtilsTest : public OperatorTestBase { } } + if (!flattenSources) { + for (int i = 0; i < numSources; ++i) { + const auto source = sources[i]; + const auto numRows = source->size(); + std::vector sortIndices(numRows, 0); + for (auto i = 0; i < numRows; ++i) { + sortIndices[i] = i; + } + BufferPtr indices = allocateIndices(numRows, pool_.get()); + auto rawIndices = indices->asMutable(); + for (size_t i = 0; i < numRows; ++i) { + rawIndices[i] = sortIndices[i]; + } + sources[i] = wrap(numRows, indices, source); + } + } + std::vector columnMap; if (sourceType != targetType) { for (column_index_t sourceChannel = 0; sourceChannel < kNumColumns; @@ -133,6 +154,61 @@ class OperatorUtilsTest : public OperatorTestBase { } } + void gatherMergeTest( + int32_t numValues, + int numMergeWays, + int targetSize, + bool useRandom) { + auto goldenVector = makeRowVector({ + makeFlatVector(numValues, [&](auto row) { return row; }), + }); + std::vector> mergeWays(numMergeWays); + for (int32_t value = 0; value < numValues; value++) { + int way = useRandom ? folly::Random::rand32() % numMergeWays + : value % numMergeWays; + mergeWays[way].push_back(value); + } + std::vector sources; + std::vector> streams; + std::vector sortKeys = {{0, {true, true}}}; + for (int way = 0; way < numMergeWays; way++) { + auto source = makeRowVector({ + makeFlatVector( + mergeWays[way].size(), + [&](auto row) { return mergeWays[way][row]; }), + }); + sources.push_back(source); + streams.push_back( + std::make_unique(way, sortKeys, source)); + } + auto mergeTree = + std::make_unique>(std::move(streams)); + RowVectorPtr targetVector = std::static_pointer_cast( + BaseVector::create(sources[0]->type(), targetSize, pool_.get())); + std::vector bufferSources(targetSize); + std::vector bufferSourceIndices(targetSize); + for (int32_t batch = 0; batch * targetSize < numValues; batch++) { + int32_t valueBegin = batch * targetSize; + int32_t valueEnd = valueBegin + targetSize; + valueEnd = std::min(valueEnd, numValues); + VectorPtr tmp = std::move(targetVector); + BaseVector::prepareForReuse(tmp, targetSize); + targetVector = std::static_pointer_cast(tmp); + for (auto& child : targetVector->children()) { + child->resize(targetSize); + } + int count = 0; + testingGatherMerge( + targetVector, *mergeTree, count, bufferSources, bufferSourceIndices); + EXPECT_EQ(count, valueEnd - valueBegin); + auto result = targetVector->childAt(0).get(); + auto golden = goldenVector->childAt(0).get(); + for (int32_t row = 0; row < valueEnd - valueBegin; row++) { + EXPECT_TRUE(result->equalValueAt(golden, row, valueBegin + row)); + } + } + } + void setTaskOutputBatchConfig( uint32_t preferredBatchSize, uint32_t maxRows, @@ -373,6 +449,67 @@ TEST_F(OperatorUtilsTest, gatherCopy) { } } +TEST_F(OperatorUtilsTest, gatherCopyEncoding) { + std::shared_ptr rowType; + std::shared_ptr reversedRowType; + { + std::vector names = { + "bool_val", + "tiny_val", + "small_val", + "int_val", + "long_val", + "ordinal", + "float_val", + "double_val", + "string_val", + "array_val", + "struct_val", + "map_val"}; + std::vector reversedNames = names; + std::reverse(reversedNames.begin(), reversedNames.end()); + + std::vector> types = { + BOOLEAN(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + ARRAY(VARCHAR()), + ROW({{"s_int", INTEGER()}, {"s_array", ARRAY(REAL())}}), + MAP(VARCHAR(), + MAP(BIGINT(), + ROW({{"s2_int", INTEGER()}, {"s2_string", VARCHAR()}})))}; + std::vector> reversedTypes = types; + std::reverse(reversedTypes.begin(), reversedTypes.end()); + + rowType = ROW(std::move(names), std::move(types)); + reversedRowType = ROW(std::move(reversedNames), std::move(reversedTypes)); + } + + // Gather copy with identical column mapping. + gatherCopyTest(rowType, rowType, 1, false); + gatherCopyTest(rowType, rowType, 5, false); + // Gather copy with non-identical column mapping. + gatherCopyTest(rowType, reversedRowType, 1, false); + gatherCopyTest(rowType, reversedRowType, 5, false); +} + +TEST_F(OperatorUtilsTest, gatherMerge) { + gatherMergeTest(1234, 2, 10, false); + gatherMergeTest(1234, 2, 100, false); + gatherMergeTest(1234, 10, 10, false); + gatherMergeTest(1234, 10, 100, false); + gatherMergeTest(1234, 2, 10, true); + gatherMergeTest(1234, 2, 100, true); + gatherMergeTest(1234, 10, 10, true); + gatherMergeTest(1234, 10, 100, true); +} + TEST_F(OperatorUtilsTest, makeOperatorSpillPath) { EXPECT_EQ("spill/3_1_100", makeOperatorSpillPath("spill", 3, 1, 100)); } @@ -448,6 +585,36 @@ TEST_F(OperatorUtilsTest, addOperatorRuntimeStats) { ASSERT_EQ(stats[statsName].min, 100); } +TEST_F(OperatorUtilsTest, setOperatorRuntimeStats) { + std::unordered_map stats; + const std::string statsName("stats"); + const RuntimeCounter minStatsValue(100, RuntimeCounter::Unit::kBytes); + const RuntimeCounter maxStatsValue(200, RuntimeCounter::Unit::kBytes); + setOperatorRuntimeStats(statsName, minStatsValue, stats); + ASSERT_EQ(stats[statsName].count, 1); + ASSERT_EQ(stats[statsName].sum, 100); + ASSERT_EQ(stats[statsName].max, 100); + ASSERT_EQ(stats[statsName].min, 100); + + setOperatorRuntimeStats(statsName, maxStatsValue, stats); + ASSERT_EQ(stats[statsName].count, 1); + ASSERT_EQ(stats[statsName].sum, 200); + ASSERT_EQ(stats[statsName].max, 200); + ASSERT_EQ(stats[statsName].min, 200); + + addOperatorRuntimeStats(statsName, maxStatsValue, stats); + ASSERT_EQ(stats[statsName].count, 2); + ASSERT_EQ(stats[statsName].sum, 400); + ASSERT_EQ(stats[statsName].max, 200); + ASSERT_EQ(stats[statsName].min, 200); + + setOperatorRuntimeStats(statsName, minStatsValue, stats); + ASSERT_EQ(stats[statsName].count, 1); + ASSERT_EQ(stats[statsName].sum, 100); + ASSERT_EQ(stats[statsName].max, 100); + ASSERT_EQ(stats[statsName].min, 100); +} + TEST_F(OperatorUtilsTest, initializeRowNumberMapping) { BufferPtr mapping; auto rawMapping = initializeRowNumberMapping(mapping, 10, pool()); @@ -531,6 +698,52 @@ TEST_F(OperatorUtilsTest, projectChildren) { } } +TEST_F(OperatorUtilsTest, projectDuplicateChildren) { + // Test wrapping an unloaded lazy vector in dictionary vector multiple + // times. + auto flatVector = makeNullableFlatVector( + std::vector>{1, std::nullopt, 3, 4, 5}); + const auto size = flatVector->size(); + + auto lazyVector = std::make_shared( + pool(), + BIGINT(), + size, + std::make_unique([&](RowSet /*rows*/) { + return makeFlatVector( + size, + [&](vector_size_t row) { return flatVector->valueAt(row); }, + [&](vector_size_t row) { return flatVector->isNullAt(row); }); + })); + + std::vector children = {lazyVector}; + auto rowVector = makeRowVector(std::move(children)); + + std::vector identityProjections; + identityProjections.emplace_back(0, 0); + identityProjections.emplace_back(0, 1); + + auto mapping = makeIndices(size, [](auto row) { return row % 3; }); + + std::vector projectedChildren(2); + projectChildren( + projectedChildren, rowVector, identityProjections, size, mapping); + + for (const auto& projection : identityProjections) { + auto* result = projectedChildren[projection.outputChannel].get(); + result->loadedVector(); + auto* source = rowVector->childAt(projection.inputChannel).get(); + for (auto i = 0; i < size; ++i) { + auto srcIndex = mapping->as()[i]; + if (result->isNullAt(i)) { + ASSERT_TRUE(source->isNullAt(srcIndex)); + } else { + ASSERT_TRUE(result->equalValueAt(source, i, srcIndex)); + } + } + } +} + TEST_F(OperatorUtilsTest, reclaimableSectionGuard) { RowTypePtr rowType = ROW({"c0"}, {INTEGER()}); diff --git a/velox/exec/tests/OrderByTest.cpp b/velox/exec/tests/OrderByTest.cpp index c343dbe8020d..25548f3f0481 100644 --- a/velox/exec/tests/OrderByTest.cpp +++ b/velox/exec/tests/OrderByTest.cpp @@ -558,8 +558,9 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringInputProcessing) { auto spillDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -700,8 +701,9 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringReserve) { auto spillDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -946,8 +948,9 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringOutputProcessing) { SCOPED_TRACE(fmt::format("enableSpilling {}", enableSpilling)); auto spillDirectory = exec::test::TempDirectoryPath::create(); auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), kMaxBytes, memory::MemoryReclaimer::create())); auto expectedResult = AssertQueryBuilder( PlanBuilder() @@ -1319,11 +1322,12 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimFromOrderBy) { .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kOrderBySpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .orderBy({"c0 ASC NULLS LAST"}, false) - .capturePlanNodeId(orderById) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .orderBy({"c0 ASC NULLS LAST"}, false) + .capturePlanNodeId(orderById) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); auto taskStats = exec::toPlanStats(task->taskStats()); auto& planStats = taskStats.at(orderById); @@ -1357,10 +1361,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimFromEmptyOrderBy) { .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kOrderBySpillEnabled, true) - .plan(PlanBuilder() - .values(vectors) - .orderBy({"c0 ASC NULLS LAST"}, false) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .orderBy({"c0 ASC NULLS LAST"}, false) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); // Verify no spill has been triggered. const auto stats = task->taskStats().pipelineStats; @@ -1376,8 +1381,9 @@ DEBUG_ONLY_TEST_F(OrderByTest, orderByWithLazyInput) { VectorFuzzer(fuzzerOpts_, pool()).fuzzRowChildrenToLazy(nonLazyVector)); std::vector lazyInputCopy; - lazyInputCopy.push_back(std::dynamic_pointer_cast( - nonLazyVector->testingCopyPreserveEncodings())); + lazyInputCopy.push_back( + std::dynamic_pointer_cast( + nonLazyVector->testingCopyPreserveEncodings())); createDuckDbTable(lazyInputCopy); std::atomic_bool nonReclaimableSectionEntered{false}; @@ -1404,10 +1410,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, orderByWithLazyInput) { .spillDirectory(spillDirectory->getPath()) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kOrderBySpillEnabled, true) - .plan(PlanBuilder() - .values(lazyInput) - .orderBy({"c0 ASC NULLS LAST"}, false) - .planNode()) + .plan( + PlanBuilder() + .values(lazyInput) + .orderBy({"c0 ASC NULLS LAST"}, false) + .planNode()) .assertResults("SELECT * FROM tmp ORDER BY c0 ASC NULLS LAST"); ASSERT_TRUE(lazyLoadedInNonReclaimableSection.has_value()); diff --git a/velox/exec/tests/OutputBufferManagerTest.cpp b/velox/exec/tests/OutputBufferManagerTest.cpp index e1e5feaa494d..4c8ffd2dfe5c 100644 --- a/velox/exec/tests/OutputBufferManagerTest.cpp +++ b/velox/exec/tests/OutputBufferManagerTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/OutputBufferManager.h" +#include #include #include "folly/experimental/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" @@ -109,13 +110,14 @@ class OutputBufferManagerTest : public testing::Test { std::move(planFragment), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); bufferManager_->initializeTask(task, kind, numDestinations, numDrivers); return task; } - std::unique_ptr makeSerializedPage( + std::unique_ptr makeSerializedPage( RowTypePtr rowType, vector_size_t size) { auto vector = std::dynamic_pointer_cast( @@ -375,8 +377,9 @@ class OutputBufferManagerTest : public testing::Test { std::vector> pages, int64_t inSequence, std::vector remainingBytes) { - promise.setValue(Response{ - std::move(pages), inSequence, std::move(remainingBytes)}); + promise.setValue( + Response{ + std::move(pages), inSequence, std::move(remainingBytes)}); }); future.wait(); ASSERT_TRUE(future.isReady()); @@ -435,7 +438,7 @@ class OutputBufferManagerTest : public testing::Test { const VectorSerde::Kind serdeKind_; std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::hardware_concurrency())}; std::shared_ptr pool_; std::shared_ptr bufferManager_; RowTypePtr rowType_; @@ -570,21 +573,18 @@ VELOX_INSTANTIATE_TEST_SUITE_P( TEST_F(OutputBufferManagerTest, outputType) { ASSERT_EQ( - PartitionedOutputNode::kindString( - PartitionedOutputNode::Kind::kPartitioned), + PartitionedOutputNode::toName(PartitionedOutputNode::Kind::kPartitioned), "PARTITIONED"); ASSERT_EQ( - PartitionedOutputNode::kindString( - PartitionedOutputNode::Kind::kArbitrary), + PartitionedOutputNode::toName(PartitionedOutputNode::Kind::kArbitrary), "ARBITRARY"); ASSERT_EQ( - PartitionedOutputNode::kindString( - PartitionedOutputNode::Kind::kBroadcast), + PartitionedOutputNode::toName(PartitionedOutputNode::Kind::kBroadcast), "BROADCAST"); VELOX_ASSERT_THROW( - PartitionedOutputNode::kindString( + PartitionedOutputNode::toName( static_cast(100)), - "Invalid Output Kind 100"); + "Invalid enum value: 100"); } TEST_P(OutputBufferManagerWithDifferentSerdeKindsTest, destinationBuffer) { @@ -1473,7 +1473,7 @@ TEST_P( std::memcpy(iobuf->writableData(), payload.data(), payloadSize); iobuf->append(payloadSize); - auto page = std::make_unique(std::move(iobuf)); + auto page = std::make_unique(std::move(iobuf)); auto queue = std::make_shared(1, 0); std::vector promises; diff --git a/velox/exec/tests/ParallelProjectTest.cpp b/velox/exec/tests/ParallelProjectTest.cpp new file mode 100644 index 000000000000..e3fcb8d54e6d --- /dev/null +++ b/velox/exec/tests/ParallelProjectTest.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::velox::exec { +namespace { + +class ParallelProjectTest : public test::OperatorTestBase { + protected: + void SetUp() override { + test::OperatorTestBase::SetUp(); + } +}; + +TEST_F(ParallelProjectTest, basic) { + auto data = makeRowVector({ + makeFlatVector(100, folly::identity), + makeFlatVector(100, folly::identity), + }); + + createDuckDbTable({data}); + + auto plan = + test::PlanBuilder() + .values({data}) + .parallelProject({{"c0 + 1", "c0 * 2"}, {"c1 + 10", "c1 * 3"}}) + .planNode(); + + assertQuery(plan, "SELECT c0 + 1, c0 * 2, c1 + 10, c1 * 3 FROM tmp"); +} + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/PlanBuilderTest.cpp b/velox/exec/tests/PlanBuilderTest.cpp index 69436314762b..3366cb71a8ab 100644 --- a/velox/exec/tests/PlanBuilderTest.cpp +++ b/velox/exec/tests/PlanBuilderTest.cpp @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/core/Expressions.h" #include "velox/exec/WindowFunction.h" +#include "velox/exec/tests/utils/ExpressionBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/TestIndexStorageConnector.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/parse/Expressions.h" @@ -254,51 +258,220 @@ TEST_F(PlanBuilderTest, missingOutputType) { } TEST_F(PlanBuilderTest, projectExpressions) { + using namespace velox::expr_builder; + // Non-typed Expressions. // Simple field access. auto data = ROW({"c0"}, {BIGINT()}); VELOX_CHECK_EQ( PlanBuilder() .tableScan("tmp", data) - .projectExpressions( - {std::make_shared("c0", std::nullopt)}) + .projectExpressions({col("c0")}) .planNode() ->toString(true, false), "-- Project[1][expressions: (c0:BIGINT, ROW[\"c0\"])] -> c0:BIGINT\n"); + // Dereference test using field access query. data = ROW({"c0"}, {ROW({"field0"}, {BIGINT()})}); VELOX_CHECK_EQ( PlanBuilder() .tableScan("tmp", data) - .projectExpressions({std::make_shared( - "field0", - std::nullopt, - std::vector{ - std::make_shared( - "c0", std::nullopt)})}) + .projectExpressions({col("c0").subfield("field0")}) .planNode() ->toString(true, false), "-- Project[1][expressions: (field0:BIGINT, ROW[\"c0\"][field0])] -> field0:BIGINT\n"); // Test Typed Expressions + auto rowType = ROW({"c0"}, {VARCHAR()}); VELOX_CHECK_EQ( PlanBuilder() - .tableScan("tmp", ROW({"c0"}, {ROW({VARCHAR()})})) + .tableScan("tmp", rowType) .projectExpressions( - {std::make_shared(VARCHAR(), "c0")}) + {core::Expressions::inferTypes(col("c0"), rowType, pool_.get())}) + .planNode() + ->toString(true, false), + "-- Project[1][expressions: (p0:VARCHAR, ROW[\"c0\"])] -> p0:VARCHAR\n"); + + rowType = ROW({"c0"}, {ROW({"field0"}, {VARCHAR()})}); + VELOX_CHECK_EQ( + PlanBuilder() + .tableScan("tmp", rowType) + .projectExpressions({core::Expressions::inferTypes( + col("c0").subfield("field0"), rowType, pool_.get())}) + .planNode() + ->toString(true, false), + "-- Project[1][expressions: (p0:VARCHAR, ROW[\"c0\"][field0])] -> p0:VARCHAR\n"); +} + +TEST_F(PlanBuilderTest, filter) { + auto data = ROW({"c0"}, {BIGINT()}); + constexpr std::string_view expectation = + "-- Filter[1][expression: gt(plus(ROW[\"c0\"],10),100)] -> c0:BIGINT\n"; + + // Filter with SQL snippet. + VELOX_CHECK_EQ( + PlanBuilder() + .tableScan("tmp", data) + .filter("c0 + 10 > 100") .planNode() ->toString(true, false), - "-- Project[1][expressions: (p0:VARCHAR, \"c0\")] -> p0:VARCHAR\n"); + expectation); + + using namespace velox::expr_builder; + + // Filter with untyped expression (same expression as above). VELOX_CHECK_EQ( PlanBuilder() - .tableScan("tmp", ROW({"c0"}, {ROW({VARCHAR()})})) - .projectExpressions({std::make_shared( - VARCHAR(), - std::make_shared(VARCHAR(), "c0"), - "field0")}) + .tableScan("tmp", data) + .filter(col("c0") + 10L > 100L) .planNode() ->toString(true, false), - "-- Project[1][expressions: (p0:VARCHAR, \"c0\"[\"field0\"])] -> p0:VARCHAR\n"); + expectation); +} + +TEST_F(PlanBuilderTest, commitStrategyParameter) { + auto data = makeRowVector({makeFlatVector(10, folly::identity)}); + auto directory = "/some/test/directory"; + + // Lambda to create a plan with given commitStrategy and verify it + auto testCommitStrategy = [&](connector::CommitStrategy commitStrategy) { + // Create a plan with commitStrategy + auto planBuilder = PlanBuilder().values({data}).tableWrite( + directory, + {}, + 0, + {}, + {}, + dwio::common::FileFormat::DWRF, + {}, + PlanBuilder::kHiveDefaultConnectorId, + {}, + nullptr, + "", + common::CompressionKind_NONE, + nullptr, + false); + + core::PlanNodePtr plan; + // Conditionally set commitStrategy if it's not kNoCommit + if (commitStrategy != connector::CommitStrategy::kNoCommit) { + plan = PlanBuilder::TableWriterBuilder(planBuilder) + .commitStrategy(commitStrategy) + .endTableWriter() + .planNode(); + } else { + plan = std::move(planBuilder.planNode()); + } + + // Verify the plan node has the correct commit strategy + auto tableWriteNode = + std::dynamic_pointer_cast(plan); + ASSERT_NE(tableWriteNode, nullptr); + ASSERT_EQ(tableWriteNode->commitStrategy(), commitStrategy); + }; + + // Test with explicit task commit strategy + testCommitStrategy(connector::CommitStrategy::kTaskCommit); + + // Test with no explicit commit strategy (should default to kNoCommit) + testCommitStrategy(connector::CommitStrategy::kNoCommit); +} + +TEST_F(PlanBuilderTest, indexLookupJoinBuilder) { + auto leftType = ROW({"t0", "t1"}, {BIGINT(), ARRAY(BIGINT())}); + auto rightType = ROW({"u0", "u1"}, {BIGINT(), BIGINT()}); + + // Create a TestIndexTableHandle that supports index lookup + auto indexTableHandle = std::make_shared( + kTestIndexConnectorName, nullptr, false); + + // Create column handles for the index table + connector::ColumnHandleMap columnHandles; + for (const auto& name : rightType->names()) { + columnHandles[name] = std::make_shared(name); + } + + auto rightScan = std::make_shared( + "right_scan", rightType, indexTableHandle, columnHandles); + + auto plan = PlanBuilder(pool_.get()) + .tableScan(leftType) + .startIndexLookupJoin() + .leftKeys({"t0"}) + .rightKeys({"u0"}) + .indexSource(rightScan) + .joinConditions({"contains(t1, u1)"}) + .hasMarker(false) + .outputLayout({"t0", "u1"}) + .joinType(core::JoinType::kInner) + .filter("t0 > 0") + .endIndexLookupJoin() + .planNode(); + + auto indexJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_NE(indexJoinNode, nullptr); + ASSERT_EQ(indexJoinNode->joinType(), core::JoinType::kInner); + ASSERT_EQ(indexJoinNode->leftKeys().size(), 1); + ASSERT_EQ(indexJoinNode->rightKeys().size(), 1); + ASSERT_EQ(indexJoinNode->leftKeys()[0]->name(), "t0"); + ASSERT_EQ(indexJoinNode->rightKeys()[0]->name(), "u0"); + ASSERT_EQ(indexJoinNode->joinConditions().size(), 1); + ASSERT_FALSE(indexJoinNode->hasMarker()); + ASSERT_EQ(indexJoinNode->outputType()->names().size(), 2); + ASSERT_EQ(indexJoinNode->outputType()->names()[0], "t0"); + ASSERT_EQ(indexJoinNode->outputType()->names()[1], "u1"); + ASSERT_EQ(indexJoinNode->filter()->toString(), "gt(ROW[\"t0\"],0)"); +} + +TEST_F(PlanBuilderTest, insertTableHandleParameter) { + auto data = makeRowVector({makeFlatVector(10, folly::identity)}); + auto directory = "/some/test/directory"; + + // Lambda to create a plan with given insertableHandle and verify it + auto testInsertTableHandle = + [&](std::shared_ptr insertTableHandle) { + // Create a plan with insertTableHandle + auto planBuilder = PlanBuilder().values({data}).tableWrite( + directory, + {}, + 0, + {}, + {}, + dwio::common::FileFormat::DWRF, + {}, + PlanBuilder::kHiveDefaultConnectorId, + {}, + nullptr, + "", + common::CompressionKind_NONE, + nullptr, + false, + connector::CommitStrategy::kNoCommit, + insertTableHandle); + + // Verify the plan node has the correct insert Table Handle. + auto tableWriteNode = + std::dynamic_pointer_cast( + planBuilder.planNode()); + ASSERT_NE(tableWriteNode, nullptr); + ASSERT_EQ(tableWriteNode->insertTableHandle(), insertTableHandle); + }; + + auto rowType = ROW({"c0", "c1", "c2"}, {BIGINT(), INTEGER(), SMALLINT()}); + auto hiveHandle = HiveConnectorTestBase::makeHiveInsertTableHandle( + rowType->names(), + rowType->children(), + {rowType->names()[0]}, // partitionedBy + nullptr, // bucketProperty + HiveConnectorTestBase::makeLocationHandle( + "/path/to/test", + std::nullopt, + connector::hive::LocationHandle::TableType::kNew)); + + auto insertHandle = std::make_shared( + std::string(PlanBuilder::kHiveDefaultConnectorId), hiveHandle); + testInsertTableHandle(insertHandle); } } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PlanNodeSerdeTest.cpp b/velox/exec/tests/PlanNodeSerdeTest.cpp index 54f5d77361bf..65d64c348f2f 100644 --- a/velox/exec/tests/PlanNodeSerdeTest.cpp +++ b/velox/exec/tests/PlanNodeSerdeTest.cpp @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/exec/PartitionFunction.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/parse/TypeResolver.h" +#include "velox/vector/tests/utils/VectorTestBase.h" #include @@ -38,12 +41,7 @@ class PlanNodeSerdeTest : public testing::Test, Type::registerSerDe(); common::Filter::registerSerDe(); - connector::hive::HiveTableHandle::registerSerDe(); - connector::hive::LocationHandle::registerSerDe(); - connector::hive::HiveColumnHandle::registerSerDe(); - connector::hive::HiveInsertTableHandle::registerSerDe(); - connector::hive::HiveInsertFileNameGenerator::registerSerDe(); - connector::hive::registerHivePartitionFunctionSerDe(); + connector::hive::HiveConnector::registerSerDe(); core::PlanNode::registerSerDe(); core::ITypedExpr::registerSerDe(); registerPartitionFunctionSerDe(); @@ -52,6 +50,11 @@ class PlanNodeSerdeTest : public testing::Test, makeFlatVector({1, 2, 3}), makeFlatVector({10, 20, 30}), makeConstant(true, 3), + makeArrayVector({ + {1, 2}, + {3, 4, 5}, + {}, + }), })}; } @@ -83,6 +86,26 @@ class PlanNodeSerdeTest : public testing::Test, } } + void topNRankSerdeTest(std::string_view function) { + auto plan = PlanBuilder() + .values({data_}) + .topNRank(function, {}, {"c0", "c2"}, 10, false) + .planNode(); + testSerde(plan); + + plan = PlanBuilder() + .values({data_}) + .topNRank(function, {}, {"c0", "c2"}, 10, true) + .planNode(); + testSerde(plan); + + plan = PlanBuilder() + .values({data_}) + .topNRank(function, {"c0"}, {"c1", "c2"}, 10, false) + .planNode(); + testSerde(plan); + } + std::vector data_; }; @@ -95,6 +118,14 @@ TEST_F(PlanNodeSerdeTest, aggregation) { testSerde(plan); + plan = PlanBuilder(pool_.get()) + .values({data_}) + .partialAggregation({"c0"}, {"count(ARRAY[1, 2])"}) + .finalAggregation() + .planNode(); + + testSerde(plan); + // Aggregation over sorted inputs. plan = PlanBuilder() .values({data_}) @@ -210,6 +241,12 @@ TEST_F(PlanNodeSerdeTest, exchange) { TEST_F(PlanNodeSerdeTest, filter) { auto plan = PlanBuilder().values({data_}).filter("c0 > 100").planNode(); testSerde(plan); + + plan = PlanBuilder(pool_.get()) + .values({data_}) + .filter("c3 = ARRAY[1,2,3]") + .planNode(); + testSerde(plan); } TEST_F(PlanNodeSerdeTest, groupId) { @@ -305,11 +342,16 @@ TEST_F(PlanNodeSerdeTest, localMerge) { TEST_F(PlanNodeSerdeTest, mergeJoin) { auto probe = makeRowVector( - {"t0", "t1", "t2"}, + {"t0", "t1", "t2", "t3"}, { makeFlatVector({1, 2, 3}), makeFlatVector({10, 20, 30}), makeFlatVector({true, true, false}), + makeArrayVector({ + {1, 2}, + {3, 4, 5}, + {}, + }), }); auto build = makeRowVector( @@ -335,6 +377,19 @@ TEST_F(PlanNodeSerdeTest, mergeJoin) { testSerde(plan); + plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({probe}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "t3 = ARRAY[1,2,3]", + {"t0", "t1", "u2", "t2"}, + core::JoinType::kInner) + .planNode(); + + testSerde(plan); + plan = PlanBuilder(planNodeIdGenerator) .values({probe}) .mergeJoin( @@ -401,13 +456,35 @@ TEST_F(PlanNodeSerdeTest, project) { testSerde(plan); } +TEST_F(PlanNodeSerdeTest, parallelProjet) { + auto plan = PlanBuilder() + .values({data_}) + .parallelProject({{"c0 + 1", "c0 * 2"}, {"c1 + 1", "c1 * 2"}}) + .planNode(); + + testSerde(plan); + + plan = PlanBuilder() + .values({data_}) + .parallelProject( + {{"c0 + 1", "c0 * 2"}, {"c1 + 1", "c1 * 2"}}, {"c2", "c3"}) + .planNode(); + + testSerde(plan); +} + TEST_F(PlanNodeSerdeTest, hashJoin) { auto probe = makeRowVector( - {"t0", "t1", "t2"}, + {"t0", "t1", "t2", "t3"}, { makeFlatVector({1, 2, 3}), makeFlatVector({10, 20, 30}), makeFlatVector({true, true, false}), + makeArrayVector({ + {1, 2}, + {3, 4, 5}, + {}, + }), }); auto build = makeRowVector( @@ -433,6 +510,19 @@ TEST_F(PlanNodeSerdeTest, hashJoin) { testSerde(plan); + plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({probe}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "t3 = ARRAY[1,2,3]", + {"t0", "t1", "u2", "t2"}, + core::JoinType::kInner) + .planNode(); + + testSerde(plan); + plan = PlanBuilder(planNodeIdGenerator) .values({probe}) .hashJoin( @@ -482,6 +572,12 @@ TEST_F(PlanNodeSerdeTest, unnest) { plan = PlanBuilder().values({data}).unnest({"c0"}, {"c1"}, "ordinal").planNode(); testSerde(plan); + + plan = PlanBuilder() + .values({data}) + .unnest({"c0"}, {"c1"}, "ordinal", "emptyUnnestValue") + .planNode(); + testSerde(plan); } TEST_F(PlanNodeSerdeTest, values) { @@ -561,34 +657,56 @@ TEST_F(PlanNodeSerdeTest, rowNumber) { } TEST_F(PlanNodeSerdeTest, scan) { - auto plan = PlanBuilder(pool_.get()) - .tableScan( - ROW({"a", "b", "c", "d"}, - {BIGINT(), BIGINT(), BOOLEAN(), DOUBLE()}), - {"a < 5", "b = 7", "c = true", "d > 0.01"}, - "a + b < 100") - .planNode(); - testSerde(plan); + { + auto plan = PlanBuilder(pool_.get()) + .tableScan( + ROW({"a", "b", "c", "d"}, + {BIGINT(), BIGINT(), BOOLEAN(), DOUBLE()}), + {"a < 5", "b = 7", "c = true", "d > 0.01"}, + "a + b < 100") + .planNode(); + testSerde(plan); + } + + { + auto plan = + PlanBuilder() + .startTableScan() + .outputType(ROW({"x"}, {BIGINT()})) + .assignments( + {{"x", HiveConnectorTestBase::regularColumn("a", BIGINT())}}) + .dataColumns(ROW({"a", "b"}, {BIGINT(), BIGINT()})) + .filterColumnHandles({ + HiveConnectorTestBase::partitionKey("ds", VARCHAR()), + HiveConnectorTestBase::regularColumn("a", BIGINT()), + }) + .remainingFilter("length(ds) + a % 2 > 0") + .endTableScan() + .planNode(); + testSerde(plan); + } + + { + auto plan = PlanBuilder() + .startTableScan() + .outputType(ROW({"x"}, {BIGINT()})) + .sampleRate(0.5) + .endTableScan() + .planNode(); + testSerde(plan); + } } TEST_F(PlanNodeSerdeTest, topNRowNumber) { - auto plan = PlanBuilder() - .values({data_}) - .topNRowNumber({}, {"c0", "c2"}, 10, false) - .planNode(); - testSerde(plan); + topNRankSerdeTest("row_number"); +} - plan = PlanBuilder() - .values({data_}) - .topNRowNumber({}, {"c0", "c2"}, 10, true) - .planNode(); - testSerde(plan); +TEST_F(PlanNodeSerdeTest, topNRank) { + topNRankSerdeTest("rank"); +} - plan = PlanBuilder() - .values({data_}) - .topNRowNumber({"c0"}, {"c1", "c2"}, 10, false) - .planNode(); - testSerde(plan); +TEST_F(PlanNodeSerdeTest, topNDenseRank) { + topNRankSerdeTest("dense_rank"); } TEST_F(PlanNodeSerdeTest, write) { @@ -617,7 +735,17 @@ TEST_F(PlanNodeSerdeTest, tableWriteMerge) { plan = PlanBuilder(pool_.get()) .values(data_) - .tableWrite("targetDirectory") + .tableWrite( + "targetDirectory", dwio::common::FileFormat::DWRF, {"min(c0)"}) + .localPartition(std::vector{}) + .tableWriteMerge() + .planNode(); + testSerde(plan); + + plan = PlanBuilder(pool_.get()) + .values(data_) + .tableWrite( + "targetDirectory", dwio::common::FileFormat::DWRF, {"min(c0)"}) .localPartition(std::vector{}) .tableWriteMerge() .planNode(); @@ -638,8 +766,233 @@ TEST_F(PlanNodeSerdeTest, tableWriteWithStats) { "approx_distinct(c2)", "sum_data_size_for_stats(c2)", "max_data_size_for_stats(c2)"}) + .localPartition(std::vector{}) + .tableWriteMerge() .planNode(); testSerde(plan); } +TEST_F(PlanNodeSerdeTest, columnStatsSpec) { + // Helper function to create typed expressions + const auto createFieldAccess = [&](const std::string& name, + const TypePtr& type) { + return std::make_shared(type, name); + }; + + const auto createCallExpr = [&](const std::string& name, + const TypePtr& returnType, + const std::vector& args) { + return std::make_shared(returnType, args, name); + }; + + // Test 1: Empty ColumnStatsSpec + { + VELOX_ASSERT_THROW( + std::make_unique( + std::vector{}, + core::AggregationNode::Step::kSingle, + std::vector{}, + std::vector{}), + ""); + } + + // Test 2: ColumnStatsSpec with only aggregates (no grouping keys) + { + std::vector groupingKeys; + std::vector aggregateNames = {"min", "max", "count"}; + std::vector aggregates; + + // Create min(c0) aggregate + auto minInput = createFieldAccess("c0", BIGINT()); + auto minCall = createCallExpr("min", BIGINT(), {minInput}); + aggregates.push_back({minCall, {BIGINT()}, nullptr, {}, {}, false}); + + // Create max(c0) aggregate + auto maxInput = createFieldAccess("c0", BIGINT()); + auto maxCall = createCallExpr("max", BIGINT(), {maxInput}); + aggregates.push_back({maxCall, {BIGINT()}, nullptr, {}, {}, false}); + + // Create count(c1) aggregate + auto countInput = createFieldAccess("c1", INTEGER()); + auto countCall = createCallExpr("count", BIGINT(), {countInput}); + aggregates.push_back({countCall, {INTEGER()}, nullptr, {}, {}, false}); + + core::ColumnStatsSpec spec( + std::move(groupingKeys), + core::AggregationNode::Step::kPartial, + std::move(aggregateNames), + std::move(aggregates)); + + auto serialized = spec.serialize(); + auto deserialized = core::ColumnStatsSpec::create(serialized, pool()); + + EXPECT_TRUE(deserialized.groupingKeys.empty()); + EXPECT_EQ(deserialized.aggregateNames.size(), 3); + EXPECT_EQ(deserialized.aggregates.size(), 3); + EXPECT_EQ( + deserialized.aggregationStep, core::AggregationNode::Step::kPartial); + EXPECT_EQ(deserialized.aggregateNames.size(), 3); + + EXPECT_EQ(deserialized.aggregateNames[0], "min"); + EXPECT_EQ(deserialized.aggregateNames[1], "max"); + EXPECT_EQ(deserialized.aggregateNames[2], "count"); + EXPECT_EQ(deserialized.aggregates[0].call->name(), "min"); + EXPECT_EQ(deserialized.aggregates[1].call->name(), "max"); + EXPECT_EQ(deserialized.aggregates[2].call->name(), "count"); + } + + // Test 3: ColumnStatsSpec with grouping keys and aggregates + { + std::vector groupingKeys; + groupingKeys.push_back(createFieldAccess("partition_key", VARCHAR())); + groupingKeys.push_back(createFieldAccess("bucket_id", INTEGER())); + + std::vector aggregateNames = {"min", "sum"}; + std::vector aggregates; + + // Create min(c0) aggregate + auto minInput = createFieldAccess("c0", BIGINT()); + auto minCall = createCallExpr("min", BIGINT(), {minInput}); + aggregates.push_back({minCall, {BIGINT()}, nullptr, {}, {}, false}); + + // Create sum(c1) aggregate + auto sumInput = createFieldAccess("c1", DOUBLE()); + auto sumCall = createCallExpr("sum", DOUBLE(), {sumInput}); + aggregates.push_back({sumCall, {DOUBLE()}, nullptr, {}, {}, false}); + + core::ColumnStatsSpec spec( + std::move(groupingKeys), + core::AggregationNode::Step::kIntermediate, + std::move(aggregateNames), + std::move(aggregates)); + + auto serialized = spec.serialize(); + auto deserialized = core::ColumnStatsSpec::create(serialized, pool()); + + EXPECT_EQ(deserialized.groupingKeys.size(), 2); + EXPECT_EQ(deserialized.aggregateNames.size(), 2); + EXPECT_EQ(deserialized.aggregates.size(), 2); + EXPECT_EQ( + deserialized.aggregationStep, + core::AggregationNode::Step::kIntermediate); + EXPECT_EQ(deserialized.aggregateNames.size(), 2); + + EXPECT_EQ(deserialized.groupingKeys[0]->name(), "partition_key"); + EXPECT_EQ(deserialized.groupingKeys[0]->type(), VARCHAR()); + EXPECT_EQ(deserialized.groupingKeys[1]->name(), "bucket_id"); + EXPECT_EQ(deserialized.groupingKeys[1]->type(), INTEGER()); + + EXPECT_EQ(deserialized.aggregateNames[0], "min"); + EXPECT_EQ(deserialized.aggregateNames[1], "sum"); + EXPECT_EQ(deserialized.aggregates[0].call->name(), "min"); + EXPECT_EQ(deserialized.aggregates[1].call->name(), "sum"); + } + + // Test 4: ColumnStatsSpec with different aggregation steps + for (auto step : + {core::AggregationNode::Step::kSingle, + core::AggregationNode::Step::kPartial, + core::AggregationNode::Step::kIntermediate, + core::AggregationNode::Step::kFinal}) { + std::vector groupingKeys; + std::vector aggregateNames = {"count"}; + std::vector aggregates; + + auto countInput = createFieldAccess("test_col", BIGINT()); + auto countCall = createCallExpr("count", BIGINT(), {countInput}); + aggregates.push_back({countCall, {BIGINT()}, nullptr, {}, {}, false}); + + core::ColumnStatsSpec spec( + std::move(groupingKeys), + step, + std::move(aggregateNames), + std::move(aggregates)); + + auto serialized = spec.serialize(); + auto deserialized = core::ColumnStatsSpec::create(serialized, pool()); + + EXPECT_EQ(deserialized.aggregationStep, step); + EXPECT_EQ(deserialized.aggregateNames.size(), 1); + EXPECT_EQ(deserialized.aggregateNames[0], "count"); + } + + // Test 5: ColumnStatsSpec with complex aggregates (distinct, with mask, with + // sorting) + { + std::vector groupingKeys; + std::vector aggregateNames = { + "count_distinct", "array_agg_sorted"}; + std::vector aggregates; + + // Create count(distinct c0) aggregate + auto distinctInput = createFieldAccess("c0", VARCHAR()); + auto countDistinctCall = createCallExpr("count", BIGINT(), {distinctInput}); + auto maskField = createFieldAccess("mask_col", BOOLEAN()); + aggregates.push_back( + {countDistinctCall, {VARCHAR()}, maskField, {}, {}, true}); + + // Create array_agg(c1 ORDER BY c2) aggregate + auto arrayInput = createFieldAccess("c1", INTEGER()); + auto arrayAggCall = + createCallExpr("array_agg", ARRAY(INTEGER()), {arrayInput}); + auto sortingKey = createFieldAccess("c2", BIGINT()); + std::vector sortingKeys = {sortingKey}; + std::vector sortingOrders = {core::SortOrder(true, true)}; + aggregates.push_back( + {arrayAggCall, + {INTEGER()}, + nullptr, + sortingKeys, + sortingOrders, + false}); + + core::ColumnStatsSpec spec( + std::move(groupingKeys), + core::AggregationNode::Step::kSingle, + std::move(aggregateNames), + std::move(aggregates)); + + auto serialized = spec.serialize(); + auto deserialized = core::ColumnStatsSpec::create(serialized, pool()); + + EXPECT_EQ(deserialized.aggregates.size(), 2); + + // Check distinct aggregate + EXPECT_TRUE(deserialized.aggregates[0].distinct); + EXPECT_NE(deserialized.aggregates[0].mask, nullptr); + EXPECT_EQ(deserialized.aggregates[0].mask->name(), "mask_col"); + + // Check sorted aggregate + EXPECT_FALSE(deserialized.aggregates[1].distinct); + EXPECT_EQ(deserialized.aggregates[1].mask, nullptr); + EXPECT_EQ(deserialized.aggregates[1].sortingKeys.size(), 1); + EXPECT_EQ(deserialized.aggregates[1].sortingKeys[0]->name(), "c2"); + EXPECT_EQ(deserialized.aggregates[1].sortingOrders.size(), 1); + EXPECT_TRUE(deserialized.aggregates[1].sortingOrders[0].isAscending()); + EXPECT_TRUE(deserialized.aggregates[1].sortingOrders[0].isNullsFirst()); + } + + // Error cases. + { + std::vector groupingKeys; + groupingKeys.push_back(createFieldAccess("partition_key", VARCHAR())); + groupingKeys.push_back(createFieldAccess("bucket_id", INTEGER())); + + std::vector aggregateNames = {"min", "sum"}; + std::vector aggregates; + + // Create min(c0) aggregate + auto minInput = createFieldAccess("c0", BIGINT()); + auto minCall = createCallExpr("min", BIGINT(), {minInput}); + aggregates.push_back({minCall, {BIGINT()}, nullptr, {}, {}, false}); + VELOX_ASSERT_THROW( + core::ColumnStatsSpec( + std::move(groupingKeys), + core::AggregationNode::Step::kSingle, + std::move(aggregateNames), + std::move(aggregates)), + ""); + } +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PlanNodeStatsTest.cpp b/velox/exec/tests/PlanNodeStatsTest.cpp new file mode 100644 index 000000000000..2a59537eb430 --- /dev/null +++ b/velox/exec/tests/PlanNodeStatsTest.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/PlanNodeStats.h" +#include + +namespace facebook::velox::exec::test { + +TEST(PlanNodeStatsTest, exprStatsTotal) { + PlanNodeStats stats; + stats.expressionStats["foo"] = ExprStats{ + .timing = {.wallNanos = 1, .cpuNanos = 2}, + .numProcessedRows = 3, + .numProcessedVectors = 4}; + + PlanNodeStats total; + total += stats; + EXPECT_EQ(total.expressionStats["foo"], stats.expressionStats["foo"]); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PlanNodeToStringTest.cpp b/velox/exec/tests/PlanNodeToStringTest.cpp index 0619aabb6c47..334c35127db8 100644 --- a/velox/exec/tests/PlanNodeToStringTest.cpp +++ b/velox/exec/tests/PlanNodeToStringTest.cpp @@ -14,21 +14,21 @@ * limitations under the License. */ +#include "velox/connectors/hive/HiveConnector.h" #include "velox/exec/WindowFunction.h" -#include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/parse/TypeResolver.h" +#include "velox/vector/tests/utils/VectorTestBase.h" #include -using namespace facebook; -using namespace facebook::velox; -using namespace facebook::velox::common::test; - using facebook::velox::exec::test::PlanBuilder; +namespace facebook::velox::exec { +namespace { + class PlanNodeToStringTest : public testing::Test, public velox::test::VectorTestBase { public: @@ -75,25 +75,25 @@ TEST_F(PlanNodeToStringTest, recursive) { TEST_F(PlanNodeToStringTest, detailed) { ASSERT_EQ( - "-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n", + "-- Project[4][expressions: (out3:BIGINT, plus(cast(ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n", plan_->toString(true, false)); } TEST_F(PlanNodeToStringTest, recursiveAndDetailed) { ASSERT_EQ( - "-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n" - " -- Filter[3][expression: lt(mod(cast ROW[\"out1\"] as BIGINT,10),8)] -> out1:SMALLINT, out2:BIGINT\n" - " -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast ROW[\"c0\"] as BIGINT,100),mod(cast ROW[\"c1\"] as BIGINT,50)))] -> out1:SMALLINT, out2:BIGINT\n" - " -- Filter[1][expression: lt(mod(cast ROW[\"c0\"] as BIGINT,10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n" + "-- Project[4][expressions: (out3:BIGINT, plus(cast(ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n" + " -- Filter[3][expression: lt(mod(cast(ROW[\"out1\"] as BIGINT),10),8)] -> out1:SMALLINT, out2:BIGINT\n" + " -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast(ROW[\"c0\"] as BIGINT),100),mod(cast(ROW[\"c1\"] as BIGINT),50)))] -> out1:SMALLINT, out2:BIGINT\n" + " -- Filter[1][expression: lt(mod(cast(ROW[\"c0\"] as BIGINT),10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n" " -- Values[0][5 rows in 1 vectors] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n", plan_->toString(true, true)); } TEST_F(PlanNodeToStringTest, withContext) { auto addContext = [](const core::PlanNodeId& planNodeId, - const std::string& /* indentation */, - std::stringstream& stream) { - stream << "Context for " << planNodeId; + const std::string& indentation, + std::ostream& stream) { + stream << indentation << "Context for " << planNodeId << std::endl; }; ASSERT_EQ( @@ -102,7 +102,7 @@ TEST_F(PlanNodeToStringTest, withContext) { plan_->toString(false, false, addContext)); ASSERT_EQ( - "-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n" + "-- Project[4][expressions: (out3:BIGINT, plus(cast(ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n" " Context for 4\n", plan_->toString(true, false, addContext)); @@ -120,13 +120,13 @@ TEST_F(PlanNodeToStringTest, withContext) { plan_->toString(false, true, addContext)); ASSERT_EQ( - "-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n" + "-- Project[4][expressions: (out3:BIGINT, plus(cast(ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n" " Context for 4\n" - " -- Filter[3][expression: lt(mod(cast ROW[\"out1\"] as BIGINT,10),8)] -> out1:SMALLINT, out2:BIGINT\n" + " -- Filter[3][expression: lt(mod(cast(ROW[\"out1\"] as BIGINT),10),8)] -> out1:SMALLINT, out2:BIGINT\n" " Context for 3\n" - " -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast ROW[\"c0\"] as BIGINT,100),mod(cast ROW[\"c1\"] as BIGINT,50)))] -> out1:SMALLINT, out2:BIGINT\n" + " -- Project[2][expressions: (out1:SMALLINT, ROW[\"c0\"]), (out2:BIGINT, plus(mod(cast(ROW[\"c0\"] as BIGINT),100),mod(cast(ROW[\"c1\"] as BIGINT),50)))] -> out1:SMALLINT, out2:BIGINT\n" " Context for 2\n" - " -- Filter[1][expression: lt(mod(cast ROW[\"c0\"] as BIGINT,10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n" + " -- Filter[1][expression: lt(mod(cast(ROW[\"c0\"] as BIGINT),10),9)] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n" " Context for 1\n" " -- Values[0][5 rows in 1 vectors] -> c0:SMALLINT, c1:INTEGER, c2:BIGINT\n" " Context for 0\n", @@ -136,9 +136,11 @@ TEST_F(PlanNodeToStringTest, withContext) { TEST_F(PlanNodeToStringTest, withMultiLineContext) { auto addContext = [](const core::PlanNodeId& planNodeId, const std::string& indentation, - std::stringstream& stream) { - stream << "Context for " << planNodeId << ": line 1" << std::endl; - stream << indentation << "Context for " << planNodeId << ": line 2"; + std::ostream& stream) { + stream << indentation << "Context for " << planNodeId << ": line 1" + << std::endl; + stream << indentation << "Context for " << planNodeId << ": line 2" + << std::endl; }; ASSERT_EQ( @@ -148,7 +150,7 @@ TEST_F(PlanNodeToStringTest, withMultiLineContext) { plan_->toString(false, false, addContext)); ASSERT_EQ( - "-- Project[4][expressions: (out3:BIGINT, plus(cast ROW[\"out1\"] as BIGINT,10))] -> out3:BIGINT\n" + "-- Project[4][expressions: (out3:BIGINT, plus(cast(ROW[\"out1\"] as BIGINT),10))] -> out3:BIGINT\n" " Context for 4: line 1\n" " Context for 4: line 2\n", plan_->toString(true, false, addContext)); @@ -755,7 +757,7 @@ TEST_F(PlanNodeToStringTest, tableScan) { "range filters: [(discount, DoubleRange: [0.050000, 0.070000] no nulls), " "(quantity, DoubleRange: (-inf, 24.000000) no nulls), " "(shipdate, BytesRange: [1994-01-01, 1994-12-31] no nulls)], " - "remaining filter: (not(like(ROW[\"comment\"],\"%special%request%\")))] " + "remaining filter: (not(like(ROW[\"comment\"],%special%request%)))] " "-> discount:DOUBLE, quantity:DOUBLE, shipdate:VARCHAR, comment:VARCHAR\n"; ASSERT_EQ(output, plan->toString(true, false)); } @@ -766,7 +768,7 @@ TEST_F(PlanNodeToStringTest, tableScan) { .planNode(); ASSERT_EQ( - "-- TableScan[0][table: hive_table, remaining filter: (not(like(ROW[\"comment\"],\"%special%request%\")))] " + "-- TableScan[0][table: hive_table, remaining filter: (not(like(ROW[\"comment\"],%special%request%)))] " "-> discount:DOUBLE, quantity:DOUBLE, shipdate:VARCHAR, comment:VARCHAR\n", plan->toString(true, false)); } @@ -927,38 +929,58 @@ TEST_F(PlanNodeToStringTest, rowNumber) { plan->toString(true, false)); } -TEST_F(PlanNodeToStringTest, topNRowNumber) { +namespace { +void topNRankPlanNodeToStringTest(std::string_view function) { auto rowType = ROW({"a", "b"}, {BIGINT(), VARCHAR()}); auto plan = PlanBuilder() .tableScan(rowType) - .topNRowNumber({}, {"a DESC"}, 10, false) + .topNRank(function, {}, {"a DESC"}, 10, false) .planNode(); ASSERT_EQ("-- TopNRowNumber[1]\n", plan->toString()); ASSERT_EQ( - "-- TopNRowNumber[1][order by (a DESC NULLS LAST) limit 10] -> a:BIGINT, b:VARCHAR\n", + fmt::format( + "-- TopNRowNumber[1][{} order by (a DESC NULLS LAST) limit 10] -> a:BIGINT, b:VARCHAR\n", + function), plan->toString(true, false)); plan = PlanBuilder() .tableScan(rowType) - .topNRowNumber({}, {"a DESC"}, 10, true) + .topNRank(function, {}, {"a DESC"}, 10, true) .planNode(); ASSERT_EQ("-- TopNRowNumber[1]\n", plan->toString()); ASSERT_EQ( - "-- TopNRowNumber[1][order by (a DESC NULLS LAST) limit 10] -> a:BIGINT, b:VARCHAR, row_number:BIGINT\n", + fmt::format( + "-- TopNRowNumber[1][{} order by (a DESC NULLS LAST) limit 10] -> a:BIGINT, b:VARCHAR, row_number:BIGINT\n", + function), plan->toString(true, false)); plan = PlanBuilder() .tableScan(rowType) - .topNRowNumber({"a"}, {"b"}, 10, false) + .topNRank(function, {"a"}, {"b"}, 10, false) .planNode(); ASSERT_EQ("-- TopNRowNumber[1]\n", plan->toString()); ASSERT_EQ( - "-- TopNRowNumber[1][partition by (a) order by (b ASC NULLS LAST) limit 10] -> a:BIGINT, b:VARCHAR\n", + fmt::format( + "-- TopNRowNumber[1][{} partition by (a) order by (b ASC NULLS LAST) limit 10] -> a:BIGINT, b:VARCHAR\n", + function), plan->toString(true, false)); } +} // namespace + +TEST_F(PlanNodeToStringTest, topNRowNumber) { + topNRankPlanNodeToStringTest("row_number"); +} + +TEST_F(PlanNodeToStringTest, topNRank) { + topNRankPlanNodeToStringTest("rank"); +} + +TEST_F(PlanNodeToStringTest, topNDenseRank) { + topNRankPlanNodeToStringTest("dense_rank"); +} TEST_F(PlanNodeToStringTest, markDistinct) { auto op = @@ -971,3 +993,6 @@ TEST_F(PlanNodeToStringTest, markDistinct) { "-- MarkDistinct[1][a, b] -> a:VARCHAR, b:BIGINT, c:BIGINT, marker:BOOLEAN\n", op->toString(true, false)); } + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/PlanNodeToSummaryStringTest.cpp b/velox/exec/tests/PlanNodeToSummaryStringTest.cpp index 2a726e2988bc..4338c3cd1981 100644 --- a/velox/exec/tests/PlanNodeToSummaryStringTest.cpp +++ b/velox/exec/tests/PlanNodeToSummaryStringTest.cpp @@ -20,6 +20,7 @@ #include "velox/parse/TypeResolver.h" #include "velox/vector/tests/utils/VectorTestBase.h" +#include #include using facebook::velox::exec::test::PlanBuilder; @@ -39,6 +40,13 @@ class PlanNodeToSummaryStringTest : public testing::Test, static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); } + + static std::vector toLines(const std::string& planText) { + std::vector lines; + folly::split("\n", planText, lines); + + return lines; + } }; TEST_F(PlanNodeToSummaryStringTest, basic) { @@ -64,12 +72,11 @@ TEST_F(PlanNodeToSummaryStringTest, basic) { " functions: plus: 3, subscript: 6\n" " constants: BIGINT: 9\n" " projections: 7 out of 7\n" - " dereferences: 0 out of 7\n" " -- Filter[1]: 3 fields: a INTEGER, b ARRAY, c MAP\n" " expressions: call: 8, cast: 2, constant: 5, field: 3\n" " functions: and: 2, cardinality: 1, gt: 3, plus: 1, subscript: 1\n" " constants: BIGINT: 5\n" - " filter: and(and(gt(cast ROW[\"a\"] as BIGINT,10),gt(cardinal...\n" + " filter: and(and(gt(cast(ROW[\"a\"] as BIGINT),10),gt(cardina...\n" " -- TableScan[0]: 3 fields: a INTEGER, b ARRAY, c MAP\n" " table: hive_table\n", plan->toSummaryString()); @@ -80,15 +87,14 @@ TEST_F(PlanNodeToSummaryStringTest, basic) { " functions: plus: 3, subscript: 6\n" " constants: BIGINT: 9\n" " projections: 7 out of 7\n" - " p0: plus(cast ROW[\"a\"] as BIGINT,1)\n" - " p1: subscript(ROW[\"b\"],cast 1 as INTEGER)\n" + " p0: plus(cast(ROW[\"a\"] as BIGINT),1)\n" + " p1: subscript(ROW[\"b\"],cast(1 as INTEGER))\n" " ... 5 more\n" - " dereferences: 0 out of 7\n" " -- Filter[1]: 3 fields: a INTEGER, b ARRAY(BIGINT), c MAP(TINYINT, BIGINT)\n" " expressions: call: 8, cast: 2, constant: 5, field: 3\n" " functions: and: 2, cardinality: 1, gt: 3, plus: 1, subscript: 1\n" " constants: BIGINT: 5\n" - " filter: and(and(gt(cast ROW[\"a\"] as BIGINT,10),gt(cardinal...\n" + " filter: and(and(gt(cast(ROW[\"a\"] as BIGINT),10),gt(cardina...\n" " -- TableScan[0]: 3 fields: a INTEGER, b ARRAY(BIGINT), c MAP(TINYINT, BIGINT)\n" " table: hive_table\n", plan->toSummaryString({ @@ -96,6 +102,14 @@ TEST_F(PlanNodeToSummaryStringTest, basic) { .maxOutputFields = 3, .maxChildTypes = 2, })); + + ASSERT_THAT( + toLines(plan->toSkeletonString()), + testing::ElementsAre( + testing::Eq("-- Filter[1]: 3 fields"), + testing::Eq(" -- TableScan[0]: 3 fields"), + testing::Eq(" table: hive_table"), + testing::Eq(""))); } TEST_F(PlanNodeToSummaryStringTest, expressions) { @@ -118,47 +132,56 @@ TEST_F(PlanNodeToSummaryStringTest, expressions) { "d.y", "length(d.z) * strpos(d.z, 'foo')", "12.345", + "ceil(cast(a as real))", }) .planNode(); ASSERT_EQ( - "-- Project[1]: 7 fields: a INTEGER, p1 ARRAY, c MAP, p3 BIGINT, y BIGINT, ...\n" - " expressions: call: 6, constant: 4, dereference: 4, field: 8, lambda: 1\n" - " functions: length: 1, multiply: 2, plus: 1, strpos: 1, transform: 1\n" + "-- Project[1]: 8 fields: a INTEGER, p1 ARRAY, c MAP, p3 BIGINT, y BIGINT, ...\n" + " expressions: call: 7, cast: 1, constant: 4, dereference: 4, field: 7, lambda: 1\n" + " functions: ceil: 1, length: 1, multiply: 2, plus: 1, strpos: 1, transform: 1\n" " constants: BIGINT: 2, DOUBLE: 1, VARCHAR: 1\n" - " projections: 4 out of 7\n" - " dereferences: 1 out of 7\n" + " projections: 4 out of 8\n" + " dereferences: 1 out of 8\n" + " constant projections: 1 out of 8\n" " -- TableScan[0]: 4 fields: a INTEGER, b ARRAY, c MAP, d ROW(3)\n" " table: hive_table\n", plan->toSummaryString()); ASSERT_EQ( - "-- Project[1]: 7 fields: a INTEGER, p1 ARRAY, c MAP, p3 BIGINT, y BIGINT, ...\n" - " expressions: call: 6, constant: 4, dereference: 4, field: 8, lambda: 1\n" - " functions: length: 1, multiply: 2, plus: 1, strpos: 1, transform: 1\n" + "-- Project[1]: 8 fields: a INTEGER, p1 ARRAY, c MAP, p3 BIGINT, y BIGINT, ...\n" + " expressions: call: 7, cast: 1, constant: 4, dereference: 4, field: 7, lambda: 1\n" + " functions: ceil: 1, length: 1, multiply: 2, plus: 1, strpos: 1, transform: 1\n" " constants: BIGINT: 2, DOUBLE: 1, VARCHAR: 1\n" - " projections: 4 out of 7\n" + " projections: 4 out of 8\n" " p1: transform(ROW[\"b\"],lambda ROW -> plus(RO...\n" " p3: multiply(ROW[\"d\"][x],10)\n" - " p5: multiply(length(ROW[\"d\"][z]),strpos(ROW[\"d\"][z],\"f...\n" + " p5: multiply(length(ROW[\"d\"][z]),strpos(ROW[\"d\"][z],fo...\n" " ... 1 more\n" - " dereferences: 1 out of 7\n" + " dereferences: 1 out of 8\n" " y: ROW[\"d\"][y]\n" + " constant projections: 1 out of 8\n" + " p6: 12.345\n" " -- TableScan[0]: 4 fields: a INTEGER, b ARRAY, c MAP, d ROW(3)\n" " table: hive_table\n", plan->toSummaryString( - {.project = {.maxProjections = 3, .maxDereferences = 2}})); + {.project = { + .maxProjections = 3, + .maxDereferences = 2, + .maxConstants = 1, + }})); + + ASSERT_THAT( + toLines(plan->toSkeletonString()), + testing::ElementsAre( + testing::Eq("-- TableScan[0]: 4 fields"), + testing::Eq(" table: hive_table"), + testing::Eq(""))); } TEST_F(PlanNodeToSummaryStringTest, aggregation) { aggregate::prestosql::registerAllAggregateFunctions(); - auto rowType = - ROW({"a", "b", "c"}, - { - INTEGER(), - INTEGER(), - INTEGER(), - }); + auto rowType = ROW({"a", "b", "c"}, {INTEGER(), INTEGER(), INTEGER()}); auto plan = PlanBuilder() .tableScan(rowType) @@ -189,7 +212,46 @@ TEST_F(PlanNodeToSummaryStringTest, aggregation) { " -- TableScan[0]: 3 fields: a INTEGER, b INTEGER, c INTEGER\n" " table: hive_table\n", plan->toSummaryString({.aggregate = {.maxAggregations = 3}})); + + ASSERT_THAT( + toLines(plan->toSkeletonString()), + testing::ElementsAre( + testing::Eq("-- Aggregation[1]: 6 fields"), + testing::Eq(" -- TableScan[0]: 3 fields"), + testing::Eq(" table: hive_table"), + testing::Eq(""))); } +TEST_F(PlanNodeToSummaryStringTest, withContext) { + auto plan = PlanBuilder() + .tableScan(ROW({"a", "b", "c"}, INTEGER())) + .project({"a + b", "a * c"}) + .filter("p0 > p1") + .planNode(); + + auto addContext = [](const PlanNodeId& planNodeId, + const std::string& indentation, + std::ostream& stream) { + stream << indentation << "Context for " << planNodeId << std::endl; + }; + + ASSERT_THAT( + toLines(plan->toSummaryString({}, addContext)), + testing::ElementsAre( + testing::StartsWith("-- Filter[2]"), + testing::StartsWith(" expressions:"), + testing::StartsWith(" functions:"), + testing::StartsWith(" filter:"), + testing::StartsWith(" Context for 2"), + testing::StartsWith(" -- Project[1]"), + testing::StartsWith(" expressions:"), + testing::StartsWith(" functions:"), + testing::StartsWith(" projections:"), + testing::StartsWith(" Context for 1"), + testing::StartsWith(" -- TableScan[0]"), + testing::StartsWith(" table: hive_table"), + testing::StartsWith(" Context for 0"), + testing::Eq(""))); +} } // namespace } // namespace facebook::velox::core diff --git a/velox/exec/tests/PrefixSortTest.cpp b/velox/exec/tests/PrefixSortTest.cpp index f404fa05ac4d..732d590dc052 100644 --- a/velox/exec/tests/PrefixSortTest.cpp +++ b/velox/exec/tests/PrefixSortTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "velox/exec/PrefixSort.h" diff --git a/velox/exec/tests/PrestoQueryRunnerHyperLogLogTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerHyperLogLogTransformTest.cpp new file mode 100644 index 000000000000..9349f42beed9 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerHyperLogLogTransformTest.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/HyperLogLogType.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerHyperLogLogTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerHyperLogLogTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(HYPERLOGLOG())); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(HYPERLOGLOG()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(HYPERLOGLOG(), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), HYPERLOGLOG()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({HYPERLOGLOG(), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), TIMESTAMP(), ARRAY(ROW({MAP(VARCHAR(), HYPERLOGLOG())}))}))); +} + +TEST_F(PrestoQueryRunnerHyperLogLogTransformTest, transform) { + test(HYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerHyperLogLogTransformTest, transformArray) { + testArray(HYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerHyperLogLogTransformTest, transformMap) { + testMap(HYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerHyperLogLogTransformTest, transformRow) { + testRow(HYPERLOGLOG()); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp b/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp new file mode 100644 index 000000000000..a52fc8150685 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.cpp @@ -0,0 +1,193 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" + +namespace facebook::velox::exec::test { + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::test( + const VectorPtr& vector) { + const auto colName = "col"; + const auto input = + makeRowVector({colName}, {transformIntermediateTypes(vector)}); + + auto expr = getProjectionsToIntermediateTypes( + vector->type(), + std::make_shared( + colName, + std::nullopt, + std::vector{std::make_shared()}), + colName); + + core::PlanNodePtr plan = + PlanBuilder().values({input}).projectExpressions({expr}).planNode(); + + AssertQueryBuilder(plan).assertResults(makeRowVector({colName}, {vector})); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testDictionary( + const VectorPtr& base) { + // Wrap in a dictionary without nulls. + auto dict = BaseVector::wrapInDictionary( + nullptr, makeIndicesInReverse(100), 100, base); + test(dict); + // Wrap in a dictionary with some nulls. + dict = BaseVector::wrapInDictionary( + makeNulls(100, [](vector_size_t row) { return row % 10 == 0; }), + makeIndicesInReverse(100), + 100, + base); + test(dict); + // Wrap in a dictionary with all nulls. + dict = BaseVector::wrapInDictionary( + makeNulls(100, [](vector_size_t) { return true; }), + makeIndicesInReverse(100), + 100, + base); + test(dict); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testConstant( + const VectorPtr& base) { + // Test a non-null constant. + test(BaseVector::wrapInConstant(100, 0, base)); + // Test a null constant. + test(BaseVector::createNullConstant(ARRAY(base->type()), 100, pool_.get())); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::test( + const TypePtr& type) { + // No nulls. + auto& opts = fuzzer_.getMutableOptions(); + opts.nullRatio = 0; + auto vector = fuzzer_.fuzzFlat(type, kVectorSize); + test(vector); + // All nulls. + opts.nullRatio = 1; + vector = fuzzer_.fuzzFlat(type, kVectorSize); + test(vector); + // Some nulls. + opts.nullRatio = 0.1; + vector = fuzzer_.fuzzFlat(type, kVectorSize); + test(vector); + + testDictionary(vector); + testConstant(vector); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testArray( + const TypePtr& type) { + // Test array vector no nulls. + auto& opts = fuzzer_.getMutableOptions(); + opts.nullRatio = 0; + auto array = fuzzer_.fuzzArray(type, kVectorSize); + test(array); + + // Test array vector all nulls. + opts.nullRatio = 1; + array = fuzzer_.fuzzArray(type, kVectorSize); + test(array); + + // Test array vector some nulls. + opts.nullRatio = 0.1; + array = fuzzer_.fuzzArray(type, kVectorSize); + test(array); + + testDictionary(array); + testConstant(array); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testMap( + const TypePtr& type) { + // Test map vector no nulls. + auto& opts = fuzzer_.getMutableOptions(); + opts.nullRatio = 0.0; + auto map = fuzzer_.fuzzMap(type, type, kVectorSize); + test(map); + + // Test map vector all nulls. + opts.nullRatio = 1; + map = fuzzer_.fuzzMap(type, type, kVectorSize); + test(map); + + // Test map vector some nulls. + opts.nullRatio = 0.1; + map = fuzzer_.fuzzMap(type, type, kVectorSize); + test(map); + + testDictionary(map); + testConstant(map); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testRow( + const TypePtr& type) { + const auto size = 100; + auto& opts = fuzzer_.getMutableOptions(); + opts.nullRatio = 0; + auto field1 = fuzzer_.fuzz(type, size); + opts.nullRatio = 0.1; + auto field2 = fuzzer_.fuzz(type, size); + opts.nullRatio = 1; + auto field3 = fuzzer_.fuzz(type, size); + const auto rowType = + ROW({"c1", "c2", "c3"}, {field1->type(), field2->type(), field3->type()}); + // Test map vector no nulls. + test(vectorMaker_.rowVector({field1, field2, field3})); + + // Test row vector some nulls. + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t row) { return row % 10 == 0; }), + size, + std::vector{field1, field2, field3})); + + // Test row vector all nulls. + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t) { return true; }), + size, + std::vector{field1, field2, field3})); + + const auto base = vectorMaker_.rowVector({field1, field2, field3}); + testDictionary(base); + testConstant(base); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testArray( + const VectorPtr& vector) { + test(fuzzer_.fuzzArray(vector, vector->size())); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testMap( + const VectorPtr& keys, + const VectorPtr& values) { + VELOX_DCHECK_EQ(keys->size(), values->size()); + test(fuzzer_.fuzzMap(keys, values, keys->size())); +} + +void PrestoQueryRunnerIntermediateTypeTransformTestBase::testRow( + std::vector&& vectors, + std::vector names) { + auto vector_size = vectors.size(); + VELOX_DCHECK_EQ(vector_size, names.size()); + test(fuzzer_.fuzzRow(std::move(vectors), names, vector_size)); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h b/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h new file mode 100644 index 000000000000..56c93f9966b8 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +namespace facebook::velox::exec::test { + +class PrestoQueryRunnerIntermediateTypeTransformTestBase + : public functions::test::FunctionBaseTest { + protected: + void test(const VectorPtr& vector); + + void testDictionary(const VectorPtr& base); + + void testConstant(const VectorPtr& base); + + void test(const TypePtr& type); + + void testArray(const TypePtr& type); + + void testMap(const TypePtr& type); + + void testRow(const TypePtr& type); + + void testArray(const VectorPtr& vector); + + void testMap(const VectorPtr& keys, const VectorPtr& values); + + void testRow( + std::vector&& vectors, + std::vector names); + + private: + const int32_t kVectorSize = 100; + VectorFuzzer fuzzer_{VectorFuzzer::Options{}, pool_.get(), 123}; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerIntervalTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerIntervalTransformTest.cpp new file mode 100644 index 000000000000..340500efc4de --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerIntervalTransformTest.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerIntervalTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerIntervalTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(INTERVAL_DAY_TIME())); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(INTERVAL_DAY_TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(INTERVAL_DAY_TIME(), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), INTERVAL_DAY_TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({INTERVAL_DAY_TIME(), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), + TIMESTAMP(), + ARRAY(ROW({MAP(VARCHAR(), INTERVAL_DAY_TIME())}))}))); +} + +TEST_F(PrestoQueryRunnerIntervalTransformTest, roundTrip) { + std::vector> no_nulls{0, 1, 2, 3}; + test(makeNullableFlatVector(no_nulls, INTERVAL_DAY_TIME())); + + std::vector> some_nulls{0, 1, std::nullopt, 3}; + test(makeNullableFlatVector(some_nulls, INTERVAL_DAY_TIME())); + + std::vector> all_nulls{ + std::nullopt, std::nullopt, std::nullopt}; + test(makeNullableFlatVector(all_nulls, INTERVAL_DAY_TIME())); +} + +TEST_F(PrestoQueryRunnerIntervalTransformTest, negative) { + auto vector = makeNullableFlatVector( + std::vector>{ + -1, + -2, + -3, + -4, + -5, + }, + INTERVAL_DAY_TIME()); + const auto colName = "col"; + const auto input = + makeRowVector({colName}, {transformIntermediateTypes(vector)}); + + auto expr = getProjectionsToIntermediateTypes( + vector->type(), + std::make_shared( + colName, + std::nullopt, + std::vector{std::make_shared()}), + colName); + + core::PlanNodePtr plan = + PlanBuilder().values({input}).projectExpressions({expr}).planNode(); + + AssertQueryBuilder(plan).assertResults(makeRowVector( + {colName}, + {makeNullableFlatVector( + std::vector>{ + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + }, + INTERVAL_DAY_TIME())})); +} + +TEST_F(PrestoQueryRunnerIntervalTransformTest, transformArray) { + auto input = makeNullableFlatVector( + std::vector>{ + 0, + 1, + 12, + 123, + 1234, + 12345, + 123456, + 1234567, + 12345678, + 123456789, + 1234567890, + 12345678901, + 123456789012, + 1234567890123, + 12345678901234, + 123456789012345, + 1234567890123456, + 12345678901234567, + 123456789012345678, + 1234567890123456789, + }, + INTERVAL_DAY_TIME()); + testArray(input); + + input = makeNullableFlatVector( + std::vector>{ + 993296205767471, + 101271764434518, + 166587109740908, + 210274651317771, + 276381323443199, + 283617048324990, + 519099922518052, + 530020098439118, + 604149362180160, + 622016152847258}, + INTERVAL_DAY_TIME()); + testArray(input); +} + +TEST_F(PrestoQueryRunnerIntervalTransformTest, transformMap) { + auto keys = makeNullableFlatVector( + std::vector>{ + 1, + 12, + 123, + 1234, + 12345, + 123456, + 1234567, + 12345678, + 123456789, + 1234567890, + }, + INTERVAL_DAY_TIME()); + + auto values = makeNullableFlatVector( + std::vector>{ + 993296205767471, + 101271764434518, + 166587109740908, + 210274651317771, + 276381323443199, + 283617048324990, + 519099922518052, + 530020098439118, + 604149362180160, + 622016152847258}, + INTERVAL_DAY_TIME()); + + testMap(keys, values); +} + +TEST_F(PrestoQueryRunnerIntervalTransformTest, transformRow) { + auto input = makeNullableFlatVector( + std::vector>{ + 993296205767471, + 101271764434518, + 166587109740908, + 210274651317771, + 276381323443199, + 283617048324990, + 519099922518052, + 530020098439118, + 604149362180160, + 622016152847258}, + INTERVAL_DAY_TIME()); + testRow({input}, {"row"}); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerJsonTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerJsonTransformTest.cpp new file mode 100644 index 000000000000..d54373791d08 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerJsonTransformTest.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/JsonType.h" + +namespace facebook::velox::exec::test { +namespace { +class PrestoQueryRunnerJsonTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerJsonTransformTest, verifyTransformExpression) { + core::PlanNodePtr plan = + PlanBuilder() + .values({makeRowVector( + {"c0"}, + {transformIntermediateTypes(makeNullableFlatVector( + std::vector>{}, JSON()))})}) + .projectExpressions({getProjectionsToIntermediateTypes( + JSON(), + std::make_shared( + "c0", + std::nullopt, + std::vector{ + std::make_shared()}), + "c0")}) + .planNode(); + + VELOX_CHECK_EQ( + plan->toString(true, false), + "-- Project[1][expressions: (c0:JSON, try(json_parse(ROW[\"c0\"])))] -> c0:JSON\n"); + AssertQueryBuilder(plan).assertTypeAndNumRows(JSON(), 0); +} + +TEST_F(PrestoQueryRunnerJsonTransformTest, roundTrip) { + std::vector> no_nulls{"1", "2", "3"}; + test(makeNullableFlatVector(no_nulls, JSON())); + + std::vector> some_nulls{"1", std::nullopt, "3"}; + test(makeNullableFlatVector(some_nulls, JSON())); + + std::vector> all_nulls{ + std::nullopt, std::nullopt, std::nullopt}; + test(makeNullableFlatVector(all_nulls, JSON())); +} + +TEST_F(PrestoQueryRunnerJsonTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(JSON())); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(JSON()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(JSON(), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), JSON()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({JSON(), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({BOOLEAN(), ARRAY(JSON())}))); + ASSERT_TRUE(isIntermediateOnlyType( + ROW({SMALLINT(), TIMESTAMP(), ARRAY(ROW({MAP(VARCHAR(), JSON())}))}))); +} + +TEST_F(PrestoQueryRunnerJsonTransformTest, transformArray) { + std::vector> valid_json{ + "1", + "2", + "3", + "\"{}\"", + "[1,2,3]", + "{\"name\":\"John\"}", + "{\"product\":\"laptop\"}", + "{\"temperature\":23.5}", + "{\"isActive\":true}", + "{\"coordinates\":{\"latitude\":40.7128}}", + "{\"colors\":[\"red\"]}", + "{\"user\":{\"id\":123}}", + "{\"order\":{\"id\":456}}", + "{\"event\":\"concert\"}", + "{\"settings\":{\"volume\":75}}", + "{\"company\":{\"name\":\"TechCorp\"}}", + "{\"university\":{\"departments\":[{\"name\":\"ComputerScience\"}]}}", + "{\"library\":{\"books\":[{\"title\":\"1984\"}]}}", + "{\"restaurant\":{\"menu\":{\"appetizers\":[\"salad\"]}}}", + "{\"project\":{\"name\":\"Apollo\"}}", + }; + auto input = makeNullableFlatVector(valid_json, JSON()); + testArray(input); +} + +TEST_F(PrestoQueryRunnerJsonTransformTest, transformMap) { + std::vector> valid_json_keys{ + "\"key1\"", + "\"key2\"", + "\"key3\"", + "{\"address\":{\"city\":\"New York\"}}", + "{\"company\":{\"name\":\"TechCorp\"}}", + "{\"product\":\"Laptop\"}", + "\"key7\"", + "\"key8\"", + "\"key9\"", + "\"key10\""}; + auto keys = makeNullableFlatVector(valid_json_keys, JSON()); + std::vector> valid_json_values{ + "{\"name\":\"Alice\"}", + "{\"age\":30}", + "{\"address\":{\"city\":\"New York\"}}", + "{\"company\":{\"name\":\"TechCorp\"}}", + "{\"product\":\"Laptop\"}", + "{\"price\":999.99}", + "{\"user\":{\"id\":123}}", + "{\"order\":{\"id\":456}}", + "{\"event\":{\"name\":\"Conference\"}}", + "{\"library\":{\"books\":[{\"title\":\"1984\"}]}}"}; + auto values = makeNullableFlatVector(valid_json_values, JSON()); + testMap(keys, values); +} + +TEST_F(PrestoQueryRunnerJsonTransformTest, transformRow) { + std::vector> valid_json{ + "1", + "2", + "3", + "\"{}\"", + "[1,2,3]", + "{\"name\":\"John\"}", + "{\"product\":\"laptop\"}", + "{\"temperature\":23.5}", + "{\"isActive\":true}", + "{\"coordinates\":{\"latitude\":40.7128}}", + "{\"colors\":[\"red\"]}", + "{\"user\":{\"id\":123}}", + "{\"order\":{\"id\":456}}", + "{\"event\":\"concert\"}", + "{\"settings\":{\"volume\":75}}", + "{\"company\":{\"name\":\"TechCorp\"}}", + "{\"university\":{\"departments\":[{\"name\":\"ComputerScience\"}]}}", + "{\"library\":{\"books\":[{\"title\":\"1984\"}]}}", + "{\"restaurant\":{\"menu\":{\"appetizers\":[\"salad\"]}}}", + "{\"project\":{\"name\":\"Apollo\"}}", + }; + auto input = makeNullableFlatVector(valid_json, JSON()); + testRow({input}, {"row"}); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerKHyperLogLogTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerKHyperLogLogTransformTest.cpp new file mode 100644 index 000000000000..80a4f5a7cd16 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerKHyperLogLogTransformTest.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/KHyperLogLogType.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerKHyperLogLogTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(KHYPERLOGLOG())); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(KHYPERLOGLOG()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(KHYPERLOGLOG(), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), KHYPERLOGLOG()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({KHYPERLOGLOG(), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), + TIMESTAMP(), + ARRAY(ROW({MAP(VARCHAR(), KHYPERLOGLOG())}))}))); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transform) { + test(KHYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transformArray) { + testArray(KHYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transformMap) { + testMap(KHYPERLOGLOG()); +} + +TEST_F(PrestoQueryRunnerKHyperLogLogTransformTest, transformRow) { + testRow(KHYPERLOGLOG()); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerQDigestTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerQDigestTransformTest.cpp new file mode 100644 index 000000000000..173d4966d5e5 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerQDigestTransformTest.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/QDigestType.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerQDigestTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerQDigestTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(QDIGEST(DOUBLE()))); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(QDIGEST(DOUBLE())))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(QDIGEST(DOUBLE()), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), QDIGEST(DOUBLE())))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({QDIGEST(DOUBLE()), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), + TIMESTAMP(), + ARRAY(ROW({MAP(VARCHAR(), QDIGEST(DOUBLE()))}))}))); + + ASSERT_TRUE(isIntermediateOnlyType(QDIGEST(BIGINT()))); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(QDIGEST(BIGINT())))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(QDIGEST(BIGINT()), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), QDIGEST(BIGINT())))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({QDIGEST(BIGINT()), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), + TIMESTAMP(), + ARRAY(ROW({MAP(VARCHAR(), QDIGEST(BIGINT()))}))}))); + + ASSERT_TRUE(isIntermediateOnlyType(QDIGEST(REAL()))); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(QDIGEST(REAL())))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(QDIGEST(REAL()), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), QDIGEST(REAL())))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({QDIGEST(REAL()), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), + TIMESTAMP(), + ARRAY(ROW({MAP(VARCHAR(), QDIGEST(REAL()))}))}))); +} + +TEST_F(PrestoQueryRunnerQDigestTransformTest, transform) { + test(QDIGEST(DOUBLE())); + test(QDIGEST(BIGINT())); + test(QDIGEST(REAL())); +} + +TEST_F(PrestoQueryRunnerQDigestTransformTest, transformArray) { + testArray(QDIGEST(DOUBLE())); + testArray(QDIGEST(BIGINT())); + testArray(QDIGEST(REAL())); +} + +TEST_F(PrestoQueryRunnerQDigestTransformTest, transformMap) { + testMap(QDIGEST(DOUBLE())); + testMap(QDIGEST(BIGINT())); + testMap(QDIGEST(REAL())); +} + +TEST_F(PrestoQueryRunnerQDigestTransformTest, transformRow) { + testRow(QDIGEST(DOUBLE())); + testRow(QDIGEST(BIGINT())); + testRow(QDIGEST(REAL())); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerSetDigestTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerSetDigestTransformTest.cpp new file mode 100644 index 000000000000..af15b0dec719 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerSetDigestTransformTest.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/SetDigestType.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerSetDigestTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(SETDIGEST())); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(SETDIGEST()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(SETDIGEST(), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), SETDIGEST()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({SETDIGEST(), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), TIMESTAMP(), ARRAY(ROW({MAP(VARCHAR(), SETDIGEST())}))}))); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transform) { + test(SETDIGEST()); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transformArray) { + testArray(SETDIGEST()); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transformMap) { + testMap(SETDIGEST()); +} + +TEST_F(PrestoQueryRunnerSetDigestTransformTest, transformRow) { + testRow(SETDIGEST()); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerTDigestTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerTDigestTransformTest.cpp new file mode 100644 index 000000000000..4dd072cf8591 --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerTDigestTransformTest.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" +#include "velox/functions/prestosql/types/TDigestType.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerTDigestTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +TEST_F(PrestoQueryRunnerTDigestTransformTest, isIntermediateOnlyType) { + ASSERT_TRUE(isIntermediateOnlyType(TDIGEST(DOUBLE()))); + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(TDIGEST(DOUBLE())))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(TDIGEST(DOUBLE()), SMALLINT()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARBINARY(), TDIGEST(DOUBLE())))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({TDIGEST(DOUBLE()), SMALLINT()}))); + ASSERT_TRUE(isIntermediateOnlyType(ROW( + {SMALLINT(), + TIMESTAMP(), + ARRAY(ROW({MAP(VARCHAR(), TDIGEST(DOUBLE()))}))}))); +} + +TEST_F(PrestoQueryRunnerTDigestTransformTest, transform) { + test(TDIGEST(DOUBLE())); +} + +TEST_F(PrestoQueryRunnerTDigestTransformTest, transformArray) { + testArray(TDIGEST(DOUBLE())); +} + +TEST_F(PrestoQueryRunnerTDigestTransformTest, transformMap) { + testMap(TDIGEST(DOUBLE())); +} + +TEST_F(PrestoQueryRunnerTDigestTransformTest, transformRow) { + testRow(TDIGEST(DOUBLE())); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index ee8d1e7bcd2f..598d5633bd91 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -203,7 +203,7 @@ TEST_F(PrestoQueryRunnerTest, toSql) { EXPECT_EQ( queryRunner->toSql(plan), fmt::format( - "SELECT c0, c1, c2, first_value(c0) OVER (PARTITION BY c1 ORDER BY c2 ASC NULLS LAST {}) FROM (tmp)", + "SELECT c0, c1, c2, first_value(c0) OVER (PARTITION BY c1 ORDER BY c2 ASC NULLS LAST {}) as w0 FROM (tmp)", frameClause)); const auto firstValueFrame = @@ -222,7 +222,7 @@ TEST_F(PrestoQueryRunnerTest, toSql) { EXPECT_EQ( queryRunner->toSql(plan), fmt::format( - "SELECT c0, c1, c2, first_value(c0) OVER (PARTITION BY c1 ORDER BY c2 DESC NULLS FIRST {}), last_value(c0) OVER (PARTITION BY c1 ORDER BY c2 DESC NULLS FIRST {}) FROM (tmp)", + "SELECT c0, c1, c2, first_value(c0) OVER (PARTITION BY c1 ORDER BY c2 DESC NULLS FIRST {}) as w0, last_value(c0) OVER (PARTITION BY c1 ORDER BY c2 DESC NULLS FIRST {}) as w1 FROM (tmp)", firstValueFrame, lastValueFrame)); } diff --git a/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp new file mode 100644 index 000000000000..0e72861ff20b --- /dev/null +++ b/velox/exec/tests/PrestoQueryRunnerTimeTransformTest.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" + +namespace facebook::velox::exec::test { +namespace { + +class PrestoQueryRunnerTimeTransformTest + : public PrestoQueryRunnerIntermediateTypeTransformTestBase {}; + +// Test that TIME is recognized as an intermediate type that needs +// transformation +TEST_F(PrestoQueryRunnerTimeTransformTest, isIntermediateOnlyType) { + // Core test: TIME should be an intermediate type + ASSERT_TRUE(isIntermediateOnlyType(TIME())); + + // Complex types containing TIME should also be intermediate types + ASSERT_TRUE(isIntermediateOnlyType(ARRAY(TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(VARCHAR(), TIME()))); + ASSERT_TRUE(isIntermediateOnlyType(MAP(TIME(), VARCHAR()))); + ASSERT_TRUE(isIntermediateOnlyType(ROW({TIME(), BIGINT()}))); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, roundTrip) { + // Test basic TIME values (no nulls, some nulls, all nulls) + std::vector> no_nulls{0, 3661000, 43200000, 86399999}; + test(makeNullableFlatVector(no_nulls, TIME())); + + std::vector> some_nulls{ + 0, 3661000, std::nullopt, 86399999}; + test(makeNullableFlatVector(some_nulls, TIME())); + + std::vector> all_nulls{ + std::nullopt, std::nullopt, std::nullopt}; + test(makeNullableFlatVector(all_nulls, TIME())); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformArray) { + auto input = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 1000, // 00:00:01.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 (noon) + 86399999, // 23:59:59.999 + 3723456, // 01:02:03.456 + 45678901, // 12:41:18.901 + std::nullopt, + 72000000, // 20:00:00.000 + 36000000 // 10:00:00.000 + }, + TIME()); + testArray(input); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformMap) { + // keys can't be null for maps + auto keys = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 + 86399999, // 23:59:59.999 + 36000000, // 10:00:00.000 + 72000000, // 20:00:00.000 + 1800000, // 00:30:00.000 + 7200000, // 02:00:00.000 + 64800000, // 18:00:00.000 + 32400000 // 09:00:00.000 + }, + TIME()); + + auto values = makeNullableFlatVector( + {100, 200, std::nullopt, 400, 500, std::nullopt, 700, 800, 900, 1000}, + BIGINT()); + + testMap(keys, values); +} + +TEST_F(PrestoQueryRunnerTimeTransformTest, transformRow) { + auto input = makeNullableFlatVector( + std::vector>{ + 0, // 00:00:00.000 + 3661000, // 01:01:01.000 + 43200000, // 12:00:00.000 + 86399999, // 23:59:59.999 + std::nullopt, + 36000000, // 10:00:00.000 + 72000000, // 20:00:00.000 + 1800000, // 00:30:00.000 + 7200000, // 02:00:00.000 + 64800000 // 18:00:00.000 + }, + TIME()); + testRow({input}, {"time_col"}); +} + +} // namespace +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp b/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp index 92445e94cbd9..4b60208686ea 100644 --- a/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTimestampWithTimeZoneTransformTest.cpp @@ -15,16 +15,14 @@ */ #include "velox/exec/fuzzer/PrestoQueryRunnerIntermediateTypeTransforms.h" -#include "velox/exec/tests/utils/AssertQueryBuilder.h" -#include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/exec/tests/PrestoQueryRunnerIntermediateTypeTransformTestBase.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/functions/prestosql/types/fuzzer_utils/TimestampWithTimeZoneInputGenerator.h" namespace facebook::velox::exec::test { namespace { class PrestoQueryRunnerTimestampWithTimeZoneTransformTest - : public functions::test::FunctionBaseTest { + : public PrestoQueryRunnerIntermediateTypeTransformTestBase { public: VectorPtr fuzzTimestampWithTimeZone( const size_t seed, @@ -44,51 +42,6 @@ class PrestoQueryRunnerTimestampWithTimeZoneTransformTest return makeNullableFlatVector(values, TIMESTAMP_WITH_TIME_ZONE()); } - - void test(const VectorPtr& vector) { - const auto colName = "col"; - const auto input = - makeRowVector({colName}, {transformIntermediateOnlyType(vector)}); - - auto expr = getIntermediateOnlyTypeProjectionExpr( - vector->type(), - std::make_shared( - colName, - std::nullopt, - std::vector{std::make_shared()}), - colName); - - core::PlanNodePtr plan = - PlanBuilder().values({input}).projectExpressions({expr}).planNode(); - - AssertQueryBuilder(plan).assertResults(makeRowVector({colName}, {vector})); - } - - void testDictionary(const VectorPtr& base) { - // Wrap in a dictionary without nulls. - test(BaseVector::wrapInDictionary( - nullptr, makeIndicesInReverse(100), 100, base)); - // Wrap in a dictionary with some nulls. - test(BaseVector::wrapInDictionary( - makeNulls(100, [](vector_size_t row) { return row % 10 == 0; }), - makeIndicesInReverse(100), - 100, - base)); - // Wrap in a dictionary with all nulls. - test(BaseVector::wrapInDictionary( - makeNulls(100, [](vector_size_t) { return true; }), - makeIndicesInReverse(100), - 100, - base)); - } - - void testConstant(const VectorPtr& base) { - // Test a non-null constant. - test(BaseVector::wrapInConstant(100, 0, base)); - // Test a null constant. - test(BaseVector::createNullConstant( - ARRAY(TIMESTAMP_WITH_TIME_ZONE()), 100, pool_.get())); - } }; TEST_F( @@ -145,8 +98,8 @@ TEST_F( transformIntermediateOnlyTypeTimestampWithTimeZoneArray) { auto elements = fuzzTimestampWithTimeZone(0, 0.1, 1000); auto size = 100; - std::vector offsets(size + 1); - for (int i = 0; i < size + 1; i++) { + std::vector offsets; + for (int i = 0; i < size; i++) { offsets.push_back(i * 10); } // Test array vector no nulls. @@ -177,8 +130,8 @@ TEST_F( auto keys = fuzzTimestampWithTimeZone(0, 0, 1000); auto values = fuzzTimestampWithTimeZone(1, 0.1, 1000); auto size = 100; - std::vector offsets(size + 1); - for (int i = 0; i < size + 1; i++) { + std::vector offsets; + for (int i = 0; i < size; i++) { offsets.push_back(i * 10); } // Test map vector no nulls. @@ -216,20 +169,22 @@ TEST_F( test(vectorMaker_.rowVector({field1, field2, field3})); // Test row vector some nulls. - test(std::make_shared( - pool_.get(), - rowType, - makeNulls(size, [](vector_size_t row) { return row % 10 == 0; }), - size, - std::vector{field1, field2, field3})); + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t row) { return row % 10 == 0; }), + size, + std::vector{field1, field2, field3})); // Test row vector all nulls. - test(std::make_shared( - pool_.get(), - rowType, - makeNulls(size, [](vector_size_t) { return true; }), - size, - std::vector{field1, field2, field3})); + test( + std::make_shared( + pool_.get(), + rowType, + makeNulls(size, [](vector_size_t) { return true; }), + size, + std::vector{field1, field2, field3})); const auto base = vectorMaker_.rowVector({field1, field2, field3}); testDictionary(base); diff --git a/velox/exec/tests/PrintPlanWithStatsTest.cpp b/velox/exec/tests/PrintPlanWithStatsTest.cpp index 4064ea156c87..c1dbfc1303b8 100644 --- a/velox/exec/tests/PrintPlanWithStatsTest.cpp +++ b/velox/exec/tests/PrintPlanWithStatsTest.cpp @@ -48,6 +48,12 @@ void compareOutputs( for (; std::getline(iss, line);) { lineCount++; std::vector potentialLines; + if (expectedLineIndex >= expectedRegex.size()) { + ASSERT_FALSE(true) << "Output has more lines than expected." + << "\n Source: " << testName + << "\n Line number: " << lineCount + << "\n Unexpected Line: " << line; + } auto expectedLine = expectedRegex.at(expectedLineIndex++); while (!RE2::FullMatch(line, expectedLine.line)) { potentialLines.push_back(expectedLine.line); @@ -59,11 +65,18 @@ void compareOutputs( << "\n Expected Line one of: " << folly::join(",", potentialLines); } + if (expectedLineIndex >= expectedRegex.size()) { + ASSERT_FALSE(true) + << "Output did not match and no more patterns to check." + << "\n Source: " << testName << "\n Line number: " << lineCount + << "\n Line: " << line + << "\n Expected Line one of: " << folly::join(",", potentialLines); + } expectedLine = expectedRegex.at(expectedLineIndex++); } } for (int i = expectedLineIndex; i < expectedRegex.size(); i++) { - ASSERT_TRUE(expectedRegex[expectedLineIndex].optional); + ASSERT_TRUE(expectedRegex[i].optional); } } @@ -157,6 +170,7 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { {" dataSourceLazyCpuNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyInputBytes[ ]* sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, @@ -164,10 +178,12 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { {" Output: 2000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, {" HashBuild: Input: 100 rows \\(.+\\), Output: 0 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, {" distinctKey0\\s+sum: 101, count: 1, min: 101, max: 101, avg: 101"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" hashtable.buildWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" hashtable.capacity\\s+sum: 200, count: 1, min: 200, max: 200, avg: 200"}, {" hashtable.numDistinct\\s+sum: 100, count: 1, min: 100, max: 100, avg: 100"}, {" hashtable.numRehashes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" hashtable.vectorHasherMergeCpuNanos\\s+sum: .*, count: 1, min: .*, max: .*, avg: .*"}, {" queuedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" rangeKey0\\s+sum: 200, count: 1, min: 200, max: 200, avg: 200"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, @@ -180,6 +196,8 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { true}, {" blockedWaitForJoinBuildWallNanos\\s+sum: .+, count: 1, min: .+, max: .+", true}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+", + true}, {" dynamicFiltersProduced\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1", true}, {" queuedWallNanos\\s+sum: .+, count: 1, min: .+, max: .+", true}, @@ -191,38 +209,44 @@ TEST_F(PrintPlanWithStatsTest, innerJoinWithTableScan) { {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" -- TableScan\\[2\\]\\[table: hive_table\\] -> c0:INTEGER, c1:BIGINT"}, {" Input: 2000 rows \\(.+\\), Raw Input: 20480 rows \\(.+\\), Output: 2000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 20, DynamicFilter producer plan nodes: 3, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceAddSplitWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" dataSourceReadWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" dynamicFiltersAccepted[ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+"}, - {" maxSingleIoWaitWallNanos[ ]*sum: .+, count: 1, min: .+, max: .+"}, {" numPrefetch [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" numRamRead [ ]* sum: 60, count: 1, min: 60, max: 60, avg: 60"}, - {" numStorageRead [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, {" overreadBytes[ ]* sum: 0B, count: 1, min: 0B, max: 0B, avg: 0B"}, - {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" prefetchBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" preloadSplitPrepareTimeNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" processedSplits[ ]+sum: 20, count: 1, min: 20, max: 20, avg: 20"}, {" processedStrides[ ]+sum: 20, count: 1, min: 20, max: 20, avg: 20"}, - {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" ramReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" totalRemainingFilterTime\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" storageReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, + {" waitForPreloadSplitNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}, {" -- Project\\[1\\]\\[expressions: \\(u_c0:INTEGER, ROW\\[\"c0\"\\]\\), \\(u_c1:BIGINT, ROW\\[\"c1\"\\]\\)\\] -> u_c0:INTEGER, u_c1:BIGINT"}, {" Output: 100 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: 0B, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" -- Values\\[0\\]\\[100 rows in 1 vectors\\] -> c0:INTEGER, c1:BIGINT"}, {" Input: 0 rows \\(.+\\), Output: 100 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: 0B, Memory allocations: .+, Threads: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}}); @@ -275,6 +299,7 @@ TEST_F(PrintPlanWithStatsTest, partialAggregateWithTableScan) { {" dataSourceLazyInputBytes\\s+sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, {" distinctKey0\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" hashtable.capacity\\s+sum: (?:1273|1252), count: 1, min: (?:1273|1252), max: (?:1273|1252), avg: (?:1273|1252)"}, {" hashtable.numDistinct\\s+sum: (?:849|835), count: 1, min: (?:849|835), max: (?:849|835), avg: (?:849|835)"}, {" hashtable.numRehashes\\s+sum: 1, count: 1, min: 1, max: 1, avg: 1"}, @@ -285,31 +310,33 @@ TEST_F(PrintPlanWithStatsTest, partialAggregateWithTableScan) { {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" -- TableScan\\[0\\]\\[table: hive_table\\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR"}, {" Input: 10000 rows \\(.+\\), Output: 10000 rows \\(.+\\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+)"}, + {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceAddSplitWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" dataSourceReadWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+"}, - {" maxSingleIoWaitWallNanos[ ]*sum: .+, count: 1, min: .+, max: .+"}, {" numPrefetch [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" numRamRead [ ]* sum: 7, count: 1, min: 7, max: 7, avg: 7"}, - {" numStorageRead [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, {" overreadBytes[ ]* sum: 0B, count: 1, min: 0B, max: 0B, avg: 0B"}, {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, - {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" ramReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" totalRemainingFilterTime\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}}); + {" storageReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}}); } } @@ -349,40 +376,44 @@ TEST_F(PrintPlanWithStatsTest, tableWriterWithTableScan) { {" dataSourceLazyCpuNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyInputBytes\\s+sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceLazyWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" dwrfWriterCount\\s+sum: .+, count: 1, min: .+, max: .+"}, {" numWrittenFiles\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningWallNanos\\s+sum: .+, count: 1, min: .+, max: .+, avg: .+"}, {" stripeSize\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" writeIOWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, + {" writeIOWallNanos\\s+sum: .+, count: 1, min: .+, max: .+, avg: .+"}, {R"( -- TableScan\[0\]\[table: hive_table\] -> c0:BIGINT, c1:INTEGER, c2:SMALLINT, c3:REAL, c4:DOUBLE, c5:VARCHAR)"}, {R"( Input: 100 rows \(.+\), Output: 100 rows \(.+\), Cpu time: .+, Blocked wall time: .+, Peak memory: .+, Memory allocations: .+, Threads: 1, Splits: 1, CPU breakdown: B/I/O/F (.+/.+/.+/.+))"}, + {" connectorSplitSize[ ]* sum: .+, count: .+, min: .+, max: .+"}, {" dataSourceAddSplitWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" dataSourceReadWallNanos[ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" dataSourceReadWallNanos[ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" driverCpuTimeNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" footerBufferOverread[ ]* sum: .+, count: 1, min: .+, max: .+"}, {" ioWaitWallNanos [ ]* sum: .+, count: .+ min: .+, max: .+"}, - {" maxSingleIoWaitWallNanos[ ]*sum: .+, count: 1, min: .+, max: .+"}, {" numPrefetch [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" numRamRead [ ]* sum: 7, count: 1, min: 7, max: 7, avg: 7"}, - {" numStorageRead [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" numStripes[ ]* sum: .+, count: 1, min: .+, max: .+"}, {" overreadBytes[ ]* sum: 0B, count: 1, min: 0B, max: 0B, avg: 0B"}, {" prefetchBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, {" processedSplits [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, {" processedStrides [ ]* sum: 1, count: 1, min: 1, max: 1, avg: 1"}, + {" processedUnits [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" preloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, - {" ramReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, + {" ramReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, {" readyPreloadedSplits[ ]+sum: .+, count: .+, min: .+, max: .+", true}, {" runningAddInputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningFinishWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, {" runningGetOutputWallNanos\\s+sum: .+, count: 1, min: .+, max: .+"}, - {" storageReadBytes [ ]* sum: .+, count: 1, min: .+, max: .+"}, - {" totalRemainingFilterTime\\s+sum: .+, count: .+, min: .+, max: .+"}, - {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}}); + {" storageReadBytes [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" totalRemainingFilterWallNanos\\s+sum: .+, count: .+, min: .+, max: .+"}, + {" totalScanTime [ ]* sum: .+, count: .+, min: .+, max: .+"}, + {" unitLoadNanos[ ]* sum: .+, count: .+, min: .+, max: .+, avg: .+"}}); } TEST_F(PrintPlanWithStatsTest, taskAPI) { diff --git a/velox/exec/tests/QueryAssertionsTest.cpp b/velox/exec/tests/QueryAssertionsTest.cpp index 21b5d60099a0..eb08c7b14b7e 100644 --- a/velox/exec/tests/QueryAssertionsTest.cpp +++ b/velox/exec/tests/QueryAssertionsTest.cpp @@ -182,7 +182,7 @@ TEST_F(QueryAssertionsTest, singleFloatColumn) { size, [&](auto row) { return row == 302 - ? 2.01 + std::max(kEpsilon, double(6 * FLT_EPSILON)) + ? 2.01 + std::max(Variant::kEpsilon, double(6 * FLT_EPSILON)) : row % 6 + 0.01; }, nullEvery(7)), @@ -282,15 +282,17 @@ TEST_F(QueryAssertionsTest, multiFloatColumnWithUniqueKeys) { makeFlatVector( size, [&](auto row) { - return row == 6 ? 2 + std::max(float(kEpsilon), 6 * FLT_EPSILON) - : row % 4; + return row == 6 + ? 2 + std::max(float(Variant::kEpsilon), 6 * FLT_EPSILON) + : row % 4; }, nullEvery(5)), makeFlatVector( size, [&](auto row) { - return row == 1 ? 1.01 + std::max(kEpsilon, double(3 * FLT_EPSILON)) - : row % 6 + 0.01; + return row == 1 + ? 1.01 + std::max(Variant::kEpsilon, double(3 * FLT_EPSILON)) + : row % 6 + 0.01; }, nullEvery(7)), }); diff --git a/velox/exec/tests/RowContainerTest.cpp b/velox/exec/tests/RowContainerTest.cpp index d7d2622caeff..281ef6e86858 100644 --- a/velox/exec/tests/RowContainerTest.cpp +++ b/velox/exec/tests/RowContainerTest.cpp @@ -108,7 +108,8 @@ class RowContainerTestHelper { RowContainer* const rowContainer_; }; -class RowContainerTest : public exec::test::RowContainerTestBase { +class RowContainerTest : public exec::test::RowContainerTestBase, + public testing::WithParamInterface { protected: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); @@ -344,7 +345,9 @@ class RowContainerTest : public exec::test::RowContainerTestBase { sum += data.rowSize(row) - data.fixedRowSize(); } auto usage = data.stringAllocator().currentBytes(); - EXPECT_EQ(usage, sum); + if (data.testingRowPointers().empty()) { + EXPECT_EQ(usage, sum); + } } std::vector store( @@ -484,11 +487,12 @@ class RowContainerTest : public exec::test::RowContainerTestBase { const VectorPtr& expected, std::optional flags) { // If no flags provided then it must be the default of {true, true}. - SCOPED_TRACE(fmt::format( - "{}, ascending = {}, nullsFirst = {}", - type->toString(), - flags.has_value() ? flags.value().ascending : true, - flags.has_value() ? flags.value().nullsFirst : true)); + SCOPED_TRACE( + fmt::format( + "{}, ascending = {}, nullsFirst = {}", + type->toString(), + flags.has_value() ? flags.value().ascending : true, + flags.has_value() ? flags.value().nullsFirst : true)); // Set 'isJoinBuild' to true to enable nullable sort key in test. auto rowContainer = makeRowContainer({type}, {type}, false); @@ -570,8 +574,8 @@ class RowContainerTest : public exec::test::RowContainerTestBase { for (auto row : rows) { ASSERT_EQ( expected[index], - rowContainer->equals( - row, rowContainer->columnAt(0), rhsDecoded, index)) + rowContainer->compare( + row, rowContainer->columnAt(0), rhsDecoded, index) == 0) << fmt::format( "Mismatch at index {} with canHandleNulls {}", index, @@ -935,7 +939,7 @@ static int32_t sign(int32_t n) { } } // namespace -TEST_F(RowContainerTest, extractWithNullsAndTargetOffset) { +TEST_P(RowContainerTest, extractWithNullsAndTargetOffset) { constexpr int32_t kNumRows = 100; // The second column must have no nulls in the first batch. auto rowVector1 = makeRowVector({ @@ -961,8 +965,8 @@ TEST_F(RowContainerTest, extractWithNullsAndTargetOffset) { // Create and fill up two row containers from two row vectors. std::vector vecTypes = {BOOLEAN(), VARCHAR(), TINYINT()}; RowTypePtr rowType = VectorMaker::rowType({BOOLEAN(), VARCHAR(), TINYINT()}); - auto data1 = makeRowContainer({}, vecTypes); - auto data2 = makeRowContainer({}, vecTypes); + auto data1 = makeRowContainer({}, vecTypes, true, GetParam()); + auto data2 = makeRowContainer({}, vecTypes, true, GetParam()); for (auto i = 0; i < kNumRows; i++) { data1->newRow(); data2->newRow(); @@ -1045,7 +1049,7 @@ TEST_F(RowContainerTest, storeExtractArrayOfVarchar) { roundTrip(input); } -TEST_F(RowContainerTest, types) { +TEST_P(RowContainerTest, types) { constexpr int32_t kNumRows = 100; auto batch = makeDataset( ROW( @@ -1105,7 +1109,7 @@ TEST_F(RowContainerTest, types) { std::vector dependents; dependents.insert( dependents.begin(), types.begin() + types.size() / 2, types.end()); - auto data = makeRowContainer(keys, dependents); + auto data = makeRowContainer(keys, dependents, true, GetParam()); EXPECT_GT(data->nextOffset(), 0); EXPECT_GT(data->probedFlagOffset(), 0); @@ -1177,11 +1181,13 @@ TEST_F(RowContainerTest, types) { EXPECT_EQ(source->hashValueAt(i), hashes[i]); // Test non-null and nullable variants of equals. if (column < keys.size()) { - EXPECT_TRUE( - data->equals(rows[i], data->columnAt(column), decoded, i)); + EXPECT_EQ( + data->compare(rows[i], data->columnAt(column), decoded, i), + 0); } else if (!columnType->isMap()) { - EXPECT_TRUE( - data->equals(rows[i], data->columnAt(column), decoded, i)); + EXPECT_EQ( + data->compare(rows[i], data->columnAt(column), decoded, i), + 0); } // Non-key map columns are not comparable, as the map keys are not sorted. if (columnType->isMap() && column >= keys.size()) { @@ -1215,7 +1221,7 @@ TEST_F(RowContainerTest, types) { EXPECT_LT(0, free.second); } -TEST_F(RowContainerTest, extractNulls) { +TEST_P(RowContainerTest, extractNulls) { constexpr int32_t kNumRows = 100; auto batch = makeRowVector({ makeFlatVector( @@ -1266,7 +1272,7 @@ TEST_F(RowContainerTest, extractNulls) { ARRAY(INTEGER()), MAP(INTEGER(), INTEGER()), ROW({INTEGER(), INTEGER()})}; - auto data = makeRowContainer({}, rowType); + auto data = makeRowContainer({}, rowType, true, GetParam()); for (int i = 0; i < kNumRows; i++) { data->newRow(); } @@ -1347,11 +1353,11 @@ TEST_F(RowContainerTest, erase) { RowContainerTestHelper(data.get()).checkConsistency(); } -TEST_F(RowContainerTest, initialNulls) { +TEST_P(RowContainerTest, initialNulls) { std::vector keys{INTEGER()}; std::vector dependent{INTEGER()}; // Join build. - auto data = makeRowContainer(keys, dependent, true); + auto data = makeRowContainer(keys, dependent, true, GetParam()); auto row = data->newRow(); auto isNullAt = [](const RowContainer& data, const char* row, int32_t i) { auto column = data.columnAt(i); @@ -1361,7 +1367,7 @@ TEST_F(RowContainerTest, initialNulls) { EXPECT_FALSE(isNullAt(*data, row, 0)); EXPECT_FALSE(isNullAt(*data, row, 1)); // Non-join build. - data = makeRowContainer(keys, dependent, false); + data = makeRowContainer(keys, dependent, false, GetParam()); row = data->newRow(); EXPECT_FALSE(isNullAt(*data, row, 0)); EXPECT_FALSE(isNullAt(*data, row, 1)); @@ -1369,7 +1375,7 @@ TEST_F(RowContainerTest, initialNulls) { TEST_F(RowContainerTest, rowSize) { constexpr int32_t kNumRows = 100; - auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}); + auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}, true); // The layout is expected to be smallint - 6 bytes of padding - 1 byte of bits // - StringView - rowSize - next pointer. The bits are a null flag for the @@ -1405,10 +1411,10 @@ TEST_F(RowContainerTest, rowSize) { EXPECT_EQ(rows, rowsFromContainer); } -TEST_F(RowContainerTest, columnSize) { +TEST_P(RowContainerTest, columnSize) { const uint64_t kNumRows = 1000; - auto rowContainer = - makeRowContainer({BIGINT(), VARCHAR()}, {BIGINT(), VARCHAR()}); + auto rowContainer = makeRowContainer( + {BIGINT(), VARCHAR()}, {BIGINT(), VARCHAR()}, true, GetParam()); VectorFuzzer fuzzer( { @@ -1452,8 +1458,8 @@ TEST_F(RowContainerTest, columnSize) { } } -TEST_F(RowContainerTest, rowSizeWithNormalizedKey) { - auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}); +TEST_P(RowContainerTest, rowSizeWithNormalizedKey) { + auto data = makeRowContainer({SMALLINT()}, {VARCHAR()}, true, GetParam()); data->newRow(); data->disableNormalizedKeys(); data->newRow(); @@ -1469,7 +1475,7 @@ TEST_F(RowContainerTest, estimateRowSize) { // Make a RowContainer with a fixed-length key column and a variable-length // dependent column. - auto rowContainer = makeRowContainer({BIGINT()}, {VARCHAR()}); + auto rowContainer = makeRowContainer({BIGINT()}, {VARCHAR()}, true); EXPECT_FALSE(rowContainer->estimateRowSize().has_value()); // Store rows to the container. @@ -1512,6 +1518,7 @@ TEST_F(RowContainerTest, alignment) { false, true, true, + false, pool_.get()); constexpr int kNumRows = 100; char* rows[kNumRows]; @@ -1679,6 +1686,7 @@ TEST_F(RowContainerTest, probedFlag) { true, // isJoinBuild true, // hasProbedFlag false, // hasNormalizedKey + false, // useListRowIndex pool_.get()); auto input = makeRowVector({ @@ -1868,8 +1876,10 @@ TEST_F(RowContainerTest, unknown) { } for (size_t row = 0; row < size; ++row) { - ASSERT_TRUE(rowContainer->equals( - rows[row], rowContainer->columnAt(0), decoded, row)); + ASSERT_EQ( + rowContainer->compare( + rows[row], rowContainer->columnAt(0), decoded, row), + 0); } { @@ -1892,24 +1902,29 @@ TEST_F(RowContainerTest, unknown) { // Verify compare method with row and decoded vector as input // Sorting a NULL constant Vector doesn't change the Vector, so we just // validate that it runs without throwing an exception. - EXPECT_NO_THROW(std::sort( - indexedRows.begin(), - indexedRows.end(), - [&](const std::pair& l, const std::pair& r) { - return rowContainer->compare( - l.second, rowContainer->columnAt(0), decoded, r.first, {}) < - 0; - })); + EXPECT_NO_THROW( + std::sort( + indexedRows.begin(), + indexedRows.end(), + [&](const std::pair& l, const std::pair& r) { + return rowContainer->compare( + l.second, + rowContainer->columnAt(0), + decoded, + r.first, + {}) < 0; + })); // Verify compareRows method with row as input. // Sorting a NULL constant Vector doesn't change the Vector, so we just // validate that it runs without throwing an exception. - EXPECT_NO_THROW(std::sort( - indexedRows.begin(), - indexedRows.end(), - [&](const std::pair& l, const std::pair& r) { - return rowContainer->compareRows(l.second, r.second) < 0; - })); + EXPECT_NO_THROW( + std::sort( + indexedRows.begin(), + indexedRows.end(), + [&](const std::pair& l, const std::pair& r) { + return rowContainer->compareRows(l.second, r.second) < 0; + })); } TEST_F(RowContainerTest, nans) { @@ -1945,8 +1960,10 @@ TEST_F(RowContainerTest, nans) { // Verify that they are considered equal. for (size_t row = 0; row < size; ++row) { - ASSERT_TRUE(rowContainer->equals( - rows[row], rowContainer->columnAt(0), decoded, row)); + ASSERT_EQ( + rowContainer->compare( + rows[row], rowContainer->columnAt(0), decoded, row), + 0); } ASSERT_EQ(rowContainer->compare(rows[0], rows[1], 0, {}), 0); } @@ -1989,10 +2006,10 @@ TEST_F(RowContainerTest, toString) { EXPECT_EQ( rowContainer->toString(rows[0]), - "{1, summer, 11, 0.10000000149011612, 3 elements starting at 0 {1, 2, 3}}"); + "{1, summer, 11, 0.10000000149011612, {1, 2, 3}}"); EXPECT_EQ( rowContainer->toString(rows[1]), - "{2, fall, 0, 2.3399999141693115, 2 elements starting at 0 {4, 5}}"); + "{2, fall, 0, 2.3399999141693115, {4, 5}}"); EXPECT_EQ( rowContainer->toString(rows[2]), "{3, winter, 12, 123.00299835205078, null}"); @@ -2133,7 +2150,7 @@ DEBUG_ONLY_TEST_F(RowContainerTest, eraseAfterOomStoringString) { rowContainer->eraseRows(folly::Range(rows.data(), numRows)); } -TEST_F(RowContainerTest, hugeIntStoreWithNulls) { +TEST_P(RowContainerTest, hugeIntStoreWithNulls) { constexpr int32_t kNumRows = 100; constexpr int32_t kColumnIndex = 0; @@ -2166,7 +2183,7 @@ TEST_F(RowContainerTest, hugeIntStoreWithNulls) { dictNulls, dictIndices, kNumRows, hugeIntVector); std::vector keys; - auto data = makeRowContainer({HUGEINT()}, {}, false); + auto data = makeRowContainer({HUGEINT()}, {}, false, GetParam()); std::vector rows(kNumRows); for (auto i = 0; i < kNumRows; ++i) { rows[i] = data->newRow(); @@ -2181,9 +2198,9 @@ TEST_F(RowContainerTest, hugeIntStoreWithNulls) { assertEqualVectors(source, extracted); } -TEST_F(RowContainerTest, columnHasNulls) { - auto rowContainer = - makeRowContainer({BIGINT(), BIGINT()}, {BIGINT(), BIGINT()}, false); +TEST_P(RowContainerTest, columnHasNulls) { + auto rowContainer = makeRowContainer( + {BIGINT(), BIGINT()}, {BIGINT(), BIGINT()}, false, GetParam()); for (int i = 0; i < rowContainer->columnTypes().size(); ++i) { ASSERT_FALSE(rowContainer->columnHasNulls(i)); } @@ -2234,7 +2251,7 @@ TEST_F(RowContainerTest, columnHasNulls) { } } -TEST_F(RowContainerTest, store) { +TEST_P(RowContainerTest, store) { const uint64_t kNumRows = 1000; auto rowVectorWithNulls = makeRowVector({ makeFlatVector( @@ -2264,7 +2281,7 @@ TEST_F(RowContainerTest, store) { }); for (auto& rowVector : {rowVectorWithNulls, rowVectorNoNulls}) { auto rowContainer = makeRowContainer( - {BIGINT(), VARCHAR()}, {BIGINT(), ARRAY(BIGINT())}, false); + {BIGINT(), VARCHAR()}, {BIGINT(), ARRAY(BIGINT())}, false, GetParam()); std::vector rows; rows.reserve(kNumRows); @@ -2408,7 +2425,7 @@ TEST_F(RowContainerTest, customComparisonRow) { }); } -TEST_F(RowContainerTest, isNanAt) { +TEST_P(RowContainerTest, isNanAt) { const auto kNan = std::numeric_limits::quiet_NaN(); const auto kNanF = std::numeric_limits::quiet_NaN(); auto rowVector = makeRowVector({ @@ -2419,8 +2436,8 @@ TEST_F(RowContainerTest, isNanAt) { }); const auto kNumRows = rowVector->size(); - auto rowContainer = - makeRowContainer({REAL(), DOUBLE()}, {REAL(), DOUBLE()}, false); + auto rowContainer = makeRowContainer( + {REAL(), DOUBLE()}, {REAL(), DOUBLE()}, false, GetParam()); std::vector rows; rows.reserve(kNumRows); @@ -2653,7 +2670,7 @@ TEST_F(RowContainerTest, rowColumnStats) { EXPECT_EQ(stats.nullCount(), 3); } -TEST_F(RowContainerTest, storeAndCollectColumnStats) { +TEST_P(RowContainerTest, storeAndCollectColumnStats) { const uint64_t kNumRows = 1000; auto rowVector = makeRowVector({ makeFlatVector( @@ -2664,7 +2681,8 @@ TEST_F(RowContainerTest, storeAndCollectColumnStats) { nullEvery(7)), }); - auto rowContainer = makeRowContainer({BIGINT(), VARCHAR()}, {}, false); + auto rowContainer = + makeRowContainer({BIGINT(), VARCHAR()}, {}, false, GetParam()); std::vector rows; rows.reserve(kNumRows); @@ -2709,4 +2727,47 @@ TEST_F(RowContainerTest, storeAndCollectColumnStats) { } } +TEST_F(RowContainerTest, setAllNull) { + std::vector keyTypes = {INTEGER()}; + std::vector accumulators{Accumulator( + true, 8, false, 8, INTEGER(), [](auto, auto) {}, [](auto) {})}; + + auto rowContainer = std::make_unique( + keyTypes, + true, + accumulators, + std::vector{}, + false, + true, + false, + false, + false, + pool_.get()); + + auto row = rowContainer->newRow(); + + auto keyColumn = rowContainer->columnAt(0); + auto accColumn = rowContainer->columnAt(1); + ASSERT_FALSE( + RowContainer::isNullAt(row, keyColumn.nullByte(), keyColumn.nullMask())); + ASSERT_FALSE( + RowContainer::isNullAt(row, accColumn.nullByte(), accColumn.nullMask())); + ASSERT_EQ( + (row[accColumn.initializedByte()] & accColumn.initializedMask()), 0); + + rowContainer->setAllNull(row); + + ASSERT_TRUE( + RowContainer::isNullAt(row, keyColumn.nullByte(), keyColumn.nullMask())); + ASSERT_TRUE( + RowContainer::isNullAt(row, accColumn.nullByte(), accColumn.nullMask())); + ASSERT_EQ( + (row[accColumn.initializedByte()] & accColumn.initializedMask()), 0); +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + RowContainerTest, + RowContainerTest, + testing::ValuesIn({false, true})); + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/RowNumberTest.cpp b/velox/exec/tests/RowNumberTest.cpp index 84e272d42739..7a00000f76a9 100644 --- a/velox/exec/tests/RowNumberTest.cpp +++ b/velox/exec/tests/RowNumberTest.cpp @@ -218,11 +218,12 @@ TEST_F(RowNumberTest, spill) { core::QueryConfig::kSpillNumPartitionBits, testData.spillPartitionBits) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); @@ -404,11 +405,12 @@ DEBUG_ONLY_TEST_F(RowNumberTest, spillOnlyDuringInputOrOutput) { core::QueryConfig::kSpillNumPartitionBits, testData.spillPartitionBits) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); @@ -488,11 +490,12 @@ DEBUG_ONLY_TEST_F(RowNumberTest, recursiveSpill) { .config( core::QueryConfig::kMaxSpillLevel, testData.maxSpillLevel - 1) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); @@ -550,11 +553,12 @@ TEST_F(RowNumberTest, spillWithYield) { core::QueryConfig::kDriverCpuTimeSliceLimitMs, testData.cpuTimeSliceLimitMs) .queryCtx(queryCtx) - .plan(PlanBuilder() - .values(vectors) - .rowNumber({"c0"}) - .capturePlanNodeId(rowNumberPlanNodeId) - .planNode()) + .plan( + PlanBuilder() + .values(vectors) + .rowNumber({"c0"}) + .capturePlanNodeId(rowNumberPlanNodeId) + .planNode()) .assertResults( "SELECT *, row_number() over (partition by c0) FROM tmp"); auto taskStats = toPlanStats(task->taskStats()); diff --git a/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp b/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp index 43a07e7c2ddc..011d3876311d 100644 --- a/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp +++ b/velox/exec/tests/ScaleWriterLocalPartitionTest.cpp @@ -81,6 +81,10 @@ class TestExchangeController { if (holdBufferBytes_ == 0) { return; } + if (holdBuffer_ != nullptr) { + return; + } + holdPool_ = pool; holdBuffer_ = holdPool_->allocate(holdBufferBytes_); } @@ -227,11 +231,6 @@ class FakeSourceOperator : public SourceOperator { private: void initialize() override { Operator::initialize(); - - if (operatorCtx_->driverCtx()->driverId != 0) { - return; - } - testController_->maybeHoldBuffer(pool()); } @@ -495,10 +494,11 @@ class ScaleWriterLocalPartitionTest : public HiveConnectorTestBase { for (const auto& name : rowType_->names()) { orderByKeys.push_back(fmt::format("{} ASC NULLS FIRST", name)); } - AssertQueryBuilder queryBuilder(PlanBuilder() - .values(inputVectors) - .orderBy(orderByKeys, false) - .planNode()); + AssertQueryBuilder queryBuilder( + PlanBuilder() + .values(inputVectors) + .orderBy(orderByKeys, false) + .planNode()); return queryBuilder.copyResults(pool_.get()); } @@ -812,11 +812,10 @@ TEST_F(ScaleWriterLocalPartitionTest, partitionBasic) { {4, 4, 4, 1ULL << 30, 1.0, 0, {1, 2}, 0.8, 0.6, false}, {1, 4, 4, 0, 1.0, 0, {1, 2}, 0.3, 0.2, false}, {4, 4, 4, 0, 1.0, 0, {1, 2}, 0.3, 0.2, false}, - {1, 4, 4, 0, 0.1, queryCapacity / 2, {1, 2}, 0.8, 0.6, false}, - {4, 4, 4, 0, 0.1, queryCapacity / 2, {1, 2}, 0.8, 0.6, false}, + {1, 4, 4, 0, 0.0001, queryCapacity / 2, {1, 2}, 0.8, 0.6, false}, + {4, 4, 4, 0, 0.0001, queryCapacity / 2, {1, 2}, 0.8, 0.6, false}, {1, 32, 128, 0, 1.0, 0, {1, 2, 3, 4, 5, 6, 7, 8}, 0.8, 0.6, true}, - {4, 32, 128, 0, 1.0, 0, {1, 2, 3, 4, 5, 6, 7, 8}, 0.8, 0.6, true}, - }; + {4, 32, 128, 0, 1.0, 0, {1, 2, 3, 4, 5, 6, 7, 8}, 0.8, 0.6, true}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); @@ -888,17 +887,25 @@ TEST_F(ScaleWriterLocalPartitionTest, partitionBasic) { .copyResults(pool_.get(), task); auto planStats = toPlanStats(task->taskStats()); if (testData.expectedRebalance) { - ASSERT_EQ( + // NOTE: thre is time race across multiple drivers on the shared balancer + // stats reporting which can't totally avoid in tests under some extreme + // conditions like the rebalance happens on the last close driver. + ASSERT_LE( planStats.at(exchnangeNodeId) .customStats.count( ScaleWriterPartitioningLocalPartition::kScaledPartitions), 1); - ASSERT_GT( - planStats.at(exchnangeNodeId) - .customStats - .at(ScaleWriterPartitioningLocalPartition::kScaledPartitions) - .sum, - 0); + if (planStats.at(exchnangeNodeId) + .customStats.count( + ScaleWriterPartitioningLocalPartition::kScaledPartitions) == + 1) { + ASSERT_GT( + planStats.at(exchnangeNodeId) + .customStats + .at(ScaleWriterPartitioningLocalPartition::kScaledPartitions) + .sum, + 0); + } ASSERT_EQ( planStats.at(exchnangeNodeId) .customStats.count( diff --git a/velox/exec/tests/SerializedPageSpillerTest.cpp b/velox/exec/tests/SerializedPageSpillerTest.cpp deleted file mode 100644 index 7c97c5b3d543..000000000000 --- a/velox/exec/tests/SerializedPageSpillerTest.cpp +++ /dev/null @@ -1,467 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/exec/SerializedPageSpiller.h" -#include "velox/common/base/tests/GTestUtils.h" -#include "velox/exec/tests/utils/OperatorTestBase.h" -#include "velox/exec/tests/utils/SerializedPageUtil.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" -#include "velox/exec/tests/utils/TempFilePath.h" - -using namespace facebook::velox; -using namespace facebook::velox::exec; -using namespace facebook::velox::exec::test; -namespace facebook::velox::exec::test { -class SerializedPageSpillerHelper { - public: - SerializedPageSpillerHelper(SerializedPageSpiller& spiller) - : spiller_(spiller) {} - - void checkSpillerConsistency() { - if (spiller_.totalPages_ == 0) { - ASSERT_EQ(spiller_.totalBytes_, 0); - ASSERT_EQ(spiller_.bufferStream_, nullptr); - } - if (spiller_.totalBytes_ > 0) { - ASSERT_GT(spiller_.totalPages_, 0); - ASSERT_GT(spiller_.bufferStream_->tellp(), spiller_.totalBytes_); - } - } - - private: - SerializedPageSpiller const& spiller_; -}; - -class SerializedPageSpillReaderHelper { - public: - SerializedPageSpillReaderHelper(SerializedPageSpillReader& reader) - : reader_(reader) {} - - void checkReaderConsistency() { - ASSERT_GE(reader_.numPages_, reader_.bufferedPages_.size()); - } - - void assertNumBufferedPages(uint64_t numPages) { - ASSERT_EQ(reader_.bufferedPages_.size(), numPages); - } - - private: - SerializedPageSpillReader const& reader_; -}; - -class SerializedPageSpillerTest : public exec::test::OperatorTestBase { - public: - void SetUp() override { - OperatorTestBase::SetUp(); - filesystems::registerLocalFileSystem(); - rng_.seed(0); - } - - protected: - std::vector> generateData( - uint32_t numPages, - int64_t maxPageSize, - bool hasVoidedNumRows, - int64_t maxNumRows) { - std::vector> pages; - pages.reserve(numPages); - for (auto i = 0; i < numPages; ++i) { - auto iobufBytes = folly::Random().rand64(maxPageSize, rng_); - - // Setup a chained iobuf. - std::unique_ptr iobuf; - if (iobufBytes > 1) { - auto firstHalfBytes = iobufBytes / 2; - iobuf = folly::IOBuf::create(firstHalfBytes); - std::memset(iobuf->writableData(), 'x', firstHalfBytes); - iobuf->append(firstHalfBytes); - - auto secondHalfBytes = iobufBytes - firstHalfBytes; - auto secondHalfBuf = folly::IOBuf::create(secondHalfBytes); - std::memset(secondHalfBuf->writableData(), 'y', secondHalfBytes); - secondHalfBuf->append(secondHalfBytes); - iobuf->prependChain(std::move(secondHalfBuf)); - } else { - iobuf = folly::IOBuf::create(iobufBytes); - std::memset(iobuf->writableData(), 'x', iobufBytes); - iobuf->append(iobufBytes); - } - - std::optional numRowsOpt; - if (!hasVoidedNumRows || folly::Random().oneIn(2, rng_)) { - numRowsOpt = std::optional(folly::Random().rand64(maxNumRows, rng_)); - } - pages.push_back(std::make_shared( - std::move(iobuf), nullptr, numRowsOpt)); - } - return pages; - } - - void checkIOBufsEqual( - std::unique_ptr& buf1, - std::unique_ptr& buf2) { - auto coalescedBuf1 = buf1->coalesce(); - auto coalescedBuf2 = buf2->coalesce(); - ASSERT_EQ(coalescedBuf1.size(), coalescedBuf2.size()); - ASSERT_EQ( - std::memcmp( - coalescedBuf1.data(), coalescedBuf2.data(), coalescedBuf1.size()), - 0); - } - - void checkSerializedPageEqual(SerializedPage& page1, SerializedPage& page2) { - ASSERT_EQ(page1.numRows().has_value(), page2.numRows().has_value()); - if (page1.numRows().has_value()) { - ASSERT_EQ(page1.numRows().value(), page1.numRows().value()); - } - auto buf1 = page1.getIOBuf(); - auto buf2 = page2.getIOBuf(); - checkIOBufsEqual(buf1, buf2); - } - - folly::Random::DefaultGenerator rng_; - common::UpdateAndCheckSpillLimitCB updateAndCheckSpillLimitCB_{ - [](uint64_t /* unused */) {}}; -}; -} // namespace facebook::velox::exec::test - -TEST_F(SerializedPageSpillerTest, pageSpillerBasic) { - auto pool = rootPool_->addLeafChild("destinationBufferSpiller"); - - struct TestValue { - uint32_t numPages; - int64_t maxPageSize; - bool hasVoidedNumRows; - int64_t maxNumRows; - uint64_t readBufferSize; - uint64_t writeBufferSize; - uint64_t targetFileSize; - - std::string debugString() const { - return fmt::format( - "numPages {}, maxPageSize {}, hasVoidedNumRows {}, maxNumRows {}, " - "readBufferSize {}, writeBufferSize {}, targetFileSize {}", - numPages, - maxPageSize, - hasVoidedNumRows, - maxNumRows, - readBufferSize, - writeBufferSize, - targetFileSize); - } - }; - - std::vector testValues{ - {10, 64, true, 20, 1024, 1024, 2048}, - {10, 64, false, 20, 1024, 0, 2048}, - {0, 64, true, 20, 1024, 256, 2048}, - {10, 64, true, 20, 1, 2048, 128}, - {10, 0, true, 20, 128, 2048, 128}}; - - for (const auto& testValue : testValues) { - SCOPED_TRACE(testValue.debugString()); - - auto tempFile = exec::test::TempFilePath::create(); - const auto& prefixPath = tempFile->getPath(); - auto fs = filesystems::getFileSystem(prefixPath, {}); - SCOPE_EXIT { - fs->remove(prefixPath); - }; - - auto pages = generateData( - testValue.numPages, - testValue.maxPageSize, - testValue.hasVoidedNumRows, - testValue.maxNumRows); - - folly::Synchronized spillStats; - SerializedPageSpiller spiller( - testValue.writeBufferSize, - testValue.targetFileSize, - prefixPath, - "", - updateAndCheckSpillLimitCB_, - pool.get(), - &spillStats); - SerializedPageSpillerHelper spillerHelper(spiller); - spillerHelper.checkSpillerConsistency(); - - spiller.spill(pages); - spillerHelper.checkSpillerConsistency(); - - auto spillResults = spiller.finishSpill(); - - SerializedPageSpillReader reader( - std::move(spillResults), - testValue.readBufferSize, - pool.get(), - &spillStats); - SerializedPageSpillReaderHelper readerHelper(reader); - - ASSERT_EQ(reader.empty(), pages.empty()); - ASSERT_EQ(reader.numPages(), pages.size()); - - VELOX_ASSERT_THROW(reader.at(pages.size()), ""); - readerHelper.checkReaderConsistency(); - - VELOX_ASSERT_THROW(reader.deleteFront(pages.size() + 1), ""); - readerHelper.checkReaderConsistency(); - - if (pages.empty()) { - ASSERT_TRUE(reader.empty()); - continue; - } - - ASSERT_FALSE(reader.empty()); - uint32_t i = 0; - while (!reader.empty()) { - if (reader.at(0) != nullptr) { - ASSERT_EQ(reader.at(0)->size(), pages[i]->size()); - } else { - ASSERT_EQ(pages[i], nullptr); - } - readerHelper.checkReaderConsistency(); - - auto unspilledPage = reader.at(0); - readerHelper.checkReaderConsistency(); - - ASSERT_LT(i, pages.size()); - ASSERT_EQ(unspilledPage->numRows(), pages[i]->numRows()); - ASSERT_EQ(unspilledPage->size(), pages[i]->size()); - auto originalIOBuf = pages[i]->getIOBuf(); - auto unspilledIOBuf = unspilledPage->getIOBuf(); - checkIOBufsEqual(originalIOBuf, unspilledIOBuf); - if (testValue.maxPageSize == 0) { - ASSERT_GE(pool->usedBytes(), 0); - } else { - ASSERT_GT(pool->usedBytes(), 0); - } - reader.deleteFront(1); - ++i; - } - ASSERT_EQ(i, pages.size()); - ASSERT_TRUE(reader.empty()); - ASSERT_EQ(reader.numPages(), 0); - - VELOX_ASSERT_THROW(reader.at(0), ""); - readerHelper.checkReaderConsistency(); - - VELOX_ASSERT_THROW(reader.deleteFront(1), ""); - readerHelper.checkReaderConsistency(); - - ASSERT_NO_THROW(reader.deleteAll()); - readerHelper.checkReaderConsistency(); - } - ASSERT_EQ(pool->usedBytes(), 0); -} - -TEST_F(SerializedPageSpillerTest, spillReaderAccessors) { - auto pool = rootPool_->addLeafChild("spillReaderAccessors"); - auto pages = generateData(20, 1LL << 20, true, 1000); - - struct TestValue { - std::string testName; - std::function>&, - SerializedPageSpillReader&, - uint64_t)> - accessorVerifier; - std::string debugString() { - return testName; - } - }; - - std::vector testValues{ - {"SerializedPageSpillReader::at", - [this](auto& originalPages, auto& reader, auto index) { - // Accessor verifier for SerializedPageSpillReader::at() - if (index >= originalPages.size()) { - VELOX_ASSERT_THROW(reader.at(index), ""); - return; - } - auto originalPage = originalPages[index]; - auto unspilledPage = reader.at(index); - if (originalPage == nullptr) { - ASSERT_EQ(unspilledPage, nullptr); - return; - } - ASSERT_EQ(originalPage->size(), unspilledPage->size()); - ASSERT_EQ(originalPage->numRows(), unspilledPage->numRows()); - auto originalIOBuf = originalPage->getIOBuf(); - auto unspilledIOBuf = unspilledPage->getIOBuf(); - checkIOBufsEqual(originalIOBuf, unspilledIOBuf); - }}}; - - for (auto& testValue : testValues) { - SCOPED_TRACE(testValue.debugString()); - auto tempFile = exec::test::TempFilePath::create(); - const auto& prefixPath = tempFile->getPath(); - auto fs = filesystems::getFileSystem(prefixPath, {}); - SCOPE_EXIT { - fs->remove(prefixPath); - }; - - folly::Synchronized spillStats; - SerializedPageSpiller spiller( - 1024, - 2048, - prefixPath, - "", - updateAndCheckSpillLimitCB_, - pool.get(), - &spillStats); - SerializedPageSpillerHelper spillerHelper(spiller); - spiller.spill(pages); - spillerHelper.checkSpillerConsistency(); - auto spillResults = spiller.finishSpill(); - - SerializedPageSpillReader reader( - std::move(spillResults), 1024, pool.get(), &spillStats); - SerializedPageSpillReaderHelper readerHelper(reader); - - testValue.accessorVerifier(pages, reader, 1); - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(2); - - testValue.accessorVerifier(pages, reader, 10); - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(11); - - testValue.accessorVerifier(pages, reader, 25); - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(11); - - testValue.accessorVerifier(pages, reader, 19); - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(20); - - testValue.accessorVerifier(pages, reader, 5); - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(20); - } - ASSERT_EQ(pool->usedBytes(), 0); -} - -TEST_F(SerializedPageSpillerTest, spillReaderDelete) { - auto pool = rootPool_->addLeafChild("spillReaderDelete"); - const auto kNumPages = 20; - auto pages = generateData(kNumPages, 1LL << 20, true, 1000); - - struct TestValue { - uint32_t numBufferedPages; - uint32_t numDelete; - - std::string debugString() { - return fmt::format( - "numBufferedPages {}, numDelete {}", numBufferedPages, numDelete); - } - }; - - std::vector testValues{ - {0, 0}, - {0, 10}, - {0, 20}, - {0, 25}, - {10, 0}, - {10, 5}, - {10, 15}, - {10, 20}, - {10, 25}}; - for (auto& testValue : testValues) { - SCOPED_TRACE(testValue.debugString()); - // Test delete front. - auto tempFile = exec::test::TempFilePath::create(); - const auto& prefixPath = tempFile->getPath(); - auto fs = filesystems::getFileSystem(prefixPath, {}); - SCOPE_EXIT { - fs->remove(prefixPath); - }; - - folly::Synchronized spillStats; - SerializedPageSpiller spiller( - 1024, - 2048, - prefixPath, - "", - updateAndCheckSpillLimitCB_, - pool.get(), - &spillStats); - SerializedPageSpillerHelper spillerHelper(spiller); - spiller.spill(pages); - spillerHelper.checkSpillerConsistency(); - auto spillResult = spiller.finishSpill(); - - SerializedPageSpillReader reader( - std::move(spillResult), 1024, pool.get(), &spillStats); - SerializedPageSpillReaderHelper readerHelper(reader); - - // Unspill pages to buffer - if (testValue.numBufferedPages > 0) { - reader.at(testValue.numBufferedPages - 1); - } - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(testValue.numBufferedPages); - - if (testValue.numDelete > kNumPages) { - VELOX_ASSERT_THROW(reader.deleteFront(testValue.numDelete), ""); - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(testValue.numBufferedPages); - continue; - } else { - reader.deleteFront(testValue.numDelete); - } - readerHelper.checkReaderConsistency(); - if (testValue.numDelete <= testValue.numBufferedPages) { - readerHelper.assertNumBufferedPages( - testValue.numBufferedPages - testValue.numDelete); - } else { - readerHelper.assertNumBufferedPages(0); - } - } - - { - // Test delete all. - auto tempFile = exec::test::TempFilePath::create(); - const auto& prefixPath = tempFile->getPath(); - auto fs = filesystems::getFileSystem(prefixPath, {}); - SCOPE_EXIT { - fs->remove(prefixPath); - }; - - folly::Synchronized spillStats; - SerializedPageSpiller spiller( - 1024, - 2048, - prefixPath, - "", - updateAndCheckSpillLimitCB_, - pool.get(), - &spillStats); - SerializedPageSpillerHelper spillerHelper(spiller); - spiller.spill(pages); - spillerHelper.checkSpillerConsistency(); - auto spillResult = spiller.finishSpill(); - - SerializedPageSpillReader reader( - std::move(spillResult), 1024, pool.get(), &spillStats); - SerializedPageSpillReaderHelper readerHelper(reader); - - reader.at(10); - reader.deleteAll(); - readerHelper.checkReaderConsistency(); - readerHelper.assertNumBufferedPages(0); - } -} diff --git a/velox/exec/tests/SimpleAverageAggregate.cpp b/velox/exec/tests/SimpleAverageAggregate.cpp index fea9254cb1cf..ea710380b3b6 100644 --- a/velox/exec/tests/SimpleAverageAggregate.cpp +++ b/velox/exec/tests/SimpleAverageAggregate.cpp @@ -34,9 +34,8 @@ class AverageAggregate { using InputType = Row; // Type of intermediate result vector wrapped in Row. - using IntermediateType = - Row; + using IntermediateType = Row; // Type of output vector. using OutputType = @@ -102,18 +101,20 @@ exec::AggregateRegistrationResult registerSimpleAverageAggregate( std::vector> signatures; for (const auto& inputType : {"smallint", "integer", "bigint", "double"}) { - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("double") - .intermediateType("row(double,bigint)") - .argumentType(inputType) - .build()); + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType("row(double,bigint)") + .argumentType(inputType) + .build()); } - signatures.push_back(exec::AggregateFunctionSignatureBuilder() - .returnType("real") - .intermediateType("row(double,bigint)") - .argumentType("real") - .build()); + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .returnType("real") + .intermediateType("row(double,bigint)") + .argumentType("real") + .build()); return exec::registerAggregateFunction( name, diff --git a/velox/exec/tests/SortBufferTest.cpp b/velox/exec/tests/SortBufferTest.cpp index 93bf32edf4df..80774a2d2613 100644 --- a/velox/exec/tests/SortBufferTest.cpp +++ b/velox/exec/tests/SortBufferTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/SortBuffer.h" +#include #include #include "velox/common/base/tests/GTestUtils.h" @@ -89,6 +90,7 @@ class SortBufferTest : public OperatorTestBase, 0, 0, "none", + 0, spillPrefixSortConfig); } @@ -118,7 +120,9 @@ class SortBufferTest : public OperatorTestBase, const std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::hardware_concurrency())}; + const std::shared_ptr fuzzerPool_ = + memory::memoryManager()->addLeafPool("SortBufferTest"); tsan_atomic nonReclaimableSection_{false}; folly::Random::DefaultGenerator rng_; @@ -335,13 +339,10 @@ TEST_P(SortBufferTest, DISABLED_randomData) { &nonReclaimableSection_, prefixSortConfig_); - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("VectorFuzzer"); - std::vector inputVectors; inputVectors.reserve(3); for (size_t inputRows : {1000, 1000, 1000}) { - VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool_.get()); RowVectorPtr input = fuzzer.fuzzRow(inputType_); sortBuffer->addInput(input); inputVectors.push_back(input); @@ -401,6 +402,7 @@ TEST_P(SortBufferTest, batchOutput) { 0, 0, "none", + 0, prefixSortConfig_); folly::Synchronized spillStats; auto sortBuffer = std::make_unique( @@ -413,15 +415,11 @@ TEST_P(SortBufferTest, batchOutput) { testData.triggerSpill ? &spillConfig : nullptr, &spillStats); ASSERT_EQ(sortBuffer->canSpill(), testData.triggerSpill); - - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("VectorFuzzer"); - std::vector inputVectors; inputVectors.reserve(testData.numInputRows.size()); uint64_t totalNumInput = 0; for (size_t inputRows : testData.numInputRows) { - VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = inputRows}, fuzzerPool_.get()); RowVectorPtr input = fuzzer.fuzzRow(inputType_); sortBuffer->addInput(input); inputVectors.push_back(input); @@ -498,6 +496,7 @@ TEST_P(SortBufferTest, spill) { 0, 0, "none", + 0, prefixSortConfig_); folly::Synchronized spillStats; auto sortBuffer = std::make_unique( @@ -510,9 +509,7 @@ TEST_P(SortBufferTest, spill) { testData.spillEnabled ? &spillConfig : nullptr, &spillStats); - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("spillSource"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); uint64_t totalNumInput = 0; ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); @@ -595,9 +592,7 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringInput) { ASSERT_EQ(sortBuffer->pool()->usedBytes(), 0); }))); - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("spillDuringInput"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); const auto peakSpillMemoryUsage = @@ -621,6 +616,7 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringInput) { ASSERT_GE( memory::spillMemoryPool()->stats().peakBytes, peakSpillMemoryUsage); } + ASSERT_EQ(pool_->usedBytes(), 0); } DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringOutput) { @@ -645,10 +641,7 @@ DEBUG_ONLY_TEST_P(SortBufferTest, spillDuringOutput) { sortBuffer->spill(); ASSERT_EQ(sortBuffer->pool()->usedBytes(), 0); }))); - - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("spillDuringOutput"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); ASSERT_EQ(memory::spillMemoryPool()->stats().usedBytes, 0); const auto peakSpillMemoryUsage = @@ -690,10 +683,7 @@ DEBUG_ONLY_TEST_P(SortBufferTest, reserveMemorySortGetOutput) { prefixSortConfig_, spillEnabled ? &spillConfig : nullptr, &spillStats); - - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("reserveMemoryGetOutput"); - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); const int numInputs{10}; for (int i = 0; i < numInputs; ++i) { @@ -735,8 +725,11 @@ DEBUG_ONLY_TEST_P(SortBufferTest, reserveMemorySort) { } testSettings[] = {{false, true}, {true, false}, {true, true}}; for (const auto [usePrefixSort, spillEnabled] : testSettings) { - SCOPED_TRACE(fmt::format( - "usePrefixSort: {}, spillEnabled: {}, ", usePrefixSort, spillEnabled)); + SCOPED_TRACE( + fmt::format( + "usePrefixSort: {}, spillEnabled: {}, ", + usePrefixSort, + spillEnabled)); auto spillDirectory = exec::test::TempDirectoryPath::create(); auto spillConfig = getSpillConfig(spillDirectory->getPath(), usePrefixSort); folly::Synchronized spillStats; @@ -777,9 +770,6 @@ DEBUG_ONLY_TEST_P(SortBufferTest, reserveMemorySort) { } TEST_P(SortBufferTest, emptySpill) { - const std::shared_ptr fuzzerPool = - memory::memoryManager()->addLeafPool("emptySpillSource"); - for (bool hasPostSpillData : {false, true}) { SCOPED_TRACE(fmt::format("hasPostSpillData {}", hasPostSpillData)); auto spillDirectory = exec::test::TempDirectoryPath::create(); @@ -797,7 +787,7 @@ TEST_P(SortBufferTest, emptySpill) { sortBuffer->spill(); if (hasPostSpillData) { - VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool.get()); + VectorFuzzer fuzzer({.vectorSize = 1024}, fuzzerPool_.get()); sortBuffer->addInput(fuzzer.fuzzRow(inputType_)); } sortBuffer->noMoreInput(); diff --git a/velox/exec/tests/SpatialIndexTest.cpp b/velox/exec/tests/SpatialIndexTest.cpp new file mode 100644 index 000000000000..e2a7778cd52f --- /dev/null +++ b/velox/exec/tests/SpatialIndexTest.cpp @@ -0,0 +1,550 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/SpatialIndex.h" +#include +#include + +using namespace ::testing; +using namespace facebook::velox::exec; + +namespace facebook::velox::exec::test { + +class SpatialIndexTest : public virtual testing::Test { + protected: + void makeIndex( + std::vector envelopes, + uint32_t branchSize = SpatialIndex::kDefaultRTreeBranchSize) { + branchSize_ = branchSize; + Envelope bounds = Envelope::of(envelopes); + index_ = SpatialIndex(std::move(bounds), std::move(envelopes), branchSize); + } + + Envelope indexBounds() const { + return index_.bounds(); + } + + void assertQuery( + double minX, + double minY, + double maxX, + double maxY, + std::vector expected) const { + std::vector actual = + index_.query(Envelope::from(minX, minY, maxX, maxY)); + std::sort(actual.begin(), actual.end()); + std::sort(expected.begin(), expected.end()); + ASSERT_EQ(actual, expected); + } + + SpatialIndex index_; + uint32_t branchSize_ = SpatialIndex::kDefaultRTreeBranchSize; +}; + +TEST_F(SpatialIndexTest, testEnvelope) { + Envelope empty = Envelope::empty(); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_FALSE(Envelope::intersects(empty, empty)); + + Envelope point = + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = -1}; + ASSERT_FALSE(point.isEmpty()); + ASSERT_FALSE(Envelope::intersects(empty, point)); + ASSERT_TRUE(Envelope::intersects(point, point)); +} + +TEST_F(SpatialIndexTest, testNaNHandling) { + float nan = std::numeric_limits::quiet_NaN(); + + Envelope envWithNaN{ + .minX = nan, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN.isEmpty()); + + Envelope envWithNaN2{ + .minX = 0, .minY = 0, .maxX = nan, .maxY = 1, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN2.isEmpty()); + + Envelope envWithNaN3{ + .minX = 0, .minY = nan, .maxX = 1, .maxY = 1, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN3.isEmpty()); + + Envelope envWithNaN4{ + .minX = 0, .minY = 0, .maxX = 1, .maxY = nan, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN4.isEmpty()); + + Envelope envWithNaN5{ + .minX = nan, .minY = nan, .maxX = nan, .maxY = nan, .rowIndex = 0}; + ASSERT_TRUE(envWithNaN5.isEmpty()); +} + +TEST_F(SpatialIndexTest, testEmptyIndex) { + makeIndex(std::vector{}); + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, std::numeric_limits::infinity()); + ASSERT_EQ(bounds.minY, std::numeric_limits::infinity()); + ASSERT_EQ(bounds.maxX, -std::numeric_limits::infinity()); + ASSERT_EQ(bounds.maxY, -std::numeric_limits::infinity()); + ASSERT_EQ(bounds.rowIndex, -1); + + assertQuery(0, 0, 1, 1, {}); +} + +TEST_F(SpatialIndexTest, testSingleEnvelope) { + makeIndex( + std::vector{Envelope{ + .minX = 1, .minY = 11, .maxX = 2, .maxY = 12, .rowIndex = 0}}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, 1); + ASSERT_EQ(bounds.minY, 11); + ASSERT_EQ(bounds.maxX, 2); + ASSERT_EQ(bounds.maxY, 12); + + assertQuery(1.5, 11.5, 1.5, 11.5, {0}); + assertQuery(0.5, 10.5, 1.5, 11.5, {0}); + assertQuery(0, 10, 0.5, 10.5, {}); + assertQuery(3, 13, 4, 14, {}); +} + +TEST_F(SpatialIndexTest, testPointProbe) { + makeIndex( + std::vector{ + Envelope{.minX = 1, .minY = 0, .maxX = 1, .maxY = 0, .rowIndex = 6}, + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = 5}, + Envelope{.minX = 0, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 4}, + Envelope{.minX = -1, .minY = -1, .maxX = 0, .maxY = 0, .rowIndex = 3}, + Envelope{.minX = -1, .minY = -1, .maxX = 1, .maxY = 1, .rowIndex = 2}, + Envelope{ + .minX = 0.5, .minY = 0.5, .maxX = 1, .maxY = 1, .rowIndex = 1}, + }); + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, -1); + ASSERT_EQ(bounds.minY, -1); + ASSERT_EQ(bounds.maxX, 1); + ASSERT_EQ(bounds.maxY, 1); + ASSERT_EQ(bounds.rowIndex, -1); + + assertQuery(0, 0, 0, 0, {2, 3, 4, 5}); + assertQuery(0, 1, 0, 1, {2, 4}); +} + +TEST_F(SpatialIndexTest, testFloatImprecision) { + // Since the index casts doubles to floats then nudges the result, + // we should make sure that the index gives the right results on + // cases where the double doesn't have an exact float representation. + float float1 = 1.0f; + float float1Down = + std::nextafterf(float1, -std::numeric_limits::infinity()); + float float2 = 2.0f; + float float2Up = + std::nextafterf(float2, std::numeric_limits::infinity()); + + double baseMax = static_cast(float2); + double baseMaxUp = + std::nextafter(baseMax, std::numeric_limits::infinity()); + double baseMaxDown = + std::nextafter(baseMax, -std::numeric_limits::infinity()); + double baseMin = static_cast(float1); + double baseMinUp = + std::nextafter(baseMin, std::numeric_limits::infinity()); + double baseMinDown = + std::nextafter(baseMin, -std::numeric_limits::infinity()); + + makeIndex( + std::vector{ + Envelope::from(baseMin, baseMin, baseMax, baseMax, 1), + Envelope::from(baseMinUp, baseMinUp, baseMaxUp, baseMaxUp, 2), + Envelope::from(baseMinDown, baseMinDown, baseMaxDown, baseMaxDown, 3), + }); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, float1Down); + ASSERT_EQ(bounds.minY, float1Down); + ASSERT_EQ(bounds.maxX, float2Up); + ASSERT_EQ(bounds.maxY, float2Up); + + assertQuery(2.1, 2.1, 2.1, 2.1, {}); + assertQuery(baseMin, baseMin, baseMin, baseMin, {1, 2, 3}); + assertQuery(baseMinDown, baseMinDown, baseMinDown, baseMinDown, {1, 2, 3}); + assertQuery(baseMinUp, baseMinUp, baseMinUp, baseMinUp, {1, 2, 3}); + assertQuery(baseMax, baseMax, baseMax, baseMax, {1, 2, 3}); + assertQuery(baseMaxDown, baseMaxDown, baseMaxDown, baseMaxDown, {1, 2, 3}); + assertQuery(baseMaxUp, baseMaxUp, baseMaxUp, baseMaxUp, {1, 2, 3}); +} + +TEST_F(SpatialIndexTest, testFloatImprecisionSubnormal) { + // Check that our bumping rules work for subnormal floats as well. + float subnormalFloatDown = + std::nextafterf(0.0, -std::numeric_limits::infinity()); + float subnormalFloatUp = + std::nextafterf(0.0, std::numeric_limits::infinity()); + + double subnormalDoubleDown = + std::nextafter(0.0, -std::numeric_limits::infinity()); + double subnormalDoubleUp = + std::nextafter(0.0, std::numeric_limits::infinity()); + + makeIndex( + std::vector{ + Envelope::from(0.0, 0.0, 0.0, 0.0, 1), + Envelope::from( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + 2), + Envelope::from( + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + 3), + Envelope::from( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleUp, + subnormalDoubleUp, + 4), + }); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, subnormalFloatDown); + ASSERT_EQ(bounds.minY, subnormalFloatDown); + ASSERT_EQ(bounds.maxX, subnormalFloatUp); + ASSERT_EQ(bounds.maxY, subnormalFloatUp); + + assertQuery(0.1, 0.1, 0.1, 0.1, {}); + assertQuery(0.0, 0.0, 0.0, 0.0, {1, 2, 3, 4}); + assertQuery( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleDown, + {1, 2, 3, 4}); + assertQuery( + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + subnormalDoubleUp, + {1, 2, 3, 4}); + assertQuery( + subnormalDoubleDown, + subnormalDoubleDown, + subnormalDoubleUp, + subnormalDoubleUp, + {1, 2, 3, 4}); +} + +TEST_F(SpatialIndexTest, testNegativeCoordinates) { + makeIndex( + std::vector{ + Envelope{ + .minX = -5, .minY = -5, .maxX = -1, .maxY = -1, .rowIndex = 0}, + Envelope{ + .minX = -10, .minY = -10, .maxX = -6, .maxY = -6, .rowIndex = 1}, + Envelope{ + .minX = -3, .minY = -8, .maxX = 2, .maxY = -4, .rowIndex = 2}}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, -10); + ASSERT_EQ(bounds.minY, -10); + ASSERT_EQ(bounds.maxX, 2); + ASSERT_EQ(bounds.maxY, -1); + + assertQuery(-7, -7, -7, -7, {1}); + assertQuery(-2, -5, -2, -5, {0, 2}); + assertQuery(0, 0, 1, 1, {}); +} + +TEST_F(SpatialIndexTest, testOverlappingEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 8, .maxY = 8, .rowIndex = 2}, + Envelope{ + .minX = 7, .minY = 7, .maxX = 12, .maxY = 12, .rowIndex = 3}}); + + assertQuery(6, 6, 6, 6, {0, 1, 2}); + assertQuery(8, 8, 8, 8, {0, 1, 2, 3}); + assertQuery(9, 9, 9, 9, {0, 1, 3}); + assertQuery(3, 3, 3, 3, {0, 2}); + assertQuery(13, 13, 13, 13, {1}); +} + +TEST_F(SpatialIndexTest, testNonOverlappingEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 0}, + Envelope{.minX = 2, .minY = 2, .maxX = 3, .maxY = 3, .rowIndex = 1}, + Envelope{.minX = 4, .minY = 4, .maxX = 5, .maxY = 5, .rowIndex = 2}, + Envelope{.minX = 6, .minY = 6, .maxX = 7, .maxY = 7, .rowIndex = 3}}); + + assertQuery(0.5, 0.5, 0.5, 0.5, {0}); + assertQuery(2.5, 2.5, 2.5, 2.5, {1}); + assertQuery(4.5, 4.5, 4.5, 4.5, {2}); + assertQuery(6.5, 6.5, 6.5, 6.5, {3}); + assertQuery(1.5, 1.5, 1.5, 1.5, {}); +} + +TEST_F(SpatialIndexTest, testLargeQueryEnvelope) { + makeIndex( + std::vector{ + Envelope{.minX = 1, .minY = 1, .maxX = 2, .maxY = 2, .rowIndex = 0}, + Envelope{.minX = 3, .minY = 3, .maxX = 4, .maxY = 4, .rowIndex = 1}, + Envelope{.minX = 5, .minY = 5, .maxX = 6, .maxY = 6, .rowIndex = 2}}); + + assertQuery(0, 0, 10, 10, {0, 1, 2}); + assertQuery(-100, -100, 100, 100, {0, 1, 2}); +} + +TEST_F(SpatialIndexTest, testSmallQueryEnvelope) { + makeIndex( + std::vector{ + Envelope{ + .minX = 0, .minY = 0, .maxX = 100, .maxY = 100, .rowIndex = 0}, + Envelope{ + .minX = 50, + .minY = 50, + .maxX = 150, + .maxY = 150, + .rowIndex = 1}}); + + assertQuery(25, 25, 26, 26, {0}); + assertQuery(75, 75, 76, 76, {0, 1}); + assertQuery(125, 125, 126, 126, {1}); + assertQuery(0.1, 0.1, 0.2, 0.2, {0}); +} + +TEST_F(SpatialIndexTest, testEdgeTouching) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 5, .maxY = 5, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 0, .maxX = 10, .maxY = 5, .rowIndex = 1}, + Envelope{.minX = 0, .minY = 5, .maxX = 5, .maxY = 10, .rowIndex = 2}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 3}}); + + assertQuery(5, 5, 5, 5, {0, 1, 2, 3}); + assertQuery(5, 2, 5, 2, {0, 1}); + assertQuery(2, 5, 2, 5, {0, 2}); +} + +TEST_F(SpatialIndexTest, testCornerTouching) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 5, .maxY = 5, .rowIndex = 0}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 1}}); + + assertQuery(5, 5, 5, 5, {0, 1}); + assertQuery(4.9, 4.9, 5.1, 5.1, {0, 1}); +} + +TEST_F(SpatialIndexTest, testInfiniteValues) { + float inf = std::numeric_limits::infinity(); + float negInf = -std::numeric_limits::infinity(); + + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 1, .maxY = 1, .rowIndex = 0}}); + + assertQuery(inf, inf, inf, inf, {}); + assertQuery(negInf, negInf, negInf, negInf, {}); + assertQuery(negInf, negInf, inf, inf, {0}); +} + +TEST_F(SpatialIndexTest, testLargeDataset) { + std::vector envelopes; + envelopes.reserve(1000); + for (int i = 0; i < 1000; ++i) { + envelopes.push_back( + Envelope{ + .minX = static_cast(i), + .minY = static_cast(i), + .maxX = static_cast(i + 1), + .maxY = static_cast(i + 1), + .rowIndex = i}); + } + makeIndex(std::move(envelopes)); + + assertQuery(500.5, 500.5, 500.5, 500.5, {500}); + assertQuery(100.5, 100.5, 104.5, 104.5, {100, 101, 102, 103, 104}); + assertQuery(-1, -1, -1, -1, {}); + assertQuery(1001, 1001, 1001, 1001, {}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, 0); + ASSERT_EQ(bounds.minY, 0); + ASSERT_EQ(bounds.maxX, 1000); + ASSERT_EQ(bounds.maxY, 1000); +} + +TEST_F(SpatialIndexTest, testVeryLargeCoordinates) { + float largeVal = 1e20f; + makeIndex( + std::vector{ + Envelope{ + .minX = -largeVal, + .minY = -largeVal, + .maxX = -largeVal + 1, + .maxY = -largeVal + 1, + .rowIndex = 0}, + Envelope{ + .minX = largeVal - 1, + .minY = largeVal - 1, + .maxX = largeVal, + .maxY = largeVal, + .rowIndex = 1}}); + + assertQuery( + -largeVal + 0.5, -largeVal + 0.5, -largeVal + 0.5, -largeVal + 0.5, {0}); + assertQuery( + largeVal - 0.5, largeVal - 0.5, largeVal - 0.5, largeVal - 0.5, {1}); + assertQuery(0, 0, 0, 0, {}); +} + +TEST_F(SpatialIndexTest, testQueryOutsideBounds) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}}); + + assertQuery(-10, -10, -5, -5, {}); + assertQuery(20, 20, 25, 25, {}); + assertQuery(-10, 5, -5, 10, {}); + assertQuery(5, 20, 10, 25, {}); +} + +TEST_F(SpatialIndexTest, testPartialOverlap) { + makeIndex( + std::vector{Envelope{ + .minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}}); + + assertQuery(-5, -5, 5, 5, {0}); + assertQuery(5, -5, 15, 5, {0}); + assertQuery(-5, 5, 5, 15, {0}); + assertQuery(5, 5, 15, 15, {0}); +} + +TEST_F(SpatialIndexTest, testMixedSizeEnvelopes) { + makeIndex( + std::vector{ + Envelope{ + .minX = 0, .minY = 0, .maxX = 0.1, .maxY = 0.1, .rowIndex = 0}, + Envelope{ + .minX = 1, .minY = 1, .maxX = 100, .maxY = 100, .rowIndex = 1}, + Envelope{ + .minX = 50, .minY = 50, .maxX = 51, .maxY = 51, .rowIndex = 2}}); + + assertQuery(0.05, 0.05, 0.05, 0.05, {0}); + assertQuery(50, 50, 100, 100, {1, 2}); + assertQuery(25, 25, 25, 25, {1}); +} + +TEST_F(SpatialIndexTest, testZeroAreaEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 0, .maxY = 0, .rowIndex = 0}, + Envelope{.minX = 1, .minY = 1, .maxX = 1, .maxY = 1, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 2, .maxY = 2, .rowIndex = 2}}); + + assertQuery(0, 0, 0, 0, {0}); + assertQuery(1, 1, 1, 1, {1}); + assertQuery(2, 2, 2, 2, {2}); + assertQuery(0.5, 0.5, 0.5, 0.5, {}); + assertQuery(0, 0, 2, 2, {0, 1, 2}); +} + +TEST_F(SpatialIndexTest, testIdenticalEnvelopes) { + makeIndex( + std::vector{ + Envelope{.minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 1}, + Envelope{ + .minX = 5, .minY = 5, .maxX = 10, .maxY = 10, .rowIndex = 2}}); + + assertQuery(7, 7, 7, 7, {0, 1, 2}); + assertQuery(5, 5, 10, 10, {0, 1, 2}); + assertQuery(4, 4, 4, 4, {}); +} + +TEST_F(SpatialIndexTest, testDifferentBranchSizes) { + std::vector branchSizes = {2, 3, 4, 8, 16, 32, 64, 128, 256}; + + for (uint32_t branchSize : branchSizes) { + std::vector envelopes; + envelopes.reserve(100); + for (int i = 0; i < 100; ++i) { + envelopes.push_back( + Envelope{ + .minX = static_cast(i), + .minY = static_cast(i), + .maxX = static_cast(i + 1), + .maxY = static_cast(i + 1), + .rowIndex = i}); + } + makeIndex(std::move(envelopes), branchSize); + + assertQuery(50.5, 50.5, 50.5, 50.5, {50}); + assertQuery(10.5, 10.5, 14.5, 14.5, {10, 11, 12, 13, 14}); + assertQuery(-1, -1, -1, -1, {}); + assertQuery(101, 101, 101, 101, {}); + + Envelope bounds = indexBounds(); + ASSERT_EQ(bounds.minX, 0); + ASSERT_EQ(bounds.minY, 0); + ASSERT_EQ(bounds.maxX, 100); + ASSERT_EQ(bounds.maxY, 100); + } +} + +TEST_F(SpatialIndexTest, testSmallBranchSize) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 8, .maxY = 8, .rowIndex = 2}, + Envelope{ + .minX = 7, .minY = 7, .maxX = 12, .maxY = 12, .rowIndex = 3}}, + 2); + + assertQuery(6, 6, 6, 6, {0, 1, 2}); + assertQuery(8, 8, 8, 8, {0, 1, 2, 3}); + assertQuery(9, 9, 9, 9, {0, 1, 3}); + assertQuery(3, 3, 3, 3, {0, 2}); + assertQuery(13, 13, 13, 13, {1}); +} + +TEST_F(SpatialIndexTest, testLargeBranchSize) { + makeIndex( + std::vector{ + Envelope{.minX = 0, .minY = 0, .maxX = 10, .maxY = 10, .rowIndex = 0}, + Envelope{.minX = 5, .minY = 5, .maxX = 15, .maxY = 15, .rowIndex = 1}, + Envelope{.minX = 2, .minY = 2, .maxX = 8, .maxY = 8, .rowIndex = 2}, + Envelope{ + .minX = 7, .minY = 7, .maxX = 12, .maxY = 12, .rowIndex = 3}}, + 512); + + assertQuery(6, 6, 6, 6, {0, 1, 2}); + assertQuery(8, 8, 8, 8, {0, 1, 2, 3}); + assertQuery(9, 9, 9, 9, {0, 1, 3}); + assertQuery(3, 3, 3, 3, {0, 2}); + assertQuery(13, 13, 13, 13, {1}); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/SpatialJoinTest.cpp b/velox/exec/tests/SpatialJoinTest.cpp new file mode 100644 index 000000000000..01606121a403 --- /dev/null +++ b/velox/exec/tests/SpatialJoinTest.cpp @@ -0,0 +1,767 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/core/PlanFragment.h" +#include "velox/core/PlanNode.h" +#include "velox/core/QueryConfig.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::exec::test { + +class SpatialJoinTest : public OperatorTestBase { + public: + /* The polygons below have the following relations: + * [Legend: = equals, v overlap, / disjoint] + * + * A B C D + * A = v / / + * B v = / / + * C / / = / + * D / / / = + * + * Overlap means geometries share interior points (for full definition see + * (DE-9IM)[https://en.wikipedia.org/wiki/DE-9IM]), but neither contains the + * other. + */ + static constexpr std::string_view kPolygonA = + "POLYGON ((0 0, -0.5 2.5, 0 5, 2.5 5.5, 5 5, 5.5 2.5, 5 0, 2.5 -0.5, 0 0))"; + static constexpr std::string_view kPolygonB = + "POLYGON ((4 4, 3.5 7, 4 10, 7 10.5, 10 10, 10.5 7, 10 4, 7 3.5, 4 4))"; + static constexpr std::string_view kPolygonC = + "POLYGON ((15 15, 15 14, 14 14, 14 15, 15 15))"; + static constexpr std::string_view kPolygonD = + "POLYGON ((18 18, 18 19, 19 19, 19 18, 18 18))"; + + // A set of points: X in A, Y in A and B, Z in B, W outside of A and B + static constexpr std::string_view kPointX = "POINT (1 1)"; + static constexpr std::string_view kPointY = "POINT (4.5 4.5)"; + static constexpr std::string_view kPointZ = "POINT (6 6)"; + static constexpr std::string_view kPointW = "POINT (20 20)"; + static constexpr std::string_view kPointV = "POINT (15 15)"; + static constexpr std::string_view kPointS = "POINT (18 18)"; + static constexpr std::string_view kPointQ = "POINT (28 28)"; + static constexpr std::string_view kMultipointU = "MULTIPOINT (15 15)"; + static constexpr std::string_view kMultipointT = + "MULTIPOINT (14.5 14.5, 16 16)"; + static constexpr std::string_view kMultipointR = "MULTIPOINT (15 15, 19 19)"; + + protected: + void runTest( + const std::vector>& probeWkts, + const std::vector>& buildWkts, + const std::optional>>& radiiOpt, + const std::string& predicate, + core::JoinType joinType, + const std::vector>& expectedLeftWkts, + const std::vector>& expectedRightWkts) { + for (bool separateProbeBatches : {false, true}) { + for (size_t maxBatchSize : {128, 3, 2, 1}) { + for (int32_t maxDrivers : {1, 4}) { + runTestWithConfig( + probeWkts, + buildWkts, + radiiOpt, + predicate, + joinType, + expectedLeftWkts, + expectedRightWkts, + maxDrivers, + maxBatchSize, + separateProbeBatches); + } + } + } + } + + void runTestWithConfig( + const std::vector>& probeWkts, + const std::vector>& buildWkts, + const std::optional>>& radiiOpt, + const std::string& predicate, + core::JoinType joinType, + const std::vector>& expectedLeftWkts, + const std::vector>& expectedRightWkts, + int32_t maxDrivers, + size_t maxBatchSize, + bool separateBatches) { + std::vector> probeWktsStr( + probeWkts.begin(), probeWkts.end()); + std::vector> buildWktsStr( + buildWkts.begin(), buildWkts.end()); + std::vector> expectedLeftWktsStr( + expectedLeftWkts.begin(), expectedLeftWkts.end()); + std::vector> expectedRightWktsStr( + expectedRightWkts.begin(), expectedRightWkts.end()); + auto radii = radiiOpt.value_or( + std::vector>(buildWkts.size(), std::nullopt)); + VELOX_CHECK_EQ(radii.size(), buildWkts.size()); + std::optional radiusVariable = std::nullopt; + if (radiiOpt.has_value()) { + radiusVariable = "radius"; + } + + std::vector probeBatches; + std::vector buildBatches; + if (separateBatches) { + for (const auto& wkt : probeWktsStr) { + probeBatches.push_back(makeRowVector( + {"left_g"}, {makeNullableFlatVector({wkt})})); + } + if (probeBatches.empty()) { + probeBatches.push_back(makeRowVector( + {"left_g"}, {makeNullableFlatVector({})})); + } + + for (size_t idx = 0; idx < buildWktsStr.size(); ++idx) { + auto& wkt = buildWktsStr[idx]; + buildBatches.push_back(makeRowVector( + {"right_g", "radius"}, + {makeNullableFlatVector({wkt}), + makeNullableFlatVector({radii[idx]})})); + } + if (buildBatches.empty()) { + buildBatches.push_back(makeRowVector( + {"right_g", "radius"}, + {makeNullableFlatVector({}), + makeNullableFlatVector({})})); + } + } else { + probeBatches.push_back(makeRowVector( + {"left_g"}, {makeNullableFlatVector(probeWktsStr)})); + buildBatches.push_back(makeRowVector( + {"right_g", "radius"}, + {makeNullableFlatVector(buildWktsStr), + makeNullableFlatVector(radii)})); + } + auto expectedRows = makeRowVector( + {"left_g", "right_g"}, + {makeNullableFlatVector(expectedLeftWktsStr), + makeNullableFlatVector(expectedRightWktsStr)}); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values(probeBatches) + .project({"ST_GeometryFromText(left_g) AS left_g"}) + .localPartitionRoundRobinRow() + .spatialJoin( + PlanBuilder(planNodeIdGenerator) + .values(buildBatches) + .project( + {"ST_GeometryFromText(right_g) AS right_g", "radius"}) + .localPartition({}) + .planNode(), + predicate, + "left_g", + "right_g", + radiusVariable, + {"left_g", "right_g"}, + joinType) + .project( + {"ST_AsText(left_g) AS left_g", + "ST_AsText(right_g) AS right_g"}) + .planNode(); + AssertQueryBuilder builder{plan}; + builder.maxDrivers(maxDrivers) + .config(core::QueryConfig::kPreferredOutputBatchRows, maxBatchSize) + .config(core::QueryConfig::kMaxOutputBatchRows, maxBatchSize) + .assertResults({expectedRows}); + } +}; + +TEST_F(SpatialJoinTest, testTrivialSpatialJoin) { + runTest( + {"POINT (1 1)"}, + {"POINT (1 1)"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (1 1)"}, + {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testSimpleSpatialInnerJoin) { + runTest( + {"POINT (1 1)", "POINT (1 2)"}, + {"POINT (1 1)", "POINT (2 1)"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (1 1)"}, + {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testSimpleSpatialLeftJoin) { + runTest( + {"POINT (1 1)", "POINT (1 2)"}, + {"POINT (1 1)", "POINT (2 1)"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {"POINT (1 1)", "POINT (1 2)"}, + {"POINT (1 1)", std::nullopt}); +} + +TEST_F(SpatialJoinTest, testSpatialJoinNullRows) { + runTest( + {"POINT (0 0)", std::nullopt, "POINT (1 1)", std::nullopt}, + {"POINT (0 0)", "POINT (1 1)", std::nullopt, std::nullopt}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (0 0)", "POINT (1 1)"}, + {"POINT (0 0)", "POINT (1 1)"}); + runTest( + {"POINT (0 0)", std::nullopt, "POINT (2 2)", std::nullopt}, + {"POINT (0 0)", "POINT (1 1)", std::nullopt, std::nullopt}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {"POINT (0 0)", "POINT (2 2)", std::nullopt, std::nullopt}, + {"POINT (0 0)", std::nullopt, std::nullopt, std::nullopt}); +} + +// Test geometries that don't intersect but their envelopes do. +// Important to test spatial index +TEST_F(SpatialJoinTest, simpleSpatialJoinEnvelopes) { + runTest( + {"POINT (0.5 0.6)", "POINT (0.5 0.5)", "LINESTRING (0 0.1, 0.9 1)"}, + {"POLYGON ((0 0, 1 1, 1 0, 0 0))"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (0.5 0.5)"}, + {"POLYGON ((0 0, 1 1, 1 0, 0 0))"}); +} + +TEST_F(SpatialJoinTest, testSelfSpatialJoin) { + std::vector> inputWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + std::vector> leftOutputWkts = { + kPolygonA, kPolygonA, kPolygonB, kPolygonB, kPolygonC, kPolygonD}; + std::vector> rightOutputWkts = { + kPolygonA, kPolygonB, kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + + runTest( + inputWkts, + inputWkts, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + leftOutputWkts, + rightOutputWkts); + + runTest( + inputWkts, + inputWkts, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + leftOutputWkts, + rightOutputWkts); + + runTest( + inputWkts, + inputWkts, + std::nullopt, + "ST_Overlaps(left_g, right_g)", + core::JoinType::kInner, + {kPolygonA, kPolygonB}, + {kPolygonB, kPolygonA}); + + runTest( + inputWkts, + inputWkts, + std::nullopt, + "ST_Intersects(left_g, right_g) AND ST_Overlaps(left_g, right_g)", + core::JoinType::kInner, + {kPolygonA, kPolygonB}, + {kPolygonB, kPolygonA}); + + runTest( + inputWkts, + inputWkts, + std::nullopt, + "ST_Overlaps(left_g, right_g)", + core::JoinType::kLeft, + {kPolygonA, kPolygonB, kPolygonC, kPolygonD}, + {kPolygonB, kPolygonA, std::nullopt, std::nullopt}); + + runTest( + inputWkts, + inputWkts, + std::nullopt, + "ST_Equals(left_g, right_g)", + core::JoinType::kInner, + inputWkts, + inputWkts); + + runTest( + inputWkts, + inputWkts, + std::nullopt, + "ST_Equals(left_g, right_g)", + core::JoinType::kLeft, + inputWkts, + inputWkts); +} + +TEST_F(SpatialJoinTest, pointPolygonSpatialJoin) { + std::vector> polygonWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + std::vector> pointWkts = { + kPointX, + kPointY, + kPointZ, + kPointW, + kPointV, + kPointS, + kPointQ, + kMultipointU, + kMultipointT, + kMultipointR}; + + std::vector> pointOutputWkts = { + kPointX, + kPointY, + kPointY, + kPointZ, + kPointV, + kPointS, + kMultipointU, + kMultipointR, + kMultipointR, + kMultipointT}; + std::vector> polygonOutputWkts = { + kPolygonA, + kPolygonA, + kPolygonB, + kPolygonB, + kPolygonC, + kPolygonD, + kPolygonC, + kPolygonC, + kPolygonD, + kPolygonC}; + runTest( + pointWkts, + polygonWkts, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + pointOutputWkts, + polygonOutputWkts); +} + +TEST_F(SpatialJoinTest, testSimpleNullRowsJoin) { + runTest( + {"POINT (1 1)", std::nullopt, "POINT (1 2)"}, + {"POINT (1 1)", "POINT (2 1)", std::nullopt}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"POINT (1 1)"}, + {"POINT (1 1)"}); +} + +TEST_F(SpatialJoinTest, testGeometryCollection) { + runTest( + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 1)"}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)", + "GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)"}); + + runTest( + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 2))", + "GEOMETRYCOLLECTION EMPTY", + "POINT (1 2)"}, + std::vector>{1.0, 1.0, 1.0}, + "ST_Distance(left_g, right_g) <= radius", + core::JoinType::kInner, + {"GEOMETRYCOLLECTION (POINT (1 1))", + "GEOMETRYCOLLECTION (POINT (1 1))", + "POINT (1 1)", + "POINT (1 1)"}, + {"GEOMETRYCOLLECTION (POINT (1 2))", + "POINT (1 2)", + "GEOMETRYCOLLECTION (POINT (1 2))", + "POINT (1 2)"}); +} + +TEST_F(SpatialJoinTest, testDistanceJoin) { + runTest( + {"POINT (1 2)", "POLYGON ((1 2, 2 2, 2 3, 1 3, 1 2))", std::nullopt}, + {"POINT (2 2)", + "POINT (1 1)", + std::nullopt, + "POINT (1 2)", + "POLYGON ((1 1, 1 0, 0 0, 0 1, 1 1))"}, + std::vector>{std::nullopt, 1.0, 0.0, 0.0, 1.0}, + "ST_Distance(left_g, right_g) <= radius", + core::JoinType::kInner, + {"POINT (1 2)", + "POLYGON ((1 2, 1 3, 2 3, 2 2, 1 2))", + "POINT (1 2)", + "POLYGON ((1 2, 1 3, 2 3, 2 2, 1 2))", + "POINT (1 2)", + "POLYGON ((1 2, 1 3, 2 3, 2 2, 1 2))"}, + {"POINT (1 1)", + "POINT (1 1)", + "POINT (1 2)", + "POINT (1 2)", + "POLYGON ((1 1, 1 0, 0 0, 0 1, 1 1))", + "POLYGON ((1 1, 1 0, 0 0, 0 1, 1 1))"}); +} + +TEST_F(SpatialJoinTest, testContainsPointsInPolygons) { + // Tests ST_Contains(polygon, point) - which polygons contain which points + std::vector> pointWkts = { + kPointX, kPointY, kPointZ, kPointW}; + std::vector> polygonWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + + // Expected: A contains X, B contains Y, B contains Z, A contains Y, D + // contains nothing from our test set Note: Y is in both A and B since they + // overlap + std::vector> pointOutputWkts = { + kPointX, kPointY, kPointY, kPointZ}; + std::vector> polygonOutputWkts = { + kPolygonA, kPolygonA, kPolygonB, kPolygonB}; + + runTest( + pointWkts, + polygonWkts, + std::nullopt, + "ST_Contains(right_g, left_g)", + core::JoinType::kInner, + pointOutputWkts, + polygonOutputWkts); +} + +TEST_F(SpatialJoinTest, testContainsPolygonsInPolygons) { + // Tests ST_Contains(polygon, polygon) - which polygons contain which polygons + // From the Java test, polygon C contains polygon B (C is larger and covers B) + std::vector> polygonWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + + // Each polygon contains itself, plus any additional containments + // Based on the spatial relations, we need to check which polygons actually + // contain others For now, test self-containment which should always work + std::vector> leftOutputWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + std::vector> rightOutputWkts = { + kPolygonA, kPolygonB, kPolygonC, kPolygonD}; + + runTest( + polygonWkts, + polygonWkts, + std::nullopt, + "ST_Contains(right_g, left_g)", + core::JoinType::kInner, + leftOutputWkts, + rightOutputWkts); +} + +TEST_F(SpatialJoinTest, testContainsLeftJoin) { + // Tests ST_Contains with LEFT join - all probe rows should appear + std::vector> pointWkts = { + kPointX, kPointY, kPointZ, kPointW}; + std::vector> polygonWkts = { + kPolygonA, kPolygonB}; + + // W is outside both polygons, so it should have null for the right side + std::vector> pointOutputWkts = { + kPointX, kPointY, kPointY, kPointZ, kPointW}; + std::vector> polygonOutputWkts = { + kPolygonA, kPolygonA, kPolygonB, kPolygonB, std::nullopt}; + + runTest( + pointWkts, + polygonWkts, + std::nullopt, + "ST_Contains(right_g, left_g)", + core::JoinType::kLeft, + pointOutputWkts, + polygonOutputWkts); +} + +TEST_F(SpatialJoinTest, testTouches) { + // Test ST_Touches - geometries that touch at boundary but don't overlap + // Polygon and a point on its boundary + std::vector> probeWkts = { + "POINT (1 2)", "POINT (3 2)", "LINESTRING (0 0, 1 1)"}; + std::vector> buildWkts = { + "POLYGON ((1 1, 1 4, 4 4, 4 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + // Point (1,2) touches both polygons (on their boundaries) + // Point (3,2) touches second polygon (on its boundary) + // LineString (0 0, 1 1) touches both polygons (endpoint at (1,1)) + std::vector> probeOutputWkts = { + "POINT (1 2)", + "POINT (1 2)", + "POINT (3 2)", + "LINESTRING (0 0, 1 1)", + "LINESTRING (0 0, 1 1)"}; + std::vector> buildOutputWkts = { + "POLYGON ((1 1, 1 4, 4 4, 4 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((1 1, 1 4, 4 4, 4 1, 1 1))", + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Touches(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testTouchesPolygons) { + // Test ST_Touches with two polygons that touch at a corner + std::vector> probeWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((5 5, 5 6, 6 6, 6 5, 5 5))"}; + std::vector> buildWkts = { + "POLYGON ((3 3, 3 5, 5 5, 5 3, 3 3))"}; + + // Both polygons touch the build polygon at corners: + // - First polygon touches at (3,3) + // - Second polygon touches at (5,5) + std::vector> probeOutputWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))", + "POLYGON ((5 5, 5 6, 6 6, 6 5, 5 5))"}; + std::vector> buildOutputWkts = { + "POLYGON ((3 3, 3 5, 5 5, 5 3, 3 3))", + "POLYGON ((3 3, 3 5, 5 5, 5 3, 3 3))"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Touches(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testCrosses) { + // Test ST_Crosses - geometries that cross each other + // A linestring crossing a polygon + std::vector> probeWkts = { + "LINESTRING (0 0, 4 4)", // Crosses both polygons + "LINESTRING (5 0, 5 4)", // Outside both + "LINESTRING (1 1, 2 2)" // Contained in polygon, doesn't cross + }; + std::vector> buildWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + // Only the first linestring crosses the polygon + std::vector> probeOutputWkts = { + "LINESTRING (0 0, 4 4)"}; + std::vector> buildOutputWkts = { + "POLYGON ((1 1, 1 3, 3 3, 3 1, 1 1))"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Crosses(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testCrossesLineStrings) { + // Test ST_Crosses with two linestrings that cross each other + std::vector> probeWkts = { + "LINESTRING (0 0, 1 1)", // Crosses first build linestring + "LINESTRING (2 2, 3 3)" // Parallel, doesn't cross + }; + std::vector> buildWkts = { + "LINESTRING (1 0, 0 1)"}; + + // Only the first linestring crosses + std::vector> probeOutputWkts = { + "LINESTRING (0 0, 1 1)"}; + std::vector> buildOutputWkts = { + "LINESTRING (1 0, 0 1)"}; + + runTest( + probeWkts, + buildWkts, + std::nullopt, + "ST_Crosses(left_g, right_g)", + core::JoinType::kInner, + probeOutputWkts, + buildOutputWkts); +} + +TEST_F(SpatialJoinTest, testEmptyBuild) { + runTest( + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + {}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {}, + {}); + runTest( + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + {}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + {std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt}); +} + +TEST_F(SpatialJoinTest, testEmptyProbe) { + runTest( + {}, + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kInner, + {}, + {}); + runTest( + {}, + {kPointX, std::nullopt, kPointY, kPointZ, kPointW}, + std::nullopt, + "ST_Intersects(left_g, right_g)", + core::JoinType::kLeft, + {}, + {}); +} + +TEST_F(SpatialJoinTest, failOnGroupedExecution) { + std::vector batches{ + makeRowVector({"wkt"}, {makeFlatVector({"POINT(0 0)"})})}; + core::PlanNodeId groupedScanNodeId; + auto planNodeIdGenerator = std::make_shared(); + auto planFragment = + PlanBuilder(planNodeIdGenerator) + .values(batches) + .capturePlanNodeId(groupedScanNodeId) + .project({"ST_GeometryFromText(wkt) AS left_g"}) + .spatialJoin( + PlanBuilder(planNodeIdGenerator) + .values(batches) + .project({"ST_GeometryFromText(wkt) AS right_g"}) + .localPartition({}) + .planNode(), + "ST_Intersects(left_g, right_g)", + "left_g", + "right_g", + std::nullopt, + {"left_g", "right_g"}, + core::JoinType::kInner) + .project( + {"ST_AsText(left_g) AS left_g", "ST_AsText(right_g) AS right_g"}) + .planFragment(); + planFragment.executionStrategy = core::ExecutionStrategy::kGrouped; + planFragment.groupedExecutionLeafNodeIds.emplace(groupedScanNodeId); + auto task = Task::create( + "task-grouped-join", + std::move(planFragment), + 0, + core::QueryCtx::create(driverExecutor_.get()), + Task::ExecutionMode::kParallel); + + VELOX_ASSERT_THROW( + task->start(1), "Spatial joins do not support grouped execution."); +} + +TEST_F(SpatialJoinTest, testLargeJoinSize) { + size_t numRows = 64; + size_t maxCoord = 17; + std::vector buildWkts; + buildWkts.reserve(numRows); + std::vector probeWkts; + probeWkts.reserve(numRows); + for (size_t i = 0; i < numRows; ++i) { + buildWkts.push_back( + fmt::format("POINT ({} {})", (i + 1) % maxCoord, (i + 2) % maxCoord)); + probeWkts.push_back( + fmt::format("POINT ({} {})", i % maxCoord, (i + 1) % maxCoord)); + } + + std::vector> buildWktsView; + buildWktsView.reserve(numRows); + std::vector> probeWktsView; + probeWktsView.reserve(numRows); + for (size_t i = 0; i < numRows; ++i) { + buildWktsView.push_back(buildWkts[i]); + probeWktsView.push_back(probeWkts[i]); + } + + std::vector> expectedLeftWkts; + expectedLeftWkts.reserve(numRows * numRows / maxCoord); + std::vector> expectedRightWkts; + expectedRightWkts.reserve(numRows * numRows / maxCoord); + for (size_t innerIdx = 0; innerIdx < numRows; ++innerIdx) { + for (size_t outerIdx = 0; outerIdx < numRows; ++outerIdx) { + if (probeWkts[outerIdx] == buildWkts[innerIdx]) { + expectedLeftWkts.push_back(probeWkts[outerIdx]); + expectedRightWkts.push_back(buildWkts[innerIdx]); + } + } + } + + for (bool separateProbeBatches : {false, true}) { + for (size_t maxBatchSize : {64, 13, 7, 5, 3, 2, 1}) { + runTestWithConfig( + buildWktsView, + probeWktsView, + std::nullopt, + "ST_Equals(left_g, right_g)", + core::JoinType::kInner, + expectedLeftWkts, + expectedRightWkts, + 1, + maxBatchSize, + separateProbeBatches); + } + } +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/SpillTest.cpp b/velox/exec/tests/SpillTest.cpp index 9097abd7f664..1ba324c042ab 100644 --- a/velox/exec/tests/SpillTest.cpp +++ b/velox/exec/tests/SpillTest.cpp @@ -23,6 +23,7 @@ #include "velox/common/file/FileSystems.h" #include "velox/exec/OperatorUtils.h" #include "velox/exec/Spill.h" +#include "velox/exec/tests/utils/MergeTestBase.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/type/Timestamp.h" @@ -394,6 +395,31 @@ class SpillTest : public ::testing::TestWithParam, ASSERT_EQ(numFinishedFiles, expectedFiles); } + void hybridSpillStateTest( + int64_t targetFileSize, + int numPartitions, + int numBatches, + int numDuplicates, + const std::vector& compareFlags, + uint64_t expectedNumSpilledFiles) { + int mergeWayThresholdBegin = 0; + int mergeWayThresholdEnd = 4; + for (int i = mergeWayThresholdBegin; i < mergeWayThresholdEnd; i++) { + if (i == 1) { + // Skip invalid value. + continue; + } + spillStateTest( + targetFileSize, + numPartitions, + numBatches, + numDuplicates, + compareFlags, + expectedNumSpilledFiles, + i); + } + } + // 'numDuplicates' specifies the number of duplicates generated for each // distinct sort key value in test. void spillStateTest( @@ -402,16 +428,18 @@ class SpillTest : public ::testing::TestWithParam, int numBatches, int numDuplicates, const std::vector& compareFlags, - uint64_t expectedNumSpilledFiles) { + uint64_t expectedNumSpilledFiles, + int mergeWayThreshold) { const int numRowsPerBatch = 1'000; - SCOPED_TRACE(fmt::format( - "targetFileSize: {}, numPartitions: {}, numBatches: {}, numDuplicates: {}, nullsFirst: {}, ascending: {}", - targetFileSize, - numPartitions, - numBatches, - numDuplicates, - compareFlags.empty() ? true : compareFlags[0].nullsFirst, - compareFlags.empty() ? true : compareFlags[0].ascending)); + SCOPED_TRACE( + fmt::format( + "targetFileSize: {}, numPartitions: {}, numBatches: {}, numDuplicates: {}, nullsFirst: {}, ascending: {}", + targetFileSize, + numPartitions, + numBatches, + numDuplicates, + compareFlags.empty() ? true : compareFlags[0].nullsFirst, + compareFlags.empty() ? true : compareFlags[0].ascending)); const auto prevGStats = common::globalSpillStats(); @@ -488,13 +516,20 @@ class SpillTest : public ::testing::TestWithParam, ASSERT_EQ(stats.spilledBytes, totalFileBytes); ASSERT_EQ(prevGStats.spilledBytes + totalFileBytes, newGStats.spilledBytes); + bool usePreMerge = mergeWayThreshold >= 2; for (const auto& partitionId : partitionIds) { auto spillFiles = state_->finish(partitionId); ASSERT_EQ(state_->numFinishedFiles(partitionId), 0); auto spillPartition = SpillPartition(SpillPartitionId(partitionId), std::move(spillFiles)); - auto merge = - spillPartition.createOrderedReader(1 << 20, pool(), &spillStats_); + auto spillConfig = common::SpillConfig(); + spillConfig.numMaxMergeFiles = mergeWayThreshold; + spillConfig.readBufferSize = 1 << 20; + spillConfig.writeBufferSize = 1 << 20; + spillConfig.updateAndCheckSpillLimitCb = [](int64_t) {}; + spillConfig.fileCreateConfig = ""; + std::unique_ptr> merge = + spillPartition.createOrderedReader(spillConfig, pool(), &spillStats_); int numReadBatches = 0; // We expect all the rows in dense increasing order. for (auto i = 0; i < numBatches * numRowsPerBatch; ++i) { @@ -527,8 +562,10 @@ class SpillTest : public ::testing::TestWithParam, } } ASSERT_EQ(nullptr, merge->next()); - // We do two append writes per each input batch. - ASSERT_EQ(numBatches, numReadBatches); + if (!usePreMerge) { + // We do two append writes per each input batch. + ASSERT_EQ(numBatches, numReadBatches); + } } const auto finalStats = spillStats_.copy(); @@ -567,7 +604,67 @@ class SpillTest : public ::testing::TestWithParam, ASSERT_TRUE(fs->exists(spilledFile)); } // Verify stats. - ASSERT_EQ(runtimeStats_["spillFileSize"].count, spilledFiles.size()); + if (!usePreMerge) { + ASSERT_EQ(runtimeStats_["spillFileSize"].count, spilledFiles.size()); + } else { + ASSERT_GE(runtimeStats_["spillFileSize"].count, spilledFiles.size()); + } + } + + void gatherMergeTest( + int32_t numValues, + int numMergeWays, + int targetSize, + bool useRandom) { + auto goldenVector = makeRowVector({ + makeFlatVector(numValues, [&](auto row) { return row; }), + }); + std::vector> mergeWays(numMergeWays); + for (int32_t value = 0; value < numValues; value++) { + int way = useRandom ? folly::Random::rand32() % numMergeWays + : value % numMergeWays; + mergeWays[way].push_back(value); + } + std::vector sources; + std::vector> streams; + std::vector sortKeys = {{0, {true, true}}}; + for (int way = 0; way < numMergeWays; way++) { + auto source = makeRowVector({ + makeFlatVector( + mergeWays[way].size(), + [&](auto row) { return mergeWays[way][row]; }), + }); + sources.push_back(source); + streams.push_back( + std::make_unique( + way, sortKeys, source)); + } + auto mergeTree = + std::make_unique>(std::move(streams)); + RowVectorPtr targetVector = std::static_pointer_cast( + BaseVector::create(sources[0]->type(), targetSize, pool_.get())); + std::vector bufferSources(targetSize); + std::vector bufferSourceIndices(targetSize); + for (int32_t batch = 0; batch * targetSize < numValues; batch++) { + int32_t valueBegin = batch * targetSize; + int32_t valueEnd = valueBegin + targetSize; + valueEnd = std::min(valueEnd, numValues); + VectorPtr tmp = std::move(targetVector); + BaseVector::prepareForReuse(tmp, targetSize); + targetVector = std::static_pointer_cast(tmp); + for (auto& child : targetVector->children()) { + child->resize(targetSize); + } + int count = 0; + testingGatherMerge( + targetVector, *mergeTree, count, bufferSources, bufferSourceIndices); + EXPECT_EQ(count, valueEnd - valueBegin); + auto result = targetVector->childAt(0).get(); + auto golden = goldenVector->childAt(0).get(); + for (int32_t row = 0; row < valueEnd - valueBegin; row++) { + EXPECT_TRUE(result->equalValueAt(golden, row, valueBegin + row)); + } + } } folly::Random::DefaultGenerator rng_; @@ -591,18 +688,18 @@ TEST_P(SpillTest, spillState) { // triggered by batch write. // Test with distinct sort keys. - spillStateTest(kGB, 2, 8, 1, {CompareFlags{true, true}}, 8); - spillStateTest(kGB, 2, 8, 1, {CompareFlags{true, false}}, 8); - spillStateTest(kGB, 2, 8, 1, {CompareFlags{false, true}}, 8); - spillStateTest(kGB, 2, 8, 1, {CompareFlags{false, false}}, 8); - spillStateTest(kGB, 2, 8, 1, {}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{true, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{true, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{false, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {CompareFlags{false, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 1, {}, 8); // Test with duplicate sort keys. - spillStateTest(kGB, 2, 8, 8, {CompareFlags{true, true}}, 8); - spillStateTest(kGB, 2, 8, 8, {CompareFlags{true, false}}, 8); - spillStateTest(kGB, 2, 8, 8, {CompareFlags{false, true}}, 8); - spillStateTest(kGB, 2, 8, 8, {CompareFlags{false, false}}, 8); - spillStateTest(kGB, 2, 8, 8, {}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{true, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{true, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{false, true}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {CompareFlags{false, false}}, 8); + hybridSpillStateTest(kGB, 2, 8, 8, {}, 8); } TEST_P(SpillTest, spillTimestamp) { @@ -646,11 +743,17 @@ TEST_P(SpillTest, spillTimestamp) { state.testingNonEmptySpilledPartitionIdSet().contains(partitionId)); SpillPartition spillPartition(SpillPartitionId{0}, state.finish(partitionId)); + auto spillConfig = common::SpillConfig(); + spillConfig.numMaxMergeFiles = 2; + spillConfig.readBufferSize = 1 << 20; + spillConfig.writeBufferSize = 1 << 20; + spillConfig.updateAndCheckSpillLimitCb = [](int64_t) {}; + spillConfig.fileCreateConfig = ""; auto merge = - spillPartition.createOrderedReader(1 << 20, pool(), &spillStats_); + spillPartition.createOrderedReader(spillConfig, pool(), &spillStats_); ASSERT_TRUE(merge != nullptr); ASSERT_TRUE( - spillPartition.createOrderedReader(1 << 20, pool(), &spillStats_) == + spillPartition.createOrderedReader(spillConfig, pool(), &spillStats_) == nullptr); for (auto i = 0; i < timeValues.size(); ++i) { auto* stream = merge->next(); @@ -668,18 +771,18 @@ TEST_P(SpillTest, spillStateWithSmallTargetFileSize) { // write. // Test with distinct sort keys. - spillStateTest(1, 2, 8, 1, {CompareFlags{true, true}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {CompareFlags{true, false}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {CompareFlags{false, true}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {CompareFlags{false, false}}, 8 * 2); - spillStateTest(1, 2, 8, 1, {}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{true, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{true, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{false, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {CompareFlags{false, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 1, {}, 8 * 2); // Test with duplicated sort keys. - spillStateTest(1, 2, 8, 8, {CompareFlags{true, false}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {CompareFlags{true, true}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {CompareFlags{false, false}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {CompareFlags{false, true}}, 8 * 2); - spillStateTest(1, 2, 8, 8, {}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{true, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{true, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{false, false}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {CompareFlags{false, true}}, 8 * 2); + hybridSpillStateTest(1, 2, 8, 8, {}, 8 * 2); } TEST_P(SpillTest, spillPartitionId) { @@ -958,12 +1061,15 @@ TEST_P(SpillTest, spillPartitionFunctionBasic) { std::vector columns; columns.push_back( makeFlatVector(numRows, [](auto row) { return row; })); - columns.push_back(makeFlatVector( - numRows, [](auto row) { return fmt::format("key_{}", row); })); - columns.push_back(makeFlatVector( - numRows, [](auto row) { return fmt::format("key_{}_{}", row, row); })); - columns.push_back(makeFlatVector( - numRows, [](auto row) { return fmt::format("val_{}", row); })); + columns.push_back(makeFlatVector(numRows, [](auto row) { + return fmt::format("key_{}", row); + })); + columns.push_back(makeFlatVector(numRows, [](auto row) { + return fmt::format("key_{}_{}", row, row); + })); + columns.push_back(makeFlatVector(numRows, [](auto row) { + return fmt::format("val_{}", row); + })); inputVectors.push_back(makeRowVector(columns)); } @@ -1522,6 +1628,17 @@ TEST(SpillTest, scopedSpillInjectionRegex) { } } +TEST_P(SpillTest, gatherMerge) { + gatherMergeTest(1234, 2, 10, false); + gatherMergeTest(1234, 2, 100, false); + gatherMergeTest(1234, 10, 10, false); + gatherMergeTest(1234, 10, 100, false); + gatherMergeTest(1234, 2, 10, true); + gatherMergeTest(1234, 2, 100, true); + gatherMergeTest(1234, 10, 10, true); + gatherMergeTest(1234, 10, 100, true); +} + VELOX_INSTANTIATE_TEST_SUITE_P( SpillTestSuite, SpillTest, diff --git a/velox/exec/tests/SpillerBenchmarkBase.cpp b/velox/exec/tests/SpillerBenchmarkBase.cpp index ab94f79c62e3..8eab9d0d7560 100644 --- a/velox/exec/tests/SpillerBenchmarkBase.cpp +++ b/velox/exec/tests/SpillerBenchmarkBase.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -56,7 +57,7 @@ DEFINE_uint32( "The number of key columns"); DEFINE_uint32( spiller_benchmark_spill_executor_size, - std::thread::hardware_concurrency(), + folly::hardware_concurrency(), "The spiller executor size in number of threads"); DEFINE_uint32( spiller_benchmark_spill_vector_size, diff --git a/velox/exec/tests/SpillerTest.cpp b/velox/exec/tests/SpillerTest.cpp index e9f435495960..ecec2f6f844d 100644 --- a/velox/exec/tests/SpillerTest.cpp +++ b/velox/exec/tests/SpillerTest.cpp @@ -115,7 +115,7 @@ struct TestParam { poolSize, compressionKindToString(compressionKind), std::to_string(enablePrefixSort), - joinTypeName(joinType)); + core::JoinTypeName::toName(joinType)); } }; @@ -331,13 +331,14 @@ class SpillerTest : public exec::test::RowContainerTestBase { bool ascending = true, bool makeError = false, uint64_t readBufferSize = 1 << 20) { - SCOPED_TRACE(fmt::format( - "spillType: {} numDuplicates: {} outputBatchSize: {} ascending: {} makeError: {}", - typeName(type_), - numDuplicates, - outputBatchSize, - ascending, - makeError)); + SCOPED_TRACE( + fmt::format( + "spillType: {} numDuplicates: {} outputBatchSize: {} ascending: {} makeError: {}", + typeName(type_), + numDuplicates, + outputBatchSize, + ascending, + makeError)); constexpr int32_t kNumRows = 5'000; const auto prevGStats = common::globalSpillStats(); @@ -714,11 +715,11 @@ class SpillerTest : public exec::test::RowContainerTestBase { // We make a merge reader that merges the spill files and the rows that // are still in the RowContainer. auto merge = spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + spillConfig_, pool(), &spillStats_); ASSERT_TRUE(merge != nullptr); ASSERT_TRUE( spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_) == nullptr); + spillConfig_, pool(), &spillStats_) == nullptr); // We read the spilled data back and check that it matches the sorted // order of the partition. @@ -865,14 +866,15 @@ class SpillerTest : public exec::test::RowContainerTestBase { ss << partitionId.toString() << " "; } ss << "]"; - SCOPED_TRACE(fmt::format( - "Param: {}, numSpillers: {}, numBatchRows: {}, numAppendBatches: {}, targetFileSize: {}, spillPartitionIdSet: {}", - param_.toString(), - numSpillers, - numBatchRows, - numAppendBatches, - targetFileSize, - ss.str())); + SCOPED_TRACE( + fmt::format( + "Param: {}, numSpillers: {}, numBatchRows: {}, numAppendBatches: {}, targetFileSize: {}, spillPartitionIdSet: {}", + param_.toString(), + numSpillers, + numBatchRows, + numAppendBatches, + targetFileSize, + ss.str())); std::vector> inputsByPartition(numPartitions_); @@ -1573,7 +1575,7 @@ TEST_P(AggregationOutputOnly, basic) { ASSERT_EQ(spillPartitionSet.size(), 1); auto spillPartition = std::move(spillPartitionSet.begin()->second); auto merge = spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + spillConfig_, pool(), &spillStats_); for (auto i = 0; i < expectedNumSpilledRows; ++i) { auto* stream = merge->next(); @@ -1686,8 +1688,8 @@ TEST_P(SortOutputOnly, basic) { auto spillPartition = std::move(spillPartitionSet.begin()->second); const int expectedNumSpilledRows = numListedRows; - auto merge = spillPartition->createOrderedReader( - spillConfig_.readBufferSize, pool(), &spillStats_); + auto merge = + spillPartition->createOrderedReader(spillConfig_, pool(), &spillStats_); if (expectedNumSpilledRows == 0) { ASSERT_TRUE(merge == nullptr); } else { diff --git a/velox/exec/tests/SplitListenerTest.cpp b/velox/exec/tests/SplitListenerTest.cpp new file mode 100644 index 000000000000..9f7b3a35841e --- /dev/null +++ b/velox/exec/tests/SplitListenerTest.cpp @@ -0,0 +1,231 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/Task.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::exec::test { +namespace { + +std::unordered_map> + numOfSplitsByUuid; + +class CountSplitListener : public SplitListener { + public: + CountSplitListener(const std::string& taskId, const std::string& taskUuid) + : SplitListener(taskId, taskUuid) {} + + void onTaskCompletion() override { + auto& map = numOfSplitsByUuid[taskUuid_]; + for (auto& [planNodeId, count] : counts_) { + map[planNodeId] += count; + } + } + + void onAddSplit( + const core::PlanNodeId& planNodeId, + const exec::Split& /*split*/) override { + counts_[planNodeId]++; + } + + private: + std::unordered_map counts_; +}; + +class CountAgainSplitListener : public SplitListener { + public: + CountAgainSplitListener( + const std::string& taskId, + const std::string& taskUuid) + : SplitListener(taskId, taskUuid) {} + + void onTaskCompletion() override { + auto& map = numOfSplitsByUuid[taskUuid_]; + for (auto& [planNodeId, count] : counts_) { + map[planNodeId] += count; + } + } + + void onAddSplit( + const core::PlanNodeId& planNodeId, + const exec::Split& /*split*/) override { + counts_[planNodeId]++; + } + + private: + std::unordered_map counts_; +}; + +template +class TestSplitListenerFactory : public SplitListenerFactory { + public: + ~TestSplitListenerFactory() override = default; + + std::unique_ptr create( + const std::string& taskId, + const std::string& taskUuid, + const core::QueryConfig& /*config*/) override { + return std::make_unique(taskId, taskUuid); + } +}; + +class DummyListenerFactory : public SplitListenerFactory { + public: + ~DummyListenerFactory() override = default; + + std::unique_ptr create( + const std::string& /*taskId*/, + const std::string& /*taskUuid*/, + const core::QueryConfig& /*config*/) override { + return nullptr; + } +}; + +class SplitListenerTest : public HiveConnectorTestBase { + public: + SplitListenerTest() { + countSplitListenerFactory_ = + std::make_shared>(); + countAgainSplitListenerFactory_ = + std::make_shared>(); + dummyListenerFactory_ = std::make_shared(); + } + + protected: + void makeTable() { + rowType_ = ROW({"c0"}, {BIGINT()}); + RowVectorPtr table = + makeRowVector({"c0"}, {makeFlatVector({1, 2, 3, 4, 5})}); + directory_ = TempDirectoryPath::create(); + auto directoryPath = directory_->getPath(); + auto tablePath = fmt::format("{}/t", directoryPath); + auto fs = filesystems::getFileSystem(tablePath, {}); + fs->mkdir(tablePath); + // Write the table three times to make multiple splits. + for (auto i = 0; i < 3; ++i) { + auto filePath = fmt::format("{}/f{}", tablePath, i); + writeToFile(filePath, table); + filePaths_.emplace_back(filePath); + } + } + + std::vector> getSplits() { + std::vector> splits; + splits.reserve(filePaths_.size()); + for (const auto& filePath : filePaths_) { + splits.emplace_back( + connector::hive::HiveConnectorSplitBuilder(filePath) + .connectorId(kHiveConnectorId) + .fileFormat(dwio::common::FileFormat::DWRF) + .build()); + } + return splits; + } + + RowTypePtr rowType_; + std::shared_ptr directory_; + std::vector filePaths_; + + std::shared_ptr> + countSplitListenerFactory_; + std::shared_ptr> + countAgainSplitListenerFactory_; + std::shared_ptr dummyListenerFactory_; +}; + +} // namespace +} // namespace facebook::velox::exec::test + +namespace facebook::velox::exec::test { + +TEST_F(SplitListenerTest, basic) { + ASSERT_TRUE(exec::registerSplitListenerFactory(countSplitListenerFactory_)); + // Not allowing register the same split listener factory twice. + ASSERT_FALSE(exec::registerSplitListenerFactory(countSplitListenerFactory_)); + ASSERT_TRUE(exec::registerSplitListenerFactory(dummyListenerFactory_)); + + makeTable(); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId nodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType_) + .capturePlanNodeId(nodeId) + .project({"c0 + 1 as c0"}) + .planNode(); + + auto checkCount = [&](std::shared_ptr& task) { + const auto uuid = task->uuid(); + task = nullptr; + EXPECT_GT(numOfSplitsByUuid.count(uuid), 0); + EXPECT_EQ(numOfSplitsByUuid[uuid].size(), 1); + EXPECT_GT(numOfSplitsByUuid[uuid].count(nodeId), 0); + EXPECT_EQ(numOfSplitsByUuid[uuid][nodeId], 3); + }; + + for (const auto addWithSequence : {false, true}) { + const auto splits = getSplits(); + std::shared_ptr task; + AssertQueryBuilder(plan) + .splits(nodeId, splits) + .addSplitWithSequence(addWithSequence) + .copyResults(pool_.get(), task); + checkCount(task); + } + + ASSERT_TRUE(exec::unregisterSplitListenerFactory(countSplitListenerFactory_)); + ASSERT_TRUE(exec::unregisterSplitListenerFactory(dummyListenerFactory_)); +} + +TEST_F(SplitListenerTest, multipleListeners) { + ASSERT_TRUE(exec::registerSplitListenerFactory(countSplitListenerFactory_)); + ASSERT_TRUE( + exec::registerSplitListenerFactory(countAgainSplitListenerFactory_)); + + makeTable(); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId nodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(rowType_) + .capturePlanNodeId(nodeId) + .project({"c0 + 1 as c0"}) + .planNode(); + + const auto splits = getSplits(); + std::shared_ptr task; + AssertQueryBuilder(plan) + .splits(nodeId, splits) + .copyResults(pool_.get(), task); + + const auto uuid = task->uuid(); + task = nullptr; + EXPECT_GT(numOfSplitsByUuid.count(uuid), 0); + EXPECT_EQ(numOfSplitsByUuid[uuid].size(), 1); + EXPECT_GT(numOfSplitsByUuid[uuid].count(nodeId), 0); + EXPECT_EQ(numOfSplitsByUuid[uuid][nodeId], 6); + + ASSERT_TRUE(exec::unregisterSplitListenerFactory(countSplitListenerFactory_)); + ASSERT_TRUE( + exec::unregisterSplitListenerFactory(countAgainSplitListenerFactory_)); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/StreamingAggregationTest.cpp b/velox/exec/tests/StreamingAggregationTest.cpp index 224a044feca9..1e23fa58bf8a 100644 --- a/velox/exec/tests/StreamingAggregationTest.cpp +++ b/velox/exec/tests/StreamingAggregationTest.cpp @@ -27,8 +27,14 @@ namespace { using namespace facebook::velox::exec::test; -class StreamingAggregationTest : public HiveConnectorTestBase, - public testing::WithParamInterface { +struct TestParams { + int32_t streamingMinOutputBatchSize; + uint64_t preferredOutputBatchBytes; +}; + +class StreamingAggregationTest + : public HiveConnectorTestBase, + public testing::WithParamInterface { protected: void SetUp() override { HiveConnectorTestBase::SetUp(); @@ -36,7 +42,11 @@ class StreamingAggregationTest : public HiveConnectorTestBase, } int32_t flushRows() { - return GetParam(); + return GetParam().streamingMinOutputBatchSize; + } + + uint64_t preferredOutputBatchBytes() { + return GetParam().preferredOutputBatchBytes; } AssertQueryBuilder& config( @@ -48,7 +58,10 @@ class StreamingAggregationTest : public HiveConnectorTestBase, std::to_string(outputBatchSize)) .config( core::QueryConfig::kStreamingAggregationMinOutputBatchRows, - std::to_string(flushRows())); + std::to_string(flushRows())) + .config( + core::QueryConfig::kPreferredOutputBatchBytes, + std::to_string(preferredOutputBatchBytes())); } void testAggregation( @@ -139,21 +152,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, void testSortedAggregationWithBarrier( const std::vector& keys, - uint32_t outputBatchSize) { + uint32_t outputBatchSize, + uint32_t expectedNumOuputBatches) { const auto inputVectors = addPayload(keys, 2); - int expectedNumOuputBatchesWithBarrier{0}; - int numInputRows{0}; - for (const auto& inputVector : inputVectors) { - numInputRows += inputVector->size(); - if (outputBatchSize > inputVector->size()) { - ++expectedNumOuputBatchesWithBarrier; - } else { - expectedNumOuputBatchesWithBarrier += - bits::divRoundUp(inputVector->size(), outputBatchSize); - } - } - const int expectedNumOuputBatches = - bits::divRoundUp(numInputRows, outputBatchSize); std::vector> tempFiles; const int numSplits = keys.size(); @@ -168,8 +169,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, core::PlanNodeId aggregationNodeId; auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( {"c0"}, @@ -201,8 +203,7 @@ class StreamingAggregationTest : public HiveConnectorTestBase, velox::exec::toPlanStats(taskStats) .at(aggregationNodeId) .outputVectors, - barrierExecution ? expectedNumOuputBatchesWithBarrier - : expectedNumOuputBatches); + expectedNumOuputBatches); } } @@ -247,22 +248,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, void testDistinctAggregationWithBarrier( const std::vector& keys, - uint32_t outputBatchSize) { + uint32_t outputBatchSize, + uint32_t expectedNumOuputBatches) { const auto inputVectors = addPayload(keys, 2); - int expectedNumOuputBatchesWithBarrier{0}; - int numInputRows{0}; - for (const auto& inputVector : inputVectors) { - numInputRows += inputVector->size(); - if (outputBatchSize > inputVector->size()) { - ++expectedNumOuputBatchesWithBarrier; - } else { - expectedNumOuputBatchesWithBarrier += - bits::divRoundUp(inputVector->size(), outputBatchSize); - } - } - const int expectedNumOuputBatches = - bits::divRoundUp(numInputRows, outputBatchSize); - std::vector> tempFiles; const int numSplits = keys.size(); for (int32_t i = 0; i < numSplits; ++i) { @@ -277,8 +265,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, core::PlanNodeId aggregationNodeId; auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( {"c0"}, @@ -311,8 +300,7 @@ class StreamingAggregationTest : public HiveConnectorTestBase, velox::exec::toPlanStats(taskStats) .at(aggregationNodeId) .outputVectors, - barrierExecution ? expectedNumOuputBatchesWithBarrier - : expectedNumOuputBatches); + expectedNumOuputBatches); } } @@ -321,8 +309,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( {"c0"}, {}, {}, core::AggregationNode::Step::kSingle, false) @@ -346,8 +335,7 @@ class StreamingAggregationTest : public HiveConnectorTestBase, velox::exec::toPlanStats(taskStats) .at(aggregationNodeId) .outputVectors, - barrierExecution ? expectedNumOuputBatchesWithBarrier - : expectedNumOuputBatches); + expectedNumOuputBatches); } } } @@ -513,19 +501,6 @@ class StreamingAggregationTest : public HiveConnectorTestBase, const std::vector& keys, uint32_t outputBatchSize) { const auto inputVectors = addPayload(keys); - int expectedNumOuputBatchesWithBarrier{0}; - int numInputRows{0}; - for (const auto& inputVector : inputVectors) { - numInputRows += inputVector->size(); - if (outputBatchSize > inputVector->size()) { - ++expectedNumOuputBatchesWithBarrier; - } else { - expectedNumOuputBatchesWithBarrier += - bits::divRoundUp(inputVector->size(), outputBatchSize); - } - } - const int expectedNumOuputBatches = - bits::divRoundUp(numInputRows, outputBatchSize); std::vector> tempFiles; const int numSplits = keys.size(); for (int32_t i = 0; i < numSplits; ++i) { @@ -541,8 +516,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( keys[0]->type()->asRow().names(), @@ -578,12 +554,6 @@ class StreamingAggregationTest : public HiveConnectorTestBase, const auto taskStats = task->taskStats(); ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); ASSERT_EQ(taskStats.numFinishedSplits, numSplits); - ASSERT_EQ( - velox::exec::toPlanStats(taskStats) - .at(aggregationNodeId) - .outputVectors, - barrierExecution ? expectedNumOuputBatchesWithBarrier - : expectedNumOuputBatches); EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed); } } @@ -592,8 +562,9 @@ class StreamingAggregationTest : public HiveConnectorTestBase, core::PlanNodeId aggregationNodeId; auto plan = PlanBuilder(planNodeIdGenerator) .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) + .outputType( + std::dynamic_pointer_cast( + inputVectors[0]->type())) .endTableScan() .streamingAggregation( keys[0]->type()->asRow().names(), @@ -626,12 +597,6 @@ class StreamingAggregationTest : public HiveConnectorTestBase, const auto taskStats = task->taskStats(); ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0); ASSERT_EQ(taskStats.numFinishedSplits, numSplits); - ASSERT_EQ( - velox::exec::toPlanStats(taskStats) - .at(aggregationNodeId) - .outputVectors, - barrierExecution ? expectedNumOuputBatchesWithBarrier - : expectedNumOuputBatches); EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed); } } @@ -641,13 +606,32 @@ class StreamingAggregationTest : public HiveConnectorTestBase, VELOX_INSTANTIATE_TEST_SUITE_P( StreamingAggregationTest, StreamingAggregationTest, - testing::ValuesIn({0, 1, 64, std::numeric_limits::max()}), - [](const testing::TestParamInfo& info) { + testing::Values( + TestParams{0, 1}, + TestParams{0, 1024}, + TestParams{0, std::numeric_limits::max()}, + TestParams{1, 1}, + TestParams{1, 1024}, + TestParams{1, std::numeric_limits::max()}, + TestParams{64, 1}, + TestParams{64, 1024}, + TestParams{64, std::numeric_limits::max()}, + TestParams{std::numeric_limits::max(), 1}, + TestParams{std::numeric_limits::max(), 1024}, + TestParams{ + std::numeric_limits::max(), + std::numeric_limits::max()}), + [](const testing::TestParamInfo& info) { return fmt::format( - "streamingMinOutputBatchSize_{}", - info.param == std::numeric_limits::max() + "streamingMinOutputBatchSize_{}_preferredOutputBatchBytes_{}", + info.param.streamingMinOutputBatchSize == + std::numeric_limits::max() ? "inf" - : std::to_string(info.param)); + : std::to_string(info.param.streamingMinOutputBatchSize), + info.param.preferredOutputBatchBytes == + std::numeric_limits::max() + ? "inf" + : std::to_string(info.param.preferredOutputBatchBytes)); }); TEST_P(StreamingAggregationTest, smallInputBatches) { @@ -779,9 +763,7 @@ TEST_P(StreamingAggregationTest, closeUninitialized) { std::make_shared( BIGINT(), "c0"), std::make_shared( - BIGINT(), - std::vector{}, - "do-not-exist")}, + BIGINT(), "do-not-exist")}, source); }) .partialStreamingAggregation({"c0"}, {"sum(x)"}) @@ -793,8 +775,7 @@ TEST_P(StreamingAggregationTest, closeUninitialized) { } TEST_P(StreamingAggregationTest, sortedAggregations) { - auto size = 1024; - + const auto size = 512; std::vector keys = { makeFlatVector(size, [](auto row) { return row; }), makeFlatVector(size, [size](auto row) { return (size + row); }), @@ -804,7 +785,7 @@ TEST_P(StreamingAggregationTest, sortedAggregations) { 78, [size](auto row) { return (3 * size + row); }), }; - testSortedAggregation(keys, 1024); + testSortedAggregation(keys, 512); testSortedAggregation(keys, 32); } @@ -878,6 +859,189 @@ TEST_P(StreamingAggregationTest, clusteredInput) { } } +TEST_P(StreamingAggregationTest, clusteredInputWithOutputSplit) { + std::vector keysWithOverlap = { + makeNullableFlatVector({1, 1, std::nullopt, 2, 2}), + makeFlatVector({2, 3, 3, 4}), + makeFlatVector({5, 6, 6, 6}), + makeFlatVector({6, 6, 6, 6}), + makeFlatVector({6, 7, 8}), + }; + auto dataWithOverlap = addPayload(keysWithOverlap, 1); + auto planWithOverlap = PlanBuilder() + .values(dataWithOverlap) + .streamingAggregation( + {"c0"}, + {"arbitrary(c1)", "array_agg(c1)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + const auto expectedWithOverlap = makeRowVector({ + makeNullableFlatVector({1, std::nullopt, 2, 3, 4, 5, 6, 7, 8}), + makeFlatVector({0, 2, 3, 6, 8, 9, 10, 18, 19}), + makeArrayVector( + {{0, 1}, + {2}, + {3, 4, 5}, + {6, 7}, + {8}, + {9}, + {10, 11, 12, 13, 14, 15, 16, 17}, + {18}, + {19}}), + }); + for (auto batchSize : {1, 3, 20}) { + SCOPED_TRACE(fmt::format("batchSize={}", batchSize)); + config(AssertQueryBuilder(planWithOverlap), batchSize) + .assertResults(expectedWithOverlap); + } + + std::vector keysWithoutOverlap = { + makeNullableFlatVector({1, 1, std::nullopt, 2, 2}), + makeFlatVector({3, 3, 4, 4}), + makeFlatVector({5, 6, 6, 7}), + makeFlatVector({8, 8, 9, 9}), + makeFlatVector({10, 11, 12}), + }; + auto dataWithoutOverlap = addPayload(keysWithoutOverlap, 1); + auto planWithoutOverlap = PlanBuilder() + .values(dataWithoutOverlap) + .streamingAggregation( + {"c0"}, + {"arbitrary(c1)", "array_agg(c1)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + const auto expectedWithoutOverlap = makeRowVector( + {makeNullableFlatVector( + {1, std::nullopt, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), + makeFlatVector({0, 2, 3, 5, 7, 9, 10, 12, 13, 15, 17, 18, 19}), + makeArrayVector( + {{0, 1}, + {2}, + {3, 4}, + {5, 6}, + {7, 8}, + {9}, + {10, 11}, + {12}, + {13, 14}, + {15, 16}, + {17}, + {18}, + {19}})}); + for (auto batchSize : {1, 3, 20}) { + SCOPED_TRACE(fmt::format("batchSize={}", batchSize)); + config(AssertQueryBuilder(planWithoutOverlap), batchSize) + .assertResults(expectedWithoutOverlap); + } + + std::vector mixedKeys = { + makeNullableFlatVector({1, 1, std::nullopt, std::nullopt, 2}), + makeFlatVector({3, 3, 4, 4}), + makeFlatVector({6, 6, 6, 7}), + makeFlatVector({7, 8, 9, 9}), + makeFlatVector({10, 11, 12}), + }; + auto mixedData = addPayload(mixedKeys, 1); + auto mixedPlan = PlanBuilder() + .values(mixedData) + .streamingAggregation( + {"c0"}, + {"arbitrary(c1)", "array_agg(c1)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + const auto expectedMixedResult = makeRowVector( + {makeNullableFlatVector( + {1, std::nullopt, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12}), + makeFlatVector({0, 2, 4, 5, 7, 9, 12, 14, 15, 17, 18, 19}), + makeArrayVector( + {{0, 1}, + {2, 3}, + {4}, + {5, 6}, + {7, 8}, + {9, 10, 11}, + {12, 13}, + {14}, + {15, 16}, + {17}, + {18}, + {19}})}); + for (auto batchSize : {1, 3, 20}) { + SCOPED_TRACE(fmt::format("batchSize={}", batchSize)); + config(AssertQueryBuilder(mixedPlan), batchSize) + .assertResults(expectedMixedResult); + } +} + +TEST_P(StreamingAggregationTest, clusteredInputWithNulls) { + std::vector keyVectors = { + makeFlatVector({1, 1, 1, 2, 2, 2, 3, 3, 3, 3}), + makeFlatVector({4, 4, 4, 4, 5, 5, 5, 5, 6, 6}), + makeFlatVector({7, 7, 7, 8}), + makeFlatVector({8, 8, 8, 9, 9, 9, 10, 10, 10}), + makeFlatVector({11, 11, 11}), + }; + std::vector dataVectors = { + makeRowVector( + {makeFlatVector({1, 1, 1, 2, 2, 2, 3, 3, 3, 3}), + makeFlatVector({1, 1, 1, 2, 2, 2, 3, 3, 3, 3})}, + [](auto row) { return row < 3; }), + makeRowVector( + {makeFlatVector({4, 4, 4, 4, 5, 5, 5, 5, 6, 6}), + makeFlatVector({4, 4, 4, 4, 5, 5, 5, 5, 6, 6})}, + [](auto row) { return row < 4 || row > 7; }), + + makeRowVector( + {makeFlatVector({7, 7, 7, 8}), + makeFlatVector({7, 7, 7, 8})}, + [](auto row) { return row > 2; }), + + makeRowVector( + {makeFlatVector({8, 8, 8, 9, 9, 9, 10, 10, 10}), + makeFlatVector({8, 8, 8, 9, 9, 9, 10, 10, 10})}, + [](auto row) { return row < 3; }), + + makeRowVector( + {makeFlatVector({11, 11, 11}), + makeFlatVector({11, 11, 11})}, + [](auto /*unused*/) { return true; })}; + ASSERT_EQ(keyVectors.size(), dataVectors.size()); + std::vector rowVectors; + for (int i = 0; i < keyVectors.size(); ++i) { + rowVectors.emplace_back(makeRowVector({keyVectors[i], dataVectors[i]})); + } + + const auto plan = + PlanBuilder() + .values(rowVectors) + .partialStreamingAggregation({"c0"}, {"count(c1)", "arbitrary(c1)"}) + .finalAggregation() + .planNode(); + + const auto expected = makeRowVector( + {makeNullableFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), + makeFlatVector({0, 3, 4, 0, 4, 0, 3, 0, 3, 3, 0}), + makeRowVector( + {makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}), + makeFlatVector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})}, + [](auto row) { + if (row == 0 || row == 3 || row == 5 || row == 7 || row == 10) { + return true; + } + return false; + })}); + for (auto batchSize : {20}) { + SCOPED_TRACE(fmt::format("batchSize={}", batchSize)); + config(AssertQueryBuilder(plan), batchSize).assertResults(expected); + } +} + TEST_P(StreamingAggregationTest, sortedAggregationsWithBarrier) { const auto size = 1024; const std::vector keys = { @@ -889,8 +1053,8 @@ TEST_P(StreamingAggregationTest, sortedAggregationsWithBarrier) { 78, [size](auto row) { return (3 * size + row); }), }; - testSortedAggregationWithBarrier(keys, 1024); - testSortedAggregationWithBarrier(keys, 32); + testSortedAggregationWithBarrier(keys, 1024, 4); + testSortedAggregationWithBarrier(keys, 32, 4); } TEST_P(StreamingAggregationTest, clusteredInputWithBarrier) { @@ -909,16 +1073,17 @@ TEST_P(StreamingAggregationTest, clusteredInputWithBarrier) { auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId streamingAggregationNodeId; - auto plan = PlanBuilder(planNodeIdGenerator) - .startTableScan() - .outputType(std::dynamic_pointer_cast( - inputVectors[0]->type())) - .endTableScan() - .partialStreamingAggregation( - {"c0"}, {"count(c1)", "arbitrary(c1)", "array_agg(c1)"}) - .capturePlanNodeId(streamingAggregationNodeId) - .finalAggregation() - .planNode(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .startTableScan() + .outputType( + std::dynamic_pointer_cast(inputVectors[0]->type())) + .endTableScan() + .partialStreamingAggregation( + {"c0"}, {"count(c1)", "arbitrary(c1)", "array_agg(c1)"}) + .capturePlanNodeId(streamingAggregationNodeId) + .finalAggregation() + .planNode(); const auto expected = makeRowVector( {makeNullableFlatVector( {1, 2, std::nullopt, 3, 4, 9, 10, 11, 12, 17, 18, 19}), @@ -929,28 +1094,20 @@ TEST_P(StreamingAggregationTest, clusteredInputWithBarrier) { struct { int batchSize; bool barrierExecution; + int numExpectedOutputBatches; std::string debugString() const { return fmt::format( - "batchSize={}, barrierExecution={}", batchSize, barrierExecution); + "batchSize={}, barrierExecution={}, numExpectedOutputBatches={}", + batchSize, + barrierExecution, + numExpectedOutputBatches); } - } testSettings[] = {{3, true}, {3, false}, {20, true}, {20, false}}; + } testSettings[] = { + {3, true, 3}, {3, false, 3}, {20, true, 3}, {20, false, 1}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - int expectedNumOuputBatchesWithBarrier{0}; - int numInputRows{0}; - for (const auto& inputVector : inputVectors) { - numInputRows += inputVector->size(); - if (testData.batchSize > inputVector->size()) { - ++expectedNumOuputBatchesWithBarrier; - } else { - expectedNumOuputBatchesWithBarrier += - bits::divRoundUp(inputVector->size(), testData.batchSize); - } - } - const int expectedNumOuputBatches = - bits::divRoundUp(numInputRows, testData.batchSize); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .splits(makeHiveConnectorSplits(tempFiles)) @@ -967,8 +1124,7 @@ TEST_P(StreamingAggregationTest, clusteredInputWithBarrier) { velox::exec::toPlanStats(taskStats) .at(streamingAggregationNodeId) .outputVectors, - testData.barrierExecution ? expectedNumOuputBatchesWithBarrier - : expectedNumOuputBatches); + testData.numExpectedOutputBatches); } } @@ -983,8 +1139,8 @@ TEST_P(StreamingAggregationTest, distinctAggregationsWithBarrier) { 78, [size](auto row) { return (3 * size + row); }), }; - testDistinctAggregationWithBarrier(keys, 1024); - testDistinctAggregationWithBarrier(keys, 32); + testDistinctAggregationWithBarrier(keys, 1024, 4); + testDistinctAggregationWithBarrier(keys, 32, 4); std::vector multiKeys = { makeRowVector({ @@ -1003,5 +1159,287 @@ TEST_P(StreamingAggregationTest, distinctAggregationsWithBarrier) { testMultiKeyDistinctAggregationWithBarrier(multiKeys, 1024); testMultiKeyDistinctAggregationWithBarrier(multiKeys, 3); } + +TEST_P(StreamingAggregationTest, constantInput) { + auto data = makeRowVector({makeFlatVector({1, 1, 2, 2})}); + auto plan = PlanBuilder() + .values({data}) + .partialStreamingAggregation({"c0"}, {"array_agg(3)"}) + .finalAggregation() + .planNode(); + auto expected = makeRowVector({ + makeFlatVector({1, 2}), + makeArrayVector({{3, 3}, {3, 3}}), + }); + config(AssertQueryBuilder(plan), 1).assertResults(expected); + config(AssertQueryBuilder(plan), 10).assertResults(expected); +} + +TEST_P(StreamingAggregationTest, preferredOutputBatchBytes) { + // Use grouping keys that span one or more batches. + std::vector keys = { + makeNullableFlatVector({1, 1, std::nullopt, 2, 2}), + makeFlatVector({2, 3, 3, 4}), + makeFlatVector({5, 6, 6, 6}), + makeFlatVector({6, 6, 6, 6}), + makeFlatVector({6, 7, 8}), + }; + + auto data = addPayload(keys, 1); + + auto plan = PlanBuilder() + .values(data) + .partialStreamingAggregation( + {"c0"}, + {"count(1)", + "min(c1)", + "max(c1)", + "sum(c1)", + "sumnonpod(1)", + "sum(cast(NULL as INT))"}) + .finalAggregation() + .planNode(); + + auto results = + config(AssertQueryBuilder(plan), 1024).copyResultBatches(pool_.get()); + + // If streamingMinOutputBatchSize is set to 1, we expect an output batch for: + // {1, NULL}, {2}, {3, 4}, {5}, {6}, {7, 8}. + // Otherwise, we expect the output batches to be determined by + // preferredOutputBatchBytes. + size_t expectedOutputBatches; + if (GetParam().streamingMinOutputBatchSize == 1) { + expectedOutputBatches = 6; + } else if (GetParam().preferredOutputBatchBytes == 1) { + expectedOutputBatches = 5; + } else if (GetParam().preferredOutputBatchBytes == 1024) { + expectedOutputBatches = 2; + } else { + ASSERT_EQ( + GetParam().preferredOutputBatchBytes, + std::numeric_limits::max()); + expectedOutputBatches = 1; + } + + ASSERT_EQ(results.size(), expectedOutputBatches); +} + +// Tests that when noGroupsSpanBatches is set, the number of output batches +// matches the number of input batches when minOutputBatchRows is set to 1. +// When minOutputBatchRows is set to an extremely large value, we expect a +// single output batch. +TEST_F(StreamingAggregationTest, noGroupsSpanBatches) { + // Create input batches where no group spans across batches. + // Each batch has unique grouping keys that don't appear in other batches. + std::vector keys = { + makeFlatVector({1, 1, 2, 2}), + makeFlatVector({3, 3, 4, 4}), + makeFlatVector({5, 5, 6, 6}), + makeFlatVector({7, 7, 8, 8}), + makeFlatVector({9, 9, 10, 10}), + }; + + auto data = addPayload(keys, 1); + createDuckDbTable(data); + + struct { + int32_t minOutputBatchRows; + size_t expectedOutputBatches; + + std::string debugString() const { + return fmt::format( + "minOutputBatchRows={}, expectedOutputBatches={}", + minOutputBatchRows, + expectedOutputBatches); + } + } testSettings[] = { + // When minOutputBatchRows is 1, each input batch produces an output batch + {1, keys.size()}, + // When minOutputBatchRows is very large, all groups are batched together + // into a single output + {std::numeric_limits::max(), 1}, + }; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId aggregationNodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .values(data) + .streamingAggregation( + {"c0"}, + {"count(1)", "sum(c1)"}, + {}, + core::AggregationNode::Step::kSingle, + /*ignoreNullKeys=*/false, + /*noGroupsSpanBatches=*/true) + .capturePlanNodeId(aggregationNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kStreamingAggregationMinOutputBatchRows, + std::to_string(testData.minOutputBatchRows)) + .assertResults("SELECT c0, count(1), sum(c1) FROM tmp GROUP BY c0"); + + // Verify the number of output batches. + const auto taskStats = task->taskStats(); + ASSERT_EQ( + velox::exec::toPlanStats(taskStats).at(aggregationNodeId).outputVectors, + testData.expectedOutputBatches); + } +} + +namespace { +class InputSourceNode : public core::PlanNode { + public: + InputSourceNode( + const core::PlanNodeId& id, + int numInitialInputCalls, + int numSkipInputCalls, + core::PlanNodePtr source) + : PlanNode{id}, + numInitialInputCalls_(numInitialInputCalls), + numSkipInputCalls_(numSkipInputCalls), + sources_{std::move(source)} {} + + const RowTypePtr& outputType() const override { + return sources_[0]->outputType(); + } + + const std::vector& sources() const override { + return sources_; + } + + int numInitialInputCalls() const { + return numInitialInputCalls_; + } + + int numSkipInputCalls() const { + return numSkipInputCalls_; + } + + std::string_view name() const override { + return "external blocking node"; + } + + private: + void addDetails(std::stringstream& /* stream */) const override {} + + const int numInitialInputCalls_; + const int numSkipInputCalls_; + const std::vector sources_; +}; + +class InputSourceOperator : public exec::Operator { + public: + InputSourceOperator( + int32_t operatorId, + exec::DriverCtx* driverCtx, + std::shared_ptr node) + : Operator( + driverCtx, + node->outputType(), + operatorId, + node->id(), + "InputSource"), + numInitialInputCalls_(node->numInitialInputCalls()), + numSkipInputCalls_(node->numSkipInputCalls()) {} + + bool needsInput() const override { + if (numInitialInputCalls_-- >= 0) { + return true; + } + if (numSkipInputCalls_-- >= 0) { + return false; + } + return !noMoreInput_; + } + + void addInput(RowVectorPtr input) override { + input_ = std::move(input); + } + + RowVectorPtr getOutput() override { + auto output = std::move(input_); + input_ = nullptr; + return output; + } + + exec::BlockingReason isBlocked(ContinueFuture* future) override { + return exec::BlockingReason::kNotBlocked; + } + + bool isFinished() override { + return noMoreInput_; + } + + private: + mutable int numInitialInputCalls_{0}; + mutable int numSkipInputCalls_{0}; + RowVectorPtr input_; +}; + +class SourceNodeTranslator : public exec::Operator::PlanNodeTranslator { + std::unique_ptr toOperator( + exec::DriverCtx* ctx, + int32_t id, + const core::PlanNodePtr& node) override { + if (auto castedNode = + std::dynamic_pointer_cast(node)) { + return std::make_unique(id, ctx, castedNode); + } + return nullptr; + } +}; +} // namespace + +TEST_P(StreamingAggregationTest, needsInputWhenSplitOutput) { + exec::Operator::registerOperator(std::make_unique()); + const auto size = 32; + const auto numBatches{5}; + std::vector batches; + for (int i = 0; i < numBatches; ++i) { + batches.push_back(makeRowVector( + {makeFlatVector( + size, + [i, size](auto row) { + return row == 0 ? i * size + row - 1 : i * size + row; + }), + makeFlatVector(size, [](auto row) { return row; })})); + } + createDuckDbTable(batches); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId aggregationNodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .values(batches) + .streamingAggregation( + {"c0"}, + {"array_agg(c1)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .addNode([](std::string id, core::PlanNodePtr input) mutable { + return std::make_shared( + id, 2, numBatches - 2, input); + }) + .project({"c0", "a0"}) + .capturePlanNodeId(aggregationNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .serialExecution(true) + .config( + core::QueryConfig::kStreamingAggregationMinOutputBatchRows, "1") + .assertResults("SELECT c0, array_agg(c1) FROM tmp GROUP BY c0"); + const auto taskStats = task->taskStats(); + ASSERT_EQ( + velox::exec::toPlanStats(taskStats).at(aggregationNodeId).outputVectors, + 9); +} } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/TableEvolutionFuzzer.cpp b/velox/exec/tests/TableEvolutionFuzzer.cpp index ac52462b9741..82863458aee2 100644 --- a/velox/exec/tests/TableEvolutionFuzzer.cpp +++ b/velox/exec/tests/TableEvolutionFuzzer.cpp @@ -16,13 +16,45 @@ #include "velox/exec/tests/TableEvolutionFuzzer.h" #include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/dwio/common/tests/utils/FilterGenerator.h" +#include "velox/dwio/dwrf/common/Config.h" #include "velox/exec/Cursor.h" #include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/QueryAssertions.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/expression/fuzzer/ExpressionFuzzer.h" +#include "velox/functions/FunctionRegistry.h" #include "velox/vector/fuzzer/VectorFuzzer.h" #include +#include + +DEFINE_bool( + enable_oom_injection_write_path, + false, + "When enabled OOMs will randomly be triggered while executing the write path " + "The goal of this mode is to ensure unexpected exceptions " + "aren't thrown and the process isn't killed in the process of cleaning " + "up after failures. Therefore, results are not compared when this is " + "enabled. Note that this option only works in debug builds."); + +DEFINE_bool( + enable_oom_injection_read_path, + false, + "When enabled OOMs will randomly be triggered while executing scan " + "plans. The goal of this mode is to ensure unexpected exceptions " + "aren't thrown and the process isn't killed in the process of cleaning " + "up after failures. Therefore, results are not compared when this is " + "enabled. Note that this option only works in debug builds."); + +DEFINE_int32( + aggregation_pushdown_frequency, + 5, + "Controls the frequency of aggregation pushdown. The aggregation pushdown " + "is enabled with probability 1/N where N is this value. For example, " + "N=5 means 20% chance, N=2 means 50% chance."); + namespace facebook::velox::exec::test { std::ostream& operator<<( @@ -45,6 +77,70 @@ VectorFuzzer::Options makeVectorFuzzerOptions() { return options; } +template +void removeFromVector(std::vector& vec, const T& value) { + auto it = std::find(vec.begin(), vec.end(), value); + if (it != vec.end()) { + vec.erase(it); + } +} + +bool hasUnsupportedMapKey(const TypePtr& type) { + switch (type->kind()) { + case TypeKind::MAP: { + auto mapType = type->asMap(); + // FlatMapColumnWriter only supports TINYINT, SMALLINT, INTEGER, BIGINT, + // VARCHAR, VARBINARY as KeyType + auto keyKind = mapType.keyType()->kind(); + if (keyKind != TypeKind::TINYINT && keyKind != TypeKind::SMALLINT && + keyKind != TypeKind::INTEGER && keyKind != TypeKind::BIGINT && + keyKind != TypeKind::VARCHAR && keyKind != TypeKind::VARBINARY) { + return true; + } + return hasUnsupportedMapKey(mapType.valueType()); + } + case TypeKind::ARRAY: + return hasUnsupportedMapKey(type->asArray().elementType()); + case TypeKind::ROW: { + auto& rowType = type->asRow(); + for (int i = 0; i < rowType.size(); ++i) { + if (hasUnsupportedMapKey(rowType.childAt(i))) { + return true; + } + } + return false; + } + default: + return false; + } +} + +bool hasMapColumns(const RowTypePtr& schema) { + VLOG(1) << "Checking if schema has map columns"; + for (int i = 0; i < schema->size(); ++i) { + if (schema->childAt(i)->isMap()) { + return true; + } + } + return false; +} + +bool hasEmptyElement(const RowVectorPtr& data, int columnIndex) { + auto mapVector = data->childAt(columnIndex)->as(); + if (!mapVector) { + return true; + } + + // Check if any map entry is empty (null or size = 0) + auto sizes = mapVector->sizes(); + for (int j = 0; j < mapVector->size(); ++j) { + if (mapVector->isNullAt(j) || sizes->as()[j] == 0) { + return true; // Found an empty map + } + } + return false; // No empty maps found +} + } // namespace TableEvolutionFuzzer::TableEvolutionFuzzer(const Config& config) @@ -91,23 +187,82 @@ TableEvolutionFuzzer::parseFileFormats(std::string input) { namespace { +// Helper function to randomly select aggregates from available columns +// without replacement. Returns a list of aggregate expressions. +void generateAggregatesForColumns( + const std::vector& availableColumns, + const std::vector& supportedAggFuncs, + const RowTypePtr& schema, + FuzzerGenerator& rng, + std::vector& aggregates) { + if (availableColumns.empty()) { + return; + } + + int numAggregates = std::min( + static_cast(availableColumns.size()), + std::min( + static_cast(5), + static_cast( + folly::Random::rand32(1, availableColumns.size() + 1, rng)))); + + std::unordered_set selectedIndices; + for (int i = 0; i < numAggregates; ++i) { + if (folly::Random::oneIn(2, rng)) { + int randomIdx; + do { + randomIdx = folly::Random::rand32(availableColumns.size(), rng); + } while (selectedIndices.count(randomIdx) > 0); + selectedIndices.insert(randomIdx); + + int colIdx = availableColumns[randomIdx]; + std::string aggFunc = supportedAggFuncs[folly::Random::rand32( + supportedAggFuncs.size(), rng)]; + aggregates.push_back( + fmt::format("{}({})", aggFunc, schema->nameOf(colIdx))); + } + } +} + std::vector> runTaskCursors( - const std::vector>& cursors, + const std::vector>& cursors, folly::Executor& executor) { std::vector>> futures; for (int i = 0; i < cursors.size(); ++i) { auto [promise, future] = folly::makePromiseContract>(); futures.push_back(std::move(future)); - executor.add([&, i, promise = std::move(promise)]() mutable { + auto cursorPtr = cursors[i]; + auto task = cursorPtr->task(); + executor.add([cursorPtr, task, promise = std::move(promise)]() mutable { std::vector results; try { - while (cursors[i]->moveNext()) { - auto& result = cursors[i]->current(); + while (cursorPtr->moveNext()) { + auto& result = cursorPtr->current(); result->loadedVector(); results.push_back(std::move(result)); } promise.setValue(std::move(results)); + } catch (VeloxRuntimeError& e) { + if (FLAGS_enable_oom_injection_write_path && + e.errorCode() == facebook::velox::error_code::kMemCapExceeded && + e.message() == ScopedOOMInjector::kErrorMessage) { + // If we enabled OOM injection we expect the exception thrown by the + // ScopedOOMInjector. + LOG(INFO) << "OOM injection triggered in write path: " << e.what(); + promise.setValue(std::move(results)); + } else if ( + FLAGS_enable_oom_injection_read_path && + e.errorCode() == facebook::velox::error_code::kMemCapExceeded && + e.message() == ScopedOOMInjector::kErrorMessage) { + // If we enabled OOM injection we expect the exception thrown by the + // ScopedOOMInjector. + LOG(INFO) << "OOM injection triggered in read path: " << e.what(); + promise.setValue(std::move(results)); + } else { + LOG(ERROR) << e.what(); + promise.setException(e); + } } catch (const std::exception& e) { LOG(ERROR) << e.what(); promise.setException(e); @@ -116,12 +271,12 @@ std::vector> runTaskCursors( } std::vector> results; constexpr std::chrono::seconds kTaskTimeout(10); + results.reserve(futures.size()); for (auto& future : futures) { results.push_back(std::move(future).get(kTaskTimeout)); } return results; } - // `tableBucketCount' is the bucket count of current table setup when reading. // `partitionBucketCount' is the bucket count when the partition was written. // `tableBucketCount' must be a multiple of `partitionBucketCount'. @@ -134,11 +289,14 @@ void buildScanSplitFromTableWriteResult( dwio::common::FileFormat fileFormat, const std::vector& writeResult, std::vector& splits) { + if (FLAGS_enable_oom_injection_write_path) { + return; + } VELOX_CHECK_EQ(writeResult.size(), 1); auto* fragments = writeResult[0]->childAt(1)->asChecked>(); for (int i = 1; i < writeResult[0]->size(); ++i) { - auto fragment = folly::parseJson(fragments->valueAt(i)); + auto fragment = folly::parseJson(std::string_view(fragments->valueAt(i))); auto fileName = fragment["fileWriteInfos"][0]["writeFileName"].asString(); auto hiveSplit = std::make_shared( TableEvolutionFuzzer::connectorId(), @@ -191,8 +349,18 @@ void checkResultsEqual( expectedRowIndex = 0; continue; } - VELOX_CHECK(actual[actualVectorIndex]->equalValueAt( - expected[expectedVectorIndex].get(), actualRowIndex, expectedRowIndex)); + VELOX_CHECK( + actual[actualVectorIndex]->equalValueAt( + expected[expectedVectorIndex].get(), + actualRowIndex, + expectedRowIndex), + "actualVectorIndex={} actualRowIndex={} expectedVectorIndex={} expectedRowIndex={}\nactual={}\nexpected={}", + actualVectorIndex, + actualRowIndex, + expectedVectorIndex, + expectedRowIndex, + actual[actualVectorIndex]->toString(actualRowIndex), + expected[expectedVectorIndex]->toString(expectedRowIndex)); ++actualRowIndex; ++expectedRowIndex; } @@ -210,6 +378,169 @@ void checkResultsEqual( VELOX_CHECK_EQ(expectedVectorIndex, expected.size()); } +common::SubfieldFilters generateSubfieldFilters( + RowTypePtr& rowType, + const RowVectorPtr& finalExpectedData) { + dwio::common::MutationSpec mutations; + std::vector hitRows; + + std::unique_ptr filterGenerator = + std::make_unique(rowType, 0); + + auto subfieldsVector = filterGenerator->makeFilterables(rowType->size(), 100); + + const auto& filterSpecs = + filterGenerator->makeRandomSpecs(subfieldsVector, 100); + + return filterGenerator->makeSubfieldFilters( + filterSpecs, {finalExpectedData}, &mutations, hitRows); +} + +fuzzer::ExpressionFuzzer::FuzzedExpressionData generateRemainingFilters( + const TableEvolutionFuzzer::Config& config, + unsigned currentSeed) { + // Use ExpressionFuzzer to generate complex expressions, but use the actual + // data types from finalExpectedData + // Configure ExpressionFuzzer to generate simpler expressions suitable for + // filters + fuzzer::ExpressionFuzzer::Options options; + options.enableComplexTypes = false; // Disable complex types to avoid issues + options.enableDecimalType = false; // Disable decimal types + options.maxLevelOfNesting = 3; // Reduce nesting to avoid complexity + options.nullRatio = 0.0; // No null values to avoid type resolution issues + // Only use simple comparison and logical functions suitable for filters + options.useOnlyFunctions = "eq,neq,lt,lte,gt,gte,and,or,not"; + options.specialForms = "and,or"; // Only simple special forms + + // Skip complex functions that generate unparseable expressions + options.skipFunctions = { + "regexp_like", + "regexp_extract", + "replace", + "replace_first", + "json_format", + "json_extract", + "json_parse", + "from_utf8", + "to_utf8", + "reverse", + "upper", + "lower", + "st_coorddim", + "is_null", + "is_not_null"}; + + auto signatureMap = getVectorFunctionSignatures(); + + // Configure VectorFuzzer to avoid null values and use the actual data types + VectorFuzzer::Options vectorFuzzerOptions; + vectorFuzzerOptions.nullRatio = 0.0; // No nulls + vectorFuzzerOptions.vectorSize = 100; + auto vectorFuzzer = + std::make_shared(vectorFuzzerOptions, config.pool); + + fuzzer::ExpressionFuzzer expressionFuzzer( + signatureMap, currentSeed, vectorFuzzer, options); + + return expressionFuzzer.fuzzExpressions(1); +} + +// Generate random aggregation configuration for pushdown testing. +// Only generates aggregations that are eligible for pushdown: +// - Supported aggregate functions: min, max, bool_and, bool_or +// - Each column can only be used by at most one aggregate +// - Grouping keys are optional (can be empty for global aggregation) +// - Columns with filters (subfield or remaining) are excluded to enable +// pushdown +std::optional generateAggregationConfig( + const RowTypePtr& schema, + FuzzerGenerator& rng, + const std::unordered_set& filteredColumns) { + // List of aggregate functions that support pushdown + // Note: Excluding 'sum' to avoid integer overflow in fuzzer with random data + static const std::vector supportedNumericAggs = {"min", "max"}; + static const std::vector supportedBooleanAggs = { + "bool_and", "bool_or"}; + static const std::vector supportedIntegerAggs = { + "bitwise_and_agg", "bitwise_or_agg", "bitwise_xor_agg"}; + + // Randomly decide number of grouping keys (0 to 2) + int numGroupingKeys = folly::Random::rand32(3, rng); + std::vector groupingKeys; + std::unordered_set usedColumnIndices; + + // Select random columns for grouping keys + for (int i = 0; i < numGroupingKeys && i < schema->size(); ++i) { + int colIdx = folly::Random::rand32(schema->size(), rng); + if (usedColumnIndices.count(colIdx) == 0) { + groupingKeys.push_back(schema->nameOf(colIdx)); + usedColumnIndices.insert(colIdx); + } + } + + // Generate aggregates on remaining columns + // For aggregation pushdown to work, each column should only be used once + // and columns with filters should be excluded + std::vector aggregates; + std::vector availableNumericColumns; + std::vector availableIntegerColumns; + std::vector availableBooleanColumns; + for (int i = 0; i < schema->size(); ++i) { + if (usedColumnIndices.count(i) == 0) { + auto columnName = schema->nameOf(i); + // Skip columns that have filters (subfield or remaining) + if (filteredColumns.count(columnName) > 0) { + continue; + } + + auto type = schema->childAt(i); + // Integer types: randomly choose between min/max or bitwise aggregations + // Note: Exclude DATE/Interval type as it doesn't support bitwise + // aggregations + if ((type->isInteger() || type->isBigint() || type->isSmallint() || + type->isTinyint()) && + !type->isDate() && !type->isIntervalDayTime() && + !type->isIntervalYearMonth()) { + if (folly::Random::oneIn(2, rng)) { + availableIntegerColumns.push_back(i); + } else { + availableNumericColumns.push_back(i); + } + } + // Float types support min/max only + else if ((type->isReal() || type->isDouble()) && !type->isDecimal()) { + availableNumericColumns.push_back(i); + } + // Boolean types support bool_and/bool_or + else if (type->isBoolean()) { + availableBooleanColumns.push_back(i); + } + } + } + + // Need at least one column to aggregate + if (availableNumericColumns.empty() && availableBooleanColumns.empty() && + availableIntegerColumns.empty()) { + return std::nullopt; + } + + // Randomly pick columns for aggregates without replacement + generateAggregatesForColumns( + availableNumericColumns, supportedNumericAggs, schema, rng, aggregates); + generateAggregatesForColumns( + availableBooleanColumns, supportedBooleanAggs, schema, rng, aggregates); + generateAggregatesForColumns( + availableIntegerColumns, supportedIntegerAggs, schema, rng, aggregates); + + if (aggregates.empty()) { + return std::nullopt; + } + + return AggregationConfig{ + .groupingKeys = std::move(groupingKeys), + .aggregates = std::move(aggregates)}; +} + } // namespace VectorPtr TableEvolutionFuzzer::liftToType( @@ -302,8 +633,9 @@ VectorPtr TableEvolutionFuzzer::liftToType( if (i < children.size()) { children[i] = liftToType(children[i], childType); } else { - children.push_back(BaseVector::createNullConstant( - childType, row->size(), config_.pool)); + children.push_back( + BaseVector::createNullConstant( + childType, row->size(), config_.pool)); } } return std::make_shared( @@ -315,42 +647,84 @@ VectorPtr TableEvolutionFuzzer::liftToType( } void TableEvolutionFuzzer::run() { - std::vector bucketColumnIndices; - for (int i = 0; i < config_.columnCount; ++i) { - if (folly::Random::oneIn(2 * config_.columnCount, rng_)) { - bucketColumnIndices.push_back(i); - } + ScopedOOMInjector oomInjectorWritePath( + [this]() -> bool { return folly::Random::oneIn(10, rng_); }, + 10); // Check the condition every 10 ms. + if (FLAGS_enable_oom_injection_write_path) { + oomInjectorWritePath.enable(); } - VLOG(1) << "bucketColumnIndices: [" << folly::join(", ", bucketColumnIndices) - << "]"; - auto testSetups = makeSetups(bucketColumnIndices); - auto tableOutputRootDir = TempDirectoryPath::create(); - std::vector> writeTasks( - 2 * config_.evolutionCount - 1); - for (int i = 0; i < config_.evolutionCount; ++i) { - auto data = vectorFuzzer_.fuzzRow(testSetups[i].schema, kVectorSize, false); - for (auto& child : data->children()) { - BaseVector::flattenVector(child); + + // Step 1: Randomly decide whether to generate remaining filters (50% chance) + bool shouldGenerateRemainingFilters = folly::Random::oneIn(2, rng_); + + fuzzer::ExpressionFuzzer::FuzzedExpressionData generatedRemainingFilters; + std::vector additionalColumnNames; + std::vector additionalColumnTypes; + + if (shouldGenerateRemainingFilters) { + // Generate remaining filters and extract new columns + generatedRemainingFilters = generateRemainingFilters(config_, currentSeed_); + + VLOG(1) << "Generated remaining filters from expression fuzzer: " + << generatedRemainingFilters.expressions[0]->toString(); + + // Extract all columns from generatedRemainingFilters.inputType + if (generatedRemainingFilters.inputType) { + for (int i = 0; i < generatedRemainingFilters.inputType->size(); ++i) { + const auto& columnName = generatedRemainingFilters.inputType->nameOf(i); + additionalColumnNames.push_back(columnName); + additionalColumnTypes.push_back( + generatedRemainingFilters.inputType->childAt(i)); + } } - auto actualDir = - fmt::format("{}/actual_{}", tableOutputRootDir->getPath(), i); - VELOX_CHECK(std::filesystem::create_directory(actualDir)); - writeTasks[2 * i] = - makeWriteTask(testSetups[i], data, actualDir, bucketColumnIndices); - if (i == config_.evolutionCount - 1) { - continue; + + if (!additionalColumnNames.empty()) { + VLOG(1) + << "Found " << additionalColumnNames.size() + << " columns from generateRemainingFilters, will add to schema evolution"; } - auto expectedDir = - fmt::format("{}/expected_{}", tableOutputRootDir->getPath(), i); - VELOX_CHECK(std::filesystem::create_directory(expectedDir)); - auto expectedData = std::static_pointer_cast( - liftToType(data, testSetups.back().schema)); - writeTasks[2 * i + 1] = makeWriteTask( - testSetups.back(), expectedData, expectedDir, bucketColumnIndices); + } else { + VLOG(1) << "Skipping remaining filter generation (50% randomization)"; } + + // Step 2: Test setup and bucketColumnIndices generation with additional + // columns + auto bucketColumnIndices = generateBucketColumnIndices(); + + // Track column name mappings during evolution + std::unordered_map columnNameMapping; + for (const auto& columnName : additionalColumnNames) { + columnNameMapping[columnName] = columnName; // Initially map to itself + } + + auto testSetups = makeSetups( + bucketColumnIndices, + additionalColumnNames, + additionalColumnTypes, + &columnNameMapping); + + // Step 3: Create and execute write tasks + auto tableOutputRootDir = TempDirectoryPath::create(); + std::vector> writeTasks( + 2 * config_.evolutionCount - 1); + RowVectorPtr finalExpectedData; + + folly::F14FastMap> globalMapColumnKeys; + std::vector globallyConsistentColumnIndexVector; + + createWriteTasks( + testSetups, + bucketColumnIndices, + tableOutputRootDir->getPath(), + writeTasks, + finalExpectedData, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); + auto executor = folly::getGlobalCPUExecutor(); auto writeResults = runTaskCursors(writeTasks, *executor); + // Step 4: Create scan splits from write results std::optional selectedBucket; if (!bucketColumnIndices.empty()) { selectedBucket = @@ -358,38 +732,115 @@ void TableEvolutionFuzzer::run() { VLOG(1) << "selectedBucket=" << *selectedBucket; } - std::vector actualSplits, expectedSplits; - for (int i = 0; i < config_.evolutionCount; ++i) { - auto* result = &writeResults[2 * i]; - buildScanSplitFromTableWriteResult( - testSetups.back().schema, - bucketColumnIndices, - selectedBucket, - testSetups.back().bucketCount(), - testSetups[i].bucketCount(), - testSetups[i].fileFormat, - *result, - actualSplits); - if (i < config_.evolutionCount - 1) { - result = &writeResults[2 * i + 1]; + auto [actualSplits, expectedSplits] = createScanSplitsFromWriteResults( + writeResults, + testSetups, + bucketColumnIndices, + selectedBucket, + finalExpectedData); + + // Step 5: Setup scan tasks with filters and optional aggregation pushdown + auto rowType = testSetups.back().schema; + PushdownConfig pushdownConfig; + + // Generate subfield filters first + pushdownConfig.subfieldFiltersMap = + generateSubfieldFilters(rowType, finalExpectedData); + + // Extract field names used by subfield filters to avoid conflicts + std::unordered_set subfieldFilteredFields; + for (const auto& [subfield, filter] : pushdownConfig.subfieldFiltersMap) { + auto fieldName = subfield.toString(); + VLOG(1) << "Raw subfield: " << fieldName << ' ' << filter->toString(); + // Extract the root field name (before any nested access) + const size_t dotPos = fieldName.find('.'); + if (dotPos != std::string::npos) { + fieldName = fieldName.substr(0, dotPos); + VLOG(1) << "Subfield filter targets field: " << fieldName; + } + subfieldFilteredFields.insert(fieldName); + } + + // Apply generated remaining filters with updated column names, avoiding + // conflicts + if (shouldGenerateRemainingFilters) { + // Apply generated remaining filters + applyRemainingFilters( + generatedRemainingFilters, + columnNameMapping, + pushdownConfig, + subfieldFilteredFields); + } + + // Collect all filtered columns (both subfield and remaining filters) + std::unordered_set allFilteredColumns = subfieldFilteredFields; + + // Extract columns from remaining filter if present + if (!pushdownConfig.remainingFilter.empty()) { + for (const auto& name : rowType->names()) { + // Check if the column name appears in the remaining filter + if (pushdownConfig.remainingFilter.find(name) != std::string::npos) { + allFilteredColumns.insert(name); + } } - buildScanSplitFromTableWriteResult( - testSetups.back().schema, - bucketColumnIndices, - selectedBucket, - testSetups.back().bucketCount(), - testSetups.back().bucketCount(), - testSetups.back().fileFormat, - *result, - expectedSplits); } - std::vector> scanTasks(2); - scanTasks[0] = - makeScanTask(testSetups.back().schema, std::move(actualSplits)); - scanTasks[1] = - makeScanTask(testSetups.back().schema, std::move(expectedSplits)); + + // Enable aggregation testing + std::optional aggConfig; + bool shouldTestAggregation = + folly::Random::oneIn(FLAGS_aggregation_pushdown_frequency, rng_); + if (shouldTestAggregation) { + aggConfig = generateAggregationConfig(rowType, rng_, allFilteredColumns); + if (aggConfig.has_value()) { + VLOG(1) << "Testing aggregation pushdown with grouping keys: [" + << folly::join(", ", aggConfig->groupingKeys) + << "] and aggregates: [" + << folly::join(", ", aggConfig->aggregates) << "]"; + } else { + VLOG(1) << "Could not generate valid aggregation configuration"; + aggConfig = std::nullopt; + } + } + + std::vector> scanTasks(2); + // actual: TableScan -> Aggregation (allows pushdown) + pushdownConfig.aggregationConfig = aggConfig; + scanTasks[0] = makeScanTask( + rowType, + std::move(actualSplits), + pushdownConfig, + false, + false, // insertProjectToBlockPushdown + globalMapColumnKeys, + globallyConsistentColumnIndexVector); + + // expected: TableScan -> Project -> Aggregation (blocks pushdown) + // Insert a Project node to prevent aggregation pushdown + pushdownConfig.aggregationConfig = aggConfig; + scanTasks[1] = makeScanTask( + rowType, + std::move(expectedSplits), + pushdownConfig, + true, + true, // insertProjectToBlockPushdown + globalMapColumnKeys, + globallyConsistentColumnIndexVector); + + ScopedOOMInjector oomInjectorReadPath( + [this]() -> bool { return folly::Random::oneIn(10, rng_); }, + 10); // Check the condition every 10 ms. + if (FLAGS_enable_oom_injection_read_path) { + oomInjectorReadPath.enable(); + } + + // Step 6: Execute scan tasks and verify results auto scanResults = runTaskCursors(scanTasks, *executor); - checkResultsEqual(scanResults[0], scanResults[1]); + + // Skip result verification when OOM injection is enabled + if (!FLAGS_enable_oom_injection_write_path && + !FLAGS_enable_oom_injection_read_path) { + checkResultsEqual(scanResults[0], scanResults[1]); + } } int TableEvolutionFuzzer::Setup::bucketCount() const { @@ -416,13 +867,25 @@ TypePtr TableEvolutionFuzzer::makeNewType(int maxDepth) { return vectorFuzzer_.randType(scalarTypes, maxDepth); } -RowTypePtr TableEvolutionFuzzer::makeInitialSchema() { +RowTypePtr TableEvolutionFuzzer::makeInitialSchema( + const std::vector& additionalColumnNames, + const std::vector& additionalColumnTypes) { std::vector names(config_.columnCount); std::vector types(config_.columnCount); for (int i = 0; i < config_.columnCount; ++i) { names[i] = makeNewName(); types[i] = makeNewType(3); } + + // Add additional columns from generateRemainingFilters + for (int i = 0; i < additionalColumnNames.size(); ++i) { + names.push_back(additionalColumnNames[i]); + types.push_back(additionalColumnTypes[i]); + VLOG(1) << "Adding additional column to initial schema: " + << additionalColumnNames[i] << " of type " + << additionalColumnTypes[i]->toString(); + } + return ROW(std::move(names), std::move(types)); } @@ -449,6 +912,10 @@ TypePtr TableEvolutionFuzzer::evolveType(const TypePtr& old) { case TypeKind::SMALLINT: return INTEGER(); case TypeKind::INTEGER: + // Don't evolve DATE type to BIGINT + if (old->isDate()) { + return old; + } return BIGINT(); case TypeKind::REAL: return DOUBLE(); @@ -459,7 +926,8 @@ TypePtr TableEvolutionFuzzer::evolveType(const TypePtr& old) { RowTypePtr TableEvolutionFuzzer::evolveRowType( const RowType& old, - const std::vector& bucketColumnIndices) { + const std::vector& bucketColumnIndices, + std::unordered_map* columnNameMapping) { auto names = old.names(); auto types = old.children(); for (int i = 0, j = 0; i < old.size(); ++i) { @@ -471,7 +939,22 @@ RowTypePtr TableEvolutionFuzzer::evolveRowType( continue; } if (folly::Random::oneIn(4, rng_)) { - names[i] = makeNewName(); + auto oldName = names[i]; + auto newName = makeNewName(); + names[i] = newName; + + // Update column name mapping if provided + if (columnNameMapping) { + // Find if this column was originally from generateRemainingFilters + for (auto& [originalName, currentName] : *columnNameMapping) { + if (currentName == oldName) { + currentName = newName; + VLOG(1) << "Updated column name mapping: " << originalName << " -> " + << newName; + break; + } + } + } } types[i] = evolveType(types[i]); } @@ -483,14 +966,18 @@ RowTypePtr TableEvolutionFuzzer::evolveRowType( } std::vector TableEvolutionFuzzer::makeSetups( - const std::vector& bucketColumnIndices) { + const std::vector& bucketColumnIndices, + const std::vector& additionalColumnNames, + const std::vector& additionalColumnTypes, + std::unordered_map* columnNameMapping) { std::vector setups(config_.evolutionCount); for (int i = 0; i < config_.evolutionCount; ++i) { if (i == 0) { - setups[i].schema = makeInitialSchema(); - } else { setups[i].schema = - evolveRowType(*setups[i - 1].schema, bucketColumnIndices); + makeInitialSchema(additionalColumnNames, additionalColumnTypes); + } else { + setups[i].schema = evolveRowType( + *setups[i - 1].schema, bucketColumnIndices, columnNameMapping); } if (!bucketColumnIndices.empty()) { if (i == 0) { @@ -513,22 +1000,161 @@ std::unique_ptr TableEvolutionFuzzer::makeWriteTask( const Setup& setup, const RowVectorPtr& data, const std::string& outputDir, - const std::vector& bucketColumnIndices) { + const std::vector& bucketColumnIndices, + FuzzerGenerator& rng, + bool enableFlatMap, + folly::F14FastMap>& globalMapColumnKeys, + std::vector& globallyCompatibleFlatmapColumns) { auto builder = PlanBuilder().values({data}); + + // Create serdeParameters using proper dwrf::Config for flatmap configuration + std::unordered_map serdeParameters; + + if (hasMapColumns(setup.schema)) { + // Find all top-level map column indices that support flatmap + std::vector supportedMapColumnIndices; + + for (int i = 0; i < setup.schema->size(); ++i) { + if (setup.schema->childAt(i)->isMap()) { + // Check if this specific map column has any empty elements + if (hasEmptyElement(data, i)) { + removeFromVector(globallyCompatibleFlatmapColumns, i); + continue; + } + + if (!hasUnsupportedMapKey(setup.schema->childAt(i))) { + // %50 chance to enable flatmap for this map column. + if (enableFlatMap && folly::Random::oneIn(2, rng)) { + supportedMapColumnIndices.push_back(static_cast(i)); + VLOG(1) << "Write column " << setup.schema->nameOf(i) + << " as flatmap"; + + // Extract actual keys from the map data and collect directly into + // global set + SelectivityVector allRows(data->childAt(i)->size()); + DecodedVector decodedMap(*data->childAt(i), allRows); + auto* mapVector = decodedMap.base()->asChecked(); + if (mapVector->size() > 0) { + auto keys = mapVector->mapKeys(); + + if (keys) { + // Collect keys directly into the global set + auto& uniqueKeys = globalMapColumnKeys[static_cast(i)]; + + // Iterate through the decoded rows, not the raw mapVector + // indices + for (vector_size_t row = 0; row < data->childAt(i)->size(); + ++row) { + auto decodedIndex = decodedMap.index(row); + if (!decodedMap.isNullAt(row) && + !mapVector->isNullAt(decodedIndex)) { + // Get the map entry for this decoded row + auto mapOffset = mapVector->offsetAt(decodedIndex); + auto mapSize = mapVector->sizeAt(decodedIndex); + + // Process all keys in this map entry + for (vector_size_t keyIdx = 0; keyIdx < mapSize; ++keyIdx) { + auto keyPosition = mapOffset + keyIdx; + if (!keys->isNullAt(keyPosition)) { + std::string keyStr; + if (keys->type()->isVarchar() || + keys->type()->isVarbinary()) { + auto* keyVector = keys->asFlatVector(); + auto keyView = keyVector->valueAt(keyPosition); + keyStr = std::string(keyView); + } else if (keys->type()->isInteger()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isBigint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isSmallint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else if (keys->type()->isTinyint()) { + auto* keyVector = keys->asFlatVector(); + auto keyVal = keyVector->valueAt(keyPosition); + keyStr = std::to_string(keyVal); + } else { + // This should not be reached since + // hasUnsupportedMapKey filters out unsupported types + VELOX_UNREACHABLE( + "Unsupported map key type: {}", + keys->type()->toString()); + } + uniqueKeys.insert(keyStr); + } + } + } + } + } + } + } else { + // Remove this column from globallyCompatibleFlatmapColumns + removeFromVector(globallyCompatibleFlatmapColumns, i); + } + } else { + removeFromVector(globallyCompatibleFlatmapColumns, i); + } + } + } + + if (!supportedMapColumnIndices.empty()) { + auto config = std::make_shared(); + config->set(dwrf::Config::FLATTEN_MAP, true); + config->set>( + dwrf::Config::MAP_FLAT_COLS, supportedMapColumnIndices); + + // Convert to serdeParameters + auto configParams = config->toSerdeParams(); + serdeParameters.insert(configParams.begin(), configParams.end()); + } + } + if (bucketColumnIndices.empty()) { - builder.tableWrite(outputDir, setup.fileFormat); + if (!serdeParameters.empty()) { + builder.tableWrite( + outputDir, + /*partitionBy=*/{}, + /*bucketCount=*/0, + /*bucketedBy=*/{}, + /*sortBy=*/{}, + setup.fileFormat, + /*aggregates=*/{}, + /*connectorId=*/PlanBuilder::kHiveDefaultConnectorId, + serdeParameters); + } else { + builder.tableWrite(outputDir, setup.fileFormat); + } } else { std::vector bucketColumnNames; bucketColumnNames.reserve(bucketColumnIndices.size()); for (auto i : bucketColumnIndices) { bucketColumnNames.push_back(setup.schema->nameOf(i)); } - builder.tableWrite( - outputDir, - /*partitionBy=*/{}, - setup.bucketCount(), - bucketColumnNames, - setup.fileFormat); + if (!serdeParameters.empty()) { + builder.tableWrite( + outputDir, + /*partitionBy=*/{}, + setup.bucketCount(), + bucketColumnNames, + /*sortBy=*/{}, + setup.fileFormat, + /*aggregates=*/{}, + /*connectorId=*/PlanBuilder::kHiveDefaultConnectorId, + serdeParameters); + } else { + builder.tableWrite( + outputDir, + /*partitionBy=*/{}, + setup.bucketCount(), + bucketColumnNames, + /*sortBy=*/{}, + setup.fileFormat); + } } CursorParameters params; params.serialExecution = true; @@ -555,19 +1181,101 @@ VectorPtr TableEvolutionFuzzer::liftToPrimitiveType( std::vector({})); } +RowTypePtr TableEvolutionFuzzer::buildFlatmapAsStructSchema( + const RowTypePtr& tableSchema, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns) { + if (globallyCompatibleFlatmapColumns.empty()) { + return tableSchema; + } + + VLOG(1) << "Setting up struct reading for " + << globallyCompatibleFlatmapColumns.size() + << " flatmap columns with real keys"; + + auto names = tableSchema->names(); + auto types = tableSchema->children(); + + // Filter globalMapColumnKeys to only include globally compatible columns + std::unordered_map> filteredMapColumnKeys; + for (int mapColumnIndex : globallyCompatibleFlatmapColumns) { + if (globalMapColumnKeys.find(mapColumnIndex) != globalMapColumnKeys.end()) { + // Add 50% probability to include this column in filteredMapColumnKeys + if (folly::Random::oneIn(2, rng_)) { + filteredMapColumnKeys[mapColumnIndex] = + globalMapColumnKeys.at(mapColumnIndex); + } + } + } + + // Use the filteredMapColumnKeys for struct reading + for (const auto& [mapColumnIndex, keysSet] : filteredMapColumnKeys) { + // Convert map type to struct type for struct reading + auto finalMapType = types[mapColumnIndex]->asMap(); + auto finalValueType = finalMapType.valueType(); + // Convert F14FastSet to vector for ROW constructor + std::vector keys(keysSet.begin(), keysSet.end()); + // Construct struct schema with real keys from write time + final value + // type + std::vector finalStructFieldTypes(keys.size(), finalValueType); + auto finalStructSchema = ROW(keys, finalStructFieldTypes); + + // Replace the map type with struct type in the schema + types[mapColumnIndex] = finalStructSchema; + } + + // Build new schema using struct reading for flatmap columns + return ROW(names, types); +} + std::unique_ptr TableEvolutionFuzzer::makeScanTask( const RowTypePtr& tableSchema, - std::vector splits) { + std::vector splits, + const PushdownConfig& pushdownConfig, + bool useFiltersAsNode, + bool insertProjectToBlockPushdown, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns) { + // Build schema for flatmap as struct reading + RowTypePtr newSchemaUsingStructReadingFlatMap = buildFlatmapAsStructSchema( + tableSchema, globalMapColumnKeys, globallyCompatibleFlatmapColumns); + CursorParameters params; params.serialExecution = true; - // TODO: Mix in filter and aggregate pushdowns. - params.planNode = PlanBuilder() - .tableScan( - tableSchema, - /*subfieldFilters=*/{}, - /*remainingFilter=*/"", - tableSchema) - .planNode(); + + auto builder = PlanBuilder() + .filtersAsNode(useFiltersAsNode) + .tableScanWithPushDown( + newSchemaUsingStructReadingFlatMap, // Use struct + // schema for + // flatmap reading + /*pushdownConfig=*/pushdownConfig, + tableSchema, // Original schema as dataColumns + {}); + + // If insertProjectToBlockPushdown is set, insert an identity Project node + // to prevent Driver::mayPushdownAggregation() from allowing pushdown + if (insertProjectToBlockPushdown && + pushdownConfig.aggregationConfig.has_value()) { + // Create identity projection: simply pass through all columns + std::vector projectExprs; + for (const auto& name : newSchemaUsingStructReadingFlatMap->names()) { + projectExprs.push_back(name); + } + builder.project(projectExprs); + } + + // Add aggregation if enabled in pushdown config + if (pushdownConfig.aggregationConfig.has_value()) { + builder.singleAggregation( + pushdownConfig.aggregationConfig->groupingKeys, + pushdownConfig.aggregationConfig->aggregates); + } + + params.planNode = builder.planNode(); + auto cursor = TaskCursor::create(params); for (auto& split : splits) { cursor->task()->addSplit("0", std::move(split)); @@ -576,4 +1284,224 @@ std::unique_ptr TableEvolutionFuzzer::makeScanTask( return cursor; } +std::vector +TableEvolutionFuzzer::generateBucketColumnIndices() { + std::vector bucketColumnIndices; + for (int i = 0; i < config_.columnCount; ++i) { + if (folly::Random::oneIn(2 * config_.columnCount, rng_)) { + bucketColumnIndices.push_back(i); + } + } + VLOG(1) << "bucketColumnIndices: [" << folly::join(", ", bucketColumnIndices) + << "]"; + return bucketColumnIndices; +} + +std::pair, std::vector> +TableEvolutionFuzzer::createScanSplitsFromWriteResults( + const std::vector>& writeResults, + const std::vector& testSetups, + const std::vector& bucketColumnIndices, + std::optional selectedBucket, + const RowVectorPtr& finalExpectedData) { + std::vector actualSplits, expectedSplits; + + for (int i = 0; i < config_.evolutionCount; ++i) { + auto* result = &writeResults[2 * i]; + buildScanSplitFromTableWriteResult( + testSetups.back().schema, + bucketColumnIndices, + selectedBucket, + testSetups.back().bucketCount(), + testSetups[i].bucketCount(), + testSetups[i].fileFormat, + *result, + actualSplits); + + if (i < config_.evolutionCount - 1) { + result = &writeResults[2 * i + 1]; + } + buildScanSplitFromTableWriteResult( + testSetups.back().schema, + bucketColumnIndices, + selectedBucket, + testSetups.back().bucketCount(), + testSetups.back().bucketCount(), + testSetups.back().fileFormat, + *result, + expectedSplits); + } + + return {std::move(actualSplits), std::move(expectedSplits)}; +} + +void TableEvolutionFuzzer::createWriteTasks( + const std::vector& testSetups, + const std::vector& bucketColumnIndices, + const std::string& tableOutputRootDirPath, + std::vector>& writeTasks, + RowVectorPtr& finalExpectedData, + folly::F14FastMap>& globalMapColumnKeys, + std::vector& globallyConsistentColumnIndexVector) { + // Initialize globallyConsistentColumnIndexVector with all map column indices + // from the first schema, then filter out incompatible ones during processing + if (hasMapColumns(testSetups[0].schema)) { + for (int j = 0; j < testSetups[0].schema->size(); ++j) { + if (testSetups[0].schema->childAt(j)->isMap() && + !hasUnsupportedMapKey(testSetups[0].schema->childAt(j))) { + globallyConsistentColumnIndexVector.push_back(j); + } + } + } + + // Generate data and create write tasks in a single loop + for (int i = 0; i < config_.evolutionCount; ++i) { + // Generate fresh data for each evolution step independently + auto data = vectorFuzzer_.fuzzRow(testSetups[i].schema, kVectorSize, false); + for (auto& child : data->children()) { + BaseVector::flattenVector(child); + } + + auto actualDir = fmt::format("{}/actual_{}", tableOutputRootDirPath, i); + VELOX_CHECK(std::filesystem::create_directory(actualDir)); + + // Pass globally consistent columns to restrict flatmap usage + writeTasks[2 * i] = makeWriteTask( + testSetups[i], + data, + actualDir, + bucketColumnIndices, + rng_, + true, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); + + if (i == config_.evolutionCount - 1) { + finalExpectedData = std::move(data); + continue; + } + auto expectedDir = fmt::format("{}/expected_{}", tableOutputRootDirPath, i); + VELOX_CHECK(std::filesystem::create_directory(expectedDir)); + auto expectedData = std::static_pointer_cast( + liftToType(data, testSetups.back().schema)); + + writeTasks[2 * i + 1] = makeWriteTask( + testSetups.back(), + expectedData, + expectedDir, + bucketColumnIndices, + rng_, + true, + globalMapColumnKeys, + globallyConsistentColumnIndexVector); + } +} + +void TableEvolutionFuzzer::applyRemainingFilters( + const fuzzer::ExpressionFuzzer::FuzzedExpressionData& + generatedRemainingFilters, + const std::unordered_map& columnNameMapping, + PushdownConfig& pushownConfig, + const std::unordered_set& subfieldFilteredFields) { + if (generatedRemainingFilters.expressions.empty() || + columnNameMapping.empty()) { + return; + } + + std::vector filterStrings; + for (const auto& expr : generatedRemainingFilters.expressions) { + auto filterString = expr->toString(); + VLOG(1) << "Processing remaining filter expression: " << filterString; + + // First, update column names in the filter string using columnNameMapping + for (const auto& [originalName, currentName] : columnNameMapping) { + // Simple string replacement - this is a basic approach + // In a more robust implementation, we would parse the expression tree + size_t pos = 0; + while ((pos = filterString.find(originalName, pos)) != + std::string::npos) { + // Check if this is a complete word (not part of another identifier) + bool isCompleteWord = true; + if (pos > 0 && + (std::isalnum(filterString[pos - 1]) || + filterString[pos - 1] == '_')) { + isCompleteWord = false; + } + if (pos + originalName.length() < filterString.length() && + (std::isalnum(filterString[pos + originalName.length()]) || + filterString[pos + originalName.length()] == '_')) { + isCompleteWord = false; + } + + if (isCompleteWord) { + filterString.replace(pos, originalName.length(), currentName); + pos += currentName.length(); + } else { + pos += originalName.length(); + } + } + } + + VLOG(1) << "After column name mapping: " << filterString; + + // Now check if this filter expression conflicts with subfield filters + bool hasConflict = false; + for (const auto& subfieldField : subfieldFilteredFields) { + // Check if the filter string contains references to fields that are + // already filtered by subfield filters + size_t pos = 0; + while ((pos = filterString.find(subfieldField, pos)) != + std::string::npos) { + // Check if this is a complete word (not part of another identifier) + bool isCompleteWord = true; + if (pos > 0 && + (std::isalnum(filterString[pos - 1]) || + filterString[pos - 1] == '_')) { + isCompleteWord = false; + } + if (pos + subfieldField.length() < filterString.length() && + (std::isalnum(filterString[pos + subfieldField.length()]) || + filterString[pos + subfieldField.length()] == '_')) { + isCompleteWord = false; + } + + if (isCompleteWord) { + hasConflict = true; + VLOG(1) + << "CONFLICT DETECTED! Skipping remaining filter due to conflict with subfield filter on field: " + << subfieldField << ", filter: " << filterString; + break; + } + pos += subfieldField.length(); + } + if (hasConflict) { + break; + } + } + + // Skip this filter if it conflicts with subfield filters + if (hasConflict) { + VLOG(1) << "Skipping filter due to conflict: " << filterString; + continue; + } + + VLOG(1) << "No conflict detected, proceeding with filter: " << filterString; + + // Fix DATE literal format: convert bare date to DATE literal format + // to prevent DuckDB parser from interpreting it as arithmetic expression + RE2 datePattern(R"(\b(\d{4}-\d{2}-\d{2})\b)"); + RE2::GlobalReplace(&filterString, datePattern, "DATE '\\1'"); + + filterStrings.push_back(filterString); + VLOG(1) << "Updated filter expression: " << filterString; + } + + if (filterStrings.size() == 1) { + pushownConfig.remainingFilter = filterStrings[0]; + } else if (filterStrings.size() > 1) { + pushownConfig.remainingFilter = + "(" + folly::join(") AND (", filterStrings) + ")"; + } +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/TableEvolutionFuzzer.h b/velox/exec/tests/TableEvolutionFuzzer.h index f1540fc1e793..80f0ff95538d 100644 --- a/velox/exec/tests/TableEvolutionFuzzer.h +++ b/velox/exec/tests/TableEvolutionFuzzer.h @@ -14,6 +14,8 @@ * limitations under the License. */ +#pragma once + #include "velox/exec/Cursor.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -21,6 +23,8 @@ #include #include #include +#include +#include "velox/expression/fuzzer/ExpressionFuzzer.h" namespace facebook::velox::exec::test { @@ -76,22 +80,35 @@ class TableEvolutionFuzzer { TypePtr makeNewType(int maxDepth); - RowTypePtr makeInitialSchema(); + RowTypePtr makeInitialSchema( + const std::vector& additionalColumnNames = {}, + const std::vector& additionalColumnTypes = {}); TypePtr evolveType(const TypePtr& old); RowTypePtr evolveRowType( const RowType& old, - const std::vector& bucketColumnIndices); + const std::vector& bucketColumnIndices, + std::unordered_map* columnNameMapping = + nullptr); std::vector makeSetups( - const std::vector& bucketColumnIndices); + const std::vector& bucketColumnIndices, + const std::vector& additionalColumnNames = {}, + const std::vector& additionalColumnTypes = {}, + std::unordered_map* columnNameMapping = + nullptr); static std::unique_ptr makeWriteTask( const Setup& setup, const RowVectorPtr& data, const std::string& outputDir, - const std::vector& bucketColumnIndices); + const std::vector& bucketColumnIndices, + FuzzerGenerator& rng, + bool enableFlatMap, + folly::F14FastMap>& + globalMapColumnKeys, + std::vector& globallyCompatibleFlatmapColumns); template VectorPtr liftToPrimitiveType( @@ -102,7 +119,61 @@ class TableEvolutionFuzzer { std::unique_ptr makeScanTask( const RowTypePtr& tableSchema, - std::vector splits); + std::vector splits, + const PushdownConfig& pushdownConfig, + bool useFiltersAsNode, + bool insertProjectToBlockPushdown, + const folly::F14FastMap>& + globalMapColumnKeys = {}, + const std::vector& globallyCompatibleFlatmapColumns = {}); + + /// Builds schema for flatmap as struct reading by converting selected map + /// columns to struct types. + RowTypePtr buildFlatmapAsStructSchema( + const RowTypePtr& tableSchema, + const folly::F14FastMap>& + globalMapColumnKeys, + const std::vector& globallyCompatibleFlatmapColumns); + + /// Randomly generates bucket column indices for partitioning data. + /// Returns a vector of column indices that will be used for bucketing, + /// with each column having a 1/(2*columnCount) probability of being selected. + std::vector generateBucketColumnIndices(); + + /// Creates write tasks for all evolution steps. + /// Generates test data and creates TaskCursor objects for writing data + /// to temporary directories. Populates the writeTasks vector and sets + /// finalExpectedData to the data from the last evolution step. + void createWriteTasks( + const std::vector& testSetups, + const std::vector& bucketColumnIndices, + const std::string& tableOutputRootDirPath, + std::vector>& writeTasks, + RowVectorPtr& finalExpectedData, + folly::F14FastMap>& + globalMapColumnKeys, + std::vector& globallyConsistentColumnIndexVector); + + /// Creates scan splits from write results. + /// Converts the output of write tasks into scan splits that can be used + /// for reading the written data back during the scan phase. + std::pair, std::vector> + createScanSplitsFromWriteResults( + const std::vector>& writeResults, + const std::vector& testSetups, + const std::vector& bucketColumnIndices, + std::optional selectedBucket, + const RowVectorPtr& finalExpectedData); + + /// Applies remaining filters with updated column names. + /// Updates filter expressions to use evolved column names based on the + /// column name mapping tracked during schema evolution. + void applyRemainingFilters( + const fuzzer::ExpressionFuzzer::FuzzedExpressionData& + generatedRemainingFilters, + const std::unordered_map& columnNameMapping, + PushdownConfig& pushdownConfig, + const std::unordered_set& subfieldFilteredFields); const Config config_; VectorFuzzer vectorFuzzer_; diff --git a/velox/exec/tests/TableEvolutionFuzzerTest.cpp b/velox/exec/tests/TableEvolutionFuzzerTest.cpp index 09f672488cc5..250920369ec9 100644 --- a/velox/exec/tests/TableEvolutionFuzzerTest.cpp +++ b/velox/exec/tests/TableEvolutionFuzzerTest.cpp @@ -18,13 +18,16 @@ #include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include #include #include +#include "velox/parse/TypeResolver.h" DEFINE_uint32(seed, 0, ""); -DEFINE_int32(duration_sec, 30, ""); +DEFINE_int32(table_evolution_fuzzer_duration_sec, 30, ""); DEFINE_int32(column_count, 5, ""); DEFINE_int32(evolution_count, 5, ""); @@ -34,16 +37,12 @@ namespace { void registerFactories(folly::Executor* ioExecutor) { filesystems::registerLocalFileSystem(); - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - TableEvolutionFuzzer::connectorId(), - std::make_shared( - std::unordered_map()), - ioExecutor); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + TableEvolutionFuzzer::connectorId(), + std::make_shared( + std::unordered_map()), + ioExecutor); connector::registerConnector(hiveConnector); dwio::common::registerFileSinks(); dwrf::registerDwrfReaderFactory(); @@ -61,7 +60,8 @@ TEST(TableEvolutionFuzzerTest, run) { exec::test::TableEvolutionFuzzer fuzzer(config); fuzzer.setSeed(FLAGS_seed); const auto startTime = std::chrono::system_clock::now(); - const auto deadline = startTime + std::chrono::seconds(FLAGS_duration_sec); + const auto deadline = startTime + + std::chrono::seconds(FLAGS_table_evolution_fuzzer_duration_sec); for (int i = 0; std::chrono::system_clock::now() < deadline; ++i) { LOG(INFO) << "Starting iteration " << i << ", seed=" << fuzzer.seed(); fuzzer.run(); @@ -84,5 +84,8 @@ int main(int argc, char** argv) { facebook::velox::memory::MemoryManager::Options{}); auto ioExecutor = folly::getGlobalIOExecutor(); facebook::velox::exec::test::registerFactories(ioExecutor.get()); + facebook::velox::functions::prestosql::registerAllScalarFunctions(); + facebook::velox::aggregate::prestosql::registerAllAggregateFunctions(); + facebook::velox::parse::registerTypeResolver(); return RUN_ALL_TESTS(); } diff --git a/velox/exec/tests/TableScanTest.cpp b/velox/exec/tests/TableScanTest.cpp index 1a0e94a97d38..44fd4fc0fe6d 100644 --- a/velox/exec/tests/TableScanTest.cpp +++ b/velox/exec/tests/TableScanTest.cpp @@ -20,11 +20,13 @@ #include #include #include +#include #include "velox/common/base/Fs.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/caching/AsyncDataCache.h" #include "velox/common/caching/tests/CacheTestUtil.h" +#include "velox/common/file/File.h" #include "velox/common/file/tests/FaultyFile.h" #include "velox/common/file/tests/FaultyFileSystem.h" #include "velox/common/memory/MemoryArbitrator.h" @@ -33,20 +35,19 @@ #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveDataSource.h" #include "velox/connectors/hive/HivePartitionFunction.h" -#include "velox/dwio/common/CacheInputStream.h" #include "velox/dwio/common/tests/utils/DataFiles.h" +#include "velox/dwio/orc/reader/OrcReader.h" #include "velox/exec/Cursor.h" #include "velox/exec/Exchange.h" -#include "velox/exec/OutputBufferManager.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/TableScan.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" -#include "velox/exec/tests/utils/LocalExchangeSource.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/TableScanTestBase.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" #include "velox/expression/ExprToSubfieldFilter.h" +#include "velox/functions/lib/IsNull.h" #include "velox/type/Timestamp.h" #include "velox/type/Type.h" #include "velox/type/tests/SubfieldFiltersBuilder.h" @@ -55,13 +56,13 @@ using namespace facebook::velox; using namespace facebook::velox::cache; using namespace facebook::velox::connector::hive; using namespace facebook::velox::core; -using namespace facebook::velox::exec; using namespace facebook::velox::common::test; using namespace facebook::velox::exec::test; using namespace facebook::velox::tests::utils; DECLARE_int32(cache_prefetch_min_pct); +namespace facebook::velox::exec { namespace { void verifyCacheStats( const FileHandleCacheStats& cacheStats, @@ -72,9 +73,13 @@ void verifyCacheStats( EXPECT_EQ(cacheStats.numHits, numHits); EXPECT_EQ(cacheStats.numLookups, numLookups); } -} // namespace -class TableScanTest : public TableScanTestBase {}; +class TableScanTest : public TableScanTestBase { + void SetUp() override { + TableScanTestBase::SetUp(); + orc::registerOrcReaderFactory(); + } +}; TEST_F(TableScanTest, allColumns) { auto vectors = makeVectors(10, 1'000); @@ -286,7 +291,7 @@ TEST_F(TableScanTest, partitionKeyAlias) { writeToFile(filePath->getPath(), vectors); createDuckDbTable(vectors); - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"a", regularColumn("c0", BIGINT())}, {"ds_alias", partitionKey("ds", VARCHAR())}}; @@ -355,8 +360,8 @@ TEST_F(TableScanTest, timestamp) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(ROW({"c0", "c1"}, {BIGINT(), TIMESTAMP()})) - .subfieldFilter("c1 is null") .dataColumns(dataColumns) + .subfieldFilter("c1 is null") .endTableScan() .planNode(); assertQuery(op, {filePath}, "SELECT c0, c1 FROM tmp WHERE c1 is null"); @@ -364,8 +369,8 @@ TEST_F(TableScanTest, timestamp) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(ROW({"c0", "c1"}, {BIGINT(), TIMESTAMP()})) - .subfieldFilter("c1 < '1970-01-01 01:30:00'::TIMESTAMP") .dataColumns(dataColumns) + .subfieldFilter("c1 < '1970-01-01 01:30:00'::TIMESTAMP") .endTableScan() .planNode(); assertQuery( @@ -384,8 +389,8 @@ TEST_F(TableScanTest, timestamp) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(ROW({"c0"}, {BIGINT()})) - .subfieldFilter("c1 is null") .dataColumns(dataColumns) + .subfieldFilter("c1 is null") .endTableScan() .planNode(); assertQuery(op, {filePath}, "SELECT c0 FROM tmp WHERE c1 is null"); @@ -393,8 +398,8 @@ TEST_F(TableScanTest, timestamp) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(ROW({"c0"}, {BIGINT()})) - .subfieldFilter("c1 < timestamp'1970-01-01 01:30:00'") .dataColumns(dataColumns) + .subfieldFilter("c1 < timestamp'1970-01-01 01:30:00'") .endTableScan() .planNode(); assertQuery( @@ -520,8 +525,7 @@ TEST_F(TableScanTest, subfieldPruningRowType) { writeToFile(filePath->getPath(), vectors); std::vector requiredSubfields; requiredSubfields.emplace_back("e.c"); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["e"] = std::make_shared( "e", HiveColumnHandle::ColumnType::kRegular, @@ -577,8 +581,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterSubfieldsMissing) { writeToFile(filePath->getPath(), vectors); std::vector requiredSubfields; requiredSubfields.emplace_back("e.c"); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["e"] = std::make_shared( "e", HiveColumnHandle::ColumnType::kRegular, @@ -633,8 +636,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterRootFieldMissing) { auto vectors = makeVectors(10, 1'000, rowType); auto filePath = TempFilePath::create(); writeToFile(filePath->getPath(), vectors); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["d"] = std::make_shared( "d", HiveColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); auto op = PlanBuilder() @@ -676,8 +678,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterStruct) { for (int filterColumn = kWholeColumn; filterColumn <= kSubfieldOnly; ++filterColumn) { SCOPED_TRACE(fmt::format("{} {}", outputColumn, filterColumn)); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["d"] = std::make_shared( "d", HiveColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); if (outputColumn > kNoOutput) { @@ -762,8 +763,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterMap) { for (int filterColumn = kWholeColumn; filterColumn <= kSubfieldOnly; ++filterColumn) { SCOPED_TRACE(fmt::format("{} {}", outputColumn, filterColumn)); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["a"] = std::make_shared( "a", HiveColumnHandle::ColumnType::kRegular, BIGINT(), BIGINT()); if (outputColumn > kNoOutput) { @@ -780,8 +780,7 @@ TEST_F(TableScanTest, subfieldPruningRemainingFilterMap) { } std::string remainingFilter; if (filterColumn == kWholeColumn) { - remainingFilter = - "coalesce(b, cast(null AS MAP(BIGINT, BIGINT)))[0] == 0"; + remainingFilter = "coalesce(b, map_concat(b, b))[0] == 0"; } else { remainingFilter = "b[0] == 0"; } @@ -862,8 +861,7 @@ TEST_F(TableScanTest, subfieldPruningMapType) { requiredSubfields.emplace_back("c[0]"); requiredSubfields.emplace_back("c[2]"); requiredSubfields.emplace_back("c[4]"); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["c"] = std::make_shared( "c", HiveColumnHandle::ColumnType::kRegular, @@ -946,8 +944,7 @@ TEST_F(TableScanTest, subfieldPruningArrayType) { writeToFile(filePath->getPath(), vectors); std::vector requiredSubfields; requiredSubfields.emplace_back("c[3]"); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["c"] = std::make_shared( "c", HiveColumnHandle::ColumnType::kRegular, @@ -1064,8 +1061,8 @@ TEST_F(TableScanTest, missingColumns) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(outputType) - .subfieldFilter("c1 <= 100.1") .dataColumns(dataColumns) + .subfieldFilter("c1 <= 100.1") .endTableScan() .planNode(); assertQuery(op, filePaths, "SELECT * FROM tmp WHERE c1 <= 100.1", 0); @@ -1074,8 +1071,8 @@ TEST_F(TableScanTest, missingColumns) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(outputType) - .subfieldFilter("c1 <= 2000.1") .dataColumns(dataColumns) + .subfieldFilter("c1 <= 2000.1") .endTableScan() .planNode(); @@ -1085,8 +1082,8 @@ TEST_F(TableScanTest, missingColumns) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(outputTypeC0) - .subfieldFilter("c1 <= 3000.1") .dataColumns(dataColumns) + .subfieldFilter("c1 <= 3000.1") .endTableScan() .planNode(); @@ -1096,8 +1093,8 @@ TEST_F(TableScanTest, missingColumns) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(ROW({}, {})) - .subfieldFilter("c1 <= 4000.1") .dataColumns(dataColumns) + .subfieldFilter("c1 <= 4000.1") .endTableScan() .singleAggregation({}, {"count(1)"}) .planNode(); @@ -1109,7 +1106,7 @@ TEST_F(TableScanTest, missingColumns) { filters[common::Subfield("c1")] = lessThanOrEqualDouble(1050.0, true); auto tableHandle = std::make_shared( kHiveConnectorId, "tmp", true, std::move(filters), nullptr, dataColumns); - ColumnHandleMap assignments; + connector::ColumnHandleMap assignments; assignments["c0"] = regularColumn("c0", BIGINT()); op = PlanBuilder(pool_.get()) .startTableScan() @@ -1125,8 +1122,8 @@ TEST_F(TableScanTest, missingColumns) { op = PlanBuilder(pool_.get()) .startTableScan() .outputType(ROW({}, {})) - .subfieldFilter("c1 is null") .dataColumns(dataColumns) + .subfieldFilter("c1 is null") .endTableScan() .singleAggregation({}, {"count(1)"}) .planNode(); @@ -1347,6 +1344,18 @@ TEST_F(TableScanTest, batchSize) { EXPECT_GT(opStats.outputPositions / opStats.outputVectors, 1); EXPECT_LT(opStats.outputPositions / opStats.outputVectors, numRows); } + { + SCOPED_TRACE("Projection"); + plan = PlanBuilder().tableScan(ROW({}, {}), {}, "", rowType).planNode(); + auto task = AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits({filePath})) + .config( + QueryConfig::kPreferredOutputBatchBytes, + std::to_string(1 + numRows / 8)) + .assertResults(makeRowVector(ROW({}, {}), numRows)); + const auto opStats = task->taskStats().pipelineStats[0].operatorStats[0]; + EXPECT_EQ(opStats.outputVectors, 1); + } } // Test that adding the same split with the same sequence id does not cause @@ -1767,6 +1776,33 @@ TEST_F(TableScanTest, preloadEmptySplit) { assertQuery(op, filePaths, "SELECT * FROM tmp", 1); } +TEST_F(TableScanTest, readAsLowerCase) { + auto rowType = + ROW({"Товары", "国Ⅵ", "\uFF21", "\uFF22"}, + {BIGINT(), DOUBLE(), REAL(), INTEGER()}); + auto vectors = makeVectors(10, 1'000, rowType); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), vectors); + createDuckDbTable(vectors); + + // Test reading table with non-ascii names. + auto op = PlanBuilder() + .tableScan( + ROW({"товары", "国ⅵ", "\uFF41", "\uFF42"}, + {BIGINT(), DOUBLE(), REAL(), INTEGER()})) + .planNode(); + auto split = + exec::test::HiveConnectorSplitBuilder(filePath->getPath()).build(); + + AssertQueryBuilder(op, duckDbQueryRunner_) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kFileColumnNamesReadAsLowerCaseSession, + "true") + .split(split) + .assertResults("SELECT * FROM tmp"); +} + TEST_F(TableScanTest, partitionedTableVarcharKey) { auto rowType = ROW({"c0", "c1"}, {BIGINT(), DOUBLE()}); auto vectors = makeVectors(10, 1'000, rowType); @@ -1855,7 +1891,7 @@ TEST_F(TableScanTest, partitionedTableDateKey) { .partitionKey("pkey", partitionValue) .build(); auto outputType = ROW({"pkey", "c0", "c1"}, {DATE(), BIGINT(), DOUBLE()}); - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"pkey", partitionKey("pkey", DATE())}, {"c0", regularColumn("c0", BIGINT())}, {"c1", regularColumn("c1", DOUBLE())}}; @@ -1895,7 +1931,7 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { .partitionKey("pkey", partitionValue) .build(); - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"pkey", partitionKey("pkey", TIMESTAMP())}, {"c0", regularColumn("c0", BIGINT())}, {"c1", regularColumn("c1", DOUBLE())}}; @@ -1933,8 +1969,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT {}, * FROM tmp", asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT {}, * FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); @@ -1960,9 +1998,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT c0, {}, c1 FROM tmp", - asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT c0, {}, c1 FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -1987,9 +2026,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT c0, c1, {} FROM tmp", - asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT c0, c1, {} FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -2014,8 +2054,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT {} FROM tmp", asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT {} FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -2062,8 +2104,10 @@ TEST_F(TableScanTest, partitionedTableTimestampKey) { kReadTimestampPartitionValueAsLocalTimeSession, asLocalTime ? "true" : "false") .splits({split}) - .assertResults(fmt::format( - "SELECT {}, * FROM tmp", asLocalTime ? tsValueAsLocal : tsValue)); + .assertResults( + fmt::format( + "SELECT {}, * FROM tmp", + asLocalTime ? tsValueAsLocal : tsValue)); }; expect(true); expect(false); @@ -2223,7 +2267,8 @@ TEST_F(TableScanTest, statsBasedSkipping) { // c0 <= -1 -> whole file should be skipped based on stats auto subfieldFilters = singleSubfieldFilter("c0", lessThanOrEqual(-1)); - ColumnHandleMap assignments = {{"c1", regularColumn("c1", INTEGER())}}; + connector::ColumnHandleMap assignments = { + {"c1", regularColumn("c1", INTEGER())}}; auto assertQuery = [&](const std::string& query) { auto tableHandle = makeTableHandle( @@ -2470,7 +2515,7 @@ TEST_F(TableScanTest, statsBasedSkippingWithoutDecompression) { auto assertQuery = [&](const std::string& filter) { auto rowType = asRowType(rowVector->type()); return TableScanTest::assertQuery( - PlanBuilder(pool_.get()).tableScan(rowType, {filter}).planNode(), + PlanBuilder(pool_.get()).tableScan(rowType, {}, {filter}).planNode(), filePaths, "SELECT * FROM tmp WHERE " + filter); }; @@ -3031,6 +3076,7 @@ TEST_F(TableScanTest, bucketConversion) { return splits; }; { + SCOPED_TRACE("Basic"); auto outputType = ROW({"c1"}, {BIGINT()}); auto plan = PlanBuilder().tableScan(outputType, {}, "", schema).planNode(); std::vector c1; @@ -3087,6 +3133,28 @@ TEST_F(TableScanTest, bucketConversion) { auto expected = makeRowVector({"c2", "c1"}, {data, data}); AssertQueryBuilder(plan).splits(makeSplits()).assertResults(expected); } + { + SCOPED_TRACE("Dynamic filters"); + auto outputType = ROW({"c1"}, {BIGINT()}); + auto build = makeRowVector({"cc1"}, {makeFlatVector({2, 3})}); + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId scanNodeId; + auto plan = + PlanBuilder(planNodeIdGenerator) + .tableScan(outputType, {}, "", schema) + .capturePlanNodeId(scanNodeId) + .hashJoin( + {"c1"}, + {"cc1"}, + PlanBuilder(planNodeIdGenerator).values({build}).planNode(), + "", + {"c1"}) + .planNode(); + auto expected = makeRowVector({makeConstant(2, 1)}); + AssertQueryBuilder(plan) + .splits(scanNodeId, makeSplits()) + .assertResults(expected); + } } TEST_F(TableScanTest, bucketConversionWithSubfieldPruning) { @@ -3137,6 +3205,52 @@ TEST_F(TableScanTest, bucketConversionWithSubfieldPruning) { ASSERT_EQ(j, result->size()); } +TEST_F(TableScanTest, bucketConversionLazyColumn) { + auto vector = makeRowVector({ + makeFlatVector({1, 3, 5}), + makeFlatVector({4, 4, 6}), + }); + auto schema = asRowType(vector->type()); + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + constexpr int kNewNumBuckets = 4; + std::vector> splits; + { + // First split requires bucket conversion. + std::vector> handles; + handles.push_back(makeColumnHandle("c0", INTEGER(), {})); + auto split = makeHiveConnectorSplit(file->getPath()); + split->tableBucketNumber = 1; + split->bucketConversion = {kNewNumBuckets, 2, std::move(handles)}; + splits.push_back(split); + } + // Second split no bucket conversion, and non-empty after filter. + vector = makeRowVector({ + makeFlatVector({1, 3, 5}), + makeFlatVector({4, 5, 6}), + }); + auto file2 = TempFilePath::create(); + writeToFile(file2->getPath(), {vector}); + splits.push_back(makeHiveConnectorSplit(file2->getPath())); + { + // Third split requires bucket conversion, empty after filter. + std::vector> handles; + handles.push_back(makeColumnHandle("c0", INTEGER(), {})); + auto split = makeHiveConnectorSplit(file->getPath()); + split->tableBucketNumber = 3; + split->bucketConversion = {kNewNumBuckets, 2, std::move(handles)}; + splits.push_back(split); + } + auto outputType = ROW({"c1"}, {BIGINT()}); + auto plan = + PlanBuilder().tableScan(outputType, {"c1 = 5"}, "", schema).planNode(); + auto expected = makeRowVector({makeConstant(5, 1)}); + AssertQueryBuilder(plan) + .splits(splits) + .config(core::QueryConfig::kMaxSplitPreloadPerDriver, "0") + .assertResults(expected); +} + TEST_F(TableScanTest, integerNotEqualFilter) { auto rowType = ROW( {"c0", "c1", "c2", "c3"}, {TINYINT(), SMALLINT(), INTEGER(), BIGINT()}); @@ -3417,7 +3531,8 @@ TEST_F(TableScanTest, remainingFilter) { "SELECT * FROM tmp WHERE c1 > c0 AND c0 >= 0"); // Remaining filter uses columns that are not used otherwise. - ColumnHandleMap assignments = {{"c2", regularColumn("c2", DOUBLE())}}; + connector::ColumnHandleMap assignments = { + {"c2", regularColumn("c2", DOUBLE())}}; assertQuery( PlanBuilder(pool_.get()) @@ -4026,7 +4141,8 @@ TEST_F(TableScanTest, interleaveLazyEager) { auto eagerFile = TempFilePath::create(); writeToFile(eagerFile->getPath(), rowsWithNulls); - ColumnHandleMap assignments = {{"c0", regularColumn("c0", column->type())}}; + connector::ColumnHandleMap assignments = { + {"c0", regularColumn("c0", column->type())}}; CursorParameters params; params.planNode = PlanBuilder() .startTableScan() @@ -4208,6 +4324,53 @@ TEST_F(TableScanTest, parallelPrepare) { .copyResults(pool_.get()); } +TEST_F(TableScanTest, parallelPrepareWithSubfieldFilters) { + // Test metadataFilter is correctly transferred during split prefetch. + constexpr int32_t kNumParallel = 100; + auto data = makeRowVector({ + makeFlatVector(100, [](auto row) { return row; }), + makeFlatVector(100, [](auto row) { return row * 2; }), + }); + + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), {data}); + + auto subfieldFilters = SubfieldFiltersBuilder() + .add("c0", greaterThanOrEqual(10)) + .add("c1", lessThan(150)) + .build(); + + auto outputType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + auto remainingFilter = parseExpr("c0 % 2 = 0", outputType); + auto tableHandle = + makeTableHandle(std::move(subfieldFilters), remainingFilter); + auto assignments = allRegularColumns(outputType); + + auto plan = exec::test::PlanBuilder(pool_.get()) + .startTableScan() + .outputType(outputType) + .tableHandle(tableHandle) + .assignments(assignments) + .endTableScan() + .planNode(); + + std::vector splits; + for (auto i = 0; i < kNumParallel; ++i) { + splits.push_back(makeHiveSplit(filePath->getPath())); + } + + auto result = AssertQueryBuilder(plan) + .config( + core::QueryConfig::kMaxSplitPreloadPerDriver, + std::to_string(kNumParallel)) + .splits(splits) + .copyResults(pool_.get()); + + // Verify results: c0 >= 10 AND c1 < 150 AND c0 % 2 = 0 + // So rows [10,12,14,...,74] = 33 rows per split + ASSERT_EQ(result->size(), 33 * kNumParallel); +} + TEST_F(TableScanTest, dictionaryMemo) { constexpr int kSize = 100; const char* baseStrings[] = { @@ -4839,7 +5002,7 @@ TEST_F(TableScanTest, varbinaryPartitionKey) { writeToFile(filePath->getPath(), vectors); createDuckDbTable(vectors); - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"a", regularColumn("c0", BIGINT())}, {"ds_alias", partitionKey("ds", VARBINARY())}}; @@ -4896,7 +5059,8 @@ TEST_F(TableScanTest, timestampPartitionKey) { return splits; }; - ColumnHandleMap assignments = {{"t", partitionKey("t", TIMESTAMP())}}; + connector::ColumnHandleMap assignments = { + {"t", partitionKey("t", TIMESTAMP())}}; auto plan = PlanBuilder() .startTableScan() .outputType(ROW({"t"}, {TIMESTAMP()})) @@ -5001,6 +5165,88 @@ TEST_F(TableScanTest, flatMapReadOffset) { .assertResults(expected); } +TEST_F(TableScanTest, flatMapKeyTypeEvolution) { + auto vector = + makeRowVector({makeMapVector({{{1, 2}, {3, 4}}})}); + auto config = std::make_shared(); + config->set(dwrf::Config::FLATTEN_MAP, true); + config->set>(dwrf::Config::MAP_FLAT_COLS, {0}); + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}, config); + auto split = makeHiveConnectorSplit(file->getPath()); + auto schema = ROW({"c0"}, {MAP(BIGINT(), BIGINT())}); + { + SCOPED_TRACE("Read as map"); + auto plan = PlanBuilder().tableScan(schema).planNode(); + auto expected = + makeRowVector({makeMapVector({{{1, 2}, {3, 4}}})}); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } + { + SCOPED_TRACE("Read as struct"); + auto readSchema = ROW({"c0"}, {ROW({"1", "3"}, {BIGINT(), BIGINT()})}); + auto plan = PlanBuilder().tableScan(readSchema, {}, "", schema).planNode(); + auto expected = makeRowVector({makeRowVector( + {"1", "3"}, + {makeConstant(2, 1), makeConstant(4, 1)})}); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } +} + +TEST_F(TableScanTest, flatMapLazyRowValue) { + auto c0 = makeMapVector( + {0, 2}, + makeFlatVector({1, 2, 1, 2}), + makeRowVector({ + makeFlatVector({3, 4, 5, 6}), + makeFlatVector({7, 8, 9, 10}), + })); + auto vector = makeRowVector({c0, makeFlatVector({false, true})}); + auto config = std::make_shared(); + config->set(dwrf::Config::FLATTEN_MAP, true); + config->set>(dwrf::Config::MAP_FLAT_COLS, {0}); + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}, config); + auto split = makeHiveConnectorSplit(file->getPath()); + { + SCOPED_TRACE("Read as map"); + auto plan = PlanBuilder() + .tableScan( + ROW("c0", c0->type()), + {"c0 is not null", "c1 = true"}, + "", + vector->rowType()) + .planNode(); + auto expected = makeRowVector({wrapInDictionary(makeIndices({1}), c0)}); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } + { + SCOPED_TRACE("Read as struct"); + auto valueType = ROW({"c0", "c1"}, {BIGINT(), BIGINT()}); + auto readSchema = ROW("c0", ROW({"1", "2"}, valueType)); + auto plan = PlanBuilder() + .tableScan( + readSchema, + {"c0 is not null", "c1 = true"}, + "", + vector->rowType()) + .planNode(); + auto expected = makeRowVector({makeRowVector( + {"1", "2"}, + { + makeRowVector({ + makeConstant(5, 1), + makeConstant(9, 1), + }), + makeRowVector({ + makeConstant(6, 1), + makeConstant(10, 1), + }), + })}); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } +} + TEST_F(TableScanTest, dynamicFilters) { // Make sure filters on same column from multiple downstream operators are // merged properly without overwriting each other. @@ -5067,8 +5313,7 @@ TEST_F(TableScanTest, dynamicFilterWithRowIndexColumn) { {"row_index", "a"}, {makeFlatVector(5, folly::identity), makeFlatVector(5, folly::identity)}); - std::unordered_map> - assignments; + connector::ColumnHandleMap assignments; assignments["a"] = std::make_shared( "a", connector::hive::HiveColumnHandle::ColumnType::kRegular, @@ -5112,6 +5357,63 @@ TEST_F(TableScanTest, dynamicFilterWithRowIndexColumn) { .assertResults(resVector); } +TEST_F(TableScanTest, bloomFilterPushdown) { + auto build = makeRowVector( + {"b"}, + { + makeFlatVector( + 10'001 + VectorHasher::kMaxDistinct, + [](auto i) { return 1000 * i; }), + }); + auto probe = makeRowVector( + {"a"}, + { + makeFlatVector( + 2 * build->size(), [](auto i) { return 500 * i; }), + }); + std::shared_ptr files[2]; + files[0] = TempFilePath::create(); + writeToFile(files[0]->getPath(), {probe}); + files[1] = TempFilePath::create(); + writeToFile(files[1]->getPath(), {build}); + auto idGenerator = std::make_shared(); + core::PlanNodeId probeScanId, buildScanId, joinId; + auto plan = PlanBuilder(idGenerator) + .tableScan(ROW({"a"}, {BIGINT()})) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"a"}, + {"b"}, + PlanBuilder(idGenerator) + .tableScan(ROW({"b"}, {BIGINT()})) + .capturePlanNodeId(buildScanId) + .planNode(), + /*filter=*/"", + {"a"}) + .capturePlanNodeId(joinId) + .planNode(); + for (bool parallelBuild : {false, true}) { + SCOPED_TRACE(fmt::format("parallelBuild={}", parallelBuild)); + AssertQueryBuilder builder(plan); + builder + .config( + core::QueryConfig::kHashProbeBloomFilterPushdownMaxSize, + std::to_string(4 * build->size())) + .split(probeScanId, makeHiveConnectorSplit(files[0]->getPath())) + .split(buildScanId, makeHiveConnectorSplit(files[1]->getPath())); + if (parallelBuild) { + builder.serialExecution(false).maxDrivers(2).config( + core::QueryConfig::kMinTableRowsForParallelJoinBuild, "1"); + } + auto task = builder.assertResults(build); + auto planStats = toPlanStats(task->taskStats()); + ASSERT_EQ( + planStats.at(joinId).customStats.at("dynamicFiltersProduced").sum, + parallelBuild ? 2 : 1); + ASSERT_GT(planStats.at(joinId).customStats.at("bloomFilterSize").sum, 0); + } +} + // TODO: re-enable this test once we add back driver suspension support for // table scan. TEST_F(TableScanTest, DISABLED_memoryArbitrationWithSlowTableScan) { @@ -5519,11 +5821,11 @@ TEST_F(TableScanTest, footerIOCount) { .assertResults( BaseVector::create(vector->type(), 0, pool())); auto stats = getTableScanRuntimeStats(task); - ASSERT_EQ(stats.at("numStorageRead").sum, 1); + ASSERT_EQ(stats.at("storageReadBytes").count, 1); ASSERT_GT(stats.at("footerBufferOverread").sum, 0); } -TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { +TEST_F(TableScanTest, statsBasedFilterReorderBothEnabledAndDisabled) { gflags::FlagSaver gflagSaver; // Disable prefetch to avoid test flakiness. FLAGS_cache_prefetch_min_pct = 200; @@ -5552,8 +5854,8 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { } createDuckDbTable(vectors); - for (auto disableReoder : {false}) { - SCOPED_TRACE(fmt::format("disableReoder {}", disableReoder)); + for (auto disableReorder : {false, true}) { + SCOPED_TRACE(fmt::format("disableReorder {}", disableReorder)); auto* cache = cache::AsyncDataCache::getInstance(); cache->clear(); @@ -5586,7 +5888,7 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { kHiveConnectorId, connector::hive::HiveConfig:: kReadStatsBasedFilterReorderDisabledSession, - disableReoder ? "true" : "false") + disableReorder ? "true" : "false") // Disable coalesce so that each column stream has a separate read // per split at least. .connectorSessionProperty( @@ -5610,7 +5912,7 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { auto tableScanStats = getTableScanStats(task); ASSERT_EQ(tableScanStats.customStats.count("storageReadBytes"), 1); ASSERT_GT(tableScanStats.customStats["storageReadBytes"].sum, 0); - ASSERT_EQ(tableScanStats.customStats["storageReadBytes"].count, 1); + ASSERT_GT(tableScanStats.customStats["storageReadBytes"].count, 0); ASSERT_EQ(tableScanStats.numSplits, numSplits); } @@ -5622,7 +5924,7 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { kHiveConnectorId, connector::hive::HiveConfig:: kReadStatsBasedFilterReorderDisabledSession, - disableReoder ? "true" : "false") + disableReorder ? "true" : "false") .connectorSessionProperty( kHiveConnectorId, connector::hive::HiveConfig::kMaxCoalescedBytesSession, @@ -5640,17 +5942,566 @@ TEST_F(TableScanTest, statsBasedFilterReorderDisabled) { "SELECT c0 FROM tmp WHERE (c1 IN (1,7,11) OR c1 IS NULL) AND (c3 IN (1,7,11) OR c3 IS NULL)"); auto tableScanStats = getTableScanStats(task); - if (disableReoder) { + if (disableReorder) { ASSERT_EQ(tableScanStats.customStats.count("storageReadBytes"), 0); } else { + // Cache hit if (tableScanStats.customStats.count("storageReadBytes") == 0) { continue; } + // Cache miss, should behave like first time run ASSERT_EQ(tableScanStats.customStats.count("storageReadBytes"), 1); ASSERT_GT(tableScanStats.customStats["storageReadBytes"].sum, 0); - ASSERT_EQ(tableScanStats.customStats["storageReadBytes"].count, 1); + ASSERT_GT(tableScanStats.customStats["storageReadBytes"].count, 0); } ASSERT_EQ(tableScanStats.numSplits, numSplits); } } } + +TEST_F(TableScanTest, prevBatchEmptyAdaptivity) { + auto rowType = ROW({"c0", "c1"}, {BIGINT(), VARCHAR()}); + + const vector_size_t size = 100; + const size_t stringBytes = 1024 * 1024; + const size_t preferredOutputBatchBytes = 10UL << 20; + + const std::string sampleString(stringBytes, 'a'); + StringView sampleStringView(sampleString); + auto rowVector = makeRowVector( + {makeFlatVector( + size, + [&](auto row) { + return row % 100 == 50 ? 51 : row % 100; + }), // so that the filter "c0 = 50" cannot rely on the min-max range + // to filter out all data in the data source even before the + // first batch is read + makeFlatVector( + size, [&](auto /*unused*/) { return sampleStringView; })}); + + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), rowVector); + createDuckDbTable({rowVector}); + + auto plan = PlanBuilder().tableScan(rowType, {"c0 = 50"}).planNode(); + { + auto task = AssertQueryBuilder(duckDbQueryRunner_) + .plan(plan) + .split(makeHiveConnectorSplit(filePath->getPath())) + .config( + QueryConfig::kMaxOutputBatchRows, + folly::to(size * 4)) + .config( + QueryConfig::kPreferredOutputBatchBytes, + folly::to(preferredOutputBatchBytes)) + .assertResults("SELECT * FROM tmp WHERE c0 = 50"); + const auto opStats = task->taskStats().pipelineStats[0].operatorStats[0]; + const auto numBatchesRead = + opStats.runtimeStats.at("dataSourceReadWallNanos").count - 1; + const auto batchSizeWithoutAdaptivity = + QueryConfig({}).preferredOutputBatchBytes() / + (sizeof(int64_t) + stringBytes + sizeof(StringView)); + const auto numBatchesReadWithoutAdaptivity = + bits::divRoundUp(size, batchSizeWithoutAdaptivity); + EXPECT_GT(numBatchesReadWithoutAdaptivity, numBatchesRead); + } +} + +TEST_F(TableScanTest, textfileEscape) { + auto expected = makeRowVector( + {"c0", "c1"}, + { + makeFlatVector({"a,bc", "d"}), + makeFlatVector({"e", "e"}), + }); + + const auto tempFile = TempFilePath::create(); + const auto tempPath = tempFile->getPath(); + remove(tempPath.c_str()); + LocalWriteFile localWriteFile(tempPath); + localWriteFile.append("a\\,bc,e\nd,e"); + localWriteFile.close(); + + std::unordered_map customSplitInfo; + std::unordered_map> partitionKeys; + std::unordered_map serdeParameters{ + {dwio::common::SerDeOptions::kFieldDelim, ","}, + {dwio::common::SerDeOptions::kEscapeChar, "\\"}}; + + auto split = std::make_shared( + kHiveConnectorId, + tempPath, + dwio::common::FileFormat(dwio::common::FileFormat::TEXT), + 0, + std::numeric_limits::max(), + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + serdeParameters); + + auto inputType = asRowType(expected->type()); + auto plan = + PlanBuilder(pool()).tableScan(inputType, {}, "", inputType).planNode(); + + auto task = facebook::velox::exec::test::AssertQueryBuilder(plan) + .split(split) + .assertResults(expected); + auto planStats = facebook::velox::exec::toPlanStats(task->taskStats()); + auto scanNodeId = plan->id(); + auto it = planStats.find(scanNodeId); + ASSERT_TRUE(it != planStats.end()); + auto rawInputBytes = it->second.rawInputBytes; + auto overreadBytes = getTableScanRuntimeStats(task).at("overreadBytes").sum; + + ASSERT_EQ(rawInputBytes, 11); + ASSERT_EQ(overreadBytes, 0); +} + +TEST_F(TableScanTest, textfileChunkReadEntireFile) { + auto expected = makeRowVector( + {"c0", "c1"}, + { + makeFlatVector({"row1_col1", "row2_col1", "row3_col1"}), + makeFlatVector({"row1_col2", "row2_col2", "row3_col2"}), + }); + + const auto tempFile = TempFilePath::create(); + const auto tempPath = tempFile->getPath(); + remove(tempPath.c_str()); + LocalWriteFile localWriteFile(tempPath); + + localWriteFile.append("row1_col1,row1_col2\n"); + localWriteFile.append("row2_col1,row2_col2\n"); + localWriteFile.append("row3_col1,row3_col2\n"); + + // Add extra padding data that might be read but not used + localWriteFile.append("extra_row1,extra_data1\n"); + localWriteFile.append("extra_row2,extra_data2\n"); + localWriteFile.close(); + + std::unordered_map customSplitInfo; + std::unordered_map> partitionKeys; + std::unordered_map serdeParameters{ + {dwio::common::SerDeOptions::kFieldDelim, ","}}; + + // Create a split that only reads part of the file (first 60 bytes) + // This should cause the reader to potentially overread beyond the split + // boundary + auto split = std::make_shared( + kHiveConnectorId, + tempPath, + dwio::common::FileFormat(dwio::common::FileFormat::TEXT), + 0, + 59, // Limit to first 60 bytes instead of reading entire file + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + serdeParameters); + + auto inputType = asRowType(expected->type()); + auto plan = + PlanBuilder(pool()).tableScan(inputType, {}, "", inputType).planNode(); + + auto task = facebook::velox::exec::test::AssertQueryBuilder(plan) + .split(split) + .assertResults(expected); + + auto planStats = facebook::velox::exec::toPlanStats(task->taskStats()); + auto scanNodeId = plan->id(); + auto it = planStats.find(scanNodeId); + ASSERT_TRUE(it != planStats.end()); + auto rawInputBytes = it->second.rawInputBytes; + + // Entire file was read in a single chunk even though range is [0,59] + ASSERT_EQ(rawInputBytes, 106); +} + +TEST_F(TableScanTest, textfileLarge) { + constexpr int kNumRows = + 100000; // This will generate well over 8388608 bytes (per chunk read) + constexpr int kNumCols = 10; + + constexpr int loadQuantum = 8 << 20; // loadQuantum_ as of June 2025 + + // Helper function to generate column data + auto generateColumnData = [](int row, int col) { + return fmt::format("row{}_col{}_padding_data_to_increase_size", row, col); + }; + + // Helper function to generate CSV row + auto generateCsvRow = [&](int row) { + std::vector cols; + cols.reserve(kNumCols); + for (int col = 0; col < kNumCols; ++col) { + cols.push_back(generateColumnData(row, col)); + } + return fmt::format("{}\n", fmt::join(cols, ",")); + }; + + // Create expected result (only first row since split limit is 10 bytes) + std::vector expectedRow; + expectedRow.reserve(kNumCols); + for (int col = 0; col < kNumCols; ++col) { + expectedRow.push_back(generateColumnData(0, col)); + } + + std::vector columnNames; + std::vector columnVectors; + columnNames.reserve(kNumCols); + columnVectors.reserve(kNumCols); + + for (int col = 0; col < kNumCols; ++col) { + columnNames.push_back(fmt::format("c{}", col)); + columnVectors.push_back(makeFlatVector({expectedRow[col]})); + } + + auto expected = makeRowVector(columnNames, columnVectors); + + // Create large file + const auto tempFile = TempFilePath::create(); + const auto tempPath = tempFile->getPath(); + remove(tempPath.c_str()); + LocalWriteFile localWriteFile(tempPath); + + for (int row = 0; row < kNumRows; ++row) { + localWriteFile.append(generateCsvRow(row)); + } + localWriteFile.close(); + + ASSERT_GE(std::filesystem::file_size(tempPath), loadQuantum); + + std::unordered_map customSplitInfo; + std::unordered_map> partitionKeys; + std::unordered_map serdeParameters{ + {dwio::common::SerDeOptions::kFieldDelim, ","}}; + + auto split = std::make_shared( + kHiveConnectorId, + tempPath, + dwio::common::FileFormat(dwio::common::FileFormat::TEXT), + 0, + 10, // Limit to only first row + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + serdeParameters); + + auto inputType = asRowType(expected->type()); + auto plan = + PlanBuilder(pool()).tableScan(inputType, {}, "", inputType).planNode(); + + auto task = facebook::velox::exec::test::AssertQueryBuilder(plan) + .split(split) + .assertResults(expected); + + auto planStats = facebook::velox::exec::toPlanStats(task->taskStats()); + auto scanNodeId = plan->id(); + auto it = planStats.find(scanNodeId); + ASSERT_TRUE(it != planStats.end()); + auto rawInputBytes = it->second.rawInputBytes; + + // Verify we did not read the entire file but only a chunk + ASSERT_EQ(rawInputBytes, loadQuantum); + ASSERT_GT(getTableScanRuntimeStats(task)["totalScanTime"].sum, 0); + ASSERT_GT(getTableScanRuntimeStats(task)["ioWaitWallNanos"].sum, 0); +} + +TEST_F(TableScanTest, duplicateFieldProject) { + auto vector = makeRowVector( + {"id", "name"}, + { + makeFlatVector({1, 2}), + makeFlatVector({"Alice", "John"}), + }); + + auto file = TempFilePath::create(); + writeToFile(file->getPath(), vector); + createDuckDbTable({vector}); + + auto plan = PlanBuilder() + .tableScan(vector->rowType()) + .filter("name = 'John'") + .project({"id AS t0", "id AS t1"}) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .split(makeHiveConnectorSplit(file->getPath())) + .assertResults("SELECT id, id FROM tmp WHERE name = 'John'"); +} + +TEST_F(TableScanTest, parallelUnitLoader) { + auto vectors = makeVectors(10, 1'000); + auto filePath = TempFilePath::create(); + writeToFile( + filePath->getPath(), + vectors, + std::make_shared(), + []() { return std::make_unique(1000, 0); }); + createDuckDbTable(vectors); + auto plan = tableScanNode(); + auto task = + AssertQueryBuilder(plan) + .splits(makeHiveConnectorSplits({filePath})) + .connectorSessionProperty( + kHiveConnectorId, + connector::hive::HiveConfig::kParallelUnitLoadCountSession, + std::to_string(3)) + .assertTypeAndNumRows(rowType_, 10'000); + auto stats = getTableScanRuntimeStats(task); + // Verify that parallel unit loader is enabled. + ASSERT_GT(stats.count("waitForUnitReadyNanos"), 0); +} + +TEST_F(TableScanTest, filterColumnHandles) { + auto data = makeVectors(1, 10, ROW({"a", "b"}, BIGINT())); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), data); + auto split = exec::test::HiveConnectorSplitBuilder(filePath->getPath()) + .partitionKey("ds", "2025-10-23") + .build(); + auto plan = PlanBuilder() + .startTableScan() + .outputType(ROW({"x"}, {BIGINT()})) + .assignments({{"x", regularColumn("a", BIGINT())}}) + .dataColumns(asRowType(data[0]->type())) + .filterColumnHandles({ + partitionKey("ds", VARCHAR()), + regularColumn("a", BIGINT()), + }) + .remainingFilter("length(ds) + a % 2 > 0") + .endTableScan() + .planNode(); + AssertQueryBuilder(plan).split(split).assertResults( + makeRowVector({data[0]->childAt(0)})); +} + +TEST_F(TableScanTest, columnPostProcessorWithSubfieldFilters) { + auto data = makeFlatVector(10, folly::identity); + auto vector = makeRowVector({data, data, data}); + auto file = TempFilePath::create(); + writeToFile(file->getPath(), {vector}); + ChainedVectorLoader::PostVectorLoadProcessor postProc; + postProc = [&](auto& column) { + if (column->isLazy()) { + auto* lazy = column->template asUnchecked(); + if (lazy->isLoaded()) { + column = lazy->loadedVectorShared(); + postProc(column); + } else { + lazy->chain(postProc); + } + return; + } + if (column->encoding() == VectorEncoding::Simple::DICTIONARY) { + auto alphabet = column->valueVector(); + postProc(alphabet); + column->setValueVector(std::move(alphabet)); + return; + } + auto* values = + column->template asChecked>()->mutableRawValues(); + for (vector_size_t i = 0; i < column->size(); ++i) { + ++values[i]; + } + }; + auto c0Handle = std::make_shared( + "c0", + HiveColumnHandle::ColumnType::kRegular, + BIGINT(), + BIGINT(), + std::vector{}, + HiveColumnHandle::ColumnParseParameters{}, + postProc); + auto c1Handle = regularColumn("c1", BIGINT()); + auto c2Handle = std::make_shared( + "c2", + HiveColumnHandle::ColumnType::kRegular, + BIGINT(), + BIGINT(), + std::vector{}, + HiveColumnHandle::ColumnParseParameters{}, + postProc); + auto expected = makeRowVector({ + makeFlatVector({1, 3, 5, 7, 9}), + makeFlatVector({0, 2, 4, 6, 8}), + makeFlatVector({1, 3, 5, 7, 9}), + }); + { + SCOPED_TRACE("Subfield filters"); + auto plan = + PlanBuilder(pool()) + .startTableScan() + .outputType(asRowType(vector->type())) + .subfieldFilter("c0 in (0, 2, 4, 6, 8)") + .assignments({{"c0", c0Handle}, {"c1", c1Handle}, {"c2", c2Handle}}) + .endTableScan() + .planNode(); + auto split = makeHiveConnectorSplit(file->getPath()); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } + { + SCOPED_TRACE("Remaining filter"); + auto plan = + PlanBuilder() + .startTableScan() + .outputType(asRowType(vector->type())) + .remainingFilter("c0 % 2 = 0") + .assignments({{"c0", c0Handle}, {"c1", c1Handle}, {"c2", c2Handle}}) + .endTableScan() + .planNode(); + auto split = makeHiveConnectorSplit(file->getPath()); + AssertQueryBuilder(plan).split(split).assertResults(expected); + } +} + +TEST_F(TableScanTest, shortDecimalFilter) { + functions::registerIsNotNullFunction("isnotnull"); + + std::vector> values = { + 123456789123456789L, + 987654321123456L, + std::nullopt, + 2000000000000000L, + 5000000000000000L, + 987654321987654321L, + 100000000000000L, + 1230000000123456L, + 120000000123456L, + std::nullopt}; + auto rowVector = makeRowVector( + {"a"}, + { + makeNullableFlatVector(values, DECIMAL(18, 6)), + }); + createDuckDbTable({rowVector}); + + auto filePath = facebook::velox::test::getDataFilePath( + "velox/exec/tests", "data/decimal.orc"); + auto split = exec::test::HiveConnectorSplitBuilder(filePath) + .start(0) + .length(fs::file_size(filePath)) + .fileFormat(dwio::common::FileFormat::ORC) + .build(); + + auto rowType = rowVector->rowType(); + // Is not null. + auto op = + PlanBuilder().tableScan(rowType, {}, "isnotnull(a)", rowType).planNode(); + assertQuery(op, split, "SELECT a FROM tmp where a is not null"); + + // Is null. + op = PlanBuilder().tableScan(rowType, {}, "is_null(a)", rowType).planNode(); + assertQuery(op, split, "SELECT a FROM tmp where a is null"); + + // BigintRange. + op = + PlanBuilder() + .tableScan( + rowType, + {}, + "a > 2000000000.0::DECIMAL(18, 6) and a < 6000000000.0::DECIMAL(18, 6)", + rowType) + .planNode(); + assertQuery( + op, + split, + "SELECT a FROM tmp where a > 2000000000.0 and a < 6000000000.0"); + + // NegatedBigintRange. + op = + PlanBuilder() + .tableScan( + rowType, + {}, + "not(a between 2000000000.0::DECIMAL(18, 6) and 6000000000.0::DECIMAL(18, 6))", + rowType) + .planNode(); + assertQuery( + op, + split, + "SELECT a FROM tmp where a < 2000000000.0 or a > 6000000000.0"); +} + +TEST_F(TableScanTest, longDecimalFilter) { + functions::registerIsNotNullFunction("isnotnull"); + + std::vector> shortValues = { + 123456789123456789L, + 987654321123456L, + std::nullopt, + 2000000000000000L, + 5000000000000000L, + 987654321987654321L, + 100000000000000L, + 1230000000123456L, + 120000000123456L, + std::nullopt}; + + std::vector> longValues = { + HugeInt::parse("123456789123456789123456789" + std::string(9, '0')), + HugeInt::parse("987654321123456789" + std::string(9, '0')), + std::nullopt, + HugeInt::parse("2" + std::string(37, '0')), + HugeInt::parse("5" + std::string(37, '0')), + HugeInt::parse("987654321987654321987654321" + std::string(9, '0')), + HugeInt::parse("1" + std::string(26, '0')), + HugeInt::parse("123000000012345678" + std::string(10, '0')), + HugeInt::parse("120000000123456789" + std::string(9, '0')), + HugeInt::parse("9" + std::string(37, '0'))}; + + auto rowVector = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector(shortValues, DECIMAL(18, 6)), + makeNullableFlatVector(longValues, DECIMAL(38, 18)), + }); + createDuckDbTable({rowVector}); + + auto filePath = facebook::velox::test::getDataFilePath( + "velox/exec/tests", "data/decimal.orc"); + auto split = exec::test::HiveConnectorSplitBuilder(filePath) + .start(0) + .length(fs::file_size(filePath)) + .fileFormat(dwio::common::FileFormat::ORC) + .build(); + + auto outputType = ROW({"b"}, {DECIMAL(38, 18)}); + auto dataColumns = rowVector->rowType(); + + auto op = PlanBuilder() + .tableScan(outputType, {}, "isnotnull(b)", dataColumns) + .planNode(); + assertQuery(op, split, "SELECT b FROM tmp where b is not null"); + + // Is null. + op = PlanBuilder() + .tableScan(outputType, {}, "is_null(b)", dataColumns) + .planNode(); + assertQuery(op, split, "SELECT b FROM tmp where b is null"); + + // HugeintRange. + op = + PlanBuilder() + .tableScan( + outputType, + {}, + "b > 2000000000.0::DECIMAL(38, 18) and b < 6000000000.0::DECIMAL(38, 18)", + dataColumns) + .planNode(); + assertQuery( + op, + split, + "SELECT b FROM tmp where b > 2000000000.0 and b < 6000000000.0"); + + // Test filter column not being projected out. + op = PlanBuilder() + .tableScan(outputType, {}, "a is null", dataColumns) + .planNode(); + assertQuery(op, split, "SELECT b FROM tmp WHERE a is null"); +} + +} // namespace +} // namespace facebook::velox::exec diff --git a/velox/exec/tests/TableWriterTest.cpp b/velox/exec/tests/TableWriterTest.cpp index 412ab6c5f734..8a78a9208df5 100644 --- a/velox/exec/tests/TableWriterTest.cpp +++ b/velox/exec/tests/TableWriterTest.cpp @@ -82,7 +82,7 @@ TEST_F(BasicTableWriterTestBase, roundTrip) { ->as>(); ASSERT_TRUE(details->isNullAt(0)); ASSERT_FALSE(details->isNullAt(1)); - folly::dynamic obj = folly::parseJson(details->valueAt(1)); + folly::dynamic obj = folly::parseJson(std::string_view(details->valueAt(1))); ASSERT_EQ(size, obj["rowCount"].asInt()); auto fileWriteInfos = obj["fileWriteInfos"]; @@ -93,10 +93,12 @@ TEST_F(BasicTableWriterTestBase, roundTrip) { // Read from 'writeFileName' and verify the data matches the original. plan = PlanBuilder().tableScan(rowType).planNode(); - auto copy = AssertQueryBuilder(plan) - .split(makeHiveConnectorSplit(fmt::format( - "{}/{}", targetDirectoryPath->getPath(), writeFileName))) - .copyResults(pool()); + auto copy = + AssertQueryBuilder(plan) + .split(makeHiveConnectorSplit( + fmt::format( + "{}/{}", targetDirectoryPath->getPath(), writeFileName))) + .copyResults(pool()); assertEqualResults({data}, {copy}); } @@ -178,7 +180,7 @@ TEST_F(BasicTableWriterTestBase, targetFileName) { auto results = AssertQueryBuilder(plan).copyResults(pool()); auto* details = results->childAt(TableWriteTraits::kFragmentChannel) ->asUnchecked>(); - auto detail = folly::parseJson(details->valueAt(1)); + auto detail = folly::parseJson(std::string_view(details->valueAt(1))); auto fileWriteInfos = detail["fileWriteInfos"]; ASSERT_EQ(1, fileWriteInfos.size()); ASSERT_EQ(fileWriteInfos[0]["writeFileName"].asString(), kFileName); @@ -205,66 +207,72 @@ class PartitionedTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); } } } @@ -288,26 +296,28 @@ class UnpartitionedTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_NONE, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_NONE, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_NONE, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_NONE, + scaleWriter} + .value); } } } @@ -331,26 +341,28 @@ class BucketedUnpartitionedTableWriterTest const std::vector bucketModes = {TestMode::kOnlyBucketed}; for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_NONE, - /*scaleWriter=*/false} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_NONE, + /*scaleWriter=*/false} + .value); } } return testParams; @@ -375,86 +387,83 @@ class BucketedTableOnlyWriteTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (auto bucketMode : bucketModes) { - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - true, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - true, - multiDrivers, - CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + true, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); } } } @@ -480,26 +489,28 @@ class BucketSortOnlyTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (auto bucketMode : bucketModes) { - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_ZSTD, - /*scaleWriter=*/false} - .value); - testParams.push_back(TestParam{ - fileFormat, - bucketMode, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - true, - multiDrivers, - facebook::velox::common::CompressionKind_NONE, - /*scaleWriter=*/false} - .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_ZSTD, + /*scaleWriter=*/false} + .value); + testParams.push_back( + TestParam{ + fileFormat, + bucketMode, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + true, + multiDrivers, + facebook::velox::common::CompressionKind_NONE, + /*scaleWriter=*/false} + .value); } } } @@ -523,26 +534,28 @@ class PartitionedWithoutBucketTableWriterTest for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - true, - CompressionKind_ZSTD, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); } } } @@ -565,126 +578,138 @@ class AllTableWriterTest : public TableWriterTestBase, for (bool multiDrivers : multiDriverOptions) { for (FileFormat fileFormat : fileFormats) { for (bool scaleWriter : {false, true}) { - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kUnpartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kPartitioned, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kHiveCompatible, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kNoCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); - testParams.push_back(TestParam{ - fileFormat, - TestMode::kOnlyBucketed, - CommitStrategy::kTaskCommit, - HiveBucketProperty::Kind::kPrestoNative, - false, - multiDrivers, - CompressionKind_ZSTD, - scaleWriter} - .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kUnpartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kPartitioned, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kHiveCompatible, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kNoCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); + testParams.push_back( + TestParam{ + fileFormat, + TestMode::kOnlyBucketed, + CommitStrategy::kTaskCommit, + HiveBucketProperty::Kind::kPrestoNative, + false, + multiDrivers, + CompressionKind_ZSTD, + scaleWriter} + .value); } } } @@ -1444,6 +1469,7 @@ TEST_P(UnpartitionedTableWriterTest, differentCompression) { } if (compressionKind == CompressionKind_NONE || compressionKind == CompressionKind_ZLIB || + compressionKind == CompressionKind_GZIP || compressionKind == CompressionKind_ZSTD) { auto result = AssertQueryBuilder(plan) .config( @@ -1630,6 +1656,17 @@ TEST_P(BucketedTableOnlyWriteTest, bucketCountLimit) { SCOPED_TRACE(testParam_.toString()); auto input = makeVectors(1, 100); createDuckDbTable(input); + + // Get the HiveConfig to access the configurable maxBucketCount + auto defaultHiveConfig = + std::make_shared(std::make_shared( + std::unordered_map())); + + auto emptySession = std::make_shared( + std::unordered_map()); + uint32_t maxBucketCount = + defaultHiveConfig->maxBucketCount(emptySession.get()); + struct { uint32_t bucketCount; bool expectedError; @@ -1641,10 +1678,10 @@ TEST_P(BucketedTableOnlyWriteTest, bucketCountLimit) { } testSettings[] = { {1, false}, {3, false}, - {HiveDataSink::maxBucketCount() - 1, false}, - {HiveDataSink::maxBucketCount(), true}, - {HiveDataSink::maxBucketCount() + 1, true}, - {HiveDataSink::maxBucketCount() * 2, true}}; + {maxBucketCount - 1, false}, + {maxBucketCount, true}, + {maxBucketCount + 1, true}, + {maxBucketCount * 2, true}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); auto outputDirectory = TempDirectoryPath::create(); @@ -1780,7 +1817,8 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { } if (!fragmentVector->isNullAt(i)) { ASSERT_FALSE(fragmentVector->isNullAt(i)); - folly::dynamic obj = folly::parseJson(fragmentVector->valueAt(i)); + folly::dynamic obj = + folly::parseJson(std::string_view(fragmentVector->valueAt(i))); if (testMode_ == TestMode::kUnpartitioned) { ASSERT_EQ(obj["targetPath"], outputDirectory->getPath()); ASSERT_EQ(obj["writePath"], outputDirectory->getPath()); @@ -1789,13 +1827,17 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { for (const auto& partitionBy : partitionedBy_) { partitionDirRe += fmt::format("/{}=.+", partitionBy); } - ASSERT_TRUE(RE2::FullMatch( - obj["targetPath"].asString(), - fmt::format("{}{}", outputDirectory->getPath(), partitionDirRe))) + ASSERT_TRUE( + RE2::FullMatch( + obj["targetPath"].asString(), + fmt::format( + "{}{}", outputDirectory->getPath(), partitionDirRe))) << obj["targetPath"].asString(); - ASSERT_TRUE(RE2::FullMatch( - obj["writePath"].asString(), - fmt::format("{}{}", outputDirectory->getPath(), partitionDirRe))) + ASSERT_TRUE( + RE2::FullMatch( + obj["writePath"].asString(), + fmt::format( + "{}{}", outputDirectory->getPath(), partitionDirRe))) << obj["writePath"].asString(); } numRows += obj["rowCount"].asInt(); @@ -1820,7 +1862,7 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { ASSERT_EQ(writeFileName, targetFileName); } else { const std::string kParquetSuffix = ".parquet"; - if (folly::StringPiece(targetFileName).endsWith(kParquetSuffix)) { + if (targetFileName.ends_with(kParquetSuffix)) { // Remove the .parquet suffix. auto trimmedFilename = targetFileName.substr( 0, targetFileName.size() - kParquetSuffix.size()); @@ -1831,9 +1873,11 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { } } if (!commitContextVector->isNullAt(i)) { - ASSERT_TRUE(RE2::FullMatch( - commitContextVector->valueAt(i).getString(), - fmt::format(".*{}.*", commitStrategyToString(commitStrategy_)))) + ASSERT_TRUE( + RE2::FullMatch( + commitContextVector->valueAt(i).getString(), + fmt::format( + ".*{}.*", CommitStrategyName::toName(commitStrategy_)))) << commitContextVector->valueAt(i); } } @@ -1853,7 +1897,7 @@ TEST_P(AllTableWriterTest, tableWriteOutputCheck) { auto obj = TableWriteTraits::getTableCommitContext(result); ASSERT_EQ( obj[TableWriteTraits::kCommitStrategyContextKey], - commitStrategyToString(commitStrategy_)); + CommitStrategyName::toName(commitStrategy_)); ASSERT_EQ(obj[TableWriteTraits::klastPageContextKey], true); ASSERT_EQ(obj[TableWriteTraits::kLifeSpanContextKey], "TaskWide"); } @@ -1900,60 +1944,52 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) { auto outputDirectory = TempDirectoryPath::create(); std::vector groupingKeyFields; + groupingKeyFields.reserve(partitionedBy_.size()); for (int i = 0; i < partitionedBy_.size(); ++i) { - groupingKeyFields.emplace_back(std::make_shared( - partitionTypes_.at(i), partitionedBy_.at(i))); + groupingKeyFields.emplace_back( + std::make_shared( + partitionTypes_.at(i), partitionedBy_.at(i))); } // aggregation node core::TypedExprPtr intInputField = std::make_shared(SMALLINT(), "c2"); auto minCallExpr = std::make_shared( - SMALLINT(), std::vector{intInputField}, "min"); + SMALLINT(), "min", intInputField); auto maxCallExpr = std::make_shared( - SMALLINT(), std::vector{intInputField}, "max"); + SMALLINT(), "max", intInputField); auto distinctCountCallExpr = std::make_shared( - VARBINARY(), - std::vector{intInputField}, - "approx_distinct"); + VARBINARY(), "approx_distinct", intInputField); core::TypedExprPtr strInputField = std::make_shared(VARCHAR(), "c5"); auto maxDataSizeCallExpr = std::make_shared( - BIGINT(), - std::vector{strInputField}, - "max_data_size_for_stats"); + BIGINT(), "max_data_size_for_stats", strInputField); auto sumDataSizeCallExpr = std::make_shared( - BIGINT(), - std::vector{strInputField}, - "sum_data_size_for_stats"); + BIGINT(), "sum_data_size_for_stats", strInputField); core::TypedExprPtr boolInputField = std::make_shared(BOOLEAN(), "c6"); auto countCallExpr = std::make_shared( - BIGINT(), std::vector{boolInputField}, "count"); + BIGINT(), "count", boolInputField); auto countIfCallExpr = std::make_shared( - BIGINT(), std::vector{boolInputField}, "count_if"); + BIGINT(), "count_if", boolInputField); core::TypedExprPtr mapInputField = std::make_shared( MAP(DATE(), BIGINT()), "c7"); auto countMapCallExpr = std::make_shared( - BIGINT(), std::vector{mapInputField}, "count"); + BIGINT(), "count", mapInputField); auto sumDataSizeMapCallExpr = std::make_shared( - BIGINT(), - std::vector{mapInputField}, - "sum_data_size_for_stats"); + BIGINT(), "sum_data_size_for_stats", mapInputField); core::TypedExprPtr arrayInputField = std::make_shared( MAP(DATE(), BIGINT()), "c7"); auto countArrayCallExpr = std::make_shared( - BIGINT(), std::vector{mapInputField}, "count"); + BIGINT(), "count", mapInputField); auto sumDataSizeArrayCallExpr = std::make_shared( - BIGINT(), - std::vector{mapInputField}, - "sum_data_size_for_stats"); + BIGINT(), "sum_data_size_for_stats", mapInputField); const std::vector aggregateNames = { "min", @@ -1983,7 +2019,7 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) { }; }; - std::vector aggregates = { + const std::vector aggregates = { makeAggregate(minCallExpr), makeAggregate(maxCallExpr), makeAggregate(distinctCountCallExpr), @@ -1996,22 +2032,18 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) { makeAggregate(countArrayCallExpr), makeAggregate(sumDataSizeArrayCallExpr), }; - const auto aggregationNode = std::make_shared( - core::PlanNodeId(), - core::AggregationNode::Step::kPartial, + const core::ColumnStatsSpec columnStatsSpec{ groupingKeyFields, - std::vector{}, + core::AggregationNode::Step::kPartial, aggregateNames, - aggregates, - false, // ignoreNullKeys - PlanBuilder().values({input}).planNode()); + aggregates}; auto plan = PlanBuilder() .values({input}) .addNode(addTableWriter( rowType_, rowType_->names(), - aggregationNode, + columnStatsSpec, std::make_shared( kHiveConnectorId, makeHiveInsertTableHandle( @@ -2036,7 +2068,7 @@ TEST_P(AllTableWriterTest, columnStatsDataTypes) { const auto distinctCountStatsVector = result->childAt(nextColumnStatsIndex++)->asFlatVector(); HashStringAllocator allocator{pool_.get()}; - DenseHll denseHll{ + DenseHll<> denseHll{ std::string(distinctCountStatsVector->valueAt(0)).c_str(), &allocator}; ASSERT_EQ(denseHll.cardinality(), 1000); const auto maxDataSizeStatsVector = @@ -2087,20 +2119,15 @@ TEST_P(AllTableWriterTest, columnStats) { output.emplace_back("min"); types.emplace_back(BIGINT()); const auto writerOutputType = ROW(std::move(output), std::move(types)); - - // aggregation node - auto aggregationNode = generateAggregationNode( - "c0", - groupingKeys, - core::AggregationNode::Step::kPartial, - PlanBuilder().values({input}).planNode()); + auto columnStatsSpec = generateColumnStatsSpec( + "c0", groupingKeys, core::AggregationNode::Step::kPartial); auto plan = PlanBuilder() .values({input}) .addNode(addTableWriter( rowType_, rowType_->names(), - aggregationNode, + columnStatsSpec, std::make_shared( kHiveConnectorId, makeHiveInsertTableHandle( @@ -2188,16 +2215,13 @@ TEST_P(AllTableWriterTest, columnStatsWithTableWriteMerge) { const auto writerOutputType = ROW(std::move(output), std::move(types)); // aggregation node - auto aggregationNode = generateAggregationNode( - "c0", - groupingKeys, - core::AggregationNode::Step::kPartial, - PlanBuilder().values({input}).planNode()); + auto columnStatsSpec = generateColumnStatsSpec( + "c0", groupingKeys, core::AggregationNode::Step::kPartial); auto tableWriterPlan = PlanBuilder().values({input}).addNode(addTableWriter( rowType_, rowType_->names(), - aggregationNode, + columnStatsSpec, std::make_shared( kHiveConnectorId, makeHiveInsertTableHandle( @@ -2209,17 +2233,10 @@ TEST_P(AllTableWriterTest, columnStatsWithTableWriteMerge) { false, commitStrategy_)); - auto mergeAggregationNode = generateAggregationNode( - "min", - groupingKeys, - core::AggregationNode::Step::kIntermediate, - std::move(tableWriterPlan.planNode())); - auto finalPlan = tableWriterPlan.capturePlanNodeId(tableWriteNodeId_) .localPartition(std::vector{}) - .tableWriteMerge(std::move(mergeAggregationNode)) + .tableWriteMerge() .planNode(); - auto result = AssertQueryBuilder(finalPlan).copyResults(pool()); auto rowVector = result->childAt(0)->asFlatVector(); auto fragmentVector = result->childAt(1)->asFlatVector(); @@ -2516,13 +2533,14 @@ DEBUG_ONLY_TEST_P(BucketSortOnlyTableWriterTest, outputBatchRows) { maxOutputBytes, expectedOutputCount); } - } testSettings[] = {// we have 4 buckets thus 4 writers. - {10000, "1000kB", 4}, - // when maxOutputRows = 1, 1000 rows triggers 1000 writes - {1, "1kB", 1000}, - // estimatedRowSize is ~62bytes, when maxOutputSize = 62 * - // 100, 1000 rows triggers ~10 writes - {10000, "6200B", 12}}; + } testSettings[] = { + // we have 4 buckets thus 4 writers. + {10000, "1000kB", 4}, + // when maxOutputRows = 1, 1000 rows triggers 1000 writes + {1, "1kB", 1000}, + // estimatedRowSize is ~62bytes, when maxOutputSize = 62 * + // 100, 1000 rows triggers ~10 writes + {10000, "6200B", 12}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); @@ -2674,38 +2692,78 @@ DEBUG_ONLY_TEST_P(BucketSortOnlyTableWriterTest, yield) { VELOX_INSTANTIATE_TEST_SUITE_P( TableWriterTest, UnpartitionedTableWriterTest, - testing::ValuesIn(UnpartitionedTableWriterTest::getTestParams())); + testing::ValuesIn(UnpartitionedTableWriterTest::getTestParams()), + [](const testing::TestParamInfo& info) { + const auto testParams = + static_cast(info.param); + return fmt::format( + "UnpartitionedTableWriterTest_{}", testParams.toString()); + }); VELOX_INSTANTIATE_TEST_SUITE_P( TableWriterTest, BucketedUnpartitionedTableWriterTest, - testing::ValuesIn(BucketedUnpartitionedTableWriterTest::getTestParams())); + testing::ValuesIn(BucketedUnpartitionedTableWriterTest::getTestParams()), + [](const testing::TestParamInfo& info) { + const auto testParams = + static_cast(info.param); + return fmt::format( + "BucketedUnpartitionedTableWriterTest_{}", testParams.toString()); + }); VELOX_INSTANTIATE_TEST_SUITE_P( TableWriterTest, PartitionedTableWriterTest, - testing::ValuesIn(PartitionedTableWriterTest::getTestParams())); + testing::ValuesIn(PartitionedTableWriterTest::getTestParams()), + [](const testing::TestParamInfo& info) { + const auto testParams = + static_cast(info.param); + return fmt::format( + "PartitionedTableWriterTest_{}", testParams.toString()); + }); VELOX_INSTANTIATE_TEST_SUITE_P( TableWriterTest, BucketedTableOnlyWriteTest, - testing::ValuesIn(BucketedTableOnlyWriteTest::getTestParams())); + testing::ValuesIn(BucketedTableOnlyWriteTest::getTestParams()), + [](const testing::TestParamInfo& info) { + const auto testParams = + static_cast(info.param); + return fmt::format( + "BucketedTableOnlyWriteTest_{}", testParams.toString()); + }); VELOX_INSTANTIATE_TEST_SUITE_P( TableWriterTest, AllTableWriterTest, - testing::ValuesIn(AllTableWriterTest::getTestParams())); + testing::ValuesIn(AllTableWriterTest::getTestParams()), + [](const testing::TestParamInfo& info) { + const auto testParams = + static_cast(info.param); + return fmt::format("AllTableWriterTest_{}", testParams.toString()); + }); VELOX_INSTANTIATE_TEST_SUITE_P( TableWriterTest, PartitionedWithoutBucketTableWriterTest, - testing::ValuesIn( - PartitionedWithoutBucketTableWriterTest::getTestParams())); + testing::ValuesIn(PartitionedWithoutBucketTableWriterTest::getTestParams()), + [](const testing::TestParamInfo& info) { + const auto testParams = + static_cast(info.param); + return fmt::format( + "PartitionedWithoutBucketTableWriterTest_{}", testParams.toString()); + }); VELOX_INSTANTIATE_TEST_SUITE_P( TableWriterTest, BucketSortOnlyTableWriterTest, - testing::ValuesIn(BucketSortOnlyTableWriterTest::getTestParams())); + testing::ValuesIn(BucketSortOnlyTableWriterTest::getTestParams()), + [](const testing::TestParamInfo& info) { + const auto testParams = + static_cast(info.param); + return fmt::format( + "BucketSortOnlyTableWriterTest_{}", testParams.toString()); + }); class TableWriterArbitrationTest : public HiveConnectorTestBase { protected: @@ -2738,7 +2796,7 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, reclaimFromTableWriter) { const int batchSize = 1'000; options.vectorSize = batchSize; options.stringVariableLength = false; - options.stringLength = 1'000; + options.stringLength = 500; VectorFuzzer fuzzer(options, pool()); const int numBatches = 20; std::vector vectors; @@ -2758,8 +2816,10 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, reclaimFromTableWriter) { const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic_int numInputs{0}; @@ -2877,8 +2937,10 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, reclaimFromSortTableWriter) { const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); const auto spillStats = common::globalSpillStats(); @@ -2980,10 +3042,11 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, writerFlushThreshold) { const std::vector testParams{ {0, 0}, {0, 1UL << 30}, {64UL << 20, 1UL << 30}}; for (const auto& testParam : testParams) { - SCOPED_TRACE(fmt::format( - "bytesToReserve: {}, writerFlushThreshold: {}", - succinctBytes(testParam.bytesToReserve), - succinctBytes(testParam.writerFlushThreshold))); + SCOPED_TRACE( + fmt::format( + "bytesToReserve: {}, writerFlushThreshold: {}", + succinctBytes(testParam.bytesToReserve), + succinctBytes(testParam.writerFlushThreshold))); auto queryPool = memory::memoryManager()->addRootPool( "writerFlushThreshold", kQueryMemoryCapacity); @@ -2991,8 +3054,11 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, writerFlushThreshold) { const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); + ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); memory::MemoryPool* compressionPool{nullptr}; @@ -3099,8 +3165,11 @@ DEBUG_ONLY_TEST_F( const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); + ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic injectFakeAllocationOnce{true}; @@ -3182,8 +3251,10 @@ DEBUG_ONLY_TEST_F( const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; const int numPrevReclaimedBytes = arbitrator->stats().reclaimedUsedBytes; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic writerNoMoreInput{false}; @@ -3282,8 +3353,10 @@ DEBUG_ONLY_TEST_F( const int numPrevArbitrationFailures = arbitrator->stats().numFailures; const int numPrevNonReclaimableAttempts = arbitrator->stats().numNonReclaimableAttempts; - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic injectFakeAllocationOnce{true}; @@ -3374,8 +3447,11 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, tableFileWriteError) { auto queryPool = memory::memoryManager()->addRootPool( "tableFileWriteError", kQueryMemoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); + ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic_bool injectWriterErrorOnce{true}; @@ -3442,8 +3518,10 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, tableWriteSpillUseMoreMemory) { auto queryPool = memory::memoryManager()->addRootPool( "tableWriteSpillUseMoreMemory", kQueryMemoryCapacity / 4); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity / 4); auto fakeLeafPool = queryCtx->pool()->addLeafChild( @@ -3529,14 +3607,18 @@ DEBUG_ONLY_TEST_F(TableWriterArbitrationTest, tableWriteReclaimOnClose) { auto queryPool = memory::memoryManager()->addRootPool( "tableWriteSpillUseMoreMemory", kQueryMemoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); auto fakeQueryPool = memory::memoryManager()->addRootPool("fake", kQueryMemoryCapacity); - auto fakeQueryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(fakeQueryPool)); + auto fakeQueryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(fakeQueryPool)) + .build(); ASSERT_EQ(fakeQueryCtx->pool()->capacity(), kQueryMemoryCapacity); auto fakeLeafPool = fakeQueryCtx->pool()->addLeafChild( @@ -3622,8 +3704,10 @@ DEBUG_ONLY_TEST_F( .data; auto queryPool = memory::memoryManager()->addRootPool( "tableWriteSpillUseMoreMemory", kQueryMemoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor_.get(), QueryConfig{{}}, {}, nullptr, std::move(queryPool)); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool(std::move(queryPool)) + .build(); ASSERT_EQ(queryCtx->pool()->capacity(), kQueryMemoryCapacity); std::atomic_bool writerCloseWaitFlag{true}; diff --git a/velox/exec/tests/TaskTest.cpp b/velox/exec/tests/TaskTest.cpp index 0aa7d1ed8bea..d74826cc3480 100644 --- a/velox/exec/tests/TaskTest.cpp +++ b/velox/exec/tests/TaskTest.cpp @@ -25,6 +25,7 @@ #include "velox/common/testutil/TestValue.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/exec/Cursor.h" +#include "velox/exec/HashAggregation.h" #include "velox/exec/OutputBufferManager.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/Values.h" @@ -255,7 +256,8 @@ class ExternalBlocker { public: folly::SemiFuture continueFuture() { if (isBlocked_) { - auto [promise, future] = makeVeloxContinuePromiseContract(); + auto [promise, future] = + makeVeloxContinuePromiseContract("ExternalBlocker::continueFuture"); continuePromise_ = std::move(promise); return std::move(future); } @@ -459,6 +461,55 @@ class TestBadMemoryTranslator : public exec::Operator::PlanNodeTranslator { return nullptr; } }; + +// Test operator that calls the protected shouldYield() method to verify it +// correctly delegates to the driver's shouldYield() implementation and +// respects CPU time slice limits. +class TestShouldYieldOperator : public exec::Operator { + public: + TestShouldYieldOperator( + int32_t operatorId, + exec::DriverCtx* driverCtx, + const RowTypePtr& outputType, + const std::string& nodeId) + : Operator(driverCtx, outputType, operatorId, nodeId, "TestShouldYield") { + } + + bool needsInput() const override { + return !noMoreInput_ && !input_; + } + + void addInput(RowVectorPtr input) override { + input_ = std::move(input); + } + + RowVectorPtr getOutput() override { + if (!input_) { + return nullptr; + } + + // Test the protected shouldYield() method + shouldYieldResult_ = shouldYield(); + + return input_; + } + + exec::BlockingReason isBlocked(ContinueFuture* /*unused*/) override { + return exec::BlockingReason::kNotBlocked; + } + + bool isFinished() override { + return noMoreInput_ && !input_; + } + + bool getShouldYieldResult() const { + return shouldYieldResult_; + } + + private: + RowVectorPtr input_; + bool shouldYieldResult_{false}; +}; } // namespace class TaskTest : public HiveConnectorTestBase { @@ -474,7 +525,8 @@ class TaskTest : public HiveConnectorTestBase { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); for (const auto& [nodeId, paths] : filePaths) { for (const auto& path : paths) { @@ -523,7 +575,8 @@ TEST_F(TaskTest, toJson) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); ASSERT_EQ( task->toString(), @@ -588,7 +641,8 @@ TEST_F(TaskTest, wrongPlanNodeForSplit) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); // Add split for the source node. task->addSplit("0", exec::Split(folly::copy(connectorSplit))); @@ -644,7 +698,8 @@ TEST_F(TaskTest, wrongPlanNodeForSplit) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); errorMessage = "Splits can be associated only with leaf plan nodes which require splits. Plan node ID 0 doesn't refer to such plan node."; VELOX_ASSERT_THROW( @@ -671,7 +726,8 @@ TEST_F(TaskTest, duplicatePlanNodeIds) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel), + Task::ExecutionMode::kParallel, + exec::Consumer{}), "Plan node IDs must be unique. Found duplicate ID: 0.") } @@ -847,8 +903,8 @@ TEST_F(TaskTest, hasMixedExecutionGroupJoin) { task->start(1); - ASSERT_FALSE( - task->hasMixedExecutionGroupJoin(dynamic_cast( + ASSERT_FALSE(task->hasMixedExecutionGroupJoin( + dynamic_cast( nonMixedGroupedModeJoinNode.get()))); ASSERT_TRUE(task->hasMixedExecutionGroupJoin( dynamic_cast(mixedGroupedModeJoinNode.get()))); @@ -1178,7 +1234,8 @@ TEST_F(TaskTest, serialExecutionExternalBlockable) { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); std::vector results; for (;;) { auto result = nonBlockingTask->next(&continueFuture); @@ -1204,7 +1261,8 @@ TEST_F(TaskTest, serialExecutionExternalBlockable) { plan, 0, core::QueryCtx::create(), - Task::ExecutionMode::kSerial); + Task::ExecutionMode::kSerial, + exec::Consumer{}); // Before we block, we expect `next` to get data normally. results.push_back(blockingTask->next(&continueFuture)); EXPECT_TRUE(results.back() != nullptr); @@ -1241,16 +1299,17 @@ TEST_F(TaskTest, supportSerialExecutionMode) { .project({"c0 % 10"}) .partitionedOutput({}, 1, std::vector{"p0"}) .planFragment(); - auto task = Task::create( - "single.execution.task.0", - plan, - 0, - core::QueryCtx::create(), - Task::ExecutionMode::kSerial); - // PartitionedOutput does not support serial execution mode, therefore the // task doesn't support it either. - ASSERT_FALSE(task->supportSerialExecutionMode()); + VELOX_ASSERT_THROW( + Task::create( + "single.execution.task.0", + plan, + 0, + core::QueryCtx::create(), + Task::ExecutionMode::kSerial, + exec::Consumer{}), + ""); } TEST_F(TaskTest, updateBroadCastOutputBuffers) { @@ -1266,7 +1325,8 @@ TEST_F(TaskTest, updateBroadCastOutputBuffers) { plan, 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(1, 1); @@ -1284,7 +1344,8 @@ TEST_F(TaskTest, updateBroadCastOutputBuffers) { plan, 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(1, 1); @@ -1572,8 +1633,13 @@ DEBUG_ONLY_TEST_F(TaskTest, inconsistentExecutionMode) { auto plan = PlanBuilder().values({data, data, data}).project({"c0"}).planFragment(); auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); - auto task = - Task::create("task.0", plan, 0, queryCtx, Task::ExecutionMode::kSerial); + auto task = Task::create( + "task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kSerial, + exec::Consumer{}); task->next(); VELOX_ASSERT_THROW(task->start(4, 1), "Inconsistent task execution mode."); @@ -1907,17 +1973,24 @@ TEST_F(TaskTest, driverCreationMemoryAllocationCheck) { .planFragment(); for (bool singleThreadExecution : {false, true}) { SCOPED_TRACE(fmt::format("singleThreadExecution: ", singleThreadExecution)); - auto badTask = Task::create( - "driverCreationMemoryAllocationCheck", - plan, - 0, - core::QueryCtx::create( - singleThreadExecution ? nullptr : driverExecutor_.get()), - singleThreadExecution ? Task::ExecutionMode::kSerial - : Task::ExecutionMode::kParallel); if (singleThreadExecution) { - VELOX_ASSERT_THROW(badTask->next(), "Unexpected memory pool allocations"); + VELOX_ASSERT_THROW( + Task::create( + "driverCreationMemoryAllocationCheck", + plan, + 0, + core::QueryCtx::create(nullptr), + Task::ExecutionMode::kSerial, + exec::Consumer{}), + "Unexpected memory pool allocations"); } else { + auto badTask = Task::create( + "driverCreationMemoryAllocationCheck", + plan, + 0, + core::QueryCtx::create(driverExecutor_.get()), + Task::ExecutionMode::kParallel, + exec::Consumer{}); VELOX_ASSERT_THROW( badTask->start(1), "Unexpected memory pool allocations"); } @@ -1944,36 +2017,34 @@ TEST_F(TaskTest, spillDirectoryCallback) { {{core::QueryConfig::kSpillEnabled, "true"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}}); params.maxDrivers = 1; - - auto cursor = TaskCursor::create(params); - - std::shared_ptr task = cursor->task(); - auto tmpRootDir = exec::test::TempDirectoryPath::create(); - auto tmpParentSpillDir = fmt::format( + auto spillRootDir = exec::test::TempDirectoryPath::create(); + auto spillParentDir = fmt::format( "{}{}/parent_spill/", tests::utils::FaultyFileSystem::scheme(), - tmpRootDir->getPath()); - auto tmpSpillDir = fmt::format( + spillRootDir->getPath()); + auto spillDir = fmt::format( "{}{}/parent_spill/spill/", tests::utils::FaultyFileSystem::scheme(), - tmpRootDir->getPath()); + spillRootDir->getPath()); - EXPECT_FALSE(task->hasCreateSpillDirectoryCb()); - - task->setCreateSpillDirectoryCb([tmpParentSpillDir, tmpSpillDir]() { - auto filesystem = filesystems::getFileSystem(tmpParentSpillDir, nullptr); + params.spillDirectory = spillDir; + params.spillDirectoryCallback = [spillParentDir, spillDir]() { + auto filesystem = filesystems::getFileSystem(spillParentDir, nullptr); filesystems::DirectoryOptions options; options.values.emplace( filesystems::DirectoryOptions::kMakeDirectoryConfig.toString(), "dummy.config=123"); - filesystem->mkdir(tmpParentSpillDir, options); - filesystem->mkdir(tmpSpillDir); - return tmpSpillDir; - }); + filesystem->mkdir(spillParentDir, options); + filesystem->mkdir(spillDir); + return spillDir; + }; + auto cursor = TaskCursor::create(params); + std::shared_ptr task = cursor->task(); EXPECT_TRUE(task->hasCreateSpillDirectoryCb()); + auto fs = std::dynamic_pointer_cast( - filesystems::getFileSystem(tmpParentSpillDir, nullptr)); + filesystems::getFileSystem(spillParentDir, nullptr)); fs->setFileSystemInjectionError( std::make_exception_ptr(std::runtime_error("test exception")), @@ -1989,7 +2060,7 @@ TEST_F(TaskTest, spillDirectoryCallback) { auto mkdirOp = static_cast(op); if (mkdirOp->path == - fmt::format("{}/parent_spill/", tmpRootDir->getPath())) { + fmt::format("{}/parent_spill/", spillRootDir->getPath())) { parentDirectoryCreated = true; auto it = mkdirOp->options.values.find( filesystems::DirectoryOptions::kMakeDirectoryConfig.toString()); @@ -1997,7 +2068,7 @@ TEST_F(TaskTest, spillDirectoryCallback) { EXPECT_EQ(it->second, "dummy.config=123"); } if (mkdirOp->path == - fmt::format("{}/parent_spill/spill/", tmpRootDir->getPath())) { + fmt::format("{}/parent_spill/spill/", spillRootDir->getPath())) { spillDirectoryCreated = true; } return; @@ -2042,13 +2113,13 @@ TEST_F(TaskTest, spillDirectoryLifecycleManagement) { {{core::QueryConfig::kSpillEnabled, "true"}, {core::QueryConfig::kAggregationSpillEnabled, "true"}}); params.maxDrivers = 1; + const auto rootTempDir = exec::test::TempDirectoryPath::create(); + const auto tmpDirectoryPath = + rootTempDir->getPath() + "/spillDirectoryLifecycleManagement"; + params.spillDirectory = tmpDirectoryPath; auto cursor = TaskCursor::create(params); std::shared_ptr task = cursor->task(); - auto rootTempDir = exec::test::TempDirectoryPath::create(); - auto tmpDirectoryPath = - rootTempDir->getPath() + "/spillDirectoryLifecycleManagement"; - task->setSpillDirectory(tmpDirectoryPath, false); TestScopedSpillInjection scopedSpillInjection(100); while (cursor->moveNext()) { @@ -2104,7 +2175,6 @@ TEST_F(TaskTest, spillDirNotCreated) { auto* task = cursor->task().get(); auto rootTempDir = exec::test::TempDirectoryPath::create(); auto tmpDirectoryPath = rootTempDir->getPath() + "/spillDirNotCreated"; - task->setSpillDirectory(tmpDirectoryPath, false); while (cursor->moveNext()) { } @@ -2150,7 +2220,8 @@ DEBUG_ONLY_TEST_F(TaskTest, resumeAfterTaskFinish) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Request pause and then unblock operators to proceed. @@ -2181,7 +2252,12 @@ DEBUG_ONLY_TEST_F(TaskTest, serialLongRunningOperatorInTaskReclaimerAbort) { auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); auto blockingTask = Task::create( - "blocking.task.0", plan, 0, queryCtx, Task::ExecutionMode::kSerial); + "blocking.task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kSerial, + exec::Consumer{}); // Before we block, we expect `next` to get data normally. EXPECT_NE(nullptr, blockingTask->next()); @@ -2256,7 +2332,12 @@ DEBUG_ONLY_TEST_F(TaskTest, longRunningOperatorInTaskReclaimerAbort) { auto queryCtx = core::QueryCtx::create(driverExecutor_.get()); auto blockingTask = Task::create( - "blocking.task.0", plan, 0, queryCtx, Task::ExecutionMode::kParallel); + "blocking.task.0", + plan, + 0, + queryCtx, + Task::ExecutionMode::kParallel, + exec::Consumer{}); blockingTask->start(4, 1); const std::string abortErrorMessage("Synthetic Exception"); @@ -2322,7 +2403,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskReclaimStats) { std::move(plan), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); const int numReclaims{10}; @@ -2396,7 +2478,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskPauseTime) { std::move(plan), 0, std::move(queryCtx), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Wait for the task driver starts to run. @@ -2445,7 +2528,8 @@ TEST_F(TaskTest, updateStatsWhileCloseOffThreadDriver) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); std::this_thread::sleep_for(std::chrono::milliseconds(100)); task->testingVisitDrivers( @@ -2490,7 +2574,8 @@ DEBUG_ONLY_TEST_F(TaskTest, driverEnqueAfterFailedAndPausedTask) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); // Request pause. @@ -2533,10 +2618,11 @@ DEBUG_ONLY_TEST_F(TaskTest, taskReclaimFailure) { .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kAggregationSpillEnabled, true) .maxDrivers(1) - .plan(PlanBuilder() - .values(inputVectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(inputVectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults( "SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"), spillTableError); @@ -2561,10 +2647,11 @@ DEBUG_ONLY_TEST_F(TaskTest, taskDeletionPromise) { std::thread queryThread([&]() { AssertQueryBuilder(duckDbQueryRunner_) .maxDrivers(1) - .plan(PlanBuilder() - .values(inputVectors) - .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) - .planNode()) + .plan( + PlanBuilder() + .values(inputVectors) + .singleAggregation({"c0", "c1"}, {"array_agg(c2)"}) + .planNode()) .assertResults("SELECT c0, c1, array_agg(c2) FROM tmp GROUP BY c0, c1"); }); @@ -2610,7 +2697,8 @@ DEBUG_ONLY_TEST_F(TaskTest, taskCancellation) { std::move(plan), 0, core::QueryCtx::create(driverExecutor_.get()), - Task::ExecutionMode::kParallel); + Task::ExecutionMode::kParallel, + exec::Consumer{}); task->start(4, 1); auto cancellationToken = task->getCancellationToken(); ASSERT_FALSE(cancellationToken.isCancellationRequested()); @@ -2668,13 +2756,12 @@ TEST_F(TaskTest, invalidPlanNodeForBarrier) { makeFlatVector(1'000, [](auto row) { return row; }), }); - // Filter + Project. const auto plan = PlanBuilder() .values({data, data}) .filter("c0 < 100") .project({"c0 + 5"}) .planFragment(); - ASSERT_FALSE(plan.supportsBarrier()); + ASSERT_TRUE(plan.firstNodeNotSupportingBarrier()); const auto task = Task::create( "invalidPlanNodeForBarrier", @@ -2683,7 +2770,12 @@ TEST_F(TaskTest, invalidPlanNodeForBarrier) { core::QueryCtx::create(), Task::ExecutionMode::kSerial); ASSERT_TRUE(!task->underBarrier()); - VELOX_ASSERT_THROW(task->requestBarrier(), "Task doesn't support barrier"); + VELOX_ASSERT_THROW( + task->requestBarrier(), + "Name of the first node that doesn't support barriered execution:"); + while (auto next = task->next()) { + } + ASSERT_TRUE(task->isFinished()); } TEST_F(TaskTest, barrierAfterNoMoreSplits) { @@ -2718,6 +2810,9 @@ TEST_F(TaskTest, barrierAfterNoMoreSplits) { VELOX_ASSERT_THROW( task->requestBarrier(), "Can't start barrier on task which has already received no more splits"); + while (auto next = task->next()) { + } + ASSERT_TRUE(task->isFinished()); } TEST_F(TaskTest, invalidTaskModeForBarrier) { @@ -2733,7 +2828,7 @@ TEST_F(TaskTest, invalidTaskModeForBarrier) { .filter("c0 < 100") .project({"c0 + 5"}) .planFragment(); - ASSERT_TRUE(plan.supportsBarrier()); + ASSERT_TRUE(plan.firstNodeNotSupportingBarrier() == nullptr); const auto task = Task::create( "invalidTaskModeForBarrier", @@ -2742,7 +2837,9 @@ TEST_F(TaskTest, invalidTaskModeForBarrier) { core::QueryCtx::create(), Task::ExecutionMode::kParallel); ASSERT_TRUE(!task->underBarrier()); - VELOX_ASSERT_THROW(task->requestBarrier(), "Task doesn't support barrier"); + VELOX_ASSERT_THROW( + task->requestBarrier(), + "(Parallel vs. Serial) Task doesn't support barriered execution."); } TEST_F(TaskTest, addSplitAfterBarrier) { @@ -2760,7 +2857,7 @@ TEST_F(TaskTest, addSplitAfterBarrier) { .filter("c0 < 100") .project({"c0 + 5"}) .planFragment(); - ASSERT_TRUE(plan.supportsBarrier()); + ASSERT_TRUE(plan.firstNodeNotSupportingBarrier() == nullptr); const auto task = Task::create( "barrierAfterNoMoreSplits", @@ -3144,4 +3241,185 @@ TEST_F(TaskTest, testTerminateDuringBarrierWithUnion) { ASSERT_EQ(task->taskStats().numBarriers, 1); ASSERT_EQ(task->taskStats().numFinishedSplits, 3); } + +TEST_F(TaskTest, expressionStatsInBetweenBarriers) { + // Verify that expression stats are collected in between barriers and at the + // end. + // This projection ensures that we verify that inputs of special + // form (coalesce) are also included. + const std::string projection = "coalesce(c0 + 1, 0)"; + const int numRows{10}; + auto data = makeRowVector({ + makeFlatVector(numRows, [](auto row) { return row; }), + }); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), {data}); + + core::PlanNodeId scanId; + core::PlanNodeId projectNodeId; + auto plan = PlanBuilder() + .tableScan(asRowType(data->type())) + .capturePlanNodeId(scanId) + .project({projection}) + .capturePlanNodeId(projectNodeId) + .planFragment(); + + auto queryCtx = core::QueryCtx::create(); + queryCtx->testingOverrideConfigUnsafe( + {{core::QueryConfig::kMaxOutputBatchRows, "10"}, + {core::QueryConfig::kOperatorTrackExpressionStats, "true"}}); + const auto task = Task::create( + "expressionStatsInBetweenBarriers", + plan, + 0, + std::move(queryCtx), + Task::ExecutionMode::kSerial); + ASSERT_TRUE(!task->underBarrier()); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + auto barrierFuture = task->requestBarrier(); + ASSERT_TRUE(task->underBarrier()); + RowVectorPtr result; + do { + ContinueFuture dummyFuture{ContinueFuture::makeEmpty()}; + result = task->next(&dummyFuture); + ASSERT_FALSE(dummyFuture.valid()); + } while (result != nullptr); + auto taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numBarriers, 1); + ASSERT_EQ(taskStats.numFinishedSplits, 1); + auto verifyExpressionStats = [nodeId = projectNodeId]( + const TaskStats& taskStats, + uint64_t expectedNumProcessedRows) { + ASSERT_EQ(taskStats.pipelineStats.size(), 1); + ASSERT_EQ(taskStats.pipelineStats[0].operatorStats.size(), 2); + auto& projectStats = taskStats.pipelineStats[0].operatorStats[1]; + ASSERT_EQ(projectStats.planNodeId, nodeId); + auto& expressionStats = projectStats.expressionStats; + // Ensure only the non-special form expression is tracked. + ASSERT_EQ(expressionStats.size(), 1); + auto it = expressionStats.find("plus"); + ASSERT_TRUE(it != expressionStats.end()); + ASSERT_EQ(it->second.numProcessedRows, expectedNumProcessedRows); + }; + verifyExpressionStats(taskStats, 10); + ASSERT_TRUE(barrierFuture.isReady()); + barrierFuture.wait(); + task->addSplit( + scanId, exec::Split(makeHiveConnectorSplit(filePath->getPath()))); + task->noMoreSplits(scanId); + do { + result = task->next(); + } while (result != nullptr); + VELOX_CHECK(waitForTaskCompletion(task.get())); + taskStats = task->taskStats(); + ASSERT_EQ(taskStats.numFinishedSplits, 2); + verifyExpressionStats(taskStats, 20); +} + +DEBUG_ONLY_TEST_F(TaskTest, taskExecutionEndTime) { + std::vector vectors; + std::vector> tempFiles; + const int numSplits{5}; + const int numRowsPerSplit{1'000}; + for (int32_t i = 0; i < numSplits; ++i) { + vectors.push_back(makeRowVector( + {makeFlatVector(numRowsPerSplit, [](auto row) { return row; }), + makeFlatVector( + numRowsPerSplit, [](auto row) { return row * 2; })})); + tempFiles.push_back(TempFilePath::create()); + } + writeToFiles(toFilePaths(tempFiles), vectors); + createDuckDbTable(vectors); + + const int injectedDelaySecs{2}; + std::atomic_bool injectDelayOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function( + ([&](const velox::exec::Operator* op) { + if (op->operatorCtx()->operatorType() == "HashAggregation") { + return; + } + auto task = op->operatorCtx()->task(); + if (!task->testingAllSplitsFinished()) { + return; + } + if (!injectDelayOnce.exchange(false)) { + return; + } + std::this_thread::sleep_for( + std::chrono::seconds(injectedDelaySecs)); // No Lint. + }))); + + core::PlanNodeId tableScanNodeId; + auto plan = test::PlanBuilder() + .tableScan(asRowType(vectors.back()->type())) + .capturePlanNodeId(tableScanNodeId) + .singleAggregation({"c0"}, {"sum(c1)"}) + .planNode(); + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .splits(tableScanNodeId, makeHiveConnectorSplits(tempFiles)) + .assertResults("SELECT c0, sum(c1) FROM tmp GROUP BY c0"); + const auto taskStats = task->taskStats(); + ASSERT_GE( + taskStats.executionEndTimeMs - taskStats.executionStartTimeMs, + injectedDelaySecs * 1'000); +} + +DEBUG_ONLY_TEST_F(TaskTest, operatorShouldYieldMethod) { + auto planNodeIdGenerator = std::make_shared(); + const auto data = makeRowVector({ + makeFlatVector(3, [](auto row) { return row; }), + }); + const uint64_t kDriverCpuTimeSliceLimitMs = 100; + + struct { + bool hasDelay; + std::string debugString() const { + return fmt::format("hasDelay: {}", hasDelay); + } + } testSetting[]{{true}, {false}}; + + for (const auto& testData : testSetting) { + SCOPED_TRACE(testData.debugString()); + std::atomic shouldYieldResult{false}; + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Values::getOutput", + std::function( + [&](const exec::Values* values) { + auto testShouldYieldOp = + std::make_unique( + 0, + values->operatorCtx()->driverCtx(), + ROW({"c0"}, {BIGINT()}), + planNodeIdGenerator->next()); + + testShouldYieldOp->addInput( + makeRowVector({makeFlatVector({1})})); + if (testData.hasDelay) { + std::this_thread::sleep_for( + std::chrono::milliseconds(2 * kDriverCpuTimeSliceLimitMs)); + } + + // This will test shouldYield() internally + testShouldYieldOp->getOutput(); + shouldYieldResult = testShouldYieldOp->getShouldYieldResult(); + })); + + auto queryCtx = core::QueryCtx::create( + executor_.get(), + core::QueryConfig({ + {core::QueryConfig::kDriverCpuTimeSliceLimitMs, + folly::to(kDriverCpuTimeSliceLimitMs)}, + })); + + auto plan = PlanBuilder(planNodeIdGenerator).values({data}).planNode(); + AssertQueryBuilder(plan).queryCtx(queryCtx).copyResults(pool()); + + ASSERT_EQ(testData.hasDelay, shouldYieldResult.load()); + } +} + } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/ThreadDebugInfoTest.cpp b/velox/exec/tests/ThreadDebugInfoTest.cpp index 75bab5b37593..6bcb1801e145 100644 --- a/velox/exec/tests/ThreadDebugInfoTest.cpp +++ b/velox/exec/tests/ThreadDebugInfoTest.cpp @@ -63,6 +63,7 @@ template struct InduceSegFaultFunction { template void call(TResult& out, const TInput& in) { + LOG(ERROR) << "error"; int* nullpointer = nullptr; *nullpointer = 6; } @@ -117,9 +118,10 @@ DEBUG_ONLY_TEST_F(ThreadDebugInfoDeathTest, withinTheCallingThread) { #ifndef IS_BUILDING_WITH_SAN ASSERT_DEATH( - (task->next()), + task->next(), ".*Fatal signal handler. Query Id= TaskCursorQuery_0 Task Id= single.execution.task.0.*"); #endif + task->requestCancel(); } DEBUG_ONLY_TEST_F(ThreadDebugInfoDeathTest, noThreadContextSet) { diff --git a/velox/exec/tests/TopNRowNumberTest.cpp b/velox/exec/tests/TopNRowNumberTest.cpp index fac57658c2c1..7b21a3dee4b5 100644 --- a/velox/exec/tests/TopNRowNumberTest.cpp +++ b/velox/exec/tests/TopNRowNumberTest.cpp @@ -29,12 +29,25 @@ namespace { class TopNRowNumberTest : public OperatorTestBase { protected: - TopNRowNumberTest() { + explicit TopNRowNumberTest(core::TopNRowNumberNode::RankFunction function) + : functionName_(core::TopNRowNumberNode::rankFunctionName(function)) {} + + void SetUp() override { + exec::test::OperatorTestBase::SetUp(); filesystems::registerLocalFileSystem(); } + + const std::string functionName_; }; -TEST_F(TopNRowNumberTest, basic) { +class MultiTopNRowNumberTest : public TopNRowNumberTest, + public testing::WithParamInterface< + core::TopNRowNumberNode::RankFunction> { + public: + MultiTopNRowNumberTest() : TopNRowNumberTest(GetParam()) {} +}; + +TEST_P(MultiTopNRowNumberTest, basic) { auto data = makeRowVector({ // Partitioning key. makeFlatVector({1, 1, 2, 2, 1, 2, 1}), @@ -50,38 +63,41 @@ TEST_F(TopNRowNumberTest, basic) { // Emit row numbers. auto plan = PlanBuilder() .values({data}) - .topNRowNumber({"c0"}, {"c1"}, limit, true) + .topNRank(functionName_, {"c0"}, {"c1"}, limit, true) .planNode(); assertQuery( plan, fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by c0 order by c1) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by c0 order by c1) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit)); // Do not emit row numbers. plan = PlanBuilder() .values({data}) - .topNRowNumber({"c0"}, {"c1"}, limit, false) + .topNRank(functionName_, {"c0"}, {"c1"}, limit, false) .planNode(); assertQuery( plan, fmt::format( - "SELECT c0, c1, c2 FROM (SELECT *, row_number() over (partition by c0 order by c1) as rn FROM tmp) " + "SELECT c0, c1, c2 FROM (SELECT *, {}() over (partition by c0 order by c1) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit)); // No partitioning keys. plan = PlanBuilder() .values({data}) - .topNRowNumber({}, {"c1"}, limit, true) + .topNRank(functionName_, {}, {"c1"}, limit, true) .planNode(); assertQuery( plan, fmt::format( - "SELECT * FROM (SELECT *, row_number() over (order by c1) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (order by c1) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit)); }; @@ -91,7 +107,7 @@ TEST_F(TopNRowNumberTest, basic) { testLimit(5); } -TEST_F(TopNRowNumberTest, largeOutput) { +TEST_P(MultiTopNRowNumberTest, largeOutput) { // Make 10 vectors. Use different types for partitioning key, sorting key and // data. Use order of columns different from partitioning keys, followed by // sorting keys, followed by data. @@ -119,13 +135,14 @@ TEST_F(TopNRowNumberTest, largeOutput) { core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, limit, true) + .topNRank(functionName_, {"p"}, {"s"}, limit, true) .capturePlanNodeId(topNRowNumberId) .planNode(); auto sql = fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit); AssertQueryBuilder(plan, duckDbQueryRunner_) .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") @@ -154,15 +171,17 @@ TEST_F(TopNRowNumberTest, largeOutput) { // No partitioning keys. plan = PlanBuilder() .values(data) - .topNRowNumber({}, {"s"}, limit, true) + .topNRank(functionName_, {}, {"s"}, limit, true) .planNode(); AssertQueryBuilder(plan, duckDbQueryRunner_) .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") - .assertResults(fmt::format( - "SELECT * FROM (SELECT *, row_number() over (order by s) as rn FROM tmp) " - " WHERE rn <= {}", - limit)); + .assertResults( + fmt::format( + "SELECT * FROM (SELECT *, {}() over (order by s) as rn FROM tmp) " + " WHERE rn <= {}", + functionName_, + limit)); }; testLimit(1); @@ -172,7 +191,7 @@ TEST_F(TopNRowNumberTest, largeOutput) { testLimit(2000); } -TEST_F(TopNRowNumberTest, manyPartitions) { +TEST_P(MultiTopNRowNumberTest, manyPartitions) { const vector_size_t size = 10'000; auto data = split( makeRowVector( @@ -203,13 +222,14 @@ TEST_F(TopNRowNumberTest, manyPartitions) { core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, limit, true) + .topNRank(functionName_, {"p"}, {"s"}, limit, true) .capturePlanNodeId(topNRowNumberId) .planNode(); auto sql = fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit); assertQuery(plan, sql); @@ -243,7 +263,7 @@ TEST_F(TopNRowNumberTest, manyPartitions) { testLimit(1, 1); } -TEST_F(TopNRowNumberTest, fewPartitions) { +TEST_P(MultiTopNRowNumberTest, fewPartitions) { const vector_size_t size = 10'000; auto data = split( makeRowVector( @@ -274,13 +294,14 @@ TEST_F(TopNRowNumberTest, fewPartitions) { core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, limit, true) + .topNRank(functionName_, {"p"}, {"s"}, limit, true) .capturePlanNodeId(topNRowNumberId) .planNode(); auto sql = fmt::format( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " " WHERE rn <= {}", + functionName_, limit); assertQuery(plan, sql); @@ -312,7 +333,7 @@ TEST_F(TopNRowNumberTest, fewPartitions) { testLimit(100); } -TEST_F(TopNRowNumberTest, abandonPartialEarly) { +TEST_P(MultiTopNRowNumberTest, abandonPartialEarly) { auto data = makeRowVector( {"p", "s"}, { @@ -326,9 +347,9 @@ TEST_F(TopNRowNumberTest, abandonPartialEarly) { auto runPlan = [&](int32_t minRows) { auto plan = PlanBuilder() .values(split(data, 10)) - .topNRowNumber({"p"}, {"s"}, 99, false) + .topNRank(functionName_, {"p"}, {"s"}, 99, false) .capturePlanNodeId(topNRowNumberId) - .topNRowNumber({"p"}, {"s"}, 99, true) + .topNRank(functionName_, {"p"}, {"s"}, 99, true) .planNode(); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) @@ -337,8 +358,10 @@ TEST_F(TopNRowNumberTest, abandonPartialEarly) { fmt::format("{}", minRows)) .config(core::QueryConfig::kAbandonPartialTopNRowNumberMinPct, "80") .assertResults( - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " - "WHERE rn <= 99"); + fmt::format( + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " + "WHERE rn <= 99", + functionName_)); return exec::toPlanStats(task->taskStats()); }; @@ -360,7 +383,7 @@ TEST_F(TopNRowNumberTest, abandonPartialEarly) { } } -TEST_F(TopNRowNumberTest, planNodeValidation) { +TEST_P(MultiTopNRowNumberTest, planNodeValidation) { auto data = makeRowVector( ROW({"a", "b", "c", "d", "e"}, { @@ -377,7 +400,7 @@ TEST_F(TopNRowNumberTest, planNodeValidation) { int32_t limit = 10) { PlanBuilder() .values({data}) - .topNRowNumber(partitionKeys, sortingKeys, limit, true) + .topNRank(functionName_, partitionKeys, sortingKeys, limit, true) .planNode(); }; @@ -403,15 +426,16 @@ TEST_F(TopNRowNumberTest, planNodeValidation) { plan({"a", "b"}, {"c"}, 0), "Limit must be greater than zero"); } -TEST_F(TopNRowNumberTest, maxSpillBytes) { +TEST_P(MultiTopNRowNumberTest, maxSpillBytes) { const auto rowType = ROW({"c0", "c1", "c2"}, {INTEGER(), INTEGER(), VARCHAR()}); const auto vectors = createVectors(rowType, 1024, 15 << 20); auto planNodeIdGenerator = std::make_shared(); auto plan = PlanBuilder(planNodeIdGenerator) .values(vectors) - .topNRowNumber({"c0"}, {"c1"}, 100, true) + .topNRank(functionName_, {"c0"}, {"c1"}, 100, true) .planNode(); + struct { int32_t maxSpilledBytes; bool expectedExceedLimit; @@ -451,7 +475,7 @@ TEST_F(TopNRowNumberTest, maxSpillBytes) { // This test verifies that TopNRowNumber operator reclaim all the memory after // spill. -DEBUG_ONLY_TEST_F(TopNRowNumberTest, memoryUsageCheckAfterReclaim) { +DEBUG_ONLY_TEST_P(MultiTopNRowNumberTest, memoryUsageCheckAfterReclaim) { std::atomic_int inputCount{0}; SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal::addInput", @@ -496,13 +520,14 @@ DEBUG_ONLY_TEST_F(TopNRowNumberTest, memoryUsageCheckAfterReclaim) { core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, 1'000, true) + .topNRank(functionName_, {"p"}, {"s"}, 1'000, true) .capturePlanNodeId(topNRowNumberId) .planNode(); - const auto sql = - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " - " WHERE rn <= 1000"; + const auto sql = fmt::format( + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " + " WHERE rn <= 1000", + functionName_); auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) .config(core::QueryConfig::kSpillEnabled, "true") .config(core::QueryConfig::kTopNRowNumberSpillEnabled, "true") @@ -520,7 +545,7 @@ DEBUG_ONLY_TEST_F(TopNRowNumberTest, memoryUsageCheckAfterReclaim) { // This test verifies that TopNRowNumber operator can be closed twice which // might be triggered by memory pool abort. -DEBUG_ONLY_TEST_F(TopNRowNumberTest, doubleClose) { +DEBUG_ONLY_TEST_P(MultiTopNRowNumberTest, doubleClose) { const std::string errorMessage("doubleClose"); SCOPED_TESTVALUE_SET( "facebook::velox::exec::Driver::runInternal::noMoreInput", @@ -556,15 +581,72 @@ DEBUG_ONLY_TEST_F(TopNRowNumberTest, doubleClose) { core::PlanNodeId topNRowNumberId; auto plan = PlanBuilder() .values(data) - .topNRowNumber({"p"}, {"s"}, 1'000, true) + .topNRank(functionName_, {"p"}, {"s"}, 1'000, true) .capturePlanNodeId(topNRowNumberId) .planNode(); - const auto sql = - "SELECT * FROM (SELECT *, row_number() over (partition by p order by s) as rn FROM tmp) " - " WHERE rn <= 1000"; + const auto sql = fmt::format( + "SELECT * FROM (SELECT *, {}() over (partition by p order by s) as rn FROM tmp) " + " WHERE rn <= 1000", + functionName_); VELOX_ASSERT_THROW(assertQuery(plan, sql), errorMessage); } + +// This test verifies that TopNRowNumber operator handles OOM that occurs in the +// middle of groupProbe, after inserting some new rows into the row container. +DEBUG_ONLY_TEST_P(MultiTopNRowNumberTest, oomInGroupProbe) { + const std::string errorMessage("Simulated OOM in groupProbe"); + std::atomic_int insertCount{0}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::HashTable::insertEntry", + std::function( + ([&](memory::MemoryPool* /*pool*/) { + // Trigger OOM after inserting some rows to simulate failure in the + // middle of groupProbe insertion. + if (++insertCount == 100) { + VELOX_FAIL(errorMessage); + } + }))); + + const vector_size_t size = 10'000; + auto data = split( + makeRowVector( + {"d", "s", "p"}, + { + // Data. + makeFlatVector( + size, [](auto row) { return row; }, nullEvery(11)), + // Sorting key. + makeFlatVector( + size, + [](auto row) { return (size - row) * 10; }, + [](auto row) { return row == 123; }), + // Partitioning key. Make sure to spread rows from the same + // partition across multiple batches. + makeFlatVector( + size, [](auto row) { return row % 5'000; }, nullEvery(7)), + }), + 10); + + core::PlanNodeId topNRowNumberId; + auto plan = PlanBuilder() + .values(data) + .topNRank(functionName_, {"p"}, {"s"}, 1'000, true) + .capturePlanNodeId(topNRowNumberId) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan).copyResults(pool_.get()), errorMessage); +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + TopNRowNumberTest, + MultiTopNRowNumberTest, + testing::ValuesIn( + std::vector( + {core::TopNRowNumberNode::RankFunction::kRowNumber, + core::TopNRowNumberNode::RankFunction::kRank, + core::TopNRowNumberNode::RankFunction::kDenseRank}))); } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/TraceUtilTest.cpp b/velox/exec/tests/TraceUtilTest.cpp index 14a62f2159c4..acb94280ba52 100644 --- a/velox/exec/tests/TraceUtilTest.cpp +++ b/velox/exec/tests/TraceUtilTest.cpp @@ -139,6 +139,29 @@ TEST_F(TraceUtilTest, traceDirectoryLayoutUtilities) { "/traceRoot/queryId/taskId/1/1/1/op_split_trace.split"); } +TEST_F(TraceUtilTest, traceDirectoryTrailingSlashHandling) { + const std::string queryId = "queryId"; + const std::string taskId = "taskId"; + + const std::string expectedPath = "/traceRoot/queryId"; + const std::string expectedTaskPath = "/traceRoot/queryId/taskId"; + + // Test with trailing slash + const std::string traceDirWithSlash = "/traceRoot/"; + ASSERT_EQ(getQueryTraceDirectory(traceDirWithSlash, queryId), expectedPath); + ASSERT_EQ( + getTaskTraceDirectory(traceDirWithSlash, queryId, taskId), + expectedTaskPath); + + // Test without trailing slash + const std::string traceDirWithoutSlash = "/traceRoot"; + ASSERT_EQ( + getQueryTraceDirectory(traceDirWithoutSlash, queryId), expectedPath); + ASSERT_EQ( + getTaskTraceDirectory(traceDirWithoutSlash, queryId, taskId), + expectedTaskPath); +} + TEST_F(TraceUtilTest, getTaskIds) { const auto rootDir = TempDirectoryPath::create(); const auto rootPath = rootDir->getPath(); diff --git a/velox/exec/tests/TreeOfLosersTest.cpp b/velox/exec/tests/TreeOfLosersTest.cpp index 8e193e40fcd2..0a56a4f49701 100644 --- a/velox/exec/tests/TreeOfLosersTest.cpp +++ b/velox/exec/tests/TreeOfLosersTest.cpp @@ -181,8 +181,9 @@ TEST_F(TreeOfLosersTest, allSorted) { TEST_F(TreeOfLosersTest, allEmpty) { for (bool testNextEqual : {false, true}) { for (int numStreams : {0, 1, 5, 100}) { - SCOPED_TRACE(fmt::format( - "numStreams: {}, testNextEqual", numStreams, testNextEqual)); + SCOPED_TRACE( + fmt::format( + "numStreams: {}, testNextEqual", numStreams, testNextEqual)); std::vector> mergeStreams; for (int i = 0; i < numStreams; ++i) { mergeStreams.push_back( @@ -211,12 +212,13 @@ TEST_F(TreeOfLosersTest, randomWithDuplicates) { for (int iter = 0; iter < 10; ++iter) { const int numCount = std::max(1, folly::Random::rand32(1000'000)); const int numStreams = std::max(3, folly::Random::rand32(100)); - SCOPED_TRACE(fmt::format( - "iter: {}, testNextEqual: {}, numCount: {}, numStreams: {}", - iter, - testNextEqual, - numCount, - numStreams)); + SCOPED_TRACE( + fmt::format( + "iter: {}, testNextEqual: {}, numCount: {}, numStreams: {}", + iter, + testNextEqual, + numCount, + numStreams)); std::vector> streamNumVectors(numStreams); for (int i = 0; i < numCount; ++i) { const int streamIndex = folly::Random::rand32(numStreams); diff --git a/velox/exec/tests/UnnestTest.cpp b/velox/exec/tests/UnnestTest.cpp index a3bc2b20b8f2..7d8e1b7d6b78 100644 --- a/velox/exec/tests/UnnestTest.cpp +++ b/velox/exec/tests/UnnestTest.cpp @@ -67,6 +67,160 @@ TEST_P(UnnestTest, basicArray) { assertQuery(params, "SELECT c0, UNNEST(c1) FROM tmp WHERE c0 % 7 > 0"); } +TEST_P(UnnestTest, arrayWithIdentityMap) { + struct { + int32_t vectorSize; + int32_t arraySize; + int32_t outputBatchSize; + int32_t expectedOutputBatches; + + std::string debugString() const { + return fmt::format( + "vectorSize: {}, arraySize: {}, outputBatchSize: {}, expectedOutputBatches: {}", + vectorSize, + arraySize, + outputBatchSize, + expectedOutputBatches); + } + } testSettings[] = { + {100, 1, 1, 100}, + {100, 1, 100, 1}, + {100, 4, 1, 400}, + {100, 4, 100, 4}, + {100, 4, 400, 1}, + {1024, 4, 256, 16}, + {1024, 4, 1024, 4}, + {1024, 4, 4096, 1}, + {1024, 1, 256, 4}, + {1024, 4, 7, 586}, + {1024, 1, 7, 147}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto vector = makeRowVector( + {makeFlatVector( + testData.vectorSize, [](auto row) { return row; }), + makeArrayVector( + testData.vectorSize, + [&](auto /*unused*/) { return testData.arraySize; }, + [](auto row, auto index) { return index * (row % 3); })}); + createDuckDbTable({vector}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId unnestPlanNodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .values({vector}) + .unnest({"c0"}, {"c1"}) + .capturePlanNodeId(unnestPlanNodeId) + .planNode(); + + auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(testData.outputBatchSize)) + .assertResults({"SELECT c0, UNNEST(c1) FROM tmp"}); + const auto taskStats = task->taskStats(); + ASSERT_EQ( + exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputVectors, + testData.expectedOutputBatches); + } +} + +TEST_P(UnnestTest, arrayWithoutIdentityMap) { + struct { + int32_t vectorSize; + int32_t arraySize1; + int32_t arraySize2; + int32_t outputBatchSize; + int32_t expectedOutputBatches; + + std::string debugString() const { + return fmt::format( + "vectorSize: {}, arraySize1: {}, arraySize2: {}, outputBatchSize: {}, expectedOutputBatches: {}", + vectorSize, + arraySize1, + arraySize2, + outputBatchSize, + expectedOutputBatches); + } + } testSettings[] = { + {100, 1, 2, 100, 2}, + {100, 1, 2, 200, 1}, + {1024, 1, 4, 256, 16}, + {1024, 3, 4, 256, 16}, + {1024, 1, 4, 4096, 1}, + {1024, 3, 4, 4096, 1}, + {1024, 1, 4, 7, 586}, + {1024, 3, 4, 7, 586}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto vector = makeRowVector( + {makeFlatVector( + testData.vectorSize, [](auto row) { return row; }), + makeArrayVector( + testData.vectorSize, + [&](auto /*unused*/) { return testData.arraySize1; }, + [](auto row, auto index) { return index * (row % 3); }), + makeArrayVector( + testData.vectorSize, + [&](auto /*unused*/) { return testData.arraySize2; }, + [](auto row, auto index) { return index * (row % 3); })}); + createDuckDbTable({vector}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId unnestPlanNodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .values({vector}) + .unnest({"c0"}, {"c1", "c2"}) + .capturePlanNodeId(unnestPlanNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(testData.outputBatchSize)) + .assertResults({"SELECT c0, UNNEST(c1), UNNEST(c2) FROM tmp"}); + const auto taskStats = task->taskStats(); + ASSERT_EQ( + exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputVectors, + testData.expectedOutputBatches); + } +} + +TEST_P(UnnestTest, arrayWithNull) { + const auto vector = makeRowVector( + {makeFlatVector(1024, [](auto row) { return row; }), + makeArrayVector( + 1024, + [&](auto /*unused*/) { return 3; }, + [](auto row, auto index) { return index * (row % 3); }, + nullEvery(6)), + makeArrayVector( + 1024, + [&](auto /*unused*/) { return 4; }, + [](auto row, auto index) { return index * (row % 3); })}); + createDuckDbTable({vector}); + + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId unnestPlanNodeId; + auto plan = PlanBuilder(planNodeIdGenerator) + .values({vector}) + .unnest({"c0"}, {"c1", "c2"}) + .capturePlanNodeId(unnestPlanNodeId) + .planNode(); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kPreferredOutputBatchRows, std::to_string(25)) + .assertResults({"SELECT c0, UNNEST(c1), UNNEST(c2) FROM tmp"}); + const auto taskStats = task->taskStats(); + ASSERT_EQ( + exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputVectors, 164); +} + TEST_P(UnnestTest, arrayWithOrdinality) { auto array = vectorMaker_.arrayVectorNullable( {{{1, 2, std::nullopt, 4}}, @@ -131,6 +285,119 @@ TEST_P(UnnestTest, arrayWithOrdinality) { assertQuery(params, expectedInDict); } +TEST_P(UnnestTest, arrayWithMarker) { + const auto array = makeArrayVectorFromJson( + {"[1, 2, null, 4]", "null", "[5, 6]", "[]", "[null]", "[7, 8, 9]"}); + const auto input = makeRowVector( + {makeNullableFlatVector({1.1, 2.2, 3.3, 4.4, 5.5, std::nullopt}), + array}); + + const auto expected = makeRowVector( + {makeNullableFlatVector( + {1.1, + 1.1, + 1.1, + 1.1, + 3.3, + 3.3, + 5.5, + std::nullopt, + std::nullopt, + std::nullopt}), + makeNullableFlatVector( + {1, 2, std::nullopt, 4, 5, 6, std::nullopt, 7, 8, 9})}); + + const auto expectedWithOrdinality = makeRowVector( + {expected->childAt(0), + expected->childAt(1), + makeNullableFlatVector({1, 2, 3, 4, 1, 2, 1, 1, 2, 3})}); + + const auto expectedWithMarker = makeRowVector( + {makeNullableFlatVector( + {1.1, + 1.1, + 1.1, + 1.1, + 2.2, + 3.3, + 3.3, + 4.4, + 5.5, + std::nullopt, + std::nullopt, + std::nullopt}), + makeNullableFlatVector( + {1, + 2, + std::nullopt, + 4, + std::nullopt, + 5, + 6, + std::nullopt, + std::nullopt, + 7, + 8, + 9}), + makeNullableFlatVector( + {true, + true, + true, + true, + false, + true, + true, + false, + true, + true, + true, + true})}); + const auto expectedWithBoth = makeRowVector( + {expectedWithMarker->childAt(0), + expectedWithMarker->childAt(1), + makeNullableFlatVector({1, 2, 3, 4, 0, 1, 2, 0, 1, 1, 2, 3}), + expectedWithMarker->childAt(2)}); + + struct { + bool hasOrdinality; + bool hasEmptyUnnestValue; + RowVectorPtr input; + RowVectorPtr expected; + + std::string debugString() const { + return fmt::format( + "hasOrdinality: {}, hasEmptyUnnestValue: {}, input: {}, expected: {}", + hasOrdinality, + hasEmptyUnnestValue, + input->toString(0, input->size()), + expected->toString(0, expected->size())); + } + } testSettings[] = { + {false, false, input, expected}, + {true, false, input, expectedWithOrdinality}, + {false, true, input, expectedWithMarker}, + {true, true, input, expectedWithBoth}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::optional ordinalityName; + if (testData.hasOrdinality) { + ordinalityName = "ordinal"; + } + std::optional markerName; + if (testData.hasEmptyUnnestValue) { + markerName = "emptyUnnestValue"; + } + auto op = PlanBuilder() + .values({testData.input}) + .unnest({"c0"}, {"c1"}, ordinalityName, markerName) + .planNode(); + auto params = makeCursorParameters(op); + assertQuery(params, testData.expected); + } +} + TEST_P(UnnestTest, basicMap) { auto vector = makeRowVector( {makeFlatVector(100, [](auto row) { return row; }), @@ -199,6 +466,92 @@ TEST_P(UnnestTest, mapWithOrdinality) { assertQuery(params, expectedInDict); } +TEST_P(UnnestTest, mapWithMarker) { + const auto map = makeNullableMapVector( + {{{{1, 1.1}, {2, std::nullopt}}}, + common::testutil::optionalEmpty, + {{{3, 3.3}, {4, 4.4}, {5, 5.5}}}, + std::nullopt, + {{{6, std::nullopt}}}}); + + const auto input = + makeRowVector({makeNullableFlatVector({1, 2, 3, 4, 5}), map}); + + const auto expected = makeRowVector( + {makeNullableFlatVector({1, 1, 3, 3, 3, 5}), + makeNullableFlatVector({1, 2, 3, 4, 5, 6}), + makeNullableFlatVector( + {1.1, std::nullopt, 3.3, 4.4, 5.5, std::nullopt})}); + + const auto expectedWithOrdinality = makeRowVector( + {expected->childAt(0), + expected->childAt(1), + expected->childAt(2), + makeNullableFlatVector({1, 2, 1, 2, 3, 1})}); + + const auto expectedWithMarker = makeRowVector( + {makeNullableFlatVector({1, 1, 2, 3, 3, 3, 4, 5}), + makeNullableFlatVector( + {1, 2, std::nullopt, 3, 4, 5, std::nullopt, 6}), + makeNullableFlatVector( + {1.1, + std::nullopt, + std::nullopt, + 3.3, + 4.4, + 5.5, + std::nullopt, + std::nullopt}), + makeNullableFlatVector( + {true, true, false, true, true, true, false, true})}); + + const auto expectedWithBoth = makeRowVector( + {expectedWithMarker->childAt(0), + expectedWithMarker->childAt(1), + expectedWithMarker->childAt(2), + makeNullableFlatVector({1, 2, 0, 1, 2, 3, 0, 1}), + expectedWithMarker->childAt(3)}); + + struct { + bool hasOrdinality; + bool hasEmptyUnnestValue; + RowVectorPtr input; + RowVectorPtr expected; + + std::string debugString() const { + return fmt::format( + "hasOrdinality: {}, hasEmptyUnnestValue: {}, input: {}, expected: {}", + hasOrdinality, + hasEmptyUnnestValue, + input->toString(0, input->size()), + expected->toString(0, expected->size())); + } + } testSettings[] = { + {false, false, input, expected}, + {true, false, input, expectedWithOrdinality}, + {false, true, input, expectedWithMarker}, + {true, true, input, expectedWithBoth}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::optional ordinalityName; + if (testData.hasOrdinality) { + ordinalityName = "ordinal"; + } + std::optional markerName; + if (testData.hasEmptyUnnestValue) { + markerName = "emptyUnnestValue"; + } + auto op = PlanBuilder() + .values({testData.input}) + .unnest({"c0"}, {"c1"}, ordinalityName, markerName) + .planNode(); + auto params = makeCursorParameters(op); + assertQuery(params, testData.expected); + } +} + TEST_P(UnnestTest, multipleColumns) { std::vector offsets(100, 0); for (int i = 1; i < 100; ++i) { @@ -225,8 +578,8 @@ TEST_P(UnnestTest, multipleColumns) { .unnest({"c0"}, {"c1", "c2", "c3"}) .planNode(); - // DuckDB doesn't support Unnest from MAP column. Hence,using 2 separate array - // columns with the keys and values part of the MAP to validate. + // DuckDB doesn't support Unnest from MAP column. Hence,using 2 separate + // array columns with the keys and values part of the MAP to validate. auto duckDbVector = makeRowVector( {makeFlatVector(100, [](auto row) { return row; }), makeArrayVector( @@ -278,8 +631,8 @@ TEST_P(UnnestTest, multipleColumnsWithOrdinality) { .unnest({"c0"}, {"c1", "c2", "c3"}, "ordinal") .planNode(); - // DuckDB doesn't support Unnest from MAP column. Hence,using 2 separate array - // columns with the keys and values part of the MAP to validate. + // DuckDB doesn't support Unnest from MAP column. Hence,using 2 separate + // array columns with the keys and values part of the MAP to validate. auto ordinalitySize = [&](auto row) { if (row % 42 == 0) { return offsets[row + 1] - offsets[row]; @@ -495,15 +848,16 @@ TEST_P(UnnestTest, barrier) { // Unnest 1K rows into 3K rows. auto planNodeIdGenerator = std::make_shared(); core::PlanNodeId unnestPlanNodeId; - const auto plan = PlanBuilder(planNodeIdGenerator) - .startTableScan() - .outputType(std::dynamic_pointer_cast( - vectors[0]->type())) - .endTableScan() - .project({"sequence(1, 3) as s"}) - .unnest({}, {"s"}) - .capturePlanNodeId(unnestPlanNodeId) - .planNode(); + const auto plan = + PlanBuilder(planNodeIdGenerator) + .startTableScan() + .outputType( + std::dynamic_pointer_cast(vectors[0]->type())) + .endTableScan() + .project({"sequence(1, 3) as s"}) + .unnest({}, {"s"}) + .capturePlanNodeId(unnestPlanNodeId) + .planNode(); const auto expectedResult = makeRowVector({ makeFlatVector( @@ -527,7 +881,7 @@ TEST_P(UnnestTest, barrier) { const int numExpectedOutputVectors = bits::divRoundUp(numRowsPerSplit * 3, testData.numOutputRows) * numSplits; - auto task = AssertQueryBuilder(plan, duckDbQueryRunner_) + auto task = AssertQueryBuilder(plan) .config(core::QueryConfig::kSparkPartitionId, "0") .config( core::QueryConfig::kMaxSplitPreloadPerDriver, @@ -545,18 +899,76 @@ TEST_P(UnnestTest, barrier) { ASSERT_EQ( exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputRows, numSplits * numRowsPerSplit * 3); - // NOTE: unnest operator produce the same number of output batches no matter - // it is under barrier execution mode or not. + // NOTE: unnest operator produce the same number of output batches no + // matter it is under barrier execution mode or not. ASSERT_EQ( exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputVectors, numExpectedOutputVectors); } } +TEST_P(UnnestTest, spiltOutput) { + std::vector vectors; + const auto numBatches = 3; + const auto inputBatchSize = 256; + for (int32_t i = 0; i < 3; ++i) { + auto vector = makeRowVector({ + makeFlatVector(inputBatchSize, [](auto row) { return row; }), + }); + vectors.push_back(vector); + } + createDuckDbTable(vectors); + + // Unnest 256 rows into 768 rows. + auto planNodeIdGenerator = std::make_shared(); + core::PlanNodeId unnestPlanNodeId; + const auto plan = PlanBuilder(planNodeIdGenerator) + .values(vectors) + .project({"sequence(1, 3) as s"}) + .unnest({}, {"s"}) + .capturePlanNodeId(unnestPlanNodeId) + .planNode(); + + const auto expectedResult = makeRowVector({ + makeFlatVector( + numBatches * 3 * inputBatchSize, + [](auto row) { return 1 + row % 3; }), + }); + + struct { + bool produceSingleOutput; + int expectedNumOutputExectors; + + std::string toString() const { + return fmt::format( + "produceSingleOutput {}, expectedNumOutputExectors {}", + produceSingleOutput, + expectedNumOutputExectors); + } + } testSettings[] = { + {true, numBatches}, + {false, bits::divRoundUp(inputBatchSize * 3, GetParam()) * numBatches}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.toString()); + auto task = AssertQueryBuilder(plan) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(GetParam())) + .config( + core::QueryConfig::kUnnestSplitOutput, + testData.produceSingleOutput ? "false" : "true") + .assertResults(expectedResult); + const auto taskStats = task->taskStats(); + ASSERT_EQ( + exec::toPlanStats(taskStats).at(unnestPlanNodeId).outputVectors, + testData.expectedNumOutputExectors); + } +} + VELOX_INSTANTIATE_TEST_SUITE_P( UnnestTest, UnnestTest, - testing::ValuesIn(/*batchSize*/ {2, 17, 33, 1024}), + testing::ValuesIn(/*batchSize*/ {2, 17, 33, 512}), [](const testing::TestParamInfo& info) { return fmt::format("outputBatchSize_{}", info.param); }); diff --git a/velox/exec/tests/VectorHasherTest.cpp b/velox/exec/tests/VectorHasherTest.cpp index 2a3cb8ef9f78..cb3224162f40 100644 --- a/velox/exec/tests/VectorHasherTest.cpp +++ b/velox/exec/tests/VectorHasherTest.cpp @@ -462,14 +462,15 @@ TEST_F(VectorHasherTest, stringDistinctOverflow) { for (auto i = 0; i < 7; ++i) { auto& stringVec = strings[i]; stringVec.resize(numRows); - batches.emplace_back(makeFlatVector( - numRows, [&i, &stringVec, numRows](vector_size_t row) { - const auto num = numRows * i + row; - stringVec[row] = (row != 0) - ? fmt::format("abcdefghijabcdefghij{}", num) - : fmt::format("s{}", num); - return StringView(stringVec[row]); - })); + batches.emplace_back( + makeFlatVector( + numRows, [&i, &stringVec, numRows](vector_size_t row) { + const auto num = numRows * i + row; + stringVec[row] = (row != 0) + ? fmt::format("abcdefghijabcdefghij{}", num) + : fmt::format("s{}", num); + return StringView(stringVec[row]); + })); } SelectivityVector rows(numRows, true); @@ -680,9 +681,9 @@ TEST_F(VectorHasherTest, merge) { VectorHasher emptyHasher(BIGINT(), 0); VectorHasher otherEmptyHasher(BIGINT(), 0); EXPECT_TRUE(emptyHasher.empty()); - emptyHasher.merge(otherHasher); - hasher.merge(emptyHasher); - hasher.merge(otherEmptyHasher); + emptyHasher.merge(otherHasher, 1'000'000); + hasher.merge(emptyHasher, 1'000'000); + hasher.merge(otherEmptyHasher, 1'000'000); uint64_t numRange; uint64_t numDistinct; hasher.cardinality(0, numRange, numDistinct); @@ -720,6 +721,45 @@ TEST_F(VectorHasherTest, merge) { EXPECT_EQ(numDistinct - 1, ids.size()); } +TEST_F(VectorHasherTest, mergeMaxNumDistinct) { + constexpr vector_size_t kSize = 100; + SelectivityVector rows(kSize); + raw_vector hashes(kSize); + + auto vector1 = + makeFlatVector(kSize, [](vector_size_t row) { return row; }); + auto vector2 = makeFlatVector( + kSize, [](vector_size_t row) { return 1000 + row; }); + auto vector3 = makeFlatVector( + kSize, [](vector_size_t row) { return 2000 + row; }); + + VectorHasher hasher1(BIGINT(), 0); + hasher1.decode(*vector1, rows); + hasher1.computeValueIds(rows, hashes); + + VectorHasher hasher2(BIGINT(), 0); + hasher2.decode(*vector2, rows); + hasher2.computeValueIds(rows, hashes); + + VectorHasher hasher3(BIGINT(), 0); + hasher3.decode(*vector3, rows); + hasher3.computeValueIds(rows, hashes); + + hasher1.merge(hasher2, kSize * 2); + uint64_t numRange; + uint64_t numDistinct; + hasher1.cardinality(0, numRange, numDistinct); + EXPECT_EQ(numDistinct, kSize * 2 + 1); + + hasher1.merge(hasher3, kSize * 2); + hasher1.cardinality(0, numRange, numDistinct); + EXPECT_EQ(numDistinct, VectorHasher::kRangeTooLarge); + + hasher1.merge(hasher3, kSize * 10); + hasher1.cardinality(0, numRange, numDistinct); + EXPECT_EQ(numDistinct, VectorHasher::kRangeTooLarge); +} + TEST_F(VectorHasherTest, computeValueIdsBigint) { testComputeValueIds(false); testComputeValueIds(true); @@ -1197,3 +1237,84 @@ TEST_F(VectorHasherTest, customComparisonRow) { {std::nullopt, 0, 1, 0, 1, 0, 1}, velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())})); } + +TEST_F(VectorHasherTest, customComparisonValueIds) { + // Test that VectorHasher created with custom comparison type + // has value IDs disabled (distinctOverflow_ and rangeOverflow_ set). + auto vectorHasher = exec::VectorHasher::create( + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), 1); + + // Test that types with custom comparison do not support value IDs. + EXPECT_FALSE(vectorHasher->typeSupportsValueIds()); + + // Verify that mayUseValueIds() returns false for custom comparison types. + EXPECT_FALSE(vectorHasher->mayUseValueIds()); +} + +DEBUG_ONLY_TEST_F(VectorHasherTest, customComparisonNoValueIds) { + // Test that custom comparison types cannot use value IDs for optimization. + auto data = makeRowVector({makeNullableFlatVector( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON())}); + + auto hasher = exec::VectorHasher::create( + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), 0); + + SelectivityVector allRows(data->size()); + raw_vector result(data->size()); + std::fill(result.begin(), result.end(), 0); + + hasher->decode(*data->childAt(0), allRows); + + VELOX_ASSERT_THROW( + hasher->computeValueIds(allRows, result), "Value IDs cannot be used"); + VectorHasher::ScratchMemory scratchMemory; + VELOX_ASSERT_THROW( + hasher->lookupValueIds(*data->childAt(0), allRows, scratchMemory, result), + "Value IDs cannot be used"); + EXPECT_EQ(nullptr, hasher->getFilter(true)); + VELOX_ASSERT_THROW( + hasher->enableValueRange(1, 50), "Value IDs cannot be used"); + VELOX_ASSERT_THROW( + hasher->enableValueRange(1, 50), "Value IDs cannot be used"); +} + +DEBUG_ONLY_TEST_F(VectorHasherTest, computeValueIdsForRowsCustomComparison) { + // Test that computeValueIdsForRows throws an exception for types with custom + // comparison. + auto hasher = exec::VectorHasher::create( + velox::test::BIGINT_TYPE_WITH_CUSTOM_COMPARISON(), 0); + + constexpr int32_t kNumGroups = 5; + constexpr int32_t kRowSize = 16; + constexpr int32_t kValueOffset = 0; + constexpr int32_t kNullByte = 8; + constexpr uint8_t kNullMask = 1; + + // Allocate memory for row-wise data. + std::vector> rowData(kNumGroups); + std::vector groups(kNumGroups); + + for (int i = 0; i < kNumGroups; ++i) { + rowData[i].resize(kRowSize, 0); + groups[i] = rowData[i].data(); + + // Set values for all rows (no nulls for simplicity). + *reinterpret_cast(groups[i] + kValueOffset) = i * 256; + } + + raw_vector result(kNumGroups); + std::fill(result.begin(), result.end(), 0); + + // computeValueIdsForRows should throw an exception for types with custom + // comparison. + VELOX_ASSERT_THROW( + hasher->computeValueIdsForRows( + groups.data(), + kNumGroups, + kValueOffset, + kNullByte, + kNullMask, + result), + "Value IDs cannot be used"); +} diff --git a/velox/exec/tests/VeloxIn10MinDemo.cpp b/velox/exec/tests/VeloxIn10MinDemo.cpp index 87d571a1788c..e1c6ec5d5504 100644 --- a/velox/exec/tests/VeloxIn10MinDemo.cpp +++ b/velox/exec/tests/VeloxIn10MinDemo.cpp @@ -14,10 +14,10 @@ * limitations under the License. */ #include +#include #include "velox/common/memory/Memory.h" #include "velox/connectors/tpch/TpchConnector.h" #include "velox/connectors/tpch/TpchConnectorSplit.h" -#include "velox/core/Expressions.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/expression/Expr.h" @@ -47,25 +47,17 @@ class VeloxIn10MinDemo : public VectorTestBase { // Register type resolver with DuckDB SQL parser. parse::registerTypeResolver(); - // Register the TPC-H Connector Factory. - connector::registerConnectorFactory( - std::make_shared()); - // Create and register a TPC-H connector. - auto tpchConnector = - connector::getConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName) - ->newConnector( - kTpchConnectorId, - std::make_shared( - std::unordered_map())); + connector::tpch::TpchConnectorFactory factory; + auto tpchConnector = factory.newConnector( + kTpchConnectorId, + std::make_shared( + std::unordered_map())); connector::registerConnector(tpchConnector); } ~VeloxIn10MinDemo() { connector::unregisterConnector(kTpchConnectorId); - connector::unregisterConnectorFactory( - connector::tpch::TpchConnectorFactory::kTpchConnectorName); } /// Parse SQL expression into a typed expression tree using DuckDB SQL parser. @@ -99,8 +91,9 @@ class VeloxIn10MinDemo : public VectorTestBase { /// Make TPC-H split to add to TableScan node. exec::Split makeTpchSplit() const { - return exec::Split(std::make_shared( - kTpchConnectorId, /*cacheable=*/true, 1, 0)); + return exec::Split( + std::make_shared( + kTpchConnectorId, /*cacheable=*/true, 1, 0)); } /// Run the demo. @@ -108,7 +101,7 @@ class VeloxIn10MinDemo : public VectorTestBase { std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::hardware_concurrency())}; std::shared_ptr queryCtx_{ core::QueryCtx::create(executor_.get())}; std::unique_ptr execCtx_{ diff --git a/velox/exec/tests/WindowTest.cpp b/velox/exec/tests/WindowTest.cpp index c2e84f9b0f32..0c1be2f231c7 100644 --- a/velox/exec/tests/WindowTest.cpp +++ b/velox/exec/tests/WindowTest.cpp @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/exec/Window.h" +#include #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/exec/OrderBy.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/RowsStreamingWindowBuild.h" #include "velox/exec/SortWindowBuild.h" @@ -61,12 +64,13 @@ class WindowTest : public OperatorTestBase { 0, 0, "none", + 0, prefixSortConfig); } const std::shared_ptr executor_{ std::make_shared( - std::thread::hardware_concurrency())}; + folly::hardware_concurrency())}; tsan_atomic nonReclaimableSection_{false}; }; @@ -113,6 +117,112 @@ TEST_F(WindowTest, spill) { ASSERT_GT(stats.spilledPartitions, 0); } +TEST_F(WindowTest, spillBatchReadTinyPartitions) { + const vector_size_t size = 1'000; + const uint32_t minReadBatchRows = 100; + // Each tiny partition has 1 row. + const uint32_t partitionRows = 1; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row / partitionRows; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config( + core::QueryConfig::kWindowSpillMinReadBatchRows, minReadBatchRows) + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s) FROM tmp"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + ASSERT_EQ( + stats.operatorStats.at("Window") + ->customStats[Window::kWindowSpillReadNumBatches] + .sum, + size / minReadBatchRows); +} + +TEST_F(WindowTest, spillBatchReadHugePartitions) { + const vector_size_t size = 1'000; + const uint32_t minReadBatchRows = 100; + // Each huge partition has 200 rows, which is larger than minReadBatchRows. + const uint32_t partitionRows = 200; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row / partitionRows; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config( + core::QueryConfig::kWindowSpillMinReadBatchRows, minReadBatchRows) + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s) FROM tmp"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); + ASSERT_EQ( + stats.operatorStats.at("Window") + ->customStats[Window::kWindowSpillReadNumBatches] + .sum, + size / partitionRows); +} + TEST_F(WindowTest, spillUnsupported) { const vector_size_t size = 1'000; auto data = makeRowVector( @@ -178,10 +288,11 @@ TEST_F(WindowTest, rowBasedStreamingWindowOOM) { auto planNodeIdGenerator = std::make_shared(); CursorParameters params; auto queryCtx = core::QueryCtx::create(executor_.get()); - queryCtx->testingOverrideMemoryPool(memory::memoryManager()->addRootPool( - queryCtx->queryId(), - 8'388'608 /* 8MB */, - exec::MemoryReclaimer::create())); + queryCtx->testingOverrideMemoryPool( + memory::memoryManager()->addRootPool( + queryCtx->queryId(), + 8'388'608 /* 8MB */, + exec::MemoryReclaimer::create())); params.queryCtx = queryCtx; @@ -323,6 +434,122 @@ DEBUG_ONLY_TEST_F(WindowTest, valuesRowsStreamingWindowBuild) { ASSERT_TRUE(isStreamCreated.load()); } +TEST_F(WindowTest, prePartitionedSortBuild) { + const vector_size_t size = 1'000; + const int numPartitions = 37; + const int numSubPartitions = 4; + auto data = makeRowVector( + {"p", "s"}, + { + // Partition key. + makeFlatVector( + size, [](auto row) { return row % numPartitions; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = + PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s desc)"}) + .capturePlanNodeId(windowId) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config( + core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)) + .assertResults( + "SELECT *, row_number() over (partition by p order by s desc) FROM tmp ORDER BY s"); +} + +TEST_F(WindowTest, prePartitionedSortBuildSkewed) { + const vector_size_t size = 1'000; + const int numPartitions = 4; + const int numSubPartitions = 16; + auto data = makeRowVector( + {"p", "s"}, + { + // Partition key. + makeFlatVector( + size, [](auto row) { return row % numPartitions; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = + PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s desc)"}) + .capturePlanNodeId(windowId) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config( + core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)) + .assertResults( + "SELECT *, row_number() over (partition by p order by s desc) FROM tmp ORDER BY s"); +} + +TEST_F(WindowTest, prePartitionedBuildWithSpill) { + const vector_size_t size = 1'000; + const int numPartitions = 37; + const int numSubPartitions = 4; + auto data = makeRowVector( + {"d", "p", "s"}, + { + // Payload. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector( + size, [](auto row) { return row % numPartitions; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + auto plan = + PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p order by s desc)"}) + .capturePlanNodeId(windowId) + .planNode(); + + auto spillDirectory = TempDirectoryPath::create(); + TestScopedSpillInjection scopedSpillInjection(100); + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .config( + core::QueryConfig::kWindowNumSubPartitions, + std::to_string(numSubPartitions)) + .config(core::QueryConfig::kSpillEnabled, "true") + .config(core::QueryConfig::kWindowSpillEnabled, "true") + .config(core::QueryConfig::kOrderBySpillEnabled, "false") + .spillDirectory(spillDirectory->getPath()) + .assertResults( + "SELECT *, row_number() over (partition by p order by s desc) FROM tmp ORDER BY s"); + + auto taskStats = exec::toPlanStats(task->taskStats()); + const auto& stats = taskStats.at(windowId); + + ASSERT_GT(stats.spilledBytes, 0); + ASSERT_GT(stats.spilledRows, 0); + ASSERT_GT(stats.spilledFiles, 0); + ASSERT_GT(stats.spilledPartitions, 0); +} + DEBUG_ONLY_TEST_F(WindowTest, aggregationWithNonDefaultFrame) { const vector_size_t size = 1'00; @@ -436,9 +663,8 @@ TEST_F(WindowTest, missingFunctionSignature) { auto callExpr = std::make_shared( BIGINT(), - std::vector{ - std::make_shared(VARCHAR(), "c1")}, - "sum"); + "sum", + std::make_shared(VARCHAR(), "c1")); VELOX_ASSERT_THROW( runWindow(callExpr), @@ -446,9 +672,8 @@ TEST_F(WindowTest, missingFunctionSignature) { callExpr = std::make_shared( VARCHAR(), - std::vector{ - std::make_shared(BIGINT(), "c2")}, - "sum"); + "sum", + std::make_shared(BIGINT(), "c2")); VELOX_ASSERT_THROW( runWindow(callExpr), @@ -552,10 +777,11 @@ TEST_F(WindowTest, nagativeFrameArg) { .planNode(); VELOX_ASSERT_USER_THROW( AssertQueryBuilder(plan, duckDbQueryRunner_) - .assertResults(fmt::format( - "SELECT *, regr_count(c0, c1) over (partition by p0, p1 order by row_number ROWS between {} PRECEDING and {} FOLLOWING) from tmp", - startOffset, - endOffset)), + .assertResults( + fmt::format( + "SELECT *, regr_count(c0, c1) over (partition by p0, p1 order by row_number ROWS between {} PRECEDING and {} FOLLOWING) from tmp", + startOffset, + endOffset)), testData.debugString()); } } @@ -645,11 +871,12 @@ DEBUG_ONLY_TEST_F(WindowTest, reserveMemorySort) { for (const auto [usePrefixSort, spillEnabled, enableSpillPrefixSort] : testSettings) { - SCOPED_TRACE(fmt::format( - "usePrefixSort: {}, spillEnabled: {}, enableSpillPrefixSort: {}", - usePrefixSort, - spillEnabled, - enableSpillPrefixSort)); + SCOPED_TRACE( + fmt::format( + "usePrefixSort: {}, spillEnabled: {}, enableSpillPrefixSort: {}", + usePrefixSort, + spillEnabled, + enableSpillPrefixSort)); auto spillDirectory = exec::test::TempDirectoryPath::create(); auto spillConfig = getSpillConfig(spillDirectory->getPath(), enableSpillPrefixSort); @@ -658,12 +885,14 @@ DEBUG_ONLY_TEST_F(WindowTest, reserveMemorySort) { velox::common::PrefixSortConfig prefixSortConfig = velox::common::PrefixSortConfig{ std::numeric_limits::max(), 130, 12}; + folly::Synchronized opStats; auto sortWindowBuild = std::make_unique( plan, pool_.get(), std::move(prefixSortConfig), spillEnabled ? &spillConfig : nullptr, &nonReclaimableSection_, + &opStats, &spillStats); TestScopedSpillInjection scopedSpillInjection(0); @@ -713,18 +942,20 @@ TEST_F(WindowTest, NaNFrameBound) { if (startBound == "following" && endBound == "preceding") { continue; } - frames.push_back(fmt::format( - "{} over (order by s0 {} range between off0 {} and off1 {})", - call, - order, - startBound, - endBound)); - frames.push_back(fmt::format( - "{} over (order by s0 {} range between off1 {} and off0 {})", - call, - order, - startBound, - endBound)); + frames.push_back( + fmt::format( + "{} over (order by s0 {} range between off0 {} and off1 {})", + call, + order, + startBound, + endBound)); + frames.push_back( + fmt::format( + "{} over (order by s0 {} range between off1 {} and off0 {})", + call, + order, + startBound, + endBound)); } } } @@ -749,5 +980,63 @@ TEST_F(WindowTest, NaNFrameBound) { } } +DEBUG_ONLY_TEST_F(WindowTest, releaseWindowBuildInTime) { + const vector_size_t size = 1'000; + auto data = makeRowVector( + {"d", "p0", "s"}, + { + // Payload Data. + makeFlatVector(size, [](auto row) { return row; }), + // Partition key. + makeFlatVector(size, [](auto row) { return row % 11; }), + // Sorting key. + makeFlatVector(size, [](auto row) { return row; }), + }); + + createDuckDbTable({data}); + + core::PlanNodeId windowId; + core::PlanNodeId orderById; + auto plan = PlanBuilder() + .values(split(data, 10)) + .window({"row_number() over (partition by p0 order by s)"}) + .capturePlanNodeId(windowId) + .orderBy({"d"}, false) + .capturePlanNodeId(orderById) + .planNode(); + + std::atomic windowPool{nullptr}; + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::getOutput", + std::function([&](exec::Operator* op) { + auto* windowOp = dynamic_cast(op); + if (windowOp == nullptr || windowPool != nullptr) { + return; + } + windowPool = windowOp->pool(); + })); + + std::atomic_bool checkOnce{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::noMoreInput", + std::function([&](exec::Operator* op) { + if (dynamic_cast(op) == nullptr || + checkOnce.exchange(true)) { + return; + } + ASSERT_LT( + windowPool.load()->usedBytes(), windowPool.load()->peakBytes() / 3); + })); + + auto task = + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchBytes, "1024") + .assertResults( + "SELECT *, row_number() over (partition by p0 order by s) " + "FROM tmp " + "ORDER BY d"); +} + } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/data/decimal.orc b/velox/exec/tests/data/decimal.orc new file mode 100644 index 000000000000..b89662b65f68 Binary files /dev/null and b/velox/exec/tests/data/decimal.orc differ diff --git a/velox/exec/tests/utils/AggregationResolver.cpp b/velox/exec/tests/utils/AggregationResolver.cpp new file mode 100644 index 000000000000..44cc0058b143 --- /dev/null +++ b/velox/exec/tests/utils/AggregationResolver.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/AggregationResolver.h" + +#include "velox/exec/Aggregate.h" +#include "velox/expression/FunctionCallToSpecialForm.h" + +namespace facebook::velox::exec::test { +namespace { +std::string throwAggregateFunctionDoesntExist(const std::string& name) { + std::stringstream error; + error << "Aggregate function doesn't exist: " << name << "."; + exec::aggregateFunctions().withRLock([&](const auto& functionsMap) { + if (functionsMap.empty()) { + error << " Registry of aggregate functions is empty. " + "Make sure to register some aggregate functions."; + } + }); + VELOX_USER_FAIL(error.str()); +} + +std::string throwAggregateFunctionSignatureNotSupported( + const std::string& name, + const std::vector& types, + const std::vector>& + signatures) { + std::stringstream error; + error << "Aggregate function signature is not supported: " + << toString(name, types) + << ". Supported signatures: " << toString(signatures) << "."; + VELOX_USER_FAIL(error.str()); +} +} // namespace + +TypePtr resolveAggregateType( + const std::string& aggregateName, + core::AggregationNode::Step step, + const std::vector& rawInputTypes, + bool nullOnFailure) { + if (auto signatures = exec::getAggregateFunctionSignatures(aggregateName)) { + for (const auto& signature : signatures.value()) { + exec::SignatureBinder binder(*signature, rawInputTypes); + if (binder.tryBind()) { + return binder.tryResolveType( + exec::isPartialOutput(step) ? signature->intermediateType() + : signature->returnType()); + } + } + + if (nullOnFailure) { + return nullptr; + } + + throwAggregateFunctionSignatureNotSupported( + aggregateName, rawInputTypes, signatures.value()); + } + + // We may be parsing lambda expression used in a lambda aggregate function. In + // this case, 'aggregateName' would refer to a scalar function. + // + // TODO Enhance the parser to allow for specifying separate resolver for + // lambda expressions. + if (auto type = + exec::resolveTypeForSpecialForm(aggregateName, rawInputTypes)) { + return type; + } + + if (auto type = parse::resolveScalarFunctionType( + aggregateName, rawInputTypes, true)) { + return type; + } + + if (nullOnFailure) { + return nullptr; + } + + throwAggregateFunctionDoesntExist(aggregateName); + return nullptr; +} + +AggregateTypeResolver::AggregateTypeResolver(core::AggregationNode::Step step) + : step_(step), previousHook_(core::Expressions::getResolverHook()) { + core::Expressions::setTypeResolverHook( + [&](const auto& inputs, const auto& expr, bool nullOnFailure) { + return resolveType(inputs, expr, nullOnFailure); + }); +} + +AggregateTypeResolver::~AggregateTypeResolver() { + core::Expressions::setTypeResolverHook(previousHook_); +} + +TypePtr AggregateTypeResolver::resolveType( + const std::vector& inputs, + const std::shared_ptr& expr, + bool nullOnFailure) const { + auto functionName = expr->name(); + + // Use raw input types (if available) to resolve intermediate and final + // result types. + if (exec::isRawInput(step_)) { + std::vector types; + for (auto& input : inputs) { + types.push_back(input->type()); + } + + return resolveAggregateType(functionName, step_, types, nullOnFailure); + } + + if (!rawInputTypes_.empty()) { + return resolveAggregateType( + functionName, step_, rawInputTypes_, nullOnFailure); + } + + if (!nullOnFailure) { + VELOX_USER_FAIL( + "Cannot resolve aggregation function return type without raw input types: {}", + functionName); + } + return nullptr; +} +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/AggregationResolver.h b/velox/exec/tests/utils/AggregationResolver.h new file mode 100644 index 000000000000..65cefb6620c6 --- /dev/null +++ b/velox/exec/tests/utils/AggregationResolver.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/ITypedExpr.h" +#include "velox/core/PlanNode.h" +#include "velox/parse/Expressions.h" +#include "velox/parse/TypeResolver.h" + +namespace facebook::velox::exec::test { +class AggregateTypeResolver { + public: + explicit AggregateTypeResolver(core::AggregationNode::Step step); + + ~AggregateTypeResolver(); + + void setRawInputTypes(const std::vector& types) { + rawInputTypes_ = types; + } + + private: + TypePtr resolveType( + const std::vector& inputs, + const std::shared_ptr& expr, + bool nullOnFailure) const; + + const core::AggregationNode::Step step_; + const core::Expressions::TypeResolverHook previousHook_; + std::vector rawInputTypes_; +}; + +/// Resolves the output type for an aggregate function given its name, +/// aggregation step, and input types. +/// +/// The resolution process follows these steps: +/// 1. Looks up registered aggregate function signatures by name +/// 2. Attempts to bind the provided raw input types to available signatures +/// 3. Returns the appropriate output type based on the aggregation step: +/// - For partial steps (kPartial, kIntermediate): returns intermediate type +/// - For final steps (kSingle, kFinal): returns the final result type +/// 4. Falls back to scalar function or special form resolution if aggregate +/// function lookup fails (useful for lambda expressions in aggregates) +/// +/// @param aggregateName The name of the aggregate function (e.g., "sum", +/// "count", "avg") +/// @param step The aggregation step indicating the stage in multi-phase +/// aggregation +/// @param rawInputTypes The types of the raw input arguments to the aggregate +/// function +/// @param nullOnFailure If true, returns nullptr when type resolution fails; +/// if false, throws an exception on failure +/// @return The resolved output type for the aggregate function, or nullptr if +/// nullOnFailure is true and resolution fails +/// @throws VeloxUserError if type resolution fails and nullOnFailure is false +TypePtr resolveAggregateType( + const std::string& aggregateName, + core::AggregationNode::Step step, + const std::vector& rawInputTypes, + bool nullOnFailure); +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/ArbitratorTestUtil.cpp b/velox/exec/tests/utils/ArbitratorTestUtil.cpp index 8c3bb9dc8ace..5499898d0666 100644 --- a/velox/exec/tests/utils/ArbitratorTestUtil.cpp +++ b/velox/exec/tests/utils/ArbitratorTestUtil.cpp @@ -19,9 +19,6 @@ #include "velox/dwio/dwrf/common/Config.h" #include "velox/exec/TableWriter.h" -using namespace facebook::velox; -using namespace facebook::velox::exec; -using namespace facebook::velox::exec::test; using namespace facebook::velox::memory; namespace facebook::velox::exec::test { @@ -31,18 +28,11 @@ std::shared_ptr newQueryCtx( folly::Executor* executor, int64_t memoryCapacity, const std::string& queryId) { - std::unordered_map> configs; - std::shared_ptr pool = - memoryManager->addRootPool("", memoryCapacity); - auto queryCtx = core::QueryCtx::create( - executor, - core::QueryConfig({}), - configs, - cache::AsyncDataCache::getInstance(), - std::move(pool), - nullptr, - queryId); - return queryCtx; + return core::QueryCtx::Builder() + .executor(executor) + .pool(memoryManager->addRootPool("", memoryCapacity)) + .queryId(queryId) + .build(); } std::unique_ptr createMemoryManager( diff --git a/velox/exec/tests/utils/ArbitratorTestUtil.h b/velox/exec/tests/utils/ArbitratorTestUtil.h index 8021db656081..6260c78883ed 100644 --- a/velox/exec/tests/utils/ArbitratorTestUtil.h +++ b/velox/exec/tests/utils/ArbitratorTestUtil.h @@ -112,7 +112,7 @@ std::shared_ptr newQueryCtx( std::unique_ptr createMemoryManager( int64_t arbitratorCapacity = kMemoryCapacity, uint64_t memoryPoolInitCapacity = kMemoryPoolInitCapacity, - uint64_t maxReclaimWaitMs = 5 * 60 * 1'000, + uint64_t maxReclaimWaitMs = 60 * 1'000, uint64_t fastExponentialGrowthCapacityLimit = 0, double slowCapacityGrowPct = 0); diff --git a/velox/exec/tests/utils/AssertQueryBuilder.cpp b/velox/exec/tests/utils/AssertQueryBuilder.cpp index 1b68b77cf717..26008449aedb 100644 --- a/velox/exec/tests/utils/AssertQueryBuilder.cpp +++ b/velox/exec/tests/utils/AssertQueryBuilder.cpp @@ -266,7 +266,26 @@ RowVectorPtr AssertQueryBuilder::copyResults( return copy; } -uint64_t AssertQueryBuilder::runWithoutResults(std::shared_ptr& task) { +std::vector AssertQueryBuilder::copyResultBatches( + memory::MemoryPool* pool) { + auto [cursor, results] = readCursor(); + + if (results.empty()) { + return results; + } + + std::vector copies; + copies.reserve(results.size()); + for (const auto& result : results) { + copies.push_back( + BaseVector::create(result->type(), result->size(), pool)); + copies.back()->copy(result.get(), 0, 0, result->size()); + } + + return copies; +} + +uint64_t AssertQueryBuilder::countResults(std::shared_ptr& task) { auto [cursor, results] = readCursor(); uint64_t count = 0; for (const auto& result : results) { @@ -276,6 +295,11 @@ uint64_t AssertQueryBuilder::runWithoutResults(std::shared_ptr& task) { return count; } +uint64_t AssertQueryBuilder::countResults() { + std::shared_ptr task; + return countResults(task); +} + std::pair, std::vector> AssertQueryBuilder::readCursor() { VELOX_CHECK_NOT_NULL(params_.planNode); @@ -287,17 +311,14 @@ AssertQueryBuilder::readCursor() { static std::atomic cursorQueryId{0}; const std::string queryId = fmt::format("TaskCursorQuery_{}", cursorQueryId++); - auto queryPool = memory::memoryManager()->addRootPool( - queryId, params_.maxQueryCapacity); - params_.queryCtx = core::QueryCtx::create( - executor_.get(), - core::QueryConfig({}), - std:: - unordered_map>{}, - cache::AsyncDataCache::getInstance(), - std::move(queryPool), - nullptr, - queryId); + + params_.queryCtx = core::QueryCtx::Builder() + .executor(executor_.get()) + .pool( + memory::memoryManager()->addRootPool( + queryId, params_.maxQueryCapacity)) + .queryId(queryId) + .build(); } } if (!configs_.empty()) { diff --git a/velox/exec/tests/utils/AssertQueryBuilder.h b/velox/exec/tests/utils/AssertQueryBuilder.h index 8e9f15d9831b..127d142f1168 100644 --- a/velox/exec/tests/utils/AssertQueryBuilder.h +++ b/velox/exec/tests/utils/AssertQueryBuilder.h @@ -15,6 +15,7 @@ */ #pragma once +#include #include "velox/exec/tests/utils/QueryAssertions.h" namespace facebook::velox::exec::test { @@ -182,8 +183,7 @@ class AssertQueryBuilder { const TypePtr& expectedType, vector_size_t expectedNumRows); - /// Run the query and collect all results into a single vector. Throws if - /// query returns empty result. + /// Run the query and collect all results into a single vector. RowVectorPtr copyResults(memory::MemoryPool* pool); /// Similar to above method and also returns the task. @@ -191,8 +191,15 @@ class AssertQueryBuilder { memory::MemoryPool* pool, std::shared_ptr& task); + /// Run the query and copy the result Vectors as their original batches. + std::vector copyResultBatches(memory::MemoryPool* pool); + /// Run the query and return the number of result rows. - uint64_t runWithoutResults(std::shared_ptr& task); + uint64_t countResults(std::shared_ptr& task); + + /// Run the query and return the number of result rows without requiring a + /// task parameter. + uint64_t countResults(); private: std::pair, std::vector> @@ -200,7 +207,7 @@ class AssertQueryBuilder { static std::unique_ptr newExecutor() { return std::make_unique( - std::thread::hardware_concurrency()); + folly::hardware_concurrency()); } // Used by the created task as the default driver executor. diff --git a/velox/exec/tests/utils/CMakeLists.txt b/velox/exec/tests/utils/CMakeLists.txt index acc1119c9942..cf3ef9dbb102 100644 --- a/velox/exec/tests/utils/CMakeLists.txt +++ b/velox/exec/tests/utils/CMakeLists.txt @@ -14,13 +14,14 @@ add_library(velox_temp_path TempFilePath.cpp TempDirectoryPath.cpp) -target_link_libraries( - velox_temp_path velox_exception) +target_link_libraries(velox_temp_path velox_exception) add_library( velox_exec_test_lib + AggregationResolver.cpp AssertQueryBuilder.cpp ArbitratorTestUtil.cpp + FilterToExpression.cpp HiveConnectorTestBase.cpp IndexLookupJoinTestBase.cpp LocalExchangeSource.cpp @@ -34,7 +35,8 @@ add_library( TpchQueryBuilder.cpp VectorTestUtil.cpp PortUtil.cpp - SerializedPageUtil.cpp) + SerializedPageUtil.cpp +) target_link_libraries( velox_exec_test_lib @@ -50,27 +52,14 @@ target_link_libraries( velox_dwio_common velox_dwio_dwrf_reader velox_dwio_dwrf_writer + velox_dwio_text_reader_register velox_dwio_common_test_utils velox_file_test_utils velox_type_fbhive velox_hive_connector velox_tpch_connector + velox_tpcds_connector velox_presto_serializer velox_functions_prestosql - velox_aggregates) - -if(${VELOX_BUILD_RUNNER}) - add_library(velox_exec_runner_test_util DistributedPlanBuilder.cpp - LocalRunnerTestBase.cpp) - - target_link_libraries( - velox_exec_runner_test_util - velox_temp_path - velox_exec_test_lib - velox_exec - velox_file_test_utils - velox_hive_connector - velox_tpch_connector - velox_local_runner) - -endif() + velox_aggregates +) diff --git a/velox/exec/tests/utils/DistributedPlanBuilder.cpp b/velox/exec/tests/utils/DistributedPlanBuilder.cpp deleted file mode 100644 index da09ee09af71..000000000000 --- a/velox/exec/tests/utils/DistributedPlanBuilder.cpp +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/exec/tests/utils/DistributedPlanBuilder.h" - -namespace facebook::velox::exec::test { - -DistributedPlanBuilder::DistributedPlanBuilder( - const runner::MultiFragmentPlan::Options& options, - std::shared_ptr planNodeIdGenerator, - memory::MemoryPool* pool) - : PlanBuilder(planNodeIdGenerator, pool), options_(options), root_(this) { - root_->stack_.push_back(this); - newFragment(); - current_->width = options_.numWorkers; -} - -DistributedPlanBuilder::DistributedPlanBuilder(DistributedPlanBuilder& root) - : PlanBuilder(root.planNodeIdGenerator(), root.pool()), - options_(root.options_), - root_(&root) { - root_->stack_.push_back(this); - newFragment(); - current_->width = options_.numWorkers; -} - -DistributedPlanBuilder::~DistributedPlanBuilder() { - VELOX_CHECK_EQ(root_->stack_.size(), 1); -} - -std::vector DistributedPlanBuilder::fragments() { - newFragment(); - return std::move(fragments_); -} - -void DistributedPlanBuilder::newFragment() { - if (current_) { - gatherScans(planNode_); - current_->fragment = core::PlanFragment(std::move(planNode_)); - fragments_.push_back(std::move(*current_)); - } - current_ = std::make_unique( - fmt::format("{}.{}", options_.queryId, root_->fragmentCounter_++)); - planNode_ = nullptr; -} - -PlanBuilder& DistributedPlanBuilder::shufflePartitioned( - const std::vector& partitionKeys, - int numPartitions, - bool replicateNullsAndAny, - const std::vector& outputLayout) { - partitionedOutput( - partitionKeys, numPartitions, replicateNullsAndAny, outputLayout); - auto* output = - dynamic_cast(planNode_.get()); - VELOX_CHECK_NOT_NULL(output); - auto producerPrefix = current_->taskPrefix; - newFragment(); - current_->width = numPartitions; - exchange(output->outputType(), VectorSerde::Kind::kPresto); - auto* exchange = dynamic_cast(planNode_.get()); - VELOX_CHECK_NOT_NULL(exchange); - current_->inputStages.push_back( - runner::InputStage{exchange->id(), producerPrefix}); - return *this; -} - -core::PlanNodePtr DistributedPlanBuilder::shufflePartitionedResult( - const std::vector& partitionKeys, - int numPartitions, - bool replicateNullsAndAny, - const std::vector& outputLayout) { - partitionedOutput( - partitionKeys, numPartitions, replicateNullsAndAny, outputLayout); - auto* output = - dynamic_cast(planNode_.get()); - VELOX_CHECK_NOT_NULL(output); - const auto producerPrefix = current_->taskPrefix; - auto result = planNode_; - newFragment(); - root_->stack_.pop_back(); - auto* consumer = root_->stack_.back(); - if (consumer->current_->width != 0) { - VELOX_CHECK_EQ( - numPartitions, - consumer->current_->width, - "The consumer width should match the producer fanout"); - } else { - consumer->current_->width = numPartitions; - } - - for (auto& fragment : fragments_) { - root_->fragments_.push_back(std::move(fragment)); - } - exchange(output->outputType(), VectorSerde::Kind::kPresto); - auto* exchange = dynamic_cast(planNode_.get()); - consumer->current_->inputStages.push_back( - runner::InputStage{exchange->id(), producerPrefix}); - return std::move(planNode_); -} - -core::PlanNodePtr DistributedPlanBuilder::shuffleBroadcastResult() { - partitionedOutputBroadcast(); - auto* output = - dynamic_cast(planNode_.get()); - VELOX_CHECK_NOT_NULL(output); - const auto producerPrefix = current_->taskPrefix; - auto result = planNode_; - newFragment(); - - VELOX_CHECK_GE(root_->stack_.size(), 2); - root_->stack_.pop_back(); - auto* consumer = root_->stack_.back(); - VELOX_CHECK_GE(consumer->current_->width, 1); - VELOX_CHECK_EQ(fragments_.back().numBroadcastDestinations, 0); - fragments_.back().numBroadcastDestinations = consumer->current_->width; - - for (auto& fragment : fragments_) { - root_->fragments_.push_back(std::move(fragment)); - } - exchange(output->outputType(), VectorSerde::Kind::kPresto); - auto* exchange = dynamic_cast(planNode_.get()); - VELOX_CHECK_NOT_NULL(exchange); - consumer->current_->inputStages.push_back( - runner::InputStage{exchange->id(), producerPrefix}); - return std::move(planNode_); -} - -void DistributedPlanBuilder::gatherScans(const core::PlanNodePtr& plan) { - if (auto scan = std::dynamic_pointer_cast(plan)) { - current_->scans.push_back(scan); - return; - } - for (auto& source : plan->sources()) { - gatherScans(source); - } -} -} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/DistributedPlanBuilder.h b/velox/exec/tests/utils/DistributedPlanBuilder.h deleted file mode 100644 index 3f54d7915339..000000000000 --- a/velox/exec/tests/utils/DistributedPlanBuilder.h +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/runner/MultiFragmentPlan.h" - -namespace facebook::velox::exec::test { - -/// Builder for distributed plans. Adds a shuffle() and related -/// methods for building PartitionedOutput-Exchange pairs between -/// fragments. Not thread safe. -class DistributedPlanBuilder : public PlanBuilder { - public: - /// Constructs a top level DistributedPlanBuilder. - DistributedPlanBuilder( - const runner::MultiFragmentPlan::Options& options, - std::shared_ptr planNodeIdGenerator, - memory::MemoryPool* pool = nullptr); - - /// Constructs a child builder. Used for branching plans, e.g. the subplan for - /// a join build side. - DistributedPlanBuilder(DistributedPlanBuilder& root); - - ~DistributedPlanBuilder() override; - - /// Returns the planned fragments. The builder will be empty after this. This - /// is only called on the root builder. - std::vector fragments(); - - PlanBuilder& shufflePartitioned( - const std::vector& keys, - int numPartitions, - bool replicateNullsAndAny, - const std::vector& outputLayout = {}) override; - - core::PlanNodePtr shufflePartitionedResult( - const std::vector& keys, - int numPartitions, - bool replicateNullsAndAny, - const std::vector& outputLayout = {}) override; - - core::PlanNodePtr shuffleBroadcastResult() override; - - private: - void newFragment(); - - void gatherScans(const core::PlanNodePtr& plan); - - const runner::MultiFragmentPlan::Options& options_; - DistributedPlanBuilder* const root_; - - // Stack of outstanding builders. The last element is the immediately - // enclosing one. When returning an ExchangeNode from returnShuffle, the stack - // is used to establish the linkage between the fragment of the returning - // builder and the fragment current in the calling builder. Only filled in the - // root builder. - std::vector stack_; - - // Fragment counter. Only used in root builder. - int32_t fragmentCounter_{0}; - - // The fragment being built. Will be moved to the root builder's 'fragments_' - // when complete. - std::unique_ptr current_; - - // The fragments gathered under this builder. Moved to the root builder when - // returning the subplan. - std::vector fragments_; -}; - -} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/ExpressionBuilder.h b/velox/exec/tests/utils/ExpressionBuilder.h new file mode 100644 index 000000000000..d2d208b29647 --- /dev/null +++ b/velox/exec/tests/utils/ExpressionBuilder.h @@ -0,0 +1,478 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/parse/Expressions.h" + +namespace facebook::velox::expr_builder { + +/// Fluent Expression Builder. +/// +/// This file contains fluent methods that make it convenient to create +/// (untyped) expression trees. This provides similar functionality to a SQL +/// parser, without bringing dependency on external libraries or bringing leaked +/// semantics from other systems. +/// +/// The untyped expressions can then be turned into typed expressions ready for +/// execution using type binding from `core::Expressions::inferTypes()`. +/// +/// The API provided is as close to the actual expression trees as possible. +/// Comparisons, arithmetics, conjuncts, function calls, literals, aliases, and +/// more are supported with this API. +/// +/// For example, to create a column reference, you can: +/// +/// > using namespace expr_builder; +/// > core::ExprPtr e = col("c0"); +/// +/// You can also use the "_c" C++ literal provided: +/// +/// > core::ExprPtr e = "c0"_c; +/// +/// Expressions created using ExpressionBuilder functions can be used in any +/// places that accept a ExprPtr. In practice, they create a ExprWrapper object, +/// but ExprWrappers are implicitly convertible to ExprSet. +/// +/// Nested column references can be specified as either: +/// +/// > col("parent", "child"); +/// > col("parent").subfield("child"); +/// +/// To debug the expression generated, you can simply: +/// +/// > LOG(INFO) << *e; +/// +/// Comparisons and other expressions can be fluently created using C++ +/// overloaded operators: +/// +/// > col("c") > 10; // "c > 10" +/// > col("c") != "bar"; // "c != 'bar'" +/// > col("c") == nullptr; // "c = null" +/// +/// C++ literals are automatically converted into ConstantExpr (expression +/// literals) when part of an expression. To explicitly create a literal you can +/// use: +/// +/// > lit(10.3); +/// > lit("str"); +/// +/// Casts can be done using one of the two formats: +/// +/// > lit(3).cast(TINYINT()); +/// > cast("str", VARBINARY()); +/// +/// Null checking filters: +/// +/// > isNull(col("c")) // "c is null" +/// > !isNull(col("c")) // "c is not null" +/// +/// Conjuncts and "between": +/// +/// > (col("a") && col("b")) || col("c"); // "(a AND b) OR c" +/// > between(col("a"), 0, 10); // "a between 0 and 10" +/// +/// You can also use fluent version of these APIs: +/// +/// > col("a").between(0, 10); // "a between 0 and 10" +/// +/// Arithmetic operators are also overloaded: +/// +/// > col("c") * 100 + col("b"); // "c * 100 + b" +/// +/// In any expression, as long as one of the sides is an expression node, the +/// correct expression will be created. For example, both version work as +/// expected: +/// +/// > col("c") * 100; // "c * 100" +/// > 100 * col("c"); // "100 * c" +/// +/// When building long expressions, be careful about C++ constant folding and +/// operator precedence: +/// +/// > col("c") + 5 * 100; +/// +/// C++ will fold "5 * 100" and generate the expression "c + 500". To force the +/// expected behavior, you can explicitly spell out the literal: +/// +/// > col("c") + 5 * lit(100); +/// > col("c") + lit(5) * 100; +/// +/// Both will generate "col + 5 * 100", which is "plus(col, multiply(5, 100))". +/// +/// Generic function calls can be created using `call()`: +/// +/// > call("func", 10); // "func(10)" +/// +/// `call()` supports arbitrary parameters, which can be other expressions or +/// (C++) literals. +/// +/// Lambdas can be created using the following syntax: +/// +/// > lambda({"x", "y"}, col("x") * col("y") + 1) +/// +/// Where the first parameter is a vector of the lambda arguments, and the +/// second the lambda body. +/// +/// All functions above can be nested and combined in arbitrary ways. +/// +/// > 10L * col("c1") > call("func", 3.4, col("g") / col("h"), call("j")); +/// +/// is the same as "10 * c1 > func(3.4, g / h, j())". +/// +/// Comparisons, arithmetics, and other operators are mapped to function names +/// according to the table below. It is the user's responsibility to make sure +/// that there names map to their appropriate implementation: +/// +/// ------------------------------- +/// | C++ | Function Name | +/// ------------------------------- +/// | operator== | eq | +/// | operator!= | neq | +/// | operator< | lt | +/// | operator<= | lte | +/// | operator> | gt | +/// | operator>= | gte | +/// | operator! | not | +/// | operator&& | and | +/// | operator|| | or | +/// | operator+ | plus | +/// | operator- | minus | +/// | operator* | multiply | +/// | operator/ | divide | +/// | operator% | mode | +/// | operator== | eq | +/// ------------------------------- + +namespace detail { + +class ExprWrapper; + +/// Either builds a ConstantExpr (literal) based on a scalar value, or passes +/// through an ExprWrapper already constructed. +template +inline ExprWrapper toExprWrapper(T value); + +// Specialization for long to avoid ambiguity. +inline ExprWrapper toExprWrapper(long value); + +template <> +inline ExprWrapper toExprWrapper(ExprWrapper expr); + +/// Wrapper library used so we can safely overload operators. +class ExprWrapper { + public: + ExprWrapper(const core::ExprPtr& expr) : expr_(expr) {} + + std::string toString() const { + return expr_->toString(); + } + + core::ExprPtr expr() const { + return expr_; + } + + /// Add an alias to the current expression: + /// + /// > col("c0").alias("my_column"); + ExprWrapper& alias(const std::string& newAlias) { + expr_ = expr_->withAlias(newAlias); + return *this; + } + + /// Add a "subfield" expression to enable access of subfields in + /// rows/structs: + /// + /// > col("parent_col").subfield("child_name"); + ExprWrapper& subfield(std::string childName) { + expr_ = std::make_shared( + std::move(childName), std::nullopt, std::vector{expr_}); + return *this; + } + + /// Add a "cast" to the current expression: + /// + /// > col("c0").cast(VARBINARY()); + /// > lit(10).cast(TINYINT()); + ExprWrapper& cast(const TypePtr& castType) { + expr_ = + std::make_shared(castType, expr_, false, std::nullopt); + return *this; + } + + /// Add a "try_cast" to the current expression: + /// + /// > col("c0").tryCast(VARBINARY()); + /// > lit(10).tryCast(TINYINT()); + ExprWrapper& tryCast(const TypePtr& castType) { + expr_ = + std::make_shared(castType, expr_, true, std::nullopt); + return *this; + } + + /// Add a "is_null" to the current expression: + /// + /// > col("c0").isNull(); + ExprWrapper& isNull() { + expr_ = std::make_shared( + "is_null", std::vector{expr_}, std::nullopt); + return *this; + } + + /// Add a "between" clause to the current expression wrapper: + /// + /// > col("a").between(1, 10); + template + ExprWrapper& between(const T1& value1, const T2& value2) { + expr_ = std::make_shared( + "between", + std::vector{ + expr_, + detail::toExprWrapper(value1), + detail::toExprWrapper(value2)}, + std::nullopt); + return *this; + } + + /// If equality is used against an actual ExpPtr (not the wrapper), this will + /// compare the expressions themselves. + /// + /// It won't assume this is generating a eq() Velox expression. + bool operator==(const core::ExprPtr& other) const { + return *expr_ == *other; + } + + /// Provide better gtest failure messages. + friend std::ostream& operator<<(std::ostream& os, const ExprWrapper& expr) { + return os << expr.expr_->toString(); + } + + /// For convenience, enable implicit conversions to ExprPtr. + operator core::ExprPtr() const { + return expr_; + } + + private: + core::ExprPtr expr_; +}; + +/// Unpacks a list of variadic template parameters in a +/// std::vector. The elements could be ExprWrapper or C++ +/// literals, which will get converted to ConstantExpr. +/// +/// Base of recursion. +inline std::vector unpackList() { + return {}; +} + +template +inline std::vector unpackList(TFirst first, TArgs&&... args) { + std::vector head = {toExprWrapper(first)}; + auto tail = unpackList(std::forward(args)...); + head.insert(head.end(), tail.begin(), tail.end()); + return head; +} + +} // namespace detail + +/// Column references. +inline detail::ExprWrapper col(std::string name) { + return {std::make_shared( + std::move(name), std::nullopt)}; +} + +/// Enable users to use a custom C++ literal to add a column reference. +/// For example: "col"_c +inline detail::ExprWrapper operator"" _c(const char* str, size_t len) { + return col(std::string(str, len)); +} + +/// Nested column names. Ror rows/struct member references. +inline detail::ExprWrapper col(std::string parentName, std::string childName) { + return col(std::move(parentName)).subfield(std::move(childName)); +} + +/// Literals. +inline detail::ExprWrapper lit(int64_t value) { + return {std::make_shared(BIGINT(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(int32_t value) { + return {std::make_shared(INTEGER(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(int16_t value) { + return { + std::make_shared(SMALLINT(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(int8_t value) { + return {std::make_shared(TINYINT(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(bool value) { + return {std::make_shared(BOOLEAN(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(double value) { + return {std::make_shared(DOUBLE(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(float value) { + return {std::make_shared(REAL(), value, std::nullopt)}; +} + +/// Different string flavors. +inline detail::ExprWrapper lit(const char* value) { + return {std::make_shared(VARCHAR(), value, std::nullopt)}; +} + +inline detail::ExprWrapper lit(const std::string_view& value) { + return {std::make_shared( + VARCHAR(), std::string(value), std::nullopt)}; +} + +inline detail::ExprWrapper lit(const std::string& value) { + return {std::make_shared(VARCHAR(), value, std::nullopt)}; +} + +/// lit(nullptr). +inline detail::ExprWrapper lit(std::nullptr_t) { + return {std::make_shared( + UNKNOWN(), variant::null(TypeKind::UNKNOWN), std::nullopt)}; +} + +/// Macro to reduce builerplate when overloading C++ operators. The template +/// magic basically means that the overload is matched if either left or right +/// operands are an ExprWrapper. This is so that both "c"_f + 10 and 10 + "c"_f +/// are supported, for example. +/// +/// If either left or right side are ExprWrapper, we either convert the other +/// side as a constant/literal, or use it as-is if it is already an ExprWrapper. +#define VELOX_EXPR_BUILDER_OPERATOR(__op, __name) \ + template \ + inline std::enable_if_t< \ + std::is_same_v || \ + std::is_same_v, \ + detail::ExprWrapper> \ + __op(T1 lhs, T2 rhs) { \ + return {std::make_shared( \ + __name, \ + std::vector{ \ + detail::toExprWrapper(lhs), detail::toExprWrapper(rhs)}, \ + std::nullopt)}; \ + } + +/// Define C++ operator overload for comparisons. +VELOX_EXPR_BUILDER_OPERATOR(operator==, "eq"); +VELOX_EXPR_BUILDER_OPERATOR(operator!=, "neq"); +VELOX_EXPR_BUILDER_OPERATOR(operator<, "lt"); +VELOX_EXPR_BUILDER_OPERATOR(operator<=, "lte"); +VELOX_EXPR_BUILDER_OPERATOR(operator>, "gt"); +VELOX_EXPR_BUILDER_OPERATOR(operator>=, "gte"); + +/// Define C++ operator overload for arithmetics. +VELOX_EXPR_BUILDER_OPERATOR(operator+, "plus"); +VELOX_EXPR_BUILDER_OPERATOR(operator-, "minus"); +VELOX_EXPR_BUILDER_OPERATOR(operator*, "multiply"); +VELOX_EXPR_BUILDER_OPERATOR(operator/, "divide"); +VELOX_EXPR_BUILDER_OPERATOR(operator%, "mod"); + +VELOX_EXPR_BUILDER_OPERATOR(operator&&, "and"); +VELOX_EXPR_BUILDER_OPERATOR(operator||, "or"); + +/// "not" is an unary operator. +inline detail::ExprWrapper operator!(detail::ExprWrapper expr) { + return {std::make_shared( + "not", std::vector{expr.expr()}, std::nullopt)}; +} + +/// "is_null" is also unary. +template +inline detail::ExprWrapper isNull(const T& expr) { + return detail::toExprWrapper(expr).isNull(); +} + +/// "alias" as a free function. +template +inline detail::ExprWrapper alias(TInput lhs, const std::string& newAlias) { + return detail::toExprWrapper(lhs).alias(newAlias); +} + +/// "cast" as a free function. +template +inline detail::ExprWrapper cast(TInput lhs, const TypePtr& castType) { + return detail::toExprWrapper(lhs).cast(castType); +} + +/// "tryCast" as a free function. +template +inline detail::ExprWrapper tryCast(TInput lhs, const TypePtr& castType) { + return detail::toExprWrapper(lhs).tryCast(castType); +} + +/// "between" as a free function. +template +inline detail::ExprWrapper +between(detail::ExprWrapper lhs, const T1& value1, const T2& value2) { + return lhs.between(value1, value2); +} + +/// Creates a lambda expressions, given the function parameters and an +/// expression for the function body. +template +inline detail::ExprWrapper lambda( + std::initializer_list args, + const TInput& body) { + return {std::make_shared( + std::move(args), detail::toExprWrapper(body))}; +} + +/// Convenience lambda builder for single argument lambdas. +template +inline detail::ExprWrapper lambda(std::string arg, const TInput& body) { + return lambda({std::move(arg)}, body); +} + +/// Regular function calls. First parameter is the function name, followed by +/// their parameters. Parameters can be other expression nodes or literals. +template +inline detail::ExprWrapper call(std::string name, TArgs&&... args) { + return {std::make_shared( + std::move(name), + detail::unpackList(std::forward(args)...), + std::nullopt)}; +} + +namespace detail { + +template +inline ExprWrapper toExprWrapper(T value) { + return lit(value); +} + +inline ExprWrapper toExprWrapper(long value) { + return lit(static_cast(value)); +} + +template <> +inline ExprWrapper toExprWrapper(ExprWrapper expr) { + return expr; +} + +} // namespace detail + +} // namespace facebook::velox::expr_builder diff --git a/velox/exec/tests/utils/FilterToExpression.cpp b/velox/exec/tests/utils/FilterToExpression.cpp new file mode 100644 index 000000000000..7546694935d7 --- /dev/null +++ b/velox/exec/tests/utils/FilterToExpression.cpp @@ -0,0 +1,634 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/FilterToExpression.h" +#include "velox/core/Expressions.h" + +namespace facebook::velox::core::test { + +core::TypedExprPtr createBooleanExpr( + const std::vector& conditions) { + if (conditions.empty()) { + return std::make_shared(BOOLEAN(), variant(true)); + } else if (conditions.size() == 1) { + return conditions[0]; + } else { + return std::make_shared(BOOLEAN(), conditions, "and"); + } +} + +core::TypedExprPtr handleNullAllowed( + const core::TypedExprPtr& expression, + const core::TypedExprPtr& subfieldExpr, + bool nullAllowed) { + if (nullAllowed) { + auto isNullExpr = + std::make_shared(BOOLEAN(), "is_null", subfieldExpr); + + return std::make_shared( + BOOLEAN(), "or", expression, isNullExpr); + } else { + return expression; + } +} + +core::TypedExprPtr getTypedExprFromSubfield( + const common::Subfield& subfield, + const RowTypePtr& rowType) { + VELOX_CHECK(subfield.valid(), "Invalid subfield"); + + const auto& path = subfield.path(); + const auto& fieldName = subfield.baseName(); + + core::TypedExprPtr expr = std::make_shared( + rowType->findChild(fieldName), fieldName); + for (size_t i = 1; i < path.size(); ++i) { + const auto& element = path[i]; + + switch (element->kind()) { + case common::SubfieldKind::kNestedField: { + const auto* nestedField = element->as(); + const auto& nestedName = nestedField->name(); + + VELOX_CHECK( + expr->type()->isRow(), "Expected ROW type for nested field access"); + const auto& rowType = expr->type()->asRow(); + + expr = std::make_shared( + rowType.findChild(nestedName), expr, nestedName); + break; + } + + default: + VELOX_NYI( + "SubfieldKind other than kNestedField are not supported: {}", + common::SubfieldKindName::toName(element->kind())); + } + } + + return expr; +} + +core::TypedExprPtr filterToExpr( + const common::Subfield& subfield, + const common::Filter* filter, + const RowTypePtr& rowType, + memory::MemoryPool* pool) { + auto subfieldExpr = getTypedExprFromSubfield(subfield, rowType); + + auto subfieldType = subfieldExpr->type(); + auto isComplexType = + (subfieldType->kind() == TypeKind::ROW || + subfieldType->kind() == TypeKind::MAP || + subfieldType->kind() == TypeKind::ARRAY); + + // Complex types (ROW, MAP, ARRAY) only support IsNull and IsNotNull filters. + if (isComplexType && filter->kind() != common::FilterKind::kIsNull && + filter->kind() != common::FilterKind::kIsNotNull) { + VELOX_FAIL( + "Should not get any filter on complex type other than IsNull or IsNotNUll"); + } + + switch (filter->kind()) { + case common::FilterKind::kAlwaysFalse: + return std::make_shared(BOOLEAN(), variant(false)); + + case common::FilterKind::kAlwaysTrue: + return std::make_shared(BOOLEAN(), variant(true)); + + case common::FilterKind::kIsNull: + return std::make_shared( + BOOLEAN(), "is_null", subfieldExpr); + + case common::FilterKind::kIsNotNull: { + auto isNullExpr = + std::make_shared(BOOLEAN(), "is_null", subfieldExpr); + return std::make_shared(BOOLEAN(), "not", isNullExpr); + } + + case common::FilterKind::kBoolValue: { + auto boolFilter = static_cast(filter); + auto boolValue = std::make_shared( + BOOLEAN(), variant(boolFilter->testBool(true))); + + auto eqExpr = std::make_shared( + BOOLEAN(), "eq", subfieldExpr, boolValue); + + return handleNullAllowed(eqExpr, subfieldExpr, boolFilter->nullAllowed()); + } + + case common::FilterKind::kBigintRange: { + auto rangeFilter = static_cast(filter); + auto subfieldType = subfieldExpr->type(); + + // Special handling for TPCH DATE types. + if (subfieldType->isDate()) { + const int64_t kMaxInt64 = std::numeric_limits::max(); + const int64_t kMinInt64 = std::numeric_limits::min(); + + std::vector conditions; + + if (rangeFilter->lower() > kMinInt64) { + auto lower = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->lower()))); + + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); + } + + if (rangeFilter->upper() < kMaxInt64) { + auto upper = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->upper()))); + + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); + } + + auto rangeExpr = createBooleanExpr(conditions); + return handleNullAllowed( + rangeExpr, subfieldExpr, rangeFilter->nullAllowed()); + } else { + ConstantTypedExprPtr lower; + ConstantTypedExprPtr upper; + + switch (subfieldType->kind()) { + case TypeKind::TINYINT: + lower = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->lower()))); + upper = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->upper()))); + break; + case TypeKind::SMALLINT: + lower = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->lower()))); + upper = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->upper()))); + break; + case TypeKind::INTEGER: + lower = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->lower()))); + upper = std::make_shared( + subfieldType, + variant(static_cast(rangeFilter->upper()))); + break; + case TypeKind::BIGINT: + default: + lower = std::make_shared( + subfieldType, variant(rangeFilter->lower())); + upper = std::make_shared( + subfieldType, variant(rangeFilter->upper())); + break; + } + + CallTypedExprPtr rangeExpr; + if (rangeFilter->lower() == rangeFilter->upper()) { + rangeExpr = std::make_shared( + BOOLEAN(), "eq", subfieldExpr, lower); + } else { + auto greaterOrEqual = std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower); + + auto lessOrEqual = std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper); + + rangeExpr = std::make_shared( + BOOLEAN(), "and", greaterOrEqual, lessOrEqual); + } + + return handleNullAllowed( + rangeExpr, subfieldExpr, rangeFilter->nullAllowed()); + } + } + + case common::FilterKind::kNegatedBigintRange: { + auto negatedRangeFilter = + static_cast(filter); + bool nullAllowed = negatedRangeFilter->nullAllowed(); + const common::Filter* nonNegatedFilter = + negatedRangeFilter->getNonNegated(); + + auto rangeExpr = filterToExpr(subfield, nonNegatedFilter, rowType, pool); + + auto notRangeExpr = + std::make_shared(BOOLEAN(), "not", rangeExpr); + + return handleNullAllowed(notRangeExpr, subfieldExpr, nullAllowed); + } + + case common::FilterKind::kDoubleRange: { + auto doubleFilter = static_cast(filter); + auto subfieldType = subfieldExpr->type(); + + std::vector conditions; + + if (!doubleFilter->lowerUnbounded()) { + ConstantTypedExprPtr lower; + + switch (subfieldType->kind()) { + case TypeKind::REAL: + lower = std::make_shared( + subfieldType, + variant(static_cast(doubleFilter->lower()))); + break; + case TypeKind::DOUBLE: + default: + lower = std::make_shared( + subfieldType, variant(doubleFilter->lower())); + break; + } + + if (doubleFilter->lowerExclusive()) { + conditions.push_back( + std::make_shared( + BOOLEAN(), "gt", subfieldExpr, lower)); + } else { + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); + } + } + + if (!doubleFilter->upperUnbounded()) { + ConstantTypedExprPtr upper; + + switch (subfieldType->kind()) { + case TypeKind::REAL: + upper = std::make_shared( + subfieldType, + variant(static_cast(doubleFilter->upper()))); + break; + case TypeKind::DOUBLE: + default: + upper = std::make_shared( + subfieldType, variant(doubleFilter->upper())); + break; + } + + if (doubleFilter->upperExclusive()) { + conditions.push_back( + std::make_shared( + BOOLEAN(), "lt", subfieldExpr, upper)); + } else { + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); + } + } + + auto rangeExpr = createBooleanExpr(conditions); + return handleNullAllowed( + rangeExpr, subfieldExpr, doubleFilter->nullAllowed()); + } + + case common::FilterKind::kFloatRange: { + auto floatFilter = static_cast(filter); + auto subfieldType = subfieldExpr->type(); + + std::vector conditions; + + if (!floatFilter->lowerUnbounded()) { + auto lowerValue = floatFilter->lower(); + auto lower = std::make_shared( + subfieldType, variant(static_cast(lowerValue))); + + if (floatFilter->lowerExclusive()) { + conditions.push_back( + std::make_shared( + BOOLEAN(), "gt", subfieldExpr, lower)); + } else { + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); + } + } + + if (!floatFilter->upperUnbounded()) { + auto upperValue = floatFilter->upper(); + auto upper = std::make_shared( + subfieldType, variant(static_cast(upperValue))); + + if (floatFilter->upperExclusive()) { + conditions.push_back( + std::make_shared( + BOOLEAN(), "lt", subfieldExpr, upper)); + } else { + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); + } + } + + auto rangeExpr = createBooleanExpr(conditions); + return handleNullAllowed( + rangeExpr, subfieldExpr, floatFilter->nullAllowed()); + } + + case common::FilterKind::kBytesRange: { + auto bytesFilter = static_cast(filter); + auto subfieldType = subfieldExpr->type(); + + std::vector conditions; + + if (!bytesFilter->isLowerUnbounded()) { + auto lower = std::make_shared( + subfieldType, variant(bytesFilter->lower())); + + if (bytesFilter->isLowerExclusive()) { + conditions.push_back( + std::make_shared( + BOOLEAN(), "gt", subfieldExpr, lower)); + } else { + conditions.push_back( + std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower)); + } + } + + if (!bytesFilter->isUpperUnbounded()) { + auto upper = std::make_shared( + subfieldType, variant(bytesFilter->upper())); + + if (bytesFilter->isUpperExclusive()) { + conditions.push_back( + std::make_shared( + BOOLEAN(), "lt", subfieldExpr, upper)); + } else { + conditions.push_back( + std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper)); + } + } + + auto rangeExpr = createBooleanExpr(conditions); + return handleNullAllowed( + rangeExpr, subfieldExpr, bytesFilter->nullAllowed()); + } + + case common::FilterKind::kBigintValuesUsingHashTable: + case common::FilterKind::kBigintValuesUsingBitmask: { + std::vector values; + int64_t min, max; + bool nullAllowed; + + if (filter->kind() == common::FilterKind::kBigintValuesUsingHashTable) { + auto hashFilter = + static_cast(filter); + values = hashFilter->values(); + min = hashFilter->min(); + max = hashFilter->max(); + nullAllowed = hashFilter->nullAllowed(); + } else { + auto bitmaskFilter = + static_cast(filter); + values = bitmaskFilter->values(); + min = bitmaskFilter->min(); + max = bitmaskFilter->max(); + nullAllowed = bitmaskFilter->nullAllowed(); + } + + auto subfieldType = subfieldExpr->type(); + std::vector valueExprs; + valueExprs.reserve(values.size()); + + for (const auto& value : values) { + switch (subfieldType->kind()) { + case TypeKind::TINYINT: + valueExprs.push_back( + std::make_shared( + subfieldType, variant(static_cast(value)))); + break; + case TypeKind::SMALLINT: + valueExprs.push_back( + std::make_shared( + subfieldType, variant(static_cast(value)))); + break; + case TypeKind::INTEGER: + valueExprs.push_back( + std::make_shared( + subfieldType, variant(static_cast(value)))); + break; + case TypeKind::BIGINT: + default: + valueExprs.push_back( + std::make_shared( + subfieldType, variant(value))); + break; + } + } + + auto arrayExpr = std::make_shared( + ARRAY(subfieldType), valueExprs, "array_constructor"); + + auto inExpr = std::make_shared( + BOOLEAN(), "in", subfieldExpr, arrayExpr); + + // Optimization: Add range check (field >= min AND field <= max) before IN + // check. + ConstantTypedExprPtr minConstant, maxConstant; + switch (subfieldType->kind()) { + case TypeKind::TINYINT: + minConstant = std::make_shared( + subfieldType, variant(static_cast(min))); + maxConstant = std::make_shared( + subfieldType, variant(static_cast(max))); + break; + case TypeKind::SMALLINT: + minConstant = std::make_shared( + subfieldType, variant(static_cast(min))); + maxConstant = std::make_shared( + subfieldType, variant(static_cast(max))); + break; + case TypeKind::INTEGER: + minConstant = std::make_shared( + subfieldType, variant(static_cast(min))); + maxConstant = std::make_shared( + subfieldType, variant(static_cast(max))); + break; + case TypeKind::BIGINT: + default: + minConstant = + std::make_shared(subfieldType, variant(min)); + maxConstant = + std::make_shared(subfieldType, variant(max)); + break; + } + + // Create a range check expression to validate field bounds. + auto gteMinExpr = std::make_shared( + BOOLEAN(), "gte", subfieldExpr, minConstant); + auto lteMaxExpr = std::make_shared( + BOOLEAN(), "lte", subfieldExpr, maxConstant); + auto rangeCheckExpr = std::make_shared( + BOOLEAN(), "and", gteMinExpr, lteMaxExpr); + + // Combine the range check with the IN check for optimization. + auto optimizedInExpr = std::make_shared( + BOOLEAN(), "and", rangeCheckExpr, inExpr); + + return handleNullAllowed(optimizedInExpr, subfieldExpr, nullAllowed); + } + + case common::FilterKind::kNegatedBigintValuesUsingHashTable: { + auto hashFilter = + static_cast(filter); + bool nullAllowed = hashFilter->nullAllowed(); + const common::Filter* nonNegatedFilter = hashFilter->getNonNegated(); + + auto inExpr = filterToExpr(subfield, nonNegatedFilter, rowType, pool); + + auto notInExpr = + std::make_shared(BOOLEAN(), "not", inExpr); + + return handleNullAllowed(notInExpr, subfieldExpr, nullAllowed); + } + case common::FilterKind::kNegatedBigintValuesUsingBitmask: { + auto bitmaskFilter = + static_cast(filter); + bool nullAllowed = bitmaskFilter->nullAllowed(); + const common::Filter* nonNegatedFilter = bitmaskFilter->getNonNegated(); + + auto inExpr = filterToExpr(subfield, nonNegatedFilter, rowType, pool); + + auto notInExpr = + std::make_shared(BOOLEAN(), "not", inExpr); + + return handleNullAllowed(notInExpr, subfieldExpr, nullAllowed); + } + + case common::FilterKind::kBytesValues: { + auto bytesFilter = static_cast(filter); + const auto& values = bytesFilter->values(); + auto subfieldType = subfieldExpr->type(); + + std::vector valueExprs; + for (const auto& value : values) { + valueExprs.push_back( + std::make_shared(subfieldType, variant(value))); + } + + auto arrayExpr = std::make_shared( + ARRAY(subfieldType), valueExprs, "array_constructor"); + + auto inExpr = std::make_shared( + BOOLEAN(), "in", subfieldExpr, arrayExpr); + + return handleNullAllowed( + inExpr, subfieldExpr, bytesFilter->nullAllowed()); + } + + case common::FilterKind::kNegatedBytesValues: { + auto negatedBytesFilter = + static_cast(filter); + bool nullAllowed = negatedBytesFilter->nullAllowed(); + const common::Filter* nonNegatedFilter = + negatedBytesFilter->getNonNegated(); + + auto inExpr = filterToExpr(subfield, nonNegatedFilter, rowType, pool); + + auto notInExpr = + std::make_shared(BOOLEAN(), "not", inExpr); + + return handleNullAllowed(notInExpr, subfieldExpr, nullAllowed); + } + + case common::FilterKind::kTimestampRange: { + auto timestampFilter = static_cast(filter); + auto lower = std::make_shared( + TIMESTAMP(), variant(timestampFilter->lower())); + auto upper = std::make_shared( + TIMESTAMP(), variant(timestampFilter->upper())); + + CallTypedExprPtr rangeExpr; + if (timestampFilter->isSingleValue()) { + rangeExpr = std::make_shared( + BOOLEAN(), "eq", subfieldExpr, lower); + } else { + auto greaterOrEqual = std::make_shared( + BOOLEAN(), "gte", subfieldExpr, lower); + + auto lessOrEqual = std::make_shared( + BOOLEAN(), "lte", subfieldExpr, upper); + + rangeExpr = std::make_shared( + BOOLEAN(), "and", greaterOrEqual, lessOrEqual); + } + + return handleNullAllowed( + rangeExpr, subfieldExpr, timestampFilter->nullAllowed()); + } + + case common::FilterKind::kBigintMultiRange: { + auto multiRangeFilter = + static_cast(filter); + const auto& ranges = multiRangeFilter->ranges(); + + std::vector rangeExprs; + for (const auto& range : ranges) { + auto rangeExpr = filterToExpr(subfield, range.get(), rowType, pool); + rangeExprs.push_back(rangeExpr); + } + + return std::make_shared(BOOLEAN(), rangeExprs, "or"); + } + + case common::FilterKind::kMultiRange: { + auto multiRangeFilter = static_cast(filter); + const auto& filters = multiRangeFilter->filters(); + + std::vector filterExprs; + for (const auto& subFilter : filters) { + auto filterExpr = + filterToExpr(subfield, subFilter.get(), rowType, pool); + filterExprs.push_back(filterExpr); + } + + return std::make_shared(BOOLEAN(), filterExprs, "or"); + } + + case common::FilterKind::kNegatedBytesRange: { + auto negatedBytesFilter = + static_cast(filter); + bool nullAllowed = negatedBytesFilter->nullAllowed(); + const common::Filter* nonNegatedFilter = + negatedBytesFilter->getNonNegated(); + + auto rangeExpr = filterToExpr(subfield, nonNegatedFilter, rowType, pool); + + auto notRangeExpr = + std::make_shared(BOOLEAN(), "not", rangeExpr); + + return handleNullAllowed(notRangeExpr, subfieldExpr, nullAllowed); + } + + case common::FilterKind::kHugeintRange: + case common::FilterKind::kHugeintValuesUsingHashTable: + default: + VELOX_NYI( + "Filter type not yet implemented in filterToExpr: {}", + filter->toString()); + return subfieldExpr; + } +} +} // namespace facebook::velox::core::test diff --git a/velox/exec/tests/utils/FilterToExpression.h b/velox/exec/tests/utils/FilterToExpression.h new file mode 100644 index 000000000000..e0664cb03a31 --- /dev/null +++ b/velox/exec/tests/utils/FilterToExpression.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/memory/MemoryPool.h" +#include "velox/core/ITypedExpr.h" +#include "velox/type/Filter.h" +#include "velox/type/Subfield.h" + +namespace facebook::velox::core::test { +/// Converts a Filter object to a TypedExpr object that can be used in Velox's +/// expression evaluation system. +/// +/// This function takes a filter that applies to a specific subfield and +/// converts it to an equivalent expression tree. +/// +/// @param subfield The subfield to which the filter applies. This is used to +/// create the base expression that the filter condition will +/// be applied to. +/// @param filter The filter to convert. This can be any subclass of Filter, +/// such as AlwaysTrue, IsNull, BigintRange, BytesValues, etc. +/// @param rowType The row type that contains the subfield. This is used to +/// resolve the subfield path and determine its type. +/// @param pool Memory pool to use for allocations. +/// @return A TypedExpr object representing the filter condition applied to +/// the subfield. +core::TypedExprPtr filterToExpr( + const common::Subfield& subfield, + const common::Filter* filter, + const RowTypePtr& rowType, + memory::MemoryPool* pool); +} // namespace facebook::velox::core::test diff --git a/velox/exec/tests/utils/HashJoinTestBase.h b/velox/exec/tests/utils/HashJoinTestBase.h index b063389cc607..3f1bc5509f92 100644 --- a/velox/exec/tests/utils/HashJoinTestBase.h +++ b/velox/exec/tests/utils/HashJoinTestBase.h @@ -42,10 +42,24 @@ using facebook::velox::test::BatchMaker; struct TestParam { int64_t numDrivers{1}; + bool parallelBuildSideRowsEnabled; - explicit TestParam(int _numDrivers) : numDrivers(_numDrivers) {} + explicit TestParam(int _numDrivers) + : numDrivers(_numDrivers), parallelBuildSideRowsEnabled(false) {} + + TestParam(int _numDrivers, bool _parallelBuildSideRowsEnabled) + : numDrivers(_numDrivers), + parallelBuildSideRowsEnabled(_parallelBuildSideRowsEnabled) {} }; +// Required for GTest to generate unique parameterized test names. +inline std::string TestParamToName(const TestParam& param) { + return fmt::format( + "{}_drivers_{}_parallelBuildSideRowsEnabled", + param.numDrivers, + param.parallelBuildSideRowsEnabled ? "with" : "without"); +} + using SplitInput = std::unordered_map>; @@ -407,6 +421,11 @@ class HashJoinBuilder { return *this; } + HashJoinBuilder& parallelizeJoinBuildRows(bool value) { + parallelJoinBuildRowsEnabled_ = value; + return *this; + } + HashJoinBuilder& spillDirectory(const std::string& spillDirectory) { spillDirectory_ = spillDirectory; return *this; @@ -467,8 +486,9 @@ class HashJoinBuilder { } for (const auto& testData : testSettings) { - SCOPED_TRACE(fmt::format( - "{} numDrivers: {}", testData.debugString(), numDrivers_)); + SCOPED_TRACE( + fmt::format( + "{} numDrivers: {}", testData.debugString(), numDrivers_)); auto planNodeIdGenerator = std::make_shared(); std::shared_ptr joinNode; auto planNode = @@ -596,15 +616,14 @@ class HashJoinBuilder { builder.splits(splitEntry.first, splitEntry.second); } } - auto queryCtx = core::QueryCtx::create( - executor_, - core::QueryConfig{{}}, - std::unordered_map>{}, - cache::AsyncDataCache::getInstance(), - memory::MemoryManager::getInstance()->addRootPool( - "query_pool", - memory::kMaxMemory, - memory::MemoryReclaimer::create())); + auto queryCtx = core::QueryCtx::Builder() + .executor(executor_) + .pool( + memory::MemoryManager::getInstance()->addRootPool( + "query_pool", + memory::kMaxMemory, + memory::MemoryReclaimer::create())) + .build(); std::shared_ptr spillDirectory; int32_t spillPct{0}; if (injectSpill) { @@ -628,6 +647,9 @@ class HashJoinBuilder { config( core::QueryConfig::kHashProbeFinishEarlyOnEmptyBuild, hashProbeFinishEarlyOnEmptyBuild_ ? "true" : "false"); + config( + core::QueryConfig::kParallelOutputJoinBuildRowsEnabled, + parallelJoinBuildRowsEnabled_ ? "true" : "false"); if (maxDriverYieldTimeMs != 0) { config( core::QueryConfig::kDriverCpuTimeSliceLimitMs, @@ -759,6 +781,7 @@ class HashJoinBuilder { std::shared_ptr queryPool_; std::string spillDirectory_; bool hashProbeFinishEarlyOnEmptyBuild_{true}; + bool parallelJoinBuildRowsEnabled_{false}; SplitInput inputSplits_; std::function makeInputSplits_; @@ -773,7 +796,8 @@ class HashJoinTestBase : public HiveConnectorTestBase { HashJoinTestBase() : HashJoinTestBase(TestParam(1)) {} explicit HashJoinTestBase(const TestParam& param) - : numDrivers_(param.numDrivers) {} + : numDrivers_(param.numDrivers), + parallelBuildSideRowsEnabled_(param.parallelBuildSideRowsEnabled) {} void SetUp() override { HiveConnectorTestBase::SetUp(); @@ -791,7 +815,7 @@ class HashJoinTestBase : public HiveConnectorTestBase { } // Make splits with each plan node having a number of source files. - SplitInput makeSpiltInput( + SplitInput makeSplitInput( const std::vector& nodeIds, const std::vector>>& files) { VELOX_CHECK_EQ(nodeIds.size(), files.size()); @@ -945,7 +969,8 @@ class HashJoinTestBase : public HiveConnectorTestBase { case core::JoinType::kRightSemiProject: return core::JoinType::kLeftSemiProject; default: - VELOX_FAIL("Cannot flip join type: {}", core::joinTypeName(joinType)); + VELOX_FAIL( + "Cannot flip join type: {}", core::JoinTypeName::toName(joinType)); } } @@ -965,6 +990,7 @@ class HashJoinTestBase : public HiveConnectorTestBase { } const int32_t numDrivers_; + const bool parallelBuildSideRowsEnabled_; // The default left and right table types used for test. RowTypePtr probeType_; diff --git a/velox/exec/tests/utils/HiveConnectorTestBase.cpp b/velox/exec/tests/utils/HiveConnectorTestBase.cpp index 94aea88367c2..6e12c0d98aa1 100644 --- a/velox/exec/tests/utils/HiveConnectorTestBase.cpp +++ b/velox/exec/tests/utils/HiveConnectorTestBase.cpp @@ -18,12 +18,13 @@ #include "velox/common/file/FileSystems.h" #include "velox/common/file/tests/FaultyFileSystem.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/dwio/dwrf/RegisterDwrfReader.h" #include "velox/dwio/dwrf/RegisterDwrfWriter.h" -#include "velox/dwio/dwrf/reader/DwrfReader.h" #include "velox/dwio/dwrf/writer/FlushPolicy.h" #include "velox/dwio/dwrf/writer/Writer.h" +#include "velox/dwio/text/RegisterTextReader.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" namespace facebook::velox::exec::test { @@ -35,20 +36,17 @@ HiveConnectorTestBase::HiveConnectorTestBase() { void HiveConnectorTestBase::SetUp() { OperatorTestBase::SetUp(); - connector::registerConnectorFactory( - std::make_shared()); - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared( - std::unordered_map()), - ioExecutor_.get()); + connector::hive::HiveConnectorFactory factory; + auto hiveConnector = factory.newConnector( + kHiveConnectorId, + std::make_shared( + std::unordered_map()), + ioExecutor_.get()); connector::registerConnector(hiveConnector); dwio::common::registerFileSinks(); dwrf::registerDwrfReaderFactory(); dwrf::registerDwrfWriterFactory(); + text::registerTextReaderFactory(); } void HiveConnectorTestBase::TearDown() { @@ -58,18 +56,17 @@ void HiveConnectorTestBase::TearDown() { dwrf::unregisterDwrfReaderFactory(); dwrf::unregisterDwrfWriterFactory(); connector::unregisterConnector(kHiveConnectorId); - connector::unregisterConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName); + text::unregisterTextReaderFactory(); OperatorTestBase::TearDown(); } void HiveConnectorTestBase::resetHiveConnector( const std::shared_ptr& config) { connector::unregisterConnector(kHiveConnectorId); + + connector::hive::HiveConnectorFactory factory; auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector(kHiveConnectorId, config, ioExecutor_.get()); + factory.newConnector(kHiveConnectorId, config, ioExecutor_.get()); connector::registerConnector(hiveConnector); } @@ -206,23 +203,23 @@ HiveConnectorTestBase::makeHiveConnectorSplits( infoColumns) { auto file = filesystems::getFileSystem(filePath, nullptr)->openFileForRead(filePath); - const int64_t fileSize = file->size(); + const uint64_t fileSize = file->size(); // Take the upper bound. - const int64_t splitSize = std::ceil((fileSize) / splitCount); + const uint64_t splitSize = std::ceil((fileSize) / splitCount); std::vector> splits; // Add all the splits. - for (int i = 0; i < splitCount; i++) { + for (uint32_t i = 0; i < splitCount; i++) { auto splitBuilder = HiveConnectorSplitBuilder(filePath) .fileFormat(format) .start(i * splitSize) .length(splitSize); if (infoColumns.has_value()) { - for (auto infoColumn : infoColumns.value()) { + for (const auto& infoColumn : infoColumns.value()) { splitBuilder.infoColumn(infoColumn.first, infoColumn.second); } } if (partitionKeys.has_value()) { - for (auto partitionKey : partitionKeys.value()) { + for (const auto& partitionKey : partitionKeys.value()) { splitBuilder.partitionKey(partitionKey.first, partitionKey.second); } } @@ -262,7 +259,8 @@ std::vector> HiveConnectorTestBase::makeHiveConnectorSplits( const std::vector>& filePaths) { std::vector> splits; - for (auto filePath : filePaths) { + splits.reserve(filePaths.size()); + for (const auto& filePath : filePaths) { splits.push_back(makeHiveConnectorSplit( filePath->getPath(), filePath->fileSize(), diff --git a/velox/exec/tests/utils/HiveConnectorTestBase.h b/velox/exec/tests/utils/HiveConnectorTestBase.h index 2173acf133e7..b3660714f509 100644 --- a/velox/exec/tests/utils/HiveConnectorTestBase.h +++ b/velox/exec/tests/utils/HiveConnectorTestBase.h @@ -15,24 +15,18 @@ */ #pragma once -#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/dwio/dwrf/common/Config.h" #include "velox/dwio/dwrf/writer/FlushPolicy.h" -#include "velox/exec/Operator.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/TempFilePath.h" -#include "velox/type/tests/SubfieldFiltersBuilder.h" namespace facebook::velox::exec::test { static const std::string kHiveConnectorId = "test-hive"; -using ColumnHandleMap = - std::unordered_map>; - class HiveConnectorTestBase : public OperatorTestBase { public: HiveConnectorTestBase(); @@ -236,8 +230,9 @@ class HiveConnectorTestBase : public OperatorTestBase { const std::string& name, const TypePtr& type); - static ColumnHandleMap allRegularColumns(const RowTypePtr& rowType) { - ColumnHandleMap assignments; + static connector::ColumnHandleMap allRegularColumns( + const RowTypePtr& rowType) { + connector::ColumnHandleMap assignments; assignments.reserve(rowType->size()); for (uint32_t i = 0; i < rowType->size(); ++i) { const auto& name = rowType->nameOf(i); @@ -253,7 +248,7 @@ class HiveConnectorSplitBuilder : public connector::hive::HiveConnectorSplitBuilder { public: explicit HiveConnectorSplitBuilder(std::string filePath) - : connector::hive::HiveConnectorSplitBuilder(filePath) { + : connector::hive::HiveConnectorSplitBuilder(std::move(filePath)) { connectorId(kHiveConnectorId); } }; diff --git a/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp b/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp index 8ec2d9710745..8bb2bb2a0072 100644 --- a/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp +++ b/velox/exec/tests/utils/IndexLookupJoinTestBase.cpp @@ -15,15 +15,29 @@ */ #include "velox/exec/tests/utils/IndexLookupJoinTestBase.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" -namespace fecebook::velox::exec::test { +namespace facebook::velox::exec::test { +using namespace facebook::velox::test; -facebook::velox::RowTypePtr IndexLookupJoinTestBase::concat( - const facebook::velox::RowTypePtr& a, - const facebook::velox::RowTypePtr& b) { +namespace { +std::vector appendMarker(const std::vector columns) { + std::vector resultColumns; + resultColumns.reserve(columns.size() + 1); + for (const auto& column : columns) { + resultColumns.push_back(column); + } + resultColumns.push_back("__match__"); + return resultColumns; +} +} // namespace + +RowTypePtr IndexLookupJoinTestBase::concat( + const RowTypePtr& a, + const RowTypePtr& b) { std::vector names = a->names(); - std::vector types = a->children(); + std::vector types = a->children(); names.insert(names.end(), b->names().begin(), b->names().end()); types.insert(types.end(), b->children().begin(), b->children().end()); return ROW(std::move(names), std::move(types)); @@ -37,15 +51,15 @@ int IndexLookupJoinTestBase::getNumRows(const std::vector& cardinalities) { return numRows; } -std::vector -IndexLookupJoinTestBase::generateProbeInput( +std::vector IndexLookupJoinTestBase::generateProbeInput( size_t numBatches, size_t batchSize, size_t numDuplicateProbeRows, SequenceTableData& tableData, - std::shared_ptr& pool, + std::shared_ptr& pool, const std::vector& probeJoinKeys, - const std::vector inColumns, + bool hasNullKeys, + const std::vector& inColumns, const std::vector>& betweenColumns, std::optional equalMatchPct, std::optional inMatchPct, @@ -53,16 +67,33 @@ IndexLookupJoinTestBase::generateProbeInput( VELOX_CHECK_LE( probeJoinKeys.size() + betweenColumns.size() + inColumns.size(), keyType_->size()); - std::vector probeInputs; + std::vector probeInputs; probeInputs.reserve(numBatches); - facebook::velox::VectorFuzzer::Options opts; + VectorFuzzer::Options opts; opts.vectorSize = batchSize * numDuplicateProbeRows; opts.allowSlice = false; // TODO: add nullable handling later. - opts.nullRatio = 0.0; - facebook::velox::VectorFuzzer fuzzer(opts, pool.get()); + opts.nullRatio = hasNullKeys ? 0.1 : 0.0; + VectorFuzzer fuzzer(opts, pool.get()); for (int i = 0; i < numBatches; ++i) { probeInputs.push_back(fuzzer.fuzzInputRow(probeType_)); + // NOTE: index connector doesn't expect in condition column rray elements to + // be null. + if ((!inMatchPct.has_value() || tableData.keyData->size() == 0) && + hasNullKeys) { + for (int i = 0; i < probeType_->size(); ++i) { + const auto columnType = probeType_->childAt(i); + if (columnType->isArray()) { + opts.nullRatio = 0.0; + VectorFuzzer vectorFuzzer(opts, pool.get()); + probeInputs.back()->childAt(i) = vectorFuzzer.fuzz(columnType); + VELOX_CHECK(!probeInputs.back()->childAt(i)->mayHaveNulls()); + VELOX_CHECK_EQ( + probeInputs.back()->childAt(i)->size(), + probeInputs.back()->size()); + } + } + } } if (tableData.keyData->size() == 0) { @@ -70,42 +101,49 @@ IndexLookupJoinTestBase::generateProbeInput( } const auto numTableRows = tableData.keyData->size(); - std::vector> tableKeyVectors; + std::vector> tableKeyVectors; for (int i = 0; i < probeJoinKeys.size(); ++i) { auto keyVector = tableData.keyData->childAt(i); keyVector->loadedVector(); - facebook::velox::BaseVector::flattenVector(keyVector); + BaseVector::flattenVector(keyVector); tableKeyVectors.push_back( - std::dynamic_pointer_cast>( - keyVector)); + std::dynamic_pointer_cast>(keyVector)); } if (equalMatchPct.has_value()) { VELOX_CHECK_GE(equalMatchPct.value(), 0); VELOX_CHECK_LE(equalMatchPct.value(), 100); - for (int i = 0, totalRows = 0; i < numBatches; ++i) { - std::vector> probeKeyVectors; + for (int i = 0; i < numBatches; ++i) { + std::vector> probeKeyVectors; for (int j = 0; j < probeJoinKeys.size(); ++j) { - probeKeyVectors.push_back(facebook::velox::BaseVector::create< - facebook::velox::FlatVector>( - probeType_->findChild(probeJoinKeys[j]), - probeInputs[i]->size(), - pool.get())); + probeKeyVectors.push_back( + BaseVector::create>( + probeType_->findChild(probeJoinKeys[j]), + probeInputs[i]->size(), + pool.get())); } for (int row = 0; row < probeInputs[i]->size(); - row += numDuplicateProbeRows, totalRows += numDuplicateProbeRows) { - if ((totalRows / numDuplicateProbeRows) % 100 < equalMatchPct.value()) { - const auto matchKeyRow = folly::Random::rand64(numTableRows); + row += numDuplicateProbeRows) { + const auto hit = + (folly::Random::rand32(rng_) % 100) < equalMatchPct.value(); + if (hit) { + const auto matchKeyRow = folly::Random::rand32(numTableRows, rng_); for (int j = 0; j < probeJoinKeys.size(); ++j) { for (int k = 0; k < numDuplicateProbeRows; ++k) { + if (probeKeyVectors[j]->isNullAt(row + k)) { + continue; + } probeKeyVectors[j]->set( row + k, tableKeyVectors[j]->valueAt(matchKeyRow)); } } } else { for (int j = 0; j < probeJoinKeys.size(); ++j) { - const auto randomValue = folly::Random::rand32() % 4096; + const auto randomValue = folly::Random::rand32(rng_) % 4096; for (int k = 0; k < numDuplicateProbeRows; ++k) { + if (probeKeyVectors[j]->isNullAt(row + k)) { + continue; + } probeKeyVectors[j]->set( row + k, tableData.maxKeys[j] + 1 + randomValue); } @@ -125,9 +163,8 @@ IndexLookupJoinTestBase::generateProbeInput( for (int i = 0; i < inColumns.size(); ++i) { const auto inColumnName = inColumns[i]; const auto inColumnChannel = probeType_->getChildIdx(inColumnName); - auto inColumnType = - std::dynamic_pointer_cast( - probeType_->childAt(inColumnChannel)); + auto inColumnType = std::dynamic_pointer_cast( + probeType_->childAt(inColumnChannel)); VELOX_CHECK_NOT_NULL(inColumnType); const auto tableKeyChannel = probeJoinKeys.size() + i; VELOX_CHECK(keyType_->childAt(tableKeyChannel) @@ -144,10 +181,11 @@ IndexLookupJoinTestBase::generateProbeInput( for (int i = 0; i < numBatches; ++i) { probeInputs[i]->childAt(inColumnChannel) = makeArrayVector( probeInputs[i]->size(), - [&](auto row) -> facebook::velox::vector_size_t { - return maxValue - minValue + 1; - }, - [&](auto /*unused*/, auto index) { return minValue + index; }); + [&](auto row) -> vector_size_t { return maxValue - minValue + 1; }, + [&](auto /*unused*/, auto index) { return minValue + index; }, + [&](auto row) { + return probeInputs[i]->childAt(inColumnChannel)->isNullAt(row); + }); } } } @@ -177,15 +215,22 @@ IndexLookupJoinTestBase::generateProbeInput( if (lowerBoundChannel.has_value()) { probeInputs[i]->childAt(lowerBoundChannel.value()) = makeFlatVector( - probeInputs[i]->size(), [&](auto /*unused*/) { + probeInputs[i]->size(), + [&](auto /*unused*/) { return tableData.minKeys[tableKeyChannel]; + }, + [&](auto row) { + return probeInputs[i] + ->childAt(lowerBoundChannel.value()) + ->isNullAt(row); }); } const auto upperBoundColumn = betweenColumn.second; if (upperBoundChannel.has_value()) { probeInputs[i]->childAt(upperBoundChannel.value()) = makeFlatVector( - probeInputs[i]->size(), [&](auto /*unused*/) -> int64_t { + probeInputs[i]->size(), + [&](auto /*unused*/) -> int64_t { if (betweenMatchPct.value() == 0) { return tableData.minKeys[tableKeyChannel] - 1; } else { @@ -194,6 +239,11 @@ IndexLookupJoinTestBase::generateProbeInput( tableData.minKeys[tableKeyChannel]) * betweenMatchPct.value() / 100; } + }, + [&](auto row) { + return probeInputs[i] + ->childAt(upperBoundChannel.value()) + ->isNullAt(row); }); } } @@ -202,64 +252,68 @@ IndexLookupJoinTestBase::generateProbeInput( return probeInputs; } -facebook::velox::core::PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( - const std::shared_ptr& - planNodeIdGenerator, - facebook::velox::core::TableScanNodePtr indexScanNode, - const std::vector& probeVectors, +PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( + const std::shared_ptr& planNodeIdGenerator, + TableScanNodePtr indexScanNode, + const std::vector& probeVectors, const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - facebook::velox::core::JoinType joinType, + bool hasMarker, + JoinType joinType, const std::vector& outputColumns, - facebook::velox::core::PlanNodeId& joinNodeId) { + PlanNodeId& joinNodeId) { VELOX_CHECK_EQ(leftKeys.size(), rightKeys.size()); VELOX_CHECK_LE(leftKeys.size(), keyType_->size()); - return facebook::velox::exec::test::PlanBuilder( - planNodeIdGenerator, pool_.get()) + return PlanBuilder(planNodeIdGenerator, pool_.get()) .values(probeVectors) - .indexLookupJoin( - leftKeys, - rightKeys, - indexScanNode, - joinConditions, - outputColumns, - joinType) + .startIndexLookupJoin() + .leftKeys(leftKeys) + .rightKeys(rightKeys) + .indexSource(indexScanNode) + .joinConditions(joinConditions) + .hasMarker(hasMarker) + .outputLayout(hasMarker ? appendMarker(outputColumns) : outputColumns) + .joinType(joinType) + .endIndexLookupJoin() .capturePlanNodeId(joinNodeId) .planNode(); } -facebook::velox::core::PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( - const std::shared_ptr& - planNodeIdGenerator, - facebook::velox::core::TableScanNodePtr indexScanNode, +PlanNodePtr IndexLookupJoinTestBase::makeLookupPlan( + const std::shared_ptr& planNodeIdGenerator, + TableScanNodePtr indexScanNode, const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - facebook::velox::core::JoinType joinType, + const std::string& filter, + bool hasMarker, + JoinType joinType, const std::vector& outputColumns) { VELOX_CHECK_EQ(leftKeys.size(), rightKeys.size()); VELOX_CHECK_LE(leftKeys.size(), keyType_->size()); - return facebook::velox::exec::test::PlanBuilder( - planNodeIdGenerator, pool_.get()) + return PlanBuilder(planNodeIdGenerator, pool_.get()) .startTableScan() .outputType(probeType_) .endTableScan() .captureScanNodeId(probeScanNodeId_) - .indexLookupJoin( - leftKeys, - rightKeys, - indexScanNode, - joinConditions, - outputColumns, - joinType) + .startIndexLookupJoin() + .leftKeys(leftKeys) + .rightKeys(rightKeys) + .indexSource(indexScanNode) + .joinConditions(joinConditions) + .filter(filter) + .hasMarker(hasMarker) + .outputLayout(hasMarker ? appendMarker(outputColumns) : outputColumns) + .joinType(joinType) + .endIndexLookupJoin() .capturePlanNodeId(joinNodeId_) .planNode(); } void IndexLookupJoinTestBase::createDuckDbTable( const std::string& tableName, - const std::vector& data) { + const std::vector& data) { // Change each column with prefix 'c' to simplify the duckdb table // column naming. std::vector columnNames; @@ -267,7 +321,7 @@ void IndexLookupJoinTestBase::createDuckDbTable( for (int i = 0; i < data[0]->type()->size(); ++i) { columnNames.push_back(fmt::format("c{}", i)); } - std::vector duckDbInputs; + std::vector duckDbInputs; duckDbInputs.reserve(data.size()); for (const auto& dataVector : data) { duckDbInputs.emplace_back( @@ -276,29 +330,20 @@ void IndexLookupJoinTestBase::createDuckDbTable( duckDbQueryRunner_.createTable(tableName, duckDbInputs); } -facebook::velox::core::TableScanNodePtr -IndexLookupJoinTestBase::makeIndexScanNode( - const std::shared_ptr& - planNodeIdGenerator, - const std::shared_ptr - indexTableHandle, - const facebook::velox::RowTypePtr& outputType, - std::unordered_map< - std::string, - std::shared_ptr>& - assignments) { - auto planBuilder = facebook::velox::exec::test::PlanBuilder( - planNodeIdGenerator, pool_.get()); - auto indexTableScan = - std::dynamic_pointer_cast( - facebook::velox::exec::test::PlanBuilder::TableScanBuilder( - planBuilder) - .tableHandle(indexTableHandle) - .outputType(outputType) - .assignments(assignments) - .endTableScan() - .capturePlanNodeId(indexScanNodeId_) - .planNode()); +TableScanNodePtr IndexLookupJoinTestBase::makeIndexScanNode( + const std::shared_ptr& planNodeIdGenerator, + const connector::ConnectorTableHandlePtr& indexTableHandle, + const RowTypePtr& outputType, + const connector::ColumnHandleMap& assignments) { + auto planBuilder = PlanBuilder(planNodeIdGenerator, pool_.get()); + auto indexTableScan = std::dynamic_pointer_cast( + PlanBuilder::TableScanBuilder(planBuilder) + .tableHandle(indexTableHandle) + .outputType(outputType) + .assignments(assignments) + .endTableScan() + .capturePlanNodeId(indexScanNodeId_) + .planNode()); VELOX_CHECK_NOT_NULL(indexTableScan); return indexTableScan; } @@ -306,14 +351,14 @@ IndexLookupJoinTestBase::makeIndexScanNode( void IndexLookupJoinTestBase::generateIndexTableData( const std::vector& keyCardinalities, SequenceTableData& tableData, - std::shared_ptr& pool) { + std::shared_ptr& pool) { VELOX_CHECK_EQ(keyCardinalities.size(), keyType_->size()); const auto numRows = getNumRows(keyCardinalities); - facebook::velox::VectorFuzzer::Options opts; + VectorFuzzer::Options opts; opts.vectorSize = numRows; opts.nullRatio = 0.0; opts.allowSlice = false; - facebook::velox::VectorFuzzer fuzzer(opts, pool.get()); + VectorFuzzer fuzzer(opts, pool.get()); tableData.keyData = fuzzer.fuzzInputFlatRow(keyType_); tableData.valueData = fuzzer.fuzzInputFlatRow(valueType_); @@ -339,7 +384,7 @@ void IndexLookupJoinTestBase::generateIndexTableData( tableData.maxKeys[i] = maxKey; } - std::vector tableColumns; + std::vector tableColumns; VELOX_CHECK_EQ(tableType_->size(), keyType_->size() + valueType_->size()); tableColumns.reserve(tableType_->size()); for (auto i = 0; i < keyType_->size(); ++i) { @@ -351,9 +396,9 @@ void IndexLookupJoinTestBase::generateIndexTableData( tableData.tableData = makeRowVector(tableType_->names(), tableColumns); } -facebook::velox::RowTypePtr IndexLookupJoinTestBase::makeScanOutputType( +RowTypePtr IndexLookupJoinTestBase::makeScanOutputType( std::vector outputNames) { - std::vector types; + std::vector types; for (int i = 0; i < outputNames.size(); ++i) { if (valueType_->getChildIdxIfExists(outputNames[i]).has_value()) { types.push_back(valueType_->findChild(outputNames[i])); @@ -361,65 +406,114 @@ facebook::velox::RowTypePtr IndexLookupJoinTestBase::makeScanOutputType( } types.push_back(keyType_->findChild(outputNames[i])); } - return facebook::velox::ROW(std::move(outputNames), std::move(types)); + return ROW(std::move(outputNames), std::move(types)); } bool IndexLookupJoinTestBase::isFilter(const std::string& conditionSql) const { const auto inputType = concat(keyType_, probeType_); - return facebook::velox::exec::test::PlanBuilder::parseIndexJoinCondition( + return PlanBuilder::parseIndexJoinCondition( conditionSql, inputType, pool_.get()) ->isFilter(); } -std::shared_ptr -IndexLookupJoinTestBase::runLookupQuery( - const facebook::velox::core::PlanNodePtr& plan, +std::shared_ptr IndexLookupJoinTestBase::runLookupQuery( + const PlanNodePtr& plan, int numPrefetchBatches, const std::string& duckDbVefifySql) { - return facebook::velox::exec::test::AssertQueryBuilder(duckDbQueryRunner_) + return AssertQueryBuilder(duckDbQueryRunner_) .plan(plan) .config( - facebook::velox::core::QueryConfig:: - kIndexLookupJoinMaxPrefetchBatches, + QueryConfig::kIndexLookupJoinMaxPrefetchBatches, std::to_string(numPrefetchBatches)) .assertResults(duckDbVefifySql); } -std::shared_ptr -IndexLookupJoinTestBase::runLookupQuery( - const facebook::velox::core::PlanNodePtr& plan, - const std::vector< - std::shared_ptr>& probeFiles, +std::shared_ptr IndexLookupJoinTestBase::runLookupQuery( + const PlanNodePtr& plan, + const std::vector>& probeFiles, bool serialExecution, bool barrierExecution, int maxOutputRows, int numPrefetchBatches, const std::string& duckDbVefifySql) { - return facebook::velox::exec::test::AssertQueryBuilder(duckDbQueryRunner_) + return AssertQueryBuilder(duckDbQueryRunner_) .plan(plan) .splits(probeScanNodeId_, makeHiveConnectorSplits(probeFiles)) .serialExecution(serialExecution) .barrierExecution(barrierExecution) + .config(QueryConfig::kMaxOutputBatchRows, std::to_string(maxOutputRows)) .config( - facebook::velox::core::QueryConfig::kMaxOutputBatchRows, - std::to_string(maxOutputRows)) - .config( - facebook::velox::core::QueryConfig:: - kIndexLookupJoinMaxPrefetchBatches, + QueryConfig::kIndexLookupJoinMaxPrefetchBatches, std::to_string(numPrefetchBatches)) .assertResults(duckDbVefifySql); } -std::vector> +void IndexLookupJoinTestBase::verifyResultWithMatchColumn( + const PlanNodePtr& planWithoutMatchColumn, + const PlanNodeId& probeScanNodeIdWithoutMatchColumn, + const PlanNodePtr& planWithMatchColumn, + const PlanNodeId& probeScanNodeIdWithMatchColumn, + const std::vector>& probeFiles) { + VectorPtr expectedResult = AssertQueryBuilder(duckDbQueryRunner_) + .plan(planWithoutMatchColumn) + .splits( + probeScanNodeIdWithoutMatchColumn, + makeHiveConnectorSplits(probeFiles)) + .copyResults(pool()); + BaseVector::flattenVector(expectedResult); + + VectorPtr resultWithMatchColumn = AssertQueryBuilder(duckDbQueryRunner_) + .plan(planWithMatchColumn) + .splits( + probeScanNodeIdWithMatchColumn, + makeHiveConnectorSplits(probeFiles)) + .copyResults(pool()); + BaseVector::flattenVector(resultWithMatchColumn); + auto rowResultWithMatchMatchColumn = + std::dynamic_pointer_cast(resultWithMatchColumn); + const auto resultWithMatchColumnType = + std::dynamic_pointer_cast( + rowResultWithMatchMatchColumn->type()); + std::vector childVectors; + std::unordered_set lookupColumnNameSet( + valueType_->names().begin(), valueType_->names().end()); + std::vector lookupColumnVectors; + for (int i = 0; i < rowResultWithMatchMatchColumn->childrenSize() - 1; ++i) { + childVectors.push_back(rowResultWithMatchMatchColumn->childAt(i)); + if (lookupColumnNameSet.contains(resultWithMatchColumnType->nameOf(i))) { + lookupColumnVectors.push_back(rowResultWithMatchMatchColumn->childAt(i)); + } + } + auto resultWithoutMatchColumn = makeRowVector(childVectors); + assertEqualVectors(expectedResult, resultWithoutMatchColumn); + // Verify the match column if it is expected. + const auto matchColumn = + rowResultWithMatchMatchColumn + ->childAt(rowResultWithMatchMatchColumn->childrenSize() - 1) + ->asFlatVector(); + for (int i = 0; i < resultWithMatchColumn->size(); ++i) { + const bool match = matchColumn->valueAt(i); + if (match) { + for (const auto& lookupColumnVector : lookupColumnVectors) { + ASSERT_FALSE(lookupColumnVector->isNullAt(i)); + } + } else { + for (const auto& lookupColumnVector : lookupColumnVectors) { + ASSERT_TRUE(lookupColumnVector->isNullAt(i)); + } + } + } +} + +std::vector> IndexLookupJoinTestBase::createProbeFiles( - const std::vector& probeVectors) { - std::vector> - probeFiles; + const std::vector& probeVectors) { + std::vector> probeFiles; probeFiles.reserve(probeVectors.size()); for (auto i = 0; i < probeVectors.size(); ++i) { - probeFiles.push_back(facebook::velox::exec::test::TempFilePath::create()); + probeFiles.push_back(TempFilePath::create()); } writeToFiles(toFilePaths(probeFiles), probeVectors); return probeFiles; } -} // namespace fecebook::velox::exec::test +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/IndexLookupJoinTestBase.h b/velox/exec/tests/utils/IndexLookupJoinTestBase.h index 4fa0c3472826..db96fbae95df 100644 --- a/velox/exec/tests/utils/IndexLookupJoinTestBase.h +++ b/velox/exec/tests/utils/IndexLookupJoinTestBase.h @@ -13,30 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - +#pragma once #include "velox/connectors/Connector.h" #include "velox/core/PlanNode.h" -#include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/parse/PlanNodeIdGenerator.h" -namespace fecebook::velox::exec::test { -class IndexLookupJoinTestBase - : public facebook::velox::exec::test::HiveConnectorTestBase { +namespace facebook::velox::exec::test { + +using namespace facebook::velox; +using namespace facebook::velox::core; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +class IndexLookupJoinTestBase : public HiveConnectorTestBase { protected: IndexLookupJoinTestBase() = default; + void SetUp() override { + HiveConnectorTestBase::SetUp(); + rng_.seed(123); + } + struct SequenceTableData { - facebook::velox::RowVectorPtr keyData; - facebook::velox::RowVectorPtr valueData; - facebook::velox::RowVectorPtr tableData; + RowVectorPtr keyData; + RowVectorPtr valueData; + RowVectorPtr tableData; std::vector minKeys; std::vector maxKeys; }; - static facebook::velox::RowTypePtr concat( - const facebook::velox::RowTypePtr& a, - const facebook::velox::RowTypePtr& b); + static RowTypePtr concat(const RowTypePtr& a, const RowTypePtr& b); bool isFilter(const std::string& conditionSql) const; @@ -50,6 +57,7 @@ class IndexLookupJoinTestBase /// @param tableData: contains the sequence table data including key vectors /// and min/max key values. /// @param probeJoinKeys: the prefix key colums used for equality joins. + /// @param hasNullKeys: whether the probe input has null keys. /// @param inColumns: the ordered list of in conditions. /// @param betweenColumns: the ordered list of between conditions. /// @param equalMatchPct: percentage of rows in the probe input that matches @@ -58,14 +66,15 @@ class IndexLookupJoinTestBase /// the rows in index table with between conditions. /// @param inMatchPct: percentage of rows in the probe input that matches the /// rows in index table with in conditions. - std::vector generateProbeInput( + std::vector generateProbeInput( size_t numBatches, size_t batchSize, size_t numDuplicateProbeRows, SequenceTableData& tableData, - std::shared_ptr& pool, + std::shared_ptr& pool, const std::vector& probeJoinKeys, - const std::vector inColumns = {}, + bool hasNullKeys = false, + const std::vector& inColumns = {}, const std::vector>& betweenColumns = {}, std::optional equalMatchPct = std::nullopt, @@ -77,58 +86,59 @@ class IndexLookupJoinTestBase /// @param probeVectors: the probe input vectors. /// @param leftKeys: the left join keys of index lookup join. /// @param rightKeys: the right join keys of index lookup join. + /// @param hasMarker: whether the index join output includes a match + /// column at the end. /// @param joinType: the join type of index lookup join. /// @param outputColumns: the output column names of index lookup join. /// @param joinNodeId: returns the plan node id of the index lookup join /// node. - facebook::velox::core::PlanNodePtr makeLookupPlan( - const std::shared_ptr& - planNodeIdGenerator, - facebook::velox::core::TableScanNodePtr indexScanNode, - const std::vector& probeVectors, + PlanNodePtr makeLookupPlan( + const std::shared_ptr& planNodeIdGenerator, + TableScanNodePtr indexScanNode, + const std::vector& probeVectors, const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - facebook::velox::core::JoinType joinType, + bool hasMarker, + core::JoinType joinType, const std::vector& outputColumns, - facebook::velox::core::PlanNodeId& joinNodeId); + core::PlanNodeId& joinNodeId); /// Makes lookup join plan with the following parameters: + /// @param planNodeIdGenerator: generator for creating unique plan node IDs. /// @param indexScanNode: the index table scan node. - /// @param probeVectors: the probe input vectors. /// @param leftKeys: the left join keys of index lookup join. /// @param rightKeys: the right join keys of index lookup join. + /// @param joinConditions: the join conditions for index lookup join that + /// can't be converted into simple equality join conditions. + /// @param filter: additional filter condition SQL string to apply on join + /// results. Can be empty string if no additional filter is needed. + /// @param hasMarker: whether the index join output includes a match + /// column at the end. /// @param joinType: the join type of index lookup join. /// @param outputColumns: the output column names of index lookup join. - /// @param joinNodeId: returns the plan node id of the index lookup join - /// node. - /// @param probeScanNodeId: returns the plan node id of the probe table scan - facebook::velox::core::PlanNodePtr makeLookupPlan( - const std::shared_ptr& - planNodeIdGenerator, - facebook::velox::core::TableScanNodePtr indexScanNode, + PlanNodePtr makeLookupPlan( + const std::shared_ptr& planNodeIdGenerator, + TableScanNodePtr indexScanNode, const std::vector& leftKeys, const std::vector& rightKeys, const std::vector& joinConditions, - facebook::velox::core::JoinType joinType, + const std::string& filter, + bool hasMarker, + JoinType joinType, const std::vector& outputColumns); void createDuckDbTable( const std::string& tableName, - const std::vector& data); + const std::vector& data); /// Makes index table scan node with the specified index table handle. /// @param outputType: the output schema of the index table scan node. - facebook::velox::core::TableScanNodePtr makeIndexScanNode( - const std::shared_ptr& - planNodeIdGenerator, - const std::shared_ptr - indexTableHandle, - const facebook::velox::RowTypePtr& outputType, - std::unordered_map< - std::string, - std::shared_ptr>& - assignments); + TableScanNodePtr makeIndexScanNode( + const std::shared_ptr& planNodeIdGenerator, + const connector::ConnectorTableHandlePtr& indexTableHandle, + const RowTypePtr& outputType, + const connector::ColumnHandleMap& assignments); /// Generate sequence storage table which will be persisted by mock zippydb /// client for testing. @@ -141,41 +151,47 @@ class IndexLookupJoinTestBase void generateIndexTableData( const std::vector& keyCardinalities, SequenceTableData& tableData, - std::shared_ptr& pool); + std::shared_ptr& pool); /// Write 'probeVectors' to a number of files with one per each file. - std::vector> - createProbeFiles( - const std::vector& probeVectors); + std::vector> createProbeFiles( + const std::vector& probeVectors); /// Makes output schema from the index table scan node with the specified /// column names. - facebook::velox::RowTypePtr makeScanOutputType( - std::vector outputNames); + RowTypePtr makeScanOutputType(std::vector outputNames); - std::shared_ptr runLookupQuery( - const facebook::velox::core::PlanNodePtr& plan, + std::shared_ptr runLookupQuery( + const PlanNodePtr& plan, int numPrefetchBatches, const std::string& duckDbVefifySql); - std::shared_ptr runLookupQuery( - const facebook::velox::core::PlanNodePtr& plan, - const std::vector< - std::shared_ptr>& - probeFiles, + std::shared_ptr runLookupQuery( + const PlanNodePtr& plan, + const std::vector>& probeFiles, bool serialExecution, bool barrierExecution, int maxBatchRows, int numPrefetchBatches, const std::string& duckDbVefifySql); - facebook::velox::RowTypePtr keyType_; - std::optional partitionType_; - facebook::velox::RowTypePtr valueType_; - facebook::velox::RowTypePtr tableType_; - facebook::velox::RowTypePtr probeType_; - facebook::velox::core::PlanNodeId joinNodeId_; - facebook::velox::core::PlanNodeId indexScanNodeId_; - facebook::velox::core::PlanNodeId probeScanNodeId_; + /// Verifies the results of the index lookup join query with and without match + /// column. + void verifyResultWithMatchColumn( + const PlanNodePtr& planWithoutMatchColumn, + const PlanNodeId& probeScanNodeIdWithoutMatchColumn, + const PlanNodePtr& planWithMatchColumn, + const PlanNodeId& probeScanNodeIdWithMatchColumn, + const std::vector>& probeFiles); + + RowTypePtr keyType_; + std::optional partitionType_; + RowTypePtr valueType_; + RowTypePtr tableType_; + RowTypePtr probeType_; + PlanNodeId joinNodeId_; + PlanNodeId indexScanNodeId_; + PlanNodeId probeScanNodeId_; + folly::Random::DefaultGenerator rng_; }; -} // namespace fecebook::velox::exec::test +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/LocalExchangeSource.cpp b/velox/exec/tests/utils/LocalExchangeSource.cpp index ca5545fd5a3f..87cdbc966567 100644 --- a/velox/exec/tests/utils/LocalExchangeSource.cpp +++ b/velox/exec/tests/utils/LocalExchangeSource.cpp @@ -17,6 +17,7 @@ #include #include #include "velox/common/testutil/TestValue.h" +#include "velox/exec/Operator.h" #include "velox/exec/OutputBufferManager.h" namespace facebook::velox::exec::test { @@ -89,7 +90,7 @@ class LocalExchangeSource : public exec::ExchangeSource { if (data.empty()) { sequence = requestedSequence; } - std::vector> pages; + std::vector> pages; bool atEnd = false; int64_t totalBytes = 0; for (auto& inputPage : data) { @@ -100,7 +101,8 @@ class LocalExchangeSource : public exec::ExchangeSource { } totalBytes += inputPage->length(); inputPage->unshare(); - pages.push_back(std::make_unique(std::move(inputPage))); + pages.push_back( + std::make_unique(std::move(inputPage))); inputPage = nullptr; } numPages_ += pages.size(); @@ -189,7 +191,7 @@ class LocalExchangeSource : public exec::ExchangeSource { {"localExchangeSource.numPages", RuntimeMetric(numPages_)}, {"localExchangeSource.totalBytes", RuntimeMetric(totalBytes_, RuntimeCounter::Unit::kBytes)}, - {ExchangeClient::kBackgroundCpuTimeMs, + {Operator::kBackgroundCpuTimeNanos, RuntimeMetric(123 * 1000000, RuntimeCounter::Unit::kNanos)}, }; } @@ -270,7 +272,7 @@ class LocalExchangeSource : public exec::ExchangeSource { } bool checkSetRequestPromise() { - VeloxPromise promise; + VeloxPromise promise{VeloxPromise::makeEmpty()}; { std::lock_guard l(queue_->mutex()); promise = std::move(promise_); diff --git a/velox/exec/tests/utils/LocalExchangeSource.h b/velox/exec/tests/utils/LocalExchangeSource.h index 2a69f6343af2..30a9dd60a111 100644 --- a/velox/exec/tests/utils/LocalExchangeSource.h +++ b/velox/exec/tests/utils/LocalExchangeSource.h @@ -14,7 +14,8 @@ * limitations under the License. */ #pragma once -#include "velox/exec/Exchange.h" + +#include "velox/exec/ExchangeSource.h" namespace facebook::velox::exec::test { diff --git a/velox/exec/tests/utils/LocalRunnerTestBase.cpp b/velox/exec/tests/utils/LocalRunnerTestBase.cpp deleted file mode 100644 index 398df844ef17..000000000000 --- a/velox/exec/tests/utils/LocalRunnerTestBase.cpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/exec/tests/utils/LocalRunnerTestBase.h" -#include "velox/connectors/hive/HiveConfig.h" -#include "velox/exec/tests/utils/LocalExchangeSource.h" - -namespace facebook::velox::exec::test { - -void LocalRunnerTestBase::SetUp() { - HiveConnectorTestBase::SetUp(); - exec::ExchangeSource::factories().clear(); - exec::ExchangeSource::registerFactory(createLocalExchangeSource); - ensureTestData(); -} - -std::shared_ptr LocalRunnerTestBase::makeQueryCtx( - const std::string& queryId, - memory::MemoryPool* rootPool) { - auto& config = config_; - auto hiveConfig = hiveConfig_; - std::unordered_map> - connectorConfigs; - connectorConfigs[kHiveConnectorId] = - std::make_shared(std::move(hiveConfig)); - - return core::QueryCtx::create( - schemaExecutor_.get(), - core::QueryConfig(config), - std::move(connectorConfigs), - cache::AsyncDataCache::getInstance(), - rootPool->shared_from_this(), - nullptr, - queryId); -} - -void LocalRunnerTestBase::ensureTestData() { - if (!files_) { - makeTables(testTables_, files_); - } - // Destroy and rebuild the testing connector. The connector will - // show the metadata if the connector is wired for metadata. - setupConnector(); -} - -void LocalRunnerTestBase::setupConnector() { - connector::unregisterConnector(kHiveConnectorId); - - std::unordered_map configs; - configs[connector::hive::HiveConfig::kLocalDataPath] = testDataPath_; - configs[connector::hive::HiveConfig::kLocalFileFormat] = localFileFormat_; - auto hiveConnector = - connector::getConnectorFactory( - connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector( - kHiveConnectorId, - std::make_shared(std::move(configs)), - ioExecutor_.get()); - connector::registerConnector(hiveConnector); -} - -void LocalRunnerTestBase::makeTables( - std::vector specs, - std::shared_ptr& directory) { - if (initialized_) { - return; - } - initialized_ = true; - if (testDataPath_.empty()) { - directory = exec::test::TempDirectoryPath::create(); - testDataPath_ = directory->getPath(); - } - for (auto& spec : specs) { - auto tablePath = fmt::format("{}/{}", testDataPath_, spec.name); - auto fs = filesystems::getFileSystem(tablePath, {}); - fs->mkdir(tablePath); - for (auto i = 0; i < spec.numFiles; ++i) { - auto vectors = HiveConnectorTestBase::makeVectors( - spec.columns, spec.numVectorsPerFile, spec.rowsPerVector); - if (spec.customizeData) { - for (auto& vector : vectors) { - spec.customizeData(vector); - } - } - auto filePath = fmt::format("{}/f{}", tablePath, i); - tableFilePaths_[spec.name].push_back(filePath); - writeToFile(filePath, vectors); - } - } -} - -std::shared_ptr -LocalRunnerTestBase::makeSimpleSplitSourceFactory( - const runner::MultiFragmentPlanPtr& plan) { - std::unordered_map< - core::PlanNodeId, - std::vector>> - nodeSplitMap; - for (auto& fragment : plan->fragments()) { - for (auto& scan : fragment.scans) { - auto& name = scan->tableHandle()->name(); - auto files = tableFilePaths_[name]; - VELOX_CHECK(!files.empty(), "No splits known for {}", name); - std::vector> splits; - for (auto& file : files) { - splits.push_back(connector::hive::HiveConnectorSplitBuilder(file) - .connectorId(kHiveConnectorId) - .fileFormat(dwio::common::FileFormat::DWRF) - .build()); - } - nodeSplitMap[scan->id()] = std::move(splits); - } - }; - return std::make_shared( - std::move(nodeSplitMap)); -} - -std::vector readCursor( - std::shared_ptr runner) { - // 'result' borrows memory from cursor so the life cycle must be shorter. - std::vector result; - - while (auto rows = runner->next()) { - result.push_back(rows); - } - return result; -} - -} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/LocalRunnerTestBase.h b/velox/exec/tests/utils/LocalRunnerTestBase.h deleted file mode 100644 index 0303a153d503..000000000000 --- a/velox/exec/tests/utils/LocalRunnerTestBase.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "velox/exec/tests/utils/HiveConnectorTestBase.h" -#include "velox/exec/tests/utils/TempDirectoryPath.h" -#include "velox/runner/LocalRunner.h" - -namespace facebook::velox::exec::test { - -struct TableSpec { - std::string name; - RowTypePtr columns; - int32_t rowsPerVector{10000}; - int32_t numVectorsPerFile{5}; - int32_t numFiles{5}; - - /// Function Applied to generated RowVectors for the table before writing. - /// May be used to insert non-random data on top of the random datafrom - /// HiveConnectorTestBase::makeVectors. - std::function customizeData; -}; - -/// Test helper class that manages a TestCase with a set of generated -/// tables and a HiveConnector that exposes the files and their -/// metadata. The lifetime the test data is the test case consisting -/// of multiple google unit test cases. -class LocalRunnerTestBase : public HiveConnectorTestBase { - protected: - static void SetUpTestCase() { - HiveConnectorTestBase::SetUpTestCase(); - schemaExecutor_ = std::make_unique(4); - } - - static void TearDownTestCase() { - initialized_ = false; - files_.reset(); - HiveConnectorTestBase::TearDownTestCase(); - } - - void SetUp() override; - - void ensureTestData(); - - /// Re-creates the connector with kHiveConnectorId with a config - /// that points to the temp directory created by 'this'. If the - /// connector factory is wired to capture metadata then the metadata - /// will be available through the connector. - void setupConnector(); - - /// Returns a split source factory that contains splits for the table scans in - /// 'plan'. 'plan' should refer to testing tables created by 'this'. - std::shared_ptr - makeSimpleSplitSourceFactory(const runner::MultiFragmentPlanPtr& plan); - - void makeTables( - std::vector specs, - std::shared_ptr& directory); - - // Creates a QueryCtx with 'pool'. 'pool' must be a root pool. - static std::shared_ptr makeQueryCtx( - const std::string& queryId, - memory::MemoryPool* pool); - - // Configs for creating QueryCtx. - inline static std::unordered_map config_; - inline static std::unordered_map hiveConfig_; - - // The specification of the test data. The data is created in ensureTestData() - // called from each SetUp()(. - inline static std::vector testTables_; - - // The top level directory with the test data. - inline static bool initialized_; - inline static std::string testDataPath_; - inline static std::string localFileFormat_{"dwrf"}; - inline static std::shared_ptr files_; - /// Map from table name to list of file system paths. - inline static std::unordered_map> - tableFilePaths_; - inline static std::unique_ptr schemaExecutor_; -}; - -/// Reads all results from 'runner'. -std::vector readCursor( - std::shared_ptr runner); - -} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/MergeTestBase.h b/velox/exec/tests/utils/MergeTestBase.h index 54d4bdbac700..e20118f8df1a 100644 --- a/velox/exec/tests/utils/MergeTestBase.h +++ b/velox/exec/tests/utils/MergeTestBase.h @@ -15,8 +15,9 @@ */ #include "velox/common/base/Exceptions.h" +#include "velox/common/base/TreeOfLosers.h" #include "velox/common/time/Timer.h" -#include "velox/exec/TreeOfLosers.h" +#include "velox/exec/Spill.h" #include @@ -100,6 +101,44 @@ class TestingStream final : public MergeStream { std::vector numbers_; }; +class TestingSpillMergeStream : public SpillMergeStream { + public: + TestingSpillMergeStream( + uint32_t id, + const std::vector& sortingKeys, + RowVectorPtr rowVector) + : id_(id), sortingKeys_(sortingKeys) { + rowVector_ = rowVector; + size_ = rowVector_->size(); + } + + uint32_t id() const override { + return id_; + } + + private: + const std::vector& sortingKeys() const override { + VELOX_CHECK(!closed_); + return sortingKeys_; + } + + void nextBatch() override { + VELOX_CHECK(!closed_); + index_ = 0; + size_ = 0; + close(); + rowVector_.reset(); + } + + void close() override { + VELOX_CHECK(!closed_); + SpillMergeStream::close(); + } + + uint32_t id_; + const std::vector sortingKeys_; +}; + // Test data for merging. struct TestData { // Globally sorted sequence of test keys. diff --git a/velox/exec/tests/utils/OperatorTestBase.cpp b/velox/exec/tests/utils/OperatorTestBase.cpp index 297010acb4cc..e235f1249057 100644 --- a/velox/exec/tests/utils/OperatorTestBase.cpp +++ b/velox/exec/tests/utils/OperatorTestBase.cpp @@ -24,7 +24,6 @@ #include "velox/exec/tests/utils/LocalExchangeSource.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" -#include "velox/parse/Expressions.h" #include "velox/parse/ExpressionsParser.h" #include "velox/parse/TypeResolver.h" #include "velox/serializers/CompactRowSerializer.h" @@ -154,6 +153,7 @@ void OperatorTestBase::SetUp() { void OperatorTestBase::TearDown() { waitForAllTasksToBeDeleted(); stopPeriodicStatsReporter(); + executor_.reset(); // There might be lingering exchange source on executor even after all tasks // are deleted. This can cause memory leak because exchange source holds // reference to memory pool. We need to make sure they are properly cleaned. diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index 1ddf46c61976..b204ba76d265 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -17,6 +17,7 @@ #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/TableHandle.h" +#include "velox/connectors/tpcds/TpcdsConnector.h" #include "velox/connectors/tpch/TpchConnector.h" #include "velox/duckdb/conversion/DuckParser.h" #include "velox/exec/Aggregate.h" @@ -24,9 +25,10 @@ #include "velox/exec/RoundRobinPartitionFunction.h" #include "velox/exec/TableWriter.h" #include "velox/exec/WindowFunction.h" +#include "velox/exec/tests/utils/AggregationResolver.h" +#include "velox/exec/tests/utils/FilterToExpression.h" #include "velox/expression/Expr.h" #include "velox/expression/ExprToSubfieldFilter.h" -#include "velox/expression/FunctionCallToSpecialForm.h" #include "velox/expression/SignatureBinder.h" #include "velox/expression/VectorReaders.h" #include "velox/parse/Expressions.h" @@ -72,16 +74,14 @@ PlanBuilder& PlanBuilder::tableScan( const std::vector& subfieldFilters, const std::string& remainingFilter, const RowTypePtr& dataColumns, - const std::unordered_map< - std::string, - std::shared_ptr>& assignments) { + const connector::ColumnHandleMap& assignments) { return TableScanBuilder(*this) .filtersAsNode(filtersAsNode_ ? planNodeIdGenerator_ : nullptr) .outputType(outputType) .assignments(assignments) + .dataColumns(dataColumns) .subfieldFilters(subfieldFilters) .remainingFilter(remainingFilter) - .dataColumns(dataColumns) .endTableScan(); } @@ -92,28 +92,42 @@ PlanBuilder& PlanBuilder::tableScan( const std::vector& subfieldFilters, const std::string& remainingFilter, const RowTypePtr& dataColumns, - const std::unordered_map< - std::string, - std::shared_ptr>& assignments) { + const connector::ColumnHandleMap& assignments) { return TableScanBuilder(*this) .filtersAsNode(filtersAsNode_ ? planNodeIdGenerator_ : nullptr) .tableName(tableName) .outputType(outputType) .columnAliases(columnAliases) + .dataColumns(dataColumns) + .subfieldFilters(subfieldFilters) .remainingFilter(remainingFilter) - .dataColumns(dataColumns) .assignments(assignments) .endTableScan(); } +PlanBuilder& PlanBuilder::tableScanWithPushDown( + const RowTypePtr& outputType, + const PushdownConfig& pushdownConfig, + const RowTypePtr& dataColumns, + const connector::ColumnHandleMap& assignments) { + return TableScanBuilder(*this) + .filtersAsNode(filtersAsNode_ ? planNodeIdGenerator_ : nullptr) + .outputType(outputType) + .assignments(assignments) + .dataColumns(dataColumns) + .subfieldFiltersMap(pushdownConfig.subfieldFiltersMap) + .remainingFilter(pushdownConfig.remainingFilter) + .endTableScan(); +} + PlanBuilder& PlanBuilder::tpchTableScan( tpch::Table table, std::vector columnNames, double scaleFactor, - std::string_view connectorId) { - std::unordered_map> - assignmentsMap; + std::string_view connectorId, + const std::string& filter) { + connector::ColumnHandleMap assignmentsMap; std::vector outputTypes; assignmentsMap.reserve(columnNames.size()); @@ -126,23 +140,150 @@ PlanBuilder& PlanBuilder::tpchTableScan( outputTypes.emplace_back(resolveTpchColumn(table, columnName)); } auto rowType = ROW(std::move(columnNames), std::move(outputTypes)); + + core::TypedExprPtr filterExpression; + if (!filter.empty()) { + auto expression = parse::parseExpr(filter, options_); + filterExpression = + core::Expressions::inferTypes(expression, rowType, pool_); + } + + auto tableHandle = std::make_shared( + std::string(connectorId), + table, + scaleFactor, + std::move(filterExpression)); + return TableScanBuilder(*this) .filtersAsNode(filtersAsNode_ ? planNodeIdGenerator_ : nullptr) .outputType(rowType) - .tableHandle(std::make_shared( - std::string(connectorId), table, scaleFactor)) + .tableHandle(tableHandle) .assignments(assignmentsMap) .endTableScan(); } +PlanBuilder& PlanBuilder::tpcdsTableScan( + tpcds::Table table, + std::vector columnNames, + double scaleFactor, + std::string_view connectorId) { + std:: + unordered_map> + assignmentsMap; + std::vector outputTypes; + + assignmentsMap.reserve(columnNames.size()); + outputTypes.reserve(columnNames.size()); + + for (const auto& columnName : columnNames) { + assignmentsMap.emplace( + columnName, + std::make_shared(columnName)); + outputTypes.emplace_back(resolveTpcdsColumn(table, columnName)); + } + auto rowType = ROW(std::move(columnNames), std::move(outputTypes)); + return TableScanBuilder(*this) + .outputType(rowType) + .tableHandle( + std::make_shared( + std::string(connectorId), table, scaleFactor)) + .assignments(assignmentsMap) + .endTableScan(); +} + +namespace { + +// Analyzes 'expr' to determine if it can be expressed as a subfield filter. +// Returns a pair of subfield and filter if so. Otherwise, throws. +// +// Supports all expressions supported by +// exec::ExprToSubfieldFilterParser::leafCallToSubfieldFilter + negations and +// disjunctions over same subfield. +// +// Examples: +// a = 1 +// a = 1 OR a > 10 +// not (a = 1) +std::pair> toSubfieldFilter( + const core::TypedExprPtr& expr, + core::ExpressionEvaluator* evaluator) { + if (expr->isCallKind(); + auto* call = expr->asUnchecked()) { + if (call->name() == "or") { + VELOX_CHECK_EQ(call->inputs().size(), 2); + auto left = toSubfieldFilter(call->inputs()[0], evaluator); + auto right = toSubfieldFilter(call->inputs()[1], evaluator); + VELOX_CHECK(left.first == right.first); + auto filter = exec::ExprToSubfieldFilterParser::makeOrFilter( + std::move(left.second), std::move(right.second)); + VELOX_CHECK_NOT_NULL(filter); + return {std::move(left.first), std::move(filter)}; + } + + if (call->name() == "not") { + const auto& input = call->inputs()[0]; + if (input->isCallKind(); + auto* inner = input->asUnchecked()) { + if (auto result = + exec::ExprToSubfieldFilterParser::getInstance() + ->leafCallToSubfieldFilter(*inner, evaluator, true)) { + return std::move(result.value()); + } + } + } else { + if (auto result = + exec::ExprToSubfieldFilterParser::getInstance() + ->leafCallToSubfieldFilter(*call, evaluator, false)) { + return std::move(result.value()); + } + } + } + VELOX_UNSUPPORTED( + "Unsupported expression for range filter: {}", expr->toString()); +} +} // namespace + PlanBuilder::TableScanBuilder& PlanBuilder::TableScanBuilder::subfieldFilters( std::vector subfieldFilters) { - subfieldFilters_.clear(); - subfieldFilters_.reserve(subfieldFilters.size()); + VELOX_CHECK(subfieldFiltersMap_.empty()); + + if (subfieldFilters.empty()) { + return *this; + } + + // Parse subfield filters + auto queryCtx = core::QueryCtx::create(); + exec::SimpleExpressionEvaluator evaluator(queryCtx.get(), planBuilder_.pool_); + const RowTypePtr& parseType = dataColumns_ ? dataColumns_ : outputType_; for (const auto& filter : subfieldFilters) { - subfieldFilters_.emplace_back( - parse::parseExpr(filter, planBuilder_.options_)); + auto untypedExpr = parse::parseExpr(filter, planBuilder_.options_); + + // Parse directly to subfieldFiltersMap_ + auto filterExpr = core::Expressions::inferTypes( + untypedExpr, parseType, planBuilder_.pool_); + auto [subfield, subfieldFilter] = toSubfieldFilter(filterExpr, &evaluator); + + auto it = columnAliases_.find(subfield.toString()); + if (it != columnAliases_.end()) { + subfield = common::Subfield(it->second); + } + VELOX_CHECK_EQ( + subfieldFiltersMap_.count(subfield), + 0, + "Duplicate subfield: {}", + subfield.toString()); + + subfieldFiltersMap_[std::move(subfield)] = std::move(subfieldFilter); + } + return *this; +} + +PlanBuilder::TableScanBuilder& +PlanBuilder::TableScanBuilder::subfieldFiltersMap( + const common::SubfieldFilters& filtersMap) { + for (const auto& [k, v] : filtersMap) { + subfieldFiltersMap_[k.clone()] = v->clone(); } return *this; } @@ -155,6 +296,11 @@ PlanBuilder::TableScanBuilder& PlanBuilder::TableScanBuilder::remainingFilter( return *this; } +PlanBuilder::TableScanBuilder& PlanBuilder::TableScanBuilder::sampleRate( + double sampleRate) { + sampleRate_ = sampleRate; + return *this; +} namespace { void addConjunct( const core::TypedExprPtr& conjunct, @@ -163,9 +309,7 @@ void addConjunct( conjunction = conjunct; } else { conjunction = std::make_shared( - BOOLEAN(), - std::vector{conjunction, conjunct}, - "and"); + BOOLEAN(), "and", conjunction, conjunct); } } } // namespace @@ -198,35 +342,30 @@ core::PlanNodePtr PlanBuilder::TableScanBuilder::build(core::PlanNodeId id) { } } - const RowTypePtr& parseType = dataColumns_ ? dataColumns_ : outputType_; + RowTypePtr parseType = dataColumns_ ? dataColumns_ : outputType_; + if (!filterColumnHandles_.empty()) { + auto names = parseType->names(); + auto types = parseType->children(); + for (auto& handle : filterColumnHandles_) { + if (!parseType->containsChild(handle->name())) { + names.push_back(handle->name()); + types.push_back(handle->hiveType()); + } + } + parseType = ROW(std::move(names), std::move(types)); + } core::TypedExprPtr filterNodeExpr; - common::SubfieldFilters filters; - filters.reserve(subfieldFilters_.size()); - auto queryCtx = core::QueryCtx::create(); - exec::SimpleExpressionEvaluator evaluator(queryCtx.get(), planBuilder_.pool_); - - for (const auto& filter : subfieldFilters_) { - auto filterExpr = - core::Expressions::inferTypes(filter, parseType, planBuilder_.pool_); - if (filtersAsNode_) { - addConjunct(filterExpr, filterNodeExpr); - } else { - auto [subfield, subfieldFilter] = - exec::toSubfieldFilter(filterExpr, &evaluator); - auto it = columnAliases_.find(subfield.toString()); - if (it != columnAliases_.end()) { - subfield = common::Subfield(it->second); - } - VELOX_CHECK_EQ( - filters.count(subfield), - 0, - "Duplicate subfield: {}", - subfield.toString()); + if (filtersAsNode_) { + for (const auto& [subfield, filter] : subfieldFiltersMap_) { + auto filterExpr = core::test::filterToExpr( + subfield, filter.get(), parseType, planBuilder_.pool_); - filters[std::move(subfield)] = std::move(subfieldFilter); + addConjunct(filterExpr, filterNodeExpr); } + + subfieldFiltersMap_.clear(); } core::TypedExprPtr remainingFilterExpr; @@ -245,12 +384,15 @@ core::PlanNodePtr PlanBuilder::TableScanBuilder::build(core::PlanNodeId id) { connectorId_, tableName_, true, - std::move(filters), + std::move(subfieldFiltersMap_), remainingFilterExpr, - dataColumns_); + dataColumns_, + /*tableParameters=*/std::unordered_map{}, + filterColumnHandles_); } core::PlanNodePtr result = std::make_shared( id, outputType_, tableHandle_, assignments_); + if (filtersAsNode_ && filterNodeExpr) { auto filterId = planNodeIdGenerator_->next(); result = @@ -314,32 +456,32 @@ core::PlanNodePtr PlanBuilder::TableWriterBuilder::build(core::PlanNodeId id) { std::make_shared(connectorId_, hiveHandle); } - std::shared_ptr aggregationNode; + std::optional columnStatsSpec; if (!aggregates_.empty()) { auto aggregatesAndNames = planBuilder_.createAggregateExpressionsAndNames( aggregates_, {}, core::AggregationNode::Step::kPartial); - aggregationNode = std::make_shared( - planBuilder_.nextPlanNodeId(), + std::vector groupingKeys; + groupingKeys.reserve(partitionBy_.size()); + for (const auto& partitionBy : partitionBy_) { + groupingKeys.push_back( + std::make_shared( + outputType->findChild(partitionBy), partitionBy)); + } + columnStatsSpec = core::ColumnStatsSpec( + std::move(groupingKeys), core::AggregationNode::Step::kPartial, - std::vector{}, // groupingKeys - std::vector{}, // preGroupedKeys - aggregatesAndNames.names, // ignoreNullKeys - aggregatesAndNames.aggregates, - false, - upstreamNode); - VELOX_CHECK_EQ( - aggregationNode->supportsBarrier(), aggregationNode->isPreGrouped()); + aggregatesAndNames.names, + aggregatesAndNames.aggregates); } - const auto writeNode = std::make_shared( id, outputType, outputType->names(), - aggregationNode, + columnStatsSpec, insertHandle_, false, - TableWriteTraits::outputType(aggregationNode), - connector::CommitStrategy::kNoCommit, + TableWriteTraits::outputType(columnStatsSpec), + commitStrategy_, upstreamNode); VELOX_CHECK(!writeNode->supportsBarrier()); return writeNode; @@ -393,9 +535,9 @@ parseOrderByClauses( std::vector> sortingKeys; std::vector sortingOrders; for (const auto& key : keys) { - auto [untypedExpr, sortOrder] = parse::parseOrderByExpr(key); + auto orderBy = parse::parseOrderByExpr(key); auto typedExpr = - core::Expressions::inferTypes(untypedExpr, inputType, pool); + core::Expressions::inferTypes(orderBy.expr, inputType, pool); auto sortingKey = std::dynamic_pointer_cast(typedExpr); @@ -404,7 +546,7 @@ parseOrderByClauses( "ORDER BY clause must use a column name, not an expression: {}", key); sortingKeys.emplace_back(sortingKey); - sortingOrders.emplace_back(sortOrder); + sortingOrders.emplace_back(orderBy.ascending, orderBy.nullsFirst); } return {sortingKeys, sortingOrders}; @@ -444,7 +586,7 @@ PlanBuilder& PlanBuilder::projectExpressions( } else if ( auto fieldExpr = dynamic_cast(projections[i].get())) { - projectNames.push_back(fieldExpr->getFieldName()); + projectNames.push_back(fieldExpr->name()); } else { projectNames.push_back(fmt::format("p{}", i)); } @@ -466,7 +608,7 @@ PlanBuilder& PlanBuilder::projectExpressions( expressions.push_back(projections[i]); if (auto fieldExpr = dynamic_cast(projections[i].get())) { - projectNames.push_back(fieldExpr->getFieldName()); + projectNames.push_back(fieldExpr->name()); } else { projectNames.push_back(fmt::format("p{}", i)); } @@ -483,12 +625,79 @@ PlanBuilder& PlanBuilder::projectExpressions( PlanBuilder& PlanBuilder::project(const std::vector& projections) { VELOX_CHECK_NOT_NULL(planNode_, "Project cannot be the source node"); std::vector> expressions; + expressions.reserve(projections.size()); for (auto i = 0; i < projections.size(); ++i) { expressions.push_back(parse::parseExpr(projections[i], options_)); } return projectExpressions(expressions); } +PlanBuilder& PlanBuilder::parallelProject( + const std::vector>& projectionGroups, + const std::vector& noLoadColumns) { + VELOX_CHECK_NOT_NULL(planNode_, "ParallelProject cannot be the source node"); + + std::vector names; + + std::vector> exprGroups; + exprGroups.reserve(projectionGroups.size()); + + size_t i = 0; + + for (const auto& group : projectionGroups) { + std::vector typedExprs; + typedExprs.reserve(group.size()); + + for (const auto& expr : group) { + const auto typedExpr = inferTypes(parse::parseExpr(expr, options_)); + typedExprs.push_back(typedExpr); + + if (auto fieldExpr = + dynamic_cast(typedExpr.get())) { + names.push_back(fieldExpr->name()); + } else { + names.push_back(fmt::format("p{}", i)); + } + + ++i; + } + exprGroups.push_back(std::move(typedExprs)); + } + + planNode_ = std::make_shared( + nextPlanNodeId(), + std::move(names), + std::move(exprGroups), + noLoadColumns, + planNode_); + + return *this; +} + +PlanBuilder& PlanBuilder::lazyDereference( + const std::vector& projections) { + VELOX_CHECK_NOT_NULL(planNode_, "LazyDeference cannot be the source node"); + std::vector expressions; + std::vector projectNames; + for (auto i = 0; i < projections.size(); ++i) { + auto expr = inferTypes(parse::parseExpr(projections[i], options_)); + expressions.push_back(expr); + if (auto* fieldExpr = + dynamic_cast(expr.get())) { + projectNames.push_back(fieldExpr->name()); + } else { + projectNames.push_back(fmt::format("p{}", i)); + } + } + planNode_ = std::make_shared( + nextPlanNodeId(), + std::move(projectNames), + std::move(expressions), + planNode_); + VELOX_CHECK(planNode_->supportsBarrier()); + return *this; +} + PlanBuilder& PlanBuilder::appendColumns( const std::vector& newColumns) { VELOX_CHECK_NOT_NULL(planNode_, "Project cannot be the source node"); @@ -507,15 +716,20 @@ PlanBuilder& PlanBuilder::optionalFilter(const std::string& optionalFilter) { return filter(optionalFilter); } -PlanBuilder& PlanBuilder::filter(const std::string& filter) { +PlanBuilder& PlanBuilder::filter(const core::ExprPtr& filterExpr) { VELOX_CHECK_NOT_NULL(planNode_, "Filter cannot be the source node"); - auto expr = parseExpr(filter, planNode_->outputType(), options_, pool_); - planNode_ = - std::make_shared(nextPlanNodeId(), expr, planNode_); + auto typedExpr = + core::Expressions::inferTypes(filterExpr, planNode_->outputType(), pool_); + planNode_ = std::make_shared( + nextPlanNodeId(), typedExpr, planNode_); VELOX_CHECK(planNode_->supportsBarrier()); return *this; } +PlanBuilder& PlanBuilder::filter(const std::string& filterExpr) { + return filter(parse::parseExpr(filterExpr, options_)); +} + PlanBuilder& PlanBuilder::tableWrite( const std::string& outputDirectoryPath, const dwio::common::FileFormat fileFormat, @@ -579,7 +793,9 @@ PlanBuilder& PlanBuilder::tableWrite( const std::string& outputFileName, const common::CompressionKind compressionKind, const RowTypePtr& schema, - const bool ensureFiles) { + const bool ensureFiles, + const connector::CommitStrategy commitStrategy, + std::shared_ptr insertTableHandle) { return TableWriterBuilder(*this) .outputDirectoryPath(outputDirectoryPath) .outputFileName(outputFileName) @@ -595,147 +811,74 @@ PlanBuilder& PlanBuilder::tableWrite( .options(options) .compressionKind(compressionKind) .ensureFiles(ensureFiles) + .commitStrategy(commitStrategy) + .insertHandle(insertTableHandle) .endTableWriter(); } -PlanBuilder& PlanBuilder::tableWriteMerge( - const std::shared_ptr& aggregationNode) { - planNode_ = std::make_shared( - nextPlanNodeId(), - TableWriteTraits::outputType(aggregationNode), - aggregationNode, - planNode_); - VELOX_CHECK(!planNode_->supportsBarrier()); - return *this; -} - namespace { - -std::string throwAggregateFunctionDoesntExist(const std::string& name) { - std::stringstream error; - error << "Aggregate function doesn't exist: " << name << "."; - exec::aggregateFunctions().withRLock([&](const auto& functionsMap) { - if (functionsMap.empty()) { - error << " Registry of aggregate functions is empty. " - "Make sure to register some aggregate functions."; - } - }); - VELOX_USER_FAIL(error.str()); -} - -std::string throwAggregateFunctionSignatureNotSupported( - const std::string& name, - const std::vector& types, - const std::vector>& - signatures) { - std::stringstream error; - error << "Aggregate function signature is not supported: " - << toString(name, types) - << ". Supported signatures: " << toString(signatures) << "."; - VELOX_USER_FAIL(error.str()); -} - -TypePtr resolveAggregateType( - const std::string& aggregateName, - core::AggregationNode::Step step, - const std::vector& rawInputTypes, - bool nullOnFailure) { - if (auto signatures = exec::getAggregateFunctionSignatures(aggregateName)) { - for (const auto& signature : signatures.value()) { - exec::SignatureBinder binder(*signature, rawInputTypes); - if (binder.tryBind()) { - return binder.tryResolveType( - exec::isPartialOutput(step) ? signature->intermediateType() - : signature->returnType()); - } +// Finds the table writer source node rooted from 'node'. +const core::TableWriteNodePtr findTableWrite(const core::PlanNodePtr planNode) { + if (auto writer = + std::dynamic_pointer_cast(planNode)) { + return writer; + } + for (const auto& source : planNode->sources()) { + if (auto writer = findTableWrite(source)) { + return writer; } - - if (nullOnFailure) { - return nullptr; - } - - throwAggregateFunctionSignatureNotSupported( - aggregateName, rawInputTypes, signatures.value()); - } - - // We may be parsing lambda expression used in a lambda aggregate function. In - // this case, 'aggregateName' would refer to a scalar function. - // - // TODO Enhance the parser to allow for specifying separate resolver for - // lambda expressions. - if (auto type = - exec::resolveTypeForSpecialForm(aggregateName, rawInputTypes)) { - return type; } - - if (auto type = parse::resolveScalarFunctionType( - aggregateName, rawInputTypes, true)) { - return type; - } - - if (nullOnFailure) { - return nullptr; - } - - throwAggregateFunctionDoesntExist(aggregateName); return nullptr; } +} // namespace -class AggregateTypeResolver { - public: - explicit AggregateTypeResolver(core::AggregationNode::Step step) - : step_(step), previousHook_(core::Expressions::getResolverHook()) { - core::Expressions::setTypeResolverHook( - [&](const auto& inputs, const auto& expr, bool nullOnFailure) { - return resolveType(inputs, expr, nullOnFailure); - }); - } - - ~AggregateTypeResolver() { - core::Expressions::setTypeResolverHook(previousHook_); - } - - void setRawInputTypes(const std::vector& types) { - rawInputTypes_ = types; - } - - private: - TypePtr resolveType( - const std::vector& inputs, - const std::shared_ptr& expr, - bool nullOnFailure) const { - auto functionName = expr->getFunctionName(); - - // Use raw input types (if available) to resolve intermediate and final - // result types. - if (exec::isRawInput(step_)) { - std::vector types; - for (auto& input : inputs) { - types.push_back(input->type()); - } - - return resolveAggregateType(functionName, step_, types, nullOnFailure); - } +PlanBuilder& PlanBuilder::tableWriteMerge() { + VELOX_CHECK_NOT_NULL(planNode_, "TableWriteMerge cannot be the source node"); + auto writer = findTableWrite(planNode_); + VELOX_CHECK_NOT_NULL( + writer, "TableWriteMerge can only be added after TableWrite node"); - if (!rawInputTypes_.empty()) { - return resolveAggregateType( - functionName, step_, rawInputTypes_, nullOnFailure); + std::optional columnStatsSpec; + if (writer->hasColumnStatsSpec()) { + const auto writerSpec = writer->columnStatsSpec().value(); + VELOX_CHECK_EQ( + writerSpec.aggregationStep, core::AggregationNode::Step::kPartial); + std::vector> aggregateRawInputs; + const auto numAggregates = writerSpec.aggregates.size(); + aggregateRawInputs.reserve(numAggregates); + for (const auto& aggregate : writerSpec.aggregates) { + aggregateRawInputs.push_back(aggregate.rawInputTypes); } - - if (!nullOnFailure) { - VELOX_USER_FAIL( - "Cannot resolve aggregation function return type without raw input types: {}", - functionName); + const auto& inputType = planNode_->outputType(); + + std::vector aggregateNames; + aggregateNames.reserve(numAggregates); + std::vector aggregates; + aggregates.reserve(numAggregates); + for (int i = 0; i < numAggregates; ++i) { + core::AggregationNode::Aggregate aggregate = writerSpec.aggregates[i]; + aggregate.call = std::make_shared( + aggregate.call->type(), + aggregate.call->name(), + field(inputType, writerSpec.aggregateNames[i])); + aggregates.push_back(std::move(aggregate)); + aggregateNames.push_back(fmt::format("a{}", i)); } - return nullptr; + columnStatsSpec = core::ColumnStatsSpec{ + writerSpec.groupingKeys, + core::AggregationNode::Step::kIntermediate, + std::move(aggregateNames), + std::move(aggregates)}; } - const core::AggregationNode::Step step_; - const core::Expressions::TypeResolverHook previousHook_; - std::vector rawInputTypes_; -}; - -} // namespace + planNode_ = std::make_shared( + nextPlanNodeId(), + TableWriteTraits::outputType(columnStatsSpec), + columnStatsSpec, + planNode_); + VELOX_CHECK(!planNode_->supportsBarrier()); + return *this; +} core::PlanNodePtr PlanBuilder::createIntermediateOrFinalAggregation( core::AggregationNode::Step step, @@ -785,6 +928,7 @@ core::PlanNodePtr PlanBuilder::createIntermediateOrFinalAggregation( partialAggNode->aggregateNames(), aggregates, partialAggNode->ignoreNullKeys(), + partialAggNode->noGroupsSpanBatches(), planNode_); VELOX_CHECK_EQ( aggregationNode->supportsBarrier(), aggregationNode->isPreGrouped()); @@ -926,17 +1070,17 @@ PlanBuilder::AggregatesAndNames PlanBuilder::createAggregateExpressionsAndNames( } } - for (const auto& [keyExpr, order] : untypedExpr.orderBy) { + for (const auto& orderBy : untypedExpr.orderBy) { auto sortingKey = std::dynamic_pointer_cast( - inferTypes(keyExpr)); + inferTypes(orderBy.expr)); VELOX_CHECK_NOT_NULL( sortingKey, "ORDER BY clause must use a column name, not an expression: {}", aggregate); agg.sortingKeys.push_back(sortingKey); - agg.sortingOrders.push_back(order); + agg.sortingOrders.emplace_back(orderBy.ascending, orderBy.nullsFirst); } aggs.emplace_back(agg); @@ -990,6 +1134,7 @@ PlanBuilder& PlanBuilder::aggregation( globalGroupingSets, groupId, ignoreNullKeys, + /*noGroupsSpanBatches=*/false, planNode_); VELOX_CHECK_EQ( aggregationNode->supportsBarrier(), aggregationNode->isPreGrouped()); @@ -1002,7 +1147,8 @@ PlanBuilder& PlanBuilder::streamingAggregation( const std::vector& aggregates, const std::vector& masks, core::AggregationNode::Step step, - bool ignoreNullKeys) { + bool ignoreNullKeys, + bool noGroupsSpanBatches) { auto aggregatesAndNames = createAggregateExpressionsAndNames(aggregates, masks, step); auto aggregationNode = std::make_shared( @@ -1013,6 +1159,7 @@ PlanBuilder& PlanBuilder::streamingAggregation( aggregatesAndNames.names, aggregatesAndNames.aggregates, ignoreNullKeys, + noGroupsSpanBatches, planNode_); VELOX_CHECK_EQ( aggregationNode->supportsBarrier(), aggregationNode->isPreGrouped()); @@ -1035,7 +1182,7 @@ PlanBuilder& PlanBuilder::groupId( fieldAccessExpr, "Grouping key {} is not valid projection", groupingKey); - std::string inputField = fieldAccessExpr->getFieldName(); + std::string inputField = fieldAccessExpr->name(); std::string outputField = untypedExpr->alias().has_value() ? // This is a projection with a column alias with the format @@ -1043,7 +1190,7 @@ PlanBuilder& PlanBuilder::groupId( untypedExpr->alias().value() : // This is a projection without a column alias. - fieldAccessExpr->getFieldName(); + fieldAccessExpr->name(); core::GroupIdNode::GroupingKeyInfo keyInfos; keyInfos.output = outputField; @@ -1107,7 +1254,7 @@ PlanBuilder& PlanBuilder::expand( auto fieldExpr = dynamic_cast( untypedExpression.get()); VELOX_CHECK_NOT_NULL(fieldExpr); - aliases.push_back(fieldExpr->getFieldName()); + aliases.push_back(fieldExpr->name()); } projectExpr.push_back(typedExpression); } else { @@ -1121,8 +1268,9 @@ PlanBuilder& PlanBuilder::expand( dynamic_cast(untypedExpression.get()); VELOX_CHECK_NOT_NULL(constantExpr); VELOX_CHECK(constantExpr->value().isNull()); - projectExpr.push_back(std::make_shared( - expectedType, variant::null(expectedType->kind()))); + projectExpr.push_back( + std::make_shared( + expectedType, variant::null(expectedType->kind()))); } } } @@ -1181,7 +1329,7 @@ PlanBuilder& PlanBuilder::topN( PlanBuilder& PlanBuilder::limit(int64_t offset, int64_t count, bool isPartial) { planNode_ = std::make_shared( nextPlanNodeId(), offset, count, isPartial, planNode_); - VELOX_CHECK(!planNode_->supportsBarrier()); + VELOX_CHECK(planNode_->supportsBarrier()); return *this; } @@ -1196,7 +1344,7 @@ PlanBuilder& PlanBuilder::assignUniqueId( const int32_t taskUniqueId) { planNode_ = std::make_shared( nextPlanNodeId(), idName, taskUniqueId, planNode_); - VELOX_CHECK(!planNode_->supportsBarrier()); + VELOX_CHECK(planNode_->supportsBarrier()); return *this; } @@ -1673,6 +1821,45 @@ PlanBuilder& PlanBuilder::nestedLoopJoin( return *this; } +PlanBuilder& PlanBuilder::spatialJoin( + const core::PlanNodePtr& right, + const std::string& joinCondition, + const std::string& probeGeometry, + const std::string& buildGeometry, + const std::optional& radius, + const std::vector& outputLayout, + core::JoinType joinType) { + VELOX_CHECK_NOT_NULL(planNode_, "SpatialJoin cannot be the source node"); + auto probeType = planNode_->outputType(); + auto buildType = right->outputType(); + auto resultType = concat(probeType, buildType); + auto outputType = extract(resultType, outputLayout); + + VELOX_CHECK(!joinCondition.empty(), "SpatialJoin condition cannot be empty"); + core::TypedExprPtr joinConditionExpr = + parseExpr(joinCondition, resultType, options_, pool_); + + auto probeGeometryField = field(probeType, probeGeometry); + auto buildGeometryField = field(buildType, buildGeometry); + std::optional radiusField; + if (radius.has_value()) { + radiusField = field(buildType, radius.value()); + } + + planNode_ = std::make_shared( + nextPlanNodeId(), + joinType, + std::move(joinConditionExpr), + std::move(probeGeometryField), + std::move(buildGeometryField), + std::move(radiusField), + std::move(planNode_), + right, + outputType); + VELOX_CHECK(!planNode_->supportsBarrier()); + return *this; +} + namespace { core::TypedExprPtr removeCastTypedExpr(const core::TypedExprPtr& expr) { core::TypedExprPtr convertedTypedExpr = expr; @@ -1853,8 +2040,21 @@ core::IndexLookupConditionPtr PlanBuilder::parseIndexJoinCondition( castIndexConditionInputExpr( typedCallExpr->inputs()[2], keyColumnExpr->type())); } + + if (typedCallExpr->name() == "eq") { + VELOX_CHECK_EQ(typedCallExpr->inputs().size(), 2); + const auto keyColumnExpr = + std::dynamic_pointer_cast( + removeCastTypedExpr(typedCallExpr->inputs()[0])); + VELOX_CHECK_NOT_NULL( + keyColumnExpr, "{}", typedCallExpr->inputs()[0]->toString()); + return std::make_shared( + keyColumnExpr, + castIndexConditionInputExpr( + typedCallExpr->inputs()[1], keyColumnExpr->type())); + } VELOX_USER_FAIL( - "Invalid index join condition: {}, and we only support in and between conditions", + "Invalid index join condition: {}, and we only support in, between, and equal conditions", joinCondition); } @@ -1863,12 +2063,20 @@ PlanBuilder& PlanBuilder::indexLookupJoin( const std::vector& rightKeys, const core::TableScanNodePtr& right, const std::vector& joinConditions, + const std::string& filter, + bool hasMarker, const std::vector& outputLayout, core::JoinType joinType) { VELOX_CHECK_NOT_NULL(planNode_, "indexLookupJoin cannot be the source node"); - const auto inputType = concat(planNode_->outputType(), right->outputType()); + auto inputType = concat(planNode_->outputType(), right->outputType()); + if (hasMarker) { + auto names = inputType->names(); + names.push_back(outputLayout.back()); + auto types = inputType->children(); + types.push_back(BOOLEAN()); + inputType = ROW(std::move(names), std::move(types)); + } auto outputType = extract(inputType, outputLayout); - auto leftKeyFields = fields(planNode_->outputType(), leftKeys); auto rightKeyFields = fields(right->outputType(), rightKeys); @@ -1879,12 +2087,20 @@ PlanBuilder& PlanBuilder::indexLookupJoin( parseIndexJoinCondition(joinCondition, inputType, pool_)); } + // Parse filter expression if provided + core::TypedExprPtr filterExpr; + if (!filter.empty()) { + filterExpr = parseExpr(filter, inputType, options_, pool_); + } + planNode_ = std::make_shared( nextPlanNodeId(), joinType, std::move(leftKeyFields), std::move(rightKeyFields), std::move(joinConditionPtrs), + filterExpr, + hasMarker, std::move(planNode_), right, std::move(outputType)); @@ -1895,7 +2111,8 @@ PlanBuilder& PlanBuilder::indexLookupJoin( PlanBuilder& PlanBuilder::unnest( const std::vector& replicateColumns, const std::vector& unnestColumns, - const std::optional& ordinalColumn) { + const std::optional& ordinalColumn, + const std::optional& markerName) { VELOX_CHECK_NOT_NULL(planNode_, "Unnest cannot be the source node"); std::vector> replicateFields; @@ -1931,6 +2148,7 @@ PlanBuilder& PlanBuilder::unnest( unnestFields, unnestNames, ordinalColumn, + markerName, planNode_); VELOX_CHECK(planNode_->supportsBarrier()); return *this; @@ -2008,7 +2226,7 @@ class WindowTypeResolver { types.push_back(input->type()); } - auto functionName = expr->getFunctionName(); + const auto& functionName = expr->name(); return resolveWindowType(functionName, types, nullOnFailure); } @@ -2083,8 +2301,9 @@ parseOrderByKeys( std::vector sortingKeys; std::vector sortingOrders; - for (const auto& [untypedExpr, sortOrder] : windowExpr.orderBy) { - auto typedExpr = core::Expressions::inferTypes(untypedExpr, inputRow, pool); + for (const auto& orderBy : windowExpr.orderBy) { + auto typedExpr = + core::Expressions::inferTypes(orderBy.expr, inputRow, pool); auto sortingKey = std::dynamic_pointer_cast(typedExpr); VELOX_CHECK_NOT_NULL( @@ -2092,7 +2311,7 @@ parseOrderByKeys( "ORDER BY clause must use a column name, not an expression: {}", windowString); sortingKeys.emplace_back(sortingKey); - sortingOrders.emplace_back(sortOrder); + sortingOrders.emplace_back(orderBy.ascending, orderBy.nullsFirst); } return {sortingKeys, sortingOrders}; } @@ -2243,7 +2462,8 @@ PlanBuilder& PlanBuilder::rowNumber( return *this; } -PlanBuilder& PlanBuilder::topNRowNumber( +PlanBuilder& PlanBuilder::topNRank( + std::string_view function, const std::vector& partitionKeys, const std::vector& sortingKeys, int32_t limit, @@ -2257,6 +2477,7 @@ PlanBuilder& PlanBuilder::topNRowNumber( } planNode_ = std::make_shared( nextPlanNodeId(), + core::TopNRowNumberNode::rankFunctionFromName(function), fields(partitionKeys), sortingFields, sortingOrders, @@ -2267,6 +2488,15 @@ PlanBuilder& PlanBuilder::topNRowNumber( return *this; } +PlanBuilder& PlanBuilder::topNRowNumber( + const std::vector& partitionKeys, + const std::vector& sortingKeys, + int32_t limit, + bool generateRowNumber) { + return topNRank( + "row_number", partitionKeys, sortingKeys, limit, generateRowNumber); +} + PlanBuilder& PlanBuilder::markDistinct( std::string markerKey, const std::vector& distinctKeys) { @@ -2377,4 +2607,51 @@ core::TypedExprPtr PlanBuilder::inferTypes( return core::Expressions::inferTypes( untypedExpr, planNode_->outputType(), pool_); } + +core::PlanNodePtr PlanBuilder::IndexLookupJoinBuilder::build( + const core::PlanNodeId& id) { + VELOX_CHECK_NOT_NULL( + planBuilder_.planNode_, "IndexLookupJoin cannot be the source node"); + auto inputType = + concat(planBuilder_.planNode_->outputType(), indexSource_->outputType()); + if (hasMarker_) { + auto names = inputType->names(); + names.push_back(outputLayout_.back()); + auto types = inputType->children(); + types.push_back(BOOLEAN()); + inputType = ROW(std::move(names), std::move(types)); + } + auto outputType = extract(inputType, outputLayout_); + auto leftKeyFields = + PlanBuilder::fields(planBuilder_.planNode_->outputType(), leftKeys_); + auto rightKeyFields = + PlanBuilder::fields(indexSource_->outputType(), rightKeys_); + + std::vector joinConditionPtrs{}; + joinConditionPtrs.reserve(joinConditions_.size()); + for (const auto& joinCondition : joinConditions_) { + joinConditionPtrs.push_back( + PlanBuilder::parseIndexJoinCondition( + joinCondition, inputType, planBuilder_.pool_)); + } + + // Parse filter expression if provided + core::TypedExprPtr filterExpr; + if (!filter_.empty()) { + filterExpr = parseExpr( + filter_, inputType, planBuilder_.options_, planBuilder_.pool_); + } + + return std::make_shared( + id, + joinType_, + std::move(leftKeyFields), + std::move(rightKeyFields), + std::move(joinConditionPtrs), + filterExpr, + hasMarker_, + std::move(planBuilder_.planNode_), + indexSource_, + std::move(outputType)); +} } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/PlanBuilder.h b/velox/exec/tests/utils/PlanBuilder.h index 0307d4dcf491..072579380287 100644 --- a/velox/exec/tests/utils/PlanBuilder.h +++ b/velox/exec/tests/utils/PlanBuilder.h @@ -20,7 +20,6 @@ #include #include #include -#include "velox/common/memory/Memory.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/parse/ExpressionsParser.h" #include "velox/parse/IExpr.h" @@ -30,8 +29,24 @@ namespace facebook::velox::tpch { enum class Table : uint8_t; } +namespace facebook::velox::tpcds { +enum class Table : uint8_t; +} + namespace facebook::velox::exec::test { +struct AggregationConfig { + std::vector groupingKeys; + std::vector aggregates; +}; + +struct PushdownConfig { + common::SubfieldFilters subfieldFiltersMap; + std::string remainingFilter; + // Aggregation pushdown configuration + std::optional aggregationConfig; +}; + /// A builder class with fluent API for building query plans. Plans are built /// bottom up starting with the source node (table scan or similar). Expressions /// and orders can be specified using SQL. See filter, project and orderBy @@ -109,6 +124,8 @@ class PlanBuilder { static constexpr const std::string_view kHiveDefaultConnectorId{"test-hive"}; static constexpr const std::string_view kTpchDefaultConnectorId{"test-tpch"}; + static constexpr const std::string_view kTpcdsDefaultConnectorId{ + "test-tpcds"}; /// /// TableScan @@ -136,9 +153,7 @@ class PlanBuilder { const std::vector& subfieldFilters = {}, const std::string& remainingFilter = "", const RowTypePtr& dataColumns = nullptr, - const std::unordered_map< - std::string, - std::shared_ptr>& assignments = {}); + const connector::ColumnHandleMap& assignments = {}); /// Add a TableScanNode to scan a Hive table. /// @@ -168,9 +183,20 @@ class PlanBuilder { const std::vector& subfieldFilters = {}, const std::string& remainingFilter = "", const RowTypePtr& dataColumns = nullptr, - const std::unordered_map< - std::string, - std::shared_ptr>& assignments = {}); + const connector::ColumnHandleMap& assignments = {}); + + /// Add a TableScanNode to scan a Hive table with direct SubfieldFilters. + /// + /// @param outputType List of column names and types to read from the table. + /// @param PushdownConfig Contains pushdown configs for the table scan. + /// @param dataColumns Optional data columns that may differ from outputType. + /// @param assignments Optional ColumnHandles. + + PlanBuilder& tableScanWithPushDown( + const RowTypePtr& outputType, + const PushdownConfig& pushdownConfig, + const RowTypePtr& dataColumns = nullptr, + const connector::ColumnHandleMap& assignments = {}); /// Add a TableScanNode to scan a TPC-H table. /// @@ -179,11 +205,27 @@ class PlanBuilder { /// @param columnNames The columns to be returned from that table. /// @param scaleFactor The TPC-H scale factor. /// @param connectorId The TPC-H connector id. + /// @param filter Optional SQL expression to filter the data at the connector + /// level. PlanBuilder& tpchTableScan( tpch::Table table, std::vector columnNames, double scaleFactor = 1, - std::string_view connectorId = kTpchDefaultConnectorId); + std::string_view connectorId = kTpchDefaultConnectorId, + const std::string& filter = ""); + + /// Add a TableScanNode to scan a TPC-DS table. + /// + /// @param tpcdsTableHandle The handle that specifies the target TPC-DS table + /// and scale factor. + /// @param columnNames The columns to be returned from that table. + /// @param scaleFactor The TPC-DS scale factor. + /// @param connectorId The TPC-DS connector id. + PlanBuilder& tpcdsTableScan( + tpcds::Table table, + std::vector columnNames, + double scaleFactor = 0.01, + std::string_view connectorId = kTpcdsDefaultConnectorId); /// Helper class to build a custom TableScanNode. /// Uses a planBuilder instance to get the next plan id, memory pool, and @@ -239,6 +281,10 @@ class PlanBuilder { /// > column >= v2 TableScanBuilder& subfieldFilters(std::vector subfieldFilters); + // @param subfieldFiltersMap A map of Subfield to Filters. + TableScanBuilder& subfieldFiltersMap( + const common::SubfieldFilters& filtersMap); + /// @param subfieldFilter A single SQL expression to be applied to an /// individual column. TableScanBuilder& subfieldFilter(std::string subfieldFilter) { @@ -250,6 +296,8 @@ class PlanBuilder { /// AND'ed with all the subfieldFilters. TableScanBuilder& remainingFilter(std::string remainingFilter); + TableScanBuilder& sampleRate(double sampleRate); + /// @param dataColumns can be different from 'outputType' for the purposes /// of testing queries using missing columns. It is used, if specified, for /// parseExpr call and as 'dataColumns' for the TableHandle. You supply more @@ -272,19 +320,22 @@ class PlanBuilder { /// @param tableHandle Optional tableHandle. Other builder arguments such as /// the `subfieldFilters` and `remainingFilter` will be ignored. TableScanBuilder& tableHandle( - std::shared_ptr tableHandle) { + connector::ConnectorTableHandlePtr tableHandle) { tableHandle_ = std::move(tableHandle); return *this; } + TableScanBuilder& filterColumnHandles( + std::vector filterColumnHandles) { + filterColumnHandles_ = std::move(filterColumnHandles); + return *this; + } + /// @param assignments Optional ColumnHandles. /// outputType names should match the keys in the 'assignments' map. The /// 'assignments' map may contain more columns than 'outputType' if some /// columns are only used by pushed-down filters. - TableScanBuilder& assignments( - std::unordered_map< - std::string, - std::shared_ptr> assignments) { + TableScanBuilder& assignments(connector::ColumnHandleMap assignments) { assignments_ = std::move(assignments); return *this; } @@ -303,19 +354,22 @@ class PlanBuilder { std::string tableName_{"hive_table"}; std::string connectorId_{kHiveDefaultConnectorId}; RowTypePtr outputType_; - std::vector subfieldFilters_; core::ExprPtr remainingFilter_; + double sampleRate_{1.0}; RowTypePtr dataColumns_; + std::vector filterColumnHandles_; std::unordered_map columnAliases_; - std::shared_ptr tableHandle_; - std::unordered_map> - assignments_; + connector::ConnectorTableHandlePtr tableHandle_; + connector::ColumnHandleMap assignments_; // produce filters as a FilterNode instead of pushdown. bool filtersAsNode_{false}; // Generates the id of a FilterNode if 'filtersAsNode_'. std::shared_ptr planNodeIdGenerator_; + + // SubfieldFilters object containing filters to apply. + common::SubfieldFilters subfieldFiltersMap_; }; /// Start a TableScanBuilder. @@ -324,6 +378,91 @@ class PlanBuilder { return *tableScanBuilder_; } + /// Helper class to build a custom IndexLookupJoinNode. + class IndexLookupJoinBuilder { + public: + explicit IndexLookupJoinBuilder(PlanBuilder& builder) + : planBuilder_(builder) {} + + /// @param leftKeys Join keys from the table scan side, the preceding plan + /// node. Cannot be empty. + IndexLookupJoinBuilder& leftKeys(std::vector leftKeys) { + leftKeys_ = std::move(leftKeys); + return *this; + } + + /// @param rightKeys Join keys from the index lookup side, the plan node + /// specified in 'right' parameter. The number and types of left and right + /// keys must be the same. + IndexLookupJoinBuilder& rightKeys(std::vector rightKeys) { + rightKeys_ = std::move(rightKeys); + return *this; + } + + /// @param indexSource The right input source with index lookup support. + IndexLookupJoinBuilder& indexSource( + const core::TableScanNodePtr& indexSource) { + indexSource_ = indexSource; + return *this; + } + + IndexLookupJoinBuilder& joinConditions( + std::vector joinConditions) { + joinConditions_ = std::move(joinConditions); + return *this; + } + + IndexLookupJoinBuilder& hasMarker(bool hasMarker) { + hasMarker_ = hasMarker; + return *this; + } + + IndexLookupJoinBuilder& outputLayout( + std::vector outputLayout) { + outputLayout_ = std::move(outputLayout); + return *this; + } + + /// @param filter SQL expression for the additional join filter. Can + /// use columns from both probe and build sides of the join. + IndexLookupJoinBuilder& filter(std::string filter) { + filter_ = std::move(filter); + return *this; + } + + /// @param joinType Type of the join supported: inner, left. + IndexLookupJoinBuilder& joinType(core::JoinType joinType) { + joinType_ = joinType; + return *this; + } + + /// Stop the IndexLookupJoinBuilder. + PlanBuilder& endIndexLookupJoin() { + planBuilder_.planNode_ = build(planBuilder_.nextPlanNodeId()); + return planBuilder_; + } + + private: + /// Build the plan node IndexLookupJoinNode. + core::PlanNodePtr build(const core::PlanNodeId& id); + + PlanBuilder& planBuilder_; + std::vector leftKeys_; + std::vector rightKeys_; + core::TableScanNodePtr indexSource_; + std::vector joinConditions_; + std::string filter_; + bool hasMarker_{false}; + std::vector outputLayout_; + core::JoinType joinType_{core::JoinType::kInner}; + }; + + /// Start an IndexLookupJoinBuilder. + IndexLookupJoinBuilder& startIndexLookupJoin() { + indexLookupJoinBuilder_.reset(new IndexLookupJoinBuilder(*this)); + return *indexLookupJoinBuilder_; + } + /// /// TableWriter /// @@ -442,6 +581,15 @@ class PlanBuilder { return *this; } + /// Specifies commitStrategy for writing to the connector. + /// @param commitStrategy The commit strategy to use for the table write + /// operation. + TableWriterBuilder& commitStrategy( + connector::CommitStrategy commitStrategy) { + commitStrategy_ = commitStrategy; + return *this; + } + /// Stop the TableWriterBuilder. PlanBuilder& endTableWriter() { planBuilder_.planNode_ = build(planBuilder_.nextPlanNodeId()); @@ -473,6 +621,8 @@ class PlanBuilder { common::CompressionKind compressionKind_{common::CompressionKind_NONE}; bool ensureFiles_{false}; + connector::CommitStrategy commitStrategy_{ + connector::CommitStrategy::kNoCommit}; }; /// Start a TableWriterBuilder. @@ -497,8 +647,8 @@ class PlanBuilder { bool parallelizable = false, size_t repeatTimes = 1); - PlanBuilder& filtersAsNode(bool _filtersAsNode) { - filtersAsNode_ = _filtersAsNode; + PlanBuilder& filtersAsNode(bool filtersAsNode) { + filtersAsNode_ = filtersAsNode; return *this; } @@ -562,6 +712,22 @@ class PlanBuilder { /// will produce projected columns named sum_ab, c and p2. PlanBuilder& project(const std::vector& projections); + /// Add a ParallelProjectNode using groups of independent SQL expressions. + /// + /// @param projectionGroups One or more groups of expressions that depend on + /// disjunct sets of inputs. + /// @param noLoadColumn Optional columns to pass through as is without + /// loading. These columns must be distinct from the set of columns used in + /// 'projectionGroups'. + PlanBuilder& parallelProject( + const std::vector>& projectionGroups, + const std::vector& noLoadColumns = {}); + + /// Add a LazyDereferenceNode to the plan. + /// @param projections Same format as in `project`, but can only contain + /// field/subfield accesses. + PlanBuilder& lazyDereference(const std::vector& projections); + /// Add a ProjectNode to keep all existing columns and append more columns /// using specified expressions. /// @param newColumns A list of one or more expressions to use for computing @@ -573,6 +739,7 @@ class PlanBuilder { /// the type. PlanBuilder& projectExpressions( const std::vector& projections); + PlanBuilder& projectExpressions( const std::vector& projections); @@ -584,7 +751,10 @@ class PlanBuilder { /// Add a FilterNode using specified SQL expression. /// /// @param filter SQL expression of type boolean. - PlanBuilder& filter(const std::string& filter); + PlanBuilder& filter(const std::string& filterExpr); + + /// Same as above, but takes an untyped expression. + PlanBuilder& filter(const core::ExprPtr& filterExpr); /// Similar to filter() except 'optionalFilter' could be empty and the /// function will skip creating a FilterNode in that case. @@ -675,6 +845,12 @@ class PlanBuilder { /// output of the previous operator. /// @param ensureFiles When this option is set the HiveDataSink will always /// create a file even if there is no data. + /// @param commitStrategy The commit strategy to use for the table write + /// operation, default is kNoCommit. + /// @param insertTableHandle Encapsulates information needed to write data + /// to a table through a connector. If not specified, tableWrite will build + /// a HiveInsertTableHandle with columnHandles, bucketProperty and + /// locationHandle. PlanBuilder& tableWrite( const std::string& outputDirectoryPath, const std::vector& partitionBy, @@ -691,19 +867,21 @@ class PlanBuilder { const std::string& outputFileName = "", const common::CompressionKind = common::CompressionKind_NONE, const RowTypePtr& schema = nullptr, - const bool ensureFiles = false); + const bool ensureFiles = false, + const connector::CommitStrategy commitStrategy = + connector::CommitStrategy::kNoCommit, + std::shared_ptr insertTableHandle = nullptr); /// Add a TableWriteMergeNode. - PlanBuilder& tableWriteMerge( - const std::shared_ptr& aggregationNode = nullptr); + PlanBuilder& tableWriteMerge(); /// Add an AggregationNode representing partial aggregation with the /// specified grouping keys, aggregates and optional masks. /// - /// Aggregates are specified as function calls over unmodified input columns, - /// e.g. sum(a), avg(b), min(c). SQL statement AS can be used to specify names - /// for the aggregation result columns. In the absence of AS statement, result - /// columns are named a0, a1, a2, etc. + /// Aggregates are specified as function calls over unmodified input + /// columns, e.g. sum(a), avg(b), min(c). SQL statement AS can be used to + /// specify names for the aggregation result columns. In the absence of AS + /// statement, result columns are named a0, a1, a2, etc. /// /// For example, /// @@ -713,8 +891,8 @@ class PlanBuilder { /// /// partialAggregation({"k1", "k2"}, {"min(a) AS min_a", "max(b)"}) /// - /// will produce output columns k1, k2, min_a and a1, assuming the names of - /// the first two input columns are k1 and k2. + /// will produce output columns k1, k2, min_a and a1, assuming the names + /// of the first two input columns are k1 and k2. PlanBuilder& partialAggregation( const std::vector& groupingKeys, const std::vector& aggregates, @@ -871,7 +1049,8 @@ class PlanBuilder { const std::vector& aggregates, const std::vector& masks, core::AggregationNode::Step step, - bool ignoreNullKeys); + bool ignoreNullKeys, + bool noGroupsSpanBatches = false); /// Add a GroupIdNode using the specified grouping keys, grouping sets, /// aggregation inputs and a groupId column name. @@ -1147,6 +1326,24 @@ class PlanBuilder { const std::vector& outputLayout, core::JoinType joinType = core::JoinType::kInner); + /// Add a SpatialJoinNode to join two inputs using spatial join condition. + /// + /// @param right Right-side input. Typically, to reduce memory usage, the + /// smaller input is placed on the right-side. + /// @param joinCondition SQL expression as the spatial join condition. Can + /// use columns from both probe and build sides of the join. + /// @param outputLayout Output layout consisting of columns from probe and + /// build sides. + /// @param joinType Type of the join: inner (only one supported for now + PlanBuilder& spatialJoin( + const core::PlanNodePtr& right, + const std::string& joinCondition, + const std::string& probeGeometry, + const std::string& buildGeometry, + const std::optional& radius, + const std::vector& outputLayout, + core::JoinType joinType = core::JoinType::kInner); + static core::IndexLookupConditionPtr parseIndexJoinCondition( const std::string& joinCondition, const RowTypePtr& rowType, @@ -1157,6 +1354,11 @@ class PlanBuilder { /// node. Second input is specified in 'right' parameter and must be a /// table source with the connector table handle with index lookup support. /// + /// @param leftKeys Join keys from the probe side, the preceding plan node. + /// Cannot be empty. + /// @param rightKeys Join keys from the index lookup side, the plan node + /// specified in 'right' parameter. The number and types of left and right + /// keys must be the same. /// @param right The right input source with index lookup support. /// @param joinConditions SQL expressions as the join conditions. Each join /// condition must use columns from both sides. For the right side, it can @@ -1169,14 +1371,23 @@ class PlanBuilder { /// where "a" is the index column from right side and "b", "c" are either /// condition column from left side or a constant but at least one of them /// must not be constant. They all have the same type. + /// @param filter SQL expression for the additional join filter to apply on + /// join results. This supports filters that can't be converted into join + /// conditions or lookup conditions. Can be an empty string if no additional + /// filter is needed. + /// @param hasMarker if true, 'outputLayout' should include a boolean + /// column at the end to indicate if a join output row has a match or not. + /// This only applies for left join. + /// @param outputLayout Output layout consisting of columns from probe and + /// build sides. /// @param joinType Type of the join supported: inner, left. - /// - /// See hashJoin method for the description of the other parameters. PlanBuilder& indexLookupJoin( const std::vector& leftKeys, const std::vector& rightKeys, const core::TableScanNodePtr& right, const std::vector& joinConditions, + const std::string& filter, + bool hasMarker, const std::vector& outputLayout, core::JoinType joinType = core::JoinType::kInner); @@ -1198,10 +1409,16 @@ class PlanBuilder { /// @param ordinalColumn An optional name for the 'ordinal' column to produce. /// This column contains the index of the element of the unnested array or /// map. If not specified, the output will not contain this column. + /// @param markerName An optional name for the marker column to produce. + /// This column contains a boolean indicating whether the output row has + /// non-empty unnested value. If not specified, the output will not contain + /// this column and the unnest operator also skips producing output rows + /// with empty unnest value. PlanBuilder& unnest( const std::vector& replicateColumns, const std::vector& unnestColumns, - const std::optional& ordinalColumn = std::nullopt); + const std::optional& ordinalColumn = std::nullopt, + const std::optional& markerName = std::nullopt); /// Add a WindowNode to compute one or more windowFunctions. /// @param windowFunctions A list of one or more window function SQL like @@ -1237,14 +1454,23 @@ class PlanBuilder { std::optional limit = std::nullopt, bool generateRowNumber = true); - /// Add a TopNRowNumberNode to compute single row_number window function with - /// a limit applied to sorted partitions. + /// Add a TopNRowNumberNode to compute row_number + /// function with a limit applied to sorted partitions. PlanBuilder& topNRowNumber( const std::vector& partitionKeys, const std::vector& sortingKeys, int32_t limit, bool generateRowNumber); + /// Add a TopNRowNumberNode to compute row_number, rank or dense_rank window + /// function with a limit applied to sorted partitions. + PlanBuilder& topNRank( + std::string_view function, + const std::vector& partitionKeys, + const std::vector& sortingKeys, + int32_t limit, + bool generateRowNumber); + /// Add a MarkDistinctNode to compute aggregate mask channel /// @param markerKey Name of output mask channel /// @param distinctKeys List of columns to be marked distinct. @@ -1424,6 +1650,7 @@ class PlanBuilder { core::PlanNodePtr planNode_; parse::ParseOptions options_; std::shared_ptr tableScanBuilder_; + std::shared_ptr indexLookupJoinBuilder_; std::shared_ptr tableWriterBuilder_; private: diff --git a/velox/exec/tests/utils/QueryAssertions.cpp b/velox/exec/tests/utils/QueryAssertions.cpp index 720b4943b78c..9eef2c5a60f3 100644 --- a/velox/exec/tests/utils/QueryAssertions.cpp +++ b/velox/exec/tests/utils/QueryAssertions.cpp @@ -20,9 +20,9 @@ #include "duckdb/common/types.hpp" // @manual #include "velox/duckdb/conversion/DuckConversion.h" #include "velox/exec/Cursor.h" -#include "velox/exec/tests/utils/ArbitratorTestUtil.h" #include "velox/exec/tests/utils/QueryAssertions.h" -#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/Type.h" +#include "velox/vector/VariantToVector.h" #include "velox/vector/VectorTypeUtils.h" using facebook::velox::duckdb::duckdbTimestampToVelox; @@ -102,8 +102,9 @@ ::duckdb::Value duckValueAt( vector_size_t index) { auto type = vector->type(); if (type->isDate()) { - return ::duckdb::Value::DATE(::duckdb::Date::EpochDaysToDate( - vector->as>()->valueAt(index))); + return ::duckdb::Value::DATE( + ::duckdb::Date::EpochDaysToDate( + vector->as>()->valueAt(index))); } return ::duckdb::Value(vector->as>()->valueAt(index)); } @@ -218,9 +219,10 @@ ::duckdb::Value duckValueAt( const auto& mapValues = mapVector->mapValues(); auto offset = mapVector->offsetAt(mapRow); auto size = mapVector->sizeAt(mapRow); - auto mapType = ::duckdb::ListType::GetChildType(::duckdb::LogicalType::MAP( - duckdb::fromVeloxType(mapKeys->type()), - duckdb::fromVeloxType(mapValues->type()))); + auto mapType = ::duckdb::ListType::GetChildType( + ::duckdb::LogicalType::MAP( + duckdb::fromVeloxType(mapKeys->type()), + duckdb::fromVeloxType(mapValues->type()))); if (size == 0) { return ::duckdb::Value::MAP(mapType, ::duckdb::vector<::duckdb::Value>()); } @@ -276,7 +278,8 @@ variant variantAt( int32_t row, int32_t column) { return variant::binary( - StringView(::duckdb::StringValue::Get(dataChunk->GetValue(column, row)))); + std::string( + ::duckdb::StringValue::Get(dataChunk->GetValue(column, row)))); } template <> @@ -321,7 +324,7 @@ variant variantAt(const ::duckdb::Value& value) { template <> variant variantAt(const ::duckdb::Value& value) { - return variant::binary(StringView(::duckdb::StringValue::Get(value))); + return variant::binary(std::string(::duckdb::StringValue::Get(value))); } variant nullVariant(const TypePtr& type) { @@ -437,12 +440,14 @@ std::vector materialize( } else if (type->isDecimal()) { row.push_back(duckdb::decimalVariant(dataChunk->GetValue(j, i))); } else if (type->isIntervalDayTime()) { - auto value = variant(::duckdb::Interval::GetMicro( - dataChunk->GetValue(j, i).GetValue<::duckdb::interval_t>())); + auto value = variant( + ::duckdb::Interval::GetMicro( + dataChunk->GetValue(j, i).GetValue<::duckdb::interval_t>())); row.push_back(value); } else if (type->isDate()) { - auto value = variant(::duckdb::Date::EpochDays( - dataChunk->GetValue(j, i).GetValue<::duckdb::date_t>())); + auto value = variant( + ::duckdb::Date::EpochDays( + dataChunk->GetValue(j, i).GetValue<::duckdb::date_t>())); row.push_back(value); } else { auto value = VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( @@ -455,93 +460,6 @@ std::vector materialize( return rows; } -template -variant variantAt(VectorPtr vector, int32_t row) { - using T = typename KindToFlatVector::WrapperType; - return variant(vector->as>()->valueAt(row)); -} - -template <> -variant variantAt(VectorPtr vector, int32_t row) { - return variant::binary(vector->as>()->valueAt(row)); -} - -variant variantAt(const VectorPtr& vector, vector_size_t row); - -variant arrayVariantAt(const VectorPtr& vector, vector_size_t row) { - auto arrayVector = vector->wrappedVector()->as(); - auto& elements = arrayVector->elements(); - - auto wrappedRow = vector->wrappedIndex(row); - auto offset = arrayVector->offsetAt(wrappedRow); - auto size = arrayVector->sizeAt(wrappedRow); - - std::vector array; - array.reserve(size); - for (auto i = 0; i < size; i++) { - auto innerRow = offset + i; - array.push_back(variantAt(elements, innerRow)); - } - return variant::array(array); -} - -variant mapVariantAt(const VectorPtr& vector, vector_size_t row) { - auto mapVector = vector->wrappedVector()->as(); - auto& mapKeys = mapVector->mapKeys(); - auto& mapValues = mapVector->mapValues(); - - auto wrappedRow = vector->wrappedIndex(row); - auto offset = mapVector->offsetAt(wrappedRow); - auto size = mapVector->sizeAt(wrappedRow); - - std::map map; - for (auto i = 0; i < size; i++) { - auto innerRow = offset + i; - auto key = variantAt(mapKeys, innerRow); - auto value = variantAt(mapValues, innerRow); - map.insert({key, value}); - } - return variant::map(map); -} - -variant rowVariantAt(const VectorPtr& vector, vector_size_t row) { - auto rowValues = vector->wrappedVector()->as(); - auto wrappedRow = vector->wrappedIndex(row); - - std::vector values; - for (auto& child : rowValues->children()) { - values.push_back(variantAt(child, wrappedRow)); - } - return variant::row(std::move(values)); -} - -variant variantAt(const VectorPtr& vector, vector_size_t row) { - if (vector->isNullAt(row)) { - return nullVariant(vector->type()); - } - - auto typeKind = vector->typeKind(); - if (typeKind == TypeKind::ROW) { - return rowVariantAt(vector, row); - } - - if (typeKind == TypeKind::ARRAY) { - return arrayVariantAt(vector, row); - } - - if (typeKind == TypeKind::MAP) { - return mapVariantAt(vector, row); - } - - if (isTimestampWithTimeZoneType(vector->type())) { - return variant::typeWithCustomComparison( - vector->as>()->valueAt(row), - TIMESTAMP_WITH_TIME_ZONE()); - } - - return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(variantAt, typeKind, vector, row); -} - MaterializedRow getColumns( const MaterializedRow& row, const std::vector& columnIndices) { @@ -797,7 +715,7 @@ std::string toTypeString(const MaterializedRow& row) { if (i > 0) { out << ", "; } - out << mapTypeKindToName(row[i].kind()); + out << TypeKindName::toName(row[i].kind()); } out << ")"; return out.str(); @@ -899,7 +817,7 @@ std::vector materialize(const RowVectorPtr& vector) { MaterializedRow row; row.reserve(numColumns); for (size_t j = 0; j < numColumns; ++j) { - row.push_back(variantAt(simpleVectors[j], i)); + row.push_back(simpleVectors[j]->variantAt(i)); } rows.push_back(row); } @@ -1433,11 +1351,12 @@ std::pair, std::vector> readCursor( // 'result' borrows memory from cursor so the life cycle must be shorter. std::vector result; auto* task = cursor->task().get(); - while (!cursor->noMoreSplits()) { addSplits(cursor.get()); while (cursor->moveNext()) { - result.push_back(cursor->current()); + auto vector = cursor->current(); + vector->loadedVector(); + result.push_back(std::move(vector)); testingMaybeTriggerAbort(task); } } @@ -1507,7 +1426,7 @@ bool waitForTaskStateChange( void waitForAllTasksToBeDeleted(uint64_t maxWaitUs) { uint64_t waitUs = 0; while (Task::numRunningTasks() != 0) { - constexpr uint64_t kWaitInternalUs = 1'000; + constexpr uint64_t kWaitInternalUs = 50'000; std::this_thread::sleep_for(std::chrono::microseconds(kWaitInternalUs)); waitUs += kWaitInternalUs; if (waitUs >= maxWaitUs) { @@ -1529,6 +1448,15 @@ void waitForAllTasksToBeDeleted(uint64_t maxWaitUs) { folly::join("\n", pendingTaskStats)); } +void cancelAllTasks() { + std::vector> pendingTasks = Task::getRunningTasks(); + for (const auto& task : pendingTasks) { + if (task->isRunning()) { + task->requestCancel(); + } + } +} + std::shared_ptr assertQuery( const core::PlanNodePtr& plan, std::function addSplits, diff --git a/velox/exec/tests/utils/QueryAssertions.h b/velox/exec/tests/utils/QueryAssertions.h index 3acfa88885d0..db766aa1e5bf 100644 --- a/velox/exec/tests/utils/QueryAssertions.h +++ b/velox/exec/tests/utils/QueryAssertions.h @@ -221,6 +221,11 @@ bool waitForTaskStateChange( /// during this wait call. This is for testing purpose for now. void waitForAllTasksToBeDeleted(uint64_t maxWaitUs = 3'000'000); +/// Cancels all currently running tasks across all available task managers. +/// This is primarily used in testing scenarios to clean up active tasks +/// and ensure test isolation between test cases. +void cancelAllTasks(); + std::shared_ptr assertQuery( const core::PlanNodePtr& plan, const std::string& duckDbSql, @@ -307,6 +312,13 @@ bool assertEqualResults( const core::PlanNodePtr& plan1, const core::PlanNodePtr& plan2); +bool assertEqualResults( + const MaterializedRowMultiset& expectedRows, + const TypePtr& expectedType, + const MaterializedRowMultiset& actualRows, + const TypePtr& actualType, + const std::string& message); + /// Ensure both datasets have the same type and number of rows. void assertEqualTypeAndNumRows( const TypePtr& expectedType, diff --git a/velox/exec/tests/utils/RowContainerTestBase.h b/velox/exec/tests/utils/RowContainerTestBase.h index c9d126fa869e..5bde0826f5b4 100644 --- a/velox/exec/tests/utils/RowContainerTestBase.h +++ b/velox/exec/tests/utils/RowContainerTestBase.h @@ -76,7 +76,8 @@ class RowContainerTestBase : public testing::Test, std::unique_ptr makeRowContainer( const std::vector& keyTypes, const std::vector& dependentTypes, - bool isJoinBuild = true) { + bool isJoinBuild = true, + bool useListRowIndex = false) { auto container = std::make_unique( keyTypes, !isJoinBuild, @@ -86,6 +87,7 @@ class RowContainerTestBase : public testing::Test, isJoinBuild, true, true, + useListRowIndex, pool_.get()); VELOX_CHECK(container->testingMutable()); return container; diff --git a/velox/exec/tests/utils/SerializedPageUtil.cpp b/velox/exec/tests/utils/SerializedPageUtil.cpp index 4316f1a783fa..a498fe5b8c1d 100644 --- a/velox/exec/tests/utils/SerializedPageUtil.cpp +++ b/velox/exec/tests/utils/SerializedPageUtil.cpp @@ -20,7 +20,7 @@ using namespace facebook::velox; namespace facebook::velox::exec::test { -std::unique_ptr toSerializedPage( +std::unique_ptr toSerializedPage( const RowVectorPtr& vector, VectorSerde::Kind serdeKind, const std::shared_ptr& bufferManager, @@ -34,7 +34,8 @@ std::unique_ptr toSerializedPage( auto listener = bufferManager->newListener(); IOBufOutputStream stream(*pool, listener.get(), data->size()); data->flush(&stream); - return std::make_unique(stream.getIOBuf(), nullptr, size); + return std::make_unique( + stream.getIOBuf(), nullptr, size); } } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/SerializedPageUtil.h b/velox/exec/tests/utils/SerializedPageUtil.h index 6890693d6345..b074068c3ebd 100644 --- a/velox/exec/tests/utils/SerializedPageUtil.h +++ b/velox/exec/tests/utils/SerializedPageUtil.h @@ -23,7 +23,7 @@ namespace facebook::velox::exec::test { /// Helper function for serializing RowVector to PrestoPage format. -std::unique_ptr toSerializedPage( +std::unique_ptr toSerializedPage( const RowVectorPtr& vector, VectorSerde::Kind serdeKind, const std::shared_ptr& bufferManager, diff --git a/velox/exec/tests/utils/TableScanTestBase.cpp b/velox/exec/tests/utils/TableScanTestBase.cpp index ef571d75b6d3..394020fe1b5a 100644 --- a/velox/exec/tests/utils/TableScanTestBase.cpp +++ b/velox/exec/tests/utils/TableScanTestBase.cpp @@ -15,7 +15,6 @@ */ #include "velox/exec/tests/utils/TableScanTestBase.h" - #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/LocalExchangeSource.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -159,7 +158,7 @@ void TableScanTestBase::testPartitionedTableImpl( .build(); auto outputType = ROW({"pkey", "c0", "c1"}, {partitionType, BIGINT(), DOUBLE()}); - ColumnHandleMap assignments = { + connector::ColumnHandleMap assignments = { {"pkey", partitionKey("pkey", partitionType)}, {"c0", regularColumn("c0", BIGINT())}, {"c1", regularColumn("c1", DOUBLE())}}; diff --git a/velox/exec/tests/utils/TableScanTestBase.h b/velox/exec/tests/utils/TableScanTestBase.h index b51f629a49d8..f7e2929f2847 100644 --- a/velox/exec/tests/utils/TableScanTestBase.h +++ b/velox/exec/tests/utils/TableScanTestBase.h @@ -20,6 +20,7 @@ #include #include #include +#include "velox/connectors/hive/FileHandle.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" #include "velox/type/Type.h" diff --git a/velox/exec/tests/utils/TableWriterTestBase.cpp b/velox/exec/tests/utils/TableWriterTestBase.cpp index ebf2eb1f0599..6ec932f748db 100644 --- a/velox/exec/tests/utils/TableWriterTestBase.cpp +++ b/velox/exec/tests/utils/TableWriterTestBase.cpp @@ -42,7 +42,7 @@ CompressionKind TableWriterTestBase::TestParam::compressionKind() const { } bool TableWriterTestBase::TestParam::multiDrivers() const { - return (value >> 40) != 0; + return (value & (1L << 40)) != 0; } FileFormat TableWriterTestBase::TestParam::fileFormat() const { @@ -71,10 +71,10 @@ bool TableWriterTestBase::TestParam::scaleWriter() const { std::string TableWriterTestBase::TestParam::toString() const { return fmt::format( - "FileFormat[{}] TestMode[{}] commitStrategy[{}] bucketKind[{}] bucketSort[{}] multiDrivers[{}] compression[{}] scaleWriter[{}]", + "FileFormat_{}_TestMode_{}_commitStrategy_{}_bucketKind_{}_bucketSort_{}_multiDrivers_{}_compression_{}_scaleWriter_{}", dwio::common::toString((fileFormat())), testModeString(testMode()), - commitStrategyToString(commitStrategy()), + CommitStrategyName::toName(commitStrategy()), HiveBucketProperty::kindString(bucketKind()), bucketSort(), multiDrivers(), @@ -91,35 +91,26 @@ std::string TableWriterTestBase::testModeString(TestMode mode) { case TestMode::kBucketed: return "BUCKETED"; case TestMode::kOnlyBucketed: - return "BUCKETED (NOT PARTITIONED)"; + return "BUCKETED_WITHOUT_PARTITION"; } VELOX_UNREACHABLE(); } // static -std::shared_ptr -TableWriterTestBase::generateAggregationNode( +core::ColumnStatsSpec TableWriterTestBase::generateColumnStatsSpec( const std::string& name, const std::vector& groupingKeys, - AggregationNode::Step step, - const PlanNodePtr& source) { + AggregationNode::Step step) { core::TypedExprPtr inputField = std::make_shared(BIGINT(), name); - auto callExpr = std::make_shared( - BIGINT(), std::vector{inputField}, "min"); + auto callExpr = + std::make_shared(BIGINT(), "min", inputField); std::vector aggregateNames = {"min"}; std::vector aggregates = { core::AggregationNode::Aggregate{ callExpr, {{BIGINT()}}, nullptr, {}, {}}}; - return std::make_shared( - core::PlanNodeId(), - step, - groupingKeys, - std::vector{}, - aggregateNames, - aggregates, - false, // ignoreNullKeys - source); + return core::ColumnStatsSpec{ + std::move(groupingKeys), step, aggregateNames, aggregates}; } // static. @@ -127,7 +118,7 @@ std::function TableWriterTestBase::addTableWriter( const RowTypePtr& inputColumns, const std::vector& tableColumnNames, - const std::shared_ptr& aggregationNode, + const std::optional& columnStatsSpec, const std::shared_ptr& insertHandle, bool hasPartitioningScheme, connector::CommitStrategy commitStrategy) { @@ -137,10 +128,10 @@ TableWriterTestBase::addTableWriter( nodeId, inputColumns, tableColumnNames, - aggregationNode, + columnStatsSpec, insertHandle, hasPartitioningScheme, - TableWriteTraits::outputType(aggregationNode), + TableWriteTraits::outputType(columnStatsSpec), commitStrategy, std::move(source)); }; @@ -408,8 +399,9 @@ TableWriterTestBase::makeHiveConnectorSplits(const std::string& directoryPath) { std::vector> splits; for (auto& path : fs::recursive_directory_iterator(directoryPath)) { if (path.is_regular_file()) { - splits.push_back(HiveConnectorTestBase::makeHiveConnectorSplits( - path.path().string(), 1, fileFormat_)[0]); + splits.push_back( + HiveConnectorTestBase::makeHiveConnectorSplits( + path.path().string(), 1, fileFormat_)[0]); } } return splits; @@ -435,8 +427,9 @@ TableWriterTestBase::makeHiveConnectorSplits( const std::vector& filePaths) { std::vector> splits; for (const auto& filePath : filePaths) { - splits.push_back(HiveConnectorTestBase::makeHiveConnectorSplits( - filePath.string(), 1, fileFormat_)[0]); + splits.push_back( + HiveConnectorTestBase::makeHiveConnectorSplits( + filePath.string(), 1, fileFormat_)[0]); } return splits; } @@ -550,7 +543,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlan( const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode) { + const std::optional& columnStatsSpec) { return createInsertPlan( inputPlan, inputPlan.planNode()->outputType(), @@ -563,7 +556,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlan( outputTableType, outputCommitStrategy, aggregateResult, - aggregationNode); + columnStatsSpec); } PlanNodePtr TableWriterTestBase::createInsertPlan( @@ -578,7 +571,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlan( const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode) { + const std::optional& columnStatsSpec) { if (numTableWriters == 1) { return createInsertPlanWithSingleWriter( inputPlan, @@ -591,7 +584,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlan( outputTableType, outputCommitStrategy, aggregateResult, - aggregationNode); + columnStatsSpec); } else if (bucketProperty_ == nullptr) { return createInsertPlanWithForNonBucketedTable( inputPlan, @@ -603,7 +596,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlan( outputTableType, outputCommitStrategy, aggregateResult, - aggregationNode); + columnStatsSpec); } else { return createInsertPlanForBucketTable( inputPlan, @@ -616,7 +609,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlan( outputTableType, outputCommitStrategy, aggregateResult, - aggregationNode); + columnStatsSpec); } } @@ -631,7 +624,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlanWithSingleWriter( const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode) { + std::optional columnStatsSpec) { const bool addScaleWriterExchange = scaleWriter_ && (bucketProperty != nullptr); auto insertPlan = inputPlan; @@ -647,7 +640,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlanWithSingleWriter( .addNode(addTableWriter( inputRowType, tableRowType->names(), - aggregationNode, + columnStatsSpec, createInsertTableHandle( tableRowType, outputTableType, @@ -658,14 +651,6 @@ PlanNodePtr TableWriterTestBase::createInsertPlanWithSingleWriter( false, outputCommitStrategy)) .capturePlanNodeId(tableWriteNodeId_); - if (addScaleWriterExchange) { - if (!partitionedBy.empty()) { - insertPlan.scaleWriterlocalPartition( - inputColumnNames(partitionedBy, tableRowType, inputRowType)); - } else { - insertPlan.scaleWriterlocalPartitionRoundRobin(); - } - } if (aggregateResult) { insertPlan.project({TableWriteTraits::rowCountColumnName()}) .singleAggregation( @@ -686,7 +671,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlanForBucketTable( const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode) { + std::optional columnStatsSpec) { // Since we might do column rename, so generate bucket property based on // the data type from 'inputPlan'. std::vector bucketColumns; @@ -706,7 +691,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlanForBucketTable( .addNode(addTableWriter( inputRowType, tableRowType->names(), - nullptr, + std::nullopt, createInsertTableHandle( tableRowType, outputTableType, @@ -752,7 +737,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlanWithForNonBucketedTable( const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode) { + std::optional columnStatsSpec) { auto insertPlan = inputPlan; if (scaleWriter_) { if (!partitionedBy.empty()) { @@ -766,7 +751,7 @@ PlanNodePtr TableWriterTestBase::createInsertPlanWithForNonBucketedTable( .addNode(addTableWriter( inputRowType, tableRowType->names(), - nullptr, + std::nullopt, createInsertTableHandle( tableRowType, outputTableType, @@ -798,9 +783,10 @@ std::string TableWriterTestBase::partitionNameToPredicate( for (auto i = 0; i < partitionKeyValues.size(); ++i) { if (partitionTypes[i]->isVarchar() || partitionTypes[i]->isVarbinary() || partitionTypes[i]->isDate()) { - conjuncts.push_back(partitionKeyValues[i] - .replace(partitionKeyValues[i].find("="), 1, "='") - .append("'")); + conjuncts.push_back( + partitionKeyValues[i] + .replace(partitionKeyValues[i].find("="), 1, "='") + .append("'")); } else { conjuncts.push_back(partitionKeyValues[i]); } @@ -816,9 +802,10 @@ std::string TableWriterTestBase::partitionNameToPredicate( for (auto i = 0; i < partitionDirNames.size(); ++i) { if (partitionTypes_[i]->isVarchar() || partitionTypes_[i]->isVarbinary() || partitionTypes_[i]->isDate()) { - conjuncts.push_back(partitionKeyValues[i] - .replace(partitionKeyValues[i].find("="), 1, "='") - .append("'")); + conjuncts.push_back( + partitionKeyValues[i] + .replace(partitionKeyValues[i].find("="), 1, "='") + .append("'")); } else { conjuncts.push_back(partitionDirNames[i]); } @@ -831,20 +818,22 @@ void TableWriterTestBase::verifyUnbucketedFilePath( const std::string& targetDir) { ASSERT_EQ(filePath.parent_path().string(), targetDir); if (commitStrategy_ == CommitStrategy::kNoCommit) { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - fmt::format( - "test_cursor.+_[0-{}]_{}_.+", - numTableWriterCount_ - 1, - tableWriteNodeId_))) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + fmt::format( + "test_cursor.+_[0-{}]_{}_.+", + numTableWriterCount_ - 1, + tableWriteNodeId_))) << filePath.filename().string(); } else { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - fmt::format( - ".tmp.velox.test_cursor.+_[0-{}]_{}_.+", - numTableWriterCount_ - 1, - tableWriteNodeId_))) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + fmt::format( + ".tmp.velox.test_cursor.+_[0-{}]_{}_.+", + numTableWriterCount_ - 1, + tableWriteNodeId_))) << filePath.filename().string(); } } @@ -860,25 +849,29 @@ void TableWriterTestBase::verifyBucketedFileName( const std::filesystem::path& filePath) { if (commitStrategy_ == CommitStrategy::kNoCommit) { if (fileFormat_ == FileFormat::PARQUET) { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - "0[0-9]+_0_TaskCursorQuery_[0-9]+\\.parquet$")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + "0[0-9]+_0_TaskCursorQuery_[0-9]+\\.parquet$")) << filePath.filename().string(); } else { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), "0[0-9]+_0_TaskCursorQuery_[0-9]+")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), "0[0-9]+_0_TaskCursorQuery_[0-9]+")) << filePath.filename().string(); } } else { if (fileFormat_ == FileFormat::PARQUET) { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+\\.parquet$")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+\\.parquet$")) << filePath.filename().string(); } else { - ASSERT_TRUE(RE2::FullMatch( - filePath.filename().string(), - ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+")) + ASSERT_TRUE( + RE2::FullMatch( + filePath.filename().string(), + ".tmp.velox.0[0-9]+_0_TaskCursorQuery_[0-9]+_.+")) << filePath.filename().string(); } } diff --git a/velox/exec/tests/utils/TableWriterTestBase.h b/velox/exec/tests/utils/TableWriterTestBase.h index a98e4e44d722..f0c71f92beee 100644 --- a/velox/exec/tests/utils/TableWriterTestBase.h +++ b/velox/exec/tests/utils/TableWriterTestBase.h @@ -105,7 +105,7 @@ class TableWriterTestBase : public HiveConnectorTestBase { static std::function addTableWriter( const RowTypePtr& inputColumns, const std::vector& tableColumnNames, - const std::shared_ptr& aggregationNode, + const std::optional& columnStatsSpec, const std::shared_ptr& insertHandle, bool hasPartitioningScheme, connector::CommitStrategy commitStrategy = @@ -115,11 +115,10 @@ class TableWriterTestBase : public HiveConnectorTestBase { const std::vector& partitionedKeys, const RowTypePtr& rowType); - static std::shared_ptr generateAggregationNode( + static core::ColumnStatsSpec generateColumnStatsSpec( const std::string& name, const std::vector& groupingKeys, - AggregationNode::Step step, - const PlanNodePtr& source); + AggregationNode::Step step); std::shared_ptr assertQueryWithWriterConfigs( const core::PlanNodePtr& plan, @@ -207,7 +206,7 @@ class TableWriterTestBase : public HiveConnectorTestBase { connector::hive::LocationHandle::TableType::kNew, const CommitStrategy& outputCommitStrategy = CommitStrategy::kNoCommit, bool aggregateResult = true, - std::shared_ptr aggregationNode = nullptr); + const std::optional& columnStatsSpec = std::nullopt); PlanNodePtr createInsertPlan( PlanBuilder& inputPlan, @@ -222,7 +221,7 @@ class TableWriterTestBase : public HiveConnectorTestBase { connector::hive::LocationHandle::TableType::kNew, const CommitStrategy& outputCommitStrategy = CommitStrategy::kNoCommit, bool aggregateResult = true, - std::shared_ptr aggregationNode = nullptr); + const std::optional& columnStatsSpec = std::nullopt); PlanNodePtr createInsertPlanWithSingleWriter( PlanBuilder& inputPlan, @@ -235,7 +234,7 @@ class TableWriterTestBase : public HiveConnectorTestBase { const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode); + std::optional columnStatsSpec); PlanNodePtr createInsertPlanForBucketTable( PlanBuilder& inputPlan, @@ -248,7 +247,7 @@ class TableWriterTestBase : public HiveConnectorTestBase { const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode); + std::optional columnStatsSpec); // Return the corresponding column names in 'inputRowType' of // 'tableColumnNames' from 'tableRowType'. @@ -267,7 +266,7 @@ class TableWriterTestBase : public HiveConnectorTestBase { const connector::hive::LocationHandle::TableType& outputTableType, const CommitStrategy& outputCommitStrategy, bool aggregateResult, - std::shared_ptr aggregationNode); + std::optional columnStatsSpec); // Parameter partitionName is string formatted in the Hive style // key1=value1/key2=value2/... Parameter partitionTypes are types of partition diff --git a/velox/exec/tests/utils/TestIndexStorageConnector.cpp b/velox/exec/tests/utils/TestIndexStorageConnector.cpp index 3dfdfe04e3b2..2dc3eb1a0ee6 100644 --- a/velox/exec/tests/utils/TestIndexStorageConnector.cpp +++ b/velox/exec/tests/utils/TestIndexStorageConnector.cpp @@ -25,15 +25,90 @@ using facebook::velox::common::testutil::TestValue; namespace facebook::velox::exec::test { + +// static +std::shared_ptr TestIndexTable::create( + size_t numEqualJoinKeys, + const RowVectorPtr& keyData, + const RowVectorPtr& valueData, + velox::memory::MemoryPool& pool) { + VELOX_CHECK_GE(numEqualJoinKeys, 1); + VELOX_CHECK_NOT_NULL(keyData); + VELOX_CHECK_NOT_NULL(valueData); + + auto keyType = asRowType(keyData->type()); + VELOX_CHECK_GE(keyType->size(), numEqualJoinKeys); + + auto valueType = asRowType(valueData->type()); + VELOX_CHECK_GE(valueType->size(), 1); + + const auto numRows = keyData->size(); + VELOX_CHECK_EQ(numRows, valueData->size()); + + std::vector> hashers; + hashers.reserve(numEqualJoinKeys); + std::vector keyVectors; + keyVectors.reserve(numEqualJoinKeys); + for (auto i = 0; i < numEqualJoinKeys; ++i) { + hashers.push_back(std::make_unique(keyType->childAt(i), i)); + keyVectors.push_back(keyData->childAt(i)); + } + + std::vector dependentTypes; + std::vector dependentVectors; + for (auto i = numEqualJoinKeys; i < keyType->size(); ++i) { + dependentTypes.push_back(keyType->childAt(i)); + dependentVectors.push_back(keyData->childAt(i)); + } + + for (auto i = 0; i < valueType->size(); ++i) { + dependentTypes.push_back(valueType->childAt(i)); + dependentVectors.push_back(valueData->childAt(i)); + } + + // Create the table. + auto table = HashTable::createForJoin( + std::move(hashers), + /*dependentTypes=*/dependentTypes, + /*allowDuplicates=*/true, + /*hasProbedFlag=*/false, + /*minTableSizeForParallelJoinBuild=*/1, + &pool); + + // Insert data into the row container. + auto* rowContainer = table->rows(); + std::vector decodedVectors; + for (auto& vector : keyData->children()) { + decodedVectors.emplace_back(*vector); + } + for (auto& vector : valueData->children()) { + decodedVectors.emplace_back(*vector); + } + + const auto nextOffset = rowContainer->nextOffset(); + VELOX_CHECK_GT(nextOffset, 0); + for (auto row = 0; row < numRows; ++row) { + auto* newRow = rowContainer->newRow(); + *reinterpret_cast(newRow + nextOffset) = nullptr; + for (auto col = 0; col < decodedVectors.size(); ++col) { + rowContainer->store(decodedVectors[col], row, newRow, col); + } + } + + // Build the table index. + table->prepareJoinTable( + {}, BaseHashTable::kNoSpillInputStartPartitionBit, 1'000'000); + return std::make_shared( + std::move(keyType), std::move(valueType), std::move(table)); +} + namespace { core::TypedExprPtr toJoinConditionExpr( const std::vector>& joinConditions, const std::shared_ptr& indexTable, const RowTypePtr& inputType, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles) { + const connector::ColumnHandleMap& columnHandles) { if (joinConditions.empty()) { return nullptr; } @@ -46,23 +121,35 @@ core::TypedExprPtr toJoinConditionExpr( if (auto inCondition = std::dynamic_pointer_cast( condition)) { - conditionExprs.push_back(std::make_shared( - BOOLEAN(), - std::vector{ - inCondition->list, std::move(indexColumnExpr)}, - "contains")); + conditionExprs.push_back( + std::make_shared( + BOOLEAN(), + "contains", + inCondition->list, + std::move(indexColumnExpr))); continue; } if (auto betweenCondition = std::dynamic_pointer_cast( condition)) { - conditionExprs.push_back(std::make_shared( - BOOLEAN(), - std::vector{ + conditionExprs.push_back( + std::make_shared( + BOOLEAN(), + "between", std::move(indexColumnExpr), betweenCondition->lower, - betweenCondition->upper}, - "between")); + betweenCondition->upper)); + continue; + } + if (auto equalCondition = + std::dynamic_pointer_cast( + condition)) { + conditionExprs.push_back( + std::make_shared( + BOOLEAN(), + "eq", + std::move(indexColumnExpr), + equalCondition->value)); continue; } VELOX_FAIL("Invalid index join condition: {}", condition->toString()); @@ -101,7 +188,9 @@ TestIndexSource::TestIndexSource( const RowTypePtr& outputType, size_t numEqualJoinKeys, const core::TypedExprPtr& joinConditionExpr, - const std::shared_ptr& tableHandle, + const TestIndexTableHandlePtr& tableHandle, + const std::unordered_map& + columnHandles, connector::ConnectorQueryCtx* connectorQueryCtx, folly::Executor* executor) : tableHandle_(tableHandle), @@ -118,7 +207,9 @@ TestIndexSource::TestIndexSource( : nullptr), pool_(connectorQueryCtx_->memoryPool()->shared_from_this()), executor_(executor) { - VELOX_CHECK_NOT_NULL(executor_); + if (tableHandle_->asyncLookup()) { + VELOX_CHECK_NOT_NULL(executor_); + } VELOX_CHECK_LE(outputType_->size(), valueType_->size() + keyType_->size()); VELOX_CHECK_LE(numEqualJoinKeys_, keyType_->size()); for (int i = 0; i < numEqualJoinKeys_; ++i) { @@ -128,7 +219,7 @@ TestIndexSource::TestIndexSource( keyType_->toString(), inputType_->toString()); } - initOutputProjections(); + initOutputProjections(columnHandles); initConditionProjections(); } @@ -186,20 +277,24 @@ void TestIndexSource::initConditionProjections() { conditionInputType_ = ROW(std::move(names), std::move(types)); } -void TestIndexSource::initOutputProjections() { +void TestIndexSource::initOutputProjections( + const std::unordered_map& + columnHandles) { VELOX_CHECK(lookupOutputProjections_.empty()); + lookupOutputProjections_.reserve(outputType_->size()); for (auto outputChannel = 0; outputChannel < outputType_->size(); ++outputChannel) { const auto outputName = outputType_->nameOf(outputChannel); - if (valueType_->containsChild(outputName)) { - const auto tableValueChannel = valueType_->getChildIdx(outputName); + const auto columnName = columnHandles.at(outputName)->name(); + if (valueType_->containsChild(columnName)) { + const auto tableValueChannel = valueType_->getChildIdx(columnName); // The hash table layout is: index columns, value columns. lookupOutputProjections_.emplace_back( keyType_->size() + tableValueChannel, outputChannel); continue; } - const auto tableKeyChannel = keyType_->getChildIdx(outputName); + const auto tableKeyChannel = keyType_->getChildIdx(columnName); lookupOutputProjections_.emplace_back(tableKeyChannel, outputChannel); } VELOX_CHECK_EQ(lookupOutputProjections_.size(), outputType_->size()); @@ -250,10 +345,21 @@ TestIndexSource::ResultIterator::ResultIterator( lookupResultIter_->reset(*lookupResult_); } +bool TestIndexSource::ResultIterator::hasNext() { + // If we have an async result ready, we have more to return. + if (asyncResult_.has_value()) { + return true; + } + + // If the iterator is not at end, there are more results to fetch. + return !lookupResultIter_->atEnd(); +} + std::optional> TestIndexSource::ResultIterator::next( vector_size_t size, ContinueFuture& future) { + const auto lookupSize = std::min(size, kMaxLookupSize); source_->checkNotFailed(); if (hasPendingRequest_.exchange(true)) { @@ -261,7 +367,7 @@ TestIndexSource::ResultIterator::next( } if (executor_ && !asyncResult_.has_value()) { - asyncLookup(size, future); + asyncLookup(lookupSize, future); return std::nullopt; } @@ -274,7 +380,7 @@ TestIndexSource::ResultIterator::next( asyncResult_.reset(); return result; } - return syncLookup(size); + return syncLookup(lookupSize); } void TestIndexSource::ResultIterator::extractLookupColumns( @@ -460,24 +566,33 @@ std::shared_ptr TestIndexConnector::createIndexSource( size_t numJoinKeys, const std::vector& joinConditions, const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, connector::ConnectorQueryCtx* connectorQueryCtx) { VELOX_CHECK_GE(inputType->size(), numJoinKeys + joinConditions.size()); auto testIndexTableHandle = - std::dynamic_pointer_cast(tableHandle); + std::dynamic_pointer_cast(tableHandle); VELOX_CHECK_NOT_NULL(testIndexTableHandle); const auto& indexTable = testIndexTableHandle->indexTable(); auto joinConditionExpr = toJoinConditionExpr(joinConditions, indexTable, inputType, columnHandles); + + std::unordered_map testColumnHandles; + for (const auto& [name, handle] : columnHandles) { + auto testColumnHandle = + std::dynamic_pointer_cast(handle); + VELOX_CHECK_NOT_NULL(testColumnHandle); + + testColumnHandles.emplace(name, testColumnHandle); + } + return std::make_shared( inputType, outputType, numJoinKeys, - std::move(joinConditionExpr), - std::move(testIndexTableHandle), + joinConditionExpr, + testIndexTableHandle, + testColumnHandles, connectorQueryCtx, executor_); } diff --git a/velox/exec/tests/utils/TestIndexStorageConnector.h b/velox/exec/tests/utils/TestIndexStorageConnector.h index 08a1c5d69b94..64f34c025ac7 100644 --- a/velox/exec/tests/utils/TestIndexStorageConnector.h +++ b/velox/exec/tests/utils/TestIndexStorageConnector.h @@ -16,7 +16,6 @@ #pragma once #include "velox/exec/HashTable.h" -#include "velox/exec/OperatorUtils.h" #include "velox/type/Type.h" namespace facebook::velox::exec::test { @@ -38,6 +37,13 @@ struct TestIndexTable { : keyType(std::move(_keyType)), dataType(std::move(_dataType)), table(std::move(_table)) {} + + // Create index table with the given key and value inputs. + static std::shared_ptr create( + size_t numEqualJoinKeys, + const RowVectorPtr& keyData, + const RowVectorPtr& valueData, + velox::memory::MemoryPool& pool); }; // The index table handle which provides the index table for index lookup. @@ -74,16 +80,23 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { obj["name"] = name(); obj["connectorId"] = connectorId(); obj["asyncLookup"] = asyncLookup_; + // For testing purpose only, we serialize the index table pointer as an + // long integer. + obj["indexTable"] = reinterpret_cast(indexTable_.get()); return obj; } - static std::shared_ptr create( + static std::shared_ptr create( const folly::dynamic& obj, void* context) { // NOTE: this is only for testing purpose so we don't support to serialize // the table. + auto ptr = obj["indexTable"].asInt(); + auto indexTablePtr = reinterpret_cast(ptr); return std::make_shared( - obj["connectorId"].getString(), nullptr, obj["asyncLookup"].asBool()); + obj["connectorId"].getString(), + std::shared_ptr(indexTablePtr, [](TestIndexTable*) {}), + obj["asyncLookup"].asBool()); } static void registerSerDe() { @@ -105,6 +118,39 @@ class TestIndexTableHandle : public connector::ConnectorTableHandle { const bool asyncLookup_; }; +using TestIndexTableHandlePtr = std::shared_ptr; + +class TestIndexColumnHandle : public connector::ColumnHandle { + public: + explicit TestIndexColumnHandle(std::string name) : name_{std::move(name)} {} + + const std::string& name() const override { + return name_; + } + + folly::dynamic serialize() const override { + auto obj = serializeBase("TestIndexColumnHandle"); + obj["columnName"] = name_; + return obj; + } + + static connector::ColumnHandlePtr create(const folly::dynamic& obj) { + auto name = obj["columnName"].asString(); + + return std::make_shared(name); + } + + static void registerSerDe() { + auto& registry = DeserializationRegistryForSharedPtr(); + registry.Register("TestIndexColumnHandle", TestIndexColumnHandle::create); + } + + private: + const std::string name_; +}; + +using TestIndexColumnHandlePtr = std::shared_ptr; + class TestIndexSource : public connector::IndexSource, public std::enable_shared_from_this { public: @@ -113,7 +159,9 @@ class TestIndexSource : public connector::IndexSource, const RowTypePtr& outputType, size_t numEqualJoinKeys, const core::TypedExprPtr& joinConditionExpr, - const std::shared_ptr& tableHandle, + const TestIndexTableHandlePtr& tableHandle, + const std::unordered_map& + columnHandles, connector::ConnectorQueryCtx* connectorQueryCtx, folly::Executor* executor); @@ -146,11 +194,15 @@ class TestIndexSource : public connector::IndexSource, std::unique_ptr lookupResult, folly::Executor* executor); + bool hasNext() override; + std::optional> next( vector_size_t size, ContinueFuture& future) override; private: + static constexpr vector_size_t kMaxLookupSize = 8192; + // Initializes the buffer used to store row pointers or indices for output // match result processing. template @@ -224,14 +276,16 @@ class TestIndexSource : public connector::IndexSource, void checkNotFailed(); // Initialize the output projections for lookup result processing. - void initOutputProjections(); + void initOutputProjections( + const std::unordered_map& + columnHandles); // Initialize the condition filter input type and projections if configured. void initConditionProjections(); void recordCpuTiming(const CpuWallTiming& timing); - const std::shared_ptr tableHandle_; + const TestIndexTableHandlePtr tableHandle_; const RowTypePtr inputType_; const RowTypePtr outputType_; const RowTypePtr keyType_; @@ -274,9 +328,8 @@ class TestIndexConnector : public connector::Connector { std::unique_ptr createDataSource( const RowTypePtr&, - const std::shared_ptr&, - const std:: - unordered_map>&, + const connector::ConnectorTableHandlePtr&, + const connector::ColumnHandleMap&, connector::ConnectorQueryCtx*) override { VELOX_UNSUPPORTED("{} not implemented", __FUNCTION__); } @@ -286,15 +339,13 @@ class TestIndexConnector : public connector::Connector { size_t numJoinKeys, const std::vector& joinConditions, const RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, connector::ConnectorQueryCtx* connectorQueryCtx) override; std::unique_ptr createDataSink( RowTypePtr, - std::shared_ptr, + connector::ConnectorInsertTableHandlePtr, connector::ConnectorQueryCtx*, connector::CommitStrategy) override { VELOX_UNSUPPORTED("{} not implemented", __FUNCTION__); @@ -316,5 +367,12 @@ class TestIndexConnectorFactory : public connector::ConnectorFactory { folly::Executor* cpuExecutor) override { return std::make_shared(id, properties, cpuExecutor); } + + static void registerConnector(folly::CPUThreadPoolExecutor* cpuExecutor) { + TestIndexConnectorFactory factory; + std::shared_ptr connector = + factory.newConnector(kTestIndexConnectorName, {}, nullptr, cpuExecutor); + connector::registerConnector(connector); + } }; } // namespace facebook::velox::exec::test diff --git a/velox/experimental/breeze/CMakeLists.txt b/velox/experimental/breeze/CMakeLists.txt index 122f0c995b04..90001307ee8a 100644 --- a/velox/experimental/breeze/CMakeLists.txt +++ b/velox/experimental/breeze/CMakeLists.txt @@ -42,9 +42,7 @@ option(BUILD_TRACING "Build tracing." ON) option(BUILD_GENERATE_TEST_FIXTURES "Generate test fixtures at build time." ON) if(NOT DEFINED PERFTEST_EXT_TYPES) - set(PERFTEST_EXT_TYPES - 0 - CACHE STRING "Extended test types for perf tests") + set(PERFTEST_EXT_TYPES 0 CACHE STRING "Extended test types for perf tests") endif() if(NOT DEFINED VELOX_PROJECT_SOURCE_DIR) @@ -56,7 +54,8 @@ list( CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" "${VELOX_PROJECT_SOURCE_DIR}/CMake" - "${VELOX_PROJECT_SOURCE_DIR}/CMake/third-party") + "${VELOX_PROJECT_SOURCE_DIR}/CMake/third-party" +) # Include Velox ThirdPartyToolchain dependencies macros include(ResolveDependency) diff --git a/velox/experimental/breeze/breeze/platforms/cuda.cuh b/velox/experimental/breeze/breeze/platforms/cuda.cuh index 661e0743b460..02378f9af452 100644 --- a/velox/experimental/breeze/breeze/platforms/cuda.cuh +++ b/velox/experimental/breeze/breeze/platforms/cuda.cuh @@ -371,12 +371,12 @@ CudaSpecialization::atomic_add(address.data()); + *reinterpret_cast(address.data()); unsigned long long assumed; do { assumed = old; old = atomicCAS( - reinterpret_cast(address.data()), assumed, + reinterpret_cast(address.data()), assumed, __double_as_longlong(value + __longlong_as_double(assumed))); } while (assumed != old); @@ -394,9 +394,9 @@ CudaSpecialization::atomic_add(address.data()), - *reinterpret_cast(&value)); - return *reinterpret_cast(&result); + atomicAdd(reinterpret_cast(address.data()), + *reinterpret_cast(&value)); + return *reinterpret_cast(&result); } // specialization for T=Slice @@ -409,10 +409,10 @@ __device__ __forceinline__ void CudaSpecialization::atomic_min< static_assert(sizeof(float) == sizeof(unsigned), "unexpected type sizes"); float current = atomic_load(address); while (current > value) { - unsigned old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + unsigned old = atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -429,10 +429,10 @@ __device__ __forceinline__ void CudaSpecialization::atomic_max< static_assert(sizeof(float) == sizeof(unsigned), "unexpected type sizes"); float current = atomic_load(address); while (current < value) { - unsigned old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + unsigned old = atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -452,10 +452,10 @@ CudaSpecialization::atomic_min value) { unsigned long long old = - atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -475,10 +475,10 @@ CudaSpecialization::atomic_max(address.data()), - *reinterpret_cast(¤t), - *reinterpret_cast(&value)); - current = *reinterpret_cast(&old); + atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(¤t), + *reinterpret_cast(&value)); + current = *reinterpret_cast(&old); if (current == value) { break; } @@ -496,10 +496,10 @@ __device__ __forceinline__ long long CudaSpecialization::atomic_cas< address, long long compare, long long value) { unsigned long long old = - atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(&compare), - *reinterpret_cast(&value)); - return *reinterpret_cast(&old); + atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(&compare), + *reinterpret_cast(&value)); + return *reinterpret_cast(&old); } // specialization for T=Slice @@ -518,9 +518,9 @@ __device__ __forceinline__ long long CudaSpecialization::atomic_cas< unsigned long long>::pointer_type; unsigned long long old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(&compare), - *reinterpret_cast(&value)); - return *reinterpret_cast(&old); + *reinterpret_cast(&compare), + *reinterpret_cast(&value)); + return *reinterpret_cast(&old); } // specialization for T=Slice @@ -532,10 +532,10 @@ __device__ __forceinline__ float CudaSpecialization::atomic_cas< address, float compare, float value) { static_assert(sizeof(float) == sizeof(unsigned), "unexpected type sizes"); - unsigned old = atomicCAS(reinterpret_cast(address.data()), - *reinterpret_cast(&compare), - *reinterpret_cast(&value)); - return *reinterpret_cast(&old); + unsigned old = atomicCAS(reinterpret_cast(address.data()), + *reinterpret_cast(&compare), + *reinterpret_cast(&value)); + return *reinterpret_cast(&old); } #if __CUDA_ARCH__ >= 800 diff --git a/velox/experimental/breeze/breeze/platforms/hip.hpp b/velox/experimental/breeze/breeze/platforms/hip.hpp index 2ddef322d5e4..9f9b38d19f76 100644 --- a/velox/experimental/breeze/breeze/platforms/hip.hpp +++ b/velox/experimental/breeze/breeze/platforms/hip.hpp @@ -70,11 +70,11 @@ struct HipPlatform { } template __device__ __forceinline__ T atomic_load(SliceT address) { - return *reinterpret_cast(address.data()); + return *reinterpret_cast(address.data()); } template __device__ __forceinline__ void atomic_store(SliceT address, T value) { - *reinterpret_cast(address.data()) = value; + *reinterpret_cast(address.data()) = value; } template __device__ __forceinline__ T atomic_cas(SliceT address, T compare, T value) { diff --git a/velox/experimental/breeze/breeze/platforms/metal.h b/velox/experimental/breeze/breeze/platforms/metal.h index c1a3e28882fc..498fa31858c2 100644 --- a/velox/experimental/breeze/breeze/platforms/metal.h +++ b/velox/experimental/breeze/breeze/platforms/metal.h @@ -78,16 +78,15 @@ struct MetalPlatform { MetalSpecialization::atomic_store(address, value); } template