Skip to content

Commit 83ebfcb

Browse files
lmondadaaborgna-q
andauthored
feat(badger): cx and rz const functions and strategies for LexicographicCostFunction (#625)
A couple of points to note: 1. I've made some minor breaking changes to the Rust API of `LexicographicCostFunction`. I think it is cleaner now. 2. I had the choice between keeping `fn` pointers as the cost function type within `LexicographicCostFunction` or moving to `Box<Fn>`. I've stuck to the former for the moment, but I didn't figure out a simple way to reuse the same code for `Tk2Op::CX` and `Tk2Op::RzF64` without using closures. The current code has some duplication as a result, but I think it's bearable. 3. I've tried running badger with `cost_fn='rz'`, but the Rz gate count does not decrease at all. I've looked for an obvious bug but I don't think it is within these changes... Let me know if you disagree with 1. or 2 and what you think we should do about 3. --- ### Changelog metadata BEGIN_COMMIT_OVERRIDE feat: `BadgerOptimiser.load_precompiled`, `BadgerOptimiser.compile_eccs` and `passes.badger_pass` now take an optional `cost_fn` parameter to specify the cost function to minimise. Supported values are `'cx'` (default behaviour) and `'rz'`. END_COMMIT_OVERRIDE --------- Co-authored-by: Agustín Borgna <[email protected]>
1 parent 295b0df commit 83ebfcb

File tree

5 files changed

+150
-21
lines changed

5 files changed

+150
-21
lines changed

tket2-py/src/optimiser.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use std::io::BufWriter;
44
use std::{fs, num::NonZeroUsize, path::PathBuf};
55

6+
use pyo3::exceptions::PyValueError;
67
use pyo3::prelude::*;
78
use tket2::optimiser::badger::BadgerOptions;
89
use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser};
@@ -24,20 +25,60 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
2425
#[pyclass(name = "BadgerOptimiser")]
2526
pub struct PyBadgerOptimiser(DefaultBadgerOptimiser);
2627

28+
/// The cost function to use for the Badger optimiser.
29+
#[derive(Debug, Clone, Copy, Default)]
30+
pub enum BadgerCostFunction {
31+
/// Minimise CX count.
32+
#[default]
33+
CXCount,
34+
/// Minimise Rz count.
35+
RzCount,
36+
}
37+
38+
impl<'py> FromPyObject<'py> for BadgerCostFunction {
39+
fn extract(ob: &'py PyAny) -> PyResult<Self> {
40+
let str = ob.extract::<&str>()?;
41+
match str {
42+
"cx" => Ok(BadgerCostFunction::CXCount),
43+
"rz" => Ok(BadgerCostFunction::RzCount),
44+
_ => Err(PyErr::new::<PyValueError, _>(format!(
45+
"Invalid cost function: {}. Expected 'cx' or 'rz'.",
46+
str
47+
))),
48+
}
49+
}
50+
}
51+
2752
#[pymethods]
2853
impl PyBadgerOptimiser {
2954
/// Create a new [`PyDefaultBadgerOptimiser`] from a precompiled rewriter.
3055
#[staticmethod]
31-
pub fn load_precompiled(path: PathBuf) -> Self {
32-
Self(DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap())
56+
pub fn load_precompiled(path: PathBuf, cost_fn: Option<BadgerCostFunction>) -> Self {
57+
let opt = match cost_fn.unwrap_or_default() {
58+
BadgerCostFunction::CXCount => {
59+
DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap()
60+
}
61+
BadgerCostFunction::RzCount => {
62+
DefaultBadgerOptimiser::rz_opt_with_rewriter_binary(path).unwrap()
63+
}
64+
};
65+
Self(opt)
3366
}
3467

3568
/// Create a new [`PyDefaultBadgerOptimiser`] from ECC sets.
3669
///
3770
/// This will compile the rewriter from the provided ECC JSON file.
3871
#[staticmethod]
39-
pub fn compile_eccs(path: &str) -> Self {
40-
Self(DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap())
72+
pub fn compile_eccs(path: &str, cost_fn: Option<BadgerCostFunction>) -> Self {
73+
let opt = match cost_fn.unwrap_or_default() {
74+
BadgerCostFunction::CXCount => {
75+
DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap()
76+
}
77+
BadgerCostFunction::RzCount => {
78+
DefaultBadgerOptimiser::rz_opt_with_eccs_json_file(path).unwrap()
79+
}
80+
};
81+
Self(opt)
4182
}
4283

4384
/// Run the optimiser on a circuit.

tket2-py/tket2/_tket2/optimiser.pyi

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar
1+
from typing import TypeVar, Literal
22
from .circuit import Tk2Circuit
33
from pytket._tket.circuit import Circuit
44

@@ -8,12 +8,26 @@ CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit)
88

99
class BadgerOptimiser:
1010
@staticmethod
11-
def load_precompiled(filename: Path) -> BadgerOptimiser:
12-
"""Load a precompiled rewriter from a file."""
11+
def load_precompiled(
12+
filename: Path, cost_fn: Literal["cx", "rz"] | None = None
13+
) -> BadgerOptimiser:
14+
"""
15+
Load a precompiled rewriter from a file.
16+
17+
:param filename: The path to the file containing the precompiled rewriter.
18+
:param cost_fn: The cost function to use.
19+
"""
1320

1421
@staticmethod
15-
def compile_eccs(filename: Path) -> BadgerOptimiser:
16-
"""Compile a set of ECCs and create a new rewriter ."""
22+
def compile_eccs(
23+
filename: Path, cost_fn: Literal["cx", "rz"] | None = None
24+
) -> BadgerOptimiser:
25+
"""
26+
Compile a set of ECCs and create a new rewriter.
27+
28+
:param filename: The path to the file containing the ECCs.
29+
:param cost_fn: The cost function to use.
30+
"""
1731

1832
def optimise(
1933
self,

tket2-py/tket2/passes.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Optional
2+
from typing import Optional, Literal
33

44
from pytket import Circuit
55
from pytket.passes import CustomPass, BasePass
@@ -37,13 +37,17 @@ def badger_pass(
3737
max_circuit_count: Optional[int] = None,
3838
log_dir: Optional[Path] = None,
3939
rebase: bool = False,
40+
cost_fn: Literal["cx", "rz"] | None = None,
4041
) -> BasePass:
4142
"""Construct a Badger pass.
4243
4344
The Badger optimiser requires a pre-compiled rewriter produced by the
4445
`compile-rewriter <https://github.com/CQCL/tket2/tree/main/badger-optimiser>`_
4546
utility. If `rewriter` is not specified, a default one will be used.
4647
48+
The cost function to minimise can be specified by passing `cost_fn` as `'cx'`
49+
or `'rz'`. If not specified, the default is `'cx'`.
50+
4751
The arguments `max_threads`, `timeout`, `progress_timeout`, `max_circuit_count`,
4852
`log_dir` and `rebase` are optional and will be passed on to the Badger
4953
optimiser if provided."""
@@ -56,7 +60,7 @@ def badger_pass(
5660
)
5761

5862
rewriter = tket2_eccs.nam_6_3()
59-
opt = optimiser.BadgerOptimiser.load_precompiled(rewriter)
63+
opt = optimiser.BadgerOptimiser.load_precompiled(rewriter, cost_fn=cost_fn)
6064

6165
def apply(circuit: Circuit) -> Circuit:
6266
"""Apply Badger optimisation to the circuit."""

tket2/src/optimiser/badger.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ mod badger_default {
518518
/// A sane default optimiser using the given ECC sets.
519519
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
520520
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
521-
let strategy = LexicographicCostFunction::default_cx();
521+
let strategy = LexicographicCostFunction::default_cx_strategy();
522522
Ok(BadgerOptimiser::new(rewriter, strategy))
523523
}
524524

@@ -528,7 +528,24 @@ mod badger_default {
528528
rewriter_path: impl AsRef<Path>,
529529
) -> Result<Self, RewriterSerialisationError> {
530530
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
531-
let strategy = LexicographicCostFunction::default_cx();
531+
let strategy = LexicographicCostFunction::default_cx_strategy();
532+
Ok(BadgerOptimiser::new(rewriter, strategy))
533+
}
534+
535+
/// An optimiser minimising Rz gate count using the given ECC sets.
536+
pub fn rz_opt_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
537+
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
538+
let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy();
539+
Ok(BadgerOptimiser::new(rewriter, strategy))
540+
}
541+
542+
/// An optimiser minimising Rz gate count using a precompiled binary rewriter.
543+
#[cfg(feature = "binary-eccs")]
544+
pub fn rz_opt_with_rewriter_binary(
545+
rewriter_path: impl AsRef<Path>,
546+
) -> Result<Self, RewriterSerialisationError> {
547+
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
548+
let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy();
532549
Ok(BadgerOptimiser::new(rewriter, strategy))
533550
}
534551
}

tket2/src/rewrite/strategy.rs

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//! not increase some coarse cost function (e.g. CX count), whilst
1717
//! ordering them according to a lexicographic ordering of finer cost
1818
//! functions (e.g. total gate count). See
19-
//! [`LexicographicCostFunction::default_cx`]) for a default implementation.
19+
//! [`LexicographicCostFunction::default_cx_strategy`]) for a default implementation.
2020
//! - [`GammaStrategyCost`] ignores rewrites that increase the cost
2121
//! function beyond a percentage given by a f64 parameter gamma.
2222
@@ -29,7 +29,7 @@ use hugr::HugrView;
2929
use itertools::Itertools;
3030

3131
use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, LexicographicCost};
32-
use crate::Circuit;
32+
use crate::{op_matches, Circuit, Tk2Op};
3333

3434
use super::trace::RewriteTrace;
3535
use super::CircuitRewrite;
@@ -345,12 +345,66 @@ impl LexicographicCostFunction<fn(&OpType) -> usize, 2> {
345345
/// is used to rank circuits with equal CX count.
346346
///
347347
/// This is probably a good default for NISQ-y circuit optimisation.
348-
#[inline]
348+
pub fn default_cx_strategy() -> ExhaustiveGreedyStrategy<Self> {
349+
Self::cx_count().into_greedy_strategy()
350+
}
351+
352+
/// Non-increasing rewrite strategy based on CX count.
353+
///
354+
/// A fine-grained cost function given by the total number of quantum gates
355+
/// is used to rank circuits with equal CX count.
356+
///
357+
/// This is probably a good default for NISQ-y circuit optimisation.
358+
///
359+
/// Deprecated: Use `default_cx_strategy` instead.
360+
// TODO: Remove this method in the next breaking release.
361+
#[deprecated(since = "0.5.1", note = "Use `default_cx_strategy` instead.")]
349362
pub fn default_cx() -> ExhaustiveGreedyStrategy<Self> {
363+
Self::default_cx_strategy()
364+
}
365+
366+
/// Non-increasing rewrite cost function based on CX gate count.
367+
///
368+
/// A fine-grained cost function given by the total number of quantum gates
369+
/// is used to rank circuits with equal Rz gate count.
370+
#[inline]
371+
pub fn cx_count() -> Self {
350372
Self {
351373
cost_fns: [|op| is_cx(op) as usize, |op| is_quantum(op) as usize],
352374
}
353-
.into()
375+
}
376+
377+
// TODO: Ideally, do not count Clifford rotations in the cost function.
378+
/// Non-increasing rewrite cost function based on Rz gate count.
379+
///
380+
/// A fine-grained cost function given by the total number of quantum gates
381+
/// is used to rank circuits with equal Rz gate count.
382+
#[inline]
383+
pub fn rz_count() -> Self {
384+
Self {
385+
cost_fns: [
386+
|op| op_matches(op, Tk2Op::Rz) as usize,
387+
|op| is_quantum(op) as usize,
388+
],
389+
}
390+
}
391+
392+
/// Consume the cost function and create a greedy rewrite strategy out of
393+
/// it.
394+
pub fn into_greedy_strategy(self) -> ExhaustiveGreedyStrategy<Self> {
395+
ExhaustiveGreedyStrategy { strat_cost: self }
396+
}
397+
398+
/// Consume the cost function and create a threshold rewrite strategy out
399+
/// of it.
400+
pub fn into_threshold_strategy(self) -> ExhaustiveThresholdStrategy<Self> {
401+
ExhaustiveThresholdStrategy { strat_cost: self }
402+
}
403+
}
404+
405+
impl Default for LexicographicCostFunction<fn(&OpType) -> usize, 2> {
406+
fn default() -> Self {
407+
LexicographicCostFunction::cx_count()
354408
}
355409
}
356410

@@ -440,7 +494,6 @@ mod tests {
440494
circuit::Circuit,
441495
rewrite::{CircuitRewrite, Subcircuit},
442496
utils::build_simple_circuit,
443-
Tk2Op,
444497
};
445498

446499
fn n_cx(n_gates: usize) -> Circuit {
@@ -512,7 +565,7 @@ mod tests {
512565
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
513566
];
514567

515-
let strategy = LexicographicCostFunction::default_cx();
568+
let strategy = LexicographicCostFunction::cx_count().into_greedy_strategy();
516569
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
517570
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
518571
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect();
@@ -557,7 +610,7 @@ mod tests {
557610

558611
#[test]
559612
fn test_exhaustive_default_cx_cost() {
560-
let strat = LexicographicCostFunction::default_cx();
613+
let strat = LexicographicCostFunction::cx_count().into_greedy_strategy();
561614
let circ = n_cx(3);
562615
assert_eq!(strat.circuit_cost(&circ), (3, 3).into());
563616
let circ = build_simple_circuit(2, |circ| {
@@ -572,7 +625,7 @@ mod tests {
572625

573626
#[test]
574627
fn test_exhaustive_default_cx_threshold() {
575-
let strat = LexicographicCostFunction::default_cx().strat_cost;
628+
let strat = LexicographicCostFunction::cx_count();
576629
assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into()));
577630
assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into()));
578631
assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into()));

0 commit comments

Comments
 (0)