Skip to content

Commit 9c10f77

Browse files
fix: The ordering of terms differentiated formulae was potentially incorrect after v1.1.0. (#236)
1 parent 419f4d0 commit 9c10f77

File tree

2 files changed

+26
-28
lines changed

2 files changed

+26
-28
lines changed

formulaic/formula.py

+20-28
Original file line numberDiff line numberDiff line change
@@ -322,20 +322,17 @@ def differentiate(
322322
use_sympy: bool = False,
323323
) -> _SelfType:
324324
"""
325-
EXPERIMENTAL: Take the gradient of this formula. When used a linear
326-
regression, evaluating a trained model on model matrices generated by
327-
this formula is equivalent to estimating the gradient of that fitted
328-
form with respect to `wrt`.
325+
Take the gradient of this formula with respect to the variables in
326+
`wrt`.
327+
328+
When used a linear regression context, making predictions based on the
329+
model matrices generated the differentiated formula is equivalent to
330+
estimating the gradient of the fitted model with respect to `wrt`.
329331
330332
Args:
331333
wrt: The variables with respect to which the gradient should be
332334
taken.
333335
use_sympy: Whether to use sympy to perform symbolic differentiation.
334-
335-
336-
Notes:
337-
This method is provisional and may be removed in any future major
338-
version.
339336
"""
340337

341338

@@ -482,27 +479,25 @@ def differentiate( # pylint: disable=redefined-builtin
482479
use_sympy: bool = False,
483480
) -> SimpleFormula:
484481
"""
485-
EXPERIMENTAL: Take the gradient of this formula. When used a linear
486-
regression, evaluating a trained model on model matrices generated by
487-
this formula is equivalent to estimating the gradient of that fitted
488-
form with respect to `wrt`.
482+
Take the gradient of this formula with respect to the variables in
483+
`wrt`.
484+
485+
When used a linear regression context, making predictions based on the
486+
model matrices generated the differentiated formula is equivalent to
487+
estimating the gradient of the fitted model with respect to `wrt`.
489488
490489
Args:
491490
wrt: The variables with respect to which the gradient should be
492491
taken.
493492
use_sympy: Whether to use sympy to perform symbolic differentiation.
494-
495-
496-
Notes:
497-
This method is provisional and may be removed in any future major
498-
version.
499493
"""
500494
return SimpleFormula(
501495
[
502496
differentiate_term(term, wrt, use_sympy=use_sympy)
503497
for term in self.__terms
504498
],
505-
_ordering=self.ordering,
499+
# Preserve term ordering even if differentiation modifies degrees/etc.
500+
_ordering=OrderingMethod.NONE,
506501
)
507502

508503
def get_model_matrix(
@@ -784,20 +779,17 @@ def differentiate( # pylint: disable=redefined-builtin
784779
use_sympy: bool = False,
785780
) -> SimpleFormula:
786781
"""
787-
EXPERIMENTAL: Take the gradient of this formula. When used a linear
788-
regression, evaluating a trained model on model matrices generated by
789-
this formula is equivalent to estimating the gradient of that fitted
790-
form with respect to `wrt`.
782+
Take the gradient of this formula with respect to the variables in
783+
`wrt`.
784+
785+
When used a linear regression context, making predictions based on the
786+
model matrices generated the differentiated formula is equivalent to
787+
estimating the gradient of the fitted model with respect to `wrt`.
791788
792789
Args:
793790
wrt: The variables with respect to which the gradient should be
794791
taken.
795792
use_sympy: Whether to use sympy to perform symbolic differentiation.
796-
797-
798-
Notes:
799-
This method is provisional and may be removed in any future major
800-
version.
801793
"""
802794
return cast(
803795
SimpleFormula,

tests/test_formula.py

+6
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ def test_differentiate(self):
197197
assert f.differentiate("a") == ["1", "0", "0"]
198198
assert f.differentiate("c") == ["0", "0", "0"]
199199

200+
g = Formula("a:b + b:c + c:d - 1")
201+
assert g.differentiate("b") == ["a", "c", "0"] # order preserved
202+
200203
def test_differentiate_with_sympy(self):
201204
pytest.importorskip("sympy")
202205
f = Formula("a + b + log(c) - 1")
@@ -208,6 +211,9 @@ def test_differentiate_with_sympy(self):
208211
"rhs": ["0", "(1/x)"],
209212
}
210213

214+
h = Formula("a + {a**2} + b - 1").differentiate("a", use_sympy=True)
215+
assert h == ["1", "(2*a)", "0"] # order preserved
216+
211217
def test_repr(self, formula_expr, formula_exprs):
212218
assert repr(formula_expr) == "1 + a + b + c + a:b + a:c + b:c + a:b:c"
213219
assert repr(formula_exprs) == ".lhs:\n a\n.rhs:\n 1 + b"

0 commit comments

Comments
 (0)