Skip to content

Commit 8ff5a5e

Browse files
authored
Merge pull request #758 from bashtage/improve-typing
TYP: Improve typing accuracy
2 parents 38beda6 + cc9580a commit 8ff5a5e

File tree

6 files changed

+84
-13
lines changed

6 files changed

+84
-13
lines changed

arch/bootstrap/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1271,15 +1271,15 @@ def _resample(self) -> tuple[tuple[ArrayLike, ...], dict[str, ArrayLike]]:
12711271
pos_data: list[Union[NDArray, pd.DataFrame, pd.Series]] = []
12721272
for values in self._args:
12731273
if isinstance(values, (pd.Series, pd.DataFrame)):
1274-
assert isinstance(indices, NDArray)
1274+
assert isinstance(indices, np.ndarray)
12751275
pos_data.append(values.iloc[indices])
12761276
else:
12771277
assert isinstance(values, np.ndarray)
12781278
pos_data.append(values[indices])
12791279
named_data: dict[str, Union[NDArray, pd.DataFrame, pd.Series]] = {}
12801280
for key, values in self._kwargs.items():
12811281
if isinstance(values, (pd.Series, pd.DataFrame)):
1282-
assert isinstance(indices, NDArray)
1282+
assert isinstance(indices, np.ndarray)
12831283
named_data[key] = values.iloc[indices]
12841284
else:
12851285
assert isinstance(values, np.ndarray)

arch/tests/univariate/test_mean.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from arch.typing import Literal
3030
from arch.univariate.base import (
3131
ARCHModel,
32+
ARCHModelFixedResult,
3233
ARCHModelForecast,
3334
ARCHModelResult,
3435
_align_forecast,
@@ -139,7 +140,7 @@ def simulated_data(request):
139140
for a, b in itertools.product(simple_mean_models, analytic_volatility_processes)
140141
],
141142
)
142-
def forecastable_model(request):
143+
def forecastable_model(request) -> tuple[ARCHModelResult, ARCHModelFixedResult]:
143144
mod: ARCHModel
144145
vol: VolatilityProcess
145146
mod, vol = request.param

arch/unitroot/unitroot.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
diag,
1818
diff,
1919
empty,
20+
float64,
2021
full,
2122
hstack,
2223
inf,
@@ -235,7 +236,7 @@ def _autolag_ols_low_memory(
235236
trendx.append(empty((nobs, 0)))
236237
else:
237238
if "tt" in trend:
238-
tt = arange(1, nobs + 1, dtype=float)[:, None] ** 2
239+
tt = arange(1, nobs + 1, dtype=float64)[:, None] ** 2
239240
tt *= sqrt(5) / float(nobs) ** (5 / 2)
240241
trendx.append(tt)
241242
if "t" in trend:
@@ -402,6 +403,7 @@ def _df_select_lags(
402403
if max_lags is None:
403404
max_lags = int(ceil(12.0 * power(nobs / 100.0, 1 / 4.0)))
404405
max_lags = max(min(max_lags, max_max_lags), 0)
406+
assert isinstance(max_lags, int)
405407
if max_lags > 119:
406408
warnings.warn(
407409
"The value of max_lags was not specified and has been calculated as "
@@ -1970,8 +1972,8 @@ def auto_bandwidth(
19701972
float
19711973
The estimated optimal bandwidth.
19721974
"""
1973-
y = ensure1d(y, "y")
1974-
if y.shape[0] < 2:
1975+
y_arr = ensure1d(y, "y")
1976+
if y_arr.shape[0] < 2:
19751977
raise ValueError("Data must contain more than one observation")
19761978

19771979
lower_kernel = kernel.lower()
@@ -1987,12 +1989,12 @@ def auto_bandwidth(
19871989
else:
19881990
raise ValueError("Unknown kernel")
19891991

1990-
n = int(4 * ((len(y) / 100) ** n_power))
1992+
n = int(4 * ((len(y_arr) / 100) ** n_power))
19911993
sig = (n + 1) * [0]
19921994

19931995
for i in range(n + 1):
1994-
a = list(y[i:])
1995-
b = list(y[: len(y) - i])
1996+
a = list(y_arr[i:])
1997+
b = list(y_arr[: len(y_arr) - i])
19961998
sig[i] = int(npsum([i * j for (i, j) in zip(a, b)]))
19971999

19982000
sigma_m1 = sig[1 : len(sig)] # sigma without the 1st element
@@ -2018,6 +2020,6 @@ def auto_bandwidth(
20182020
else: # kernel == "qs":
20192021
gamma = 1.3221 * (((s2 / s0) ** 2) ** t_power)
20202022

2021-
bandwidth = gamma * power(len(y), t_power)
2023+
bandwidth = gamma * power(len(y_arr), t_power)
20222024

20232025
return bandwidth

arch/univariate/base.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from copy import deepcopy
88
import datetime as dt
99
from functools import cached_property
10-
from typing import Any, Callable, Optional, Union, cast
10+
from typing import Any, Callable, Optional, Union, cast, overload
1111
import warnings
1212

1313
import numpy as np
@@ -410,6 +410,38 @@ def _fit_parameterless_model(
410410
deepcopy(self),
411411
)
412412

413+
@overload
414+
def _loglikelihood(
415+
self,
416+
parameters: Float64Array,
417+
sigma2: Float64Array,
418+
backcast: Union[float, Float64Array],
419+
var_bounds: Float64Array,
420+
) -> float: # pragma: no cover
421+
... # pragma: no cover
422+
423+
@overload
424+
def _loglikelihood(
425+
self,
426+
parameters: Float64Array,
427+
sigma2: Float64Array,
428+
backcast: Union[float, Float64Array],
429+
var_bounds: Float64Array,
430+
individual: Literal[False] = ...,
431+
) -> float: # pragma: no cover
432+
... # pragma: no cover
433+
434+
@overload
435+
def _loglikelihood(
436+
self,
437+
parameters: Float64Array,
438+
sigma2: Float64Array,
439+
backcast: Union[float, Float64Array],
440+
var_bounds: Float64Array,
441+
individual: Literal[True] = ...,
442+
) -> Float64Array: # pragma: no cover
443+
... # pragma: no cover
444+
413445
def _loglikelihood(
414446
self,
415447
parameters: Float64Array,
@@ -706,6 +738,7 @@ def fit(
706738
if starting_values is not None:
707739
assert sv is not None
708740
sv = ensure1d(sv, "starting_values")
741+
assert isinstance(sv, (np.ndarray, pd.Series))
709742
valid = sv.shape[0] == num_params
710743
if a.shape[0] > 0:
711744
satisfies_constraints = a.dot(sv) - b >= 0
@@ -1362,7 +1395,7 @@ def _set_tight_x(axis: Axes, index: pd.Index) -> None:
13621395
ax = fig.add_subplot(2, 1, 1)
13631396
ax.plot(self._index.values, self.resid / self.conditional_volatility)
13641397
ax.set_title("Standardized Residuals")
1365-
ax.axes.xaxis.set_ticklabels([])
1398+
ax.set_xticklabels([])
13661399
_set_tight_x(ax, self._index)
13671400

13681401
ax = fig.add_subplot(2, 1, 2)

arch/univariate/mean.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from collections.abc import Mapping, Sequence
77
import copy
8-
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
8+
from typing import TYPE_CHECKING, Callable, Optional, Union, cast, overload
99

1010
import numpy as np
1111
import pandas as pd
@@ -546,7 +546,9 @@ def _check_specification(self) -> None:
546546
if isinstance(self._x, pd.Series):
547547
self._x = pd.DataFrame(self._x)
548548
elif self._x.ndim == 1:
549+
assert isinstance(self._x, np.ndarray)
549550
self._x = np.asarray(self._x)[:, None]
551+
assert isinstance(self._x, (np.ndarray, pd.DataFrame))
550552
if self._x.ndim != 2 or self._x.shape[0] != self._y.shape[0]:
551553
raise ValueError(
552554
"x must be nobs by n, where nobs is the same as "
@@ -1723,6 +1725,38 @@ def resids(
17231725
def starting_values(self) -> Float64Array:
17241726
return np.r_[super().starting_values(), 0.0]
17251727

1728+
@overload
1729+
def _loglikelihood(
1730+
self,
1731+
parameters: Float64Array,
1732+
sigma2: Float64Array,
1733+
backcast: Union[float, Float64Array],
1734+
var_bounds: Float64Array,
1735+
) -> float: # pragma: no cover
1736+
... # pragma: no cover
1737+
1738+
@overload
1739+
def _loglikelihood(
1740+
self,
1741+
parameters: Float64Array,
1742+
sigma2: Float64Array,
1743+
backcast: Union[float, Float64Array],
1744+
var_bounds: Float64Array,
1745+
individual: Literal[False] = ...,
1746+
) -> float: # pragma: no cover
1747+
... # pragma: no cover
1748+
1749+
@overload
1750+
def _loglikelihood(
1751+
self,
1752+
parameters: Float64Array,
1753+
sigma2: Float64Array,
1754+
backcast: Union[float, Float64Array],
1755+
var_bounds: Float64Array,
1756+
individual: Literal[True] = ...,
1757+
) -> Float64Array: # pragma: no cover
1758+
... # pragma: no cover
1759+
17261760
def _loglikelihood(
17271761
self,
17281762
parameters: Float64Array,

arch/univariate/volatility.py

+1
Original file line numberDiff line numberDiff line change
@@ -2111,6 +2111,7 @@ def compute_variance(
21112111
var_bounds: Float64Array,
21122112
) -> Float64Array:
21132113
lam = parameters[0] if self._estimate_lam else self.lam
2114+
assert isinstance(lam, float)
21142115
return ewma_recursion(lam, resids, sigma2, resids.shape[0], float(backcast))
21152116

21162117
def constraints(self) -> tuple[Float64Array, Float64Array]:

0 commit comments

Comments
 (0)