diff --git a/python/cudaq/kernel/analysis.py b/python/cudaq/kernel/analysis.py index 3f8f86c145f..a28adc15716 100644 --- a/python/cudaq/kernel/analysis.py +++ b/python/cudaq/kernel/analysis.py @@ -36,6 +36,19 @@ def visit_Assign(self, node): node.value, 'id') and node.value.id in self.measureResultsVars: self.measureResultsVars.append(target.id) return + # Check if the new variable is assigned from a measurement result + if hasattr(node, 'value') and isinstance( + node.value, + ast.Name) and node.value.id in self.measureResultsVars: + self.measureResultsVars.append(target.id) + return + # Check if the new variable uses measurement results + if hasattr(node, 'value') and isinstance( + node.value, ast.BoolOp) and 'values' in node.value.__dict__: + for value in node.value.__dict__['values']: + if hasattr(value, 'id') and value.id in self.measureResultsVars: + self.measureResultsVars.append(target.id) + return if not 'func' in node.value.__dict__: return creatorFunc = node.value.func @@ -54,8 +67,14 @@ def getVariableName(self, node): return '' def checkForMeasureResult(self, value): - return self.isMeasureCallOp(value) or self.getVariableName( - value) in self.measureResultsVars + if self.isMeasureCallOp(value): + return True + if self.getVariableName(value) in self.measureResultsVars: + return True + if isinstance(value, ast.BoolOp) and 'values' in value.__dict__: + for val in value.__dict__['values']: + if self.getVariableName(val) in self.measureResultsVars: + return True def visit_If(self, node): condition = node.test diff --git a/python/tests/mlir/ast_conditionals.py b/python/tests/mlir/ast_conditionals.py index b2445af0372..91198cf4cc6 100644 --- a/python/tests/mlir/ast_conditionals.py +++ b/python/tests/mlir/ast_conditionals.py @@ -204,6 +204,49 @@ def test14(): # CHECK-LABEL: func.func @__nvqpp__mlirgen__test14() attributes {"cudaq-entrypoint", "cudaq-kernel", qubitMeasurementFeedback = true} { + @cudaq.kernel + def test15(): + qubits = cudaq.qvector(2) + h(qubits) + foo = mx(qubits[0]) + bar = foo + + if not bar: + reset(qubits[0]) + + print(test15) + + # CHECK-LABEL: func.func @__nvqpp__mlirgen__test15() attributes {"cudaq-entrypoint", "cudaq-kernel", qubitMeasurementFeedback = true} { + + @cudaq.kernel + def test16(): + qubits = cudaq.qvector(2) + x(qubits) + foo = mx(qubits[0]) + bar = my(qubits[1]) + qux = foo or bar + + if qux: + h(qubits[0]) + + print(test16) + + # CHECK-LABEL: func.func @__nvqpp__mlirgen__test16() attributes {"cudaq-entrypoint", "cudaq-kernel", qubitMeasurementFeedback = true} { + + @cudaq.kernel + def test17(): + qubits = cudaq.qvector(2) + x(qubits) + foo = mx(qubits[0]) + bar = my(qubits[1]) + + if not foo and bar: + h(qubits[0]) + + print(test17) + + # CHECK-LABEL: func.func @__nvqpp__mlirgen__test17() attributes {"cudaq-entrypoint", "cudaq-kernel", qubitMeasurementFeedback = true} { + # leave for gdb debugging if __name__ == "__main__":