12
12
13
13
import logging
14
14
from pathlib import Path
15
- from typing import TYPE_CHECKING
15
+ from typing import TYPE_CHECKING , Optional , Callable
16
16
17
17
import numpy as np
18
18
from qiskit import QuantumCircuit
19
19
from qiskit .converters import circuit_to_dag , dag_to_circuit
20
- from qiskit .circuit import ClassicalRegister , QuantumRegister
21
- from qiskit .transpiler import PassManager
20
+ from qiskit .circuit import ClassicalRegister , QuantumRegister , Instruction
21
+ from qiskit .transpiler import PassManager , Target
22
+ from qiskit .dagcircuit import DAGCircuit
22
23
from qiskit_ibm_transpiler .ai .routing import AIRouting
23
24
24
25
from mqt .predictor .utils import calc_supermarq_features
26
+ from mqt .predictor .rl .actions import Action
25
27
26
28
if TYPE_CHECKING :
27
29
from numpy .random import Generator
32
34
33
35
logger = logging .getLogger ("mqt-predictor" )
34
36
35
- def extract_cregs_and_measurements (qc ):
37
+ def extract_cregs_and_measurements (qc : QuantumCircuit ) -> tuple [list [ClassicalRegister ], list [tuple [Instruction , list , list ]]]:
38
+ """
39
+ Extracts classical registers and measurement operations from a quantum circuit.
40
+
41
+ Args:
42
+ qc: The input QuantumCircuit.
43
+
44
+ Returns:
45
+ A tuple containing a list of classical registers and a list of measurement operations.
46
+ """
36
47
cregs = [ClassicalRegister (cr .size , name = cr .name ) for cr in qc .cregs ]
37
48
measurements = [
38
49
(item .operation , item .qubits , item .clbits )
@@ -41,7 +52,16 @@ def extract_cregs_and_measurements(qc):
41
52
]
42
53
return cregs , measurements
43
54
44
- def remove_cregs (qc ):
55
+ def remove_cregs (qc : QuantumCircuit ) -> QuantumCircuit :
56
+ """
57
+ Removes classical registers and measurement operations from the circuit.
58
+
59
+ Args:
60
+ qc: The input QuantumCircuit.
61
+
62
+ Returns:
63
+ A new QuantumCircuit with only quantum operations (no cregs or measurements).
64
+ """
45
65
qregs = [QuantumRegister (qr .size , name = qr .name ) for qr in qc .qregs ]
46
66
new_qc = QuantumCircuit (* qregs )
47
67
old_to_new = {}
@@ -55,7 +75,24 @@ def remove_cregs(qc):
55
75
new_qc .append (instr , qargs )
56
76
return new_qc
57
77
58
- def add_cregs_and_measurements (qc , cregs , measurements , qubit_map = None ):
78
+ def add_cregs_and_measurements (
79
+ qc : QuantumCircuit ,
80
+ cregs : list [ClassicalRegister ],
81
+ measurements : list [tuple [Instruction , list , list ]],
82
+ qubit_map : Optional [dict ] = None ,
83
+ ) -> QuantumCircuit :
84
+ """
85
+ Adds classical registers and measurement operations back to the quantum circuit.
86
+
87
+ Args:
88
+ qc: The quantum circuit to which cregs and measurements are added.
89
+ cregs: List of ClassicalRegister to add.
90
+ measurements: List of measurement instructions as tuples (Instruction, qubits, clbits).
91
+ qubit_map: Optional dictionary mapping original qubits to new qubits.
92
+
93
+ Returns:
94
+ The modified QuantumCircuit with cregs and measurements added.
95
+ """
59
96
for cr in cregs :
60
97
qc .add_register (cr )
61
98
for instr , qargs , cargs in measurements :
@@ -68,10 +105,15 @@ def add_cregs_and_measurements(qc, cregs, measurements, qubit_map=None):
68
105
69
106
class SafeAIRouting (AIRouting ):
70
107
"""
71
- Remove cregs before AIRouting and add them back afterwards
72
- Necessary because there are cases AIRouting can't handle
108
+ Custom AIRouting wrapper that removes classical registers before routing.
109
+
110
+ This prevents failures in AIRouting when classical bits are present by
111
+ temporarily removing classical registers and measurements and restoring
112
+ them after routing is completed.
73
113
"""
74
- def run (self , dag ):
114
+ def run (self , dag : DAGCircuit ) -> DAGCircuit :
115
+ """Run the routing pass on a DAGCircuit."""
116
+
75
117
# 1. Convert input dag to circuit
76
118
qc_orig = dag_to_circuit (dag )
77
119
@@ -101,26 +143,36 @@ def run(self, dag):
101
143
else :
102
144
try :
103
145
idx = qc_routed .qubits .index (phys )
104
- except ValueError :
105
- raise RuntimeError (f"Physical qubit { phys } not found in output circuit!" )
146
+ except ValueError as err :
147
+ raise RuntimeError (f"Physical qubit { phys } not found in output circuit!" ) from err
106
148
qubit_map [virt ] = qc_routed .qubits [idx ]
107
149
# 7. Restore classical registers and measurement instructions
108
150
qc_final = add_cregs_and_measurements (qc_routed , cregs , measurements , qubit_map )
109
151
# 8. Return as dag
110
152
return circuit_to_dag (qc_final )
111
153
112
154
def best_of_n_passmanager (
113
- action , device , qc , max_iteration = (20 ,20 ),
114
- metric_fn = None ,
115
- ):
155
+ action : Action ,
156
+ device : Target ,
157
+ qc : QuantumCircuit ,
158
+ max_iteration : tuple [int , int ] = (20 , 20 ),
159
+ metric_fn : Optional [Callable [[QuantumCircuit ], float ]] = None ,
160
+ )-> tuple [QuantumCircuit , dict [str , any ]]:
116
161
"""
117
162
Runs the given transpile_pass multiple times and keeps the best result.
118
- action: the action dict with a 'transpile_pass' key (lambda/device->[passes])
119
- device: the backend or device
120
- qc: input circuit
121
- max_iteration: number of times to try
122
- metric_fn: function(circ) -> float for scoring
123
- require_layout: skip outputs with missing layouts
163
+
164
+ Args:
165
+ action: The action dictionary with a 'transpile_pass' key
166
+ (lambda device -> [passes]).
167
+ device: The target backend or device.
168
+ qc: The input quantum circuit.
169
+ max_iteration: A tuple (layout_trials, routing_trials) specifying
170
+ how many times to try.
171
+ metric_fn: Optional function to score circuits; defaults to circuit depth.
172
+
173
+ Returns:
174
+ A tuple containing the best transpiled circuit and its corresponding
175
+ property set.
124
176
"""
125
177
best_val = None
126
178
best_result = None
@@ -249,7 +301,7 @@ def get_openqasm_gates() -> list[str]:
249
301
"rccx" ,
250
302
]
251
303
252
- def create_feature_dict (qc : QuantumCircuit , basis_gates : list [ str ], coupling_map ) -> dict [str , int | NDArray [np .float64 ]]:
304
+ def create_feature_dict (qc : QuantumCircuit ) -> dict [str , int | NDArray [np .float64 ]]:
253
305
"""Creates a feature dictionary for a given quantum circuit.
254
306
255
307
Arguments:
0 commit comments