Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 35 additions & 51 deletions tket-py/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@ pub mod portmatching;
use crate::circuit::Tk2Circuit;
use crate::rewrite::PyCircuitRewrite;
use crate::utils::{create_py_exception, ConvertPyErr};
use derive_more::From;

use hugr::{HugrView, Node};
use pyo3::prelude::*;
use tket::portmatching::{CircuitPattern, PatternMatch, PatternMatcher};
use tket::Circuit;
use tket::portmatching::{Rule, RuleMatcher};

/// The module definition
pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
let m = PyModule::new(py, "pattern")?;
m.add_class::<Rule>()?;
m.add_class::<RuleMatcher>()?;
m.add_class::<PyRule>()?;
m.add_class::<PyRuleMatcher>()?;
m.add_class::<self::portmatching::PyCircuitPattern>()?;
m.add_class::<self::portmatching::PyPatternMatcher>()?;
m.add_class::<self::portmatching::PyPatternMatch>()?;
Expand Down Expand Up @@ -45,84 +44,69 @@ create_py_exception!(
"Conversion error from circuit to pattern."
);

#[derive(Clone)]
#[pyclass]
/// A rewrite rule defined by a left hand side and right hand side of an equation.
pub struct Rule(pub [Circuit; 2]);
#[pyclass]
#[pyo3(name = "Rule")]
#[repr(transparent)]
#[derive(Debug, Clone, From)]
pub struct PyRule(pub Rule);

#[pymethods]
impl Rule {
impl PyRule {
#[new]
fn new_rule(l: &Bound<PyAny>, r: &Bound<PyAny>) -> PyResult<Rule> {
fn new_rule(l: &Bound<PyAny>, r: &Bound<PyAny>) -> PyResult<PyRule> {
let l = Tk2Circuit::new(l)?;
let r = Tk2Circuit::new(r)?;

Ok(Rule([l.circ, r.circ]))
let rule = Rule::new(l.circ, r.circ);
Ok(PyRule(rule))
}

/// The left hand side of the rule.
///
/// This is the pattern that will be matched against the target circuit.
fn lhs(&self) -> Tk2Circuit {
Tk2Circuit {
circ: self.0[0].clone(),
}
Tk2Circuit { circ: self.0.lhs() }
}

/// The right hand side of the rule.
///
/// This is the replacement that will be applied to the target circuit.
fn rhs(&self) -> Tk2Circuit {
Tk2Circuit {
circ: self.0[1].clone(),
}
Tk2Circuit { circ: self.0.rhs() }
}
}

#[pyclass]
struct RuleMatcher {
matcher: PatternMatcher,
rights: Vec<Circuit>,
#[pyo3(name = "RuleMatcher")]
#[repr(transparent)]
#[derive(Debug, Clone, From)]
struct PyRuleMatcher {
rmatcher: RuleMatcher,
}

#[pymethods]
impl RuleMatcher {
impl PyRuleMatcher {
#[new]
pub fn from_rules(rules: Vec<Rule>) -> PyResult<Self> {
let (lefts, rights): (Vec<_>, Vec<_>) =
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns.convert_pyerrs()?);

Ok(Self { matcher, rights })
pub fn from_rules(rules: Vec<PyRule>) -> PyResult<Self> {
let rules: Vec<Rule> = rules.into_iter().map(|r| r.0).collect();
let rmatcher = RuleMatcher::from_rules(rules).convert_pyerrs()?;

Ok(Self { rmatcher })
}

pub fn find_match(&self, target: &Tk2Circuit) -> PyResult<Option<PyCircuitRewrite>> {
let circ = &target.circ;
if let Some(pmatch) = self.matcher.find_matches_iter(circ).next() {
Ok(Some(self.match_to_rewrite(pmatch, circ)?))
} else {
Ok(None)
}
self.rmatcher
.find_match(circ)
.convert_pyerrs()
.map(|optn| optn.map(|rewrite| rewrite.into()))
}

pub fn find_matches(&self, target: &Tk2Circuit) -> PyResult<Vec<PyCircuitRewrite>> {
let circ = &target.circ;
self.matcher
.find_matches_iter(circ)
.map(|m| self.match_to_rewrite(m, circ))
.collect()
}
}

impl RuleMatcher {
fn match_to_rewrite(
&self,
pmatch: PatternMatch,
target: &Circuit<impl HugrView<Node = Node>>,
) -> PyResult<PyCircuitRewrite> {
let r = self.rights.get(pmatch.pattern_id().0).unwrap().clone();
let rw = pmatch.to_rewrite(target, r).convert_pyerrs()?;
Ok(rw.into())
self.rmatcher
.find_matches(circ)
.convert_pyerrs()
.map(|vec| vec.into_iter().map(|rewrite| rewrite.into()).collect())
}
}
2 changes: 1 addition & 1 deletion tket/src/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub mod pattern;
use hugr::types::EdgeKind;
use hugr::{HugrView, OutgoingPort};
use itertools::Itertools;
pub use matcher::{PatternMatch, PatternMatcher};
pub use matcher::{PatternMatch, PatternMatcher, Rule, RuleMatcher};
pub use pattern::CircuitPattern;

use derive_more::{Display, Error};
Expand Down
82 changes: 80 additions & 2 deletions tket/src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::{
path::{Path, PathBuf},
};

use super::{CircuitPattern, NodeID, PEdge, PNode};
use derive_more::{Display, Error, From};
use super::{pattern::InvalidPattern, CircuitPattern, NodeID, PEdge, PNode};
use derive_more::{Display, Error, From, Into};
use hugr::hugr::views::sibling_subgraph::{
InvalidReplacement, InvalidSubgraph, InvalidSubgraphBoundary, TopoConvexChecker,
};
Expand Down Expand Up @@ -369,6 +369,84 @@ impl PatternMatcher {
}
}

/// A rewrite rule defined by a left hand side and right hand side of an equation.
#[derive(Clone, Debug, From, Into)]
pub struct Rule(pub [Circuit; 2]);

impl Rule {
/// Construct a rule from a pattern and a replacement.
pub fn new(l: Circuit, r: Circuit) -> Self {
Self([l, r])
}

/// The left hand side of the rule.
///
/// This is the pattern that will be matched against the target circuit.
pub fn lhs(&self) -> Circuit {
self.0[0].clone()
}

/// The right hand side of the rule.
///
/// This is the replacement that will be applied to the target circuit.
pub fn rhs(&self) -> Circuit {
self.0[1].clone()
}
}

/// A matcher object for a given set of rewrite rules.
#[derive(Clone, Debug, From, Into)]
pub struct RuleMatcher {
matcher: PatternMatcher,
rights: Vec<Circuit>,
}

impl RuleMatcher {
/// Construct a matcher from a set of rules.
pub fn from_rules(rules: Vec<Rule>) -> Result<Self, InvalidPattern> {
let (lefts, rights): (Vec<_>, Vec<_>) =
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns?);

Ok(Self { matcher, rights })
}

/// Find the first match.
pub fn find_match(
&self,
target: &Circuit,
) -> Result<Option<CircuitRewrite>, InvalidReplacement> {
if let Some(pmatch) = self.matcher.find_matches_iter(target).next() {
Ok(Some(self.match_to_rewrite(pmatch, target)?))
} else {
Ok(None)
}
}

/// Find all matches.
pub fn find_matches(
&self,
target: &Circuit,
) -> Result<Vec<CircuitRewrite>, InvalidReplacement> {
self.matcher
.find_matches_iter(target)
.map(|m| self.match_to_rewrite(m, target))
.collect()
}

fn match_to_rewrite(
&self,
pmatch: PatternMatch,
target: &Circuit<impl HugrView<Node = Node>>,
) -> Result<CircuitRewrite, InvalidReplacement> {
let r = self.rights.get(pmatch.pattern_id().0).unwrap().clone();
let rw = pmatch.to_rewrite(target, r)?;
Ok(rw)
}
}

/// Errors that can occur when constructing matches.
#[derive(Debug, Display, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
Expand Down
Loading