Skip to content

Commit 110836d

Browse files
committed
drift correction bits
1 parent 2b8b826 commit 110836d

3 files changed

Lines changed: 251 additions & 3 deletions

File tree

src/napatrackmater/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
inception_dual_model_prediction
4242
)
4343

44+
from .drift import (
45+
affine_transform, apply_alpha_drift, apply_xy_drift, apply_z_drift, crop_data, get_rotation, get_xy_drift, get_z_drift
46+
)
47+
4448
from .CloudAutoEncoder import CloudAutoEncoder
4549
import json
4650
from csbdeep.utils.tf import keras_import
@@ -128,7 +132,16 @@ def load_json(fpath):
128132
"create_h5",
129133
"normalize_image_in_chunks",
130134
"vision_inception_model_prediction",
131-
"inception_dual_model_prediction"
135+
"inception_dual_model_prediction",
136+
"affine_transform",
137+
"apply_alpha_drift",
138+
"apply_xy_drift",
139+
"apply_z_drift",
140+
"crop_data",
141+
"get_rotation",
142+
"get_xy_drift",
143+
"get_z_drift"
144+
132145
)
133146

134147
clear_models_and_aliases(CloudAutoEncoder)

src/napatrackmater/_version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = version = "5.7.7"
2-
__version_tuple__ = version_tuple = (5, 7, 7)
1+
__version__ = version = "5.7.8"
2+
__version_tuple__ = version_tuple = (5, 7, 8)

src/napatrackmater/drift.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
2+
import numpy as np
3+
import dask
4+
import dask.array as da
5+
import matplotlib.pyplot as plt
6+
from skimage.registration import phase_cross_correlation
7+
from skimage.transform import warp_polar
8+
from skimage.transform import AffineTransform
9+
from scipy.ndimage import affine_transform, rotate
10+
# Utility
11+
from tqdm import tqdm
12+
import tifffile
13+
import sys
14+
import os
15+
16+
def get_xy_drift(_data, ref_channel):
17+
xy_movie = da.average(_data[ref_channel], axis = 1).compute()
18+
19+
# correct XY, relative to first frame:
20+
print('Determining drift in XY')
21+
shifts, error, phasediff = [], [], []
22+
for t in tqdm(range(len(xy_movie)-1)):
23+
24+
s, e, p = phase_cross_correlation(xy_movie[t],
25+
xy_movie[t+1],
26+
normalization=None)
27+
28+
29+
30+
shifts.append(s)
31+
error.append(e)
32+
phasediff.append(p)
33+
34+
shifts_xy = np.cumsum(np.array(shifts), axis = 0)
35+
shifts_xy = shifts_xy.tolist()
36+
shifts_xy.insert(0,[0,0])
37+
shifts_xy = np.asarray(shifts_xy)
38+
39+
plt.title('XY-Drift')
40+
plt.plot(shifts_xy[:,0], label = 'x')
41+
plt.plot(shifts_xy[:,1], label = 'y')
42+
plt.legend()
43+
plt.xlabel('Timesteps')
44+
plt.ylabel('Drift in pixel')
45+
plt.savefig('XY-Drift.svg')
46+
plt.clf()
47+
return shifts_xy
48+
49+
def get_z_drift(_data, ref_channel):
50+
z_movie = da.average(da.swapaxes(_data[ref_channel], 2,1), axis = 1).compute()
51+
52+
# correct XY, relative to first frame:
53+
print('Determining drift in Z')
54+
shifts, error, phasediff = [], [], []
55+
for t in tqdm(range(len(z_movie)-1)):
56+
57+
s, e, p = phase_cross_correlation( z_movie[t],
58+
z_movie[t+1],
59+
normalization=None)
60+
61+
62+
63+
shifts.append(s)
64+
error.append(e)
65+
phasediff.append(p)
66+
67+
shifts_z = np.cumsum(np.array(shifts), axis = 0)
68+
for i in shifts_z:
69+
i[1] = 0
70+
71+
shifts_z = shifts_z.tolist()
72+
shifts_z.insert(0,[0,0])
73+
shifts_z = np.asarray(shifts_z)
74+
75+
76+
plt.title('Z-Drift')
77+
plt.plot(shifts_z[:,0], label = 'z')
78+
plt.legend()
79+
plt.xlabel('Timesteps')
80+
plt.ylabel('Drift in pixel')
81+
plt.savefig('Z-Drift.svg')
82+
plt.clf()
83+
84+
return shifts_z
85+
86+
def translate_stack(_image, _shift):
87+
88+
# define warp matrix:
89+
tform = AffineTransform(translation=-_shift)
90+
91+
# construct shifted stack
92+
_image_out = []
93+
94+
for c,C in enumerate(_image):
95+
_out_stack = []
96+
for z,Z in enumerate(C):
97+
98+
# here is the delayed function that allows to scatter everthing nicely across workers
99+
_out_stack.append(dask.delayed(affine_transform)(Z, np.array(tform)))
100+
101+
# format arrays for proper export
102+
_out_stack = da.stack([da.from_delayed(x, shape=np.shape(Z), dtype=np.dtype(Z)) for x in _out_stack])
103+
_image_out.append(_out_stack)
104+
105+
return da.stack(_image_out)
106+
107+
108+
def apply_xy_drift(_data, xy_drift):
109+
110+
# Expectes _data to have the following shape:
111+
# np.shape_data = (channel, time, z, y, x)
112+
113+
# Swap axes to iterate over time which aligns with chunks
114+
_data = da.swapaxes(_data, 0,1)
115+
116+
print('Scheduling tasks...')
117+
_translated_data = []
118+
for t,T in tqdm(enumerate(_data)):
119+
_translated_data.append(translate_stack(T, xy_drift[t]))
120+
121+
# data formatting
122+
_data_out = da.stack(_translated_data)
123+
# just in case
124+
_data = da.swapaxes(_data, 0,1)
125+
_data_out = da.swapaxes(_data_out, 0,1)
126+
127+
print('XY-shift has been scheduled.')
128+
129+
return _data_out
130+
131+
def rotate_stack(_image, _alpha):
132+
133+
# construct shifted stack
134+
_image_out = []
135+
136+
for c,C in enumerate(_image):
137+
_out_stack = []
138+
for z,Z in enumerate(C):
139+
140+
# here is the delayed function that allows to scatter everthing nicely across workers
141+
_out_stack.append(dask.delayed(rotate)(Z, -_alpha,reshape=False))
142+
143+
# format arrays for proper export
144+
_out_stack = da.stack([da.from_delayed(x, shape=np.shape(Z), dtype=np.dtype(Z)) for x in _out_stack])
145+
_image_out.append(_out_stack)
146+
147+
return da.stack(_image_out)
148+
149+
150+
151+
def apply_z_drift(_data, z_drift):
152+
153+
# swap axes around
154+
_data = da.swapaxes(_data, 0,1)
155+
_data = da.swapaxes(_data, 3,2)
156+
157+
_data_out = []
158+
for t,T in tqdm(enumerate(_data)):
159+
_data_out.append(translate_stack(T, z_drift[t]))
160+
161+
# Stack the list to a dask array and swap back the axes.
162+
_data_out = da.stack(_data_out)
163+
_data = da.swapaxes(_data,2,3)
164+
_data = da.swapaxes(_data,0,1)
165+
_data_out = da.swapaxes(_data_out,2,3)
166+
_data_out = da.swapaxes(_data_out,0,1)
167+
168+
return _data_out
169+
170+
def crop_data(data, xy_drift, z_drift):
171+
y_crop = [int(np.max(xy_drift[:,0])),int(np.shape(data)[-1]-abs(np.min(xy_drift[:,0])))]
172+
x_crop = [int(np.max(xy_drift[:,1])),int(np.shape(data)[-2]-abs(np.min(xy_drift[:,1])))]
173+
z_crop = [int(np.max(z_drift[:,0])),int(np.shape(data)[-3]-abs(np.min(z_drift[:,0])))]
174+
175+
cropped = data[:,:,
176+
z_crop[0]:z_crop[1],
177+
y_crop[0]:y_crop[1],
178+
x_crop[0]:x_crop[1]]
179+
return cropped
180+
181+
def get_rotation(data, ref_channel):
182+
xy_movie = da.average(data[ref_channel], axis = 1).compute()
183+
184+
# get image radius
185+
radius = int(min(da.shape(data)[3], da.shape(data)[4])/2)
186+
187+
print('Determining rotation')
188+
shifts, error, phasediff = [], [], []
189+
for t in tqdm(range(len(xy_movie)-1)):
190+
191+
s, e, p = phase_cross_correlation( warp_polar(xy_movie[t], radius = radius),
192+
warp_polar(xy_movie[t+1], radius = radius),
193+
normalization=None)
194+
195+
196+
197+
shifts.append(s)
198+
error.append(e)
199+
phasediff.append(p)
200+
201+
shifts_a = np.cumsum(np.array(shifts), axis = 0)
202+
for i in shifts:
203+
i[0] = 0
204+
205+
shifts_a = shifts_a.tolist()
206+
shifts_a.insert(0,[0,0])
207+
shifts_a = np.asarray(shifts_a)
208+
209+
plt.title('alpha-Drift')
210+
plt.plot(shifts_a[:,0], label = 'alpha')
211+
plt.legend()
212+
plt.xlabel('Timesteps')
213+
plt.ylabel('Rotation in degree')
214+
plt.savefig('Rotation-Drift.svg')
215+
plt.clf()
216+
217+
return shifts_a[:,0]
218+
219+
def apply_alpha_drift(_data, alpha_drift):
220+
221+
# swap axes around
222+
_data = da.swapaxes(_data, 0,1)
223+
224+
_data_out = []
225+
for t,T in tqdm(enumerate(_data)):
226+
_data_out.append(rotate_stack(T, alpha_drift[t]))
227+
228+
# Stack the list to a dask array and swap back the axes.
229+
_data_out = da.stack(_data_out)
230+
_data = da.swapaxes(_data,0,1)
231+
_data_out = da.swapaxes(_data_out,0,1)
232+
233+
234+
return _data_out
235+

0 commit comments

Comments
 (0)