Skip to content

Commit aa36e66

Browse files
authored
ENH: set data keys as first positional arguments (#488)
1 parent 8ebaccf commit aa36e66

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

src/tensorwaves/function/sympy/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ def create_parametrized_function(
109109
[0.0, 0.0, 0.0, 0.0, 0.0]
110110
"""
111111
free_symbols = _get_free_symbols(expression)
112-
sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
112+
parameter_set = set(parameters)
113+
parameter_symbols = sorted(free_symbols & parameter_set, key=lambda s: s.name)
114+
data_symbols = sorted(free_symbols - parameter_set, key=lambda s: s.name)
115+
sorted_symbols = tuple(data_symbols + parameter_symbols) # for partial+gradient
113116
lambdified_function = _lambdify_normal_or_fast(
114117
expression=expression,
115118
symbols=sorted_symbols,

tests/function/test_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def function(self) -> ParametrizedBackendFunction:
3434
return create_parametrized_function(expression, parameters, backend="numpy")
3535

3636
def test_argument_order(self, function: ParametrizedBackendFunction):
37-
assert function.argument_order == ("c_1", "c_2", "c_3", "c_4", "x")
37+
"""Test whether data arguments come before parameters."""
38+
assert function.argument_order == ("x", "c_1", "c_2", "c_3", "c_4")
3839

3940
@pytest.mark.parametrize(
4041
("test_data", "expected_results"),

tests/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_create_cached_function(backend):
112112

113113
assert isinstance(cached_function, ParametrizedBackendFunction)
114114
assert isinstance(cache_transformer, SympyDataTransformer)
115-
assert cached_function.argument_order == ("a", "c", "f0", "x")
115+
assert cached_function.argument_order == ("f0", "x", "a", "c") # data args first
116116
assert set(cached_function.parameters) == {"a", "c"}
117117
assert set(cache_transformer.functions) == {"f0", "x"}
118118

0 commit comments

Comments
 (0)