|
| 1 | + |
| 2 | +import numpy as np |
| 3 | +import dask |
| 4 | +import dask.array as da |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +from skimage.registration import phase_cross_correlation |
| 7 | +from skimage.transform import warp_polar |
| 8 | +from skimage.transform import AffineTransform |
| 9 | +from scipy.ndimage import affine_transform, rotate |
| 10 | +# Utility |
| 11 | +from tqdm import tqdm |
| 12 | +import tifffile |
| 13 | +import sys |
| 14 | +import os |
| 15 | + |
| 16 | +def get_xy_drift(_data, ref_channel): |
| 17 | + xy_movie = da.average(_data[ref_channel], axis = 1).compute() |
| 18 | + |
| 19 | + # correct XY, relative to first frame: |
| 20 | + print('Determining drift in XY') |
| 21 | + shifts, error, phasediff = [], [], [] |
| 22 | + for t in tqdm(range(len(xy_movie)-1)): |
| 23 | + |
| 24 | + s, e, p = phase_cross_correlation(xy_movie[t], |
| 25 | + xy_movie[t+1], |
| 26 | + normalization=None) |
| 27 | + |
| 28 | + |
| 29 | + |
| 30 | + shifts.append(s) |
| 31 | + error.append(e) |
| 32 | + phasediff.append(p) |
| 33 | + |
| 34 | + shifts_xy = np.cumsum(np.array(shifts), axis = 0) |
| 35 | + shifts_xy = shifts_xy.tolist() |
| 36 | + shifts_xy.insert(0,[0,0]) |
| 37 | + shifts_xy = np.asarray(shifts_xy) |
| 38 | + |
| 39 | + plt.title('XY-Drift') |
| 40 | + plt.plot(shifts_xy[:,0], label = 'x') |
| 41 | + plt.plot(shifts_xy[:,1], label = 'y') |
| 42 | + plt.legend() |
| 43 | + plt.xlabel('Timesteps') |
| 44 | + plt.ylabel('Drift in pixel') |
| 45 | + plt.savefig('XY-Drift.svg') |
| 46 | + plt.clf() |
| 47 | + return shifts_xy |
| 48 | + |
| 49 | +def get_z_drift(_data, ref_channel): |
| 50 | + z_movie = da.average(da.swapaxes(_data[ref_channel], 2,1), axis = 1).compute() |
| 51 | + |
| 52 | + # correct XY, relative to first frame: |
| 53 | + print('Determining drift in Z') |
| 54 | + shifts, error, phasediff = [], [], [] |
| 55 | + for t in tqdm(range(len(z_movie)-1)): |
| 56 | + |
| 57 | + s, e, p = phase_cross_correlation( z_movie[t], |
| 58 | + z_movie[t+1], |
| 59 | + normalization=None) |
| 60 | + |
| 61 | + |
| 62 | + |
| 63 | + shifts.append(s) |
| 64 | + error.append(e) |
| 65 | + phasediff.append(p) |
| 66 | + |
| 67 | + shifts_z = np.cumsum(np.array(shifts), axis = 0) |
| 68 | + for i in shifts_z: |
| 69 | + i[1] = 0 |
| 70 | + |
| 71 | + shifts_z = shifts_z.tolist() |
| 72 | + shifts_z.insert(0,[0,0]) |
| 73 | + shifts_z = np.asarray(shifts_z) |
| 74 | + |
| 75 | + |
| 76 | + plt.title('Z-Drift') |
| 77 | + plt.plot(shifts_z[:,0], label = 'z') |
| 78 | + plt.legend() |
| 79 | + plt.xlabel('Timesteps') |
| 80 | + plt.ylabel('Drift in pixel') |
| 81 | + plt.savefig('Z-Drift.svg') |
| 82 | + plt.clf() |
| 83 | + |
| 84 | + return shifts_z |
| 85 | + |
| 86 | +def translate_stack(_image, _shift): |
| 87 | + |
| 88 | + # define warp matrix: |
| 89 | + tform = AffineTransform(translation=-_shift) |
| 90 | + |
| 91 | + # construct shifted stack |
| 92 | + _image_out = [] |
| 93 | + |
| 94 | + for c,C in enumerate(_image): |
| 95 | + _out_stack = [] |
| 96 | + for z,Z in enumerate(C): |
| 97 | + |
| 98 | + # here is the delayed function that allows to scatter everthing nicely across workers |
| 99 | + _out_stack.append(dask.delayed(affine_transform)(Z, np.array(tform))) |
| 100 | + |
| 101 | + # format arrays for proper export |
| 102 | + _out_stack = da.stack([da.from_delayed(x, shape=np.shape(Z), dtype=np.dtype(Z)) for x in _out_stack]) |
| 103 | + _image_out.append(_out_stack) |
| 104 | + |
| 105 | + return da.stack(_image_out) |
| 106 | + |
| 107 | + |
| 108 | +def apply_xy_drift(_data, xy_drift): |
| 109 | + |
| 110 | + # Expectes _data to have the following shape: |
| 111 | + # np.shape_data = (channel, time, z, y, x) |
| 112 | + |
| 113 | + # Swap axes to iterate over time which aligns with chunks |
| 114 | + _data = da.swapaxes(_data, 0,1) |
| 115 | + |
| 116 | + print('Scheduling tasks...') |
| 117 | + _translated_data = [] |
| 118 | + for t,T in tqdm(enumerate(_data)): |
| 119 | + _translated_data.append(translate_stack(T, xy_drift[t])) |
| 120 | + |
| 121 | + # data formatting |
| 122 | + _data_out = da.stack(_translated_data) |
| 123 | + # just in case |
| 124 | + _data = da.swapaxes(_data, 0,1) |
| 125 | + _data_out = da.swapaxes(_data_out, 0,1) |
| 126 | + |
| 127 | + print('XY-shift has been scheduled.') |
| 128 | + |
| 129 | + return _data_out |
| 130 | + |
| 131 | +def rotate_stack(_image, _alpha): |
| 132 | + |
| 133 | + # construct shifted stack |
| 134 | + _image_out = [] |
| 135 | + |
| 136 | + for c,C in enumerate(_image): |
| 137 | + _out_stack = [] |
| 138 | + for z,Z in enumerate(C): |
| 139 | + |
| 140 | + # here is the delayed function that allows to scatter everthing nicely across workers |
| 141 | + _out_stack.append(dask.delayed(rotate)(Z, -_alpha,reshape=False)) |
| 142 | + |
| 143 | + # format arrays for proper export |
| 144 | + _out_stack = da.stack([da.from_delayed(x, shape=np.shape(Z), dtype=np.dtype(Z)) for x in _out_stack]) |
| 145 | + _image_out.append(_out_stack) |
| 146 | + |
| 147 | + return da.stack(_image_out) |
| 148 | + |
| 149 | + |
| 150 | + |
| 151 | +def apply_z_drift(_data, z_drift): |
| 152 | + |
| 153 | + # swap axes around |
| 154 | + _data = da.swapaxes(_data, 0,1) |
| 155 | + _data = da.swapaxes(_data, 3,2) |
| 156 | + |
| 157 | + _data_out = [] |
| 158 | + for t,T in tqdm(enumerate(_data)): |
| 159 | + _data_out.append(translate_stack(T, z_drift[t])) |
| 160 | + |
| 161 | + # Stack the list to a dask array and swap back the axes. |
| 162 | + _data_out = da.stack(_data_out) |
| 163 | + _data = da.swapaxes(_data,2,3) |
| 164 | + _data = da.swapaxes(_data,0,1) |
| 165 | + _data_out = da.swapaxes(_data_out,2,3) |
| 166 | + _data_out = da.swapaxes(_data_out,0,1) |
| 167 | + |
| 168 | + return _data_out |
| 169 | + |
| 170 | +def crop_data(data, xy_drift, z_drift): |
| 171 | + y_crop = [int(np.max(xy_drift[:,0])),int(np.shape(data)[-1]-abs(np.min(xy_drift[:,0])))] |
| 172 | + x_crop = [int(np.max(xy_drift[:,1])),int(np.shape(data)[-2]-abs(np.min(xy_drift[:,1])))] |
| 173 | + z_crop = [int(np.max(z_drift[:,0])),int(np.shape(data)[-3]-abs(np.min(z_drift[:,0])))] |
| 174 | + |
| 175 | + cropped = data[:,:, |
| 176 | + z_crop[0]:z_crop[1], |
| 177 | + y_crop[0]:y_crop[1], |
| 178 | + x_crop[0]:x_crop[1]] |
| 179 | + return cropped |
| 180 | + |
| 181 | +def get_rotation(data, ref_channel): |
| 182 | + xy_movie = da.average(data[ref_channel], axis = 1).compute() |
| 183 | + |
| 184 | + # get image radius |
| 185 | + radius = int(min(da.shape(data)[3], da.shape(data)[4])/2) |
| 186 | + |
| 187 | + print('Determining rotation') |
| 188 | + shifts, error, phasediff = [], [], [] |
| 189 | + for t in tqdm(range(len(xy_movie)-1)): |
| 190 | + |
| 191 | + s, e, p = phase_cross_correlation( warp_polar(xy_movie[t], radius = radius), |
| 192 | + warp_polar(xy_movie[t+1], radius = radius), |
| 193 | + normalization=None) |
| 194 | + |
| 195 | + |
| 196 | + |
| 197 | + shifts.append(s) |
| 198 | + error.append(e) |
| 199 | + phasediff.append(p) |
| 200 | + |
| 201 | + shifts_a = np.cumsum(np.array(shifts), axis = 0) |
| 202 | + for i in shifts: |
| 203 | + i[0] = 0 |
| 204 | + |
| 205 | + shifts_a = shifts_a.tolist() |
| 206 | + shifts_a.insert(0,[0,0]) |
| 207 | + shifts_a = np.asarray(shifts_a) |
| 208 | + |
| 209 | + plt.title('alpha-Drift') |
| 210 | + plt.plot(shifts_a[:,0], label = 'alpha') |
| 211 | + plt.legend() |
| 212 | + plt.xlabel('Timesteps') |
| 213 | + plt.ylabel('Rotation in degree') |
| 214 | + plt.savefig('Rotation-Drift.svg') |
| 215 | + plt.clf() |
| 216 | + |
| 217 | + return shifts_a[:,0] |
| 218 | + |
| 219 | +def apply_alpha_drift(_data, alpha_drift): |
| 220 | + |
| 221 | + # swap axes around |
| 222 | + _data = da.swapaxes(_data, 0,1) |
| 223 | + |
| 224 | + _data_out = [] |
| 225 | + for t,T in tqdm(enumerate(_data)): |
| 226 | + _data_out.append(rotate_stack(T, alpha_drift[t])) |
| 227 | + |
| 228 | + # Stack the list to a dask array and swap back the axes. |
| 229 | + _data_out = da.stack(_data_out) |
| 230 | + _data = da.swapaxes(_data,0,1) |
| 231 | + _data_out = da.swapaxes(_data_out,0,1) |
| 232 | + |
| 233 | + |
| 234 | + return _data_out |
| 235 | + |
0 commit comments