Skip to content

Commit 3faf4f3

Browse files
committed
Some mypy issues
1 parent 05f8553 commit 3faf4f3

2 files changed

Lines changed: 20 additions & 9 deletions

File tree

ci/run_mypy.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
set -e -E -u -o pipefail
1616

17+
mypy --version
1718
mypy \
1819
--config-file ./pyproject.toml \
1920
--exclude=legateboost/test \

legateboost/models/tree.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from dataclasses import dataclass
44
from enum import IntEnum
5-
from typing import Any, Callable, List, Sequence, Union, cast
5+
from typing import Any, Callable, Dict, List, Sequence, Union, cast
66

77
import numpy as np
88
import numpy.typing as npt
@@ -333,13 +333,23 @@ class TreeAsNumpy:
333333
# this container is a convenience to not have 7 function arguments
334334
class OnnxSoa:
335335
def __init__(self, size: int, n_outputs: int) -> None:
336-
self.nodes_modes = np.full(size, "BRANCH_LEQ")
337-
self.nodes_featureids = np.full(size, -1, dtype=np.int32)
338-
self.nodes_truenodeids = np.full(size, -1, dtype=np.int32)
339-
self.nodes_falsenodeids = np.full(size, -1, dtype=np.int32)
340-
self.nodes_nodeids = np.arange(size, dtype=np.int32)
341-
self.nodes_values = np.full(size, -1.0, dtype=np.float64)
342-
self.leaf_weights = np.full((size, n_outputs), -1.0, dtype=np.float64)
336+
self.nodes_modes: npt.NDArray[str] = np.full(size, "BRANCH_LEQ")
337+
self.nodes_featureids: npt.NDArray[np.int32] = np.full(
338+
size, -1, dtype=np.int32
339+
)
340+
self.nodes_truenodeids: npt.NDArray[np.int32] = np.full(
341+
size, -1, dtype=np.int32
342+
)
343+
self.nodes_falsenodeids: npt.NDArray[np.int32] = np.full(
344+
size, -1, dtype=np.int32
345+
)
346+
self.nodes_nodeids: npt.NDArray[np.int32] = np.arange(size, dtype=np.int32)
347+
self.nodes_values: npt.NDArray[np.float64] = np.full(
348+
size, -1.0, dtype=np.float64
349+
)
350+
self.leaf_weights: npt.NDArray[np.float64] = np.full(
351+
(size, n_outputs), -1.0, dtype=np.float64
352+
)
343353

344354
def recurse_tree(
345355
self, tree: TreeAsNumpy, soa: OnnxSoa, old_node_idx: int, new_node_idx: int
@@ -395,7 +405,7 @@ def to_onnx(self, X: cn.array) -> Any:
395405

396406
onnx_nodes = []
397407

398-
kwargs = {}
408+
kwargs: Dict[str, Any] = {}
399409
# TreeEnsembleRegressor asks us to pass these as tensors when X.dtype is double
400410
# we simply pass a set of indices as leaf weights and then add a node later to
401411
# look up the (vector valued) leaf weights

0 commit comments

Comments
 (0)