Skip to content

Commit 34fb75a

Browse files
michaelosthegeLarsHalle
authored andcommitted
Fix all typing errors and run mypy in pre-commit
1 parent 5d1277d commit 34fb75a

File tree

11 files changed

+159
-105
lines changed

11 files changed

+159
-105
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ repos:
1919
rev: 22.3.0
2020
hooks:
2121
- id: black
22+
- repo: https://github.com/pre-commit/mirrors-mypy
23+
rev: v1.3.0
24+
hooks:
25+
- id: mypy

bletl/core.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import urllib.request
99
import warnings
1010
from collections.abc import Iterable
11-
from typing import Optional, Union
11+
from typing import Optional, Sequence, Union
1212

1313
import numpy
1414
import pandas
@@ -92,14 +92,14 @@ def get_parser(filepath: Union[str, pathlib.Path]) -> BLDParser:
9292
def _parse(
9393
filepath: str,
9494
drop_incomplete_cycles: bool,
95-
lot_number: int,
96-
temp: int,
97-
cal_0: float = None,
98-
cal_100: float = None,
99-
phi_min: float = None,
100-
phi_max: float = None,
101-
pH_0: float = None,
102-
dpH: float = None,
95+
lot_number: Optional[int],
96+
temp: Optional[int],
97+
cal_0: Optional[float] = None,
98+
cal_100: Optional[float] = None,
99+
phi_min: Optional[float] = None,
100+
phi_max: Optional[float] = None,
101+
pH_0: Optional[float] = None,
102+
dpH: Optional[float] = None,
103103
) -> BLData:
104104
"""Parses a raw BioLector CSV file into a BLData object.
105105
@@ -138,29 +138,39 @@ def _parse(
138138
When the file contents do not match with a known BioLector result file format.
139139
"""
140140
parser = get_parser(filepath)
141-
data = parser.parse(filepath, lot_number, temp, cal_0, cal_100, phi_min, phi_max, pH_0, dpH)
141+
data = parser.parse(
142+
filepath,
143+
lot_number=lot_number,
144+
temp=temp,
145+
cal_0=cal_0,
146+
cal_100=cal_100,
147+
phi_min=phi_min,
148+
phi_max=phi_max,
149+
pH_0=pH_0,
150+
dpH=dpH,
151+
)
142152

143153
if (not data.measurements.empty) and drop_incomplete_cycles:
144154
index_names, measurements = utils._unindex(data.measurements)
145155
latest_full_cycle = utils._last_full_cycle(measurements)
146156
measurements = measurements[measurements.cycle <= latest_full_cycle]
147-
data._measurements = utils._reindex(measurements, index_names)
157+
data._measurements = utils._reindex(measurements, index_names) # type: ignore
148158

149159
return data
150160

151161

152162
def parse(
153-
filepaths,
163+
filepaths: Union[str, Sequence[str]],
154164
*,
155165
drop_incomplete_cycles: bool = True,
156-
lot_number: int = None,
157-
temp: int = None,
158-
cal_0: float = None,
159-
cal_100: float = None,
160-
phi_min: float = None,
161-
phi_max: float = None,
162-
pH_0: float = None,
163-
dpH: float = None,
166+
lot_number: Optional[int] = None,
167+
temp: Optional[int] = None,
168+
cal_0: Optional[float] = None,
169+
cal_100: Optional[float] = None,
170+
phi_min: Optional[float] = None,
171+
phi_max: Optional[float] = None,
172+
pH_0: Optional[float] = None,
173+
dpH: Optional[float] = None,
164174
) -> BLData:
165175
"""Parses a raw BioLector CSV file into a BLData object and applies calibration.
166176

bletl/growth.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import typing
3+
from typing import Dict, Optional, Sequence, Tuple, Union
34

45
import arviz
56
import calibr8
@@ -10,7 +11,7 @@
1011
try:
1112
import pytensor.tensor as pt
1213
except ModuleNotFoundError:
13-
import aesara.tensor as pt
14+
import aesara.tensor as pt # type: ignore
1415

1516

1617
_log = logging.getLogger(__file__)
@@ -22,13 +23,13 @@ class GrowthRateResult:
2223
def __init__(
2324
self,
2425
*,
25-
t_data: numpy.ndarray,
26-
t_segments: numpy.ndarray,
27-
y: numpy.ndarray,
26+
t_data: Union[Sequence[float], numpy.ndarray],
27+
t_segments: Union[Sequence[float], numpy.ndarray],
28+
y: Union[Sequence[float], numpy.ndarray],
2829
calibration_model: calibr8.CalibrationModel,
29-
switchpoints: typing.Dict[float, str],
30+
switchpoints: Dict[float, str],
3031
pmodel: pm.Model,
31-
theta_map: dict,
32+
theta_map: Dict[str, numpy.ndarray],
3233
):
3334
"""Creates a result object of a growth rate analysis.
3435
@@ -47,9 +48,9 @@ def __init__(
4748
theta_map : dict
4849
the PyMC MAP estimate
4950
"""
50-
self._t_data = t_data
51-
self._t_segments = t_segments
52-
self._y = y
51+
self._t_data = numpy.asarray(t_data)
52+
self._t_segments = numpy.asarray(t_segments)
53+
self._y = numpy.asarray(y)
5354
self._switchpoints = switchpoints
5455
self.calibration_model = calibration_model
5556
self._pmodel = pmodel
@@ -73,17 +74,17 @@ def y(self) -> numpy.ndarray:
7374
return self._y
7475

7576
@property
76-
def switchpoints(self) -> typing.Dict[float, str]:
77+
def switchpoints(self) -> Dict[float, str]:
7778
"""Dictionary (by time) of known and detected switchpoints."""
7879
return self._switchpoints
7980

8081
@property
81-
def known_switchpoints(self) -> typing.Tuple[float]:
82+
def known_switchpoints(self) -> Tuple[float, ...]:
8283
"""Time values of previously known switchpoints in the model."""
8384
return tuple(t for t, label in self.switchpoints.items() if label != "detected")
8485

8586
@property
86-
def detected_switchpoints(self) -> typing.Tuple[float]:
87+
def detected_switchpoints(self) -> Tuple[float, ...]:
8788
"""Time values of switchpoints that were autodetected from the fit."""
8889
return tuple(t for t, label in self.switchpoints.items() if label == "detected")
8990

@@ -93,12 +94,12 @@ def pmodel(self) -> pm.Model:
9394
return self._pmodel
9495

9596
@property
96-
def theta_map(self) -> dict:
97+
def theta_map(self) -> Dict[str, numpy.ndarray]:
9798
"""MAP estimate of the model parameters."""
9899
return self._theta_map
99100

100101
@property
101-
def idata(self) -> typing.Optional[arviz.InferenceData]:
102+
def idata(self) -> Optional[arviz.InferenceData]:
102103
"""ArviZ InferenceData object of the MCMC trace."""
103104
return self._idata
104105

@@ -113,18 +114,20 @@ def x_map(self) -> numpy.ndarray:
113114
return self.theta_map["X"]
114115

115116
@property
116-
def mu_mcmc(self) -> typing.Optional[numpy.ndarray]:
117+
def mu_mcmc(self) -> Optional[numpy.ndarray]:
117118
"""Posterior samples of growth rates in segments between data points."""
118119
if not self.idata:
119120
return None
121+
assert hasattr(self.idata, "posterior")
120122
return self.idata.posterior.mu_t.stack(sample=("chain", "draw")).values.T
121123

122124
@property
123-
def x_mcmc(self) -> typing.Optional[numpy.ndarray]:
125+
def x_mcmc(self) -> Optional[numpy.ndarray]:
124126
"""Posterior samples of biomass curve."""
125-
if not self.idata:
127+
if self.idata is None:
126128
return None
127-
return self._idata.posterior["X"].stack(sample=("chain", "draw")).T
129+
assert hasattr(self.idata, "posterior")
130+
return self.idata.posterior["X"].stack(sample=("chain", "draw")).T
128131

129132
def sample(self, **kwargs) -> None:
130133
"""Runs MCMC sampling with default settings on the growth model.
@@ -157,8 +160,8 @@ def _make_random_walk(
157160
nu: float = 1,
158161
length: int,
159162
student_t: bool,
160-
initval: numpy.ndarray = None,
161-
dims: typing.Optional[str] = None,
163+
initval: Optional[numpy.ndarray] = None,
164+
dims: Optional[str] = None,
162165
):
163166
"""Create a random walk with either a Normal or Student-t distribution.
164167
@@ -215,7 +218,11 @@ def _make_random_walk(
215218

216219

217220
def _get_smoothed_mu(
218-
t: numpy.ndarray, y: numpy.ndarray, cm_cdw: calibr8.CalibrationModel, *, clip=0.5
221+
t: Sequence[float],
222+
y: Sequence[float],
223+
cm_cdw: calibr8.CalibrationModel,
224+
*,
225+
clip: float = 0.5,
219226
) -> numpy.ndarray:
220227
"""Calculate a rough estimate of the specific growth rate from smoothed observations.
221228
@@ -236,10 +243,10 @@ def _get_smoothed_mu(
236243
A vector of specific growth rates.
237244
"""
238245
# apply moving average to reduce backscatter noise
239-
y = numpy.convolve(y, numpy.ones(5) / 5, "same")
246+
yarr = numpy.convolve(y, numpy.ones(5) / 5, "same")
240247

241248
# convert to biomass
242-
X = cm_cdw.predict_independent(y)
249+
X = cm_cdw.predict_independent(yarr)
243250

244251
# calculate growth rate
245252
dX = numpy.diff(X)
@@ -259,17 +266,17 @@ def _get_smoothed_mu(
259266

260267

261268
def fit_mu_t(
262-
t: typing.Sequence[float],
263-
y: typing.Sequence[float],
269+
t: Sequence[float],
270+
y: Sequence[float],
264271
calibration_model: calibr8.CalibrationModel,
265272
*,
266-
switchpoints: typing.Optional[typing.Union[typing.Sequence[float], typing.Dict[float, str]]] = None,
273+
switchpoints: Optional[Union[Sequence[float], Dict[float, str]]] = None,
267274
mcmc_samples: int = 0,
268275
mu_prior: float = 0,
269276
drift_scale: float,
270277
nu: float = 5,
271278
x0_prior: float = 0.25,
272-
student_t: typing.Optional[bool] = None,
279+
student_t: Optional[bool] = None,
273280
switchpoint_prob: float = 0.01,
274281
replicate_id: str = "unnamed",
275282
):
@@ -357,7 +364,7 @@ def fit_mu_t(
357364
mu_segments = []
358365
i_from = 0
359366
for i, t_switch in enumerate(t_switchpoints_known):
360-
i_to = numpy.argmax(t > t_switch)
367+
i_to = int(numpy.argmax(t > t_switch))
361368
i_len = len(t[i_from:i_to])
362369
name = f"mu_phase_{i}"
363370
slc = slice(i_from, i_to)
@@ -460,10 +467,10 @@ def fit_mu_t(
460467

461468
def detect_switchpoints(
462469
switchpoint_prob: float,
463-
t_data: typing.Sequence[float],
470+
t_data: Sequence[float],
464471
pmodel: pm.Model,
465-
theta_map: typing.Dict[str, numpy.ndarray],
466-
) -> typing.Dict[float, str]:
472+
theta_map: Dict[str, numpy.ndarray],
473+
) -> Dict[float, str]:
467474
"""Helper function to detect switchpoints from a fitted random walk.
468475
469476
Parameters
@@ -509,15 +516,15 @@ def detect_switchpoints(
509516
# To get our <number of segments> length vector to align with the <number of points>,
510517
# we prepend a 0.5 as a placeholder for the CDF of the initial point of the random walk.
511518
cdf_evals += [0.5, *numpy.exp(logcdfs)]
512-
cdf_evals = numpy.array(cdf_evals)
513-
if len(cdf_evals) != len(t_data) - 1:
519+
cdf_evals_arr = numpy.array(cdf_evals)
520+
if len(cdf_evals_arr) != len(t_data) - 1:
514521
raise Exception(
515-
f"Failed to find all random walk segments. Found {len(cdf_evals)}, expected {len(t_data) - 1}."
522+
f"Failed to find all random walk segments. Found {len(cdf_evals_arr)}, expected {len(t_data) - 1}."
516523
)
517524
# Filter for the elements that lie outside of the [0.005, 0.995] interval (if switchpoint_prob=0.01).
518525
significance_mask = numpy.logical_or(
519-
cdf_evals < (switchpoint_prob / 2),
520-
cdf_evals > (1 - switchpoint_prob / 2),
526+
cdf_evals_arr < (switchpoint_prob / 2),
527+
cdf_evals_arr > (1 - switchpoint_prob / 2),
521528
)
522529
# Collect switchpoint information from points with significant CDF values.
523530
# Here we don't need to filter known switchpoints, because these correspond to the first

bletl/parsing/bl1.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def calibrate_with_lot(self, data: BLData, lot_number: Optional[int] = None, tem
9292
def calibrate_with_parameters(
9393
self,
9494
data: BLData,
95-
cal_0: float = None,
96-
cal_100: float = None,
97-
phi_min: float = None,
98-
phi_max: float = None,
99-
pH_0: float = None,
100-
dpH: float = None,
95+
cal_0: Optional[float] = None,
96+
cal_100: Optional[float] = None,
97+
phi_min: Optional[float] = None,
98+
phi_max: Optional[float] = None,
99+
pH_0: Optional[float] = None,
100+
dpH: Optional[float] = None,
101101
):
102102
def process_backscatter(raw_data_df, cycle_ref_df, global_ref):
103103
"""
@@ -182,14 +182,14 @@ def process_DO(raw_data_df, cal_0, cal_100):
182182
def parse(
183183
self,
184184
filepath,
185-
lot_number: int = None,
186-
temp: int = None,
187-
cal_0: float = None,
188-
cal_100: float = None,
189-
phi_min: float = None,
190-
phi_max: float = None,
191-
pH_0: float = None,
192-
dpH: float = None,
185+
lot_number: Optional[int] = None,
186+
temp: Optional[int] = None,
187+
cal_0: Optional[float] = None,
188+
cal_100: Optional[float] = None,
189+
phi_min: Optional[float] = None,
190+
phi_max: Optional[float] = None,
191+
pH_0: Optional[float] = None,
192+
dpH: Optional[float] = None,
193193
):
194194
headerlines, data = split_header_data(filepath)
195195

@@ -476,6 +476,8 @@ def fetch_calibration_data(lot_number: int, temp: int):
476476
Dictionary containing calibration data.
477477
Can be readily used in calibration function.
478478
"""
479+
assert utils.__spec__ is not None
480+
assert utils.__spec__.origin is not None
479481
module_path = pathlib.Path(utils.__spec__.origin).parents[0]
480482
calibration_file = pathlib.Path(module_path, "cache", "CalibrationLot.ini")
481483

0 commit comments

Comments
 (0)