-
Notifications
You must be signed in to change notification settings - Fork 177
Expand file tree
/
Copy pathpytorch_wrapper.cpp
More file actions
39 lines (33 loc) · 1.91 KB
/
pytorch_wrapper.cpp
File metadata and controls
39 lines (33 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <torch/script.h>
#include "involution2d_wrapper.h"
TORCH_LIBRARY(involution, m) {
m.def("involution2d(Tensor input, Tensor weight, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor");
m.def("_involution2d_backward_grad_input(Tensor grad, Tensor weight, int[] input_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor");
m.def("_involution2d_backward_grad_weight(Tensor grad, Tensor input, int[] weight_shape, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor");
m.def("_involution2d_backward(Tensor grad, Tensor weight, Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(involution, CPU, m) {
m.impl("involution2d", involution::cpu::involution2d_forward);
m.impl("_involution2d_backward_grad_input", involution::cpu::involution2d_backward_grad_input);
m.impl("_involution2d_backward_grad_weight", involution::cpu::involution2d_backward_grad_weight);
m.impl("_involution2d_backward", involution::cpu::involution2d_backward);
}
#ifdef USE_CUDA
TORCH_LIBRARY_IMPL(involution, CUDA, m) {
m.impl("involution2d", involution::cuda::involution2d_forward);
m.impl("_involution2d_backward_grad_input", involution::cuda::involution2d_backward_grad_input);
m.impl("_involution2d_backward_grad_weight", involution::cuda::involution2d_backward_grad_weight);
m.impl("_involution2d_backward", involution::cuda::involution2d_backward);
}
#endif
TORCH_LIBRARY_IMPL(involution, AutogradCPU, m) {
m.impl("involution2d", involution::cpu::involution2d_autograd);
}
#ifdef USE_CUDA
TORCH_LIBRARY_IMPL(involution, AutogradCUDA, m) {
m.impl("involution2d", involution::cuda::involution2d_autograd);
}
TORCH_LIBRARY_IMPL(involution, Autocast, m) {
m.impl("involution2d", involution::cuda::involution2d_autocast);
}
#endif