Skip to content

Commit 4815c4b

Browse files
committed
Fix batched inference
1 parent 29a1bff commit 4815c4b

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

interpolator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
125125
# Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
126126
# lator for multi-frame interpolation. Below, we create a constant tensor of
127127
# shape [B]. We use the `time` tensor to infer the batch size.
128-
backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt[:, 0])
129-
forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt[:, 0])
128+
backward_flow = util.multiply_pyramid(backward_flow_pyramid, batch_dt)
129+
forward_flow = util.multiply_pyramid(forward_flow_pyramid, 1 - batch_dt)
130130

131131
pyramids_to_warp = [
132132
util.concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels],
@@ -154,6 +154,5 @@ def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
154154
'backward_flow_pyramid': backward_flow_pyramid,
155155
}
156156

157-
@torch.jit.export
158157
def forward(self, x0, x1, batch_dt) -> torch.Tensor:
159158
return self.debug_forward(x0, x1, batch_dt)['image'][0]

util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def multiply_pyramid(pyramid: List[torch.Tensor],
102102
# the batch of images from BxHxWxC-format to CxHxWxB. This can then be
103103
# multiplied with a batch of scalars, then we transpose back to the standard
104104
# BxHxWxC form.
105-
return [image * scalar for image in pyramid]
105+
return [image * scalar[..., None, None] for image in pyramid]
106106

107107

108108
def flow_pyramid_synthesis(

0 commit comments

Comments
 (0)