Skip to content

Commit eb738db

Browse files
committed
numpy drift functions
1 parent 110836d commit eb738db

2 files changed

Lines changed: 216 additions & 2 deletions

File tree

src/napatrackmater/_version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = version = "5.7.8"
2-
__version_tuple__ = version_tuple = (5, 7, 8)
1+
__version__ = version = "5.7.9"
2+
__version_tuple__ = version_tuple = (5, 7, 9)

src/napatrackmater/drift.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,217 @@ def apply_alpha_drift(_data, alpha_drift):
233233

234234
return _data_out
235235

236+
def get_xy_drift_numpy(_data, ref_channel=0):
237+
"""
238+
Compute XY drift for NumPy array shape (channel, time, z, y, x)
239+
Returns numpy array of shape (time, 2)
240+
"""
241+
# replicate da.average
242+
xy_movie = np.average(_data[ref_channel], axis=1)
243+
244+
print('Determining drift in XY')
245+
shifts, error, phasediff = [], [], []
246+
for t in tqdm(range(len(xy_movie) - 1)):
247+
s, e, p = phase_cross_correlation(
248+
xy_movie[t], xy_movie[t + 1], normalization=None
249+
)
250+
shifts.append(s)
251+
error.append(e)
252+
phasediff.append(p)
253+
254+
# cumulative sum
255+
shifts_xy = np.cumsum(np.array(shifts), axis=0)
256+
# set up list form
257+
shifts_xy = shifts_xy.tolist()
258+
shifts_xy.insert(0, [0, 0])
259+
shifts_xy = np.asarray(shifts_xy)
260+
261+
# replicate plotting
262+
plt.title('XY-Drift')
263+
plt.plot(shifts_xy[:, 0], label='x')
264+
plt.plot(shifts_xy[:, 1], label='y')
265+
plt.legend()
266+
plt.xlabel('Timesteps')
267+
plt.ylabel('Drift in pixel')
268+
plt.savefig('XY-Drift.svg')
269+
plt.clf()
270+
271+
return shifts_xy
272+
273+
274+
def get_z_drift_numpy(_data, ref_channel=0):
275+
"""
276+
Compute Z drift for NumPy array shape (channel, time, z, y, x)
277+
Returns numpy array of shape (time, 2) with only z shifts
278+
"""
279+
# replicate da.swapaxes and da.average
280+
z_movie = np.average(np.swapaxes(_data[ref_channel], 2, 1), axis=1)
281+
282+
print('Determining drift in Z')
283+
shifts, error, phasediff = [], [], []
284+
for t in tqdm(range(len(z_movie) - 1)):
285+
s, e, p = phase_cross_correlation(
286+
z_movie[t], z_movie[t + 1], normalization=None
287+
)
288+
shifts.append(s)
289+
error.append(e)
290+
phasediff.append(p)
291+
292+
shifts_z = np.cumsum(np.array(shifts), axis=0)
293+
# zero out y component same as Dask
294+
for i in shifts_z:
295+
i[1] = 0
296+
297+
shifts_z = shifts_z.tolist()
298+
shifts_z.insert(0, [0, 0])
299+
shifts_z = np.asarray(shifts_z)
300+
301+
plt.title('Z-Drift')
302+
plt.plot(shifts_z[:, 0], label='z')
303+
plt.legend()
304+
plt.xlabel('Timesteps')
305+
plt.ylabel('Drift in pixel')
306+
plt.savefig('Z-Drift.svg')
307+
plt.clf()
308+
309+
return shifts_z
310+
311+
312+
def translate_stack_numpy(_image, _shift):
313+
"""
314+
Apply 2D XY affine transform to a stack shape (channel, z, y, x)
315+
"""
316+
tform = AffineTransform(translation=-_shift)
317+
_image_out = []
318+
for c, C in enumerate(_image):
319+
_out_stack = []
320+
for z, Z in enumerate(C):
321+
# replicate delayed call signature by using np.array(tform)
322+
_out_stack.append(affine_transform(Z, np.array(tform)))
323+
_image_out.append(np.stack(_out_stack, axis=0))
324+
return np.stack(_image_out, axis=0)
325+
326+
327+
def apply_xy_drift_numpy(_data, xy_drift):
328+
"""
329+
Apply XY drift to NumPy array shape (channel, time, z, y, x)
330+
Returns corrected array of same shape
331+
"""
332+
# replicate da.swapaxes(_data,0,1)
333+
_data = np.swapaxes(_data, 0, 1)
334+
335+
print('Scheduling tasks...')
336+
_translated_data = []
337+
for t, T in tqdm(list(enumerate(_data))):
338+
_translated_data.append(translate_stack_numpy(T, xy_drift[t]))
339+
340+
_data_out = np.stack(_translated_data, axis=0)
341+
# replicate swapping back
342+
_data_out = np.swapaxes(_data_out, 0, 1)
343+
344+
print('XY-shift has been scheduled.')
345+
return _data_out
346+
347+
348+
def rotate_stack_numpy(_image, _alpha):
349+
"""
350+
Apply 2D rotation to a stack shape (channel, z, y, x)
351+
"""
352+
_image_out = []
353+
for c, C in enumerate(_image):
354+
_out_stack = []
355+
for z, Z in enumerate(C):
356+
_out_stack.append(rotate(Z, -_alpha, reshape=False))
357+
_image_out.append(np.stack(_out_stack, axis=0))
358+
return np.stack(_image_out, axis=0)
359+
360+
361+
def apply_z_drift_numpy(_data, z_drift):
362+
"""
363+
Apply Z drift to NumPy array shape (channel, time, z, y, x)
364+
Returns corrected array of same shape
365+
"""
366+
# replicate axis swaps
367+
_data = np.swapaxes(_data, 0, 1)
368+
_data = np.swapaxes(_data, 3, 2)
369+
370+
_data_out = []
371+
for t, T in tqdm(list(enumerate(_data))):
372+
_data_out.append(translate_stack_numpy(T, z_drift[t]))
373+
374+
_data_out = np.stack(_data_out, axis=0)
375+
# swap back
376+
_data_out = np.swapaxes(_data_out, 2, 3)
377+
_data_out = np.swapaxes(_data_out, 0, 1)
378+
379+
return _data_out
380+
381+
382+
def crop_data_numpy(data, xy_drift, z_drift):
383+
"""
384+
Crop edges based on drift to remove border artifacts
385+
"""
386+
y_crop = [int(np.max(xy_drift[:, 0])), int(np.shape(data)[-1] - abs(np.min(xy_drift[:, 0])))]
387+
x_crop = [int(np.max(xy_drift[:, 1])), int(np.shape(data)[-2] - abs(np.min(xy_drift[:, 1])))]
388+
z_crop = [int(np.max(z_drift[:, 0])), int(np.shape(data)[-3] - abs(np.min(z_drift[:, 0])))]
389+
390+
return data[:, :,
391+
z_crop[0]:z_crop[1],
392+
y_crop[0]:y_crop[1],
393+
x_crop[0]:x_crop[1]]
394+
395+
396+
def get_rotation_numpy(data, ref_channel=0):
397+
"""
398+
Estimate rotation drift for NumPy array shape (channel, time, z, y, x)
399+
Returns numpy array of shape (time,)
400+
"""
401+
xy_movie = np.average(data[ref_channel], axis=1)
402+
radius = int(min(data.shape[3], data.shape[4]) / 2)
403+
404+
print('Determining rotation')
405+
shifts, error, phasediff = [], [], []
406+
for t in tqdm(range(len(xy_movie) - 1)):
407+
s, e, p = phase_cross_correlation(
408+
warp_polar(xy_movie[t], radius=radius),
409+
warp_polar(xy_movie[t + 1], radius=radius),
410+
normalization=None
411+
)
412+
shifts.append(s)
413+
error.append(e)
414+
phasediff.append(p)
415+
416+
shifts_a = np.cumsum(np.array(shifts), axis=0)
417+
# replicate bug: zero original shifts then ignore in cumsum
418+
for i in shifts:
419+
i[0] = 0
420+
421+
shifts_a = shifts_a.tolist()
422+
shifts_a.insert(0, [0, 0])
423+
shifts_a = np.asarray(shifts_a)
424+
425+
plt.title('alpha-Drift')
426+
plt.plot(shifts_a[:, 0], label='alpha')
427+
plt.legend()
428+
plt.xlabel('Timesteps')
429+
plt.ylabel('Rotation in degree')
430+
plt.savefig('Rotation-Drift.svg')
431+
plt.clf()
432+
433+
return shifts_a[:, 0]
434+
435+
436+
def apply_alpha_drift_numpy(_data, alpha_drift):
437+
"""
438+
Apply rotation drift to NumPy array shape (channel, time, z, y, x)
439+
Returns corrected array of same shape
440+
"""
441+
_data = np.swapaxes(_data, 0, 1)
442+
_data_out = []
443+
for t, T in tqdm(list(enumerate(_data))):
444+
_data_out.append(rotate_stack_numpy(T, alpha_drift[t]))
445+
446+
_data_out = np.stack(_data_out, axis=0)
447+
_data_out = np.swapaxes(_data_out, 0, 1)
448+
449+
return _data_out

0 commit comments

Comments
 (0)