Skip to content

Commit f1119b4

Browse files
committed
dd v2
1 parent 65c5691 commit f1119b4

File tree

2 files changed

+225
-300
lines changed

2 files changed

+225
-300
lines changed

cirq-core/cirq/transformers/dynamical_decoupling.py

Lines changed: 150 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818

1919
from functools import reduce
2020
from itertools import cycle
21+
from enum import Enum
2122
from typing import TYPE_CHECKING
2223

24+
from attrs import frozen
25+
2326
import numpy as np
2427

2528
from cirq import ops, protocols
@@ -133,10 +136,6 @@ def _calc_busy_moment_range_of_each_qubit(circuit: FrozenCircuit) -> dict[ops.Qi
133136
return busy_moment_range_by_qubit
134137

135138

136-
def _is_insertable_moment(moment: Moment, single_qubit_gate_moments_only: bool) -> bool:
137-
return not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment)
138-
139-
140139
def _merge_single_qubit_ops_to_phxz(
141140
q: ops.Qid, operations: tuple[ops.Operation, ...]
142141
) -> ops.Operation:
@@ -149,34 +148,127 @@ def _merge_single_qubit_ops_to_phxz(
149148
return gate.on(q)
150149

151150

152-
def _calc_pulled_through(moment: Moment, input_pauli_ops: ops.PauliString) -> ops.PauliString:
153-
"""Calculates the pulled_through such that circuit(input_pauli_ops, moment.clifford_ops) is
154-
equivalent to circuit(moment.clifford_ops, pulled_through).
155-
"""
156-
clifford_ops_in_moment: list[ops.Operation] = [
157-
op for op in moment.operations if _is_clifford_op(op)
158-
]
159-
return input_pauli_ops.after(clifford_ops_in_moment)
151+
@frozen
152+
class _CircuitRepr:
153+
class _GateType(Enum):
154+
UNKOWN = 0
155+
WALL_GATE = 1
156+
DOOR_GATE = 2
157+
INSERTABLE_GATE = 3
158+
159+
gate_types: dict[ops.Qid, dict[int, _CircuitRepr._GateType]]
160+
need_to_stop: dict[ops.Qid, dict[int, bool]]
161+
circuit: FrozenCircuit
162+
163+
def __init__(self, circuit: cirq.FrozenCircuit, single_qubit_gate_moments_only: bool):
164+
object.__setattr__(self, 'circuit', circuit)
165+
166+
gate_types: dict[ops.Qid, dict[int, _CircuitRepr._GateType]] = {
167+
q: {mid: _CircuitRepr._GateType.UNKOWN for mid in range(len(circuit))}
168+
for q in circuit.all_qubits()
169+
}
170+
mergable: dict[ops.Qid, dict[int, bool]] = {
171+
q: {mid: False for mid in range(len(circuit))} for q in circuit.all_qubits()
172+
}
173+
busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(circuit)
174+
175+
# Set gate types for each (q, mid)
176+
for mid, moment in enumerate(circuit):
177+
is_insertable_moment = (
178+
not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment)
179+
)
180+
for q in circuit.all_qubits():
181+
if mid < busy_moment_range_by_qubit[q][0] or mid > busy_moment_range_by_qubit[q][1]:
182+
gate_types[q][mid] = _CircuitRepr._GateType.WALL_GATE
183+
continue
184+
op_at_q = moment.operation_at(q)
185+
if op_at_q is None:
186+
if is_insertable_moment:
187+
gate_types[q][mid] = _CircuitRepr._GateType.INSERTABLE_GATE
188+
mergable[q][mid] = True
189+
else:
190+
gate_types[q][mid] = _CircuitRepr._GateType.DOOR_GATE
191+
else:
192+
if _is_clifford_op(op_at_q):
193+
gate_types[q][mid] = _CircuitRepr._GateType.DOOR_GATE
194+
mergable[q][mid] = _is_single_qubit_operation(op_at_q)
195+
else:
196+
gate_types[q][mid] = _CircuitRepr._GateType.WALL_GATE
197+
object.__setattr__(self, 'gate_types', gate_types)
198+
199+
need_to_stop: dict[ops.Qid, dict[int, bool]] = {
200+
q: {mid: False for mid in range(len(circuit))} for q in circuit.all_qubits()
201+
}
202+
# Reversely find the last mergeable gate of each qubit, set them as need_to_stop.
203+
for q in circuit.all_qubits():
204+
self._backward_set_stopping_slots(q, len(circuit) - 1, mergable, need_to_stop)
205+
# Reversely check for each wall gate, mark the closest mergeable gate as need_to_stop.
206+
for mid in range(len(circuit)):
207+
for q in circuit.all_qubits():
208+
if self.gate_types[q][mid] == _CircuitRepr._GateType.WALL_GATE:
209+
self._backward_set_stopping_slots(q, mid - 1, mergable, need_to_stop)
210+
object.__setattr__(self, 'need_to_stop', need_to_stop)
211+
212+
def _backward_set_stopping_slots(
213+
self,
214+
q: ops.Qid,
215+
from_mid: int,
216+
mergable: dict[ops.Qid, dict[int, bool]],
217+
need_to_stop: dict[ops.Qid, dict[int, bool]],
218+
):
219+
affected_qubits: set[ops.Qid] = {q}
220+
for back_mid in range(from_mid, -1, -1):
221+
for back_q in set(affected_qubits):
222+
if self.gate_types[back_q][back_mid] == _CircuitRepr._GateType.WALL_GATE:
223+
affected_qubits.remove(back_q)
224+
continue
225+
if mergable[back_q][back_mid]:
226+
need_to_stop[back_q][back_mid] = True
227+
affected_qubits.remove(back_q)
228+
continue
229+
op_at_q = self.circuit[back_mid].operation_at(back_q) or ops.I(q)
230+
affected_qubits.update(op_at_q.qubits)
231+
if not affected_qubits:
232+
break
233+
234+
def __repr__(self) -> str:
235+
if not self.gate_types:
236+
return "CircuitRepr(empty)"
160237

238+
qubits = sorted(list(self.gate_types.keys()))
239+
if not qubits:
240+
return "CircuitRepr(no qubits)"
241+
num_moments = len(self.gate_types[qubits[0]])
161242

162-
def _get_stop_qubits(moment: Moment) -> set[ops.Qid]:
163-
stop_pulling_through_qubits: set[ops.Qid] = set()
164-
for op in moment:
165-
if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not has_unitary(
166-
op
167-
): # multi-qubit clifford op or non-mergable op.
168-
stop_pulling_through_qubits.update(op.qubits)
169-
return stop_pulling_through_qubits
243+
type_map = {
244+
_CircuitRepr._GateType.WALL_GATE: 'w',
245+
_CircuitRepr._GateType.DOOR_GATE: 'd',
246+
_CircuitRepr._GateType.INSERTABLE_GATE: 'i',
247+
_CircuitRepr._GateType.UNKOWN: 'u',
248+
}
170249

250+
max_qubit_len = max(len(str(q)) for q in qubits) if qubits else 0
171251

172-
def _need_merge_pulled_through(op_at_q: ops.Operation, is_at_last_busy_moment: bool) -> bool:
173-
"""With a pulling through pauli gate before op_at_q, need to merge with the
174-
pauli in the conditions below."""
175-
# The op must be mergable and single-qubit
176-
if not (_is_single_qubit_operation(op_at_q) and has_unitary(op_at_q)):
177-
return False
178-
# Either non-Clifford or at the last busy moment
179-
return is_at_last_busy_moment or not _is_clifford_op(op_at_q)
252+
header = f"{'':>{max_qubit_len}} |"
253+
for i in range(num_moments):
254+
header += f" {i:^3} |"
255+
256+
separator = f"{'-' * max_qubit_len}-+"
257+
separator += '-----+' * num_moments
258+
259+
lines = ["CircuitRepr:", header, separator]
260+
261+
for q in qubits:
262+
row_str = f"{str(q):>{max_qubit_len}} |"
263+
for mid in range(num_moments):
264+
gate_type = self.gate_types[q][mid]
265+
char = type_map.get(gate_type, '?')
266+
stop = self.need_to_stop[q][mid]
267+
cell = f"{char},s" if stop else f" {char} "
268+
row_str += f" {cell} |"
269+
lines.append(row_str)
270+
271+
return "\n".join(lines)
180272

181273

182274
@transformer_api.transformer
@@ -188,7 +280,7 @@ def add_dynamical_decoupling(
188280
single_qubit_gate_moments_only: bool = True,
189281
) -> cirq.Circuit:
190282
"""Adds dynamical decoupling gate operations to a given circuit.
191-
This transformer might add new moments and thus change the structure of the original circuit.
283+
This transformer preserves the structure of the original circuit.
192284
193285
Args:
194286
circuit: Input circuit to transform.
@@ -202,11 +294,15 @@ def add_dynamical_decoupling(
202294
Returns:
203295
A copy of the input circuit with dynamical decoupling operations.
204296
"""
205-
base_dd_sequence, pauli_map = _parse_dd_sequence(schema)
297+
298+
if context is not None and context.deep:
299+
raise ValueError("Deep transformation is not supported.")
300+
206301
orig_circuit = circuit.freeze()
207302

208-
busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(orig_circuit)
303+
repr = _CircuitRepr(orig_circuit, single_qubit_gate_moments_only)
209304

305+
base_dd_sequence, pauli_map = _parse_dd_sequence(schema)
210306
# Stores all the moments of the output circuit chronologically.
211307
transformed_moments: list[Moment] = []
212308
# A PauliString stores the result of 'pulling' Pauli gates past each operations
@@ -215,90 +311,31 @@ def add_dynamical_decoupling(
215311
# Iterator of gate to be used in dd sequence for each qubit.
216312
dd_iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()}
217313

218-
def _update_pulled_through(q: ops.Qid, insert_gate: ops.Gate) -> ops.Operation:
219-
nonlocal pulled_through, pauli_map
220-
pulled_through *= pauli_map[insert_gate].on(q)
221-
return insert_gate.on(q)
222-
223-
# Insert and pull remaining Pauli ops through the whole circuit.
224-
# General ideas are
225-
# * Pull through Clifford gates.
226-
# * Stop at multi-qubit non-Clifford ops (and other non-mergable ops).
227-
# * Merge to single-qubit non-Clifford ops.
228-
# * Insert a new moment if necessary.
229-
# After pulling through pulled_through at `moment`, we expect a transformation of
230-
# (pulled_through, moment) -> (updated_moment, updated_pulled_through) or
231-
# (pulled_through, moment) -> (new_moment, updated_moment, updated_pulled_through)
232-
# Moments structure changes are split into 3 steps:
233-
# 1, (..., last_moment, pulled_through1, moment, ...)
234-
# -> (..., last_moment, new_moment or None, pulled_through2, moment, ...)
235-
# 2, (..., pulled_through2, moment, ...) -> (..., pulled_through3, updated_moment, ...)
236-
# 3, (..., pulled_through3, updated_moment, ...)
237-
# -> (..., updated_moment, pulled_through4, ...)
238314
for moment_id, moment in enumerate(orig_circuit.moments):
239-
# Step 1, insert new_moment if necessary.
240-
# In detail: stop pulling through for multi-qubit non-Clifford ops or gates without
241-
# unitary representation (e.g., measure gates). If there are remaining pulled through ops,
242-
# insert into a new moment before current moment.
243-
stop_pulling_through_qubits: set[ops.Qid] = _get_stop_qubits(moment)
244-
new_moment_ops: list[ops.Operation] = []
245-
for q in stop_pulling_through_qubits:
246-
# Insert the remaining pulled_through
247-
remaining_pulled_through_gate = pulled_through.get(q)
248-
if remaining_pulled_through_gate is not None:
249-
new_moment_ops.append(_update_pulled_through(q, remaining_pulled_through_gate))
250-
# Reset dd sequence
251-
dd_iter_by_qubits[q] = cycle(base_dd_sequence)
252-
# Need to insert a new moment before current moment
253-
if new_moment_ops:
254-
# Fill insertable idle moments in the new moment using dd sequence
255-
for q in orig_circuit.all_qubits() - stop_pulling_through_qubits:
256-
if busy_moment_range_by_qubit[q][0] < moment_id <= busy_moment_range_by_qubit[q][1]:
257-
new_moment_ops.append(_update_pulled_through(q, next(dd_iter_by_qubits[q])))
258-
transformed_moments.append(Moment(new_moment_ops))
259-
260-
# Step 2, calc updated_moment with insertions / merges.
261315
updated_moment_ops: set[cirq.Operation] = set()
262316
for q in orig_circuit.all_qubits():
263-
op_at_q = moment.operation_at(q)
264-
remaining_pulled_through_gate = pulled_through.get(q)
265-
updated_op = op_at_q
266-
if op_at_q is None: # insert into idle op
267-
if not _is_insertable_moment(moment, single_qubit_gate_moments_only):
268-
continue
269-
if (
270-
busy_moment_range_by_qubit[q][0] < moment_id < busy_moment_range_by_qubit[q][1]
271-
): # insert next pauli gate in the dd sequence
272-
updated_op = _update_pulled_through(q, next(dd_iter_by_qubits[q]))
273-
elif ( # insert the remaining pulled through if beyond the ending busy moment
274-
moment_id > busy_moment_range_by_qubit[q][1]
275-
and remaining_pulled_through_gate is not None
276-
):
277-
updated_op = _update_pulled_through(q, remaining_pulled_through_gate)
278-
elif (
279-
remaining_pulled_through_gate is not None
280-
): # merge pulled-through of q to op_at_q if needed
281-
if _need_merge_pulled_through(
282-
op_at_q, moment_id == busy_moment_range_by_qubit[q][1]
283-
):
284-
remaining_op = _update_pulled_through(q, remaining_pulled_through_gate)
285-
updated_op = _merge_single_qubit_ops_to_phxz(q, (remaining_op, op_at_q))
286-
if updated_op is not None:
287-
updated_moment_ops.add(updated_op)
288-
289-
if updated_moment_ops:
290-
updated_moment = Moment(updated_moment_ops)
291-
transformed_moments.append(updated_moment)
292-
293-
# Step 3, update pulled through.
294-
# In detail: pulling current `pulled_through` through updated_moment.
295-
pulled_through = _calc_pulled_through(updated_moment, pulled_through)
296-
297-
# Insert a new moment if there are remaining pulled-through operations.
298-
ending_moment_ops = []
299-
for affected_q, combined_op_in_pauli in pulled_through.items():
300-
ending_moment_ops.append(combined_op_in_pauli.on(affected_q))
301-
if ending_moment_ops:
302-
transformed_moments.append(Moment(ending_moment_ops))
317+
new_op_at_q = moment.operation_at(q)
318+
if repr.gate_types[q][moment_id] == _CircuitRepr._GateType.INSERTABLE_GATE:
319+
new_gate = next(dd_iter_by_qubits[q])
320+
new_op_at_q = new_gate.on(q)
321+
pulled_through *= pauli_map[new_gate].on(q)
322+
if repr.need_to_stop[q][moment_id]:
323+
to_be_merged = pulled_through.get(q)
324+
if to_be_merged is not None:
325+
new_op_at_q = _merge_single_qubit_ops_to_phxz(
326+
q, [to_be_merged, new_op_at_q or ops.I(q)]
327+
)
328+
pulled_through *= to_be_merged.on(q)
329+
if new_op_at_q is not None:
330+
updated_moment_ops.add(new_op_at_q)
331+
332+
updated_moment = Moment(updated_moment_ops)
333+
clifford_ops = [op for op in updated_moment if _is_clifford_op(op)]
334+
pulled_through = pulled_through.after(clifford_ops)
335+
transformed_moments.append(updated_moment)
336+
337+
# DO NOT SUBMIT
338+
# if pulled_through.qubits() is not None:
339+
# raise RuntimeError("Expect empty pulled through after propogating all moments.")
303340

304341
return Circuit.from_moments(*transformed_moments)

0 commit comments

Comments
 (0)