Skip to content

Commit 4c10317

Browse files
authored
FIX: set i/o types for function implementations (#522)
1 parent bb8b60b commit 4c10317

File tree

8 files changed

+21
-18
lines changed

8 files changed

+21
-18
lines changed

benchmarks/ampform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from ampform.helicity import HelicityModel
2121
from qrules.combinatorics import StateDefinition
2222

23+
from tensorwaves.function import ParametrizedBackendFunction
2324
from tensorwaves.interface import (
2425
DataSample,
2526
FitResult,
27+
Function,
2628
ParameterValue,
2729
ParametrizedFunction,
2830
)
@@ -55,7 +57,7 @@ def formulate_amplitude_model(
5557

5658
def create_function(
5759
model: HelicityModel, backend: str, max_complexity: int | None = None
58-
) -> ParametrizedFunction:
60+
) -> ParametrizedBackendFunction:
5961
return create_parametrized_function(
6062
expression=model.expression.doit(),
6163
parameters=model.parameter_defaults,
@@ -66,7 +68,7 @@ def create_function(
6668

6769
def generate_data(
6870
model: HelicityModel,
69-
function: ParametrizedFunction,
71+
function: Function[DataSample, np.ndarray],
7072
data_sample_size: int,
7173
phsp_sample_size: int,
7274
backend: str,
@@ -103,7 +105,7 @@ def generate_data(
103105
def fit(
104106
data: DataSample,
105107
phsp: DataSample,
106-
function: ParametrizedFunction,
108+
function: ParametrizedFunction[DataSample, np.ndarray],
107109
initial_parameters: Mapping[str, ParameterValue],
108110
backend: str,
109111
) -> FitResult:

benchmarks/expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _generate_domain(
6161

6262
def _generate_data(
6363
size: int,
64-
function: Function,
64+
function: Function[DataSample, np.ndarray],
6565
rng: np.random.Generator,
6666
bunch_size: int = 10_000,
6767
) -> DataSample:

src/tensorwaves/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class IntensityDistributionGenerator(DataGenerator):
7171
def __init__(
7272
self,
7373
domain_generator: DataGenerator,
74-
function: Function,
74+
function: Function[DataSample, np.ndarray],
7575
domain_transformer: DataTransformer | None = None,
7676
bunch_size: int = 50_000,
7777
) -> None:

src/tensorwaves/data/transform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ._attrs import to_tuple
1717

1818
if TYPE_CHECKING: # pragma: no cover
19+
import numpy as np
1920
import sympy as sp
2021

2122

@@ -55,7 +56,9 @@ def __call__(self, data: DataSample) -> DataSample:
5556
class SympyDataTransformer(DataTransformer):
5657
"""Implementation of a `.DataTransformer`."""
5758

58-
def __init__(self, functions: Mapping[str, Function]) -> None:
59+
def __init__(
60+
self, functions: Mapping[str, Function[DataSample, np.ndarray]]
61+
) -> None:
5962
if any(not isinstance(f, Function) for f in functions.values()):
6063
msg = (
6164
f"Not all values in the mapping are an instance of {Function.__name__}"
@@ -64,7 +67,7 @@ def __init__(self, functions: Mapping[str, Function]) -> None:
6467
self.__functions = dict(functions)
6568

6669
@property
67-
def functions(self) -> dict[str, Function]:
70+
def functions(self) -> dict[str, Function[DataSample, np.ndarray]]:
6871
"""Read-only access to the internal mapping of functions."""
6972
return dict(self.__functions)
7073

src/tensorwaves/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def create_cached_function(
2929
backend: str,
3030
free_parameters: Iterable[sp.Symbol],
3131
use_cse: bool = True,
32-
) -> tuple[ParametrizedFunction, DataTransformer]:
32+
) -> tuple[ParametrizedFunction[DataSample, np.ndarray], DataTransformer]:
3333
"""Create a function and data transformer for cached computations.
3434
3535
Once it is known which parameters in an expression are to be optimized, this
@@ -118,7 +118,7 @@ class ChiSquared(Estimator):
118118

119119
def __init__( # noqa: PLR0913
120120
self,
121-
function: ParametrizedFunction,
121+
function: ParametrizedFunction[DataSample, np.ndarray],
122122
domain: DataSample,
123123
observed_values: np.ndarray,
124124
weights: np.ndarray | None = None,
@@ -185,7 +185,7 @@ class UnbinnedNLL(Estimator):
185185

186186
def __init__( # noqa: PLR0913
187187
self,
188-
function: ParametrizedFunction,
188+
function: ParametrizedFunction[DataSample, np.ndarray],
189189
data: DataSample,
190190
phsp: DataSample,
191191
phsp_volume: float = 1.0,

src/tensorwaves/function/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from __future__ import annotations
44

55
import inspect
6-
from typing import TYPE_CHECKING, Callable, Iterable, Mapping
6+
from typing import Callable, Iterable, Mapping
77

88
import attrs
9+
import numpy as np
910
from attrs import field, frozen
1011

1112
from tensorwaves.interface import (
@@ -15,9 +16,6 @@
1516
ParametrizedFunction,
1617
)
1718

18-
if TYPE_CHECKING:
19-
import numpy as np
20-
2119

2220
def _all_str(
2321
_: PositionalArgumentFunction, __: attrs.Attribute, value: Iterable[str]
@@ -66,7 +64,7 @@ def _to_tuple(argument_order: Iterable[str]) -> tuple[str, ...]:
6664

6765

6866
@frozen
69-
class PositionalArgumentFunction(Function):
67+
class PositionalArgumentFunction(Function[DataSample, np.ndarray]):
7068
"""Wrapper around a function with positional arguments.
7169
7270
This class provides a :meth:`~.Function.__call__` that can take a `.DataSample` for
@@ -90,7 +88,7 @@ def __call__(self, data: DataSample) -> np.ndarray:
9088
return self.function(*args)
9189

9290

93-
class ParametrizedBackendFunction(ParametrizedFunction):
91+
class ParametrizedBackendFunction(ParametrizedFunction[DataSample, np.ndarray]):
9492
"""Implements `.ParametrizedFunction` for a specific computational back-end.
9593
9694
.. seealso:: :func:`.create_parametrized_function`

src/tensorwaves/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __call__(self, data: InputType) -> OutputType: ...
4141
"""Allowed types for parameter values."""
4242

4343

44-
class ParametrizedFunction(Function[DataSample, np.ndarray]):
44+
class ParametrizedFunction(Function[InputType, OutputType]):
4545
"""Interface of a callable function.
4646
4747
A `ParametrizedFunction` identifies certain variables in a mathematical expression

tests/optimizer/test_fit_simple_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def generate_domain(
3838
def generate_data(
3939
size: int,
4040
boundaries: dict[str, tuple[float, float]],
41-
function: Function,
41+
function: Function[DataSample, np.ndarray],
4242
rng: np.random.Generator,
4343
bunch_size: int = 10_000,
4444
) -> DataSample:

0 commit comments

Comments
 (0)