Skip to content

Commit 17e6d6c

Browse files
authored
Merge pull request Quasars#31 from ngergihun/masking
Masking
2 parents c0f3ed2 + ad0aeb3 commit 17e6d6c

File tree

2 files changed

+188
-48
lines changed

2 files changed

+188
-48
lines changed

pySNOM/images.py

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -189,36 +189,57 @@ def type_from_channelname(channelname):
189189

190190

191191
class Transformation:
192+
192193
def transform(self, data):
193194
raise NotImplementedError()
194195

195196

196-
class LineLevel(Transformation):
197+
class MaskedTransformation(Transformation):
198+
199+
def calculate(self, data, mask=None):
200+
raise NotImplementedError()
201+
202+
def correct(self, data, correction):
203+
""" Applies the calculated corrections to the data """
204+
if self.datatype == DataTypes.Amplitude:
205+
return data / correction
206+
else:
207+
return data - correction
208+
209+
def transform(self, data, mask=None):
210+
""" Calculates and applies the corrections to the data taking into account the mask if given """
211+
correction = self.calculate(data, mask=mask)
212+
return self.correct(data, correction)
213+
214+
215+
class LineLevel(MaskedTransformation):
197216
def __init__(self, method="median", datatype=DataTypes.Phase):
198217
self.method = method
199218
self.datatype = datatype
200219

201-
def transform(self, data):
202-
if self.method == "median":
220+
def calculate(self, data, mask = None):
221+
if mask is not None:
222+
data = mask*data
223+
224+
if self.method == 'median':
203225
norm = np.nanmedian(data, axis=1, keepdims=True)
204226
elif self.method == "mean":
205227
norm = np.nanmean(data, axis=1, keepdims=True)
206228
elif self.method == "difference":
207229
if self.datatype == DataTypes.Amplitude:
208230
norm = np.nanmedian(data[1:] / data[:-1], axis=1, keepdims=True)
231+
norm = np.append(norm,1)
209232
else:
210233
norm = np.nanmedian(data[1:] - data[:-1], axis=1, keepdims=True)
211-
data = data[:-1] # difference does not make sense for the last row
234+
norm = np.append(norm,0) # difference does not make sense for the last row
235+
norm = np.reshape(norm, (norm.size,1))
212236
else:
213237
if self.datatype == DataTypes.Amplitude:
214238
norm = 1
215239
else:
216240
norm = 0
217241

218-
if self.datatype == DataTypes.Amplitude:
219-
return data / norm
220-
else:
221-
return data - norm
242+
return norm
222243

223244

224245
class RotatePhase(Transformation):
@@ -248,43 +269,41 @@ def transform(self, data):
248269
)
249270

250271

251-
class SimpleNormalize(Transformation):
272+
class SimpleNormalize(MaskedTransformation):
252273
def __init__(self, method="median", value=1.0, datatype=DataTypes.Phase):
253274
self.method = method
254275
self.value = value
255276
self.datatype = datatype
256277

257-
def transform(self, data):
278+
def calculate(self, data, mask=None):
279+
""" Calculates and returns the image corrections using mask (if given) without applying it to the data"""
280+
if mask is not None:
281+
data = mask*data
282+
258283
match self.method:
259-
case "median":
260-
if self.datatype == DataTypes.Amplitude:
261-
return data / np.nanmedian(data)
262-
else:
263-
return data - np.nanmedian(data)
264-
case "mean":
265-
if self.datatype == DataTypes.Amplitude:
266-
return data / np.nanmean(data)
267-
else:
268-
return data - np.nanmean(data)
269-
case "manual":
270-
if self.datatype == DataTypes.Amplitude:
271-
return data / self.value
272-
else:
273-
return data - self.value
274-
case "min":
275-
if self.datatype == DataTypes.Amplitude:
276-
return data / np.nanmin(data)
277-
else:
278-
return data - np.nanmin(data)
279-
280-
281-
class BackgroundPolyFit(Transformation):
284+
case 'median':
285+
norm = np.nanmedian(data)
286+
case 'mean':
287+
norm = np.nanmean(data)
288+
case 'manual':
289+
norm = self.value
290+
case 'min':
291+
norm = np.nanmin(data)
292+
293+
return norm
294+
295+
296+
class BackgroundPolyFit(MaskedTransformation):
282297
def __init__(self, xorder=1, yorder=1, datatype=DataTypes.Phase):
283298
self.xorder = int(xorder)
284299
self.yorder = int(yorder)
285300
self.datatype = datatype
286301

287-
def transform(self, data):
302+
def calculate(self, data, mask = None):
303+
""" Calculates and returns the fitted polynomial background using mask (if given) without applying it to the data"""
304+
if mask is not None:
305+
data = mask*data
306+
288307
Z = copy.deepcopy(data)
289308
x = list(range(0, Z.shape[1]))
290309
y = list(range(0, Z.shape[0]))
@@ -310,18 +329,21 @@ def get_basis(x, y, max_order_x=1, max_order_y=1):
310329
A = np.vstack(basis).T
311330
c, r, rank, s = np.linalg.lstsq(A, b, rcond=None)
312331

313-
background = np.sum(
314-
c[:, None, None]
315-
* np.array(get_basis(X, Y, self.xorder, self.yorder)).reshape(
316-
len(basis), *X.shape
317-
),
318-
axis=0,
319-
)
332+
background = np.sum(c[:, None, None] * np.array(get_basis(X, Y, self.xorder, self.yorder)).reshape(len(basis), *X.shape),axis=0)
333+
320334
except ValueError:
321335
background = np.ones(np.shape(data))
322336
print("X and Y order must be integer!")
323337

324-
if self.datatype == DataTypes["Amplitude"]:
325-
return Z / background, background
326-
else:
327-
return Z - background, background
338+
return background
339+
340+
341+
# TODO: Helper functions to create masks or turn other types of masks into 1/Nan mask
342+
def mask_from_booleans(bool_mask, bad_values = False):
343+
""" Turn a boolean array to an array conatining nans and ones"""
344+
mshape = np.shape(bool_mask)
345+
return np.where(bool_mask==bad_values,np.nan*np.ones(mshape),np.ones(mshape))
346+
347+
def mask_from_datacondition(condition):
348+
mshape = np.shape(condition)
349+
return np.where(condition,np.nan*np.ones(mshape),np.ones(mshape))

pySNOM/tests/test_transform.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from pySNOM.images import LineLevel, DataTypes
5+
from pySNOM.images import LineLevel, BackgroundPolyFit, SimpleNormalize, DataTypes, mask_from_datacondition
66

77

88
class TestLineLevel(unittest.TestCase):
@@ -46,13 +46,131 @@ def test_difference(self):
4646
d = np.arange(12).reshape(3, -1)[:, [0, 1, 3]]
4747
l = LineLevel(method="difference", datatype=DataTypes.Phase)
4848
out = l.transform(d)
49-
np.testing.assert_almost_equal(out, [[-4.0, -3.0, -1.0], [0.0, 1.0, 3.0]])
49+
np.testing.assert_almost_equal(out, [[-4., -3., -1.], [0., 1., 3.], [8., 9., 11.]])
5050
l = LineLevel(method="difference", datatype=DataTypes.Amplitude)
5151
out = l.transform(d)
5252
np.testing.assert_almost_equal(
53-
out, [[0.0, 0.2, 0.6], [2.2222222, 2.7777778, 3.8888889]]
53+
out, [[0., 0.2, 0.6], [ 2.2222222, 2.7777778, 3.8888889], [8., 9., 11.]]
5454
)
5555

56+
def test_masking_mean(self):
57+
d = np.zeros([8,10])
58+
d[2:6,3:7] = 1
59+
mask = np.ones([8,10])
60+
mask[2:6,3:7] = np.nan
61+
62+
l = LineLevel(method="mean", datatype=DataTypes.Phase)
63+
64+
out = l.transform(d)
65+
np.testing.assert_almost_equal(out[5,0], -0.4)
66+
out = l.transform(d,mask=mask)
67+
np.testing.assert_almost_equal(out[5,0], 0.0)
68+
69+
def test_masking_median(self):
70+
d = np.zeros([8,10])
71+
d[2:6,2:9] = 1
72+
mask = np.ones([8,10])
73+
mask[2:6,2:9] = np.nan
74+
75+
l = LineLevel(method="median", datatype=DataTypes.Phase)
76+
77+
out = l.transform(d)
78+
np.testing.assert_almost_equal(out[5,0], -1.0)
79+
out = l.transform(d,mask=mask)
80+
np.testing.assert_almost_equal(out[5,0], 0.0)
81+
82+
def test_difference_masking(self):
83+
d = np.zeros([8,10])
84+
d[2:6,2:9] = 1
85+
mask = np.ones([8,10])
86+
mask[2:6,2:9] = np.nan
87+
88+
l = LineLevel(method="difference", datatype=DataTypes.Phase)
89+
out = l.transform(d)
90+
np.testing.assert_almost_equal(out[5,0], 1.0)
91+
np.testing.assert_almost_equal(out[1,0], -1.0)
92+
np.testing.assert_almost_equal(out[5,2], 2.0)
93+
94+
class TestBackgroundPolyFit(unittest.TestCase):
95+
def test_withmask(self):
96+
d = np.ones([10,10])
97+
d[4:8,4:8] = 10
98+
mask = np.ones([10,10])
99+
mask[4:8,4:8] = np.nan
100+
101+
t = BackgroundPolyFit(xorder=1,yorder=1,datatype=DataTypes.Phase)
102+
out = t.transform(d,mask=mask)
103+
np.testing.assert_almost_equal(out[0,0], 0.0)
104+
np.testing.assert_almost_equal(out[9,9], 0.0)
105+
106+
t = BackgroundPolyFit(xorder=1,yorder=1,datatype=DataTypes.Amplitude)
107+
out = t.transform(d,mask=mask)
108+
np.testing.assert_almost_equal(out[0,0], 1.0)
109+
np.testing.assert_almost_equal(out[9,9], 1.0)
110+
111+
def test_withoutmask(self):
112+
d = np.ones([10,10])
113+
d[4:8,4:8] = 10
114+
mask = np.ones([10,10])
115+
mask[4:8,4:8] = np.nan
116+
117+
t = BackgroundPolyFit(xorder=1,yorder=1,datatype=DataTypes.Phase)
118+
out = t.transform(d)
119+
np.testing.assert_almost_equal(out[0,0], -0.2975206611570238)
120+
np.testing.assert_almost_equal(out[9,9], -3.439338842975202)
121+
122+
t = BackgroundPolyFit(xorder=1,yorder=1,datatype=DataTypes.Amplitude)
123+
out = t.transform(d)
124+
np.testing.assert_almost_equal(out[0,0], 0.7707006369426758)
125+
np.testing.assert_almost_equal(out[9,9], 0.22525876833718098)
126+
127+
class TestHelperFunction(unittest.TestCase):
128+
129+
def test_mask_from_condition(self):
130+
d = np.ones([2,2])
131+
d[0,0] = 0
132+
out = mask_from_datacondition(d<1)
133+
134+
np.testing.assert_equal(out[0,0], np.nan)
135+
136+
class TestSimpleNormalize(unittest.TestCase):
137+
138+
def test_median(self):
139+
d = np.zeros([8,10])
140+
d[2:9,1:9] = 1
141+
mask = np.ones([8,10])
142+
mask[2:9,1:9] = np.nan
143+
144+
l = SimpleNormalize(method="median", datatype=DataTypes.Phase)
145+
146+
out = l.transform(d)
147+
np.testing.assert_almost_equal(out[0,0], -1.0)
148+
out = l.transform(d,mask=mask)
149+
np.testing.assert_almost_equal(out[0,0], 0.0)
150+
151+
def test_mean(self):
152+
d = np.zeros([8,10])
153+
d[2:9,1:9] = 1
154+
mask = np.ones([8,10])
155+
mask[2:9,1:9] = np.nan
156+
157+
l = SimpleNormalize(method="mean", datatype=DataTypes.Phase)
158+
159+
out = l.transform(d)
160+
np.testing.assert_almost_equal(out[0,0], -0.6)
161+
out = l.transform(d,mask=mask)
162+
np.testing.assert_almost_equal(out[0,0], 0.0)
163+
164+
def test_min(self):
165+
d=np.asarray([1.0, 2.0, 3.0])
166+
mask = np.asarray([np.nan, 1, 1])
167+
168+
l = SimpleNormalize(method="min", datatype=DataTypes.Phase)
169+
170+
out = l.transform(d)
171+
np.testing.assert_almost_equal(out, [0.0, 1.0, 2.0])
172+
out = l.transform(d,mask=mask)
173+
np.testing.assert_almost_equal(out, [-1.0,0.0,1.0])
56174

57175
if __name__ == "__main__":
58176
unittest.main()

0 commit comments

Comments
 (0)