Skip to content

Commit f4e14ff

Browse files
committed
ty for type checking
1 parent 7a1f78e commit f4e14ff

File tree

6 files changed

+158
-22
lines changed

6 files changed

+158
-22
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,11 @@ dev = [
3333
"pytest-cov>=7.0.0",
3434
"ruff>=0.14.8",
3535
"scipy>=1.16.3",
36+
"ty>=0.0.5",
3637
"xarray-fancy-repr>=0.0.2",
3738
"zarr>=3.1.5",
3839
]
39-
docs = [
40-
"jupyter-book>=2.1.0",
41-
]
40+
docs = ["jupyter-book>=2.1.0"]
4241
vim = ["jupyterlab-vim>=4.1.4"]
4342

4443
[tool.pytest.ini_options]

src/linked_indices/benchmark_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def timeit_benchmark(
18-
stmt: str | callable,
18+
stmt: str | Callable[[], Any],
1919
setup: str = "pass",
2020
globals: dict[str, Any] | None = None,
2121
repeat: int = 7,

src/linked_indices/example_data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def speech_annotations() -> pd.DataFrame:
6565
[4.5, 2.0, "how"],
6666
[7.0, 2.5, "are"],
6767
]
68-
return pd.DataFrame(data, columns=["onset", "duration", "word"])
68+
return pd.DataFrame(data, columns=["onset", "duration", "word"]) # type: ignore[arg-type]
6969

7070

7171
def multi_level_annotations() -> tuple[pd.DataFrame, pd.DataFrame]:
@@ -104,7 +104,8 @@ def multi_level_annotations() -> tuple[pd.DataFrame, pd.DataFrame]:
104104
[6.0, 3.5, "test", "noun"],
105105
]
106106
word_df = pd.DataFrame(
107-
word_data, columns=["onset", "duration", "word", "part_of_speech"]
107+
word_data,
108+
columns=["onset", "duration", "word", "part_of_speech"], # type: ignore[arg-type]
108109
)
109110

110111
# Phoneme-level annotations (more fine-grained)
@@ -118,7 +119,7 @@ def multi_level_annotations() -> tuple[pd.DataFrame, pd.DataFrame]:
118119
[6.0, 1.2, "t"],
119120
[7.2, 2.3, "st"],
120121
]
121-
phoneme_df = pd.DataFrame(phoneme_data, columns=["onset", "duration", "phoneme"])
122+
phoneme_df = pd.DataFrame(phoneme_data, columns=["onset", "duration", "phoneme"]) # type: ignore[arg-type]
122123

123124
return word_df, phoneme_df
124125

@@ -166,7 +167,7 @@ def mixed_event_annotations() -> pd.DataFrame:
166167
[0.0, 5.0, "image_A", "stimulus"],
167168
[5.0, 5.0, "image_B", "stimulus"],
168169
]
169-
return pd.DataFrame(data, columns=["onset", "duration", "label", "event_type"])
170+
return pd.DataFrame(data, columns=["onset", "duration", "label", "event_type"]) # type: ignore[arg-type]
170171

171172

172173
def generate_audio_signal(

src/linked_indices/multi_interval_index.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections.abc import Mapping
22
from dataclasses import dataclass, field
33
from numbers import Integral
4-
from typing import Any
4+
from typing import Any, cast
55

66
from collections import defaultdict
77

@@ -322,6 +322,7 @@ def from_variables(cls, variables, *, options):
322322
{name: var}, options=options
323323
)
324324

325+
assert interval_index is not None, f"No interval index found for {dim_name}"
325326
interval_dims[dim_name] = IntervalDimInfo(
326327
dim_name=dim_name,
327328
coord_name=coord_name,
@@ -347,8 +348,10 @@ def from_variables(cls, variables, *, options):
347348
debug=debug,
348349
)
349350

350-
def create_variables(self, variables):
351-
idx_variables = {}
351+
def create_variables(
352+
self, variables: Mapping[Any, Variable] | None = None
353+
) -> dict[Any, Variable]:
354+
idx_variables: dict[Any, Variable] = {}
352355

353356
idx_variables.update(self._continuous_index.create_variables(variables))
354357

@@ -431,7 +434,7 @@ def _get_overlapping_slice(
431434
)
432435

433436
# Find which intervals overlap
434-
overlaps = interval_index.overlaps(query_interval)
437+
overlaps = interval_index.overlaps(query_interval) # type: ignore[union-attr]
435438
overlap_indices = np.where(overlaps)[0]
436439

437440
if self._debug:
@@ -603,7 +606,7 @@ def isel(
603606
continue
604607

605608
overlap_slice = self._get_overlapping_slice(
606-
info.interval_index.index,
609+
cast(pd.IntervalIndex, info.interval_index.index),
607610
time_range,
608611
)
609612

@@ -681,7 +684,7 @@ def sel(self, labels, method=None, tolerance=None):
681684
) is not _MISSING or (
682685
dim_name := self._label_to_dim.get(key, _MISSING)
683686
) is not _MISSING:
684-
info = self._interval_dims[dim_name]
687+
info = self._interval_dims[cast(str, dim_name)]
685688
# Use label index if key is a label, otherwise use interval index
686689
idx = info.label_indexes.get(key, info.interval_index)
687690
sel_res = idx.sel({key: value}, method=method, tolerance=tolerance)
@@ -714,7 +717,7 @@ def sel(self, labels, method=None, tolerance=None):
714717

715718
# Find overlapping intervals
716719
overlap_slice = self._get_overlapping_slice(
717-
info.interval_index.index,
720+
cast(pd.IntervalIndex, info.interval_index.index),
718721
time_range,
719722
)
720723

src/linked_indices/nd_index.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def from_variables(cls, variables: Mapping[str, Variable], *, options):
168168

169169
nd_coords[name] = NDCoord(
170170
name=name,
171-
dims=var.dims,
171+
dims=tuple(str(d) for d in var.dims),
172172
values=var.values,
173173
)
174174

@@ -184,11 +184,13 @@ def from_variables(cls, variables: Mapping[str, Variable], *, options):
184184

185185
return cls(nd_coords=nd_coords, slice_method=slice_method, debug=debug)
186186

187-
def create_variables(self, variables):
187+
def create_variables(
188+
self, variables: Mapping[Any, Variable] | None = None
189+
) -> dict[Any, Variable]:
188190
"""Create index variables for the dataset."""
189191
from xarray.core.variable import Variable as XrVariable
190192

191-
idx_variables = {}
193+
idx_variables: dict[Any, Variable] = {}
192194
for name, ndc in self._nd_coords.items():
193195
idx_variables[name] = XrVariable(dims=ndc.dims, data=ndc.values)
194196
return idx_variables
@@ -236,7 +238,7 @@ def _binary_search(
236238
self, flat_values: np.ndarray, value: float, method: str | None, coord_name: str
237239
) -> int:
238240
"""O(log n) binary search for sorted arrays."""
239-
idx = np.searchsorted(flat_values, value)
241+
idx = int(np.searchsorted(flat_values, value))
240242
n = len(flat_values)
241243

242244
if method == "nearest":
@@ -276,7 +278,7 @@ def _find_nearest_index(
276278
idx = (
277279
searchsorted_idx
278280
if searchsorted_idx is not None
279-
else np.searchsorted(flat_values, value)
281+
else int(np.searchsorted(flat_values, value))
280282
)
281283

282284
if idx == 0:
@@ -287,9 +289,9 @@ def _find_nearest_index(
287289
left_val = flat_values[idx - 1]
288290
right_val = flat_values[idx]
289291
if abs(value - left_val) <= abs(value - right_val):
290-
return idx - 1
292+
return int(idx - 1)
291293
else:
292-
return idx
294+
return int(idx)
293295

294296
def _linear_search(
295297
self, values: np.ndarray, value: float, method: str | None, coord_name: str

0 commit comments

Comments
 (0)