Skip to content

Commit 2662947

Browse files
Liu KeyuLiu Keyu
authored andcommitted
Fix: resolve pre-commit issues and add missing annotations
Fix: resolve pre-commit issues and add missing annotations Fix: resolve pre-commit issues and add missing annotations
1 parent 78dc1aa commit 2662947

File tree

2 files changed

+75
-22
lines changed

2 files changed

+75
-22
lines changed

src/mqt/predictor/rl/helper.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212

1313
import logging
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Optional, Callable
1616

1717
import numpy as np
1818
from qiskit import QuantumCircuit
1919
from qiskit.converters import circuit_to_dag, dag_to_circuit
20-
from qiskit.circuit import ClassicalRegister, QuantumRegister
21-
from qiskit.transpiler import PassManager
20+
from qiskit.circuit import ClassicalRegister, QuantumRegister, Instruction
21+
from qiskit.transpiler import PassManager, Target
22+
from qiskit.dagcircuit import DAGCircuit
2223
from qiskit_ibm_transpiler.ai.routing import AIRouting
2324

2425
from mqt.predictor.utils import calc_supermarq_features
26+
from mqt.predictor.rl.actions import Action
2527

2628
if TYPE_CHECKING:
2729
from numpy.random import Generator
@@ -32,7 +34,16 @@
3234

3335
logger = logging.getLogger("mqt-predictor")
3436

35-
def extract_cregs_and_measurements(qc):
37+
def extract_cregs_and_measurements(qc: QuantumCircuit) -> tuple[list[ClassicalRegister], list[tuple[Instruction, list, list]]]:
38+
"""
39+
Extracts classical registers and measurement operations from a quantum circuit.
40+
41+
Args:
42+
qc: The input QuantumCircuit.
43+
44+
Returns:
45+
A tuple containing a list of classical registers and a list of measurement operations.
46+
"""
3647
cregs = [ClassicalRegister(cr.size, name=cr.name) for cr in qc.cregs]
3748
measurements = [
3849
(item.operation, item.qubits, item.clbits)
@@ -41,7 +52,16 @@ def extract_cregs_and_measurements(qc):
4152
]
4253
return cregs, measurements
4354

44-
def remove_cregs(qc):
55+
def remove_cregs(qc: QuantumCircuit) -> QuantumCircuit:
56+
"""
57+
Removes classical registers and measurement operations from the circuit.
58+
59+
Args:
60+
qc: The input QuantumCircuit.
61+
62+
Returns:
63+
A new QuantumCircuit with only quantum operations (no cregs or measurements).
64+
"""
4565
qregs = [QuantumRegister(qr.size, name=qr.name) for qr in qc.qregs]
4666
new_qc = QuantumCircuit(*qregs)
4767
old_to_new = {}
@@ -55,7 +75,24 @@ def remove_cregs(qc):
5575
new_qc.append(instr, qargs)
5676
return new_qc
5777

58-
def add_cregs_and_measurements(qc, cregs, measurements, qubit_map=None):
78+
def add_cregs_and_measurements(
79+
qc: QuantumCircuit,
80+
cregs: list[ClassicalRegister],
81+
measurements: list[tuple[Instruction, list, list]],
82+
qubit_map: Optional[dict] = None,
83+
) -> QuantumCircuit:
84+
"""
85+
Adds classical registers and measurement operations back to the quantum circuit.
86+
87+
Args:
88+
qc: The quantum circuit to which cregs and measurements are added.
89+
cregs: List of ClassicalRegister to add.
90+
measurements: List of measurement instructions as tuples (Instruction, qubits, clbits).
91+
qubit_map: Optional dictionary mapping original qubits to new qubits.
92+
93+
Returns:
94+
The modified QuantumCircuit with cregs and measurements added.
95+
"""
5996
for cr in cregs:
6097
qc.add_register(cr)
6198
for instr, qargs, cargs in measurements:
@@ -68,10 +105,15 @@ def add_cregs_and_measurements(qc, cregs, measurements, qubit_map=None):
68105

69106
class SafeAIRouting(AIRouting):
70107
"""
71-
Remove cregs before AIRouting and add them back afterwards
72-
Necessary because there are cases AIRouting can't handle
108+
Custom AIRouting wrapper that removes classical registers before routing.
109+
110+
This prevents failures in AIRouting when classical bits are present by
111+
temporarily removing classical registers and measurements and restoring
112+
them after routing is completed.
73113
"""
74-
def run(self, dag):
114+
def run(self, dag: DAGCircuit) -> DAGCircuit:
115+
"""Run the routing pass on a DAGCircuit."""
116+
75117
# 1. Convert input dag to circuit
76118
qc_orig = dag_to_circuit(dag)
77119

@@ -101,26 +143,36 @@ def run(self, dag):
101143
else:
102144
try:
103145
idx = qc_routed.qubits.index(phys)
104-
except ValueError:
105-
raise RuntimeError(f"Physical qubit {phys} not found in output circuit!")
146+
except ValueError as err:
147+
raise RuntimeError(f"Physical qubit {phys} not found in output circuit!") from err
106148
qubit_map[virt] = qc_routed.qubits[idx]
107149
# 7. Restore classical registers and measurement instructions
108150
qc_final = add_cregs_and_measurements(qc_routed, cregs, measurements, qubit_map)
109151
# 8. Return as dag
110152
return circuit_to_dag(qc_final)
111153

112154
def best_of_n_passmanager(
113-
action, device, qc, max_iteration=(20,20),
114-
metric_fn=None,
115-
):
155+
action: Action,
156+
device: Target,
157+
qc: QuantumCircuit,
158+
max_iteration: tuple[int, int] = (20, 20),
159+
metric_fn: Optional[Callable[[QuantumCircuit], float]] = None,
160+
)-> tuple[QuantumCircuit, dict[str, any]]:
116161
"""
117162
Runs the given transpile_pass multiple times and keeps the best result.
118-
action: the action dict with a 'transpile_pass' key (lambda/device->[passes])
119-
device: the backend or device
120-
qc: input circuit
121-
max_iteration: number of times to try
122-
metric_fn: function(circ) -> float for scoring
123-
require_layout: skip outputs with missing layouts
163+
164+
Args:
165+
action: The action dictionary with a 'transpile_pass' key
166+
(lambda device -> [passes]).
167+
device: The target backend or device.
168+
qc: The input quantum circuit.
169+
max_iteration: A tuple (layout_trials, routing_trials) specifying
170+
how many times to try.
171+
metric_fn: Optional function to score circuits; defaults to circuit depth.
172+
173+
Returns:
174+
A tuple containing the best transpiled circuit and its corresponding
175+
property set.
124176
"""
125177
best_val = None
126178
best_result = None
@@ -249,7 +301,7 @@ def get_openqasm_gates() -> list[str]:
249301
"rccx",
250302
]
251303

252-
def create_feature_dict(qc: QuantumCircuit, basis_gates: list[str], coupling_map) -> dict[str, int | NDArray[np.float64]]:
304+
def create_feature_dict(qc: QuantumCircuit) -> dict[str, int | NDArray[np.float64]]:
253305
"""Creates a feature dictionary for a given quantum circuit.
254306
255307
Arguments:

src/mqt/predictor/rl/predictorenv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ def apply_action(self, action_index: int) -> QuantumCircuit | None:
317317

318318
def _apply_qiskit_action(self, action: Action, action_index: int) -> QuantumCircuit:
319319
if action.get("stochastic", False):
320-
metric_fn = lambda circ: circ.count_ops().get("swap", 0)
320+
def metric_fn(circ: QuantumCircuit) -> float:
321+
return circ.count_ops().get("swap", 0)
321322
# for stochastic actions, pass the layout/routing trials parameter
322323
max_iteration = self.max_iter
323324
if "Sabre" in action["name"] and "AIRouting" not in action["name"]:

0 commit comments

Comments
 (0)