Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions sima/motion/dftreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import numpy as np
from scipy.ndimage.interpolation import shift
import time
import warnings
from . import motion
try:
from pyfftw.interfaces.numpy_fft import fftn, ifftn
Expand Down Expand Up @@ -123,11 +124,13 @@ def _estimate(self, dataset):
displacements = []

for sequence in dataset:
num_planes = sequence.shape[1]
num_channels = sequence.shape[4]
num_frames, num_planes, _, _, num_channels = sequence.shape
if num_channels > 1:
raise NotImplementedError("Error: only one colour channel \
can be used for DFT motion correction. Using channel 1.")
warnings.warn("Warning: only one colour channel \
can be used for DFT motion correction. Using channel 0.")

# get results into a shape sima likes
frame_shifts = np.zeros([num_frames, num_planes, 2])

for plane_idx in range(num_planes):
# load into memory... need to pass numpy array to dftreg.
Expand Down Expand Up @@ -172,8 +175,7 @@ def _estimate(self, dataset):
else:
dy, dx = output

# get results into a shape sima likes
frame_shifts = np.zeros([len(frames), num_planes, 2])
# add plane shift info
for idx, frame in enumerate(sequence):
frame_shifts[idx, plane_idx] = [dy[idx], dx[idx]]
displacements.append(frame_shifts)
Expand Down
17 changes: 16 additions & 1 deletion sima/motion/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def correct(self, dataset, savedir, channel_names=None, info=None,
else:
mc_sequences = sequences
displacements = self.estimate(sima.ImagingDataset(mc_sequences, None))

# enforce integer displacements
displacements = [d if issubclass(d.dtype.type, np.integer) else \
d.round().astype(np.int64) for d in displacements]

disp_dim = displacements[0].shape[-1]
max_disp = np.ceil(
np.max(list(it.chain.from_iterable(d.reshape(-1, disp_dim)
Expand Down Expand Up @@ -188,7 +193,17 @@ def _estimate(self, dataset):
downsampled_dataset)
displacements = []
for d_disps in downsampled_displacements:
disps = np.repeat(d_disps, 2, axis=2) # Repeat the displacements
if d_disps.ndim == 3: # whole frame displacements
if self._offset == 0:
disps = d_disps[:]
disps[..., 0] *= 2 # multiply y-shifts by 2
displacements.append(disps)
continue
else: # duplicate displacements for all rows
disps = np.moveaxis(np.stack(
[d_disps] * dataset.frame_shape[0]), 0, 2)
else: # line-by-line displacements
disps = np.repeat(d_disps, 2, axis=2) # Repeat the displacements
disps[:, :, :, 0] *= 2 # multiply y-shifts by 2
disps[:, :, 1::2, -1] += self._offset # shift even rows by offset
displacements.append(disps)
Expand Down