Skip to content

Commit f391cd1

Browse files
committed
Add documentation to ME_EMSC class
1 parent 7f66001 commit f391cd1

File tree

1 file changed

+103
-42
lines changed

1 file changed

+103
-42
lines changed

src/biospectools/preprocessing/me_emsc.py

Lines changed: 103 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99

1010
from biospectools.preprocessing import EMSC
1111
from biospectools.preprocessing.emsc import EMSCDetails
12-
from biospectools.preprocessing.criterions import \
13-
BaseStopCriterion, TolStopCriterion, EmptyCriterionError
12+
from biospectools.preprocessing.criterions import (
13+
BaseStopCriterion,
14+
TolStopCriterion,
15+
EmptyCriterionError,
16+
)
1417
from biospectools.utils.deprecated import deprecated_alias
1518

1619

@@ -24,13 +27,15 @@ class MeEMSCDetails:
2427
n_mie_components: int
2528

2629
def __init__(
27-
self,
28-
criterions: List[BaseStopCriterion],
29-
n_mie_components: int, spatial_shape=None):
30+
self,
31+
criterions: List[BaseStopCriterion],
32+
n_mie_components: int,
33+
spatial_shape=None,
34+
):
3035
self.criterions = criterions
3136
self.n_mie_components = n_mie_components
3237
if self.n_mie_components <= 0:
33-
raise ValueError('n_components must be greater than 0')
38+
raise ValueError("n_components must be greater than 0")
3439

3540
self._extract_from_criterions()
3641

@@ -47,8 +52,8 @@ def _extract_from_criterions(self):
4752
rmses, iters, coefs, resds = np_arrs
4853
for c in self.criterions:
4954
try:
50-
self.emscs.append(c.best_value['emsc'])
51-
emsc_inns: EMSCDetails = c.best_value['internals']
55+
self.emscs.append(c.best_value["emsc"])
56+
emsc_inns: EMSCDetails = c.best_value["internals"]
5257
coefs.append(emsc_inns.coefs[0])
5358
resds.append(emsc_inns.residuals[0])
5459
rmses.append(c.best_score)
@@ -60,17 +65,19 @@ def _extract_from_criterions(self):
6065
rmses.append(np.nan)
6166
iters.append(0)
6267

63-
self.rmses, self.n_iterations, self.coefs, self.residuals = \
64-
[np.array(np.broadcast_arrays(*arr)) for arr in np_arrs]
68+
self.rmses, self.n_iterations, self.coefs, self.residuals = [
69+
np.array(np.broadcast_arrays(*arr)) for arr in np_arrs
70+
]
6571

6672
@property
6773
def scaling_coefs(self) -> np.ndarray:
6874
return self.coefs[..., 0]
6975

7076
@property
7177
def mie_components_coefs(self) -> np.ndarray:
72-
assert self.n_mie_components > 0, \
73-
'Number of mie components must be greater than zero'
78+
assert (
79+
self.n_mie_components > 0
80+
), "Number of mie components must be greater than zero"
7481
return self.coefs[..., 1:1 + self.n_mie_components]
7582

7683
@property
@@ -79,20 +86,72 @@ def polynomial_coefs(self) -> np.ndarray:
7986

8087

8188
class MeEMSC:
89+
"""Mie Extinction Extended multiplicative signal correction (ME-EMSC) [1]_.
90+
91+
Parameters
92+
----------
93+
reference : `(K_channels,) ndarray`
94+
Reference spectrum.
95+
wavenumbers : `(K_channels,) ndarray`, optional
96+
Wavenumbers for the spectra must be passed.
97+
n_components : ´Optional[int]´ default None
98+
Number of components of from the PCA decomposed scattering curves
99+
to use in the EMSC. If None the PC components will be chosen
100+
such that they have 99.96 percent explained variance.
101+
n0s : ´(M_curves,) ndarray´ default None
102+
Refractive indices to use to generate the scattering curves.
103+
radiuses : ´(M_curves,) ndarray´´ default None
104+
Radii to use to generate the scattering curves.
105+
h : ´float´ default 0.25
106+
Scaling factor in getting the imaginary part of the
107+
complex refractive index from the pure absorbance.
108+
weights : ´(K_channels,) ndarray´ default None
109+
Weights to scale the reference spectrum at every iteration.
110+
max_iter : ´int´ default 30
111+
Maximum number of iterations the algorithm can run.
112+
tol : ´float´ default 1e-4
113+
How much we require the residuals to decrease.
114+
patience : ´int´ default 1
115+
Number of iterations with residuals increasing less than
116+
tolerance before we stop the correction.
117+
positive_ref : ´bool´ default True
118+
If True the reference spectrum will be forced to be strictly positive
119+
by setting all negative values to zero.
120+
verbose : ´bool´, default False
121+
If True track the progress of correction.
122+
123+
124+
Other Parameters
125+
----------------
126+
_model : `(K_channels, 2 + n_components) ndarray`
127+
Matrix that is used to solve least squares. First column is a
128+
baseline constant and second the reference spectrum scaling constant
129+
followed by principal components of the decomposed scatter curves.
130+
_norm_wns : `(K_channels,) ndarray`
131+
Normalized wavenumbers to -1, 1 range
132+
133+
References
134+
----------
135+
.. [1] J. H. Solheim et al. *An open-source code for Mie extinction
136+
extended multiplicative signal correction for infrared microscopy
137+
spectra of cells and tissue* Journal of biophotonics, 12(8), 2019.
138+
"""
139+
82140
def __init__(
83-
self,
84-
reference: np.ndarray,
85-
wavenumbers: np.ndarray,
86-
n_components: Optional[int] = None,
87-
n0s: np.ndarray = None,
88-
radiuses: np.ndarray = None,
89-
h: float = 0.25,
90-
weights: np.ndarray = None,
91-
max_iter: int = 30,
92-
tol: float = 1e-4,
93-
patience: int = 1,
94-
positive_ref: bool = True,
95-
verbose: bool = False):
141+
self,
142+
reference: np.ndarray,
143+
wavenumbers: np.ndarray,
144+
n_components: Optional[int] = None,
145+
n0s: np.ndarray = None,
146+
radiuses: np.ndarray = None,
147+
h: float = 0.25,
148+
weights: np.ndarray = None,
149+
max_iter: int = 30,
150+
tol: float = 1e-4,
151+
patience: int = 1,
152+
positive_ref: bool = True,
153+
verbose: bool = False,
154+
):
96155
self.reference = reference
97156
self.wavenumbers = wavenumbers
98157
self.weights = weights
@@ -104,9 +163,10 @@ def __init__(
104163
self.positive_ref = positive_ref
105164
self.verbose = verbose
106165

107-
@deprecated_alias(internals='details')
108-
def transform(self, spectra: np.ndarray, details=False) \
109-
-> U[np.ndarray, T[np.ndarray, MeEMSCDetails]]:
166+
@deprecated_alias(internals="details")
167+
def transform(
168+
self, spectra: np.ndarray, details=False
169+
) -> U[np.ndarray, T[np.ndarray, MeEMSCDetails]]:
110170
ref_x = self.reference
111171
if self.positive_ref:
112172
ref_x[ref_x < 0] = 0
@@ -131,7 +191,8 @@ def transform(self, spectra: np.ndarray, details=False) \
131191

132192
if details:
133193
dtls = MeEMSCDetails(
134-
criterions, self.mie_decomposer.n_components, spatial_shape)
194+
criterions, self.mie_decomposer.n_components, spatial_shape
195+
)
135196
return correcteds, dtls
136197
return correcteds
137198

@@ -140,18 +201,17 @@ def _correct_spectrum(self, basic_emsc, pure_guess, spectrum):
140201
while not self.stop_criterion:
141202
emsc = self._build_emsc(pure_guess, basic_emsc)
142203
pure_guess, inn = emsc.transform(
143-
spectrum[None], details=True, check_correlation=False)
204+
spectrum[None], details=True, check_correlation=False
205+
)
144206
pure_guess = pure_guess[0]
145207
rmse = np.sqrt(np.mean(inn.residuals ** 2))
146-
iter_result = \
147-
{'corrected': pure_guess, 'internals': inn, 'emsc': emsc}
208+
iter_result = {"corrected": pure_guess, "internals": inn, "emsc": emsc}
148209
self.stop_criterion.add(rmse, iter_result)
149-
return self.stop_criterion.best_value['corrected']
210+
return self.stop_criterion.best_value["corrected"]
150211

151212
def _build_emsc(self, reference, basic_emsc: EMSC) -> EMSC:
152213
# scale with basic EMSC:
153-
reference = basic_emsc.transform(
154-
reference[None], check_correlation=False)[0]
214+
reference = basic_emsc.transform(reference[None], check_correlation=False)[0]
155215
if np.all(np.isnan(reference)):
156216
raise np.linalg.LinAlgError()
157217

@@ -198,20 +258,21 @@ def generate(self, pure_absorbance, wavenumbers):
198258
return qexts
199259

200260
def _calculate_qext_curves(self, nprs, nkks, wavenumbers):
201-
rho = self.alpha0s * (1 + self.gammas*nkks) * wavenumbers # noqa: F841
261+
rho = self.alpha0s * (1 + self.gammas * nkks) * wavenumbers # noqa: F841
202262
tanbeta = nprs / (1 / self.gammas + nkks)
203263
beta = np.arctan(tanbeta) # noqa: F841
204264
qexts = ne.evaluate(
205-
'2 - 4 * exp(-rho * tanbeta) * cos(beta) / rho * sin(rho - beta)'
206-
'- 4 * exp(-rho * tanbeta) * (cos(beta) / rho) ** 2 * cos(rho - 2 * beta)'
207-
'+ 4 * (cos(beta) / rho) ** 2 * cos(2 * beta)')
265+
"2 - 4 * exp(-rho * tanbeta) * cos(beta) / rho * sin(rho - beta)"
266+
"- 4 * exp(-rho * tanbeta) * (cos(beta) / rho) ** 2 * cos(rho - 2 * beta)"
267+
"+ 4 * (cos(beta) / rho) ** 2 * cos(2 * beta)"
268+
)
208269
return qexts.reshape(-1, len(wavenumbers))
209270

210271
def _get_refractive_index(self, pure_absorbance, wavenumbers):
211272
pad_size = 200
212273
# Extend absorbance spectrum
213274
wns_ext = self._extrapolate_wns(wavenumbers, pad_size)
214-
pure_ext = np.pad(pure_absorbance, pad_size, mode='edge')
275+
pure_ext = np.pad(pure_absorbance, pad_size, mode="edge")
215276

216277
# Calculate refractive index
217278
nprs_ext = pure_ext / wns_ext
@@ -224,7 +285,7 @@ def _get_refractive_index(self, pure_absorbance, wavenumbers):
224285
return nprs, nkks
225286

226287
def _extrapolate_wns(self, wns, pad_size):
227-
f = interp1d(np.arange(len(wns)), wns, fill_value='extrapolate')
288+
f = interp1d(np.arange(len(wns)), wns, fill_value="extrapolate")
228289
idxs_ext = np.arange(-pad_size, len(wns) + pad_size)
229290
wns_ext = f(idxs_ext)
230291
return wns_ext
@@ -243,7 +304,7 @@ def find_orthogonal_components(self, qexts: np.ndarray):
243304
self.n_components = self._estimate_n_components(qexts)
244305
self.svd.n_components = self.n_components
245306
# do not refit svd, since it was fitted during _estimation
246-
return self.svd.components_[:self.n_components]
307+
return self.svd.components_[: self.n_components]
247308

248309
self.svd.fit(qexts)
249310
return self.svd.components_

0 commit comments

Comments
 (0)