Skip to content

Commit 63395d1

Browse files
authored
Merge pull request #583 from CUQI-DTU/add_576_variable_class
Add support for algebra
2 parents e495208 + 77881de commit 63395d1

File tree

5 files changed

+461
-0
lines changed

5 files changed

+461
-0
lines changed

cuqi/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
""" Experimental module for testing new features and ideas. """
22
from . import mcmc
3+
from . import algebra
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._ast import VariableNode

cuqi/experimental/algebra/_ast.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""
2+
CUQIpy specific implementation of an abstract syntax tree (AST) for algebra on variables.
3+
4+
The AST is used to record the operations applied to variables allowing a delayed evaluation
5+
of said operations when needed by traversing the tree with the __call__ method.
6+
7+
For example, the following code
8+
9+
x = VariableNode('x')
10+
y = VariableNode('y')
11+
z = 2*x + 3*y
12+
13+
will create the following AST:
14+
15+
z = AddNode(
16+
MultiplyNode(
17+
ValueNode(2),
18+
VariableNode('x')
19+
),
20+
MultiplyNode(
21+
ValueNode(3),
22+
VariableNode('y')
23+
)
24+
)
25+
26+
which can be evaluated by calling the __call__ method:
27+
28+
z(x=1, y=2) # returns 8
29+
30+
"""
31+
32+
from abc import ABC, abstractmethod
33+
34+
convert_to_node = lambda x: x if isinstance(x, Node) else ValueNode(x)
35+
""" Converts any non-Node object to a ValueNode object. """
36+
37+
# ====== Base classes for the nodes ======
38+
39+
40+
class Node(ABC):
41+
"""Base class for all nodes in the abstract syntax tree.
42+
43+
Responsible for building the AST by creating nodes that represent the operations applied to variables.
44+
45+
Each subclass must implement the __call__ method that will evaluate the node given the input parameters.
46+
47+
"""
48+
49+
@abstractmethod
50+
def __call__(self, **kwargs):
51+
"""Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations."""
52+
pass
53+
54+
@abstractmethod
55+
def __repr__(self):
56+
"""String representation of the node. Used for printing the AST."""
57+
pass
58+
59+
def __add__(self, other):
60+
return AddNode(self, convert_to_node(other))
61+
62+
def __radd__(self, other):
63+
return AddNode(convert_to_node(other), self)
64+
65+
def __sub__(self, other):
66+
return SubtractNode(self, convert_to_node(other))
67+
68+
def __rsub__(self, other):
69+
return SubtractNode(convert_to_node(other), self)
70+
71+
def __mul__(self, other):
72+
return MultiplyNode(self, convert_to_node(other))
73+
74+
def __rmul__(self, other):
75+
return MultiplyNode(convert_to_node(other), self)
76+
77+
def __truediv__(self, other):
78+
return DivideNode(self, convert_to_node(other))
79+
80+
def __rtruediv__(self, other):
81+
return DivideNode(convert_to_node(other), self)
82+
83+
def __pow__(self, other):
84+
return PowerNode(self, convert_to_node(other))
85+
86+
def __rpow__(self, other):
87+
return PowerNode(convert_to_node(other), self)
88+
89+
def __neg__(self):
90+
return NegateNode(self)
91+
92+
def __abs__(self):
93+
return AbsNode(self)
94+
95+
def __getitem__(self, i):
96+
return GetItemNode(self, convert_to_node(i))
97+
98+
def __matmul__(self, other):
99+
return MatMulNode(self, convert_to_node(other))
100+
101+
def __rmatmul__(self, other):
102+
return MatMulNode(convert_to_node(other), self)
103+
104+
105+
class UnaryNode(Node, ABC):
106+
"""Base class for all unary nodes in the abstract syntax tree.
107+
108+
Parameters
109+
----------
110+
child : Node
111+
The direct child node on which the unary operation is performed.
112+
113+
"""
114+
115+
def __init__(self, child: Node):
116+
self.child = child
117+
118+
119+
class BinaryNode(Node, ABC):
120+
"""Base class for all binary nodes in the abstract syntax tree.
121+
122+
The op_symbol attribute is used for printing the operation in the __repr__ method.
123+
124+
Parameters
125+
----------
126+
left : Node
127+
Left child node to the binary operation.
128+
129+
right : Node
130+
Right child node to the binary operation.
131+
132+
"""
133+
134+
@property
135+
@abstractmethod
136+
def op_symbol(self):
137+
"""Symbol used to represent the operation in the __repr__ method."""
138+
pass
139+
140+
def __init__(self, left: Node, right: Node):
141+
self.left = left
142+
self.right = right
143+
144+
def __repr__(self):
145+
return f"{self.left} {self.op_symbol} {self.right}"
146+
147+
148+
class BinaryNodeWithParenthesis(BinaryNode, ABC):
149+
"""Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis."""
150+
151+
def __repr__(self):
152+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
153+
right = (
154+
f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
155+
)
156+
return f"{left} {self.op_symbol} {right}"
157+
158+
class BinaryNodeWithParenthesisNoSpace(BinaryNode, ABC):
159+
"""Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis but no space."""
160+
161+
def __repr__(self):
162+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
163+
right = (
164+
f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
165+
)
166+
return f"{left}{self.op_symbol}{right}"
167+
168+
169+
# ====== Specific implementations of the "leaf" nodes ======
170+
171+
172+
class VariableNode(Node):
173+
"""Node that represents a generic variable, e.g. "x" or "y".
174+
175+
Parameters
176+
----------
177+
name : str
178+
Name of the variable. Used for printing and to retrieve the given input value
179+
of the variable in the kwargs dictionary when evaluating the tree.
180+
181+
"""
182+
183+
def __init__(self, name):
184+
self.name = name
185+
186+
def __call__(self, **kwargs):
187+
"""Retrieves the value of the variable from the passed kwargs. If no value is found, it raises a KeyError."""
188+
if not self.name in kwargs:
189+
raise KeyError(
190+
f"Variable '{self.name}' not found in the given input parameters. Unable to evaluate the expression."
191+
)
192+
return kwargs[self.name]
193+
194+
def __repr__(self):
195+
return self.name
196+
197+
198+
class ValueNode(Node):
199+
"""Node that represents a constant value. The value can be any python object that is not a Node.
200+
201+
Parameters
202+
----------
203+
value : object
204+
The python object that represents the value of the node.
205+
206+
"""
207+
208+
def __init__(self, value):
209+
self.value = value
210+
211+
def __call__(self, **kwargs):
212+
"""Returns the value of the node."""
213+
return self.value
214+
215+
def __repr__(self):
216+
return str(self.value)
217+
218+
219+
# ====== Specific implementations of the "internal" nodes ======
220+
221+
222+
class AddNode(BinaryNode):
223+
"""Node that represents the addition operation."""
224+
225+
@property
226+
def op_symbol(self):
227+
return "+"
228+
229+
def __call__(self, **kwargs):
230+
return self.left(**kwargs) + self.right(**kwargs)
231+
232+
233+
class SubtractNode(BinaryNode):
234+
"""Node that represents the subtraction operation."""
235+
236+
@property
237+
def op_symbol(self):
238+
return "-"
239+
240+
def __call__(self, **kwargs):
241+
return self.left(**kwargs) - self.right(**kwargs)
242+
243+
244+
class MultiplyNode(BinaryNodeWithParenthesis):
245+
"""Node that represents the multiplication operation."""
246+
247+
@property
248+
def op_symbol(self):
249+
return "*"
250+
251+
def __call__(self, **kwargs):
252+
return self.left(**kwargs) * self.right(**kwargs)
253+
254+
255+
class DivideNode(BinaryNodeWithParenthesis):
256+
"""Node that represents the division operation."""
257+
258+
@property
259+
def op_symbol(self):
260+
return "/"
261+
262+
def __call__(self, **kwargs):
263+
return self.left(**kwargs) / self.right(**kwargs)
264+
265+
266+
class PowerNode(BinaryNodeWithParenthesisNoSpace):
267+
"""Node that represents the power operation."""
268+
269+
@property
270+
def op_symbol(self):
271+
return "^"
272+
273+
def __call__(self, **kwargs):
274+
return self.left(**kwargs) ** self.right(**kwargs)
275+
276+
277+
class GetItemNode(BinaryNode):
278+
"""Node that represents the get item operation. Here the left node is the object and the right node is the index."""
279+
280+
def __call__(self, **kwargs):
281+
return self.left(**kwargs)[self.right(**kwargs)]
282+
283+
def __repr__(self):
284+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
285+
return f"{left}[{self.right}]"
286+
287+
@property
288+
def op_symbol(self):
289+
pass
290+
291+
292+
class NegateNode(UnaryNode):
293+
"""Node that represents the arithmetic negation operation."""
294+
295+
def __call__(self, **kwargs):
296+
return -self.child(**kwargs)
297+
298+
def __repr__(self):
299+
child = (
300+
f"({self.child})"
301+
if isinstance(self.child, (BinaryNode, UnaryNode))
302+
else str(self.child)
303+
)
304+
return f"-{child}"
305+
306+
307+
class AbsNode(UnaryNode):
308+
"""Node that represents the absolute value operation."""
309+
310+
def __call__(self, **kwargs):
311+
return abs(self.child(**kwargs))
312+
313+
def __repr__(self):
314+
return f"abs({self.child})"
315+
316+
317+
class MatMulNode(BinaryNodeWithParenthesis):
318+
"""Node that represents the matrix multiplication operation."""
319+
320+
@property
321+
def op_symbol(self):
322+
return "@"
323+
324+
def __call__(self, **kwargs):
325+
return self.left(**kwargs) @ self.right(**kwargs)

demos/howtos/algebra.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Algebra in CUQIpy
3+
=================
4+
5+
CUQIpy provides a simple algebraic framework for defining and manipulating
6+
variables.
7+
8+
In this example, we will demonstrate how to define a simple algebraic structure
9+
and perform some basic algebraic operations.
10+
11+
"""
12+
13+
#%%
14+
# Defining Variables
15+
# ------------------
16+
# To utilize the algebraic framework in CUQIpy, we first need to define some
17+
# variables to apply the algebraic operations on. In this example, we will
18+
# define variables `x` and `y`.
19+
20+
from cuqi.experimental.algebra import VariableNode
21+
22+
x = VariableNode('x')
23+
y = VariableNode('y')
24+
25+
#%%
26+
# Recording Algebraic Operations
27+
# --------------------
28+
# We can now perform some algebraic operations on the variables `x` and `y`.
29+
# The algebraic operations are recorded in a computational graph (abstract syntax tree).
30+
# The operations are recoding the correct ordering and adhering to the rules of
31+
# algebra.
32+
33+
print("Basic operations:")
34+
print(x + 1)
35+
print(x + y)
36+
print(x * y)
37+
print(x / y)
38+
39+
print("\nComplex operations:")
40+
print(x**2 + 2*x*y + y**2)
41+
print((x + y)**2)
42+
43+
print("\nProgrammatric operations:")
44+
print(x[2]+y[3])
45+
46+
# %%
47+
# Utilizing the Computational Graph
48+
# ---------------------------------
49+
# The computational graph can be utilized to evaluate the algebraic expressions
50+
# when desired. This means we can define mathematical expressions without
51+
# needing to evaluate them immediately.
52+
53+
# Define a mathematical expression
54+
expr1 = (x + y)**2
55+
56+
# Evaluate the expression (using the __call__ method)
57+
print(f"Expression {expr1} evaluated at x=2, y=3 yields {expr1(x=2, y=3)}")
58+
59+
# Another example
60+
expr2 = x**2 + 2*x*y + y**2 + 16
61+
print(f"Expression {expr2} evaluated at x=2, y=3 yields {expr2(x=2, y=3)}")
62+
63+
# Another example utilizing array indexing
64+
expr3 = x[1] + y[2]
65+
print(f"Expression {expr3} evaluated at x=[1,2,3], y=[4,5,6] yields {expr3(x=[1,2,3], y=[4,5,6])}")
66+
# %%

0 commit comments

Comments
 (0)