Skip to content

Commit 0c5f437

Browse files
authored
Distributing typing information (#127) (#135)
* Typing set up * Typing information - auto generated * Update typing information - switched to inline
1 parent b384b8c commit 0c5f437

File tree

6 files changed

+37
-23
lines changed

6 files changed

+37
-23
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ dependencies = [
2929
"meds == 0.3.3",
3030
]
3131

32+
[tool.setuptools]
33+
include-package-data = true
34+
35+
[tool.setuptools.package-data]
36+
your_package = ["*.pyi", "py.typed"]
37+
3238
[tool.setuptools_scm]
3339

3440
[project.scripts]

src/aces/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
131131

132132

133133
@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
134-
def main(cfg: DictConfig):
134+
def main(cfg: DictConfig) -> None:
135135
import os
136136
from datetime import datetime
137137
from pathlib import Path

src/aces/config.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dataclasses import field
1010
from datetime import timedelta
1111
from pathlib import Path
12+
from typing import Any
1213

1314
import networkx as nx
1415
import polars as pl
@@ -28,7 +29,7 @@
2829

2930
@dataclasses.dataclass
3031
class PlainPredicateConfig:
31-
code: str | dict
32+
code: str | dict[str, Any]
3233
value_min: float | None = None
3334
value_max: float | None = None
3435
value_min_inclusive: bool | None = None
@@ -294,7 +295,7 @@ class DerivedPredicateConfig:
294295
expr: str
295296
static: bool = False
296297

297-
def __post_init__(self):
298+
def __post_init__(self) -> None:
298299
if not self.expr:
299300
raise ValueError("Derived predicates must have a non-empty expression field.")
300301

@@ -652,7 +653,7 @@ class WindowConfig:
652653
index_timestamp: str | None = None
653654

654655
@classmethod
655-
def _check_reference(cls, reference: str):
656+
def _check_reference(cls, reference: str) -> None:
656657
"""Checks to ensure referenced events are valid."""
657658
err_str = (
658659
"Window boundary reference must be either a valid alphanumeric/'_' string "
@@ -713,7 +714,7 @@ def _parse_boundary(cls, boundary: str) -> dict[str, str]:
713714
cls._check_reference(ref)
714715
return {"referenced": ref, "offset": None, "event_bound": None, "occurs_before": None}
715716

716-
def __post_init__(self):
717+
def __post_init__(self) -> None:
717718
# Parse the has constraints from the string representation to the tuple representation
718719
if self.has is not None:
719720
for each_constraint in self.has:
@@ -1132,7 +1133,11 @@ class TaskExtractorConfig:
11321133
index_timestamp_window: str | None = None
11331134

11341135
@classmethod
1135-
def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> TaskExtractorConfig:
1136+
def load(
1137+
cls: TaskExtractorConfig,
1138+
config_path: str | Path,
1139+
predicates_path: str | Path = None,
1140+
) -> TaskExtractorConfig:
11361141
"""Load a configuration file from the given path and return it as a dict.
11371142
11381143
Args:
@@ -1420,7 +1425,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
14201425

14211426
return cls(predicates=predicate_objs, trigger=trigger, windows=windows)
14221427

1423-
def _initialize_predicates(self):
1428+
def _initialize_predicates(self) -> None:
14241429
"""Initialize the predicates tree from the configuration object and check validity.
14251430
14261431
Raises:
@@ -1467,7 +1472,7 @@ def _initialize_predicates(self):
14671472
f"Graph: {nx.write_network_text(self._predicate_dag_graph)}"
14681473
)
14691474

1470-
def _initialize_windows(self):
1475+
def _initialize_windows(self) -> None:
14711476
"""Initialize the windows tree from the configuration object and check validity.
14721477
14731478
Raises:
@@ -1614,7 +1619,7 @@ def _initialize_windows(self):
16141619

16151620
self.window_nodes = window_nodes
16161621

1617-
def __post_init__(self):
1622+
def __post_init__(self) -> None:
16181623
self._initialize_predicates()
16191624
self._initialize_windows()
16201625

@@ -1627,12 +1632,12 @@ def predicates_DAG(self) -> nx.DiGraph:
16271632
return self._predicate_dag_graph
16281633

16291634
@property
1630-
def plain_predicates(self) -> dict[str:PlainPredicateConfig]:
1635+
def plain_predicates(self) -> dict[str, PlainPredicateConfig]:
16311636
"""Returns a dictionary of plain predicates in {name: code} format."""
16321637
return {p: cfg for p, cfg in self.predicates.items() if cfg.is_plain}
16331638

16341639
@property
1635-
def derived_predicates(self) -> OrderedDict[str:DerivedPredicateConfig]:
1640+
def derived_predicates(self) -> OrderedDict[str, DerivedPredicateConfig]:
16361641
"""Returns an ordered dictionary mapping derived predicates to their configs in a proper order."""
16371642
return {
16381643
p: self.predicates[p]

src/aces/expand_shards.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def expand_shards(*shards: str) -> str:
7171
return ",".join(result)
7272

7373

74-
def main():
74+
def main() -> None:
7575
print(expand_shards(*sys.argv[1:]))
7676

7777

src/aces/types.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
"""
66

77
import dataclasses
8+
from collections.abc import Iterator
89
from datetime import timedelta
10+
from typing import Any
911

1012
import polars as pl
1113

@@ -59,14 +61,14 @@ class TemporalWindowBounds:
5961
offset: timedelta | None = None
6062

6163
# Needed to make it accessible like a tuple.
62-
def __iter__(self):
64+
def __iter__(self) -> Iterator[Any]:
6365
return (getattr(self, field.name) for field in dataclasses.fields(self))
6466

6567
# Needed to make it scriptable.
66-
def __getitem__(self, key):
68+
def __getitem__(self, key: int) -> Any:
6769
return tuple(getattr(self, field.name) for field in dataclasses.fields(self))[key]
6870

69-
def __post_init__(self):
71+
def __post_init__(self) -> None:
7072
if self.offset is None:
7173
self.offset = timedelta(0)
7274

@@ -205,7 +207,7 @@ class ToEventWindowBounds:
205207
right_inclusive: bool
206208
offset: timedelta | None = None
207209

208-
def __post_init__(self):
210+
def __post_init__(self) -> None:
209211
if self.end_event == "":
210212
raise ValueError("The 'end_event' must be a non-empty string.")
211213

@@ -226,11 +228,11 @@ def __post_init__(self):
226228
self.offset = timedelta(0)
227229

228230
# Needed to make it accessible like a tuple.
229-
def __iter__(self):
231+
def __iter__(self) -> Iterator[Any]:
230232
return (getattr(self, field.name) for field in dataclasses.fields(self))
231233

232234
# Needed to make it scriptable.
233-
def __getitem__(self, key):
235+
def __getitem__(self, key: int) -> Any:
234236
return tuple(getattr(self, field.name) for field in dataclasses.fields(self))[key]
235237

236238
@property

src/aces/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import io
22
import os
33
import sys
4+
from collections.abc import Generator
45
from contextlib import contextmanager
56
from datetime import timedelta
67

78
import hydra
8-
from bigtree import print_tree
9+
from bigtree import Node, print_tree
910
from loguru import logger
1011
from pytimeparse import parse
1112

1213

13-
def parse_timedelta(time_str: str) -> timedelta:
14+
def parse_timedelta(time_str: str = None) -> timedelta:
1415
"""Parse a time string and return a timedelta object.
1516
1617
Using time expression parser: https://github.com/wroberts/pytimeparse
@@ -40,7 +41,7 @@ def parse_timedelta(time_str: str) -> timedelta:
4041

4142

4243
@contextmanager
43-
def capture_output():
44+
def capture_output() -> Generator[io.StringIO, None, None]:
4445
"""A context manager to capture stdout output.
4546
4647
This can eventually be eliminated if https://github.com/kayjan/bigtree/issues/285 is resolved.
@@ -60,14 +61,14 @@ def capture_output():
6061
sys.stdout = old_out # Restore the original stdout
6162

6263

63-
def log_tree(node):
64+
def log_tree(node: Node) -> None:
6465
"""Logs the tree structure using logging.info."""
6566
with capture_output() as captured:
6667
print_tree(node, style="const_bold") # This will print to the captured StringIO instead of stdout
6768
logger.info("\n" + captured.getvalue()) # Log the captured output
6869

6970

70-
def hydra_loguru_init(filename) -> None:
71+
def hydra_loguru_init(filename: str) -> None:
7172
"""Must be called from a hydra main!"""
7273
hydra_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
7374
logger.add(os.path.join(hydra_path, filename))

0 commit comments

Comments
 (0)