Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions tomosipo/torch_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@

class OperatorFunction(Function):
@staticmethod
def forward(ctx, input, operator, num_extra_dims=0, is_2d=False):
def forward(input, operator, num_extra_dims=0, is_2d=False):
extra_dims = input.size()[:num_extra_dims]
if input.requires_grad:
ctx.operator = operator
ctx.extra_dims = extra_dims
ctx.is_2d = is_2d

expected_ndim = (2 if is_2d else 3) + num_extra_dims
assert (
Expand All @@ -51,34 +47,49 @@ def forward(ctx, input, operator, num_extra_dims=0, is_2d=False):
else:
for subspace in itertools.product(*[range(dim_size) for dim_size in extra_dims]):
operator(input[subspace], out=output[subspace])

if is_2d:
output = torch.squeeze(output, dim=-3)
return output

@staticmethod
def setup_context(ctx, inputs, output):
_, operator, num_extra_dims, is_2d = inputs
ctx.operator = operator
ctx.num_extra_dims = num_extra_dims
ctx.is_2d = is_2d

@staticmethod
def backward(ctx, grad_output):
operator = ctx.operator
extra_dims = ctx.extra_dims
num_extra_dims = ctx.num_extra_dims
is_2d = ctx.is_2d

grad_input = grad_output.new_empty(extra_dims + operator.domain_shape, dtype=torch.float32)

if is_2d:
grad_output = torch.unsqueeze(grad_output, dim=-3)

if len(extra_dims) == 0:
operator.T(grad_output, out=grad_input)
else:
for subspace in itertools.product(*[range(dim_size) for dim_size in extra_dims]):
operator.T(grad_output[subspace], out=grad_input[subspace])

if is_2d:
grad_input = torch.squeeze(grad_input, dim=-3)
grad_input = OperatorFunction.apply(grad_output, operator.T, num_extra_dims, is_2d)

# do not return gradient for operator
return grad_input, None, None, None

@staticmethod
def jvp(ctx, grad_input, *args):
operator = ctx.operator
num_extra_dims = ctx.num_extra_dims
is_2d = ctx.is_2d

return OperatorFunction.apply(grad_input, operator, num_extra_dims, is_2d)

@staticmethod
def vmap(info, in_dims, input, operator, num_extra_dims, is_2d):
batch_dim = in_dims[0]
if batch_dim > num_extra_dims + 1:
# If batching along a dimension that is interspersed with spatial
# dimensions, move it in front to conform to batching logic
# implemented in `OperatorFunction.forward`
input = input.movedim(batch_dim, 0)
return OperatorFunction.apply(input, operator, num_extra_dims + 1, is_2d), 0
else:
return OperatorFunction.apply(input, operator, num_extra_dims + 1, is_2d), batch_dim


def to_autograd(operator, num_extra_dims=0, is_2d=False):
"""Converts an operator to an autograd function
Expand Down