diff --git a/tket2-exts/pyproject.toml b/tket2-exts/pyproject.toml index ad8660a2c..c456e4d46 100644 --- a/tket2-exts/pyproject.toml +++ b/tket2-exts/pyproject.toml @@ -34,4 +34,4 @@ repository = "https://github.com/CQCL/tket2/tree/main/tket2-exts" [build-system] requires = ["hatchling"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" diff --git a/tket2-py/src/circuit.rs b/tket2-py/src/circuit.rs index d139c3849..3a843ae9e 100644 --- a/tket2-py/src/circuit.rs +++ b/tket2-py/src/circuit.rs @@ -76,7 +76,7 @@ create_py_exception!( ); create_py_exception!( - tket2::serialize::pytket::TK1ConvertError, + tket2::serialize::pytket::Tk1ConvertError, PyTK1ConvertError, "Error type for the conversion between tket2 and tket1 operations." ); diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 8bb87cb29..ac650f673 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -50,14 +50,17 @@ itertools = { workspace = true } petgraph = { workspace = true } portmatching = { workspace = true, optional = true, features = ["serde"] } derive_more = { workspace = true, features = [ - "error", + "debug", "display", + "error", "from", "into", + "sum", + "add", ] } hugr = { workspace = true } hugr-core = { workspace = true } -portgraph = { workspace = true, features = ["serde"] } +portgraph = { workspace = true, features = ["serde", "petgraph"] } strum = { workspace = true, features = ["derive"] } fxhash = { workspace = true } indexmap = { workspace = true } diff --git a/tket2/src/extension.rs b/tket2/src/extension.rs index 43724d867..3453f4e20 100644 --- a/tket2/src/extension.rs +++ b/tket2/src/extension.rs @@ -4,7 +4,7 @@ use std::sync::Arc; -use crate::serialize::pytket::OpaqueTk1Op; +use crate::serialize::pytket::extension::OpaqueTk1Op; use crate::Tk2Op; use hugr::extension::simple_op::MakeOpDef; use hugr::extension::{ @@ -32,10 +32,10 @@ use sympy::SympyOpDef; pub const TKET1_EXTENSION_ID: ExtensionId = IdentList::new_unchecked("TKET1"); /// The name for opaque TKET1 operations. -pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Op"); +pub const TKET1_OP_NAME: SmolStr = SmolStr::new_inline("tk1op"); /// The ID of an opaque TKET1 operation metadata. -pub const TKET1_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1 Json Payload"); +pub const TKET1_PAYLOAD_NAME: SmolStr = SmolStr::new_inline("TKET1-json-payload"); /// Current version of the TKET 1 extension pub const TKET1_EXTENSION_VERSION: Version = Version::new(0, 1, 0); diff --git a/tket2/src/extension/rotation.rs b/tket2/src/extension/rotation.rs index da2a0b265..986d30f72 100644 --- a/tket2/src/extension/rotation.rs +++ b/tket2/src/extension/rotation.rs @@ -31,7 +31,7 @@ lazy_static! { } /// Identifier for the rotation type. -const ROTATION_TYPE_ID: SmolStr = SmolStr::new_inline("rotation"); +pub const ROTATION_TYPE_ID: SmolStr = SmolStr::new_inline("rotation"); /// Rotation type (as [CustomType]) pub fn rotation_custom_type(extension_ref: &Weak) -> CustomType { CustomType::new( diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index dac1000b6..a001043bc 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, Weak}; use crate::extension::bool::bool_type; use crate::extension::rotation::rotation_type; -use crate::extension::sympy::{SympyOpDef, SYM_OP_ID}; +use crate::extension::sympy::SympyOpDef; use crate::extension::{TKET2_EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID}; use hugr::ops::custom::ExtensionOp; use hugr::types::Type; @@ -14,7 +14,7 @@ use hugr::{ }, ops::OpType, type_row, - types::{type_param::TypeArg, Signature}, + types::Signature, }; use derive_more::{Display, Error}; @@ -190,29 +190,6 @@ pub fn symbolic_constant_op(arg: String) -> OpType { SympyOpDef.with_expr(arg).into() } -/// match against a symbolic constant -pub(crate) fn match_symb_const_op(op: &OpType) -> Option { - // Extract the symbol for a symbolic operation node. - let symbol_from_typeargs = |args: &[TypeArg]| -> String { - args.first() - .and_then(|arg| match arg { - TypeArg::String { arg } => Some(arg.clone()), - _ => None, - }) - .unwrap_or_else(|| panic!("Found an invalid type arg in a symbolic operation node.")) - }; - - if let OpType::ExtensionOp(e) = op { - if e.def().name() == &SYM_OP_ID && e.def().extension_id() == &EXTENSION_ID { - Some(symbol_from_typeargs(e.args())) - } else { - None - } - } else { - None - } -} - #[cfg(test)] pub(crate) mod test { diff --git a/tket2/src/serialize/pytket.rs b/tket2/src/serialize/pytket.rs index 51ba823ca..cfc9a2d07 100644 --- a/tket2/src/serialize/pytket.rs +++ b/tket2/src/serialize/pytket.rs @@ -1,16 +1,16 @@ //! Serialization and deserialization of circuits using the `pytket` JSON format. mod decoder; -mod encoder; -mod op; -mod param; +pub mod encoder; +pub mod extension; -use hugr::types::Type; +pub use encoder::{default_encoder_config, Tk1EncoderConfig, Tk1EncoderContext}; +pub use extension::PytketEmitter; -use hugr::Node; +use hugr::core::HugrNode; + +use hugr::{Hugr, Wire}; use itertools::Itertools; -// Required for serialising ops in the tket1 hugr extension. -pub(crate) use op::serialised::OpaqueTk1Op; #[cfg(test)] mod tests; @@ -24,13 +24,11 @@ use hugr::ops::OpType; use derive_more::{Display, Error, From}; use tket_json_rs::circuit_json::SerialCircuit; -use tket_json_rs::optype::OpType as SerialOpType; use tket_json_rs::register::{Bit, ElementId, Qubit}; use crate::circuit::Circuit; -use self::decoder::Tk1Decoder; -use self::encoder::Tk1Encoder; +use self::decoder::Tk1DecoderContext; pub use crate::passes::pytket::lower_to_pytket; @@ -61,16 +59,28 @@ pub trait TKETDecode: Sized { type EncodeError; /// Convert the serialized circuit to a circuit. fn decode(self) -> Result; - /// Convert a circuit to a new serialized circuit. + /// Convert a circuit to a serialized pytket circuit. + /// + /// Uses a default set of emitters to translate operations. + /// If the circuit contains non-std operations or types, + /// use [`TKETDecode::encode_with_config`] instead. fn encode(circuit: &Circuit) -> Result; + /// Convert a circuit to a serialized pytket circuit. + /// + /// You may use [`TKETDecode::encode`] if the circuit does not contain + /// non-std operations or types. + fn encode_with_config( + circuit: &Circuit, + config: Tk1EncoderConfig, + ) -> Result; } impl TKETDecode for SerialCircuit { - type DecodeError = TK1ConvertError; - type EncodeError = TK1ConvertError; + type DecodeError = Tk1ConvertError; + type EncodeError = Tk1ConvertError; fn decode(self) -> Result { - let mut decoder = Tk1Decoder::try_new(&self)?; + let mut decoder = Tk1DecoderContext::try_new(&self)?; if !self.phase.is_empty() { // TODO - add a phase gate @@ -84,32 +94,37 @@ impl TKETDecode for SerialCircuit { Ok(decoder.finish().into()) } - fn encode(circ: &Circuit) -> Result { - let mut encoder = Tk1Encoder::new(circ)?; - for com in circ.commands() { - let optype = com.optype(); - encoder.add_command(com.clone(), optype)?; - } - Ok(encoder.finish(circ)) + fn encode(circuit: &Circuit) -> Result { + let config = default_encoder_config(); + Self::encode_with_config(circuit, config) + } + + fn encode_with_config( + circuit: &Circuit, + config: Tk1EncoderConfig, + ) -> Result { + let mut encoder = Tk1EncoderContext::new(circuit, config)?; + encoder.run_encoder(circuit)?; + encoder.finish(circuit) } } /// Load a TKET1 circuit from a JSON file. -pub fn load_tk1_json_file(path: impl AsRef) -> Result { +pub fn load_tk1_json_file(path: impl AsRef) -> Result { let file = fs::File::open(path)?; let reader = io::BufReader::new(file); load_tk1_json_reader(reader) } /// Load a TKET1 circuit from a JSON reader. -pub fn load_tk1_json_reader(json: impl io::Read) -> Result { +pub fn load_tk1_json_reader(json: impl io::Read) -> Result { let ser: SerialCircuit = serde_json::from_reader(json)?; let circ: Circuit = ser.decode()?; Ok(circ) } /// Load a TKET1 circuit from a JSON string. -pub fn load_tk1_json_str(json: &str) -> Result { +pub fn load_tk1_json_str(json: &str) -> Result { let reader = json.as_bytes(); load_tk1_json_reader(reader) } @@ -122,7 +137,7 @@ pub fn load_tk1_json_str(json: &str) -> Result { /// /// Returns an error if the circuit is not flat or if it contains operations not /// supported by pytket. -pub fn save_tk1_json_file(circ: &Circuit, path: impl AsRef) -> Result<(), TK1ConvertError> { +pub fn save_tk1_json_file(circ: &Circuit, path: impl AsRef) -> Result<(), Tk1ConvertError> { let file = fs::File::create(path)?; let writer = io::BufWriter::new(file); save_tk1_json_writer(circ, writer) @@ -136,7 +151,7 @@ pub fn save_tk1_json_file(circ: &Circuit, path: impl AsRef) -> Result<(), /// /// Returns an error if the circuit is not flat or if it contains operations not /// supported by pytket. -pub fn save_tk1_json_writer(circ: &Circuit, w: impl io::Write) -> Result<(), TK1ConvertError> { +pub fn save_tk1_json_writer(circ: &Circuit, w: impl io::Write) -> Result<(), Tk1ConvertError> { let serial_circ = SerialCircuit::encode(circ)?; serde_json::to_writer(w, &serial_circ)?; Ok(()) @@ -150,82 +165,23 @@ pub fn save_tk1_json_writer(circ: &Circuit, w: impl io::Write) -> Result<(), TK1 /// /// Returns an error if the circuit is not flat or if it contains operations not /// supported by pytket. -pub fn save_tk1_json_str(circ: &Circuit) -> Result { +pub fn save_tk1_json_str(circ: &Circuit) -> Result { let mut buf = io::BufWriter::new(Vec::new()); save_tk1_json_writer(circ, &mut buf)?; let bytes = buf.into_inner().unwrap(); Ok(String::from_utf8(bytes)?) } -/// Error type for conversion between `Op` and `OpType`. -#[derive(Display, Debug, Error, From)] +/// Error type for conversion between pytket operations and tket2 ops. +#[derive(Display, derive_more::Debug, Error)] #[non_exhaustive] -pub enum OpConvertError { +#[debug(bounds(N: HugrNode))] +pub enum OpConvertError { /// The serialized operation is not supported. - #[display("Unsupported serialized pytket operation: {_0:?}")] - #[error(ignore)] // `_0` is not the error source - UnsupportedSerializedOp(SerialOpType), - /// The serialized operation is not supported. - #[display("Cannot serialize tket2 operation: {_0:?}")] - #[error(ignore)] // `_0` is not the error source - UnsupportedOpSerialization(OpType), - /// The operation has non-serializable inputs. - #[display( - "Operation {} in {node} has an unsupported input of type {typ}.", - optype - )] - UnsupportedInputType { - /// The unsupported type. - typ: Type, - /// The operation name. - optype: OpType, - /// The node. - node: Node, - }, - /// The operation has non-serializable outputs. - #[display( - "Operation {} in {node} has an unsupported output of type {typ}.", - optype - )] - UnsupportedOutputType { - /// The unsupported type. - typ: Type, - /// The operation name. - optype: OpType, - /// The node. - node: Node, - }, - /// A parameter input could not be evaluated. - #[display( - "The {typ} parameter input for operation {} in {node} could not be resolved.", - optype - )] - UnresolvedParamInput { - /// The parameter type. - typ: Type, - /// The operation with the missing input param. - optype: OpType, - /// The node. - node: Node, - }, - /// The operation has output-only qubits. - /// This is not currently supported by the encoder. - #[display("Operation {} in {node} has more output qubits than inputs.", optype)] - TooManyOutputQubits { - /// The unsupported type. - typ: Type, - /// The operation name. - optype: OpType, - /// The node. - node: Node, - }, - /// The opaque tket1 operation had an invalid type parameter. - #[display("Opaque TKET1 operation had an invalid type parameter. {error}")] - #[from] - InvalidOpaqueTypeParam { - /// The serialization error. - #[error(source)] - error: serde_json::Error, + #[display("Cannot serialize tket2 operation: {op}")] + UnsupportedOpSerialization { + /// The operation. + op: OpType, }, /// Tried to decode a tket1 operation with not enough parameters. #[display( @@ -256,21 +212,32 @@ pub enum OpConvertError { /// The given of parameters. args: Vec, }, + /// Tried to query the values associated with an unexplored wire. + /// + /// This reflects a bug in the operation encoding logic of an operation. + #[display("Could not find values associated with wire {wire}.")] + WireHasNoValues { + /// The wire that has no values. + wire: Wire, + }, + /// Tried to add values to an already registered wire. + /// + /// This reflects a bug in the operation encoding logic of an operation. + #[display("Tried to register values for wire {wire}, but it already has associated values.")] + WireAlreadyHasValues { + /// The wire that already has values. + wire: Wire, + }, } -/// Error type for conversion between `Op` and `OpType`. -#[derive(Debug, Display, Error, From)] +/// Error type for conversion between tket2 ops and pytket operations. +#[derive(derive_more::Debug, Display, Error, From)] #[non_exhaustive] -pub enum TK1ConvertError { +#[debug(bounds(N: HugrNode))] +pub enum Tk1ConvertError { /// Operation conversion error. #[from] - OpConversionError(OpConvertError), - /// The circuit has non-serializable inputs. - #[display("Circuit contains non-serializable input of type {typ}.")] - NonSerializableInputs { - /// The unsupported type. - typ: Type, - }, + OpConversionError(OpConvertError), /// The circuit uses multi-indexed registers. // // This could be supported in the future, if there is a need for it. @@ -291,6 +258,19 @@ pub enum TK1ConvertError { #[display("Unable to load pytket json file. {_0}")] #[from] FileLoadError(io::Error), + /// Custom user-defined error raised while encoding an operation. + #[display("Error while encoding operation: {msg}")] + CustomError { + /// The custom error message + msg: String, + }, +} + +impl Tk1ConvertError { + /// Create a new error with a custom message. + pub fn custom(msg: impl Into) -> Self { + Self::CustomError { msg: msg.into() } + } } /// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map, diff --git a/tket2/src/serialize/pytket/decoder.rs b/tket2/src/serialize/pytket/decoder.rs index 650ca449d..296537d45 100644 --- a/tket2/src/serialize/pytket/decoder.rs +++ b/tket2/src/serialize/pytket/decoder.rs @@ -1,5 +1,8 @@ //! Intermediate structure for decoding [`SerialCircuit`]s into [`Hugr`]s. +mod op; +mod param; + use std::collections::{HashMap, HashSet}; use hugr::builder::{Container, Dataflow, DataflowHugr, FunctionBuilder}; @@ -19,22 +22,22 @@ use tket_json_rs::circuit_json; use tket_json_rs::circuit_json::SerialCircuit; use tket_json_rs::register; -use super::op::Tk1Op; -use super::param::decode::{parse_pytket_param, PytketParam}; use super::{ - OpConvertError, RegisterHash, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, + OpConvertError, RegisterHash, Tk1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, }; use crate::extension::rotation::{rotation_type, RotationOp}; use crate::serialize::pytket::METADATA_INPUT_PARAMETERS; use crate::symbolic_constant_op; +use op::Tk1Op; +use param::{parse_pytket_param, PytketParam}; /// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`]. /// /// Mostly used to define helper internal methods. #[derive(Debug, Clone)] -pub(super) struct Tk1Decoder { +pub(super) struct Tk1DecoderContext { /// The Hugr being built. pub hugr: FunctionBuilder, /// A map from the tracked pytket registers to the [`Wire`]s in the circuit. @@ -47,9 +50,9 @@ pub(super) struct Tk1Decoder { parameters: IndexMap, } -impl Tk1Decoder { +impl Tk1DecoderContext { /// Initialize a new [`Tk1Decoder`], using the metadata from a [`SerialCircuit`]. - pub fn try_new(serialcirc: &SerialCircuit) -> Result { + pub fn try_new(serialcirc: &SerialCircuit) -> Result { let num_qubits = serialcirc.qubits.len(); let num_bits = serialcirc.bits.len(); let sig = @@ -107,7 +110,7 @@ impl Tk1Decoder { check_register(reg)?; Ok(RegisterHash::from(reg)) }) - .collect::, TK1ConvertError>>()?; + .collect::, Tk1ConvertError>>()?; // Map each register element to their starting wire. let register_wires: HashMap = ordered_registers @@ -116,7 +119,7 @@ impl Tk1Decoder { .zip(dangling_wires) .collect(); - Ok(Tk1Decoder { + Ok(Tk1DecoderContext { hugr: dfg, register_wires, ordered_registers, @@ -342,9 +345,9 @@ impl Tk1Decoder { } /// Only single-indexed registers are supported. -fn check_register(register: ®ister::ElementId) -> Result<(), TK1ConvertError> { +fn check_register(register: ®ister::ElementId) -> Result<(), Tk1ConvertError> { if register.1.len() != 1 { - Err(TK1ConvertError::MultiIndexedRegister { + Err(Tk1ConvertError::MultiIndexedRegister { register: register.0.clone(), }) } else { diff --git a/tket2/src/serialize/pytket/op.rs b/tket2/src/serialize/pytket/decoder/op.rs similarity index 72% rename from tket2/src/serialize/pytket/op.rs rename to tket2/src/serialize/pytket/decoder/op.rs index 3610bd420..41f62e65d 100644 --- a/tket2/src/serialize/pytket/op.rs +++ b/tket2/src/serialize/pytket/decoder/op.rs @@ -6,23 +6,21 @@ //! circuits by ensuring they always define a signature, and computing the //! explicit count of qubits and linear bits. -mod native; -pub(crate) mod serialised; - use hugr::ops::OpType; use hugr::IncomingPort; +use tk2op::NativeOp; use tket_json_rs::circuit_json; -use self::native::NativeOp; -use self::serialised::OpaqueTk1Op; -use super::OpConvertError; +use super::super::extension::OpaqueTk1Op; + +pub mod tk2op; /// An intermediary artifact when converting between TKET1 and TKET2 operations. /// /// This enum represents either operations that can be represented natively in TKET2, /// or operations that must be serialised as opaque TKET1 operations. #[derive(Clone, Debug, PartialEq, derive_more::From)] -pub enum Tk1Op { +pub(crate) enum Tk1Op { /// An operation with a native TKET2 counterpart. Native(NativeOp), /// An operation without a native TKET2 counterpart. @@ -30,32 +28,6 @@ pub enum Tk1Op { } impl Tk1Op { - /// Create a new `Tk1Op` from a hugr optype. - /// - /// Supports either native `Tk2Op`s or serialised tket1 `CustomOps`s. - /// - /// # Errors - /// - /// Returns an error if the operation is not supported by the TKET1 serialization. - pub fn try_from_optype(op: OpType) -> Result, OpConvertError> { - if let Some(tk2op) = op.cast() { - let native = NativeOp::try_from_tk2op(tk2op) - .ok_or_else(|| OpConvertError::UnsupportedOpSerialization(op))?; - // Skip serialisation for some special cases. - if native.serial_op().is_none() { - return Ok(None); - } - Ok(Some(Tk1Op::Native(native))) - } else { - // Unrecognised opaque operation. If it's an opaque tket1 op, return it. - // Otherwise, it's an unsupported operation and we should fail. - match OpaqueTk1Op::try_from_tket2(&op)? { - Some(opaque) => Ok(Some(Tk1Op::Opaque(opaque))), - None => Err(OpConvertError::UnsupportedOpSerialization(op.clone())), - } - } - } - /// Create a new `Tk1Op` from a tket1 `circuit_json::Operation`. /// /// If `serial_op` defines a signature then `num_qubits` and `num_qubits` are ignored. Otherwise, a signature is synthesised from those parameters. @@ -90,14 +62,6 @@ impl Tk1Op { } } - /// Get the [`tket_json_rs::circuit_json::Operation`] for the operation. - pub fn serialised_op(&self) -> Option { - match self { - Tk1Op::Native(native_op) => native_op.serialised_op(), - Tk1Op::Opaque(json_op) => Some(json_op.serialised_op().clone()), - } - } - /// Returns the ports corresponding to parameters for this operation. pub fn param_ports(&self) -> impl Iterator + '_ { match self { diff --git a/tket2/src/serialize/pytket/op/native.rs b/tket2/src/serialize/pytket/decoder/op/tk2op.rs similarity index 54% rename from tket2/src/serialize/pytket/op/native.rs rename to tket2/src/serialize/pytket/decoder/op/tk2op.rs index 4599f31df..45f4525ea 100644 --- a/tket2/src/serialize/pytket/op/native.rs +++ b/tket2/src/serialize/pytket/decoder/op/tk2op.rs @@ -1,4 +1,4 @@ -//! Operations that have corresponding representations in both `pytket` and `tket2`. +//! Encoder and decoder for tket2 operations with native pytket counterparts. use std::borrow::Cow; @@ -9,7 +9,6 @@ use hugr::std_extensions::arithmetic::float_types::float64_type; use hugr::types::Signature; use hugr::IncomingPort; -use tket_json_rs::circuit_json; use tket_json_rs::optype::OpType as Tk1OpType; use crate::extension::rotation::rotation_type; @@ -19,7 +18,7 @@ use crate::Tk2Op; /// /// Note that the signature of the native and serialised operations may differ. #[derive(Clone, Debug, PartialEq, Default)] -pub struct NativeOp { +pub(crate) struct NativeOp { /// The tket2 optype. op: OpType, /// The corresponding serialised optype. @@ -51,40 +50,6 @@ impl NativeOp { native_op } - /// Create a new `NativeOp` from a `circuit_json::Operation`. - pub fn try_from_tk2op(tk2op: Tk2Op) -> Option { - let serial_op = match tk2op { - Tk2Op::H => Tk1OpType::H, - Tk2Op::CX => Tk1OpType::CX, - Tk2Op::CY => Tk1OpType::CY, - Tk2Op::CZ => Tk1OpType::CZ, - Tk2Op::CRz => Tk1OpType::CRz, - Tk2Op::T => Tk1OpType::T, - Tk2Op::Tdg => Tk1OpType::Tdg, - Tk2Op::S => Tk1OpType::S, - Tk2Op::Sdg => Tk1OpType::Sdg, - Tk2Op::X => Tk1OpType::X, - Tk2Op::Y => Tk1OpType::Y, - Tk2Op::Z => Tk1OpType::Z, - Tk2Op::Rx => Tk1OpType::Rx, - Tk2Op::Rz => Tk1OpType::Rz, - Tk2Op::Ry => Tk1OpType::Ry, - Tk2Op::Toffoli => Tk1OpType::CCX, - Tk2Op::Reset => Tk1OpType::Reset, - Tk2Op::Measure => Tk1OpType::Measure, - // These operations do not have a direct pytket counterpart. - Tk2Op::MeasureFree => return None, - Tk2Op::QAlloc | Tk2Op::QFree | Tk2Op::TryQAlloc => { - // These operations are implicitly supported by the encoding, - // they do not create an explicit pytket operation but instead - // add new qubits to the circuit input/output. - return Some(Self::new(tk2op.into(), None)); - } - }; - - Some(Self::new(tk2op.into(), Some(serial_op))) - } - /// Returns the translated tket2 optype for this operation, if it exists. pub fn try_from_serial_optype(serial_op: Tk1OpType) -> Option { let op = match serial_op { @@ -114,39 +79,11 @@ impl NativeOp { Some(Self::new(op, Some(serial_op))) } - /// Converts this `NativeOp` into a tket_json_rs operation. - pub fn serialised_op(&self) -> Option { - let serial_op = self.serial_op.clone()?; - - // Since pytket operations are always linear, - // use the maximum of input and output bits/qubits. - let num_qubits = self.input_qubits.max(self.output_qubits); - let num_bits = self.input_bits.max(self.output_bits); - let num_params = self.num_params; - - let params = (num_params > 0).then(|| vec!["".into(); num_params]); - - let mut op = circuit_json::Operation::default(); - op.op_type = serial_op; - op.n_qb = Some(num_qubits as u32); - op.params = params; - op.signature = Some([vec!["Q".into(); num_qubits], vec!["B".into(); num_bits]].concat()); - Some(op) - } - /// Returns the dataflow signature for this operation. pub fn signature(&self) -> Option> { self.op.dataflow_signature() } - /// Returns the serial optype for this operation. - /// - /// Some special operations do not have a direct serialised counterpart, and - /// should be skipped during serialisation. - pub fn serial_op(&self) -> Option<&Tk1OpType> { - self.serial_op.as_ref() - } - /// Returns the tket2 optype for this operation. pub fn optype(&self) -> &OpType { &self.op @@ -196,35 +133,3 @@ impl NativeOp { } } } - -#[cfg(test)] -mod cfg { - use super::*; - use rstest::rstest; - use strum::IntoEnumIterator; - - #[rstest] - fn tk2_optype_correspondence() { - for tk2op in Tk2Op::iter() { - let Some(native_op) = NativeOp::try_from_tk2op(tk2op) else { - // Ignore unsupported ops. - continue; - }; - - let Some(serial_op) = native_op.serial_op.clone() else { - // Ignore ops that do not have a serialised equivalent. - // (But are still handled by the encoder). - continue; - }; - - let Some(native_op2) = NativeOp::try_from_serial_optype(serial_op.clone()) else { - panic!( - "{} serialises into {serial_op:?}, but failed to be deserialised.", - tk2op.exposed_name() - ) - }; - - assert_eq!(native_op, native_op2); - } - } -} diff --git a/tket2/src/serialize/pytket/param/param.pest b/tket2/src/serialize/pytket/decoder/param.pest similarity index 100% rename from tket2/src/serialize/pytket/param/param.pest rename to tket2/src/serialize/pytket/decoder/param.pest diff --git a/tket2/src/serialize/pytket/param/decode.rs b/tket2/src/serialize/pytket/decoder/param.rs similarity index 99% rename from tket2/src/serialize/pytket/param/decode.rs rename to tket2/src/serialize/pytket/decoder/param.rs index 186687a94..6d475af6c 100644 --- a/tket2/src/serialize/pytket/param/decode.rs +++ b/tket2/src/serialize/pytket/decoder/param.rs @@ -55,7 +55,7 @@ pub fn parse_pytket_param(param: &str) -> PytketParam<'_> { } #[derive(Parser)] -#[grammar = "serialize/pytket/param/param.pest"] +#[grammar = "serialize/pytket/decoder/param.pest"] struct ParamParser; lazy_static::lazy_static! { diff --git a/tket2/src/serialize/pytket/encoder.rs b/tket2/src/serialize/pytket/encoder.rs index 7e0d67311..cfd77412a 100644 --- a/tket2/src/serialize/pytket/encoder.rs +++ b/tket2/src/serialize/pytket/encoder.rs @@ -1,694 +1,796 @@ //! Intermediate structure for encoding [`Circuit`]s into [`SerialCircuit`]s. -use core::panic; -use std::collections::{HashMap, HashSet, VecDeque}; +mod config; +mod unit_generator; +mod unsupported_tracker; +mod value_tracker; + +pub use config::{default_encoder_config, Tk1EncoderConfig}; +use hugr::envelope::EnvelopeConfig; +use hugr::hugr::views::SiblingSubgraph; +use hugr::package::Package; +use hugr_core::hugr::internal::PortgraphNodeMap; +pub use value_tracker::{ + RegisterCount, TrackedBit, TrackedParam, TrackedQubit, TrackedValue, TrackedValues, + ValueTracker, +}; -use hugr::extension::prelude::{bool_t, qb_t}; use hugr::ops::{OpTrait, OpType}; -use hugr::std_extensions::arithmetic::float_types::float64_type; -use hugr::types::Type; -use hugr::{HugrView, Node, Wire}; +use hugr::types::EdgeKind; + +use std::borrow::Cow; +use std::collections::{BTreeSet, HashSet}; +use std::sync::Arc; + +use hugr::{HugrView, Wire}; use itertools::Itertools; use tket_json_rs::circuit_json::{self, SerialCircuit}; -use tket_json_rs::register::ElementId as RegisterUnit; +use unsupported_tracker::UnsupportedTracker; -use crate::circuit::command::{CircuitUnit, Command}; -use crate::circuit::Circuit; -use crate::extension::rotation::rotation_type; -use crate::serialize::pytket::RegisterHash; -use crate::Tk2Op; - -use super::op::Tk1Op; -use super::param::encode::fold_param_op; use super::{ - OpConvertError, TK1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_B_REGISTERS, - METADATA_INPUT_PARAMETERS, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_OUTPUT_REGISTERS, - METADATA_Q_REGISTERS, + OpConvertError, Tk1ConvertError, METADATA_B_OUTPUT_REGISTERS, METADATA_OPGROUP, METADATA_PHASE, + METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, }; +use crate::circuit::Circuit; /// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`]. -#[derive(Debug, Clone)] -pub(super) struct Tk1Encoder { +#[derive(derive_more::Debug)] +#[debug(bounds(H: HugrView))] +pub struct Tk1EncoderContext { /// The name of the circuit being encoded. name: Option, - /// Global phase value. Defaults to "0" + /// Global phase value. + /// + /// Defaults to "0" unless the circuit has a [METADATA_PHASE] metadata + /// entry. phase: String, - /// The current serialised commands + /// The already-encoded serialised pytket commands. commands: Vec, - /// A tracker for the qubits used in the circuit. - qubits: QubitTracker, - /// A tracker for the bits used in the circuit. - bits: BitTracker, - /// A tracker for the operation parameters used in the circuit. - parameters: ParameterTracker, + /// A tracker for qubit/bit/parameter values associated with the circuit's wires. + /// + /// Contains methods to update the registers in the circuit being built. + pub values: ValueTracker, + /// A tracker for unsupported regions of the circuit. + unsupported: UnsupportedTracker, + /// Configuration for the encoding. + /// + /// Contains custom operation/type/const emitters. + config: Arc>, } -impl Tk1Encoder { - /// Create a new [`JsonEncoder`] from a [`Circuit`]. - pub fn new(circ: &Circuit>) -> Result { +impl Tk1EncoderContext { + /// Create a new [`Tk1EncoderContext`] from a [`Circuit`]. + pub(super) fn new( + circ: &Circuit, + config: Tk1EncoderConfig, + ) -> Result> { let name = circ.name().map(str::to_string); let hugr = circ.hugr(); - // Check for unsupported input types. - for (_, _, typ) in circ.units() { - if ![rotation_type(), float64_type(), qb_t(), bool_t()].contains(&typ) { - return Err(TK1ConvertError::NonSerializableInputs { typ }); - } - } - // Recover other parameters stored in the metadata - // TODO: Check for invalid encoded metadata let phase = match hugr.get_metadata(circ.parent(), METADATA_PHASE) { Some(p) => p.as_str().unwrap().to_string(), None => "0".to_string(), }; - let qubit_tracker = QubitTracker::new(circ); - let bit_tracker = BitTracker::new(circ); - let parameter_tracker = ParameterTracker::new(circ); - Ok(Self { name, phase, commands: vec![], - qubits: qubit_tracker, - bits: bit_tracker, - parameters: parameter_tracker, + values: ValueTracker::new(circ, &config)?, + unsupported: UnsupportedTracker::new(circ), + config: Arc::new(config), }) } - /// Add a circuit command to the serialization. - pub fn add_command>( + /// Traverse the circuit in topological order, encoding the nodes as pytket commands. + /// + /// Returns the final [`SerialCircuit`] if successful. + pub(super) fn run_encoder( &mut self, - command: Command<'_, T>, - optype: &OpType, - ) -> Result<(), OpConvertError> { - // Register any output of the command that can be used as a TKET1 parameter. - if self.parameters.record_parameters(&command, optype)? { - // for now all ops that record parameters should be ignored (are - // just constants) - return Ok(()); - } - - // Special case for the QAlloc operation. - // This does not translate to a TKET1 operation, we just start tracking a new qubit register. - if optype == &Tk2Op::QAlloc.into() { - let Some((CircuitUnit::Linear(unit_id), _, _)) = command.outputs().next() else { - panic!("QAlloc should have a single qubit output.") - }; - debug_assert!(self.qubits.get(unit_id).is_none()); - self.qubits.add_qubit_register(unit_id); - return Ok(()); - } - - let Some(tk1op) = Tk1Op::try_from_optype(optype.clone())? else { - // This command should be ignored. - return Ok(()); - }; - - // Get the registers and wires associated with the operation's inputs. - let mut qubit_args = Vec::with_capacity(tk1op.qubit_inputs()); - let mut bit_args = Vec::with_capacity(tk1op.bit_inputs()); - let mut params = Vec::with_capacity(tk1op.num_params()); - for (unit, _, ty) in command.inputs() { - if ty == qb_t() { - let reg = self.unit_to_register(unit).unwrap_or_else(|| { - panic!( - "No register found for qubit input {unit} in node {}.", - command.node(), - ) - }); - qubit_args.push(reg); - } else if ty == bool_t() { - let reg = self.unit_to_register(unit).unwrap_or_else(|| { - panic!( - "No register found for bit input {unit} in node {}.", - command.node(), - ) - }); - bit_args.push(reg); - } else if [rotation_type(), float64_type()].contains(&ty) { - let CircuitUnit::Wire(param_wire) = unit else { - unreachable!("Angle types are not linear.") - }; - params.push(param_wire); - } else { - return Err(OpConvertError::UnsupportedInputType { - typ: ty.clone(), - optype: optype.clone(), - node: command.node(), - }); - } - } - - for (unit, _, ty) in command.outputs() { - if ty == qb_t() { - // If the qubit is not already in the qubit tracker, add it as a - // new register. - let CircuitUnit::Linear(unit_id) = unit else { - panic!("Qubit types are linear.") - }; - if self.qubits.get(unit_id).is_none() { - let reg = self.qubits.add_qubit_register(unit_id); - qubit_args.push(reg.clone()); - } - } else if ty == bool_t() { - // If the operation has any bit outputs, create a new one bit - // register. - // - // Note that we do not reassign input registers to the new - // output wires as we do not know if the bit value was modified - // by the operation, and the old value may be needed later. - // - // This may cause register duplication for opaque operations - // with input bits. - let CircuitUnit::Wire(wire) = unit else { - panic!("Bool types are not linear.") - }; - let reg = self.bits.add_bit_register(wire); - bit_args.push(reg.clone()); - } else { - return Err(OpConvertError::UnsupportedOutputType { - typ: ty.clone(), - optype: optype.clone(), - node: command.node(), - }); + circ: &Circuit, + ) -> Result<(), Tk1ConvertError> { + // Normally we'd use `SiblingGraph` here, but it doesn't support generic node types. + // See https://github.com/CQCL/hugr/issues/2010 + let (region, node_map) = circ.hugr().region_portgraph(circ.parent()); + let io_nodes = circ.io_nodes(); + + // TODO: Use weighted topological sort to try and explore unsupported + // ops first (that is, ops with no available emitter in `self.config`), + // to ensure we group them as much as possible. + let mut topo = petgraph::visit::Topo::new(®ion); + while let Some(pg_node) = topo.next(®ion) { + let node = node_map.from_portgraph(pg_node); + if io_nodes.contains(&node) { + // I/O nodes are handled by `new` and `finish`. + continue; } + self.try_encode_node(node, circ)?; } - - let opgroup: Option = command - .metadata(METADATA_OPGROUP) - .and_then(serde_json::Value::as_str) - .map(ToString::to_string); - - // Convert the command's operator to a pytket serialized one. This will - // return an error for operations that should have been caught by the - // `record_parameters` branch above (in addition to other unsupported - // ops). - let mut serial_op: circuit_json::Operation = tk1op - .serialised_op() - .ok_or_else(|| OpConvertError::UnsupportedOpSerialization(optype.clone()))?; - - if !params.is_empty() { - serial_op.params = Some( - params - .into_iter() - .filter_map(|w| self.parameters.get(&w)) - .cloned() - .collect(), - ) - } - // TODO: ops that contain free variables. - // (update decoder to ignore them too, but store them in the wrapped op) - - let mut args = qubit_args; - args.append(&mut bit_args); - let command = circuit_json::Command { - op: serial_op, - args, - opgroup, - }; - self.commands.push(command); - Ok(()) } /// Finish building and return the final [`SerialCircuit`]. - pub fn finish(self, circ: &Circuit>) -> SerialCircuit { - let (qubits, qubits_permutation) = self.qubits.finish(circ); - let (bits, mut bits_permutation) = self.bits.finish(circ); + pub(super) fn finish( + mut self, + circ: &Circuit, + ) -> Result> { + // Add any remaining unsupported nodes + // + // TODO: Test that unsupported subgraphs that don't affect any qubit/bit registers + // are correctly encoded in pytket commands. + while !self.unsupported.is_empty() { + let node = self.unsupported.iter().next().unwrap(); + let unsupported_subgraph = self.unsupported.extract_component(node); + self.emit_unsupported(unsupported_subgraph, circ)?; + } - let mut implicit_permutation = qubits_permutation; - implicit_permutation.append(&mut bits_permutation); + let final_values = self.values.finish(circ)?; let mut ser = SerialCircuit::new(self.name, self.phase); ser.commands = self.commands; - ser.qubits = qubits.into_iter().map_into().collect(); - ser.bits = bits.into_iter().map_into().collect(); - ser.implicit_permutation = implicit_permutation; + ser.qubits = final_values.qubits.into_iter().map_into().collect(); + ser.bits = final_values.bits.into_iter().map_into().collect(); + ser.implicit_permutation = final_values.qubit_permutation; ser.number_of_ws = None; - ser + Ok(ser) } - /// Translate a linear [`CircuitUnit`] into a [`RegisterUnit`], if possible. - fn unit_to_register(&self, unit: CircuitUnit) -> Option { - match unit { - CircuitUnit::Linear(i) => self.qubits.get(i).cloned(), - CircuitUnit::Wire(wire) => self.bits.get(&wire).cloned(), - } + /// Returns a reference to this encoder's configuration. + pub fn config(&self) -> &Tk1EncoderConfig { + &self.config } -} -/// A structure for tracking qubits used in the circuit being encoded. -/// -/// Nb: Although `tket-json-rs` has a "Register" struct, it's actually -/// an identifier for single qubits in the `Register::0` register. -/// We rename it to `RegisterUnit` here to avoid confusion. -#[derive(Debug, Clone, Default)] -struct QubitTracker { - /// The ordered TKET1 names for the input qubit registers. - inputs: Vec, - /// The ordered TKET1 names for the output qubit registers. - outputs: Option>, - /// The TKET1 qubit registers associated to each qubit unit of the circuit. - qubit_to_reg: HashMap, - /// A generator of new registers units to use for bit wires. - unit_generator: RegisterUnitGenerator, -} - -impl QubitTracker { - /// Create a new [`QubitTracker`] from the qubit inputs of a [`Circuit`]. - /// Reads the [`METADATA_Q_REGISTERS`] metadata entry with preset pytket qubit register names. + /// Returns the values associated with a wire. /// - /// If the circuit contains more qubit inputs than the provided list, - /// new registers are created for the remaining qubits. - pub fn new(circ: &Circuit>) -> Self { - let mut tracker = QubitTracker::default(); - - if let Some(input_regs) = circ - .hugr() - .get_metadata(circ.parent(), METADATA_Q_REGISTERS) - { - tracker.inputs = serde_json::from_value(input_regs.clone()).unwrap(); - } - let output_regs = circ - .hugr() - .get_metadata(circ.parent(), METADATA_Q_OUTPUT_REGISTERS) - .map(|regs| serde_json::from_value(regs.clone()).unwrap()); - if let Some(output_regs) = output_regs { - tracker.outputs = Some(output_regs); + /// Marks the port connection as explored. When all ports connected to the + /// wire have been explored, the wire is removed from the tracker. + /// + /// If the input wire is the output of an unsupported node, a subgraph of + /// unsupported nodes containing it will be emitted as a pytket barrier. + /// + /// This function SHOULD NOT be called before determining that the target + /// operation is supported, as it will mark the wire as explored and + /// potentially remove it from the tracker. To determine if a wire type is + /// supported, use [`Tk1EncoderConfig::type_to_pytket`] on the encoder + /// context's [`Tk1EncoderContext::config`]. + /// + /// ### Errors + /// + /// - [`OpConvertError::WireHasNoValues`] if the wire is not tracked or has + /// a type that cannot be converted to pytket values. + pub fn get_wire_values( + &mut self, + wire: Wire, + circ: &Circuit, + ) -> Result, Tk1ConvertError> { + if self.values.peek_wire_values(wire).is_some() { + return Ok(self.values.wire_values(wire).unwrap()); } - tracker.unit_generator = RegisterUnitGenerator::new( - "q", - tracker - .inputs - .iter() - .chain(tracker.outputs.iter().flatten()), - ); - - let qubit_count = circ.units().filter(|(_, _, ty)| ty == &qb_t()).count(); - - for i in 0..qubit_count { - // Use the given input register names if available, or create new ones. - if let Some(reg) = tracker.inputs.get(i) { - tracker.qubit_to_reg.insert(i, reg.clone()); - } else { - let reg = tracker.add_qubit_register(i).clone(); - tracker.inputs.push(reg); - } + // If the wire values have not been registered yet, it may be because + // the wire is the output of an unsupported node group. + // + // We need to emit the unsupported node here before returning the values. + if self.unsupported.is_unsupported(wire.node()) { + let unsupported_subgraph = self.unsupported.extract_component(wire.node()); + self.emit_unsupported(unsupported_subgraph, circ)?; + debug_assert!(!self.unsupported.is_unsupported(wire.node())); + return self.get_wire_values(wire, circ); } - tracker + Err(OpConvertError::WireHasNoValues { wire }.into()) } - /// Add a new register unit for a qubit wire. - pub fn add_qubit_register(&mut self, unit_id: usize) -> &RegisterUnit { - let reg = self.unit_generator.next(); - self.qubit_to_reg.insert(unit_id, reg); - self.qubit_to_reg.get(&unit_id).unwrap() - } - - /// Returns the register unit for a qubit wire, if it exists. - pub fn get(&self, unit_id: usize) -> Option<&RegisterUnit> { - self.qubit_to_reg.get(&unit_id) + /// Given a node in the HUGR, returns all the [`TrackedValue`]s associated + /// with its inputs. + /// + /// These values can be used with the [`Tk1EncoderContext::values`] tracker + /// to retrieve the corresponding pytket registers and parameter + /// expressions. See [`ValueTracker::qubit_register`], + /// [`ValueTracker::bit_register`], and [`ValueTracker::param_expression`]. + pub fn get_input_values( + &mut self, + node: H::Node, + circ: &Circuit, + ) -> Result> { + self.get_input_values_internal(node, circ, |_| true) } - /// Consumes the tracker and returns the final list of qubit registers, along - /// with the final permutation of the outputs. - pub fn finish( - mut self, - _circ: &Circuit>, - ) -> (Vec, Vec) { - // Ensure the input and output lists have the same registers. - let mut outputs = self.outputs.unwrap_or_default(); - let mut input_regs: HashSet = - self.inputs.iter().map(RegisterHash::from).collect(); - let output_regs: HashSet = outputs.iter().map(RegisterHash::from).collect(); - - for inp in &self.inputs { - if !output_regs.contains(&inp.into()) { - outputs.push(inp.clone()); + /// Auxiliary function used to collect the input values of a node. + /// See [`Tk1EncoderContext::get_input_values`]. + /// + /// Given a node in the HUGR, returns all the [`TrackedValue`]s associated + /// with its inputs. Calls + /// + /// Includes a filter to decide which incoming wires to include. + fn get_input_values_internal( + &mut self, + node: H::Node, + circ: &Circuit, + wire_filter: impl Fn(Wire) -> bool, + ) -> Result> { + let mut qubits: Vec = Vec::new(); + let mut bits: Vec = Vec::new(); + let mut params: Vec = Vec::new(); + + let optype = circ.hugr().get_optype(node); + let other_input_port = optype.other_input_port(); + for input in circ.hugr().node_inputs(node) { + // Ignore order edges. + if Some(input) == other_input_port { + continue; } - } - for out in &outputs { - if !input_regs.contains(&out.into()) { - self.inputs.push(out.clone()); + // Dataflow ports should have a single linked neighbour. + let Some((neigh, neigh_out)) = circ.hugr().single_linked_output(node, input) else { + return Err( + OpConvertError::UnsupportedOpSerialization { op: optype.clone() }.into(), + ); + }; + let wire = Wire::new(neigh, neigh_out); + if !wire_filter(wire) { + continue; } - } - input_regs.extend(output_regs); - // Add registers defined mid-circuit to both ends. - for reg in self.qubit_to_reg.into_values() { - if !input_regs.contains(&(®).into()) { - self.inputs.push(reg.clone()); - outputs.push(reg); + for value in self.get_wire_values(wire, circ)?.iter() { + match value { + TrackedValue::Qubit(qb) => qubits.push(*qb), + TrackedValue::Bit(b) => bits.push(*b), + TrackedValue::Param(p) => params.push(*p), + } } } + Ok(TrackedValues { + qubits, + bits, + params, + }) + } - // TODO: Look at the circuit outputs to determine the final permutation. - // - // We don't have the `CircuitUnit::Linear` assignments for the outputs - // here, so that requires some extra piping. - let permutation = outputs - .into_iter() - .zip(&self.inputs) - .map(|(out, inp)| circuit_json::ImplicitPermutation(inp.clone().into(), out.into())) - .collect_vec(); - - (self.inputs, permutation) + /// Helper to emit a new tket1 command corresponding to a single HUGR node. + /// + /// This call will fail if the node has parameter outputs. Use + /// [`Tk1EncoderContext::emit_node_with_out_params`] instead. + /// + /// See [`Tk1EncoderContext::emit_command`] for more general cases. + /// + /// ## Arguments + /// + /// - `tk1_operation`: The tket1 operation type to emit. + /// - `node`: The HUGR for which to emit the command. Qubits and bits are + /// automatically retrieved from the node's inputs/outputs. + /// - `circ`: The circuit containing the node. + pub fn emit_node( + &mut self, + tk1_optype: tket_json_rs::OpType, + node: H::Node, + circ: &Circuit, + ) -> Result<(), Tk1ConvertError> { + self.emit_node_with_out_params(tk1_optype, node, circ, |_| Vec::new()) } -} -/// A structure for tracking bits used in the circuit being encoded. -/// -/// Nb: Although `tket-json-rs` has a "Register" struct, it's actually -/// an identifier for single bits in the `Register::0` register. -/// We rename it to `RegisterUnit` here to avoid confusion. -#[derive(Debug, Clone, Default)] -struct BitTracker { - /// The ordered TKET1 names for the bit inputs. - inputs: Vec, - /// The expected order of TKET1 names for the bit outputs, - /// if that was stored in the metadata. - outputs: Option>, - /// Map each bit wire to a TKET1 register element. - bit_to_reg: HashMap, - /// Registers defined in the metadata, but not present in the circuit - /// inputs. - unused_registers: VecDeque, - /// A generator of new registers units to use for bit wires. - unit_generator: RegisterUnitGenerator, -} + /// Helper to emit a new tket1 command corresponding to a single HUGR node, + /// with parameter outputs. Use [`Tk1EncoderContext::emit_node`] for nodes + /// that don't require computing parameter outputs. + /// + /// See [`Tk1EncoderContext::emit_command`] for more general cases. + /// + /// ## Arguments + /// + /// - `tk1_operation`: The tket1 operation type to emit. + /// - `node`: The HUGR for which to emit the command. Qubits and bits are + /// automatically retrieved from the node's inputs/outputs. + /// - `circ`: The circuit containing the node. + /// - `output_params`: A function that computes the output parameter + /// expressions from the list of input parameters. If the number of parameters + /// does not match the expected number, the encoding will fail. + pub fn emit_node_with_out_params( + &mut self, + tk1_optype: tket_json_rs::OpType, + node: H::Node, + circ: &Circuit, + output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec, + ) -> Result<(), Tk1ConvertError> { + self.emit_node_command( + node, + circ, + output_params, + move |qubit_count, bit_count, params| { + make_tk1_operation(tk1_optype, qubit_count, bit_count, params) + }, + ) + } -impl BitTracker { - /// Create a new [`BitTracker`] from the bit inputs of a [`Circuit`]. - /// Reads the [`METADATA_B_REGISTERS`] metadata entry with preset pytket bit register names. + /// Helper to emit a new tket1 command corresponding to a single HUGR node, + /// using a custom operation generator and computing output parameter + /// expressions. Use [`Tk1EncoderContext::emit_node`] or + /// [`Tk1EncoderContext::emit_node_with_out_params`] when pytket operation + /// can be defined directly from a [`tket_json_rs::OpType`]. /// - /// If the circuit contains more bit inputs than the provided list, - /// new registers are created for the remaining bits. + /// See [`Tk1EncoderContext::emit_command`] for a general case emitter. /// - /// TODO: Compute output bit permutations when finishing the circuit. - pub fn new(circ: &Circuit>) -> Self { - let mut tracker = BitTracker::default(); + /// ## Arguments + /// + /// - `node`: The HUGR for which to emit the command. Qubits and bits are + /// automatically retrieved from the node's inputs/outputs. + /// - `circ`: The circuit containing the node. + /// - `output_params`: A function that computes the output parameter + /// expressions from the list of input parameters. If the number of parameters + /// does not match the expected number, the encoding will fail. + /// - `make_operation`: A function that takes the number of qubits, bits, and + /// the list of input parameter expressions and returns a pytket operation. + /// See [`make_tk1_operation`] for a helper function to create it. + pub fn emit_node_command( + &mut self, + node: H::Node, + circ: &Circuit, + output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec, + make_operation: impl FnOnce(usize, usize, &[String]) -> tket_json_rs::circuit_json::Operation, + ) -> Result<(), Tk1ConvertError> { + let TrackedValues { + mut qubits, + mut bits, + params, + } = self.get_input_values(node, circ)?; + let params: Vec = params + .into_iter() + .map(|p| self.values.param_expression(p).to_owned()) + .collect(); - if let Some(input_regs) = circ - .hugr() - .get_metadata(circ.parent(), METADATA_B_REGISTERS) - { - tracker.inputs = serde_json::from_value(input_regs.clone()).unwrap(); - } - let output_regs = circ + // Update the values in the node's outputs. + // + // We preserve the order of linear values in the input. + let mut qubit_iterator = qubits.iter().copied(); + let new_outputs = self.register_node_outputs( + node, + circ, + &mut qubit_iterator, + ¶ms, + output_params, + |_| true, + )?; + qubits.extend(new_outputs.qubits); + bits.extend(new_outputs.bits); + + // Preserve the pytket opgroup, if it got stored in the metadata. + let opgroup: Option = circ .hugr() - .get_metadata(circ.parent(), METADATA_B_OUTPUT_REGISTERS) - .map(|regs| serde_json::from_value(regs.clone()).unwrap()); - if let Some(output_regs) = output_regs { - tracker.outputs = Some(output_regs); - } + .get_metadata(node, METADATA_OPGROUP) + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); - tracker.unit_generator = RegisterUnitGenerator::new( - "c", - tracker - .inputs - .iter() - .chain(tracker.outputs.iter().flatten()), - ); + let op = make_operation(qubits.len(), bits.len(), ¶ms); + self.emit_command(op, &qubits, &bits, opgroup); + Ok(()) + } - let bit_input_wires = circ.units().filter_map(|u| match u { - (CircuitUnit::Wire(w), _, ty) if ty == bool_t() => Some(w), - _ => None, + /// Helper to emit a node that transparently forwards its inputs to its + /// outputs, resulting in no pytket operation. + /// + /// It registers the node's input qubits and bits to the output + /// wires, without modifying the tket1 circuit being constructed. + /// Output parameters are more flexible, and output expressions can be + /// computed from the input parameters via the `output_params` function. + /// + /// The node's inputs should have exactly the same number of qubits and + /// bits. This method will return an error if that is not the case. + /// + /// You must also ensure that all input and output types are supported by + /// the encoder. Otherwise, the function will return an error. + /// + /// ## Arguments + /// + /// - `node`: The HUGR for which to emit the command. Qubits and bits are + /// automatically retrieved from the node's inputs/outputs. + /// - `circ`: The circuit containing the node. + /// - `output_params`: A function that computes the output parameter + /// expressions from the list of input parameters. If the number of parameters + /// does not match the expected number, the encoding will fail. + pub fn emit_transparent_node( + &mut self, + node: H::Node, + circ: &Circuit, + output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec, + ) -> Result<(), Tk1ConvertError> { + let input_values = self.get_input_values(node, circ)?; + let output_counts = self.node_output_values(node, circ)?; + let total_out_count: RegisterCount = output_counts.iter().map(|(_, c)| *c).sum(); + + // Compute all the output parameters at once + let input_params: Vec = input_values + .params + .into_iter() + .map(|p| self.values.param_expression(p).to_owned()) + .collect_vec(); + let out_params = output_params(OutputParamArgs { + expected_count: total_out_count.params, + input_params: &input_params, }); - let mut unused_registers: HashSet = tracker.inputs.iter().cloned().collect(); - for (i, wire) in bit_input_wires.enumerate() { - // If the input is not used in the circuit, ignore it. - if circ - .hugr() - .linked_inputs(wire.node(), wire.source()) - .next() - .is_none() - { - continue; - } - - // Use the given input register names if available, or create new ones. - if let Some(reg) = tracker.inputs.get(i) { - unused_registers.remove(reg); - tracker.bit_to_reg.insert(wire, reg.clone()); - } else { - let reg = tracker.add_bit_register(wire).clone(); - tracker.inputs.push(reg); - }; + // Check that we got the expected number of outputs. + if input_values.qubits.len() != total_out_count.qubits { + return Err(Tk1ConvertError::custom(format!( + "Mismatched number of input and output qubits while trying to emit a transparent operation for {}. We have {} inputs but {} outputs.", + circ.hugr().get_optype(node), + input_values.qubits.len(), + total_out_count.qubits, + ))); + } + if input_values.bits.len() != total_out_count.bits { + return Err(Tk1ConvertError::custom(format!( + "Mismatched number of input and output bits while trying to emit a transparent operation for {}. We have {} inputs but {} outputs.", + circ.hugr().get_optype(node), + input_values.bits.len(), + total_out_count.bits, + ))); + } + if out_params.len() != total_out_count.params { + return Err(Tk1ConvertError::custom(format!( + "Not enough parameters in the input values for a {}. Expected {} but got {}.", + circ.hugr().get_optype(node), + total_out_count.params, + out_params.len() + ))); } - // If a register was defined in the metadata but not used in the circuit, - // we keep it so it can be assigned to an operation output. - tracker.unused_registers = unused_registers.into_iter().collect(); + // Now we can gather all inputs and assign them to the node outputs transparently. + let mut qubits = input_values.qubits.into_iter(); + let mut bits = input_values.bits.into_iter(); + let mut params = out_params.into_iter(); + for (wire, count) in output_counts { + let mut values: Vec = Vec::with_capacity(count.total()); + values.extend(qubits.by_ref().take(count.qubits).map(TrackedValue::Qubit)); + values.extend(bits.by_ref().take(count.bits).map(TrackedValue::Bit)); + for p in params.by_ref().take(count.params) { + values.push(self.values.new_param(p).into()); + } + self.values.register_wire(wire, values, circ)?; + } - tracker + Ok(()) } - /// Add a new register unit for a bit wire. - pub fn add_bit_register(&mut self, wire: Wire) -> &RegisterUnit { - let reg = self - .unused_registers - .pop_front() - .unwrap_or_else(|| self.unit_generator.next()); + /// Helper to emit a new tket1 command corresponding to subgraph of unsupported nodes, + /// encoded inside a pytket barrier. + /// + /// ## Arguments + /// + /// - `unsupported_nodes`: The list of nodes to encode as an unsupported subgraph. + fn emit_unsupported( + &mut self, + unsupported_nodes: BTreeSet, + circ: &Circuit, + ) -> Result<(), Tk1ConvertError> { + let subcircuit_id = format!("tk{}", unsupported_nodes.iter().min().unwrap()); - self.bit_to_reg.insert(wire, reg); - self.bit_to_reg.get(&wire).unwrap() - } + // TODO: Use a cached topo checker here instead of traversing the full graph each time we create a `SiblingSubgraph`. + // + // TopoConvexChecker likes to borrow the hugr, so it'd be too invasive to store in the `Context`. + let subgraph = SiblingSubgraph::try_from_nodes( + unsupported_nodes.iter().cloned().collect_vec(), + circ.hugr(), + ) + .unwrap_or_else(|_| { + panic!( + "Failed to create subgraph from unsupported nodes [{}]", + unsupported_nodes.iter().join(", ") + ) + }); + let input_nodes: HashSet<_> = subgraph + .incoming_ports() + .iter() + .flat_map(|inp| inp.iter().map(|(n, _)| *n)) + .collect(); + let output_nodes: HashSet<_> = subgraph.outgoing_ports().iter().map(|(n, _)| *n).collect(); - /// Returns the register unit for a bit wire, if it exists. - pub fn get(&self, wire: &Wire) -> Option<&RegisterUnit> { - self.bit_to_reg.get(wire) - } + let unsupported_hugr = subgraph.extract_subgraph(circ.hugr(), &subcircuit_id); + let payload = Package::from_hugr(unsupported_hugr) + .store_str(EnvelopeConfig::text()) + .unwrap(); - /// Consumes the tracker and returns the final list of bit registers, along - /// with the final permutation of the outputs. - pub fn finish( - mut self, - circ: &Circuit>, - ) -> (Vec, Vec) { - let mut circuit_output_order: Vec = Vec::with_capacity(self.inputs.len()); - for (node, port) in circ.hugr().all_linked_outputs(circ.output_node()) { - let wire = Wire::new(node, port); - if let Some(reg) = self.bit_to_reg.get(&wire) { - circuit_output_order.push(reg.clone()); - } + // Collects the input values for the subgraph. + // + // The [`UnsupportedTracker`] ensures that at this point all input wires must come from + // already-encoded nodes, and not from other unsupported nodes not in `unsupported_nodes`. + let mut op_values = TrackedValues::default(); + for node in &input_nodes { + let node_vals = self.get_input_values_internal(*node, circ, |w| { + unsupported_nodes.contains(&w.node()) + })?; + op_values.append(node_vals); } + let input_param_exprs: Vec = std::mem::take(&mut op_values.params) + .into_iter() + .map(|p| self.values.param_expression(p).to_owned()) + .collect(); - // Ensure the input and output lists have the same registers. - let mut outputs = self.outputs.unwrap_or_default(); - let mut input_regs: HashSet = - self.inputs.iter().map(RegisterHash::from).collect(); - let output_regs: HashSet = outputs.iter().map(RegisterHash::from).collect(); - - for inp in &self.inputs { - if !output_regs.contains(&inp.into()) { - outputs.push(inp.clone()); - } - } - for out in &outputs { - if !input_regs.contains(&out.into()) { - self.inputs.push(out.clone()); - } + // Update the values in the node's outputs, and extend `op_values` with + // any new output values. + // + // Output parameters are mapped to a fresh variable, that can be tracked + // back to the encoded subcircuit's function name. + let mut input_qubits = op_values.qubits.clone().into_iter(); + for &node in &output_nodes { + let new_outputs = self.register_node_outputs( + node, + circ, + &mut input_qubits, + &[], + |p| { + (0..p.expected_count) + .map(|i| format!("{subcircuit_id}_out{i}")) + .collect_vec() + }, + |_| true, + )?; + op_values.append(new_outputs); } - input_regs.extend(output_regs); - // Add registers defined mid-circuit to both ends. - for reg in self.bit_to_reg.into_values() { - if !input_regs.contains(&(®).into()) { - self.inputs.push(reg.clone()); - outputs.push(reg); - } - } + // Create pytket operation, and add the subcircuit as hugr + let mut tk1_op = make_tk1_operation( + tket_json_rs::OpType::Barrier, + op_values.qubits.len(), + op_values.bits.len(), + &input_param_exprs, + ); + tk1_op.data = Some(payload); - // And ensure `circuit_output_order` has all virtual registers added too. - let circuit_outputs: HashSet = circuit_output_order - .iter() - .map(RegisterHash::from) - .collect(); - for out in &outputs { - if !circuit_outputs.contains(&out.into()) { - circuit_output_order.push(out.clone()); - } - } + let opgroup = Some("tket2".to_string()); + self.emit_command(tk1_op, &op_values.qubits, &op_values.bits, opgroup); + Ok(()) + } - // Compute the final permutation. This is a combination of two mappings: - // - First, the original implicit permutation for the circuit, if this was decoded from pytket. - let original_permutation: HashMap = self - .inputs - .iter() - .zip(&outputs) - .map(|(inp, out)| (inp.clone(), RegisterHash::from(out))) - .collect(); - // - Second, the actual reordering of outputs seen at the circuit's output node. - let mut circuit_permutation: HashMap = outputs - .iter() - .zip(circuit_output_order) - .map(|(out, circ_out)| (RegisterHash::from(out), circ_out)) - .collect(); - // The final permutation is the composition of these two mappings. - let permutation = original_permutation - .into_iter() - .map(|(inp, out)| { - circuit_json::ImplicitPermutation( - inp.into(), - circuit_permutation.remove(&out).unwrap().into(), - ) - }) - .collect_vec(); + /// Emit a new tket1 command. + /// + /// This is a general-purpose command that can be used to emit any tket1 + /// operation, not necessarily corresponding to a specific HUGR node. + /// + /// Ensure that any output wires from the node being processed gets the + /// appropriate values registered by calling [`ValueTracker::register_wire`] + /// on the context's [`Tk1EncoderContext::values`] tracker. + /// + /// In general you should prefer using [`Tk1EncoderContext::emit_node`] or + /// [`Tk1EncoderContext::emit_node_with_out_params`] as they automatically + /// compute the input qubits and bits from the HUGR node, and ensure that + /// output wires get their new values registered on the tracker. + /// + /// ## Arguments + /// + /// - `tk1_operation`: The tket1 operation to emit. See + /// [`make_tk1_operation`] for a helper function to create it. + /// - `qubits`: The qubit registers to use as inputs/outputs of the pytket + /// op. Normally obtained from a HUGR node's inputs using + /// [`Tk1EncoderContext::get_input_values`] or allocated via + /// [`ValueTracker::new_qubit`]. + /// - `bits`: The bit registers to use as inputs/outputs of the pytket op. + /// Normally obtained from a HUGR node's inputs using + /// [`Tk1EncoderContext::get_input_values`] or allocated via + /// [`ValueTracker::new_bit`]. + /// - `opgroup`: A tket1 operation group identifier, if any. + pub fn emit_command( + &mut self, + tk1_operation: circuit_json::Operation, + qubits: &[TrackedQubit], + bits: &[TrackedBit], + opgroup: Option, + ) { + let qubit_regs = qubits.iter().map(|&qb| self.values.qubit_register(qb)); + let bit_regs = bits.iter().map(|&b| self.values.bit_register(b)); + let command = circuit_json::Command { + op: tk1_operation, + args: qubit_regs.chain(bit_regs).cloned().collect(), + opgroup, + }; - (self.inputs, permutation) + self.commands.push(command); } -} -/// A structure for tracking the parameters of a circuit being encoded. -#[derive(Debug, Clone, Default)] -struct ParameterTracker { - /// The parameters associated with each wire. - parameters: HashMap, -} + /// Encode a single circuit node into pytket commands and update the + /// encoder. + /// + /// Dispatches to the registered encoders, trying each in turn until one + /// successfully encodes the operation. + /// + /// Returns `true` if the node was successfully encoded, or `false` if none + /// of the encoders could process it and the node got added to the + /// unsupported set. + fn try_encode_node( + &mut self, + node: H::Node, + circ: &Circuit, + ) -> Result> { + let optype = circ.hugr().get_optype(node); -impl ParameterTracker { - /// Create a new [`ParameterTracker`] from the input parameters of a [`Circuit`]. - fn new(circ: &Circuit>) -> Self { - let mut tracker = ParameterTracker::default(); + // TODO: Boxes and non-custom optypes - let angle_input_wires = circ.units().filter_map(|u| match u { - (CircuitUnit::Wire(w), _, ty) if [rotation_type(), float64_type()].contains(&ty) => { - Some(w) + // Try to encode the operation using each of the registered encoders. + // + // If none of the encoders can handle the operation, we just add it to + // the unsupported tracker and move on. + match optype { + OpType::ExtensionOp(op) => { + let config = Arc::clone(&self.config); + if config.op_to_pytket(node, op, circ, self)? { + return Ok(true); + } } - _ => None, - }); - - // The input parameter names may be specified in the metadata. - let fixed_input_names: Vec = circ - .hugr() - .get_metadata(circ.parent(), METADATA_INPUT_PARAMETERS) - .and_then(|params| serde_json::from_value(params.clone()).ok()) - .unwrap_or_default(); - let extra_names = (fixed_input_names.len()..).map(|i| format!("f{i}")); - let mut param_name = fixed_input_names.into_iter().chain(extra_names); - - for wire in angle_input_wires { - tracker.add_parameter(wire, param_name.next().unwrap()); + OpType::LoadConstant(_) => { + self.emit_transparent_node(node, circ, |ps| ps.input_params.to_owned())?; + return Ok(true); + } + OpType::Const(op) => { + let config = Arc::clone(&self.config); + if let Some(values) = config.const_to_pytket(&op.value, self)? { + let wire = Wire::new(node, 0); + self.values.register_wire(wire, values.into_iter(), circ)?; + return Ok(true); + } + } + _ => {} } - tracker + self.unsupported.record_node(node, circ); + Ok(false) } - /// Record any output of the command that can be used as a TKET1 parameter. - /// Returns whether parameters were recorded. - /// Associates the output wires with the parameter expression. - fn record_parameters>( + /// Helper to register values for a node's output wires. + /// + /// Returns any new value associated with the output wires. + /// + /// ## Arguments + /// + /// - `node`: The node to register the outputs for. + /// - `circ`: The circuit containing the node. + /// - `qubit_values`: An iterator of existing qubit ids to re-use for the output. + /// Once all qubits have been used, new qubit ids will be generated. + /// - `input_params`: The list of input parameter expressions. + /// - `output_params`: A function that computes the output parameter + /// expressions from the list of input parameters. If the number of parameters + /// does not match the expected number, the encoding will fail. + /// - `wire_filter`: A function that takes a wire and returns true if the wire + /// at the output of the `node` should be registered. + fn register_node_outputs( &mut self, - command: &Command<'_, T>, - optype: &OpType, - ) -> Result { - let input_count = if let Some(signature) = optype.dataflow_signature() { - // Only consider commands where all inputs and some outputs are - // parameters that we can track. - let tracked_params: [Type; 2] = [rotation_type(), float64_type()]; - let all_inputs = signature - .input() - .iter() - .all(|ty| tracked_params.contains(ty)); - let some_output = signature - .output() - .iter() - .any(|ty| tracked_params.contains(ty)); - if !all_inputs || !some_output { - return Ok(false); - } - signature.input_count() - } else if let OpType::Const(_) = optype { - // `Const` is a special non-dataflow command we can handle. - // It has zero inputs. - 0 - } else { - // Not a parameter-generating command. - return Ok(false); - }; + node: H::Node, + circ: &Circuit, + qubits: &mut impl Iterator, + input_params: &[String], + output_params: impl FnOnce(OutputParamArgs<'_>) -> Vec, + wire_filter: impl Fn(Wire) -> bool, + ) -> Result> { + let output_counts = self.node_output_values(node, circ)?; + let total_out_count: RegisterCount = output_counts.iter().map(|(_, c)| *c).sum(); + + // Compute all the output parameters at once + let out_params = output_params(OutputParamArgs { + expected_count: total_out_count.params, + input_params, + }); - // Collect the input parameters. - let mut inputs = Vec::with_capacity(input_count); - for (unit, _, _) in command.inputs() { - let CircuitUnit::Wire(wire) = unit else { - panic!("Angle types are not linear") - }; - let Some(param) = self.parameters.get(&wire) else { - let typ = rotation_type(); - return Err(OpConvertError::UnresolvedParamInput { - typ, - optype: optype.clone(), - node: command.node(), - }); - }; - inputs.push(param.as_str()); + // Check that we got the expected number of outputs. + if out_params.len() != total_out_count.params { + return Err(Tk1ConvertError::custom(format!( + "Not enough parameters in the input values for a {}. Expected {} but got {}.", + circ.hugr().get_optype(node), + total_out_count.params, + out_params.len() + ))); } - let Some(param) = fold_param_op(optype, &inputs) else { - return Ok(false); - }; + // Update the values in the node's outputs. + // + // We preserve the order of linear values in thpe input + let mut params = out_params.into_iter(); + let mut new_outputs = TrackedValues::default(); + for (wire, count) in output_counts { + if !wire_filter(wire) { + continue; + } - for (unit, _, _) in command.outputs() { - if let CircuitUnit::Wire(wire) = unit { - self.add_parameter(wire, param.clone()) + let mut out_wire_values = Vec::with_capacity(count.total()); + out_wire_values.extend(qubits.by_ref().take(count.qubits).map(TrackedValue::Qubit)); + for _ in out_wire_values.len()..count.qubits { + // If we already assigned all input qubit ids, get a fresh one. + let qb = self.values.new_qubit(); + new_outputs.qubits.push(qb); + out_wire_values.push(TrackedValue::Qubit(qb)); + } + for _ in 0..count.bits { + let b = self.values.new_bit(); + new_outputs.bits.push(b); + out_wire_values.push(TrackedValue::Bit(b)); } + for expr in params.by_ref().take(count.params) { + let p = self.values.new_param(expr); + new_outputs.params.push(p); + out_wire_values.push(p.into()); + } + self.values.register_wire(wire, out_wire_values, circ)?; } - Ok(true) - } - /// Associate a parameter expression with a wire. - fn add_parameter(&mut self, wire: Wire, param: String) { - self.parameters.insert(wire, param); + Ok(new_outputs) } - /// Returns the parameter expression for a wire, if it exists. - fn get(&self, wire: &Wire) -> Option<&String> { - self.parameters.get(wire) + /// Return the output wires of a node that have an associated pytket [`RegisterCount`]. + #[allow(clippy::type_complexity)] + fn node_output_values( + &self, + node: H::Node, + circ: &Circuit, + ) -> Result, RegisterCount)>, Tk1ConvertError> { + let op = circ.hugr().get_optype(node); + let signature = op.dataflow_signature(); + let static_output = op.static_output_port(); + let other_output = op.other_output_port(); + let mut wire_counts = Vec::with_capacity(circ.hugr().num_outputs(node)); + for out_port in circ.hugr().node_outputs(node) { + let ty = if Some(out_port) == other_output { + // Ignore order edges + continue; + } else if Some(out_port) == static_output { + let EdgeKind::Const(ty) = op.static_output().unwrap() else { + return Err(Tk1ConvertError::custom(format!( + "Cannot emit a static output for a {op}." + ))); + }; + ty + } else { + let Some(ty) = signature + .as_ref() + .and_then(|s| s.out_port_type(out_port).cloned()) + else { + return Err(Tk1ConvertError::custom( + "Cannot emit a transparent node without a dataflow signature.", + )); + }; + ty + }; + + let wire = hugr::Wire::new(node, out_port); + let Some(count) = self.config().type_to_pytket(&ty)? else { + return Err(Tk1ConvertError::custom(format!( + "Found an unsupported type while encoding a {op}." + ))); + }; + wire_counts.push((wire, count)); + } + Ok(wire_counts) } } -/// A utility class for finding new unused qubit/bit names. -#[derive(Debug, Clone, Default)] -struct RegisterUnitGenerator { - /// The next index to use for a new register. - next_unit: u16, - /// The register name to use. - register: String, +/// Input passed to the output parameter computation methods in the emitting +/// functions of [`Tk1EncoderContext`]. +#[derive(Clone, Copy, Debug)] +pub struct OutputParamArgs<'a> { + /// The expected number of output parameters. + pub expected_count: usize, + /// The list of input parameter expressions. + pub input_params: &'a [String], } -impl RegisterUnitGenerator { - /// Create a new [`RegisterUnitGenerator`] - /// - /// Scans the set of existing registers to find the last used index, and - /// starts generating new unit names from there. - pub fn new<'a>( - register: impl ToString, - existing: impl IntoIterator, - ) -> Self { - let register = register.to_string(); - let mut last_unit: Option = None; - for reg in existing { - if reg.0 != register { - continue; - } - last_unit = Some(last_unit.unwrap_or_default().max(reg.1[0] as u16)); - } - RegisterUnitGenerator { - register, - next_unit: last_unit.map_or(0, |i| i + 1), - } - } - - /// Returns a fresh register unit. - pub fn next(&mut self) -> RegisterUnit { - let unit = self.next_unit; - self.next_unit += 1; - RegisterUnit(self.register.clone(), vec![unit as i64]) - } +/// Initialize a tket1 [Operation](circuit_json::Operation) to pass to +/// [`Tk1EncoderContext::emit_command`]. +/// +/// ## Arguments +/// - `tk1_optype`: The operation type to use. +/// - `qubit_count`: The number of qubits used by the operation. +/// - `bit_count`: The number of linear bits used by the operation. +/// - `params`: Parameters of the operation, expressed as string expressions. +/// Normally obtained from [`ValueTracker::param_expression`]. +pub fn make_tk1_operation( + tk1_optype: tket_json_rs::OpType, + qubit_count: usize, + bit_count: usize, + params: &[String], +) -> circuit_json::Operation { + let mut op = circuit_json::Operation::default(); + op.op_type = tk1_optype; + op.n_qb = Some(qubit_count as u32); + op.params = match params.is_empty() { + false => Some(params.to_owned()), + true => None, + }; + op.signature = Some([vec!["Q".into(); qubit_count], vec!["B".into(); bit_count]].concat()); + op } diff --git a/tket2/src/serialize/pytket/encoder/config.rs b/tket2/src/serialize/pytket/encoder/config.rs new file mode 100644 index 000000000..83e050e6c --- /dev/null +++ b/tket2/src/serialize/pytket/encoder/config.rs @@ -0,0 +1,252 @@ +//! Configuration for converting [`Circuit`]s into +//! [`tket_json_rs::circuit_json::SerialCircuit`] +//! +//! A configuration struct contains a list of custom emitters that define +//! translations of HUGR operations and types into pytket primitives. + +use std::collections::{BTreeSet, HashMap, VecDeque}; + +use hugr::extension::{ExtensionId, ExtensionSet}; +use hugr::ops::{ExtensionOp, Value}; +use hugr::types::{Signature, SumType, Type, TypeEnum}; + +use crate::serialize::pytket::extension::{ + FloatEmitter, PreludeEmitter, RotationEmitter, Tk1Emitter, Tk2Emitter, +}; +use crate::serialize::pytket::{PytketEmitter, Tk1ConvertError}; +use crate::Circuit; + +use super::value_tracker::RegisterCount; +use super::{Tk1EncoderContext, TrackedValues}; +use hugr::extension::prelude::bool_t; +use hugr::HugrView; +use itertools::Itertools; + +/// Default encoder configuration for [`Circuit`]s. +/// +/// Contains emitters for std and tket2 operations. +pub fn default_encoder_config() -> Tk1EncoderConfig { + let mut config = Tk1EncoderConfig::new(); + config.add_emitter(PreludeEmitter); + config.add_emitter(FloatEmitter); + config.add_emitter(RotationEmitter); + config.add_emitter(Tk1Emitter); + config.add_emitter(Tk2Emitter); + config +} + +/// Configuration for converting [`Circuit`] into +/// [`tket_json_rs::circuit_json::SerialCircuit`]. +/// +/// Contains custom emitters that define translations for HUGR operations, +/// types, and consts into pytket primitives. +#[derive(derive_more::Debug)] +#[debug(bounds(H: HugrView))] +pub struct Tk1EncoderConfig { + /// Operation emitters + #[debug(skip)] + pub(super) emitters: Vec>>, + /// Pre-computed map from extension ids to corresponding emitters in + /// `emitters`, identified by their index. + #[debug("{:?}", extension_emitters.keys().collect_vec())] + extension_emitters: HashMap>, + /// Emitters that request to be called for all operations. + no_extension_emitters: Vec, +} + +impl Tk1EncoderConfig { + /// Create a new [`Tk1EncoderConfig`] with no encoders. + pub fn new() -> Self { + Self { + emitters: vec![], + extension_emitters: HashMap::new(), + no_extension_emitters: vec![], + } + } + + /// Add an encoder to the configuration. + pub fn add_emitter(&mut self, encoder: impl PytketEmitter + 'static) { + let idx = self.emitters.len(); + + match encoder.extensions() { + Some(extensions) => { + for ext in extensions { + self.extension_emitters.entry(ext).or_default().push(idx); + } + } + // If the encoder does not specify an extension, it will be called + // for all operations. + None => self.no_extension_emitters.push(idx), + } + + self.emitters.push(Box::new(encoder)); + } + + /// List the extensions supported by the encoders. + /// + /// Some encoders may not specify an extension, in which case they will be called + /// for all operations irrespectively of the list returned here. + /// + /// Use [`Tk1EncoderConfig::add_emitter`] to extend this list. + pub fn supported_extensions(&self) -> impl Iterator { + self.extension_emitters.keys() + } + + /// Encode a HUGR operation using the registered custom encoders. + /// + /// Returns `true` if the operation was successfully converted. If that is + pub(super) fn op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + let mut result = false; + let extension = op.def().extension_id(); + for enc in self.emitters_for_extension(extension) { + if enc.op_to_pytket(node, op, circ, encoder)? { + result = true; + break; + } + } + Ok(result) + } + + /// Translate a HUGR type into a count of qubits, bits, and parameters, + /// using the registered custom encoders. + /// + /// Only tuple sums, bools, and custom types are supported. + /// Other types will return `None`. + pub fn type_to_pytket( + &self, + typ: &Type, + ) -> Result, Tk1ConvertError> { + match typ.as_type_enum() { + TypeEnum::Sum(sum) => { + if typ == &bool_t() { + return Ok(Some(RegisterCount { + qubits: 0, + bits: 1, + params: 0, + })); + } + if let Some(tuple) = sum.as_tuple() { + let count: Result, Tk1ConvertError> = tuple + .iter() + .map(|ty| { + match ty.clone().try_into() { + Ok(ty) => Ok(self.type_to_pytket(&ty)?), + // Sum types with row variables (variable tuple lengths) are not supported. + Err(_) => Ok(None), + } + }) + .sum(); + return count; + } + } + TypeEnum::Extension(custom) => { + let type_ext = custom.extension(); + for encoder in self.emitters_for_extension(type_ext) { + if let Some(count) = encoder.type_to_pytket(custom)? { + return Ok(Some(count)); + } + } + } + _ => {} + } + Ok(None) + } + + /// Encode a const value into the pytket context using the registered custom + /// encoders. + /// + /// Returns the values associated to the loaded constant, or `None` if the + /// constant could not be encoded. + pub(super) fn const_to_pytket( + &self, + value: &Value, + encoder: &mut Tk1EncoderContext, + ) -> Result, Tk1ConvertError> { + let mut values = TrackedValues::default(); + let mut queue = VecDeque::from([value]); + while let Some(value) = queue.pop_front() { + match value { + Value::Sum(sum) => { + if sum.sum_type == SumType::new_unary(2) { + //TODO: Add a bit and sets its value based on sum.tag + return Ok(None); + } + if sum.sum_type.as_tuple().is_some() { + for v in sum.values.iter() { + queue.push_back(v); + } + } + } + Value::Extension { e: opaque } => { + // Collect all extensions required to define the type. + let typ = opaque.value().get_type(); + // TODO: Use Type::used_extensions once it gets published. + // See + let type_exts = Signature::new(vec![], vec![typ]) + .used_extensions() + .unwrap_or_else(|e| { + panic!("Tried to encode a type with partially extension. {e}"); + }); + let exts_set = ExtensionSet::from_iter(type_exts.ids().cloned()); + + let mut encoded = false; + for e in self.emitters_for_extensions(&exts_set) { + if let Some(vs) = e.const_to_pytket(opaque, encoder)? { + values.append(vs); + encoded = true; + break; + } + } + if !encoded { + return Ok(None); + } + } + _ => return Ok(None), + } + } + Ok(Some(values)) + } + + /// Lists the emitters that can handle a given extension. + fn emitters_for_extension( + &self, + ext: &ExtensionId, + ) -> impl Iterator>> { + self.extension_emitters + .get(ext) + .into_iter() + .flatten() + .chain(self.no_extension_emitters.iter()) + .map(move |idx| &self.emitters[*idx]) + } + + /// Lists the emitters that can handle a given set of extensions. + fn emitters_for_extensions( + &self, + exts: &ExtensionSet, + ) -> impl Iterator>> { + let emitter_ids: BTreeSet = exts + .iter() + .flat_map(|ext| self.extension_emitters.get(ext).into_iter().flatten()) + .chain(self.no_extension_emitters.iter()) + .copied() + .collect(); + emitter_ids.into_iter().map(move |idx| &self.emitters[idx]) + } +} + +impl Default for Tk1EncoderConfig { + fn default() -> Self { + Self { + emitters: Default::default(), + extension_emitters: Default::default(), + no_extension_emitters: Default::default(), + } + } +} diff --git a/tket2/src/serialize/pytket/encoder/unit_generator.rs b/tket2/src/serialize/pytket/encoder/unit_generator.rs new file mode 100644 index 000000000..2c736b9e6 --- /dev/null +++ b/tket2/src/serialize/pytket/encoder/unit_generator.rs @@ -0,0 +1,44 @@ +//! This module contains the [`RegisterUnitGenerator`] struct, which is used to +//! generate fresh pytket register names for qubits and bits in a circuit. + +use tket_json_rs::register::ElementId as RegisterUnit; + +/// A utility class for finding new unused qubit/bit names. +#[derive(Debug, Clone, Default)] +pub struct RegisterUnitGenerator { + /// The next index to use for a new register. + next_unit: u16, + /// The register name to use. + register: String, +} + +impl RegisterUnitGenerator { + /// Create a new [`RegisterUnitGenerator`] + /// + /// Scans the set of existing registers to find the last used index, and + /// starts generating new unit names from there. + pub fn new<'a>( + register: impl ToString, + existing: impl IntoIterator, + ) -> Self { + let register = register.to_string(); + let mut last_unit: Option = None; + for reg in existing { + if reg.0 != register { + continue; + } + last_unit = Some(last_unit.unwrap_or_default().max(reg.1[0] as u16)); + } + RegisterUnitGenerator { + register, + next_unit: last_unit.map_or(0, |i| i + 1), + } + } + + /// Returns a fresh register unit. + pub fn next(&mut self) -> RegisterUnit { + let unit = self.next_unit; + self.next_unit += 1; + RegisterUnit(self.register.clone(), vec![unit as i64]) + } +} diff --git a/tket2/src/serialize/pytket/encoder/unsupported_tracker.rs b/tket2/src/serialize/pytket/encoder/unsupported_tracker.rs new file mode 100644 index 000000000..62133097e --- /dev/null +++ b/tket2/src/serialize/pytket/encoder/unsupported_tracker.rs @@ -0,0 +1,119 @@ +//! Tracking of subgraphs of unsupported nodes in the hugr. + +use std::collections::{BTreeSet, HashMap}; + +use hugr::core::HugrNode; +use hugr::HugrView; +use petgraph::unionfind::UnionFind; + +use crate::Circuit; + +/// A structure for tracking nodes in the hugr that cannot be encoded as TKET1 +/// operations. +/// +/// The nodes are accumulated in connected components of the hugr. When they +/// cannot be grown further, each component is encoded as a single TKET1 barrier +/// containing the unsupported operations as metadata. +#[derive(Debug, Clone)] +pub struct UnsupportedTracker { + /// Unsupported nodes in the hugr. + /// + /// Stores the index of each node in [`Self::components`]. + /// + /// Once a node has been extracted, it is removed from this map. + nodes: HashMap, + /// A UnionFind structure for tracking connected components of `Self::nodes`. + components: UnionFind, +} + +/// The connected component ID of a node. +/// +/// Multiple IDs may be merged together via the union-find structure in +/// [`UnsupportedTracker::components`]. +type ComponentId = usize; + +#[derive(Debug, Clone, Copy, Default)] +struct UnsupportedNode { + /// The index of the node in [`UnsupportedTracker::components`]. + component: ComponentId, +} + +impl UnsupportedTracker { + /// Create a new [`UnsupportedTracker`]. + pub fn new(_circ: &Circuit) -> Self { + Self { + nodes: HashMap::new(), + components: UnionFind::new_empty(), + } + } + + /// Returns `true` if the node is tracked as unsupported. + pub fn is_unsupported(&self, node: N) -> bool { + self.nodes.contains_key(&node) + } + + /// Record an unsupported node in the hugr. + pub fn record_node(&mut self, node: N, circ: &Circuit>) { + let node_data = UnsupportedNode { + component: self.components.new_set(), + }; + self.nodes.insert(node, node_data); + + // Take the union of the component with any currently tracked incoming + // neighbour. + for neighbour in circ.hugr().input_neighbours(node) { + if let Some(neigh_data) = self.nodes.get(&neighbour) { + self.components + .union(neigh_data.component, node_data.component); + } + } + } + + /// Returns the connected component of a node, and marks its elements as + /// extracted. + /// + /// Once a component has been extracted, no new nodes can be added to it and + /// calling [`UnsupportedTracker::record_node`] will use a new component + /// instead. + pub fn extract_component(&mut self, node: N) -> BTreeSet { + let node_data = self.nodes.remove(&node).unwrap(); + let component = node_data.component; + let representative = self.components.find_mut(component); + + // Compute the nodes in the component, and mark them as extracted. + // + // TODO: Implement efficient iteration over the nodes in a component on petgraph, + // and use it here. For now we just traverse all unextracted nodes. + let mut nodes = BTreeSet::new(); + nodes.insert(node); + for (&n, data) in &self.nodes { + if self.components.find_mut(data.component) == representative { + nodes.insert(n); + } + } + for n in &nodes { + self.nodes.remove(n); + } + + nodes + } + + /// Returns an iterator over the unextracted nodes in the tracker. + pub fn iter(&self) -> impl Iterator + '_ { + self.nodes.keys().copied() + } + + /// Returns `true` if there are no unextracted components in the tracker. + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } +} + +impl Default for UnsupportedTracker { + fn default() -> Self { + Self { + nodes: HashMap::new(), + components: UnionFind::new_empty(), + } + } +} diff --git a/tket2/src/serialize/pytket/encoder/value_tracker.rs b/tket2/src/serialize/pytket/encoder/value_tracker.rs new file mode 100644 index 000000000..16d543f94 --- /dev/null +++ b/tket2/src/serialize/pytket/encoder/value_tracker.rs @@ -0,0 +1,666 @@ +//! Tracker for pytket values associated to wires in a hugr being encoded. +//! +//! Values can be qubits or bits (identified by a [`tket_json_rs::register::ElementId`]), +//! or a string-encoded parameter expression. +//! +//! Wires in the hugr may be associated with multiple values. +//! Qubit and bit wires map to a single register element, and float/rotation wires map to a string parameter. +//! But custom operations (e.g. arrays / sums) may map to multiple things. +//! +//! Extensions can define which elements they map to + +use std::borrow::Cow; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; + +use hugr::core::HugrNode; +use hugr::{HugrView, Wire}; +use itertools::Itertools; +use tket_json_rs::circuit_json; +use tket_json_rs::register::ElementId as RegisterUnit; + +use crate::circuit::Circuit; +use crate::serialize::pytket::{ + OpConvertError, RegisterHash, Tk1ConvertError, METADATA_B_REGISTERS, METADATA_INPUT_PARAMETERS, +}; + +use super::unit_generator::RegisterUnitGenerator; +use super::{ + Tk1EncoderConfig, METADATA_B_OUTPUT_REGISTERS, METADATA_Q_OUTPUT_REGISTERS, + METADATA_Q_REGISTERS, +}; + +/// A structure for tracking qubits used in the circuit being encoded. +/// +/// Nb: Although `tket-json-rs` has a "Register" struct, it's actually +/// an identifier for single qubits/bits in the `Register::0` register. +/// We rename it to `RegisterUnit` here to avoid confusion. +#[derive(derive_more::Debug, Clone)] +#[debug(bounds(N: std::fmt::Debug))] +pub struct ValueTracker { + /// List of generated qubit register names. + qubits: Vec, + /// List of generated bit register names. + bits: Vec, + /// List of seen parameters. + params: Vec, + + /// The tracked data for a wire in the hugr. + /// + /// Contains an ordered list of values associated with it, + /// and a counter of unexplored neighbours used to prune the map + /// once the wire is fully explored. + wires: BTreeMap, TrackedWire>, + + /// A fixed order for the output qubits. This is typically used by tket1 to + /// define implicit qubit permutations at the end of the circuit. + /// + /// When a circuit gets decoded from pytket, we store the order in a + /// [`METADATA_Q_OUTPUT_REGISTERS`] metadata entry. + output_qubits: Vec, + /// A fixed order for the output qubits. This is typically used by tket1 to + /// define implicit qubit permutations at the end of the circuit. + /// + /// When a circuit gets decoded from pytket, we store the order in a + /// [`METADATA_B_OUTPUT_REGISTERS`] metadata entry. + #[allow(unused)] + output_bits: Vec, + + /// Qubits in `qubits` that are not currently registered to any wire. + /// + /// We draw names from here when a new qubit name is needed, before + /// resorting to the `qubit_reg_generator`. + unused_qubits: BTreeSet, + /// Bits in `bits` that are not currently registered to any wire. + /// + /// We draw names from here when a new bit name is needed, before + /// resorting to the `bit_reg_generator`. + unused_bits: BTreeSet, + + /// A generator of new registers units to use for qubit wires. + qubit_reg_generator: RegisterUnitGenerator, + /// A generator of new registers units to use for bit wires. + bit_reg_generator: RegisterUnitGenerator, +} + +/// A lightweight identifier for a qubit value. +/// +/// Contains an index into the `qubits` array of [`ValueTracker`]. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display, +)] +#[display("qubit#{}", self.0)] +pub struct TrackedQubit(usize); + +/// A lightweight identifier for a bit value. +/// +/// Contains an index into the `bits` array of [`ValueTracker`]. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display, +)] +#[display("bit#{}", self.0)] +pub struct TrackedBit(usize); + +/// A lightweight identifier for a parameter value. +/// +/// Contains an index into the `params` array of [`ValueTracker`]. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default, derive_more::Display, +)] +#[display("param#{}", self.0)] +pub struct TrackedParam(usize); + +/// A lightweight identifier for a qubit/bit/parameter value. +/// +/// Contains an index into the corresponding value array in [`ValueTracker`]. +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + derive_more::From, + derive_more::Display, +)] +#[non_exhaustive] +pub enum TrackedValue { + /// A qubit value. + /// + /// Index into the `qubits` array of [`ValueTracker`]. + Qubit(TrackedQubit), + /// A bit value. + /// + /// Index into the `bits` array of [`ValueTracker`]. + Bit(TrackedBit), + /// A parameter value. + /// + /// Index into the `params` array of [`ValueTracker`]. + Param(TrackedParam), +} + +/// Lists of tracked values, separated by type. +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub struct TrackedValues { + /// Tracked qubit values. + pub qubits: Vec, + /// Tracked bit values. + pub bits: Vec, + /// Tracked parameter values. + pub params: Vec, +} + +/// Data associated with a tracked wire in the hugr. +#[derive(Debug, Clone)] +struct TrackedWire { + /// The values associated with the wire. + /// + /// This is a list of [`TrackedValue`]s, which can be qubits, bits, or + /// parameters. + /// + /// If the wire type was not translatable to pytket values, this attribute + /// will be `None`. + pub(self) values: Option>, + /// The number of unexplored neighbours of the wire. + /// + /// This is used to prune the [`ValueTracker::wires`] map once the wire is + /// fully explored. + pub(self) unexplored_neighbours: usize, +} + +/// A count of pytket qubits, bits, and sympy parameters. +/// +/// Used as return value for [`TrackedValues::count`]. +#[derive( + Clone, + Copy, + PartialEq, + Eq, + Hash, + Debug, + Default, + derive_more::Display, + derive_more::Add, + derive_more::Sub, + derive_more::Sum, +)] +#[display("{qubits} qubits, {bits} bits, {params} parameters")] +#[non_exhaustive] +pub struct RegisterCount { + /// Amount of qubits. + pub qubits: usize, + /// Amount of bits. + pub bits: usize, + /// Amount of sympy parameters. + pub params: usize, +} + +/// The result finalizing the value tracker. +/// +/// Contains the final list of qubit and bit registers, and the implicit +/// permutation of the output registers. +#[derive(Debug, Clone)] +pub struct ValueTrackerResult { + /// The final list of qubit registers. + pub qubits: Vec, + /// The final list of bit registers. + pub bits: Vec, + /// The implicit permutation of the qubit registers. + pub qubit_permutation: Vec, +} + +impl ValueTracker { + /// Create a new [`ValueTracker`] from the inputs of a [`Circuit`]. + /// + /// Reads a number of metadata values from the circuit root node, if present, to preserve information on circuits produced by + /// decoding a pytket circuit: + /// + /// - `METADATA_Q_REGISTERS`: The qubit input register names. + /// - `METADATA_Q_OUTPUT_REGISTERS`: The reordered qubit output register names. + /// - `METADATA_B_REGISTERS`: The bit input register names. + /// - `METADATA_B_OUTPUT_REGISTERS`: The reordered bit output register names. + /// - `METADATA_INPUT_PARAMETERS`: The input parameter names. + /// + pub(super) fn new>( + circ: &Circuit, + config: &Tk1EncoderConfig, + ) -> Result> { + let param_variable_names: Vec = + read_metadata_json_list(circ, METADATA_INPUT_PARAMETERS); + let mut tracker = ValueTracker { + qubits: read_metadata_json_list(circ, METADATA_Q_REGISTERS), + bits: read_metadata_json_list(circ, METADATA_B_REGISTERS), + params: Vec::with_capacity(param_variable_names.len()), + wires: BTreeMap::new(), + output_qubits: read_metadata_json_list(circ, METADATA_Q_OUTPUT_REGISTERS), + output_bits: read_metadata_json_list(circ, METADATA_B_OUTPUT_REGISTERS), + unused_qubits: BTreeSet::new(), + unused_bits: BTreeSet::new(), + qubit_reg_generator: RegisterUnitGenerator::default(), + bit_reg_generator: RegisterUnitGenerator::default(), + }; + + if !tracker.output_qubits.is_empty() { + let inputs: HashSet<_> = tracker.qubits.iter().cloned().collect(); + for q in &tracker.output_qubits { + if !inputs.contains(q) { + tracker.qubits.push(q.clone()); + } + } + } + tracker.unused_qubits = (0..tracker.qubits.len()).map(TrackedQubit).collect(); + tracker.unused_bits = (0..tracker.bits.len()).map(TrackedBit).collect(); + tracker.qubit_reg_generator = RegisterUnitGenerator::new("q", tracker.qubits.iter()); + tracker.bit_reg_generator = RegisterUnitGenerator::new("c", tracker.bits.iter()); + + // Generator of input parameter variable names. + let existing_param_vars: HashSet = param_variable_names.iter().cloned().collect(); + let mut param_gen = param_variable_names.into_iter().chain( + (0..) + .map(|i| format!("f{}", i)) + .filter(|name| !existing_param_vars.contains(name)), + ); + + // Register the circuit's inputs with the tracker. + let inp_node = circ.input_node(); + let signature = circ.circuit_signature(); + for (port, typ) in circ + .hugr() + .node_outputs(inp_node) + .zip(signature.input().iter()) + { + let wire = Wire::new(inp_node, port); + let Some(count) = config.type_to_pytket(typ)? else { + // If the input has a non-serializable type, it gets skipped. + // + // TODO: We should store the original signature somewhere in the circuit, + // so it can be reconstructed later. + tracker.register_wire::(wire, [], circ)?; + continue; + }; + + let mut wire_values = Vec::with_capacity(count.total()); + for _ in 0..count.qubits { + let qb = tracker.new_qubit(); + wire_values.push(TrackedValue::Qubit(qb)); + } + for _ in 0..count.bits { + let bit = tracker.new_bit(); + wire_values.push(TrackedValue::Bit(bit)); + } + for _ in 0..count.params { + let param = tracker.new_param(param_gen.next().unwrap()); + wire_values.push(TrackedValue::Param(param)); + } + + tracker.register_wire(wire, wire_values, circ)?; + } + + Ok(tracker) + } + + /// Create a new qubit register name. + /// + /// Picks unused names from the `qubits` list, if available, or generates + /// a new one with the internal generator. + pub fn new_qubit(&mut self) -> TrackedQubit { + self.unused_qubits.pop_first().unwrap_or_else(|| { + self.qubits.push(self.qubit_reg_generator.next()); + TrackedQubit(self.qubits.len() - 1) + }) + } + + /// Create a new bit register name. + /// + /// Picks unused names from the `bits` list, if available, or generates + /// a new one with the internal generator. + pub fn new_bit(&mut self) -> TrackedBit { + self.unused_bits.pop_first().unwrap_or_else(|| { + self.bits.push(self.bit_reg_generator.next()); + TrackedBit(self.bits.len() - 1) + }) + } + + /// Register a new parameter string expression. + /// + /// Returns a unique identifier for the expression. + pub fn new_param(&mut self, expression: impl ToString) -> TrackedParam { + self.params.push(expression.to_string()); + TrackedParam(self.params.len() - 1) + } + + /// Associate a list of values with a wire. + /// + /// Linear qubit IDs can be reused to mark the new position of the qubit in the + /// circuit. + /// Bit types are not linear, so each [`TrackedBit`] is associated with a unique bit + /// state in the circuit. The IDs may only be reused when no more users of the bit are + /// present in the circuit. + /// + /// ### Panics + /// + /// If the wire is already associated with a different set of values. + pub fn register_wire>( + &mut self, + wire: Wire, + values: impl IntoIterator, + circ: &Circuit>, + ) -> Result<(), OpConvertError> { + let values = values.into_iter().map(|v| v.into()).collect_vec(); + + // Remove any qubit/bit used here from the unused set. + for value in &values { + match value { + TrackedValue::Qubit(qb) => { + self.unused_qubits.remove(qb); + } + TrackedValue::Bit(bit) => { + self.unused_bits.remove(bit); + } + TrackedValue::Param(_) => {} + } + } + + let unexplored_neighbours = circ.hugr().linked_ports(wire.node(), wire.source()).count(); + let tracked = TrackedWire { + values: Some(values), + unexplored_neighbours, + }; + if self.wires.insert(wire, tracked).is_some() { + return Err(OpConvertError::WireAlreadyHasValues { wire }); + } + + if unexplored_neighbours == 0 { + // We can unregister the wire immediately, since it has no unexplored + // neighbours. This will free up the qubit and bit registers associated with it. + self.unregister_wire(wire) + .expect("Wire should be registered in the tracker"); + } + + Ok(()) + } + + /// Returns the values associated with a wire. + /// + /// Marks the port connection as explored. When all ports connected to the wire + /// are explored, the wire is removed from the tracker. + /// + /// To avoid this use `peek_wire_values` instead. + /// + /// Returns `None` if the wire did not have any values associated with it, + /// or if it had a type that cannot be translated into pytket values. + pub(super) fn wire_values(&mut self, wire: Wire) -> Option> { + let values = self.wires.get(&wire)?; + if values.unexplored_neighbours != 1 { + let wire = self.wires.get_mut(&wire).unwrap(); + wire.unexplored_neighbours -= 1; + let values = wire.values.as_ref()?; + return Some(Cow::Borrowed(values)); + } + let values = self.unregister_wire(wire)?; + Some(Cow::Owned(values)) + } + + /// Returns the values associated with a wire. + /// + /// The wire is not marked as explored. To improve performance, make sure to call + /// [`ValueTracker::wire_values`] once per wire connection. + /// + /// Returns `None` if the wire did not have any values associated with it, + /// or if it had a type that cannot be translated into pytket values. + pub(super) fn peek_wire_values(&self, wire: Wire) -> Option<&[TrackedValue]> { + let wire = self.wires.get(&wire)?; + let values = wire.values.as_ref()?; + Some(&values[..]) + } + + /// Unregister a wire, freeing up the qubit and bit registers associated with it. + /// + /// Panics if the wire is not registered. + fn unregister_wire(&mut self, wire: Wire) -> Option> { + let wire = self.wires.remove(&wire).unwrap(); + let values = wire.values?; + + // Free up the qubit and bit registers associated with the wire. + for value in &values { + match value { + TrackedValue::Qubit(qb) => { + self.unused_qubits.insert(*qb); + } + TrackedValue::Bit(bit) => { + self.unused_bits.insert(*bit); + } + TrackedValue::Param(_) => {} + } + } + + Some(values) + } + + /// Returns the qubit register associated with a qubit value. + pub fn qubit_register(&self, qb: TrackedQubit) -> &RegisterUnit { + &self.qubits[qb.0] + } + + /// Returns the bit register associated with a bit value. + pub fn bit_register(&self, bit: TrackedBit) -> &RegisterUnit { + &self.bits[bit.0] + } + + /// Returns the string-encoded parameter expression associated with a parameter value. + pub fn param_expression(&self, param: TrackedParam) -> &str { + &self.params[param.0] + } + + /// Finish the tracker and return the final list of qubit and bit registers. + /// + /// Looks at the circuit's output node to determine the final order of output. + pub(super) fn finish( + self, + circ: &Circuit>, + ) -> Result> { + // Ordered list of qubits and bits at the output of the circuit. + let mut qubit_outputs = Vec::with_capacity(self.qubits.len() - self.unused_qubits.len()); + let mut bit_outputs = Vec::with_capacity(self.bits.len() - self.unused_bits.len()); + for (node, port) in circ.hugr().all_linked_outputs(circ.output_node()) { + let wire = Wire::new(node, port); + let values = self + .peek_wire_values(wire) + .ok_or_else(|| OpConvertError::WireHasNoValues { wire })?; + for value in values { + match value { + TrackedValue::Qubit(qb) => qubit_outputs.push(self.qubit_register(*qb).clone()), + TrackedValue::Bit(bit) => bit_outputs.push(self.bit_register(*bit).clone()), + TrackedValue::Param(_) => { + // Parameters are not part of a pytket circuit output. + // We ignore them here. + } + } + } + } + + // Ensure that all original outputs are present in the pytket circuit. + if qubit_outputs.len() < self.output_qubits.len() { + let qbs = self + .unused_qubits + .iter() + .take(self.output_qubits.len() - qubit_outputs.len()) + .map(|&qb| self.qubit_register(qb).clone()); + qubit_outputs.extend(qbs); + } + + // Compute the final register permutations. + let (qubit_outputs, qubit_permutation) = + compute_final_permutation(qubit_outputs, &self.qubits, &self.output_qubits); + + Ok(ValueTrackerResult { + qubits: qubit_outputs, + bits: bit_outputs, + qubit_permutation, + }) + } +} + +impl TrackedValues { + /// Returns the number of qubits, bits, and parameters in the list. + pub fn count(&self) -> RegisterCount { + RegisterCount::new(self.qubits.len(), self.bits.len(), self.params.len()) + } + + /// Iterate over the values in the list. + pub fn iter(&self) -> impl Iterator + '_ { + self.qubits + .iter() + .map(|&qb| TrackedValue::Qubit(qb)) + .chain(self.bits.iter().map(|&bit| TrackedValue::Bit(bit))) + .chain(self.params.iter().map(|¶m| TrackedValue::Param(param))) + } + + /// Append tracked values to the list. + pub fn append(&mut self, other: TrackedValues) { + self.qubits.extend(other.qubits); + self.bits.extend(other.bits); + self.params.extend(other.params); + } +} + +impl IntoIterator for TrackedValues { + type Item = TrackedValue; + + type IntoIter = std::iter::Chain< + std::iter::Chain< + itertools::MapInto, TrackedValue>, + itertools::MapInto, TrackedValue>, + >, + itertools::MapInto, TrackedValue>, + >; + + fn into_iter(self) -> Self::IntoIter { + self.qubits + .into_iter() + .map_into() + .chain(self.bits.into_iter().map_into()) + .chain(self.params.into_iter().map_into()) + } +} + +impl RegisterCount { + /// Create a new [`RegisterCount`] from the number of qubits, bits, and parameters. + pub const fn new(qubits: usize, bits: usize, params: usize) -> Self { + RegisterCount { + qubits, + bits, + params, + } + } + + /// Create a new [`RegisterCount`] containing only qubits. + pub const fn only_qubits(qubits: usize) -> Self { + RegisterCount { + qubits, + bits: 0, + params: 0, + } + } + + /// Create a new [`RegisterCount`] containing only bits. + pub const fn only_bits(bits: usize) -> Self { + RegisterCount { + qubits: 0, + bits, + params: 0, + } + } + + /// Create a new [`RegisterCount`] containing only parameters. + pub const fn only_params(params: usize) -> Self { + RegisterCount { + qubits: 0, + bits: 0, + params, + } + } + + /// Returns the number of qubits, bits, and parameters associated with the wire. + pub const fn total(&self) -> usize { + self.qubits + self.bits + self.params + } +} + +/// Read a json-encoded vector of values from the circuit's root metadata. +fn read_metadata_json_list( + circ: &Circuit, + metadata_key: &str, +) -> Vec { + let Some(value) = circ.hugr().get_metadata(circ.parent(), metadata_key) else { + return vec![]; + }; + + serde_json::from_value::>(value.clone()).unwrap_or_default() +} + +/// Compute the final unit permutation for a circuit. +/// +/// Arguments: +/// - `all_inputs`: The ordered list of registers declared in the circuit. +/// - `actual_outputs`: The final order of output registers, computed from the +/// wires at the output node of the circuit. +/// - `declared_outputs`: The list of output registers declared at the start of +/// the circuit, potentially in a different order than `declared_inputs`. +/// +/// Returns: +/// - The final list of output registers, including any extra registers +/// discarded mid-circuit. +/// - The final permutation of the output registers. +pub(super) fn compute_final_permutation( + mut actual_outputs: Vec, + all_inputs: &[RegisterUnit], + declared_outputs: &[RegisterUnit], +) -> (Vec, Vec) { + let mut declared_outputs: Vec<&RegisterUnit> = declared_outputs.iter().collect(); + let mut declared_outputs_hashes: HashSet = declared_outputs + .iter() + .map(|®| RegisterHash::from(reg)) + .collect(); + let mut actual_outputs_hashes: HashSet = + actual_outputs.iter().map(RegisterHash::from).collect(); + let mut input_hashes: HashMap = HashMap::default(); + for (i, inp) in all_inputs.iter().enumerate() { + let hash = inp.into(); + input_hashes.insert(hash, i); + // Fix the declared output order of registers. + if !declared_outputs_hashes.contains(&hash) { + declared_outputs.push(inp); + declared_outputs_hashes.insert(hash); + } + } + // Extend `actual_outputs` with extra registers seen in the circuit. + for reg in all_inputs { + let hash = reg.into(); + if !actual_outputs_hashes.contains(&hash) { + actual_outputs.push(reg.clone()); + actual_outputs_hashes.insert(hash); + } + } + + // Compute the final permutation. + // + // For each element `reg` at the output of the circuit, we find its position `i` at the input, + // and find out the pytket output register associated with that position in the `declared_outputs` list. + let permutation = actual_outputs + .iter() + .map(|reg| { + let hash = reg.into(); + let i = input_hashes.get(&hash).unwrap(); + let out = declared_outputs[*i].clone(); + circuit_json::ImplicitPermutation( + tket_json_rs::register::Qubit { id: reg.clone() }, + tket_json_rs::register::Qubit { id: out }, + ) + }) + .collect_vec(); + + (actual_outputs, permutation) +} diff --git a/tket2/src/serialize/pytket/extension.rs b/tket2/src/serialize/pytket/extension.rs new file mode 100644 index 000000000..1c67ad236 --- /dev/null +++ b/tket2/src/serialize/pytket/extension.rs @@ -0,0 +1,98 @@ +//! Extension encoder/decoders for the tket2 <-> `pytket` conversion. +//! +//! To add a new extension encoder, implement the [`PytketEmitter`] trait and add +//! it to the [`Tk1EncoderConfig`](crate::serialize::pytket::Tk1EncoderConfig) +//! used for decoding. +//! +//! This module contains decoders for some common extensions. The +//! [`default_encoder_config`](crate::serialize::pytket::default_encoder_config) +//! creates a configuration with the decoders for the standard library and tket2 +//! extension. + +mod float; +mod prelude; +mod rotation; +mod tk1; +mod tk2; + +pub use float::FloatEmitter; +use hugr::ops::constant::OpaqueValue; +pub use prelude::PreludeEmitter; +pub use rotation::RotationEmitter; +pub use tk1::Tk1Emitter; +pub use tk2::Tk2Emitter; + +pub(crate) use tk1::OpaqueTk1Op; + +use super::encoder::{RegisterCount, TrackedValues}; +use super::Tk1EncoderContext; +use crate::serialize::pytket::Tk1ConvertError; +use crate::Circuit; +use hugr::extension::ExtensionId; +use hugr::ops::ExtensionOp; +use hugr::types::CustomType; +use hugr::HugrView; + +/// An encoder of HUGR operations and types that transforms them into pytket +/// primitives. +/// +/// An [encoder configuration](crate::serialize::pytket::Tk1EncoderConfig) +/// contains a list of such encoders. When encountering a type, operation, or +/// constant in the HUGR being encoded, the configuration will call each of +/// the encoders declaring support for the specific extension sequentially until +/// one of them indicates a successful conversion. +pub trait PytketEmitter { + /// The name of the extension this encoder/decoder is for. + /// + /// [`PytketEmitter::op_to_pytket`] and [`PytketEmitter::type_to_pytket`] will + /// only be called for operations/types of these extensions. + /// + /// If the function returns `None`, the encoder will be called for all + /// operations/types irrespective of their extension. + fn extensions(&self) -> Option>; + + /// Given a node in the HUGR circuit and its operation type, try to convert + /// it to a pytket operation and add it to the pytket encoder. + /// + /// Returns `true` if the operation was successfully converted. If that is + /// the case, no further encoders will be called. + /// + /// If the operation is not supported by the encoder, return `false`. It's + /// important not to modify the `encoder` in this case, as that may + /// invalidate the context for other encoders that may be called afterwards. + fn op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + let _ = (node, op, circ, encoder); + Ok(false) + } + + /// Given a HUGR type, return the number of qubits, bits, and parameter + /// expressions of its pytket counterpart. + /// + /// If the type cannot be translated into a list of the aforementioned + /// values, return `None`. Operations dealing with such types will be + /// marked as unsupported and will be serialized as opaque operations. + fn type_to_pytket( + &self, + typ: &CustomType, + ) -> Result, Tk1ConvertError> { + let _ = typ; + Ok(None) + } + + /// Given an opaque constant value, add it to the pytket encoder and return + /// the values to associate to the loaded constant. + fn const_to_pytket( + &self, + value: &OpaqueValue, + encoder: &mut Tk1EncoderContext, + ) -> Result, Tk1ConvertError> { + let _ = (value, encoder); + Ok(None) + } +} diff --git a/tket2/src/serialize/pytket/extension/float.rs b/tket2/src/serialize/pytket/extension/float.rs new file mode 100644 index 000000000..c0c53e07e --- /dev/null +++ b/tket2/src/serialize/pytket/extension/float.rs @@ -0,0 +1,115 @@ +//! Encoder and decoder for floating point operations. + +use super::PytketEmitter; +use crate::serialize::pytket::encoder::{RegisterCount, Tk1EncoderContext, TrackedValues}; +use crate::serialize::pytket::Tk1ConvertError; +use crate::Circuit; +use hugr::extension::simple_op::MakeExtensionOp; +use hugr::extension::ExtensionId; +use hugr::ops::constant::OpaqueValue; +use hugr::ops::ExtensionOp; +use hugr::std_extensions::arithmetic::float_ops::FloatOps; +use hugr::std_extensions::arithmetic::{float_ops, float_types}; +use hugr::HugrView; + +/// Encoder for [prelude](hugr::extension::prelude) operations. +#[derive(Debug, Clone, Default)] +pub struct FloatEmitter; + +impl PytketEmitter for FloatEmitter { + fn extensions(&self) -> Option> { + Some(vec![float_ops::EXTENSION_ID, float_types::EXTENSION_ID]) + } + + fn op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + let Ok(rot_op) = FloatOps::from_extension_op(op) else { + return Ok(false); + }; + + match rot_op { + FloatOps::fadd + | FloatOps::fsub + | FloatOps::fneg + | FloatOps::fmul + | FloatOps::fdiv + | FloatOps::fpow + | FloatOps::ffloor + | FloatOps::fceil + | FloatOps::fround + | FloatOps::fmax + | FloatOps::fmin + | FloatOps::fabs => { + encoder.emit_transparent_node(node, circ, |ps| { + match FloatEmitter::encode_rotation_op(&rot_op, ps.input_params) { + Some(s) => vec![s], + None => Vec::new(), + } + })?; + Ok(true) + } + _ => Ok(false), + } + } + + fn type_to_pytket( + &self, + typ: &hugr::types::CustomType, + ) -> Result, Tk1ConvertError<::Node>> { + match typ.name() == &float_types::FLOAT_TYPE_ID { + true => Ok(Some(RegisterCount::only_params(1))), + false => Ok(None), + } + } + + fn const_to_pytket( + &self, + value: &OpaqueValue, + encoder: &mut Tk1EncoderContext, + ) -> Result, Tk1ConvertError> { + let Some(const_f) = value.value().downcast_ref::() else { + return Ok(None); + }; + + let float = const_f.value(); + // Special case for pi rotations + let val = if float == std::f64::consts::PI { + "pi".to_string() + } else { + float.to_string() + }; + + let param = encoder.values.new_param(val); + + let mut values = TrackedValues::default(); + values.params.push(param); + Ok(Some(values)) + } +} + +impl FloatEmitter { + /// Encode a rotation operation into a pytket param expression. + fn encode_rotation_op(op: &FloatOps, inputs: &[String]) -> Option { + let s = match op { + FloatOps::fadd => format!("({} + {})", inputs[0], inputs[1]), + FloatOps::fsub => format!("({} - {})", inputs[0], inputs[1]), + FloatOps::fneg => format!("(-{})", inputs[0]), + FloatOps::fmul => format!("({} * {})", inputs[0], inputs[1]), + FloatOps::fdiv => format!("({} / {})", inputs[0], inputs[1]), + FloatOps::fpow => format!("({} ** {})", inputs[0], inputs[1]), + FloatOps::ffloor => format!("floor({})", inputs[0]), + FloatOps::fceil => format!("ceil({})", inputs[0]), + FloatOps::fround => format!("round({})", inputs[0]), + FloatOps::fmax => format!("max({}, {})", inputs[0], inputs[1]), + FloatOps::fmin => format!("min({}, {})", inputs[0], inputs[1]), + FloatOps::fabs => format!("abs({})", inputs[0]), + _ => return None, + }; + Some(s) + } +} diff --git a/tket2/src/serialize/pytket/extension/prelude.rs b/tket2/src/serialize/pytket/extension/prelude.rs new file mode 100644 index 000000000..52a034329 --- /dev/null +++ b/tket2/src/serialize/pytket/extension/prelude.rs @@ -0,0 +1,92 @@ +//! Encoder and decoder for tket2 operations with native pytket counterparts. + +use super::PytketEmitter; +use crate::serialize::pytket::encoder::{RegisterCount, Tk1EncoderContext}; +use crate::serialize::pytket::Tk1ConvertError; +use crate::Circuit; +use hugr::extension::prelude::{TupleOpDef, PRELUDE_ID}; +use hugr::extension::simple_op::MakeExtensionOp; +use hugr::extension::ExtensionId; +use hugr::ops::ExtensionOp; +use hugr::types::TypeArg; +use hugr::HugrView; + +/// Encoder for [prelude](hugr::extension::prelude) operations. +#[derive(Debug, Clone, Default)] +pub struct PreludeEmitter; + +impl PytketEmitter for PreludeEmitter { + fn extensions(&self) -> Option> { + Some(vec![PRELUDE_ID]) + } + + fn op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + if let Ok(tuple_op) = TupleOpDef::from_extension_op(op) { + return self.tuple_op_to_pytket(node, op, &tuple_op, circ, encoder); + }; + Ok(false) + } + + fn type_to_pytket( + &self, + typ: &hugr::types::CustomType, + ) -> Result, Tk1ConvertError<::Node>> { + match typ.name().as_str() { + "usize" => Ok(Some(RegisterCount::only_bits(64))), + "qubit" => Ok(Some(RegisterCount::only_qubits(1))), + _ => Ok(None), + } + } +} + +impl PreludeEmitter { + /// Encode a prelude tuple operation. + /// + /// These just bundle/unbundle the values of the inputs/outputs. Since + /// pytket types are already flattened, the translation of these is just a + /// no-op. + fn tuple_op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + tuple_op: &TupleOpDef, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + if !matches!(tuple_op, TupleOpDef::MakeTuple | TupleOpDef::UnpackTuple) { + // Unknown operation + return Ok(false); + }; + + // First, check if we are working with supported types. + // + // If any of the types cannot be translated to a pytket type, we return + // false so the operation is marked as unsupported as a whole. + let args = op.args().first(); + match args { + Some(TypeArg::Sequence { elems }) => { + for arg in elems { + let TypeArg::Type { ty } = arg else { + return Ok(false); + }; + let count = encoder.config().type_to_pytket(ty)?; + if count.is_none() { + return Ok(false); + } + } + } + _ => return Ok(false), + }; + + // Now we can gather all inputs and assign them to the node outputs transparently. + encoder.emit_transparent_node(node, circ, |ps| ps.input_params.to_owned())?; + + Ok(true) + } +} diff --git a/tket2/src/serialize/pytket/extension/rotation.rs b/tket2/src/serialize/pytket/extension/rotation.rs new file mode 100644 index 000000000..aacc89086 --- /dev/null +++ b/tket2/src/serialize/pytket/extension/rotation.rs @@ -0,0 +1,95 @@ +//! Encoder and decoder for rotation operations. + +use super::PytketEmitter; +use crate::extension::rotation::{ + ConstRotation, RotationOp, ROTATION_EXTENSION_ID, ROTATION_TYPE_ID, +}; +use crate::serialize::pytket::encoder::{RegisterCount, Tk1EncoderContext, TrackedValues}; +use crate::serialize::pytket::Tk1ConvertError; +use crate::Circuit; +use hugr::extension::simple_op::MakeExtensionOp; +use hugr::extension::ExtensionId; +use hugr::ops::constant::OpaqueValue; +use hugr::ops::ExtensionOp; +use hugr::HugrView; +use itertools::Itertools; + +/// Encoder for [prelude](hugr::extension::prelude) operations. +#[derive(Debug, Clone, Default)] +pub struct RotationEmitter; + +impl PytketEmitter for RotationEmitter { + fn extensions(&self) -> Option> { + Some(vec![ROTATION_EXTENSION_ID]) + } + + fn op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + let Ok(rot_op) = RotationOp::from_extension_op(op) else { + return Ok(false); + }; + + match rot_op { + RotationOp::from_halfturns_unchecked | RotationOp::to_halfturns => { + encoder.emit_transparent_node(node, circ, |ps| vec![ps.input_params[0].clone()])?; + Ok(true) + } + RotationOp::from_halfturns => { + // Unsupported due to having an Option as output. + Ok(false) + } + _ => { + encoder.emit_transparent_node(node, circ, |ps| { + RotationEmitter::encode_rotation_op(&rot_op, ps.input_params) + .into_iter() + .collect_vec() + })?; + Ok(true) + } + } + } + + fn type_to_pytket( + &self, + typ: &hugr::types::CustomType, + ) -> Result, Tk1ConvertError<::Node>> { + match typ.name() == &ROTATION_TYPE_ID { + true => Ok(Some(RegisterCount::only_params(1))), + false => Ok(None), + } + } + + fn const_to_pytket( + &self, + value: &OpaqueValue, + encoder: &mut Tk1EncoderContext, + ) -> Result, Tk1ConvertError> { + let Some(const_f) = value.value().downcast_ref::() else { + return Ok(None); + }; + + let param = encoder.values.new_param(const_f.half_turns()); + + let mut values = TrackedValues::default(); + values.params.push(param); + Ok(Some(values)) + } +} + +impl RotationEmitter { + /// Encode a rotation operation into a pytket param expression. + fn encode_rotation_op(op: &RotationOp, inputs: &[String]) -> Option { + let s = match op { + RotationOp::radd => format!("({} + {})", inputs[0], inputs[1]), + RotationOp::to_halfturns + | RotationOp::from_halfturns_unchecked + | RotationOp::from_halfturns => return None, + }; + Some(s) + } +} diff --git a/tket2/src/serialize/pytket/op/serialised.rs b/tket2/src/serialize/pytket/extension/tk1.rs similarity index 68% rename from tket2/src/serialize/pytket/op/serialised.rs rename to tket2/src/serialize/pytket/extension/tk1.rs index f5e18c5d6..5dbb2ac02 100644 --- a/tket2/src/serialize/pytket/op/serialised.rs +++ b/tket2/src/serialize/pytket/extension/tk1.rs @@ -1,18 +1,61 @@ -//! Wrapper over pytket operations that cannot be represented naturally in tket2. +//! Encoder for pytket operations that cannot be represented naturally in tket2. -use hugr::extension::prelude::{bool_t, qb_t}; +use crate::extension::rotation::rotation_type; +use crate::extension::{TKET1_EXTENSION, TKET1_EXTENSION_ID, TKET1_OP_NAME}; +use crate::serialize::pytket::{Tk1ConvertError, Tk1EncoderContext}; +use crate::Circuit; -use hugr::ops::custom::ExtensionOp; -use hugr::ops::OpType; +use super::PytketEmitter; +use hugr::extension::prelude::{bool_t, qb_t}; +use hugr::extension::ExtensionId; +use hugr::ops::ExtensionOp; use hugr::types::{Signature, TypeArg}; - -use hugr::IncomingPort; -use serde::de::Error; +use hugr::{HugrView, IncomingPort}; use tket_json_rs::circuit_json; -use crate::extension::rotation::rotation_type; -use crate::extension::{TKET1_EXTENSION, TKET1_EXTENSION_ID, TKET1_OP_NAME}; -use crate::serialize::pytket::OpConvertError; +/// Encoder for [TKET1_EXTENSION] operations. +/// +/// That is, operations originating from a pytket circuit that did not have a +/// native HUGR representation and were instead serialized as opaque black-box +/// operations. +#[derive(Debug, Clone, Default)] +pub struct Tk1Emitter; + +impl PytketEmitter for Tk1Emitter { + fn extensions(&self) -> Option> { + Some(vec![TKET1_EXTENSION_ID]) + } + + fn op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + if op.qualified_id() != format!("{TKET1_EXTENSION_ID}.{TKET1_OP_NAME}") { + return Ok(false); + } + let Some(TypeArg::String { arg }) = op.args().first() else { + return Err(Tk1ConvertError::custom( + "Opaque TKET1 operation did not have a json-encoded type argument.", + )); + }; + let op: OpaqueTk1Op = serde_json::from_str(arg)?; + + // Most operations map directly to a pytket one. + encoder.emit_node_command( + node, + circ, + // We don't support opaque pytket operations with parameter outputs. + |_args| Vec::new(), + // Emit the pre-defined pytket operation stored in the metadata. + move |_, _, _| op.serialised_op, + )?; + + Ok(true) + } +} /// A serialized operation, containing the operation type and all its attributes. /// @@ -27,7 +70,8 @@ use crate::serialize::pytket::OpConvertError; #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] pub struct OpaqueTk1Op { /// Internal operation data. - op: circuit_json::Operation, + #[serde(rename = "op")] + pub serialised_op: circuit_json::Operation, /// Number of qubits declared by the operation. pub num_qubits: usize, /// Number of bits declared by the operation. @@ -36,7 +80,7 @@ pub struct OpaqueTk1Op { /// /// If the input is `None`, the parameter does not use a Hugr port and is /// instead stored purely as metadata for the `Operation`. - param_inputs: Vec>, + pub param_inputs: Vec>, /// The number of non-None inputs in `param_inputs`, corresponding to the /// rotation_type() inputs to the Hugr operation. pub num_params: usize, @@ -58,7 +102,7 @@ impl OpaqueTk1Op { Some([vec!["Q".into(); num_qubits], vec!["B".into(); num_bits]].concat()); } let mut op = Self { - op, + serialised_op: op, num_qubits, num_bits, param_inputs: Vec::new(), @@ -68,38 +112,6 @@ impl OpaqueTk1Op { op } - /// Try to convert a tket2 operation into a `OpaqueTk1Op`. - /// - /// Only succeeds if the operation is a [`CustomOp`] containing a tket1 operation - /// from the [`TKET1_EXTENSION_ID`] extension. Returns `None` if the operation - /// is not a tket1 operation. - /// - /// # Errors - /// - /// Returns an [`OpConvertError`] if the operation is a tket1 operation, but it - /// contains invalid data. - pub fn try_from_tket2(op: &OpType) -> Result, OpConvertError> { - // TODO: Check `extensions.contains(&TKET1_EXTENSION_ID)`? - if op.to_string() != format!("{TKET1_EXTENSION_ID}.{TKET1_OP_NAME}") { - return Ok(None); - } - let OpType::ExtensionOp(custom_op) = op else { - assert!( - !matches!(op, OpType::OpaqueOp(_)), - "Opaque Ops should have been resolved into ExtensionOps" - ); - return Ok(None); - }; - let Some(TypeArg::String { arg }) = custom_op.args().first() else { - return Err(serde_json::Error::custom( - "Opaque TKET1 operation did not have a json-encoded type argument.", - ) - .into()); - }; - let op = serde_json::from_str(arg)?; - Ok(Some(op)) - } - /// Compute the signature of the operation. /// /// The signature returned has `num_qubits` qubit inputs, followed by @@ -121,11 +133,6 @@ impl OpaqueTk1Op { self.param_inputs.iter().filter_map(|&i| i) } - /// Returns the lower level `circuit_json::Operation` contained by this struct. - pub fn serialised_op(&self) -> &circuit_json::Operation { - &self.op - } - /// Wraps the op into a [`TKET1_OP_NAME`] opaque operation. pub fn as_extension_op(&self) -> ExtensionOp { let payload = TypeArg::String { @@ -140,7 +147,7 @@ impl OpaqueTk1Op { /// /// Updates the internal `num_params` and `param_inputs` fields. fn compute_param_fields(&mut self) { - let Some(params) = self.op.params.as_ref() else { + let Some(params) = self.serialised_op.params.as_ref() else { self.param_inputs = vec![]; self.num_params = 0; return; diff --git a/tket2/src/serialize/pytket/extension/tk2.rs b/tket2/src/serialize/pytket/extension/tk2.rs new file mode 100644 index 000000000..74df488ed --- /dev/null +++ b/tket2/src/serialize/pytket/extension/tk2.rs @@ -0,0 +1,110 @@ +//! Encoder and decoder for tket2 operations with native pytket counterparts. + +use super::PytketEmitter; +use crate::extension::sympy::SympyOp; +use crate::extension::TKET2_EXTENSION_ID; +use crate::serialize::pytket::encoder::Tk1EncoderContext; +use crate::serialize::pytket::Tk1ConvertError; +use crate::{Circuit, Tk2Op}; +use hugr::extension::simple_op::MakeExtensionOp; +use hugr::extension::ExtensionId; +use hugr::ops::ExtensionOp; +use hugr::{HugrView, Wire}; +use tket_json_rs::optype::OpType as Tk1OpType; + +/// Encoder for [Tk2Op] operations. +#[derive(Debug, Clone, Default)] +pub struct Tk2Emitter; + +impl PytketEmitter for Tk2Emitter { + fn extensions(&self) -> Option> { + Some(vec![TKET2_EXTENSION_ID]) + } + + fn op_to_pytket( + &self, + node: H::Node, + op: &ExtensionOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + if let Ok(tk2op) = Tk2Op::from_extension_op(op) { + self.encode_tk2_op(node, tk2op, circ, encoder) + } else if let Ok(sympy_op) = SympyOp::from_extension_op(op) { + self.encode_sympy_op(node, sympy_op, circ, encoder) + } else { + Ok(false) + } + } +} + +impl Tk2Emitter { + /// Encode a tket2 operation into a pytket operation. + fn encode_tk2_op( + &self, + node: H::Node, + tk2op: Tk2Op, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + let serial_op = match tk2op { + Tk2Op::H => Tk1OpType::H, + Tk2Op::CX => Tk1OpType::CX, + Tk2Op::CY => Tk1OpType::CY, + Tk2Op::CZ => Tk1OpType::CZ, + Tk2Op::CRz => Tk1OpType::CRz, + Tk2Op::T => Tk1OpType::T, + Tk2Op::Tdg => Tk1OpType::Tdg, + Tk2Op::S => Tk1OpType::S, + Tk2Op::Sdg => Tk1OpType::Sdg, + Tk2Op::X => Tk1OpType::X, + Tk2Op::Y => Tk1OpType::Y, + Tk2Op::Z => Tk1OpType::Z, + Tk2Op::Rx => Tk1OpType::Rx, + Tk2Op::Rz => Tk1OpType::Rz, + Tk2Op::Ry => Tk1OpType::Ry, + Tk2Op::Toffoli => Tk1OpType::CCX, + Tk2Op::Reset => Tk1OpType::Reset, + Tk2Op::Measure => Tk1OpType::Measure, + // We translate `MeasureFree` the same way as a `Measure` operation. + // Since the node does not have outputs the qubit/bit will simply be ignored, + // but will appear when collecting the final pytket registers. + Tk2Op::MeasureFree => Tk1OpType::Measure, + // These operations are implicitly supported by the encoding, + // they do not create a new command but just modify the value trackers. + Tk2Op::QAlloc => { + let out_port = circ.hugr().node_outputs(node).next().unwrap(); + let wire = Wire::new(node, out_port); + let qb = encoder.values.new_qubit(); + encoder.values.register_wire(wire, [qb], circ)?; + return Ok(true); + } + // Since the qubit still gets connected at the end of the circuit, + // `QFree` is a no-op. + Tk2Op::QFree => { + return Ok(true); + } + // Unsupported + Tk2Op::TryQAlloc => { + return Ok(false); + } + }; + + // Most operations map directly to a pytket one. + encoder.emit_node(serial_op, node, circ)?; + + Ok(true) + } + + /// Encode a tket2 sympy operation into a pytket operation. + fn encode_sympy_op( + &self, + node: H::Node, + sympy_op: SympyOp, + circ: &Circuit, + encoder: &mut Tk1EncoderContext, + ) -> Result> { + encoder.emit_transparent_node(node, circ, |_| vec![sympy_op.expr.clone()])?; + Ok(true) + } +} diff --git a/tket2/src/serialize/pytket/param.rs b/tket2/src/serialize/pytket/param.rs deleted file mode 100644 index 951e7585e..000000000 --- a/tket2/src/serialize/pytket/param.rs +++ /dev/null @@ -1,4 +0,0 @@ -//! Encoding and decoding pytket operation parameters as hugr ops. - -pub mod decode; -pub mod encode; diff --git a/tket2/src/serialize/pytket/param/encode.rs b/tket2/src/serialize/pytket/param/encode.rs deleted file mode 100644 index 76c30e09a..000000000 --- a/tket2/src/serialize/pytket/param/encode.rs +++ /dev/null @@ -1,122 +0,0 @@ -//! Definitions for encoding hugr graphs into pytket op parameters. - -use hugr::ops::{OpType, Value}; -use hugr::std_extensions::arithmetic::float_ops::FloatOps; -use hugr::std_extensions::arithmetic::float_types::ConstF64; - -use crate::extension::rotation::{ConstRotation, RotationOp}; -use crate::extension::sympy::SympyOp; -use crate::ops::match_symb_const_op; - -/// Fold a rotation or float operation into a string, given the string -/// representations of its inputs. -/// -/// The folded op must have a single string output. -/// -/// Returns `None` if the operation cannot be folded. -pub fn fold_param_op(optype: &OpType, inputs: &[&str]) -> Option { - let param = match optype { - OpType::Const(const_op) => { - // New constant, register it if it can be interpreted as a parameter. - try_constant_to_param(const_op.value())? - } - OpType::LoadConstant(_op_type) => { - // Re-use the parameter from the input. - inputs[0].to_string() - } - // Encode some angle and float operations directly as strings using - // the already encoded inputs. Fail if the operation is not - // supported, and let the operation encoding process it instead. - OpType::ExtensionOp(_) => { - if let Some(s) = optype - .cast::() - .and_then(|op| encode_rotation_op(&op, inputs)) - { - s - } else if let Some(s) = optype - .cast::() - .and_then(|op| encode_float_op(&op, inputs)) - { - s - } else if let Some(s) = optype - .cast::() - .and_then(|op| encode_sympy_op(&op, inputs)) - { - s - } else { - return None; - } - } - _ => match_symb_const_op(optype)?.to_string(), - }; - Some(param) -} - -/// Convert a HUGR rotation or float constant to a TKET1 parameter. -/// -/// Angle parameters in TKET1 are encoded as a number of half-turns, -/// whereas HUGR uses radians. -#[inline] -fn try_constant_to_param(val: &Value) -> Option { - if let Some(const_angle) = val.get_custom_value::() { - let half_turns = const_angle.half_turns(); - Some(half_turns.to_string()) - } else if let Some(const_float) = val.get_custom_value::() { - let float = const_float.value(); - - // Special case for pi rotations - if float == std::f64::consts::PI { - Some("pi".to_string()) - } else { - Some(float.to_string()) - } - } else { - None - } -} - -/// Encode an [`RotationOp`]s as a string, given its encoded inputs. -/// -/// `inputs` contains the expressions to compute each input. -fn encode_rotation_op(op: &RotationOp, inputs: &[&str]) -> Option { - let s = match op { - RotationOp::radd => format!("({} + {})", inputs[0], inputs[1]), - // Encode/decode the rotation as pytket parameters, expressed as half-turns. - // Note that the tracked parameter strings are always written in half-turns, - // so the conversion here is a no-op. - RotationOp::to_halfturns => inputs[0].to_string(), - RotationOp::from_halfturns_unchecked => inputs[0].to_string(), - // The checked conversion returns an option, which we do not support. - RotationOp::from_halfturns => return None, - }; - Some(s) -} - -/// Encode an [`FloatOps`] as a string, given its encoded inputs. -fn encode_float_op(op: &FloatOps, inputs: &[&str]) -> Option { - let s = match op { - FloatOps::fadd => format!("({} + {})", inputs[0], inputs[1]), - FloatOps::fsub => format!("({} - {})", inputs[0], inputs[1]), - FloatOps::fneg => format!("(-{})", inputs[0]), - FloatOps::fmul => format!("({} * {})", inputs[0], inputs[1]), - FloatOps::fdiv => format!("({} / {})", inputs[0], inputs[1]), - FloatOps::fpow => format!("({} ** {})", inputs[0], inputs[1]), - FloatOps::ffloor => format!("floor({})", inputs[0]), - FloatOps::fceil => format!("ceil({})", inputs[0]), - FloatOps::fround => format!("round({})", inputs[0]), - FloatOps::fmax => format!("max({}, {})", inputs[0], inputs[1]), - FloatOps::fmin => format!("min({}, {})", inputs[0], inputs[1]), - FloatOps::fabs => format!("abs({})", inputs[0]), - _ => return None, - }; - Some(s) -} - -/// Encode a [`SympyOp`]s as a string. -fn encode_sympy_op(op: &SympyOp, inputs: &[&str]) -> Option { - if !inputs.is_empty() { - return None; - } - - Some(op.expr.clone()) -} diff --git a/tket2/src/serialize/pytket/tests.rs b/tket2/src/serialize/pytket/tests.rs index fc18a3b3a..45e34c7e2 100644 --- a/tket2/src/serialize/pytket/tests.rs +++ b/tket2/src/serialize/pytket/tests.rs @@ -10,12 +10,15 @@ use hugr::hugr::hugrmut::HugrMut; use hugr::std_extensions::arithmetic::float_ops::FloatOps; use hugr::types::Signature; use hugr::HugrView; +use itertools::Itertools; use rstest::{fixture, rstest}; use tket_json_rs::circuit_json::{self, SerialCircuit}; use tket_json_rs::optype; use tket_json_rs::register; -use super::{TKETDecode, METADATA_INPUT_PARAMETERS, METADATA_Q_OUTPUT_REGISTERS}; +use super::{ + TKETDecode, METADATA_INPUT_PARAMETERS, METADATA_Q_OUTPUT_REGISTERS, METADATA_Q_REGISTERS, +}; use crate::circuit::Circuit; use crate::extension::rotation::{rotation_type, ConstRotation, RotationOp}; use crate::extension::sympy::SympyOpDef; @@ -116,22 +119,55 @@ fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) { assert_eq!(a.name, b.name); assert_eq!(a.phase, b.phase); assert_eq!(&a.qubits, &b.qubits); - assert_eq!(&a.bits, &b.bits); assert_eq!(a.commands.len(), b.commands.len()); - // This comparison only works if both serial circuits share a topological - // ordering of commands. + let bits_a: HashSet<_> = a.bits.iter().collect(); + let bits_b: HashSet<_> = b.bits.iter().collect(); + assert_eq!(bits_a, bits_b); + + // We ignore the commands order here, as two encodings may swap + // non-dependant operations. + // + // The correct thing here would be to run a deterministic toposort and + // compare the commands in that order. This is just a quick check that + // everything is present, ignoring wire dependencies. // - // We also cannot compare the arguments directly, since we may permute them - // internally. + // Another problem is that `Command`s cannot be compared directly; + // - `command.op.signature`, and `n_qb` are optional and sometimes + // unset in pytket-generated circs. + // - qubit arguments names may differ if they have been allocated inside the circuit, + // as they depend on the traversal argument. Same with classical params. + // Here we define an ad-hoc subset that can be compared. // // TODO: Do a proper comparison independent of the toposort ordering, and // track register reordering. - for (a, b) in a.commands.iter().zip(b.commands.iter()) { - assert_eq!(a.op.op_type, b.op.op_type); - assert_eq!(a.op.params, b.op.params); - assert_eq!(a.args.len(), b.args.len()); + #[derive(PartialEq, Eq, Hash, Debug)] + struct CommandInfo { + op_type: tket_json_rs::OpType, + params: Vec, + n_args: usize, + } + + impl From<&tket_json_rs::circuit_json::Command> for CommandInfo { + fn from(command: &tket_json_rs::circuit_json::Command) -> Self { + CommandInfo { + op_type: command.op.op_type.clone(), + params: command.op.params.clone().unwrap_or_default(), + n_args: command.args.len(), + } + } + } + + let a_command_count: HashMap = a.commands.iter().map_into().counts(); + let b_command_count: HashMap = b.commands.iter().map_into().counts(); + for (a, &count_a) in &a_command_count { + let count_b = b_command_count.get(a).copied().unwrap_or_default(); + assert_eq!( + count_a, count_b, + "command {a:?} appears {count_a} times in rhs and {count_b} times in lhs" + ); } + assert_eq!(a_command_count.len(), b_command_count.len()); } /// A simple circuit with some preset qubit registers @@ -151,11 +187,17 @@ fn circ_preset_qubits() -> Circuit { let mut hugr = h.finish_hugr_with_outputs([qb0, qb1]).unwrap(); + // A preset register for the first qubit output + hugr.set_metadata( + hugr.entrypoint(), + METADATA_Q_REGISTERS, + serde_json::json!([["q", [2]], ["q", [10]], ["q", [8]]]), + ); // A preset register for the first qubit output hugr.set_metadata( hugr.entrypoint(), METADATA_Q_OUTPUT_REGISTERS, - serde_json::json!([["q", [1]]]), + serde_json::json!([["q", [10]]]), ); hugr.into() @@ -353,14 +395,14 @@ fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) { let deser_sig = deser.circuit_signature(); assert_eq!( - &deser_sig.input, &decoded_sig.input, + &decoded_sig.input, &deser_sig.input, "Input signature mismatch\n Expected: {}\n Actual: {}", - &deser_sig, &decoded_sig + &decoded_sig, &deser_sig ); assert_eq!( - &deser_sig.output, &decoded_sig.output, + &decoded_sig.output, &deser_sig.output, "Output signature mismatch\n Expected: {}\n Actual: {}", - &deser_sig, &decoded_sig + &decoded_sig, &deser_sig ); let reser = SerialCircuit::encode(&deser).unwrap();