diff --git a/conv2d_backward.cpp b/conv2d_backward.cpp deleted file mode 100644 index 71b77be..0000000 --- a/conv2d_backward.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include - -#include -#include -#include - -at::Tensor backward_weight( - c10::ArrayRef weight_size, - const at::Tensor& grad_output, - const at::Tensor& input, - c10::ArrayRef padding, - c10::ArrayRef stride, - c10::ArrayRef dilation, - int64_t groups, - bool benchmark, - bool deterministic) { - - return at::cudnn_convolution_backward_weight( - weight_size, - grad_output, - input, - padding, - stride, - dilation, - groups, - benchmark, - deterministic); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("backward", &backward_weight, "Conv2d backward cudnn"); -} diff --git a/cpp_functions.cpp b/cpp_functions.cpp deleted file mode 100755 index aa0c245..0000000 --- a/cpp_functions.cpp +++ /dev/null @@ -1,89 +0,0 @@ -#include - -#include -#include -#include - -#include -#include -#include -#include - -at::Tensor backward_weight( - c10::ArrayRef weight_size, - const at::Tensor& grad_output, - const at::Tensor& input, - c10::ArrayRef padding, - c10::ArrayRef stride, - c10::ArrayRef dilation, - int64_t groups, - bool benchmark, - bool deterministic, - bool allow_fp32) { - - return at::cudnn_convolution_backward_weight( - weight_size, - grad_output, - input, - padding, - stride, - dilation, - groups, - benchmark, - deterministic, - allow_fp32); -} -at::Tensor backward_input( - c10::ArrayRef input_size, - const at::Tensor& grad_output, - const at::Tensor& weight, - c10::ArrayRef padding, - c10::ArrayRef stride, - c10::ArrayRef dilation, - int64_t groups, - bool benchmark, - bool deterministic, - bool allow_fp32) { - - return at::cudnn_convolution_backward_input( - input_size, - grad_output, - weight, - padding, - stride, - dilation, - groups, - benchmark, - deterministic, - allow_fp32); -} - -// From pytorch/torch/csrc/Module.cpp -void DLPack_Capsule_Destructor(PyObject* data) { - HANDLE_TH_ERRORS - DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor"); - if (dlMTensor) { - // the dlMTensor has not been consumed, call deleter ourselves - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - dlMTensor->deleter(const_cast(dlMTensor)); - } else { - // the dlMTensor has been consumed - // PyCapsule_GetPointer has set an error indicator - PyErr_Clear(); - } - END_HANDLE_TH_ERRORS_RET() -} - -namespace py = pybind11; -using namespace pybind11::literals; - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("backward", &backward_weight, "Conv backward_weight cudnn"); - m.def("backward_input", &backward_input, "Conv backward_input cudnn"); - m.def("to_dlpack_with_device_id", [](const at::Tensor& data, int64_t device_id) { - DLManagedTensor* dlMTensor = at::toDLPack(data); - dlMTensor->dl_tensor.ctx.device_id = device_id; - auto capsule = py::capsule(dlMTensor, "dltensor", DLPack_Capsule_Destructor); - return capsule; - }, "Specify device_id in dlpack, for cupy to copy to right GPU"); -} diff --git a/scnn.py b/scnn.py index 6ad1fc2..565b8e9 100644 --- a/scnn.py +++ b/scnn.py @@ -18,7 +18,6 @@ from torch._six import container_abcs from torch.nn.modules.conv import _ConvNd from torch.nn.modules.utils import _pair -from torch.utils.cpp_extension import load from tqdm import tqdm @@ -37,10 +36,6 @@ def forward_amp_decorator(func): def backward_amp_decorator(func): return func -# Load and compile cpp code to call cudnn conv2d backward function -dirname = os.path.dirname(__file__) -filename = os.path.join(dirname, "cpp_functions.cpp") -cpp_functions = load(name="cpp_functions", sources=[filename], verbose=False) # inspired by torch/nn/modules/utils.py def _ntuple(n): @@ -125,9 +120,7 @@ def backward(ctx, grad_output): # TODO: performance improvements possible by only backpropping valid input # grad_input_padding = _grad_input_padding(grad_output, inpt.shape, stride, padding, (weight.shape[2], weight.shape[3])) # TODO: use this!? - grad_in = cpp_functions.backward_input(inpt.shape, grad_output, weight.to(inpt.dtype), padding, - stride, dilation, groups, - torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic) + grad_in = torch.nn.grad.conv2d_input(inpt.shape, weight, grad_output, stride, padding, dilation, groups) else: grad_in = None @@ -211,14 +204,14 @@ def backward(ctx, grad_output): # Calculate the kernel gradients with the new unseen gradient values relevant_grad = relevant_grad.contiguous() - - grad_weight = cpp_functions.backward(weight.shape, - relevant_grad.to(weight.dtype), - relevant_input.to(weight.dtype), - (0, 0), # padding - stride[1:3], dilation, groups, - torch.backends.cudnn.benchmark, # benchmark - torch.backends.cudnn.deterministic) # deterministic + + grad_weight = torch.nn.grad.conv2d_weight(relevant_input.to(weight.dtype), + weight.shape, + relevant_grad.to(weight.dtype), + stride[1:3], + (0, 0), # padding + dilation, + groups) if bias is not None: grad_bias = relevant_grad[0].sum((1, 2))