Skip to content
This repository was archived by the owner on Jul 21, 2021. It is now read-only.

support different H*W #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions tests/test_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_th_map_coordinates():

def test_th_batch_map_coordinates():
np.random.seed(42)
input = np.random.random((4, 100, 100))
coords = (np.random.random((4, 200, 2)) * 99)
input = np.random.random((4, 100, 156))
coords = np.random.random((4, 100*156, 2)) * 99

sp_mapped_vals = sp_batch_map_coordinates(input, coords)
th_mapped_vals = th_batch_map_coordinates(
Expand All @@ -36,8 +36,8 @@ def test_th_batch_map_coordinates():

def test_th_batch_map_offsets():
np.random.seed(42)
input = np.random.random((4, 100, 100))
offsets = (np.random.random((4, 100, 100, 2)) * 2)
input = np.random.random((4, 100, 156))
offsets = (np.random.random((4, 100, 156, 2)) * 2)

sp_mapped_vals = sp_batch_map_offsets(input, offsets)
th_mapped_vals = th_batch_map_offsets(
Expand All @@ -48,14 +48,19 @@ def test_th_batch_map_offsets():

def test_th_batch_map_offsets_grad():
np.random.seed(42)
input = np.random.random((4, 100, 100))
offsets = (np.random.random((4, 100, 100, 2)) * 2)
input = np.random.random((4, 100, 156))
offsets = (np.random.random((4, 100, 156, 2)) * 2)

input = Variable(torch.from_numpy(input), requires_grad=True)
offsets = Variable(torch.from_numpy(offsets), requires_grad=True)

th_mapped_vals = th_batch_map_offsets(input, offsets)
e = torch.from_numpy(np.random.random((4, 100, 100)))
e = torch.from_numpy(np.random.random((4, 100, 156)))
th_mapped_vals.backward(e)
assert not np.allclose(input.grad.data.numpy(), 0)
assert not np.allclose(offsets.grad.data.numpy(), 0)

if __name__ == '__main__':
test_th_batch_map_coordinates()
test_th_batch_map_offsets()
test_th_batch_map_offsets_grad()
35 changes: 25 additions & 10 deletions torch_deform_conv/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def th_map_coordinates(input, coords, order=1):

def sp_batch_map_coordinates(inputs, coords):
"""Reference implementation for batch_map_coordinates"""
coords = coords.clip(0, inputs.shape[1] - 1)
# coords = coords.clip(0, inputs.shape[1] - 1)

assert (coords.shape[2] == 2)
height = coords[:,:,0].clip(0, inputs.shape[1] - 1)
weight = coords[:,:,1].clip(0, inputs.shape[2] - 1)
np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(weight, axis=2)), 2)

mapped_vals = np.array([
sp_map_coordinates(input, coord.T, mode='nearest', order=1)
for input, coord in zip(inputs, coords)
Expand All @@ -87,10 +93,17 @@ def th_batch_map_coordinates(input, coords, order=1):
"""

batch_size = input.size(0)
input_size = input.size(1)
input_height = input.size(1)
input_weight = input.size(2)

n_coords = coords.size(1)

coords = torch.clamp(coords, 0, input_size - 1)
# coords = torch.clamp(coords, 0, input_size - 1)

coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_weight - 1)), 2)

assert (coords.size(1) == n_coords)

coords_lt = coords.floor().long()
coords_rb = coords.ceil().long()
coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2)
Expand Down Expand Up @@ -125,21 +138,22 @@ def sp_batch_map_offsets(input, offsets):
"""Reference implementation for tf_batch_map_offsets"""

batch_size = input.shape[0]
input_size = input.shape[1]
input_height = input.shape[1]
input_weight = input.shape[2]

offsets = offsets.reshape(batch_size, -1, 2)
grid = np.stack(np.mgrid[:input_size, :input_size], -1).reshape(-1, 2)
grid = np.stack(np.mgrid[:input_height, :input_weight], -1).reshape(-1, 2)
grid = np.repeat([grid], batch_size, axis=0)
coords = offsets + grid
coords = coords.clip(0, input_size - 1)
# coords = coords.clip(0, input_size - 1)

mapped_vals = sp_batch_map_coordinates(input, coords)
return mapped_vals


def th_generate_grid(batch_size, input_size, dtype, cuda):
def th_generate_grid(batch_size, input_height, input_weight, dtype, cuda):
grid = np.meshgrid(
range(input_size), range(input_size), indexing='ij'
range(input_height), range(input_weight), indexing='ij'
)
grid = np.stack(grid, axis=-1)
grid = grid.reshape(-1, 2)
Expand All @@ -162,11 +176,12 @@ def th_batch_map_offsets(input, offsets, grid=None, order=1):
torch.Tensor. shape = (b, s, s)
"""
batch_size = input.size(0)
input_size = input.size(1)
input_height = input.size(1)
input_weight = input.size(2)

offsets = offsets.view(batch_size, -1, 2)
if grid is None:
grid = th_generate_grid(batch_size, input_size, offsets.data.type(), offsets.data.is_cuda)
grid = th_generate_grid(batch_size, input_height, input_weight, offsets.data.type(), offsets.data.is_cuda)

coords = offsets + grid

Expand Down
8 changes: 4 additions & 4 deletions torch_deform_conv/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def forward(self, x):

@staticmethod
def _get_grid(self, x):
batch_size, input_size= x.size(0), x.size(1)
batch_size, input_height, input_weight = x.size(0), x.size(1), x.size(2)
dtype, cuda = x.data.type(), x.data.is_cuda
if self._grid_param == (batch_size, input_size, dtype, cuda):
if self._grid_param == (batch_size, input_height, input_weight, dtype, cuda):
return self._grid
self._grid_param = (batch_size, input_size, dtype, cuda)
self._grid = th_generate_grid(batch_size, input_size, dtype, cuda)
self._grid_param = (batch_size, input_height, input_weight, dtype, cuda)
self._grid = th_generate_grid(batch_size, input_height, input_weight, dtype, cuda)
return self._grid

@staticmethod
Expand Down