|
| 1 | +""" |
| 2 | +CUQIpy specific implementation of an abstract syntax tree (AST) for algebra on variables. |
| 3 | +
|
| 4 | +The AST is used to record the operations applied to variables allowing a delayed evaluation |
| 5 | +of said operations when needed by traversing the tree with the __call__ method. |
| 6 | +
|
| 7 | +For example, the following code |
| 8 | +
|
| 9 | + x = VariableNode('x') |
| 10 | + y = VariableNode('y') |
| 11 | + z = 2*x + 3*y |
| 12 | +
|
| 13 | +will create the following AST: |
| 14 | +
|
| 15 | +z = AddNode( |
| 16 | + MultiplyNode( |
| 17 | + ValueNode(2), |
| 18 | + VariableNode('x') |
| 19 | + ), |
| 20 | + MultiplyNode( |
| 21 | + ValueNode(3), |
| 22 | + VariableNode('y') |
| 23 | + ) |
| 24 | + ) |
| 25 | +
|
| 26 | +which can be evaluated by calling the __call__ method: |
| 27 | +
|
| 28 | + z(x=1, y=2) # returns 8 |
| 29 | +
|
| 30 | +""" |
| 31 | + |
| 32 | +from abc import ABC, abstractmethod |
| 33 | + |
| 34 | +convert_to_node = lambda x: x if isinstance(x, Node) else ValueNode(x) |
| 35 | +""" Converts any non-Node object to a ValueNode object. """ |
| 36 | + |
| 37 | +# ====== Base classes for the nodes ====== |
| 38 | + |
| 39 | + |
| 40 | +class Node(ABC): |
| 41 | + """Base class for all nodes in the abstract syntax tree. |
| 42 | +
|
| 43 | + Responsible for building the AST by creating nodes that represent the operations applied to variables. |
| 44 | +
|
| 45 | + Each subclass must implement the __call__ method that will evaluate the node given the input parameters. |
| 46 | +
|
| 47 | + """ |
| 48 | + |
| 49 | + @abstractmethod |
| 50 | + def __call__(self, **kwargs): |
| 51 | + """Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations.""" |
| 52 | + pass |
| 53 | + |
| 54 | + @abstractmethod |
| 55 | + def __repr__(self): |
| 56 | + """String representation of the node. Used for printing the AST.""" |
| 57 | + pass |
| 58 | + |
| 59 | + def __add__(self, other): |
| 60 | + return AddNode(self, convert_to_node(other)) |
| 61 | + |
| 62 | + def __radd__(self, other): |
| 63 | + return AddNode(convert_to_node(other), self) |
| 64 | + |
| 65 | + def __sub__(self, other): |
| 66 | + return SubtractNode(self, convert_to_node(other)) |
| 67 | + |
| 68 | + def __rsub__(self, other): |
| 69 | + return SubtractNode(convert_to_node(other), self) |
| 70 | + |
| 71 | + def __mul__(self, other): |
| 72 | + return MultiplyNode(self, convert_to_node(other)) |
| 73 | + |
| 74 | + def __rmul__(self, other): |
| 75 | + return MultiplyNode(convert_to_node(other), self) |
| 76 | + |
| 77 | + def __truediv__(self, other): |
| 78 | + return DivideNode(self, convert_to_node(other)) |
| 79 | + |
| 80 | + def __rtruediv__(self, other): |
| 81 | + return DivideNode(convert_to_node(other), self) |
| 82 | + |
| 83 | + def __pow__(self, other): |
| 84 | + return PowerNode(self, convert_to_node(other)) |
| 85 | + |
| 86 | + def __rpow__(self, other): |
| 87 | + return PowerNode(convert_to_node(other), self) |
| 88 | + |
| 89 | + def __neg__(self): |
| 90 | + return NegateNode(self) |
| 91 | + |
| 92 | + def __abs__(self): |
| 93 | + return AbsNode(self) |
| 94 | + |
| 95 | + def __getitem__(self, i): |
| 96 | + return GetItemNode(self, convert_to_node(i)) |
| 97 | + |
| 98 | + def __matmul__(self, other): |
| 99 | + return MatMulNode(self, convert_to_node(other)) |
| 100 | + |
| 101 | + def __rmatmul__(self, other): |
| 102 | + return MatMulNode(convert_to_node(other), self) |
| 103 | + |
| 104 | + |
| 105 | +class UnaryNode(Node, ABC): |
| 106 | + """Base class for all unary nodes in the abstract syntax tree. |
| 107 | +
|
| 108 | + Parameters |
| 109 | + ---------- |
| 110 | + child : Node |
| 111 | + The direct child node on which the unary operation is performed. |
| 112 | +
|
| 113 | + """ |
| 114 | + |
| 115 | + def __init__(self, child: Node): |
| 116 | + self.child = child |
| 117 | + |
| 118 | + |
| 119 | +class BinaryNode(Node, ABC): |
| 120 | + """Base class for all binary nodes in the abstract syntax tree. |
| 121 | +
|
| 122 | + The op_symbol attribute is used for printing the operation in the __repr__ method. |
| 123 | +
|
| 124 | + Parameters |
| 125 | + ---------- |
| 126 | + left : Node |
| 127 | + Left child node to the binary operation. |
| 128 | +
|
| 129 | + right : Node |
| 130 | + Right child node to the binary operation. |
| 131 | +
|
| 132 | + """ |
| 133 | + |
| 134 | + @property |
| 135 | + @abstractmethod |
| 136 | + def op_symbol(self): |
| 137 | + """Symbol used to represent the operation in the __repr__ method.""" |
| 138 | + pass |
| 139 | + |
| 140 | + def __init__(self, left: Node, right: Node): |
| 141 | + self.left = left |
| 142 | + self.right = right |
| 143 | + |
| 144 | + def __repr__(self): |
| 145 | + return f"{self.left} {self.op_symbol} {self.right}" |
| 146 | + |
| 147 | + |
| 148 | +class BinaryNodeWithParenthesis(BinaryNode, ABC): |
| 149 | + """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis.""" |
| 150 | + |
| 151 | + def __repr__(self): |
| 152 | + left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left) |
| 153 | + right = ( |
| 154 | + f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right) |
| 155 | + ) |
| 156 | + return f"{left} {self.op_symbol} {right}" |
| 157 | + |
| 158 | +class BinaryNodeWithParenthesisNoSpace(BinaryNode, ABC): |
| 159 | + """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis but no space.""" |
| 160 | + |
| 161 | + def __repr__(self): |
| 162 | + left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left) |
| 163 | + right = ( |
| 164 | + f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right) |
| 165 | + ) |
| 166 | + return f"{left}{self.op_symbol}{right}" |
| 167 | + |
| 168 | + |
| 169 | +# ====== Specific implementations of the "leaf" nodes ====== |
| 170 | + |
| 171 | + |
| 172 | +class VariableNode(Node): |
| 173 | + """Node that represents a generic variable, e.g. "x" or "y". |
| 174 | +
|
| 175 | + Parameters |
| 176 | + ---------- |
| 177 | + name : str |
| 178 | + Name of the variable. Used for printing and to retrieve the given input value |
| 179 | + of the variable in the kwargs dictionary when evaluating the tree. |
| 180 | +
|
| 181 | + """ |
| 182 | + |
| 183 | + def __init__(self, name): |
| 184 | + self.name = name |
| 185 | + |
| 186 | + def __call__(self, **kwargs): |
| 187 | + """Retrieves the value of the variable from the passed kwargs. If no value is found, it raises a KeyError.""" |
| 188 | + if not self.name in kwargs: |
| 189 | + raise KeyError( |
| 190 | + f"Variable '{self.name}' not found in the given input parameters. Unable to evaluate the expression." |
| 191 | + ) |
| 192 | + return kwargs[self.name] |
| 193 | + |
| 194 | + def __repr__(self): |
| 195 | + return self.name |
| 196 | + |
| 197 | + |
| 198 | +class ValueNode(Node): |
| 199 | + """Node that represents a constant value. The value can be any python object that is not a Node. |
| 200 | +
|
| 201 | + Parameters |
| 202 | + ---------- |
| 203 | + value : object |
| 204 | + The python object that represents the value of the node. |
| 205 | +
|
| 206 | + """ |
| 207 | + |
| 208 | + def __init__(self, value): |
| 209 | + self.value = value |
| 210 | + |
| 211 | + def __call__(self, **kwargs): |
| 212 | + """Returns the value of the node.""" |
| 213 | + return self.value |
| 214 | + |
| 215 | + def __repr__(self): |
| 216 | + return str(self.value) |
| 217 | + |
| 218 | + |
| 219 | +# ====== Specific implementations of the "internal" nodes ====== |
| 220 | + |
| 221 | + |
| 222 | +class AddNode(BinaryNode): |
| 223 | + """Node that represents the addition operation.""" |
| 224 | + |
| 225 | + @property |
| 226 | + def op_symbol(self): |
| 227 | + return "+" |
| 228 | + |
| 229 | + def __call__(self, **kwargs): |
| 230 | + return self.left(**kwargs) + self.right(**kwargs) |
| 231 | + |
| 232 | + |
| 233 | +class SubtractNode(BinaryNode): |
| 234 | + """Node that represents the subtraction operation.""" |
| 235 | + |
| 236 | + @property |
| 237 | + def op_symbol(self): |
| 238 | + return "-" |
| 239 | + |
| 240 | + def __call__(self, **kwargs): |
| 241 | + return self.left(**kwargs) - self.right(**kwargs) |
| 242 | + |
| 243 | + |
| 244 | +class MultiplyNode(BinaryNodeWithParenthesis): |
| 245 | + """Node that represents the multiplication operation.""" |
| 246 | + |
| 247 | + @property |
| 248 | + def op_symbol(self): |
| 249 | + return "*" |
| 250 | + |
| 251 | + def __call__(self, **kwargs): |
| 252 | + return self.left(**kwargs) * self.right(**kwargs) |
| 253 | + |
| 254 | + |
| 255 | +class DivideNode(BinaryNodeWithParenthesis): |
| 256 | + """Node that represents the division operation.""" |
| 257 | + |
| 258 | + @property |
| 259 | + def op_symbol(self): |
| 260 | + return "/" |
| 261 | + |
| 262 | + def __call__(self, **kwargs): |
| 263 | + return self.left(**kwargs) / self.right(**kwargs) |
| 264 | + |
| 265 | + |
| 266 | +class PowerNode(BinaryNodeWithParenthesisNoSpace): |
| 267 | + """Node that represents the power operation.""" |
| 268 | + |
| 269 | + @property |
| 270 | + def op_symbol(self): |
| 271 | + return "^" |
| 272 | + |
| 273 | + def __call__(self, **kwargs): |
| 274 | + return self.left(**kwargs) ** self.right(**kwargs) |
| 275 | + |
| 276 | + |
| 277 | +class GetItemNode(BinaryNode): |
| 278 | + """Node that represents the get item operation. Here the left node is the object and the right node is the index.""" |
| 279 | + |
| 280 | + def __call__(self, **kwargs): |
| 281 | + return self.left(**kwargs)[self.right(**kwargs)] |
| 282 | + |
| 283 | + def __repr__(self): |
| 284 | + left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left) |
| 285 | + return f"{left}[{self.right}]" |
| 286 | + |
| 287 | + @property |
| 288 | + def op_symbol(self): |
| 289 | + pass |
| 290 | + |
| 291 | + |
| 292 | +class NegateNode(UnaryNode): |
| 293 | + """Node that represents the arithmetic negation operation.""" |
| 294 | + |
| 295 | + def __call__(self, **kwargs): |
| 296 | + return -self.child(**kwargs) |
| 297 | + |
| 298 | + def __repr__(self): |
| 299 | + child = ( |
| 300 | + f"({self.child})" |
| 301 | + if isinstance(self.child, (BinaryNode, UnaryNode)) |
| 302 | + else str(self.child) |
| 303 | + ) |
| 304 | + return f"-{child}" |
| 305 | + |
| 306 | + |
| 307 | +class AbsNode(UnaryNode): |
| 308 | + """Node that represents the absolute value operation.""" |
| 309 | + |
| 310 | + def __call__(self, **kwargs): |
| 311 | + return abs(self.child(**kwargs)) |
| 312 | + |
| 313 | + def __repr__(self): |
| 314 | + return f"abs({self.child})" |
| 315 | + |
| 316 | + |
| 317 | +class MatMulNode(BinaryNodeWithParenthesis): |
| 318 | + """Node that represents the matrix multiplication operation.""" |
| 319 | + |
| 320 | + @property |
| 321 | + def op_symbol(self): |
| 322 | + return "@" |
| 323 | + |
| 324 | + def __call__(self, **kwargs): |
| 325 | + return self.left(**kwargs) @ self.right(**kwargs) |
0 commit comments