Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions conv2d_backward.cpp

This file was deleted.

89 changes: 0 additions & 89 deletions cpp_functions.cpp

This file was deleted.

25 changes: 9 additions & 16 deletions scnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch._six import container_abcs
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torch.utils.cpp_extension import load

from tqdm import tqdm

Expand All @@ -37,10 +36,6 @@ def forward_amp_decorator(func):
def backward_amp_decorator(func):
return func

# Load and compile cpp code to call cudnn conv2d backward function
dirname = os.path.dirname(__file__)
filename = os.path.join(dirname, "cpp_functions.cpp")
cpp_functions = load(name="cpp_functions", sources=[filename], verbose=False)

# inspired by torch/nn/modules/utils.py
def _ntuple(n):
Expand Down Expand Up @@ -125,9 +120,7 @@ def backward(ctx, grad_output):
# TODO: performance improvements possible by only backpropping valid input
# grad_input_padding = _grad_input_padding(grad_output, inpt.shape, stride, padding, (weight.shape[2], weight.shape[3]))
# TODO: use this!?
grad_in = cpp_functions.backward_input(inpt.shape, grad_output, weight.to(inpt.dtype), padding,
stride, dilation, groups,
torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic)
grad_in = torch.nn.grad.conv2d_input(inpt.shape, weight, grad_output, stride, padding, dilation, groups)
else:
grad_in = None

Expand Down Expand Up @@ -211,14 +204,14 @@ def backward(ctx, grad_output):

# Calculate the kernel gradients with the new unseen gradient values
relevant_grad = relevant_grad.contiguous()

grad_weight = cpp_functions.backward(weight.shape,
relevant_grad.to(weight.dtype),
relevant_input.to(weight.dtype),
(0, 0), # padding
stride[1:3], dilation, groups,
torch.backends.cudnn.benchmark, # benchmark
torch.backends.cudnn.deterministic) # deterministic
grad_weight = torch.nn.grad.conv2d_weight(relevant_input.to(weight.dtype),
weight.shape,
relevant_grad.to(weight.dtype),
stride[1:3],
(0, 0), # padding
dilation,
groups)

if bias is not None:
grad_bias = relevant_grad[0].sum((1, 2))
Expand Down