Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@
[(#2568)](https://github.com/PennyLaneAI/catalyst/pull/2568)
[(#2578)](https://github.com/PennyLaneAI/catalyst/pull/2578)
[(#2711)](https://github.com/PennyLaneAI/catalyst/pull/2711)
[(#2765)](https://github.com/PennyLaneAI/catalyst/pull/2765)

The framework is interfaced with a new `graph_decomposition` pass decorator
with key capabilities:
Expand Down
28 changes: 28 additions & 0 deletions frontend/test/pytest/from_plxpr/test_decompose_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
This module tests the decompose transformation.
"""

# pylint: disable=too-many-lines

from contextlib import nullcontext as does_not_raise
from functools import partial

Expand Down Expand Up @@ -318,6 +320,32 @@ def circuit():
resources = qp.specs(circuit, level="device")()["resources"].gate_types
assert resources == expected_resources

def test_empty_rule(self):
"""Test that a decomposition rule with no ops is handled correctly."""

@decomposition_rule(op_type="PauliX")
def empty_decomp(_wire):
pass

@qp.qjit(capture=True)
@graph_decomposition(
gate_set={"PauliY"},
fixed_decomps={"PauliX": empty_decomp},
)
@qp.qnode(qp.device("lightning.qubit", wires=1))
def circuit():
qp.X(0)
qp.Y(0)

# register the empty decomposition rule
empty_decomp(int)

return qp.expval(qp.Z(0))

expected_resources = {"PauliY": 1}
resources = qp.specs(circuit, level="device")()["resources"].gate_types
assert resources == expected_resources

@pytest.mark.xfail(
reason="graph-decomposition supports pre-compiled rules, alt_decomps and fix_decomps"
)
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Quantum/Transforms/DecompGraphSolver/DGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ struct DecompositionGraph::Impl {
// Connect rule vertex to output operator vertex
boost::add_edge(rule_vertex, output_vertex, GraphWeightedEdge{}, graph);

// Empty rules (with no inputs) are effectively just target gates
// and don't need to be connected to input operator vertices
if (rule.isEmpty()) {
continue;
}

// Connect rule vertex to input operator vertices
for (const auto &input : rule.inputs) {
const auto input_id = registerOp(input.op);
Expand Down
12 changes: 9 additions & 3 deletions mlir/lib/Quantum/Transforms/DecompGraphSolver/DGSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ ChosenDecompRule DecompositionSolver::evalRule(const RuleNode &rule)
}
}

if (rule.inputs.empty() && total_cost == 0.0) {
return invalidRule(solution.op); // invalid rule
}
solution.totalCost = total_cost;
return solution;
}
Expand All @@ -81,6 +78,15 @@ ChosenDecompRule DecompositionSolver::bestRule(const OperatorNode &op)
return invalidRule(op); // no valid rules
}

// if there is an empty rule (with no inputs)
// for the given operator, pick this as the
// best rule with zero cost
for (const auto &rule : all_rules) {
if (rule.isEmpty()) {
return evalRule(rule);
}
}

std::optional<ChosenDecompRule> best_rule;

for (const auto &rule : all_rules) {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Quantum/Transforms/DecompGraphSolver/DGTypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ struct RuleNode {
{
return name == other.name && output == other.output && origin == other.origin;
}

bool isEmpty() const { return inputs.empty(); }
};

/**
Expand Down
22 changes: 22 additions & 0 deletions mlir/unittests/DecompGraphSolver/Test_DecompGraphSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,25 @@ TEST_CASE("Test GraphSolver with MultiRZ decompositions", "[DecompGraph::Solver]
REQUIRE(chosen_rule_multiRZ5.ruleName == "multiRZ5_to_rz");
REQUIRE(chosen_rule_multiRZ5.totalCost == 1.0 * 5);
}

TEST_CASE("Test GraphSolver with empty decomposition rules", "[DecompGraph::Solver]")
{
const OperatorNode hadamard{"Hadamard"};
const OperatorNode globalPhase{"GlobalPhase"};

const WeightedGateset gateset{{{globalPhase, 1.0}}};

const std::vector<RuleNode> rules{
{"hadamard_to_globalPhase", hadamard, {}},
};

const DecompositionGraph graph({hadamard}, gateset, rules);
DecompositionSolver solver(graph);
const auto result = solver.solve();
REQUIRE(result.size() == 1);
const auto &chosen_rule = result.at(hadamard);
REQUIRE_FALSE(chosen_rule.isBasis);
REQUIRE(chosen_rule.ruleName == "hadamard_to_globalPhase");
REQUIRE(chosen_rule.inputs.empty());
REQUIRE(chosen_rule.totalCost == 0.0);
}
Loading