22 @inbounds CUDA. @atomic dx[ix, iy, c, n] += value
33end
44
5- function grid_sample_kernel!(n_elem, output, input, grid, padding_mode)
5+ function grid_sample_kernel!(n_elem, output:: AbstractArray{T, 4} , input:: AbstractArray{T, 4} , grid:: AbstractArray{V, 4} , padding_mode) where {T,V}
66 index = (threadIdx(). x - 1 ) + (blockIdx(). x - 1 ) * blockDim(). x
77 if index < n_elem
88 iW, iH, iC, _ = size(input)
@@ -16,7 +16,7 @@ function grid_sample_kernel!(n_elem, output, input, grid, padding_mode)
1616 nothing
1717end
1818
19- function ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, input, grid, padding_mode)
19+ function ∇grid_sample_kernel!(n_elem, dx:: AbstractArray{T, 4} , dgrid:: AbstractArray{V, 4} , Δ :: AbstractArray{T, 4} , input:: AbstractArray{T, 4} , grid:: AbstractArray{V, 4} , padding_mode) where {T,V}
2020 index = (threadIdx(). x - 1 ) + (blockIdx(). x - 1 ) * blockDim(). x
2121 if index < n_elem
2222 iW, iH, iC, _ = size(input)
@@ -59,3 +59,74 @@ function NNlib.∇grid_sample(Δ::CuArray{T, 4}, x::CuArray{T, 4}, grid::CuArray
5959 kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads= threads, blocks= blocks)
6060 dx, dgrid
6161end
62+
63+
64+ @inline function NNlib. _safe_add!(dx:: CuDeviceArray{T, 5} , value, ix, iy, iz, c, n) where T
65+ @inbounds CUDA. @atomic dx[ix, iy, iz, c, n] += value
66+ end
67+
68+ function grid_sample_kernel!(n_elem, output:: AbstractArray{T, 5} , input:: AbstractArray{T, 5} , grid:: AbstractArray{V, 5} , padding_mode) where {T,V}
69+ index = (threadIdx(). x - 1 ) + (blockIdx(). x - 1 ) * blockDim(). x
70+ if index < n_elem
71+ iW, iH,iD, iC, _ = size(input)
72+ _, gW, gH, gD, _ = size(grid)
73+
74+ w = index % gW + 1
75+ h = (index ÷ gW) % gH + 1
76+ d = (index ÷ (gW * gH)) % gD + 1
77+ n = index ÷ (gW * gH * gD) + 1
78+ # n = index ÷ (gW * gH) + 1
79+ # d = (index ÷ (gW * gH * n)) + 1
80+
81+ NNlib. _grid_sample_kernel!(output, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
82+ end
83+ nothing
84+ end
85+
86+ function ∇grid_sample_kernel!(n_elem, dx:: AbstractArray{T, 5} , dgrid:: AbstractArray{V, 5} , Δ:: AbstractArray{T, 5} , input:: AbstractArray{T, 5} , grid:: AbstractArray{V, 5} , padding_mode) where {T,V}
87+ index = (threadIdx(). x - 1 ) + (blockIdx(). x - 1 ) * blockDim(). x
88+ if index < n_elem
89+ iW, iH, iD, iC, _ = size(input)
90+ _, gW, gH, gD, _ = size(grid)
91+
92+ w = index % gW + 1
93+ h = (index ÷ gW) % gH + 1
94+ d = (index ÷ (gW * gH)) % gD + 1
95+ n = index ÷ (gW * gH * gD) + 1
96+ # n = index ÷ (gW * gH) + 1
97+ # d = (index ÷ (gW * gH * n)) + 1
98+
99+ NNlib. _∇grid_sample_kernel!(dx, dgrid, Δ, input, grid, padding_mode, w, h, d, n, iW, iH, iD, iC)
100+ end
101+ nothing
102+ end
103+
104+ function NNlib. grid_sample(x:: CuArray{T, 5} , grid:: CuArray{V, 5} ; padding_mode = :zeros) where {T, V}
105+ pad = Val(padding_mode)
106+ _, _, _, xC, xN = size(x)
107+ _, gW, gH, gD, _ = size(grid)
108+ n_elem = gW * gH * gD * xN
109+ y = similar(x, T, (gW, gH, gD, xC, xN))
110+
111+ kernel = @cuda launch= false grid_sample_kernel!(n_elem, y, x, grid, pad)
112+ config = launch_configuration(kernel. fun; max_threads= 256 )
113+ threads = min(n_elem, config. threads)
114+ blocks = cld(n_elem, threads)
115+ kernel(n_elem, y, x, grid, pad; threads= threads, blocks= blocks)
116+ y
117+ end
118+
119+ function NNlib.∇grid_sample(Δ:: CuArray{T, 5} , x:: CuArray{T, 5} , grid:: CuArray{V, 5} ; padding_mode = :zeros) where {T, V}
120+ pad = Val(padding_mode)
121+ xN = size(x, 5 )
122+ _, gW, gH, gD, _ = size(grid)
123+ n_elem = gW * gH * gD * xN
124+ dx, dgrid = CUDA. zeros(T, size(x)), similar(grid)
125+
126+ kernel = @cuda launch= false ∇grid_sample_kernel!(n_elem, dx, dgrid, Δ, x, grid, pad)
127+ config = launch_configuration(kernel. fun; max_threads= 256 )
128+ threads = min(n_elem, config. threads)
129+ blocks = cld(n_elem, threads)
130+ kernel(n_elem, dx, dgrid, Δ, x, grid, pad; threads= threads, blocks= blocks)
131+ dx, dgrid
132+ end
0 commit comments