1+ """Utility to splice (flatten) and merge (unflatten) complex parameters for 'real-only' optimizers."""
2+
13from __future__ import annotations
24
35from typing import TYPE_CHECKING
911
1012
1113class ParameterFlattener :
14+ """Utility-class to flatten complex parameters.
15+
16+ Args:
17+ parameters: Original parameter-dictionary (unflattened). Non-complex values will
18+ not be affected by any method.
19+ """
20+
1221 def __init__ (self , parameters : Mapping [str , ParameterValue ]) -> None :
1322 self .__real_imag_to_complex_name : dict [str , str ] = {}
1423 self .__complex_to_real_imag_name : dict [str , tuple [str , str ]] = {}
@@ -23,6 +32,17 @@ def __init__(self, parameters: Mapping[str, ParameterValue]) -> None:
2332 def unflatten (
2433 self , flattened_parameters : dict [str , float ]
2534 ) -> dict [str , ParameterValue ]:
35+ """Reverse the flattening operation.
36+
37+ Takes a parameter-dictionary and merges all real and imaginary values whose
38+ respective keys have been registered in the constructor of the
39+ `ParameterFlattener` into a complex number. Specifically, while this works also
40+ on inputs which have not been generated by :meth:`.flatten` their outputs might
41+ be unexpected.
42+
43+ Args:
44+ flattened_parameters: parameter `dict` whose values are to be unflattened.
45+ """
2646 parameters : dict [str , ParameterValue ] = {
2747 k : v
2848 for k , v in flattened_parameters .items ()
@@ -39,6 +59,15 @@ def unflatten(
3959 return parameters
4060
4161 def flatten (self , parameters : Mapping [str , ParameterValue ]) -> dict [str , float ]:
62+ """Flatten the parameter-values whose keys have been registered in the constructor.
63+
64+ Splits all complex values whose keys have been registered in the constructor of
65+ `ParameterFlattener` into their real and imaginary parts. Their keys are
66+ predetermined by the constructor. Other key-value pairs remain unchanged.
67+
68+ Args:
69+ parameters: parameter `dict` whose values are to be flattened.
70+ """
4271 flattened_parameters : dict [str , float ] = {}
4372 for par_name , value in parameters .items ():
4473 if isinstance (value , complex ):
0 commit comments