Skip to content

Commit 491e6da

Browse files
authored
Add Grid Sampling for 3D images. (#627)
1 parent b56ff50 commit 491e6da

File tree

4 files changed

+508
-66
lines changed

4 files changed

+508
-66
lines changed

ext/NNlibCUDAExt/sampling.jl

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
@inbounds CUDA.@atomic dx[ix, iy, c, n] += value
33
end
44

5-
function grid_sample_kernel!(n_elem, output, input, grid, padding_mode)
5+
function grid_sample_kernel!(n_elem, output::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}
66
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
77
if index < n_elem
88
iW, iH, iC, _ = size(input)
@@ -16,7 +16,7 @@ function grid_sample_kernel!(n_elem, output, input, grid, padding_mode)
1616
nothing
1717
end
1818

19-
function ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, input, grid, padding_mode)
19+
function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 4}, dgrid::AbstractArray{V, 4}, Δ::AbstractArray{T, 4}, input::AbstractArray{T, 4}, grid::AbstractArray{V, 4}, padding_mode) where {T,V}
2020
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
2121
if index < n_elem
2222
iW, iH, iC, _ = size(input)
@@ -59,3 +59,74 @@ function NNlib.∇grid_sample(Δ::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray
5959
kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)
6060
dx, dgrid
6161
end
62+
63+
64+
@inline function NNlib._safe_add!(dx::CuDeviceArray{T, 5}, value, ix, iy, iz, c, n) where T
65+
@inbounds CUDA.@atomic dx[ix, iy, iz, c, n] += value
66+
end
67+
68+
function grid_sample_kernel!(n_elem, output::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}
69+
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
70+
if index < n_elem
71+
iW, iH,iD, iC, _ = size(input)
72+
_, gW, gH, gD, _ = size(grid)
73+
74+
w = index % gW + 1
75+
h = (index ÷ gW) % gH + 1
76+
d = (index ÷ (gW * gH)) % gD + 1
77+
n = index ÷ (gW * gH * gD) + 1
78+
# n = index ÷ (gW * gH) + 1
79+
# d = (index ÷ (gW * gH * n)) + 1
80+
81+
NNlib._grid_sample_kernel!(output, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
82+
end
83+
nothing
84+
end
85+
86+
function ∇grid_sample_kernel!(n_elem, dx::AbstractArray{T, 5}, dgrid::AbstractArray{V, 5}, Δ::AbstractArray{T, 5}, input::AbstractArray{T, 5}, grid::AbstractArray{V, 5}, padding_mode) where {T,V}
87+
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
88+
if index < n_elem
89+
iW, iH, iD, iC, _ = size(input)
90+
_, gW, gH, gD, _ = size(grid)
91+
92+
w = index % gW + 1
93+
h = (index ÷ gW) % gH + 1
94+
d = (index ÷ (gW * gH)) % gD + 1
95+
n = index ÷ (gW * gH * gD) + 1
96+
# n = index ÷ (gW * gH) + 1
97+
# d = (index ÷ (gW * gH * n)) + 1
98+
99+
NNlib._∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
100+
end
101+
nothing
102+
end
103+
104+
function NNlib.grid_sample(x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}
105+
pad = Val(padding_mode)
106+
_, _, _, xC, xN = size(x)
107+
_, gW, gH, gD, _ = size(grid)
108+
n_elem = gW * gH * gD * xN
109+
y = similar(x, T, (gW, gH, gD, xC, xN))
110+
111+
kernel = @cuda launch=false grid_sample_kernel!(n_elem, y, x, grid, pad)
112+
config = launch_configuration(kernel.fun; max_threads=256)
113+
threads = min(n_elem, config.threads)
114+
blocks = cld(n_elem, threads)
115+
kernel(n_elem, y, x, grid, pad; threads=threads, blocks=blocks)
116+
y
117+
end
118+
119+
function NNlib.∇grid_sample(Δ::CuArray{T, 5}, x::CuArray{T, 5}, grid::CuArray{V, 5}; padding_mode = :zeros) where {T, V}
120+
pad = Val(padding_mode)
121+
xN = size(x, 5)
122+
_, gW, gH, gD, _ = size(grid)
123+
n_elem = gW * gH * gD * xN
124+
dx, dgrid = CUDA.zeros(T, size(x)), similar(grid)
125+
126+
kernel = @cuda launch=false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)
127+
config = launch_configuration(kernel.fun; max_threads=256)
128+
threads = min(n_elem, config.threads)
129+
blocks = cld(n_elem, threads)
130+
kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads=threads, blocks=blocks)
131+
dx, dgrid
132+
end

0 commit comments

Comments
 (0)