Skip to content

Commit 1c47d96

Browse files
committed
Merge branch 'Dispatch2' into FeyncalcFixes
2 parents 75c1cbf + 916a8a7 commit 1c47d96

File tree

1 file changed

+91
-10
lines changed

1 file changed

+91
-10
lines changed

mathics/builtin/patterns.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@
3434

3535

3636
from mathics.version import __version__ # noqa used in loading to check consistency.
37-
from mathics.builtin.base import Builtin, BinaryOperator, PostfixOperator
37+
from mathics.builtin.base import Builtin, BinaryOperator, PostfixOperator, AtomBuiltin
3838
from mathics.builtin.base import PatternObject, PatternError
3939
from mathics.builtin.lists import python_levelspec, InvalidLevelspecError
4040

4141
from mathics.core.expression import (
42+
Atom,
43+
String,
4244
Symbol,
4345
Expression,
4446
Number,
@@ -103,13 +105,16 @@ class RuleDelayed(BinaryOperator):
103105

104106

105107
def create_rules(rules_expr, expr, name, evaluation, extra_args=[]):
106-
if rules_expr.has_form("Dispatch", None):
107-
rules_expr = rules_expr.leaves[0]
108+
if isinstance(rules_expr, Dispatch):
109+
return rules_expr.rules, False
110+
elif rules_expr.has_form("Dispatch", None):
111+
return Dispatch(rules_expr._leaves, evaluation)
112+
108113
if rules_expr.has_form("List", None):
109114
rules = rules_expr.leaves
110115
else:
111116
rules = [rules_expr]
112-
any_lists = any(item.has_form("List", None) for item in rules)
117+
any_lists = any(item.has_form(("List", "Dispatch"), None) for item in rules)
113118
if any_lists:
114119
all_lists = all(item.has_form("List", None) for item in rules)
115120
if all_lists:
@@ -287,10 +292,8 @@ def apply(self, expr, rules, evaluation):
287292
"ReplaceAll[expr_, rules_]"
288293
try:
289294
rules, ret = create_rules(rules, expr, "ReplaceAll", evaluation)
290-
291295
if ret:
292296
return rules
293-
294297
result, applied = expr.apply_rules(rules, evaluation)
295298
return result
296299
except PatternError:
@@ -1468,18 +1471,57 @@ def yield_match(vars, rest):
14681471
)
14691472

14701473

1471-
class Dispatch(Builtin):
1474+
class Dispatch(Atom):
1475+
def __init__(self, rulelist, evaluation):
1476+
self.src = Expression(SymbolList, *rulelist)
1477+
self.rules = [Rule(rule._leaves[0], rule._leaves[1]) for rule in rulelist]
1478+
self._leaves = None
1479+
self._head = Symbol("Dispatch")
1480+
1481+
def get_sort_key(self):
1482+
return self.src.get_sort_key()
1483+
1484+
def get_atom_name(self):
1485+
return "System`Dispatch"
1486+
1487+
def __repr__(self):
1488+
return "dispatch"
1489+
1490+
def atom_to_boxes(self, f, evaluation):
1491+
leaves = self.src.format(evaluation, f.get_name())
1492+
return Expression(
1493+
"RowBox",
1494+
Expression(
1495+
SymbolList, String("Dispatch"), String("["), leaves, String("]")
1496+
),
1497+
)
1498+
1499+
1500+
class DispatchAtom(AtomBuiltin):
14721501
"""
14731502
<dl>
14741503
<dt>'Dispatch[$rulelist$]'
14751504
<dd>Introduced for compatibility. Currently, it just return $rulelist$.
14761505
In the future, it should return an optimized DispatchRules atom,
14771506
containing an optimized set of rules.
14781507
</dl>
1479-
1508+
>> rules = {{a_,b_}->a^b, {1,2}->3., F[x_]->x^2};
1509+
>> F[2] /. rules
1510+
= 4
1511+
>> dispatchrules = Dispatch[rules]
1512+
= Dispatch[{{a_, b_} -> a ^ b, {1, 2} -> 3., F[x_] -> x ^ 2}]
1513+
>> F[2] /. dispatchrules
1514+
= 4
14801515
"""
14811516

1482-
def apply_stub(self, rules, evaluation):
1517+
messages = {
1518+
"invrpl": "`1` is not a valid rule or list of rules.",
1519+
}
1520+
1521+
def __repr__(self):
1522+
return "dispatchatom"
1523+
1524+
def apply_create(self, rules, evaluation):
14831525
"""Dispatch[rules_List]"""
14841526
# TODO:
14851527
# The next step would be to enlarge this method, in order to
@@ -1489,4 +1531,43 @@ def apply_stub(self, rules, evaluation):
14891531
# compiled patters, and modify Replace and ReplaceAll to handle this
14901532
# kind of objects.
14911533
#
1492-
return rules
1534+
if isinstance(rules, Dispatch):
1535+
return rules
1536+
if rules.is_symbol():
1537+
rules = rules.evaluate(evaluation)
1538+
1539+
if rules.has_form("List", None):
1540+
rules = rules._leaves
1541+
else:
1542+
rules = [rules]
1543+
1544+
all_list = all(rule.has_form("List", None) for rule in rules)
1545+
if all_list:
1546+
leaves = [self.apply_create(rule, evaluation) for rule in rules]
1547+
return Expression(SymbolList, *leaves)
1548+
flatten_list = []
1549+
for rule in rules:
1550+
if rule.is_symbol():
1551+
rule = rule.evaluate(evaluation)
1552+
if rule.has_form("List", None):
1553+
flatten_list.extend(rule._leaves)
1554+
elif rule.has_form(("Rule", "RuleDelayed"), 2):
1555+
flatten_list.append(rule)
1556+
elif isinstance(rule, Dispatch):
1557+
flatten_list.extend(rule.src._leaves)
1558+
else:
1559+
# WMA does not raise this message: just leave it unevaluated,
1560+
# and raise an error when the dispatch rule is used.
1561+
evaluation.message("Dispatch", "invrpl", rule)
1562+
return
1563+
try:
1564+
return Dispatch(flatten_list, evaluation)
1565+
except:
1566+
return
1567+
1568+
def apply_normal(self, dispatch, evaluation):
1569+
"""Normal[dispatch_Dispatch]"""
1570+
if isinstance(dispatch, Dispatch):
1571+
return dispatch.src
1572+
else:
1573+
return dispatch._leaves[0]

0 commit comments

Comments
 (0)