Skip to content
10 changes: 10 additions & 0 deletions external/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ set(KINETO_URL
use_mirror(VARIABLE KINETO_URL URL ${KINETO_URL})
set(KINETO_MD5 f9b550591b3899fb267270c19484933f)

set(CUDNN_FRONTEND_URL
https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v0.9.1.zip)
use_mirror(VARIABLE CUDNN_FRONTEND_URL URL ${CUDNN_FRONTEND_URL})
set(CUDNN_FRONTEND_MD5 0d28ff6aaa984dac4f7d16acfc48de72)

set(EXTERNAL_TARGETS)

if(WITH_TBB) # set(WITH_${threading_runtime_item} ON) in threading.cmake
Expand All @@ -33,6 +38,11 @@ list(APPEND EXTERNAL_TARGETS fmt)
add_subdirectory(kineto)
list(APPEND EXTERNAL_TARGETS kineto)

if(BUILD_CUDA)
add_subdirectory(cudnn_frontend)
list(APPEND EXTERNAL_TARGETS cudnn_frontend)
endif()

mark_targets_as_system(${EXTERNAL_TARGETS})

set_property(GLOBAL PROPERTY EXTERNAL_TARGETS ${EXTERNAL_TARGETS})
16 changes: 16 additions & 0 deletions external/cudnn_frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
include(FetchContent)
FetchContent_Declare(
cudnn_frontend
URL ${CUDNN_FRONTEND_URL}
URL_HASH MD5=${CUDNN_FRONTEND_MD5}
)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
FetchContent_MakeAvailable(cudnn_frontend)

set(CUDNN_FRONTEND_INSTALL_DIR ${THIRD_PARTY_DIR}/cudnn_frontend)
install(
TARGETS cudnn_frontend
EXPORT oneflow
LIBRARY DESTINATION ${CUDNN_FRONTEND_INSTALL_DIR}/lib
ARCHIVE DESTINATION ${CUDNN_FRONTEND_INSTALL_DIR}/lib
)
258 changes: 258 additions & 0 deletions oneflow/core/device/cudnn_conv_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/core/framework/infer_util.h"
#include "oneflow/core/device/cudnn_conv_util.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/common/cached_caller.h"
Expand All @@ -22,6 +23,7 @@ limitations under the License.
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/framework/op_kernel.h"
#include "oneflow/core/job/lazy_mode.h"

namespace oneflow {

Expand Down Expand Up @@ -82,6 +84,7 @@ perf_t GetBestAlgorithm(const CudnnConvArgs& args, CudnnConvResource* res,
FOR_RANGE(size_t, i, 0, perf_vec.size()) {
// Note: Shouldn't all returned results be successful?
CHECK_EQ(perf_vec[i].status, CUDNN_STATUS_SUCCESS);
// TODO workspace size limit will lead to dismatch result with pytorch for large tensor
if (perf_vec[i].memory > args.params.max_ws_size) { continue; }
if (args.deterministic && perf_vec[i].determinism == CUDNN_NON_DETERMINISTIC) { continue; }
found_algo_idx = i;
Expand Down Expand Up @@ -332,6 +335,106 @@ CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType
params.max_ws_size = max_workspace_size;
}

cudnn_frontend::Tensor GetTensorDescriptor(const user_op::Tensor* t, const int64_t id) {
auto dim = t->shape_view();
auto stride = t->stride();
return cudnn_frontend::TensorBuilder()
.setDim(dim.size(), dim.data())
.setStride(stride.size(), stride.data())
.setId(id)
.setAlignment(32)
.setDataType(GetCudnnDataType(t->data_type()))
.build();
}

cudnn_frontend::Tensor GetTensorDescriptor(const user_op::TensorDesc& t, const int64_t id) {
auto dim = t.shape();
auto stride = t.stride();
return cudnn_frontend::TensorBuilder()
.setDim(dim.size(), dim.data())
.setStride(stride.size(), stride.data())
.setId(id)
.setAlignment(32)
.setDataType(GetCudnnDataType(t.data_type()))
.build();
}

cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::InferContext& ctx,
cudnnDataType_t data_type) {
if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) {
data_type = CUDNN_DATA_FLOAT;
}

std::vector<int64_t> padding;
const auto& padding_before = ctx.Attr<std::vector<int32_t>>("padding_before");
copy(padding_before.begin(), padding_before.end(), back_inserter(padding));

std::vector<int64_t> stride;
const auto& strides = ctx.Attr<std::vector<int32_t>>("strides");
copy(strides.begin(), strides.end(), back_inserter(stride));

std::vector<int64_t> dilation;
const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>("dilation_rate");
copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation));

uint64_t ndim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(ndim)
.setStrides(ndim, stride.data())
.setPrePadding(ndim, padding.data())
.setPostPadding(ndim, padding.data())
.setDilation(ndim, dilation.data())
.build();
}

cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::KernelComputeContext& ctx,
cudnnDataType_t data_type) {
if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) {
data_type = CUDNN_DATA_FLOAT;
}

std::vector<int64_t> padding;
const auto& padding_before = ctx.Attr<std::vector<int32_t>>("padding_before");
copy(padding_before.begin(), padding_before.end(), back_inserter(padding));

std::vector<int64_t> stride;
const auto& strides = ctx.Attr<std::vector<int32_t>>("strides");
copy(strides.begin(), strides.end(), back_inserter(stride));

std::vector<int64_t> dilation;
const auto& dilation_rate = ctx.Attr<std::vector<int32_t>>("dilation_rate");
copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation));

uint64_t ndim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(data_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(ndim)
.setStrides(ndim, stride.data())
.setPrePadding(ndim, padding.data())
.setPostPadding(ndim, padding.data())
.setDilation(ndim, dilation.data())
.build();
}

CudnnConvArgsV8::CudnnConvArgsV8(const user_op::InferContext& ctx, const user_op::TensorDesc& x,
const user_op::TensorDesc& y, const user_op::TensorDesc& w)
: xdesc(GetTensorDescriptor(x, 'x')),
ydesc(GetTensorDescriptor(y, 'y')),
wdesc(GetTensorDescriptor(w, 'w')),
cdesc(GetConvDescriptor(ctx, GetCudnnDataType(y.data_type()))),
beta(0.0f) {}

CudnnConvArgsV8::CudnnConvArgsV8(const user_op::KernelComputeContext& ctx, const user_op::Tensor* x,
const user_op::Tensor* y, const user_op::Tensor* w)
: xdesc(GetTensorDescriptor(x, 'x')),
ydesc(GetTensorDescriptor(y, 'y')),
wdesc(GetTensorDescriptor(w, 'w')),
cdesc(GetConvDescriptor(ctx, GetCudnnDataType(y->data_type()))),
beta(0.0f) {}

ManagedCudnnConvResource::ManagedCudnnConvResource(const CudnnConvArgs& args)
: handle_(nullptr), x_dptr_(nullptr), w_dptr_(nullptr), y_dptr_(nullptr), ws_dptr_(nullptr) {
x_byte_size_ = ByteSize4Tensor(args.params.x_dims, args.params.x_ndim, args.params.x_data_type);
Expand Down Expand Up @@ -424,6 +527,161 @@ cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvReso
args.wdesc.Get(), algo, sz);
}

cudnn_frontend::OperationGraph BuildConvOpGraph(const cudnnHandle_t handle,
const cudnnBackendDescriptorType_t desc,
const cudnn_frontend::Tensor& xdesc,
const cudnn_frontend::Tensor& ydesc,
const cudnn_frontend::Tensor& wdesc,
const cudnn_frontend::ConvDesc& cdesc, float beta) {
auto conv_op = cudnn_frontend::OperationBuilder(desc)
.setxDesc(xdesc)
.setyDesc(ydesc)
.setwDesc(wdesc)
.setcDesc(cdesc)
.setBeta(beta)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};
auto op_graph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(ops.size(), ops.data())
.build();
return op_graph;
}

void FilterEngineConfigs(cudnn_frontend::EngineConfigList& from,
cudnn_frontend::EngineConfigList& to, bool deterministic) {
auto filter = [=](cudnnBackendDescriptor_t c) {
if (deterministic) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) {
return true;
}
}
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
return true;
}
return false;
};
cudnn_frontend::filter(from, to, filter);
}

std::vector<cudnn_frontend::GeneratorSource> GetGeneratorSources(
const cudnnBackendDescriptorType_t desc) {
bool deterministic = Singleton<ResourceDesc, ForSession>::Get()
->resource()
.cudnn_conf()
.cudnn_conv_use_deterministic_algo_only();
bool heuristic = ParseBooleanFromEnv("ONEFLOW_CUDNN_USE_HEURISTIC_MODE_B", false);
auto heur_mode = heuristic ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
// Method for engine config generator based on heuristics
const auto heurgen_method =
[deterministic,
heur_mode](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(heur_mode)
.build();
auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount());
cudnn_frontend::EngineConfigList filtered_configs;
FilterEngineConfigs(engine_configs, filtered_configs, deterministic);
return filtered_configs;
};
// Method for engine config generator based on fallback list
const auto fallback_method =
[desc,
deterministic](cudnn_frontend::OperationGraph& opGraph) -> cudnn_frontend::EngineConfigList {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(opGraph)
.setOperation(desc)
.build();
auto& fallback_list = fallback.getFallbackList();
cudnn_frontend::EngineConfigList filtered_configs;
FilterEngineConfigs(fallback_list, filtered_configs, deterministic);
return filtered_configs;
};
std::vector<cudnn_frontend::GeneratorSource> sources = {heurgen_method, fallback_method};
return sources;
}

cudnn_frontend::EngineConfigList CudnnFrontendGetConfigs(const cudnnHandle_t handle,
const cudnnBackendDescriptorType_t desc,
const cudnn_frontend::Tensor& xdesc,
const cudnn_frontend::Tensor& ydesc,
const cudnn_frontend::Tensor& wdesc,
const cudnn_frontend::ConvDesc& cdesc,
float beta, std::string& tag) {
auto op_graph = BuildConvOpGraph(handle, desc, xdesc, ydesc, wdesc, cdesc, beta);
tag = op_graph.getTag();
auto sources = GetGeneratorSources(desc);
cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data());
auto configs = generator.generate_engine_config(op_graph);
return configs;
}

bool PlanErrataException(const cudnnHandle_t handle, const std::string& executionPlanTag) {
static nlohmann::json errata_json_handle;
static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, "");
if (!has_json) {
return false;
} else {
return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle,
[]() { return true; });
}
}

void RunConvPlan(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y,
user_op::Tensor* w, user_op::Tensor* buf,
const cudnn_frontend::ExecutionPlan& plan) {
void* data[] = {x->mut_dptr(), y->mut_dptr(), w->mut_dptr()};
int64_t ids[] = {'x', 'y', 'w'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(buf->mut_dptr())
.setDataPointers(3, data)
.setUids(3, ids)
.build();
OF_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
}

void TryConfigs(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y,
user_op::Tensor* w, user_op::Tensor* buf, cudnn_frontend::EngineConfigList& configs,
const std::string& tag) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, tag)
.build();
if (PlanErrataException(handle, plan.getTag())) { continue; }
RunConvPlan(handle, x, y, w, buf, plan);
return;
} catch (cudnn_frontend::cudnnException& e) {}
}
}

void CudnnFrontendRunConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc,
user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w,
user_op::Tensor* b, const CudnnConvArgsV8& args) {
std::string tag;
auto configs = CudnnFrontendGetConfigs(handle, desc, args.xdesc, args.ydesc, args.wdesc,
args.cdesc, args.beta, tag);
TryConfigs(handle, x, y, w, b, configs, tag);
}

size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle,
cudnn_frontend::EngineConfigList& configs,
const std::string& tag) {
for (auto& config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config, tag)
.build();
if (PlanErrataException(handle, plan.getTag())) { continue; }
if (plan.getWorkspaceSize() > 0L) { return plan.getWorkspaceSize(); }
} catch (cudnn_frontend::cudnnException& e) {}
}
return 1L;
}

template<>
struct CudnnConvAlgorithmSearch<cudnnConvolutionFwdAlgoPerf_t> {
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
Expand Down
34 changes: 34 additions & 0 deletions oneflow/core/device/cudnn_conv_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ limitations under the License.

#ifdef WITH_CUDA

#include "cudnn_frontend.h"
#include "cudnn_frontend_EngineConfigGenerator.h"
#include "oneflow/core/common/tensor_desc.h"
#include "oneflow/core/device/cudnn_util.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/user_op_tensor.h"

namespace oneflow {

Expand Down Expand Up @@ -93,6 +97,20 @@ struct CudnnConvArgs final {
bool enable_pseudo_half);
};

struct CudnnConvArgsV8 final {
cudnn_frontend::Tensor xdesc;
cudnn_frontend::Tensor ydesc;
cudnn_frontend::Tensor wdesc;
cudnn_frontend::ConvDesc cdesc;
float beta;

OF_DISALLOW_COPY_AND_MOVE(CudnnConvArgsV8);
explicit CudnnConvArgsV8(const user_op::InferContext& ctx, const user_op::TensorDesc& x,
const user_op::TensorDesc& y, const user_op::TensorDesc& w);
explicit CudnnConvArgsV8(const user_op::KernelComputeContext& ctx, const user_op::Tensor* x,
const user_op::Tensor* y, const user_op::Tensor* w);
};

class CudnnConvResource {
public:
CudnnConvResource() = default;
Expand Down Expand Up @@ -168,6 +186,22 @@ cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvReso
cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res,
cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz);

cudnn_frontend::EngineConfigList CudnnFrontendGetConfigs(const cudnnHandle_t handle,
const cudnnBackendDescriptorType_t desc,
const cudnn_frontend::Tensor& xdesc,
const cudnn_frontend::Tensor& ydesc,
const cudnn_frontend::Tensor& wdesc,
const cudnn_frontend::ConvDesc& cdesc,
float beta, std::string& tag);

void CudnnFrontendRunConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc,
user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w,
user_op::Tensor* b, const CudnnConvArgsV8& args);

size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle,
cudnn_frontend::EngineConfigList& configs,
const std::string& tag);

template<typename perf_t>
perf_t FindCudnnConvAlgorithm(CudnnConvArgs* args);

Expand Down
Loading