Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions pyzx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __repr__(self) -> str:
vs.append(f'{v}')
else:
vs.append(f'{v}^{c}')
return '*'.join(vs)
return ''.join(vs)

def __mul__(self, other: 'Term') -> 'Term':
"""Return the product of two terms, combining exponents.
Expand Down Expand Up @@ -276,7 +276,7 @@ def __str__(self) -> str:
elif c == 1:
ts.append(f'{t}')
else:
ts.append(f'{c}{t}')
ts.append(f'{c}{t}')
return ' + '.join(ts)

def __eq__(self, other: object) -> bool:
Expand Down Expand Up @@ -387,16 +387,25 @@ def new_const(coeff: Union[int, Fraction]) -> Poly:


poly_grammar = Lark("""
start : "(" start ")" | term ("+" term)*
term : (intf | frac)? factor ("*" factor)*
?factor : intf | frac | pi | pifrac | var
start : expr
?expr : expr "+" term -> add
| expr "-" term -> sub
| term
term : neg_term | pos_term
neg_term : "-" pos_term
pos_term : factor (("*" | "⋅")? factor)*
?factor : base ("^" exponent)?
base : intf | frac | decimal | pi | pifrac | var | "(" expr ")"
exponent : intf
var : CNAME
intf : INT
decimal : DECIMAL
pi : "\\pi" | "pi"
frac : INT "/" INT
pifrac : [INT] pi "/" INT

%import common.INT
%import common.DECIMAL
%import common.CNAME
%import common.WS
%ignore WS
Expand All @@ -411,15 +420,40 @@ class PolyTransformer(Transformer):
"""
def __init__(self, new_var: Callable[[str], Poly]):
super().__init__()

self._new_var = new_var

def start(self, items: List[Poly]) -> Poly:
return reduce(add, items)
def start(self, items: List[Any]) -> Poly:
return items[0]

def add(self, items: List[Any]) -> Poly:
return items[0] + items[1]

def sub(self, items: List[Any]) -> Poly:
return items[0] - items[1]

def term(self, items: List[Any]) -> Poly:
return items[0]

def neg_term(self, items: List[Any]) -> Poly:
return -items[0] # Negate the pos_term

def term(self, items: List[Poly]) -> Poly:
def pos_term(self, items: List[Any]) -> Poly:
return reduce(mul, items)

def factor(self, items: List[Any]) -> Poly:
if len(items) == 1:
return items[0]
# Handle exponentiation: base^exponent
base = items[0]
exponent = items[1]
return base ** exponent

def base(self, items: List[Any]) -> Poly:
return items[0]

def exponent(self, items: List[Any]) -> int:
return items[0].terms[0][0]

def var(self, items: List[Any]) -> Poly:
v = str(items[0])
return self._new_var(v)
Expand All @@ -430,6 +464,9 @@ def pi(self, _: List[Any]) -> Poly:
def intf(self, items: List[Any]) -> Poly:
return new_const(int(items[0]))

def decimal(self, items: List[Any]) -> Poly:
return new_const(Fraction(float(items[0])).limit_denominator())

def frac(self, items: List[Any]) -> Poly:
return new_const(Fraction(int(items[0]), int(items[1])))

Expand Down