Skip to content

Commit 5a77428

Browse files
committed
More type annotations
1 parent 72807e5 commit 5a77428

File tree

13 files changed

+105
-70
lines changed

13 files changed

+105
-70
lines changed

pyproject.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,15 @@ artifacts = [
128128

129129
[tool.pyright]
130130
typeCheckingMode = "standard"
131-
ignore = ["apps/marimo"]
132-
strict = ["docs", "apps"]
131+
ignore = [
132+
"apps/marimo",
133+
"src/itables_for_dash/ITable.py",
134+
]
135+
strict = [
136+
"docs",
137+
"apps",
138+
"src/itables_for_dash",
139+
"src/itables/__init__.py",
140+
"src/itables/dash.py",
141+
"src/itables/streamlit.py",
142+
]

src/itables/datatables_format.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import json
22
import warnings
3-
from typing import Optional
3+
from typing import Any, Optional, Sequence, Union
44

55
import numpy as np
66
import pandas as pd
77
import pandas.io.formats.format as fmt
88

9+
from .typing import DataFrameOrSeries
10+
911
try:
1012
import polars as pl
1113
except ImportError:
1214
pl = None
1315

1416

15-
def _format_column(x, escape_html: bool):
17+
def _format_column(
18+
x: "pd.Series[Any]", escape_html: bool
19+
) -> "Union[pd.Series[Any],Sequence[Any]]":
1620
dtype_kind = x.dtype.kind
1721
if dtype_kind in ["b", "i"]:
1822
return x
@@ -28,21 +32,23 @@ def _format_column(x, escape_html: bool):
2832
# Older versions of Pandas don't have 'leading_space'
2933
x = fmt.format_array(x._values, None, justify="all") # type: ignore
3034

35+
y: "Union[pd.Series[Any], Sequence[Any]]" = x
3136
if dtype_kind == "f":
3237
try:
33-
x = np.array(x).astype(float)
38+
z = np.array(x).astype(float)
3439
except ValueError:
40+
z = x
3541
pass
3642

37-
x = [escape_non_finite_float(f) for f in x]
43+
y = [escape_non_finite_float(f) for f in z]
3844

3945
if escape_html:
40-
return [escape_html_chars(i) for i in x]
46+
return [escape_html_chars(i) for i in y]
4147

42-
return x
48+
return y
4349

4450

45-
def escape_non_finite_float(value):
51+
def escape_non_finite_float(value: Any) -> Any:
4652
"""Encode non-finite float values to strings that will be parsed by parseJSON"""
4753
if not isinstance(value, float):
4854
return value
@@ -55,7 +61,7 @@ def escape_non_finite_float(value):
5561
return value
5662

5763

58-
def escape_html_chars(value):
64+
def escape_html_chars(value: Any) -> Any:
5965
"""Escape HTML special characters"""
6066
if isinstance(value, str):
6167
from pandas.io.formats.printing import pprint_thing # type: ignore
@@ -66,7 +72,9 @@ def escape_html_chars(value):
6672
return value
6773

6874

69-
def generate_encoder(warn_on_unexpected_types=True):
75+
def generate_encoder(warn_on_unexpected_types: bool = True) -> Any:
76+
"""Generate a JSON encoder that can handle special types like numpy"""
77+
7078
class TableValuesEncoder(json.JSONEncoder):
7179
def default(self, o):
7280
if isinstance(o, (bool, int, float, str)):
@@ -97,11 +105,11 @@ def default(self, o):
97105

98106

99107
def datatables_rows(
100-
df,
108+
df: DataFrameOrSeries,
101109
column_count: Optional[int] = None,
102110
warn_on_unexpected_types: bool = False,
103111
escape_html: bool = True,
104-
):
112+
) -> str:
105113
"""Format the values in the table and return the data, row by row, as requested by DataTables"""
106114
# We iterate over columns using an index rather than the column name
107115
# to avoid an issue in case of duplicated column names #89

src/itables/downsample.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ def as_nbytes(mem: Union[int, float, str]) -> int:
3333

3434

3535
def downsample(
36-
df, max_rows: int = 0, max_columns: int = 0, max_bytes: Union[int, str] = 0
37-
):
36+
df: DataFrameOrSeries,
37+
max_rows: int = 0,
38+
max_columns: int = 0,
39+
max_bytes: Union[int, str] = 0,
40+
) -> tuple[DataFrameOrSeries, str]:
3841
"""Return a subset of the dataframe that fits the limits"""
3942
org_rows, org_columns, org_bytes = len(df), len(df.columns), nbytes(df)
4043
max_bytes_numeric = as_nbytes(max_bytes)
@@ -67,8 +70,9 @@ def downsample(
6770

6871

6972
def shrink_towards_target_aspect_ratio(
70-
rows, columns, shrink_factor, target_aspect_ratio
71-
):
73+
rows: int, columns: int, shrink_factor: float, target_aspect_ratio: float
74+
) -> tuple[int, int]:
75+
"""Return the number of rows and columns of the shrinked dataframe"""
7276
# current and target aspect ratio
7377
aspect_ratio = rows / float(columns)
7478

src/itables/javascript.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def init_notebook_mode(
135135
all_interactive: bool = True,
136136
connected: bool = GOOGLE_COLAB,
137137
dt_bundle: Optional[Union[Path, str]] = None,
138-
):
138+
) -> None:
139139
"""Load the DataTables library and the corresponding css (if connected=False),
140140
and (if all_interactive=True), activate the DataTables representation for all the Pandas DataFrames and Series.
141141
@@ -189,7 +189,8 @@ def init_notebook_mode(
189189
display(HTML(generate_init_offline_itables_html(dt_bundle)))
190190

191191

192-
def get_animated_logo(display_logo_when_loading):
192+
def get_animated_logo(display_logo_when_loading: bool) -> str:
193+
"""Return the HTML for the loading logo of ITables"""
193194
if not display_logo_when_loading:
194195
return ""
195196
return f"<a href=https://mwouts.github.io/itables/>{read_package_file('logo/loading.svg')}</a>"
@@ -265,15 +266,15 @@ def _flat_header(df, show_index):
265266
return header
266267

267268

268-
def _tfoot_from_thead(thead):
269+
def _tfoot_from_thead(thead: str) -> str:
269270
header_rows = thead.split("</tr>")
270271
last_row = header_rows[-1]
271272
assert not last_row.strip(), last_row
272273
header_rows = header_rows[:-1]
273274
return "".join(row + "</tr>" for row in header_rows[::-1] if "<tr" in row) + "\n"
274275

275276

276-
def get_keys_to_be_evaluated(data) -> list[list[Union[int, str]]]:
277+
def get_keys_to_be_evaluated(data: Any) -> list[list[Union[int, str]]]:
277278
"""
278279
This function returns the keys that need to be evaluated
279280
in the ITable arguments
@@ -296,7 +297,7 @@ def get_keys_to_be_evaluated(data) -> list[list[Union[int, str]]]:
296297
return keys_to_be_evaluated
297298

298299

299-
def replace_value(template, pattern, value):
300+
def replace_value(template: str, pattern: str, value: str) -> str:
300301
"""Set the given pattern to the desired value in the template,
301302
after making sure that the pattern is found exactly once."""
302303
count = template.count(pattern)
@@ -311,15 +312,15 @@ def replace_value(template, pattern, value):
311312
return template.replace(pattern, value)
312313

313314

314-
def _datatables_repr_(df):
315+
def _datatables_repr_(df: DataFrameOrSeries) -> str:
315316
return to_html_datatable(df, connected=_CONNECTED)
316317

317318

318319
def to_html_datatable(
319320
df: DataFrameOrSeries,
320321
caption: Optional[str] = None,
321322
**kwargs: Unpack[ITableOptions],
322-
):
323+
) -> str:
323324
"""
324325
Return the HTML representation of the given
325326
dataframe as an interactive datatable
@@ -548,8 +549,14 @@ def get_itables_extension_arguments(
548549

549550

550551
def warn_if_selected_rows_are_not_visible(
551-
selected_rows, full_row_count, data_row_count, warn_on_selected_rows_not_rendered
552-
):
552+
selected_rows: Optional[Sequence[int]],
553+
full_row_count: int,
554+
data_row_count: int,
555+
warn_on_selected_rows_not_rendered: bool,
556+
) -> Sequence[int]:
557+
"""
558+
Issue a warning if the selected rows are not within the range of rendered rows.
559+
"""
553560
if selected_rows is None:
554561
return []
555562

@@ -589,7 +596,11 @@ def warn_if_selected_rows_are_not_visible(
589596
return [i for i in selected_rows if i < bottom_limit or i >= top_limit]
590597

591598

592-
def check_table_id(table_id: Optional[str], kwargs, df=None) -> str:
599+
def check_table_id(
600+
table_id: Optional[str],
601+
kwargs: Union[ITableOptions, DTForITablesOptions],
602+
df: Optional[DataFrameOrSeries] = None,
603+
) -> str:
593604
"""Make sure that the table_id is a valid HTML id.
594605
595606
See also https://stackoverflow.com/questions/70579/html-valid-id-attribute-values
@@ -614,7 +625,9 @@ def check_table_id(table_id: Optional[str], kwargs, df=None) -> str:
614625
return table_id
615626

616627

617-
def set_default_options(kwargs: ITableOptions, *, use_to_html: bool, app_mode: bool):
628+
def set_default_options(
629+
kwargs: ITableOptions, *, use_to_html: bool, app_mode: bool
630+
) -> None:
618631
if not app_mode:
619632
kwargs["connected"] = kwargs.get(
620633
"connected", ("dt_url" in kwargs) or _CONNECTED
@@ -681,7 +694,7 @@ def html_table_from_template(
681694
connected: bool,
682695
display_logo_when_loading: bool,
683696
kwargs: DTForITablesOptions,
684-
):
697+
) -> str:
685698
if "css" in kwargs:
686699
raise TypeError(
687700
"The 'css' argument has been deprecated, see the new "
@@ -800,7 +813,7 @@ def _filter_control(control, downsampling_warning):
800813
return None
801814

802815

803-
def safe_reset_index(df):
816+
def safe_reset_index(df: pd.DataFrame) -> pd.DataFrame:
804817
try:
805818
return df.reset_index()
806819
except ValueError:
@@ -825,6 +838,6 @@ def show(
825838
df: DataFrameOrSeries,
826839
caption: Optional[str] = None,
827840
**kwargs: Unpack[ITableOptions],
828-
):
841+
) -> None:
829842
"""Render the given dataframe as an interactive datatable"""
830843
display(HTML(to_html_datatable(df, caption, **kwargs)))

src/itables/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
)
140140

141141
"""Check that options have correct types"""
142-
warn_on_unexpected_option_type = (
142+
warn_on_unexpected_option_type: bool = (
143143
warn_on_undocumented_option and typing.is_typeguard_available()
144144
)
145145
if warn_on_unexpected_option_type:

src/itables/sample_dfs.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime, timedelta
44
from functools import lru_cache
55
from itertools import cycle
6-
from typing import Any, cast
6+
from typing import Any, Sequence, cast
77

88
import numpy as np
99
import pandas as pd
@@ -28,9 +28,9 @@
2828
"timedelta",
2929
]
3030

31-
PANDAS_VERSION_MAJOR, PANDAS_VERSION_MINOR, _ = pd.__version__.split(".", 2)
32-
PANDAS_VERSION_MAJOR = int(PANDAS_VERSION_MAJOR)
33-
PANDAS_VERSION_MINOR = int(PANDAS_VERSION_MINOR)
31+
_pd_version_major, _pd_version_minor, _ = pd.__version__.split(".", 2)
32+
PANDAS_VERSION_MAJOR = int(_pd_version_major)
33+
PANDAS_VERSION_MINOR = int(_pd_version_minor)
3434
if PANDAS_VERSION_MAJOR == 0:
3535
COLUMN_TYPES = [type for type in COLUMN_TYPES if type != "boolean"]
3636
if PANDAS_VERSION_MAJOR == 2 and PANDAS_VERSION_MINOR == 1:
@@ -84,19 +84,20 @@ def get_countries(html: bool = True, climate_zone: bool = False) -> pd.DataFrame
8484
return df
8585

8686

87-
def get_population():
87+
def get_population() -> "pd.Series[float]":
8888
"""A Pandas Series with the world population (from the world bank data)"""
8989
return pd.read_csv(find_package_file("samples/population.csv")).set_index(
9090
"Country"
9191
)["SP.POP.TOTL"]
9292

9393

94-
def get_indicators():
94+
def get_indicators() -> pd.DataFrame:
9595
"""A Pandas DataFrame with a subset of the world bank indicators"""
9696
return pd.read_csv(find_package_file("samples/indicators.csv"))
9797

9898

99-
def get_df_complex_index():
99+
def get_df_complex_index() -> pd.DataFrame:
100+
"""A Pandas DataFrame with a complex index"""
100101
df = get_countries()
101102
df = df.reset_index().set_index(["region", "country"])
102103
df.columns = pd.MultiIndex.from_arrays(
@@ -116,7 +117,7 @@ def get_df_complex_index():
116117
return df
117118

118119

119-
def get_dict_of_test_dfs(N=100, M=100) -> dict[str, pd.DataFrame]:
120+
def get_dict_of_test_dfs(N: int = 100, M: int = 100) -> dict[str, pd.DataFrame]:
120121
NM_values = np.reshape(np.linspace(start=0.0, stop=1.0, num=N * M), (N, M))
121122

122123
return {
@@ -281,7 +282,7 @@ def get_dict_of_test_dfs(N=100, M=100) -> dict[str, pd.DataFrame]:
281282
}
282283

283284

284-
def get_dict_of_polars_test_dfs(N=100, M=100) -> dict[str, Any]:
285+
def get_dict_of_polars_test_dfs(N: int = 100, M: int = 100) -> dict[str, Any]:
285286
import polars as pl
286287
import pyarrow as pa
287288

@@ -300,7 +301,8 @@ def get_dict_of_polars_test_dfs(N=100, M=100) -> dict[str, Any]:
300301
return polars_dfs
301302

302303

303-
def get_dict_of_test_series():
304+
def get_dict_of_test_series() -> dict[str, Any]:
305+
"""A dictionary of Pandas Series"""
304306
series = {}
305307
for df_name, df in get_dict_of_test_dfs().items():
306308
if len(df.columns) > 6:
@@ -313,7 +315,8 @@ def get_dict_of_test_series():
313315
return series
314316

315317

316-
def get_dict_of_polars_test_series():
318+
def get_dict_of_polars_test_series() -> dict[str, Any]:
319+
"""A dictionary of Polars Series"""
317320
import polars as pl
318321
import pyarrow as pa
319322

@@ -342,7 +345,8 @@ def generate_date_series():
342345
return pd.Series(pd.date_range("1677-09-23", "2262-04-10", freq="D"))
343346

344347

345-
def generate_random_series(rows, type):
348+
def generate_random_series(rows: int, type: str) -> Any:
349+
"""Generate a random Pandas Series of the given type and number of rows"""
346350
if type == "bool":
347351
return pd.Series(np.random.binomial(n=1, p=0.5, size=rows), dtype=bool)
348352
if type == "boolean":
@@ -382,15 +386,18 @@ def generate_random_series(rows, type):
382386
raise NotImplementedError(type)
383387

384388

385-
def generate_random_df(rows, columns, column_types=COLUMN_TYPES):
389+
def generate_random_df(
390+
rows: int, columns: int, column_types: Sequence[str] = COLUMN_TYPES
391+
) -> pd.DataFrame:
386392
rows = int(rows)
387393
types = np.random.choice(column_types, size=columns)
388-
columns = [
394+
columns_names = [
389395
"Column{}OfType{}".format(col, type.title()) for col, type in enumerate(types)
390396
]
391397

392398
series = {
393-
col: generate_random_series(rows, type) for col, type in zip(columns, types)
399+
col: generate_random_series(rows, type)
400+
for col, type in zip(columns_names, types)
394401
}
395402
index = pd.Index(range(rows))
396403
for x in series.values():
@@ -399,7 +406,7 @@ def generate_random_df(rows, columns, column_types=COLUMN_TYPES):
399406
return pd.DataFrame(series)
400407

401408

402-
def get_pandas_styler():
409+
def get_pandas_styler() -> Any:
403410
"""This function returns a Pandas Styler object
404411
405412
Cf. https://pandas.pydata.org/docs/user_guide/style.html

0 commit comments

Comments
 (0)