Skip to content

Commit

Permalink
further expand Solve[]'s ability
Browse files Browse the repository at this point in the history
1. bring domain check back into solve.eval, so that things like Abs()
   can be evaluated
2. create system symbols such as System`Reals for domain check
3. refactor, moving most logics out of Solve
  • Loading branch information
BlankShrimp committed Sep 20, 2023
1 parent 5f32b4d commit 8e10644
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 102 deletions.
111 changes: 24 additions & 87 deletions mathics/builtin/numbers/calculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from itertools import product
from typing import Optional, Union
from typing import Optional

import numpy as np
import sympy
Expand Down Expand Up @@ -71,23 +71,27 @@
from mathics.core.systemsymbols import (
SymbolAnd,
SymbolAutomatic,
SymbolComplex,
SymbolConditionalExpression,
SymbolD,
SymbolDerivative,
SymbolInfinity,
SymbolInfix,
SymbolInteger,
SymbolIntegrate,
SymbolLeft,
SymbolLog,
SymbolNIntegrate,
SymbolO,
SymbolReal,
SymbolRule,
SymbolSequence,
SymbolSeries,
SymbolSeriesData,
SymbolSimplify,
SymbolUndefined,
)
from mathics.eval.calculus import solve_sympy
from mathics.eval.makeboxes import format_element
from mathics.eval.nevaluator import eval_N

Expand Down Expand Up @@ -2208,105 +2212,38 @@ class Solve(Builtin):
messages = {
"eqf": "`1` is not a well-formed equation.",
"svars": 'Equations may not give solutions for all "solve" variables.',
"fulldim": "The solution set contains a full-dimensional component; use Reduce for complete solution information.",
}

# FIXME: the problem with removing the domain parameter from the outside
# is that the we can't make use of this information inside
# the evaluation method where it is may be needed.
rules = {
"Solve[eqs_, vars_, Complexes]": "Solve[eqs, vars]",
"Solve[eqs_, vars_, Reals]": (
"Cases[Solve[eqs, vars], {Rule[x_,y_?RealValuedNumberQ]}]"
),
"Solve[eqs_, vars_, Integers]": (
"Cases[Solve[eqs, vars], {Rule[x_,y_Integer]}]"
),
"Solve[eqs_, vars_]": "Solve[eqs, vars, Complexes]"
}
summary_text = "find generic solutions for variables"

def eval(self, eqs, vars, evaluation: Evaluation):
"Solve[eqs_, vars_]"
def eval(self, eqs, vars, domain, evaluation: Evaluation):
"Solve[eqs_, vars_, domain_]"

vars_original = vars
head_name = vars.get_head_name()
variables = vars
head_name = variables.get_head_name()
if head_name == "System`List":
vars = vars.elements
variables = variables.elements
else:
vars = [vars]
for var in vars:
variables = [variables]
for var in variables:
if (
(isinstance(var, Atom) and not isinstance(var, Symbol)) or
head_name in ("System`Plus", "System`Times", "System`Power") or # noqa
A_CONSTANT & var.get_attributes(evaluation.definitions)
):

evaluation.message("Solve", "ivar", vars_original)
evaluation.message("Solve", "ivar", vars)
return

vars_sympy = [var.to_sympy() for var in vars]
if None in vars_sympy:
sympy_variables = [var.to_sympy() for var in variables]
if None in sympy_variables:
evaluation.message("Solve", "ivar")
return
all_var_tuples = list(zip(vars, vars_sympy))

def cut_var_dimension(expressions: Union[Expression, list[Expression]]):
'''delete unused variables to avoid SymPy's PolynomialError
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
if not isinstance(expressions, list):
expressions = [expressions]
subset_vars = set()
subset_vars_sympy = set()
for var, var_sympy in all_var_tuples:
pattern = Pattern.create(var)
for equation in expressions:
if not equation.is_free(pattern, evaluation):
subset_vars.add(var)
subset_vars_sympy.add(var_sympy)
return subset_vars, subset_vars_sympy

def solve_sympy(equations: Union[Expression, list[Expression]]):
if not isinstance(equations, list):
equations = [equations]
equations_sympy = []
denoms_sympy = []
subset_vars, subset_vars_sympy = cut_var_dimension(equations)
for equation in equations:
if equation is SymbolTrue:
continue
elif equation is SymbolFalse:
return []
elements = equation.elements
for left, right in [(elements[index], elements[index + 1]) for index in range(len(elements) - 1)]:
# ↑ to deal with things like a==b==c==d
left = left.to_sympy()
right = right.to_sympy()
if left is None or right is None:
return []
equation_sympy = left - right
equation_sympy = sympy.together(equation_sympy)
equation_sympy = sympy.cancel(equation_sympy)
equations_sympy.append(equation_sympy)
numer, denom = equation_sympy.as_numer_denom()
denoms_sympy.append(denom)
try:
results = sympy.solve(equations_sympy, subset_vars_sympy, dict=True) # no transform_dict needed with dict=True
# Filter out results for which denominator is 0
# (SymPy should actually do that itself, but it doesn't!)
results = [
sol
for sol in results
if all(sympy.simplify(denom.subs(sol)) != 0 for denom in denoms_sympy)
]
return results
except sympy.PolynomialError:
# raised for e.g. Solve[x^2==1&&z^2==-1,{x,y,z}] when not deleting
# unused variables beforehand
return []
except NotImplementedError:
return []
except TypeError as exc:
if str(exc).startswith("expected Symbol, Function or Derivative"):
evaluation.message("Solve", "ivar", vars_original)
variable_tuples = list(zip(variables, sympy_variables))

def solve_recur(expression: Expression):
'''solve And, Or and List within the scope of sympy,
Expand Down Expand Up @@ -2334,7 +2271,7 @@ def solve_recur(expression: Expression):
inequations.append(sub_condition)
else:
inequations.append(child.to_sympy())
solutions.extend(solve_sympy(equations))
solutions.extend(solve_sympy(evaluation, equations, variables, domain))
conditions = sympy.And(*inequations)
result = [sol for sol in solutions if conditions.subs(sol)]
return result, None if solutions else conditions
Expand All @@ -2344,7 +2281,7 @@ def solve_recur(expression: Expression):
conditions = []
for child in expression.elements:
if child.has_form("Equal", 2):
solutions.extend(solve_sympy(child))
solutions.extend(solve_sympy(evaluation, child, variables, domain))
elif child.get_head_name() in ('System`And', 'System`Or'): # I don't believe List would be in here
sub_solution, sub_condition = solve_recur(child)
solutions.extend(sub_solution)
Expand All @@ -2363,8 +2300,8 @@ def solve_recur(expression: Expression):
if conditions is not None:
evaluation.message("Solve", "fulldim")
else:
if eqs.has_form("Equal", 2):
solutions = solve_sympy(eqs)
if eqs.get_head_name() == "System`Equal":
solutions = solve_sympy(evaluation, eqs, variables, domain)
else:
evaluation.message("Solve", "fulldim")
return ListExpression(ListExpression())
Expand All @@ -2374,7 +2311,7 @@ def solve_recur(expression: Expression):
return ListExpression(ListExpression())

if any(
sol and any(var not in sol for var in vars_sympy) for sol in solutions
sol and any(var not in sol for var in sympy_variables) for sol in solutions
):
evaluation.message("Solve", "svars")

Expand All @@ -2383,7 +2320,7 @@ def solve_recur(expression: Expression):
ListExpression(
*(
Expression(SymbolRule, var, from_sympy(sol[var_sympy]))
for var, var_sympy in all_var_tuples
for var, var_sympy in variable_tuples
if var_sympy in sol
),
)
Expand Down
6 changes: 3 additions & 3 deletions mathics/core/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,9 @@ def get_sort_key(self, pattern_sort=False) -> tuple:
def sameQ(self, other) -> bool:
"""Mathics SameQ"""
return (
isinstance(other, Complex)
and self.real == other.real
and self.imag == other.imag
isinstance(other, Complex) and
self.real == other.real and
self.imag == other.imag
)

def round(self, d=None) -> "Complex":
Expand Down
49 changes: 43 additions & 6 deletions mathics/core/convert/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Conversion to SymPy is handled directly in BaseElement descendants.
"""

from collections.abc import Iterable
from typing import Optional, Type, Union

import sympy
Expand All @@ -13,9 +14,6 @@
# Import the singleton class
from sympy.core.numbers import S

BasicSympy = sympy.Expr


from mathics.core.atoms import (
MATHICS3_COMPLEX_I,
Complex,
Expand All @@ -40,6 +38,7 @@
)
from mathics.core.list import ListExpression
from mathics.core.number import FP_MANTISA_BINARY_DIGITS
from mathics.core.rules import Pattern
from mathics.core.symbols import (
Symbol,
SymbolFalse,
Expand All @@ -62,16 +61,21 @@
SymbolGreater,
SymbolGreaterEqual,
SymbolIndeterminate,
SymbolIntegers,
SymbolLess,
SymbolLessEqual,
SymbolMatrixPower,
SymbolO,
SymbolPi,
SymbolPiecewise,
SymbolReals,
SymbolSlot,
SymbolUnequal,
)

BasicSympy = sympy.Expr


SymbolPrime = Symbol("Prime")
SymbolRoot = Symbol("Root")
SymbolRootSum = Symbol("RootSum")
Expand Down Expand Up @@ -130,6 +134,39 @@ def to_sympy_matrix(data, **kwargs) -> Optional[sympy.MutableDenseMatrix]:
return None


def apply_domain_to_symbols(symbols: Iterable[sympy.Symbol], domain) -> dict[sympy.Symbol, sympy.Symbol]:
"""Create new sympy symbols with domain applied.
Return a dict maps old to new.
"""
# FIXME: this substitute solution would break when Solve[Abs[x]==3, x],where x=-3 and x=3.
# However, substituting symbol prior to actual solving would cause sympy to have biased assumption,
# it would refuse to solve Abs() when symbol is in Complexes
result = {}
for symbol in symbols:
if domain == SymbolReals:
new_symbol = sympy.Symbol(repr(symbol), real=True)
elif domain == SymbolIntegers:
new_symbol = sympy.Symbol(repr(symbol), integer=True)
else:
new_symbol = symbol
result[symbol] = new_symbol
return result


def cut_dimension(evaluation, expressions: Union[Expression, list[Expression]], symbols: Iterable[sympy.Symbol]) -> set[sympy.Symbol]:
'''delete unused variables to avoid SymPy's PolynomialError
: Not a zero-dimensional system in e.g. Solve[x^2==1&&z^2==-1,{x,y,z}]'''
if not isinstance(expressions, list):
expressions = [expressions]
subset = set()
for symbol in symbols:
pattern = Pattern.create(symbol)
for equation in expressions:
if not equation.is_free(pattern, evaluation):
subset.add(symbol)
return subset


class SympyExpression(BasicSympy):
is_Function = True
nargs = None
Expand Down Expand Up @@ -363,9 +400,9 @@ def old_from_sympy(expr) -> BaseElement:
if is_Cn_expr(name):
return Expression(SymbolC, Integer(int(name[1:])))
if name.startswith(sympy_symbol_prefix):
name = name[len(sympy_symbol_prefix) :]
name = name[len(sympy_symbol_prefix):]
if name.startswith(sympy_slot_prefix):
index = name[len(sympy_slot_prefix) :]
index = name[len(sympy_slot_prefix):]
return Expression(SymbolSlot, Integer(int(index)))
elif expr.is_NumberSymbol:
name = str(expr)
Expand Down Expand Up @@ -517,7 +554,7 @@ def old_from_sympy(expr) -> BaseElement:
*[from_sympy(arg) for arg in expr.args]
)
if name.startswith(sympy_symbol_prefix):
name = name[len(sympy_symbol_prefix) :]
name = name[len(sympy_symbol_prefix):]
args = [from_sympy(arg) for arg in expr.args]
builtin = sympy_to_mathics.get(name)
if builtin is not None:
Expand Down
3 changes: 3 additions & 0 deletions mathics/core/systemsymbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
SymbolCompile = Symbol("System`Compile")
SymbolCompiledFunction = Symbol("System`CompiledFunction")
SymbolComplex = Symbol("System`Complex")
SymbolComplexes = Symbol("System`Complexes")
SymbolComplexInfinity = Symbol("System`ComplexInfinity")
SymbolCondition = Symbol("System`Condition")
SymbolConditionalExpression = Symbol("System`ConditionalExpression")
Expand Down Expand Up @@ -124,6 +125,7 @@
SymbolInfix = Symbol("System`Infix")
SymbolInputForm = Symbol("System`InputForm")
SymbolInteger = Symbol("System`Integer")
SymbolIntegers = Symbol("System`Integers")
SymbolIntegrate = Symbol("System`Integrate")
SymbolLeft = Symbol("System`Left")
SymbolLength = Symbol("System`Length")
Expand Down Expand Up @@ -200,6 +202,7 @@
SymbolRational = Symbol("System`Rational")
SymbolRe = Symbol("System`Re")
SymbolReal = Symbol("System`Real")
SymbolReals = Symbol("System`Reals")
SymbolRealAbs = Symbol("System`RealAbs")
SymbolRealDigits = Symbol("System`RealDigits")
SymbolRealSign = Symbol("System`RealSign")
Expand Down
32 changes: 26 additions & 6 deletions test/builtin/calculus/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,34 @@ def test_solve():
"Issue #1235",
),
(
"Solve[{x^2==4 && x < 0},{x}]",
"{x->-2}",
"",
"Solve[Abs[-2/3*(lambda + 2) + 8/3 + 4] == 4, lambda,Reals]",
"{{lambda -> 2}, {lambda -> 14}}",
"abs()",
),
(
"Solve[{x^2==4 && x < 0 && x > -4},{x}]",
"{x->-2}",
"",
"Solve[q^3 == (20-12)/(4-3), q,Reals]",
"{{q -> 2}}",
"domain check",
),
(
"Solve[x + Pi/3 == 2k*Pi + Pi/6 || x + Pi/3 == 2k*Pi + 5Pi/6, x,Reals]",
"{{x -> -Pi / 6 + 2 k Pi}, {x -> Pi / 2 + 2 k Pi}}",
"logics involved",
),
(
"Solve[m - 1 == 0 && -(m + 1) != 0, m,Reals]",
"{{m -> 1}}",
"logics and constraints",
),
(
"Solve[(lambda + 1)/6 == 1/(mu - 1) == lambda/4, {lambda, mu},Reals]",
"{{lambda -> 2, mu -> 3}}",
"chained equations",
),
(
"Solve[2*x0*Log[x0] + x0 - 2*a*x0 == -1 && x0^2*Log[x0] - a*x0^2 + b == b - x0, {x0, a, b},Reals]",
"{{x0 -> 1, a -> 1}}",
"excess variable b",
),
):
session.evaluate("Clear[h]; Clear[g]; Clear[f];")
Expand Down

0 comments on commit 8e10644

Please sign in to comment.