-
Notifications
You must be signed in to change notification settings - Fork 46
add onednn w8a16 gemm #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| [submodule "third_party/oneDNN"] | ||
| path = third_party/oneDNN | ||
| url = https://github.com/oneapi-src/oneDNN.git |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,8 @@ message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") | |
|
|
||
| include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) | ||
|
|
||
| list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) | ||
|
|
||
| # Suppress potential warnings about unused manually-specified variables | ||
| set(ignoreMe "${VLLM_PYTHON_PATH}") | ||
|
|
||
|
|
@@ -66,6 +68,7 @@ append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") | |
| # Import torch cmake configuration. | ||
| find_package(Torch REQUIRED) | ||
|
|
||
| find_package(oneDNN QUIET) | ||
|
|
||
| # | ||
| # Forward the non-CUDA device extensions to external CMake scripts. | ||
|
|
@@ -191,8 +194,10 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") | |
| endif() | ||
|
|
||
| if(ONEDNN_FOUND) | ||
| set(_ONEDNN_SRC) | ||
| file(GLOB _ONEDNN_SRC csrc/xpu/onednn/*.cpp) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about *.h? |
||
| list(APPEND VLLM_EXT_XPU_SRC | ||
| "csrc/xpu/onednn/*.cpp" | ||
| ${_ONEDNN_SRC} | ||
| ) | ||
| include_directories(${ONEDNN_INCLUDE_DIR}) | ||
| link_libraries(${ONEDNN_LIBRARY}) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,3 +43,19 @@ VLLM_TARGET_DEVICE=xpu python3 setup.py bdist_wheel | |
|
|
||
| ### how to use in vLLM | ||
|
||
| Please refer to temporary branch https://github.com/jikunshang/vllm/tree/xpu_kernel to install & test vllm which replaces `rms_norm` kernel from IPEX to vllm-xpu-kernels. | ||
|
|
||
| ### Why Static Linking DNNL Instead of Shared Linking? | ||
|
|
||
| We chose to **statically link oneDNN (DNNL)** rather than using it as a shared library for the following reasons: | ||
|
|
||
| #### 1. **Version Compatibility** | ||
|
|
||
| Static linking ensures our application always uses the exact version of DNNL. With shared libraries, there's a risk that system-installed versions might be incompatible or introduce subtle bugs due to API/ABI changes. | ||
|
|
||
| #### 2. **Performance Consistency** | ||
|
|
||
| By linking statically, we avoid potential performance variability introduced by different builds or configurations of DNNL that might be present on the host system. | ||
|
|
||
| #### 3. **Avoiding Runtime Errors** | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe one more reason: torch-xpu also use static link. cc @rogerxfeng8
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated. |
||
|
|
||
| Using shared libraries requires correct paths and environment setup (`LD_LIBRARY_PATH` on Linux). Static linking avoids issues where DNNL cannot be found or loaded at runtime. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # - Try to find oneDNN | ||
| # | ||
| # The following are set after configuration is done: | ||
| # ONEDNN_FOUND : set to true if oneDNN is found. | ||
| # ONEDNN_INCLUDE_DIR : path to oneDNN include dir. | ||
| # ONEDNN_LIBRARY : list of libraries for oneDNN | ||
| # | ||
|
|
||
| IF (NOT ONEDNN_FOUND) | ||
| SET(ONEDNN_FOUND OFF) | ||
|
|
||
| SET(ONEDNN_LIBRARY) | ||
| SET(ONEDNN_INCLUDE_DIR) | ||
| SET(DNNL_INCLUDES) | ||
|
|
||
| SET(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/third_party") | ||
| SET(ONEDNN_DIR "oneDNN") | ||
| SET(ONEDNN_ROOT "${THIRD_PARTY_DIR}/${ONEDNN_DIR}") | ||
|
|
||
| FIND_PATH(ONEDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${ONEDNN_ROOT} PATH_SUFFIXES include NO_DEFAULT_PATH) | ||
| IF(NOT ONEDNN_INCLUDE_DIR) | ||
| FIND_PACKAGE(Git) | ||
| IF(NOT Git_FOUND) | ||
| MESSAGE(FATAL_ERROR "Can not find Git executable!") | ||
| ENDIF() | ||
| EXECUTE_PROCESS( | ||
| COMMAND ${GIT_EXECUTABLE} submodule update --init ${ONEDNN_DIR} | ||
| WORKING_DIRECTORY ${THIRD_PARTY_DIR} COMMAND_ERROR_IS_FATAL ANY) | ||
| FIND_PATH(ONEDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${ONEDNN_ROOT} PATH_SUFFIXES include NO_DEFAULT_PATH) | ||
| ENDIF(NOT ONEDNN_INCLUDE_DIR) | ||
|
|
||
| IF(NOT ONEDNN_INCLUDE_DIR) | ||
| MESSAGE(FATAL_ERROR "oneDNN source files not found!") | ||
| ENDIF(NOT ONEDNN_INCLUDE_DIR) | ||
|
|
||
| SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "oneDNN sycl primitive cache" FORCE) | ||
|
|
||
| SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) | ||
|
|
||
| SET(DNNL_CPU_RUNTIME "THREADPOOL" CACHE STRING "oneDNN cpu backend" FORCE) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assuming this is a copy from onednn makefile. cpu runtime can be cleaned up.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need set this, or the OneDNN will set this env to default OMP, which will add dependency to libiomp5.so. |
||
| SET(DNNL_GPU_RUNTIME "SYCL" CACHE STRING "oneDNN gpu backend" FORCE) | ||
| SET(DNNL_BUILD_TESTS FALSE CACHE BOOL "build with oneDNN tests" FORCE) | ||
| SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "build with oneDNN examples" FORCE) | ||
| SET(DNNL_ENABLE_CONCURRENT_EXEC TRUE CACHE BOOL "multi-thread primitive execution" FORCE) | ||
| SET(DNNL_EXPERIMENTAL TRUE CACHE BOOL "use one pass for oneDNN BatchNorm" FORCE) | ||
|
|
||
| ADD_SUBDIRECTORY(${ONEDNN_ROOT} oneDNN EXCLUDE_FROM_ALL) | ||
| SET(ONEDNN_LIBRARY ${DNNL_LIBRARY_NAME}) | ||
| IF(NOT TARGET ${ONEDNN_LIBRARY}) | ||
| MESSAGE(FATAL_ERROR "Failed to include oneDNN target") | ||
| ENDIF(NOT TARGET ${ONEDNN_LIBRARY}) | ||
|
|
||
| IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC) | ||
| TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-uninitialized) | ||
| TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-strict-overflow) | ||
| TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-error=strict-overflow) | ||
| ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC) | ||
|
|
||
| TARGET_COMPILE_OPTIONS(${ONEDNN_LIBRARY} PRIVATE -Wno-tautological-compare) | ||
| GET_TARGET_PROPERTY(DNNL_INCLUDES ${ONEDNN_LIBRARY} INCLUDE_DIRECTORIES) | ||
| TARGET_LINK_LIBRARIES(${ONEDNN_LIBRARY} PRIVATE ze_loader) | ||
| list(APPEND ONEDNN_INCLUDE_DIR ${DNNL_INCLUDES}) | ||
|
|
||
| # Upper level targets should not load header files from oneDNN's third party. | ||
| LIST(FILTER ONEDNN_INCLUDE_DIR EXCLUDE REGEX | ||
| ".*third_party/oneDNN/third_party.*") | ||
|
|
||
| SET(ONEDNN_FOUND ON) | ||
| MESSAGE(STATUS "Found oneDNN: TRUE") | ||
|
|
||
| ENDIF(NOT ONEDNN_FOUND) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| #include <vector> | ||
jikunshang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #include "fp8_gemm_w8a16.h" | ||
|
|
||
| torch::Tensor fp8_gemm_w8a16(const torch::Tensor& A, const torch::Tensor& B, | ||
| bool trans_B, | ||
| const std::optional<torch::Tensor>& B_scale_, | ||
| const std::optional<torch::Tensor>& bias_) { | ||
| TORCH_CHECK(A.dim() == 2 || A.dim() == 3, | ||
| "fp8_gemm_w8a16 only support 2D and 3D inputs!\n"); | ||
| TORCH_CHECK(B.dim() == 2, "fp8_gemm_w8a16 only support 2D weights!\n"); | ||
|
|
||
| std::vector<int64_t> result_shape; | ||
| if (A.dim() == 2) { | ||
| if (trans_B) { | ||
| result_shape = {A.size(0), B.size(0)}; | ||
| } else { | ||
| result_shape = {A.size(0), B.size(1)}; | ||
| } | ||
| // src{m, k}, wei{k, n}, bias{n}, dst{m, n} | ||
| } else { | ||
| if (trans_B) { | ||
| result_shape = {A.size(0), A.size(1), B.size(0)}; | ||
| } else { | ||
| result_shape = {A.size(0), A.size(1), B.size(1)}; | ||
| } | ||
| // src{b, m, k}, wei{k, n}, bias{n}, dst{b, m, n} | ||
| } | ||
|
|
||
| // deal with input shape [m, b, k] stride [k, m * k, 1] | ||
| auto k = A.size(A.dim() - 1); | ||
| auto n = result_shape.back(); | ||
| auto res_stride = A.strides().vec(); | ||
| for (int i = 0; i < res_stride.size() - 1; i++) { | ||
| res_stride[i] = res_stride[i] / k * n; | ||
| } | ||
|
|
||
| torch::Tensor result = | ||
| at::empty_strided(result_shape, res_stride, A.options()); | ||
|
|
||
| // check if nt format | ||
| bool is_nt = true; | ||
| if (trans_B) { | ||
| is_nt = B.strides()[B.dim() - 1] == 1; | ||
| } else { | ||
| is_nt = B.strides()[B.dim() - 2] == 1; | ||
| } | ||
|
|
||
| torch::Tensor B_scale = B_scale_.has_value() | ||
| ? B_scale_.value() | ||
| : at::ones({1}, B.options().dtype(A.dtype())); | ||
|
|
||
| oneDNN::dnnl_matmul_w8a16_fp8(result, A, B, is_nt, bias_, B_scale); | ||
| return result; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| #pragma once | ||
|
|
||
| #include <c10/xpu/XPUStream.h> | ||
| #include <dnnl.hpp> | ||
| #include <torch/torch.h> | ||
|
|
||
| #include "onednn_ext.h" | ||
|
|
||
| namespace oneDNN { | ||
|
|
||
| using bias_type_t = at::native::onednn::bias_type_t; | ||
| using trans_type_t = at::native::onednn::trans_type_t; | ||
| using GpuStreamManager = at::native::onednn::GpuStreamManager; | ||
| using GpuEngineManager = at::native::onednn::GpuEngineManager; | ||
|
|
||
| static inline void dnnl_matmul_w8a16_fp8( | ||
| torch::Tensor& result, const torch::Tensor& mat1, const torch::Tensor& mat2, | ||
| bool trans_b, const std::optional<torch::Tensor>& bias, | ||
| const torch::Tensor& m2_sc, const int64_t group_size = 0) { | ||
| TORCH_CHECK(mat2.scalar_type() == at::ScalarType::Float8_e5m2 || | ||
| mat2.scalar_type() == at::ScalarType::Float8_e4m3fn, | ||
| "weight must be f8_e5m2 or f8_e4m3fn for fp8 matmul"); | ||
| auto src_sz = mat1.sizes(); | ||
| auto o_sz = result.sizes(); | ||
|
|
||
| const int m = std::reduce(src_sz.begin(), src_sz.end() - 1, 1, | ||
| std::multiplies<int64_t>()); | ||
| const int n = o_sz.back(); // presume channel last format | ||
| const int k = *(src_sz.end() - 1); | ||
|
|
||
| // get joint dtypes | ||
| joint_dtypes_t jd; | ||
| auto in_dtype = mat1.scalar_type(); | ||
| auto wei_dtype = mat2.scalar_type(); | ||
| if (in_dtype == at::ScalarType::Half) { | ||
| jd = wei_dtype == at::ScalarType::Float8_e5m2 ? joint_dtypes_t::f16_f8_e5m2 | ||
| : joint_dtypes_t::f16_f8_e4m3; | ||
| } else if (in_dtype == at::ScalarType::BFloat16) { | ||
| jd = wei_dtype == at::ScalarType::Float8_e5m2 | ||
| ? joint_dtypes_t::bf16_f8_e5m2 | ||
| : joint_dtypes_t::bf16_f8_e4m3; | ||
| } else { | ||
| TORCH_INTERNAL_ASSERT( | ||
| false, "Unsupported data type for fp8 matmul: ", mat1.scalar_type()); | ||
| } | ||
|
|
||
| // get bias type | ||
| bias_type_t b_type; | ||
| if (bias.has_value() && bias.value().defined()) { | ||
| auto& b = bias.value(); | ||
| const auto nuelm = b.numel(); | ||
| if (nuelm == 1) { | ||
| b_type = bias_type_t::scalar; | ||
| } else if (nuelm == m * n) { | ||
| b_type = bias_type_t::mn; | ||
| } else if (b.size(b.dim() - 1) == n && nuelm == n) { | ||
| b_type = bias_type_t::n; | ||
| } else if (b.size(b.dim() - 1) == 1 && nuelm == m) { | ||
| b_type = bias_type_t::m; | ||
| } else if (nuelm == 0) { | ||
| b_type = bias_type_t::none; | ||
| } else { | ||
| TORCH_CHECK(0, "unsupported bias dim in matmul ...", b.sizes()); | ||
| } | ||
| } else { | ||
| b_type = bias_type_t::none; | ||
| } | ||
|
|
||
| trans_type_t tt = trans_type_t::nn; | ||
| if (trans_b) { | ||
| // transpose mat2 | ||
| tt = trans_type_t::nt; | ||
| } | ||
|
|
||
| // get lda ldb and ldc | ||
| auto mat1_strides = mat1.strides(); | ||
| int64_t leading_dim = -1; | ||
| if (mat1.dim() == 2) { | ||
| leading_dim = 0; | ||
| } else if (mat1.dim() == 3) { | ||
| leading_dim = mat1_strides[0] < mat1_strides[1] ? 0 : 1; | ||
| } else { | ||
| TORCH_CHECK(false, | ||
| "Unsupported input dimension for fp8 matmul: ", mat1.dim()); | ||
| } | ||
| int64_t lda = mat1_strides[leading_dim]; | ||
| int64_t ldb = mat2.strides()[mat2.dim() - 1] == 1 | ||
| ? mat2.strides()[mat2.dim() - 2] | ||
| : mat2.strides()[mat2.dim() - 1]; | ||
| int64_t ldc = result.strides()[leading_dim]; | ||
|
|
||
| auto f_attr = [&](dnnl::primitive_attr& pattr) { | ||
| pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | ||
| }; | ||
|
|
||
| int arg_off = 0; | ||
|
|
||
| // ************************************************************ | ||
| // get device, engine, stream | ||
| const int dev_id = c10::xpu::getCurrentXPUStream().device_index(); | ||
| at::Device curDevice = at::Device(at::kXPU, dev_id); | ||
| auto engine = GpuEngineManager::Instance().get_engine(curDevice); | ||
|
|
||
| auto& matmul_ext = matmul_primitive_create_and_cache( | ||
| jd, tt, b_type, m, n, k, lda, ldb, ldc, dev_id, f_attr, group_size); | ||
|
|
||
| matmul_ext.set_attribute(arg_off++, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, | ||
| m2_sc.data_ptr(), [&]() { | ||
| return at::native::onednn::make_onednn_memory( | ||
| get_onednn_md(m2_sc), engine, | ||
| m2_sc.data_ptr()); | ||
| }); | ||
|
|
||
| std::vector<std::pair<int, void*>> arg_handles; | ||
| arg_handles.reserve(8); | ||
|
|
||
| arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr()); | ||
| arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr()); | ||
| arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr()); | ||
| if (b_type != bias_type_t::none) { | ||
| arg_handles.emplace_back(DNNL_ARG_BIAS, bias.value().data_ptr()); | ||
| } | ||
|
|
||
| int scratchpad_size = matmul_ext.get_scratchpad_size(); | ||
| torch::Tensor scratchpad_tensor = at::empty( | ||
| {scratchpad_size}, mat1.options().dtype(at::kByte), c10::nullopt); | ||
| arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr()); | ||
|
|
||
| auto strm = GpuStreamManager::Instance().get_stream(); | ||
jikunshang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| DPCPP_ONEDNN_EXEC_WITH_ARGHANDLES(matmul_ext, strm, engine, arg_handles, | ||
| arg_off); | ||
| } | ||
| } // namespace oneDNN | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REQUIRED is preferred which stops with error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, I will change this.
BTW, we also have env ONEDNN_FOUND to detect if the onednn is found.