1717from bqskit .ext import bqskit_to_qiskit , qiskit_to_bqskit
1818from bqskit .ir .circuit import Circuit
1919from mqt .bench .targets import get_available_device_names , get_device
20- from pytket ._tket .passes import BasePass as TketBasePass # noqa: PLC2701
2120from pytket .circuit import Qubit
2221from pytket .extensions .qiskit import qiskit_to_tk , tk_to_qiskit
2322from qiskit import QuantumCircuit
2726
2827from mqt .predictor .rl .actions import CompilationOrigin , PassType , get_actions_by_pass_type
2928from mqt .predictor .rl .parsing import (
30- PreProcessTKETRoutingAfterQiskitLayout ,
3129 final_layout_bqskit_to_qiskit ,
3230 final_layout_pytket_to_qiskit ,
3331)
3432
3533if TYPE_CHECKING :
34+ from collections .abc import Callable
35+
36+ from pytket ._tket .passes import BasePass as TketBasePass
3637 from qiskit .passmanager .base_tasks import Task
3738 from qiskit .transpiler import Target
3839
3940 from mqt .predictor .rl .actions import Action
41+ from mqt .predictor .rl .parsing import (
42+ PreProcessTKETRoutingAfterQiskitLayout ,
43+ )
4044
4145
4246@pytest .fixture
@@ -58,8 +62,8 @@ def test_bqskit_o2_action(available_actions_dict: dict[PassType, list[Action]])
5862 qc .cx (0 , 1 )
5963
6064 bqskit_qc = qiskit_to_bqskit (qc )
61- assert callable ( action_bqskit_o2 .transpile_pass )
62- bqskit_qc_optimized = action_bqskit_o2 . transpile_pass (bqskit_qc )
65+ factory = cast ( "Callable[[Circuit], Circuit]" , action_bqskit_o2 .transpile_pass )
66+ bqskit_qc_optimized = factory (bqskit_qc )
6367 assert isinstance (bqskit_qc_optimized , Circuit )
6468 optimized_qc = bqskit_to_qiskit (bqskit_qc_optimized )
6569
@@ -83,9 +87,8 @@ def test_bqskit_synthesis_action(device: Target, available_actions_dict: dict[Pa
8387 check_nat_gates (qc )
8488 assert not check_nat_gates .property_set ["all_gates_in_basis" ]
8589
86- assert callable (action_bqskit_synthesis_action .transpile_pass )
87- lambda_ = action_bqskit_synthesis_action .transpile_pass (device )
88- assert callable (lambda_ )
90+ factory = cast ("Callable[[Target], Callable[[Circuit], Circuit]]" , action_bqskit_synthesis_action .transpile_pass )
91+ lambda_ = factory (device )
8992 bqskit_qc = qiskit_to_bqskit (qc )
9093 if "rigetti" in device .description or "ionq" in device .description or "iqm" in device .description :
9194 with pytest .raises (ValueError , match = re .escape ("not supported in BQSKIT" )):
@@ -122,10 +125,11 @@ def test_bqskit_mapping_action_swaps_necessary(available_actions_dict: dict[Pass
122125
123126 device = get_device ("ibm_falcon_27" )
124127 bqskit_qc = qiskit_to_bqskit (qc )
125- assert callable (bqskit_mapping_action .transpile_pass )
126- lambda_ = bqskit_mapping_action .transpile_pass (device )
127- assert callable (lambda_ )
128- bqskit_qc_mapped , input_mapping , output_mapping = lambda_ (bqskit_qc )
128+ factory = cast (
129+ "Callable[[Target], Callable[[Circuit], tuple[Circuit, tuple[int, ...], tuple[int, ...]]]]" ,
130+ bqskit_mapping_action .transpile_pass ,
131+ )
132+ bqskit_qc_mapped , input_mapping , output_mapping = factory (device )(bqskit_qc )
129133 mapped_qc = bqskit_to_qiskit (bqskit_qc_mapped )
130134 layout = final_layout_bqskit_to_qiskit (input_mapping , output_mapping , mapped_qc , qc )
131135
@@ -186,10 +190,11 @@ def test_bqskit_mapping_action_no_swaps_necessary(available_actions_dict: dict[P
186190 device = get_device ("quantinuum_h2_56" )
187191
188192 bqskit_qc = qiskit_to_bqskit (qc_no_swap_needed )
189- assert callable (bqskit_mapping_action .transpile_pass )
190- lambda_ = bqskit_mapping_action .transpile_pass (device )
191- assert callable (lambda_ )
192- bqskit_qc_mapped , input_mapping , output_mapping = lambda_ (bqskit_qc )
193+ factory = cast (
194+ "Callable[[Target], Callable[[Circuit], tuple[Circuit, tuple[int, ...], tuple[int, ...]]]]" ,
195+ bqskit_mapping_action .transpile_pass ,
196+ )
197+ bqskit_qc_mapped , input_mapping , output_mapping = factory (device )(bqskit_qc )
193198 mapped_qc = bqskit_to_qiskit (bqskit_qc_mapped )
194199 layout = final_layout_bqskit_to_qiskit (input_mapping , output_mapping , mapped_qc , qc_no_swap_needed )
195200 assert layout is not None
@@ -211,8 +216,8 @@ def test_tket_routing(available_actions_dict: dict[PassType, list[Action]]) -> N
211216 device = get_device ("quantinuum_h2_56" )
212217
213218 layout_action = available_actions_dict [PassType .LAYOUT ][0 ]
214- assert callable ( layout_action .transpile_pass )
215- passes_ = cast ( "list[Task]" , layout_action . transpile_pass ( device ) )
219+ factory = cast ( "Callable[[Target], list[Task]]" , layout_action .transpile_pass )
220+ passes_ = factory ( device )
216221 pm = PassManager (passes_ )
217222 layouted_qc = pm .run (qc )
218223 initial_layout = pm .property_set ["layout" ]
@@ -225,11 +230,11 @@ def test_tket_routing(available_actions_dict: dict[PassType, list[Action]]) -> N
225230 assert routing_action is not None
226231
227232 tket_qc = qiskit_to_tk (layouted_qc , preserve_param_uuid = True )
228- assert callable (routing_action .transpile_pass )
229- passes = routing_action .transpile_pass (device )
230- assert isinstance (passes , list )
233+ factory = cast (
234+ "Callable[[Target], list[TketBasePass | PreProcessTKETRoutingAfterQiskitLayout]]" , routing_action .transpile_pass
235+ )
236+ passes = factory (device )
231237 for pass_ in passes :
232- assert isinstance (pass_ , TketBasePass | PreProcessTKETRoutingAfterQiskitLayout )
233238 pass_ .apply (tket_qc )
234239
235240 qbs = tket_qc .qubits
0 commit comments