diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6a0cc6870e..ff941305d0 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -305,6 +305,7 @@ [(#2568)](https://github.com/PennyLaneAI/catalyst/pull/2568) [(#2578)](https://github.com/PennyLaneAI/catalyst/pull/2578) [(#2711)](https://github.com/PennyLaneAI/catalyst/pull/2711) + [(#2765)](https://github.com/PennyLaneAI/catalyst/pull/2765) The framework is interfaced with a new `graph_decomposition` pass decorator with key capabilities: diff --git a/frontend/test/pytest/from_plxpr/test_decompose_transform.py b/frontend/test/pytest/from_plxpr/test_decompose_transform.py index 47f6f35a39..ad1f2030d7 100644 --- a/frontend/test/pytest/from_plxpr/test_decompose_transform.py +++ b/frontend/test/pytest/from_plxpr/test_decompose_transform.py @@ -15,6 +15,8 @@ This module tests the decompose transformation. """ +# pylint: disable=too-many-lines + from contextlib import nullcontext as does_not_raise from functools import partial @@ -318,6 +320,32 @@ def circuit(): resources = qp.specs(circuit, level="device")()["resources"].gate_types assert resources == expected_resources + def test_empty_rule(self): + """Test that a decomposition rule with no ops is handled correctly.""" + + @decomposition_rule(op_type="PauliX") + def empty_decomp(_wire): + pass + + @qp.qjit(capture=True) + @graph_decomposition( + gate_set={"PauliY"}, + fixed_decomps={"PauliX": empty_decomp}, + ) + @qp.qnode(qp.device("lightning.qubit", wires=1)) + def circuit(): + qp.X(0) + qp.Y(0) + + # register the empty decomposition rule + empty_decomp(int) + + return qp.expval(qp.Z(0)) + + expected_resources = {"PauliY": 1} + resources = qp.specs(circuit, level="device")()["resources"].gate_types + assert resources == expected_resources + @pytest.mark.xfail( reason="graph-decomposition supports pre-compiled rules, alt_decomps and fix_decomps" ) diff --git a/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGBuilder.cpp b/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGBuilder.cpp index 2e8386e083..894cd45d83 100644 --- a/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGBuilder.cpp +++ b/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGBuilder.cpp @@ -182,6 +182,12 @@ struct DecompositionGraph::Impl { // Connect rule vertex to output operator vertex boost::add_edge(rule_vertex, output_vertex, GraphWeightedEdge{}, graph); + // Empty rules (with no inputs) are effectively just target gates + // and don't need to be connected to input operator vertices + if (rule.isEmpty()) { + continue; + } + // Connect rule vertex to input operator vertices for (const auto &input : rule.inputs) { const auto input_id = registerOp(input.op); diff --git a/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGSolver.cpp b/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGSolver.cpp index a81ee57a5b..c83ccddce5 100644 --- a/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGSolver.cpp +++ b/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGSolver.cpp @@ -67,9 +67,6 @@ ChosenDecompRule DecompositionSolver::evalRule(const RuleNode &rule) } } - if (rule.inputs.empty() && total_cost == 0.0) { - return invalidRule(solution.op); // invalid rule - } solution.totalCost = total_cost; return solution; } @@ -81,6 +78,15 @@ ChosenDecompRule DecompositionSolver::bestRule(const OperatorNode &op) return invalidRule(op); // no valid rules } + // if there is an empty rule (with no inputs) + // for the given operator, pick this as the + // best rule with zero cost + for (const auto &rule : all_rules) { + if (rule.isEmpty()) { + return evalRule(rule); + } + } + std::optional best_rule; for (const auto &rule : all_rules) { diff --git a/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGTypes.hpp b/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGTypes.hpp index fd10f8da26..929be8f1e6 100644 --- a/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGTypes.hpp +++ b/mlir/lib/Quantum/Transforms/DecompGraphSolver/DGTypes.hpp @@ -153,6 +153,8 @@ struct RuleNode { { return name == other.name && output == other.output && origin == other.origin; } + + bool isEmpty() const { return inputs.empty(); } }; /** diff --git a/mlir/unittests/DecompGraphSolver/Test_DecompGraphSolver.cpp b/mlir/unittests/DecompGraphSolver/Test_DecompGraphSolver.cpp index f7da9f517a..6767c066c3 100644 --- a/mlir/unittests/DecompGraphSolver/Test_DecompGraphSolver.cpp +++ b/mlir/unittests/DecompGraphSolver/Test_DecompGraphSolver.cpp @@ -533,3 +533,25 @@ TEST_CASE("Test GraphSolver with MultiRZ decompositions", "[DecompGraph::Solver] REQUIRE(chosen_rule_multiRZ5.ruleName == "multiRZ5_to_rz"); REQUIRE(chosen_rule_multiRZ5.totalCost == 1.0 * 5); } + +TEST_CASE("Test GraphSolver with empty decomposition rules", "[DecompGraph::Solver]") +{ + const OperatorNode hadamard{"Hadamard"}; + const OperatorNode globalPhase{"GlobalPhase"}; + + const WeightedGateset gateset{{{globalPhase, 1.0}}}; + + const std::vector rules{ + {"hadamard_to_globalPhase", hadamard, {}}, + }; + + const DecompositionGraph graph({hadamard}, gateset, rules); + DecompositionSolver solver(graph); + const auto result = solver.solve(); + REQUIRE(result.size() == 1); + const auto &chosen_rule = result.at(hadamard); + REQUIRE_FALSE(chosen_rule.isBasis); + REQUIRE(chosen_rule.ruleName == "hadamard_to_globalPhase"); + REQUIRE(chosen_rule.inputs.empty()); + REQUIRE(chosen_rule.totalCost == 0.0); +}