Skip to content

Commit 00ce851

Browse files
authored
FEAT: implement ChainedDataTransformer (#470)
1 parent 340004b commit 00ce851

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

.cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
"qrules",
170170
"rightarrow",
171171
"rtfd",
172+
"rtol",
172173
"scipy",
173174
"sdist",
174175
"seealso",

src/tensorwaves/data/_attrs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
from typing import Iterable
4+
5+
from tensorwaves.interface import DataTransformer
6+
7+
8+
def to_tuple(items: Iterable[DataTransformer]) -> tuple[DataTransformer, ...]:
9+
return tuple(items)

src/tensorwaves/data/transform.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from typing import TYPE_CHECKING, Mapping
55

6+
from attrs import field, frozen
7+
68
from tensorwaves.function import PositionalArgumentFunction
79
from tensorwaves.function.sympy import (
810
_get_free_symbols, # pyright: ignore[reportPrivateUsage]
@@ -12,10 +14,38 @@
1214
)
1315
from tensorwaves.interface import DataSample, DataTransformer, Function
1416

17+
from ._attrs import to_tuple
18+
1519
if TYPE_CHECKING: # pragma: no cover
1620
import sympy as sp
1721

1822

23+
@frozen
24+
class ChainedDataTransformer(DataTransformer):
25+
"""Combine multiple `.DataTransformer` classes into one.
26+
27+
Args:
28+
transformer: Ordered list of transformers that you want to chain.
29+
extend: Set to `True` in order to keep keys of each output `.DataSample` and
30+
collect them into the final, chained `.DataSample`.
31+
"""
32+
33+
transformers: tuple[DataTransformer, ...] = field(converter=to_tuple)
34+
extend: bool = True
35+
36+
def __call__(self, data: DataSample) -> DataSample:
37+
new_data = dict(data)
38+
weights = new_data.get("weights")
39+
for transformer in self.transformers:
40+
if self.extend:
41+
new_data.update(transformer(new_data))
42+
else:
43+
new_data = transformer(new_data)
44+
if weights is not None:
45+
new_data["weights"] = weights
46+
return new_data
47+
48+
1949
class IdentityTransformer(DataTransformer):
2050
"""`.DataTransformer` that leaves a `.DataSample` intact."""
2151

tests/data/test_transform.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,47 @@
1-
# pylint: disable=invalid-name
1+
from __future__ import annotations
2+
23
import numpy as np
34
import pytest
45
import sympy as sp
56
from numpy import sqrt
67

7-
from tensorwaves.data import IdentityTransformer, SympyDataTransformer
8+
from tensorwaves.data.transform import (
9+
ChainedDataTransformer,
10+
IdentityTransformer,
11+
SympyDataTransformer,
12+
)
13+
14+
15+
class TestChainedDataTransformer:
16+
@pytest.mark.parametrize("extend", [False, True])
17+
def test_identity_chain(self, extend: bool):
18+
x, y, v, w = sp.symbols("x y v w")
19+
transform1 = _create_transformer({v: 2 * x - 5, w: -0.2 * y + 3})
20+
transform2 = _create_transformer({x: 0.5 * (v + 5), y: 5 * (3 - w)})
21+
chained_transform = ChainedDataTransformer([transform1, transform2], extend)
22+
rng = np.random.default_rng(seed=0)
23+
data = {"x": rng.uniform(size=100), "y": rng.uniform(size=100)}
24+
transformed_data = chained_transform(data)
25+
for key in data: # pylint: disable=consider-using-dict-items
26+
np.testing.assert_allclose(data[key], transformed_data[key], rtol=1e-13)
27+
if extend:
28+
assert set(transformed_data) == {"x", "y", "v", "w"}
29+
else:
30+
assert set(transformed_data) == {"x", "y"}
31+
32+
def test_single_chain(self):
33+
transform = IdentityTransformer()
34+
chained_transform = ChainedDataTransformer([transform])
35+
data = {
36+
"x": np.ones(5),
37+
"y": np.ones(5),
38+
}
39+
assert data == chained_transform(data)
40+
assert data is not chained_transform(data) # DataSample returned as new dict
41+
42+
43+
def _create_transformer(expressions: dict[sp.Symbol, sp.Expr]) -> SympyDataTransformer:
44+
return SympyDataTransformer.from_sympy(expressions, backend="jax")
845

946

1047
class TestIdentityTransformer:

0 commit comments

Comments
 (0)