|
10 | 10 | Atom, |
11 | 11 | Expression, |
12 | 12 | Integer, |
| 13 | + Integer0, |
13 | 14 | Integer1, |
| 15 | + RationalOneHalf, |
14 | 16 | Number, |
15 | 17 | Symbol, |
16 | 18 | SymbolFalse, |
17 | 19 | SymbolNull, |
18 | 20 | SymbolTrue, |
19 | 21 | ) |
20 | 22 | from mathics.core.convert import from_sympy, sympy_symbol_prefix |
| 23 | +from mathics.core.rules import Pattern |
21 | 24 |
|
22 | 25 | import sympy |
23 | 26 |
|
@@ -62,7 +65,6 @@ def _expand(expr): |
62 | 65 |
|
63 | 66 | if kwargs["modulus"] is not None and kwargs["modulus"] <= 0: |
64 | 67 | return Integer(0) |
65 | | - |
66 | 68 | # A special case for trigonometric functions |
67 | 69 | if "trig" in kwargs and kwargs["trig"]: |
68 | 70 | if expr.has_form("Sin", 1): |
@@ -149,7 +151,6 @@ def unconvert_subexprs(expr): |
149 | 151 | ) |
150 | 152 |
|
151 | 153 | sympy_expr = convert_sympy(expr) |
152 | | - |
153 | 154 | if deep: |
154 | 155 | # thread over everything |
155 | 156 | for (i, sub_expr,) in enumerate(sub_exprs): |
@@ -192,7 +193,6 @@ def unconvert_subexprs(expr): |
192 | 193 | sympy_expr = sympy_expr.expand(**hints) |
193 | 194 | result = from_sympy(sympy_expr) |
194 | 195 | result = unconvert_subexprs(result) |
195 | | - |
196 | 196 | return result |
197 | 197 |
|
198 | 198 |
|
@@ -1413,3 +1413,202 @@ def apply(self, expr, form, h, evaluation): |
1413 | 1413 | return Expression( |
1414 | 1414 | "List", *[Expression(h, *[i for i in s]) for s in exponents] |
1415 | 1415 | ) |
| 1416 | + |
| 1417 | + |
| 1418 | +class Collect(Builtin): |
| 1419 | + """ |
| 1420 | + <dl> |
| 1421 | + <dt>'Collect[$expr$, $x$]' |
| 1422 | + <dd> Expands $expr$ and collect together terms having the same power of $x$. |
| 1423 | + <dt>'Collect[$expr$, {$x_1$, $x_2$, ...}]' |
| 1424 | + <dd> Expands $expr$ and collect together terms having the same powers of |
| 1425 | + $x_1$, $x_2$, .... |
| 1426 | + <dt>'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]' |
| 1427 | + <dd> After collect the terms, applies $filter$ to each coefficient. |
| 1428 | + </dl> |
| 1429 | +
|
| 1430 | + >> Collect[(x+y)^3, y] |
| 1431 | + = x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3 |
| 1432 | + >> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y] |
| 1433 | + = 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z] |
| 1434 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y] |
| 1435 | + = 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3 |
| 1436 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}] |
| 1437 | + = 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3 |
| 1438 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h] |
| 1439 | + = x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1] |
| 1440 | + """ |
| 1441 | + |
| 1442 | + rules = { |
| 1443 | + "Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]", |
| 1444 | + } |
| 1445 | + |
| 1446 | + def apply_var_filter(self, expr, varlst, filt, evaluation): |
| 1447 | + """Collect[expr_, varlst_, filt_]""" |
| 1448 | + from mathics.builtin.patterns import match |
| 1449 | + |
| 1450 | + if varlst.is_symbol(): |
| 1451 | + var_exprs = [varlst] |
| 1452 | + elif varlst.has_form("List", None): |
| 1453 | + var_exprs = varlst.get_leaves() |
| 1454 | + else: |
| 1455 | + var_exprs = [varlst] |
| 1456 | + |
| 1457 | + if len(var_exprs) > 1: |
| 1458 | + target_pat = Pattern.create(Expression("Alternatives", *var_exprs)) |
| 1459 | + var_pats = [Pattern.create(var) for var in var_exprs] |
| 1460 | + else: |
| 1461 | + target_pat = Pattern.create(varlst) |
| 1462 | + var_pats = [target_pat] |
| 1463 | + |
| 1464 | + expr = expand( |
| 1465 | + expr, |
| 1466 | + numer=True, |
| 1467 | + denom=False, |
| 1468 | + deep=False, |
| 1469 | + trig=False, |
| 1470 | + modulus=None, |
| 1471 | + target_pat=target_pat, |
| 1472 | + ) |
| 1473 | + if filt == Symbol("Identity"): |
| 1474 | + filt = None |
| 1475 | + |
| 1476 | + def key_powers(lst): |
| 1477 | + key = Expression("Plus", *lst) |
| 1478 | + key = key.evaluate(evaluation) |
| 1479 | + if key.is_numeric(): |
| 1480 | + return key.to_python() |
| 1481 | + return 0 |
| 1482 | + |
| 1483 | + def powers_list(pf): |
| 1484 | + powers = [Integer0 for i, p in enumerate(var_pats)] |
| 1485 | + if pf is None: |
| 1486 | + return powers |
| 1487 | + if pf.is_symbol(): |
| 1488 | + for i, pat in enumerate(var_pats): |
| 1489 | + if match(pf, pat, evaluation): |
| 1490 | + powers[i] = Integer(1) |
| 1491 | + return powers |
| 1492 | + if pf.has_form("Sqrt", 1): |
| 1493 | + for i, pat in enumerate(var_pats): |
| 1494 | + if match(pf._leaves[0], pat, evaluation): |
| 1495 | + powers[i] = RationalOneHalf |
| 1496 | + return powers |
| 1497 | + if pf.has_form("Power", 2): |
| 1498 | + for i, pat in enumerate(var_pats): |
| 1499 | + matchval = match(pf._leaves[0], pat, evaluation) |
| 1500 | + if matchval: |
| 1501 | + powers[i] = pf._leaves[1] |
| 1502 | + return powers |
| 1503 | + if pf.has_form("Times", None): |
| 1504 | + contrib = [powers_list(factor) for factor in pf._leaves] |
| 1505 | + for i in range(len(var_pats)): |
| 1506 | + powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate( |
| 1507 | + evaluation |
| 1508 | + ) |
| 1509 | + return powers |
| 1510 | + return powers |
| 1511 | + |
| 1512 | + def split_coeff_pow(term: Expression): |
| 1513 | + """ |
| 1514 | + This function factorizes term in a coefficent free |
| 1515 | + of powers of the target variables, and a factor with |
| 1516 | + that powers. |
| 1517 | + """ |
| 1518 | + coeffs = [] |
| 1519 | + powers = [] |
| 1520 | + # First, split factors on those which are powers of the variables |
| 1521 | + # and the rest. |
| 1522 | + if term.is_free(target_pat, evaluation): |
| 1523 | + coeffs.append(term) |
| 1524 | + elif ( |
| 1525 | + term.is_symbol() |
| 1526 | + or term.has_form("Power", 2) |
| 1527 | + or term.has_form("Sqrt", 1) |
| 1528 | + ): |
| 1529 | + powers.append(term) |
| 1530 | + elif term.has_form("Times", None): |
| 1531 | + for factor in term.leaves: |
| 1532 | + if factor.is_free(target_pat, evaluation): |
| 1533 | + coeffs.append(factor) |
| 1534 | + elif match(factor, target_pat, evaluation): |
| 1535 | + powers.append(factor) |
| 1536 | + elif ( |
| 1537 | + factor.has_form("Power", 2) or factor.has_form("Sqrt", 1) |
| 1538 | + ) and match(factor._leaves[0], target_pat, evaluation): |
| 1539 | + powers.append(factor) |
| 1540 | + else: |
| 1541 | + coeffs.append(factor) |
| 1542 | + else: |
| 1543 | + coeffs.append(term) |
| 1544 | + # Now, rebuild both factors |
| 1545 | + if len(coeffs) == 0: |
| 1546 | + coeffs = None |
| 1547 | + elif len(coeffs) == 1: |
| 1548 | + coeffs = coeffs[0] |
| 1549 | + else: |
| 1550 | + coeffs = Expression("Times", *coeffs) |
| 1551 | + if len(powers) == 0: |
| 1552 | + powers = None |
| 1553 | + elif len(powers) == 1: |
| 1554 | + powers = powers[0] |
| 1555 | + else: |
| 1556 | + powers = Expression("Times", *sorted(powers)) |
| 1557 | + return coeffs, powers |
| 1558 | + |
| 1559 | + if expr.is_free(target_pat, evaluation): |
| 1560 | + if filt: |
| 1561 | + return Expression(filt, expr).evaluate(evaluation) |
| 1562 | + else: |
| 1563 | + return expr |
| 1564 | + elif expr.is_symbol() or expr.has_form("Power", 2) or expr.has_form("Sqrt", 1): |
| 1565 | + if filt: |
| 1566 | + return Expression( |
| 1567 | + "Times", Expression(filt, Integer1).evaluate(evaluation), expr |
| 1568 | + ) |
| 1569 | + else: |
| 1570 | + return expr |
| 1571 | + elif expr.has_form("Plus", None): |
| 1572 | + coeff_dict = {} |
| 1573 | + powers_dict = {} |
| 1574 | + powers_order = {} |
| 1575 | + for term in expr._leaves: |
| 1576 | + coeff, powers = split_coeff_pow(term) |
| 1577 | + pl = powers_list(powers) |
| 1578 | + key = str(pl) |
| 1579 | + if not key in powers_dict: |
| 1580 | + powers_dict[key] = powers |
| 1581 | + coeff_dict[key] = [] |
| 1582 | + powers_order[key] = key_powers(pl) |
| 1583 | + |
| 1584 | + coeff_dict[key].append(Integer1 if coeff is None else coeff) |
| 1585 | + |
| 1586 | + terms = [] |
| 1587 | + for key in sorted( |
| 1588 | + coeff_dict, key=lambda kv: powers_order[kv], reverse=False |
| 1589 | + ): |
| 1590 | + val = coeff_dict[key] |
| 1591 | + if len(val) == 0: |
| 1592 | + continue |
| 1593 | + elif len(val) == 1: |
| 1594 | + coeff = val[0] |
| 1595 | + else: |
| 1596 | + coeff = Expression("Plus", *val) |
| 1597 | + if filt: |
| 1598 | + coeff = Expression(filt, coeff).evaluate(evaluation) |
| 1599 | + |
| 1600 | + powerfactor = powers_dict[key] |
| 1601 | + if powerfactor: |
| 1602 | + terms.append(Expression("Times", coeff, powerfactor)) |
| 1603 | + else: |
| 1604 | + terms.append(coeff) |
| 1605 | + |
| 1606 | + return Expression("Plus", *terms) |
| 1607 | + else: |
| 1608 | + if filt: |
| 1609 | + return Expression(filt, expr).evaluate(evaluation) |
| 1610 | + else: |
| 1611 | + return expr |
| 1612 | + |
| 1613 | + |
| 1614 | +# tejimeto |
0 commit comments