Skip to content

Commit 5bcd574

Browse files
FEAT: make ParameterFlattener public (#546)
1 parent 6e61809 commit 5bcd574

File tree

4 files changed

+36
-5
lines changed

4 files changed

+36
-5
lines changed

src/tensorwaves/optimizer/minuit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from tqdm.auto import tqdm
1212

1313
from tensorwaves.interface import Estimator, FitResult, Optimizer, ParameterValue
14-
15-
from ._parameter import ParameterFlattener
16-
from .callbacks import Callback, _create_log # pyright: ignore[reportPrivateUsage]
14+
from tensorwaves.optimizer.callbacks import (
15+
Callback,
16+
_create_log, # pyright: ignore[reportPrivateUsage]
17+
)
18+
from tensorwaves.optimizer.parameter import ParameterFlattener
1719

1820
if TYPE_CHECKING:
1921
from collections.abc import Iterable, Mapping

src/tensorwaves/optimizer/_parameter.py renamed to src/tensorwaves/optimizer/parameter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Utility to splice (flatten) and merge (unflatten) complex parameters for 'real-only' optimizers."""
2+
13
from __future__ import annotations
24

35
from typing import TYPE_CHECKING
@@ -9,6 +11,13 @@
911

1012

1113
class ParameterFlattener:
14+
"""Utility-class to flatten complex parameters.
15+
16+
Args:
17+
parameters: Original parameter-dictionary (unflattened). Non-complex values will
18+
not be affected by any method.
19+
"""
20+
1221
def __init__(self, parameters: Mapping[str, ParameterValue]) -> None:
1322
self.__real_imag_to_complex_name: dict[str, str] = {}
1423
self.__complex_to_real_imag_name: dict[str, tuple[str, str]] = {}
@@ -23,6 +32,17 @@ def __init__(self, parameters: Mapping[str, ParameterValue]) -> None:
2332
def unflatten(
2433
self, flattened_parameters: dict[str, float]
2534
) -> dict[str, ParameterValue]:
35+
"""Reverse the flattening operation.
36+
37+
Takes a parameter-dictionary and merges all real and imaginary values whose
38+
respective keys have been registered in the constructor of the
39+
`ParameterFlattener` into a complex number. Specifically, while this works also
40+
on inputs which have not been generated by :meth:`.flatten` their outputs might
41+
be unexpected.
42+
43+
Args:
44+
flattened_parameters: parameter `dict` whose values are to be unflattened.
45+
"""
2646
parameters: dict[str, ParameterValue] = {
2747
k: v
2848
for k, v in flattened_parameters.items()
@@ -39,6 +59,15 @@ def unflatten(
3959
return parameters
4060

4161
def flatten(self, parameters: Mapping[str, ParameterValue]) -> dict[str, float]:
62+
"""Flatten the parameter-values whose keys have been registered in the constructor.
63+
64+
Splits all complex values whose keys have been registered in the constructor of
65+
`ParameterFlattener` into their real and imaginary parts. Their keys are
66+
predetermined by the constructor. Other key-value pairs remain unchanged.
67+
68+
Args:
69+
parameters: parameter `dict` whose values are to be flattened.
70+
"""
4271
flattened_parameters: dict[str, float] = {}
4372
for par_name, value in parameters.items():
4473
if isinstance(value, complex):

src/tensorwaves/optimizer/scipy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from tensorwaves.function._backend import raise_missing_module_error
1313
from tensorwaves.interface import Estimator, FitResult, Optimizer, ParameterValue
14+
from tensorwaves.optimizer.parameter import ParameterFlattener
1415

15-
from ._parameter import ParameterFlattener
1616
from .callbacks import Callback, _create_log # pyright: ignore[reportPrivateUsage]
1717

1818
if TYPE_CHECKING:

tests/optimizer/test_parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from tensorwaves.optimizer._parameter import ParameterFlattener
3+
from tensorwaves.optimizer.parameter import ParameterFlattener
44

55

66
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)