|
2 | 2 | import warnings |
3 | 3 | from dataclasses import dataclass |
4 | 4 | from enum import IntEnum |
5 | | -from typing import Any, Callable, List, Sequence, Union, cast |
| 5 | +from typing import Any, Callable, Dict, List, Sequence, Union, cast |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | import numpy.typing as npt |
@@ -333,13 +333,23 @@ class TreeAsNumpy: |
333 | 333 | # this container is a convenience to not have 7 function arguments |
334 | 334 | class OnnxSoa: |
335 | 335 | def __init__(self, size: int, n_outputs: int) -> None: |
336 | | - self.nodes_modes = np.full(size, "BRANCH_LEQ") |
337 | | - self.nodes_featureids = np.full(size, -1, dtype=np.int32) |
338 | | - self.nodes_truenodeids = np.full(size, -1, dtype=np.int32) |
339 | | - self.nodes_falsenodeids = np.full(size, -1, dtype=np.int32) |
340 | | - self.nodes_nodeids = np.arange(size, dtype=np.int32) |
341 | | - self.nodes_values = np.full(size, -1.0, dtype=np.float64) |
342 | | - self.leaf_weights = np.full((size, n_outputs), -1.0, dtype=np.float64) |
| 336 | + self.nodes_modes: npt.NDArray[str] = np.full(size, "BRANCH_LEQ") |
| 337 | + self.nodes_featureids: npt.NDArray[np.int32] = np.full( |
| 338 | + size, -1, dtype=np.int32 |
| 339 | + ) |
| 340 | + self.nodes_truenodeids: npt.NDArray[np.int32] = np.full( |
| 341 | + size, -1, dtype=np.int32 |
| 342 | + ) |
| 343 | + self.nodes_falsenodeids: npt.NDArray[np.int32] = np.full( |
| 344 | + size, -1, dtype=np.int32 |
| 345 | + ) |
| 346 | + self.nodes_nodeids: npt.NDArray[np.int32] = np.arange(size, dtype=np.int32) |
| 347 | + self.nodes_values: npt.NDArray[np.float64] = np.full( |
| 348 | + size, -1.0, dtype=np.float64 |
| 349 | + ) |
| 350 | + self.leaf_weights: npt.NDArray[np.float64] = np.full( |
| 351 | + (size, n_outputs), -1.0, dtype=np.float64 |
| 352 | + ) |
343 | 353 |
|
344 | 354 | def recurse_tree( |
345 | 355 | self, tree: TreeAsNumpy, soa: OnnxSoa, old_node_idx: int, new_node_idx: int |
@@ -395,7 +405,7 @@ def to_onnx(self, X: cn.array) -> Any: |
395 | 405 |
|
396 | 406 | onnx_nodes = [] |
397 | 407 |
|
398 | | - kwargs = {} |
| 408 | + kwargs: Dict[str, Any] = {} |
399 | 409 | # TreeEnsembleRegressor asks us to pass these as tensors when X.dtype is double |
400 | 410 | # we simply pass a set of indices as leaf weights and then add a node later to |
401 | 411 | # look up the (vector valued) leaf weights |
|
0 commit comments