-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel_filter.py
More file actions
23 lines (18 loc) · 802 Bytes
/
kernel_filter.py
File metadata and controls
23 lines (18 loc) · 802 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import KernelFilter
class kernel_filter_function(torch.autograd.Function):
@staticmethod
def forward(ctx, grid, kernel, dilation):
ctx.save_for_backward(grid, kernel)
ctx.dilation = dilation
output = KernelFilter.forward(grid, kernel, dilation)
return output
@staticmethod
def backward(ctx, backprop):
grid_grad_output, kernel_grad_output = KernelFilter.backward(*ctx.saved_tensors, backprop, ctx.dilation)
return grid_grad_output, kernel_grad_output, None
class KernelFilterClass(torch.nn.Module):
def __init__(self):
super(KernelFilterClass, self).__init__()
def forward(self, grid:torch.Tensor, kernel:torch.Tensor, dilation:int=1):
return kernel_filter_function.apply(grid, kernel, dilation)