Skip to content

Commit 0753cd0

Browse files
authored
Merge pull request Quasars#29 from ngergihun/imagestacking
NEW FEATURE: Imagestacking
2 parents 17e6d6c + e9487e7 commit 0753cd0

File tree

4 files changed

+290
-2
lines changed

4 files changed

+290
-2
lines changed

pySNOM/images.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from enum import Enum
55
from pySNOM.defaults import Defaults
66

7+
from skimage.transform import warp
8+
from skimage.registration import optical_flow_tvl1, phase_cross_correlation
9+
from scipy.ndimage import fourier_shift
10+
711
MeasurementModes = Enum(
812
"MeasurementModes",
913
["None", "AFM", "PsHet", "WLI", "PTE", "TappingAFMIR", "ContactAFM"],
@@ -347,3 +351,196 @@ def mask_from_booleans(bool_mask, bad_values = False):
347351
def mask_from_datacondition(condition):
348352
mshape = np.shape(condition)
349353
return np.where(condition,np.nan*np.ones(mshape),np.ones(mshape))
354+
355+
356+
class CalculateOpticalFlow(Transformation):
357+
"""Calculates the pixel coordiate drifts between reference and template image"""
358+
359+
def __init__(self, image_ref):
360+
self.image_ref = image_ref
361+
362+
def transform(self, image):
363+
v, u = optical_flow_tvl1(
364+
self.image_ref / np.nanmax(self.image_ref), image / np.nanmax(image)
365+
)
366+
return v, u
367+
368+
369+
class WrapImage(Transformation):
370+
"""Applies the pixel-by-pixel drift correction calculated by OpticalFlow"""
371+
372+
def __init__(self, v, u):
373+
self.v = v
374+
self.u = u
375+
376+
def transform(self, image):
377+
nr, nc = image.shape
378+
row_coords, col_coords = np.meshgrid(
379+
np.arange(nr), np.arange(nc), indexing="ij"
380+
)
381+
return warp(
382+
image, np.array([row_coords + self.v, col_coords + self.u]), mode="edge"
383+
)
384+
385+
386+
class CalculateXCorrDrift(Transformation):
387+
"""Calculates the drift between reference and template image"""
388+
389+
def __init__(self, image_ref):
390+
self.image_ref = image_ref
391+
392+
def transform(self, image):
393+
shift, _, _ = phase_cross_correlation(self.image_ref, image)
394+
return shift
395+
396+
397+
class CorrectImageDrift(Transformation):
398+
"""Rearranges image pixels to correct image shift calculated by cross-correlation"""
399+
400+
def __init__(self, shift):
401+
self.shift = shift
402+
403+
def transform(self, image):
404+
offset_phase = fourier_shift(np.fft.fftn(image), self.shift)
405+
offset_phase = np.fft.ifftn(offset_phase)
406+
return offset_phase.real
407+
408+
409+
class AlignImageStack(Transformation):
410+
"""Calculates the drift between the given images and organize the comman areas into an aligned stack"""
411+
412+
def __init__(self):
413+
pass
414+
415+
def calculate(self, images):
416+
shifts = []
417+
crossrect = [0, 0, np.shape(images[0])[0], np.shape(images[0])[1]]
418+
if len(images) > 1:
419+
xcorr = CalculateXCorrDrift(images[0])
420+
for i in range(len(images)):
421+
if i > 0:
422+
shifts.append(xcorr.transform(images[i]))
423+
crossrect = shifted_cross_section(
424+
rect1=crossrect,
425+
rect2=[
426+
-shifts[-1][0],
427+
shifts[-1][1],
428+
np.shape(images[i])[0],
429+
np.shape(images[i])[1],
430+
],
431+
)
432+
return shifts, crossrect
433+
else:
434+
return None
435+
436+
def transform(self, images, shifts, crossrect):
437+
aligned_stack = []
438+
for i in range(len(images)):
439+
if i > 0:
440+
shifter = CorrectImageDrift(shifts[i - 1])
441+
aligned_stack.append(shifter.transform(images[i]))
442+
aligned_stack[i] = cut_image(aligned_stack[i], crossrect)
443+
else:
444+
aligned_stack.append(cut_image(images[i], crossrect))
445+
return aligned_stack
446+
447+
448+
def sort_image_stack(images, wns):
449+
"""Sort the image stack based on the wavenumber list"""
450+
451+
idxs = np.argsort(np.asarray(wns))
452+
images = [images[i] for i in idxs]
453+
wns = [wns[i] for i in idxs]
454+
455+
return images, wns
456+
457+
458+
def create_nparray_stack(measlist):
459+
"""Creates a numpy array stack from a list of measurements, organized as [ rows, columns, wavelengths ] (compatible with quasar io utils)"""
460+
461+
stack = np.zeros(
462+
(np.shape(measlist[0])[0], np.shape(measlist[0])[1], len(measlist))
463+
)
464+
465+
for i, meas in enumerate(measlist):
466+
stack[:, :, i] = meas
467+
468+
return stack
469+
470+
471+
def dict_from_imagestack(X, channelname, wn=None, is_interferogram = True):
472+
"""Converts the image stack into a pySNOM spectra or interferograms compatible dictionary"""
473+
final_dict = {}
474+
params = {}
475+
476+
X = np.asarray(X)
477+
478+
params["PixelArea"] = [X.shape[1], X.shape[2], X.shape[0]]
479+
params["Averaging"] = 1
480+
params["Scan"] = "Fourier Scan"
481+
482+
final_dict[channelname] = flatten_stack(X)
483+
484+
y_loc = np.repeat(np.arange(X.shape[1]), X.shape[2])
485+
x_loc = np.tile(np.arange(X.shape[2]), X.shape[1])
486+
487+
final_dict["Row"] = np.repeat(y_loc, X.shape[0])
488+
final_dict["Column"] = np.repeat(x_loc, X.shape[0])
489+
490+
if is_interferogram:
491+
depth_channel_name = "M"
492+
else:
493+
depth_channel_name = "Wavenumber"
494+
495+
if wn is not None:
496+
final_dict[depth_channel_name] = np.tile(wn, X.shape[1] * X.shape[2])
497+
else:
498+
final_dict[depth_channel_name] = np.tile(
499+
np.arange(X.shape[0]), X.shape[1] * X.shape[2]
500+
)
501+
502+
return final_dict, params
503+
504+
def flatten_stack(imagestack):
505+
""" Flatten out values in an image stack to be aneble to add it to spectral dictionaries """
506+
imagestack = np.asarray(imagestack)
507+
flattened_values = imagestack.reshape((imagestack.shape[0], imagestack.shape[1] * imagestack.shape[2]))
508+
return np.ravel(flattened_values, order="F")
509+
510+
def shifted_cross_section(rect1: list, rect2: list):
511+
"""Calculates the cross-section of two rectangle shifted to each other"""
512+
x1 = rect1[1]
513+
x2 = rect2[1]
514+
y1 = rect1[0]
515+
y2 = rect2[0]
516+
W1 = rect1[3]
517+
W2 = rect2[3]
518+
H1 = rect1[2]
519+
H2 = rect2[2]
520+
521+
if y2 > y1:
522+
Hn = H1 - (y2 - y1)
523+
yn = y2
524+
elif (y2 < y1) and (y1 + H1 > y2 + H2): # Negative shift and higher than H2
525+
Hn = H2 + (y2 - y1)
526+
yn = y1
527+
else:
528+
Hn = H1
529+
yn = y1
530+
531+
if x2 > x1: # Positive shift
532+
Wn = W1 - (x2 - x1)
533+
xn = x2
534+
elif (x2 < x1) and (x1 + W1 > x2 + W2): # Negative shift and higher than W2
535+
Wn = W2 + (x2 - x1)
536+
xn = x1
537+
else:
538+
Wn = W1
539+
xn = x1
540+
541+
return int(yn), int(xn), int(Hn), int(Wn)
542+
543+
544+
def cut_image(image, rect):
545+
"""Cuts the part of the image array defined by rectangle"""
546+
return image[-(rect[2]) : -(rect[0] + 1), rect[1] : rect[1] + rect[3]]

pySNOM/readers.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import gsffile
33
import numpy as np
44
import pandas as pd
5-
5+
import os
6+
import pathlib
7+
import re
68

79
class Reader:
810
def __init__(self, fullfilepath=None):
@@ -342,3 +344,66 @@ def read(self):
342344
data[key] = np.asarray(data[key])
343345

344346
return data, params
347+
348+
349+
class ImageStackReader(Reader):
350+
"""Reads a list of images from the subfolders of the specified folder by loading the files that contain the pattern string int the filename"""
351+
352+
def __init__(self, folder=None):
353+
super().__init__(folder)
354+
self.folder = self.filename
355+
356+
def read(self, pattern):
357+
imagestack = []
358+
wns = []
359+
filepaths = get_filenames(self.folder, pattern)
360+
361+
for i, path in enumerate(filepaths):
362+
data_reader = GsfReader(path)
363+
imagestack.append(data_reader.read().data)
364+
365+
try:
366+
txtpath = recreate_infofile_name_from_path(path)
367+
inforeader = NeaInfoReader(txtpath)
368+
infodict = inforeader.read()
369+
if infodict["TargetWavelength"] == "":
370+
wn = infodict["InterferometerCenterDistance"][0]
371+
wns.append(wn)
372+
else:
373+
wn = infodict["TargetWavelength"]
374+
if wn < 50.0:
375+
wn = 10000 / wn
376+
wns.append(wn)
377+
except:
378+
wns.append(i)
379+
380+
idxs = np.argsort(np.asarray(wns))
381+
imagestack = [imagestack[i] for i in idxs]
382+
wns = [wns[i] for i in idxs]
383+
384+
return imagestack, wns
385+
386+
387+
def get_filenames(folder, pattern):
388+
"""Returns the filepath of all files in the subfolders of the specified folder that contain pattern string in the filename"""
389+
390+
filepaths = []
391+
392+
for subfolder in os.listdir(folder):
393+
if os.path.isdir(os.path.join(folder, subfolder)):
394+
for name in os.listdir(os.path.join(folder, subfolder)):
395+
if re.search(pattern, name):
396+
subpath = os.path.join(subfolder, name)
397+
filepaths.append(os.path.join(folder, subpath))
398+
399+
return filepaths
400+
401+
402+
def recreate_infofile_name_from_path(filepath):
403+
"""Recreates the name of the info file from the path of the data file"""
404+
405+
pathparts = list(pathlib.PurePath(filepath).parts)
406+
newparts = pathparts[:-1]
407+
newparts.append(pathparts[-2] + ".txt")
408+
409+
return str(pathlib.PurePath(*newparts))

pySNOM/tests/test_transform.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from pySNOM.images import LineLevel, BackgroundPolyFit, SimpleNormalize, DataTypes, mask_from_datacondition
5+
from pySNOM.images import LineLevel, BackgroundPolyFit, SimpleNormalize, DataTypes, AlignImageStack, mask_from_datacondition, dict_from_imagestack
66

77

88
class TestLineLevel(unittest.TestCase):
@@ -172,5 +172,30 @@ def test_min(self):
172172
out = l.transform(d,mask=mask)
173173
np.testing.assert_almost_equal(out, [-1.0,0.0,1.0])
174174

175+
176+
class TestAlignImageStack(unittest.TestCase):
177+
def test_stackalignment(self):
178+
image1 = np.zeros((50, 100))
179+
image2 = np.zeros((50, 100))
180+
image1[10:40, 10:40] = 1
181+
image2[20:50, 20:50] = 1
182+
183+
aligner = AlignImageStack()
184+
shifts, crossrect = aligner.calculate([image1, image2])
185+
np.testing.assert_equal(shifts, [np.asarray([-10.0, -10.0])])
186+
np.testing.assert_equal(crossrect, [10, 0, 40, 90])
187+
188+
out = aligner.transform([image1, image2], shifts, crossrect)
189+
np.testing.assert_equal(np.shape(out), (2, 29, 90))
190+
191+
class TestHelperFunctions(unittest.TestCase):
192+
def test_dictfromimagestack(self):
193+
stack = [np.zeros((50, 100)), np.zeros((50, 100))]
194+
195+
out, outparams = dict_from_imagestack(stack,"O2A")
196+
197+
self.assertEqual(outparams["PixelArea"], [50,100,2])
198+
self.assertTrue("M" in list(out.keys()))
199+
175200
if __name__ == "__main__":
176201
unittest.main()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ setuptools
44
scipy
55
gsffile
66
pandas>=1.4.0
7+
scikit-image

0 commit comments

Comments
 (0)