Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 150 additions & 113 deletions cirq-core/cirq/transformers/dynamical_decoupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@

from functools import reduce
from itertools import cycle
from enum import Enum
from typing import TYPE_CHECKING

from attrs import frozen

import numpy as np

from cirq import ops, protocols
Expand Down Expand Up @@ -133,10 +136,6 @@ def _calc_busy_moment_range_of_each_qubit(circuit: FrozenCircuit) -> dict[ops.Qi
return busy_moment_range_by_qubit


def _is_insertable_moment(moment: Moment, single_qubit_gate_moments_only: bool) -> bool:
return not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment)


def _merge_single_qubit_ops_to_phxz(
q: ops.Qid, operations: tuple[ops.Operation, ...]
) -> ops.Operation:
Expand All @@ -149,34 +148,127 @@ def _merge_single_qubit_ops_to_phxz(
return gate.on(q)


def _calc_pulled_through(moment: Moment, input_pauli_ops: ops.PauliString) -> ops.PauliString:
"""Calculates the pulled_through such that circuit(input_pauli_ops, moment.clifford_ops) is
equivalent to circuit(moment.clifford_ops, pulled_through).
"""
clifford_ops_in_moment: list[ops.Operation] = [
op for op in moment.operations if _is_clifford_op(op)
]
return input_pauli_ops.after(clifford_ops_in_moment)
@frozen
class _CircuitRepr:
class _GateType(Enum):
UNKOWN = 0
WALL_GATE = 1
DOOR_GATE = 2
INSERTABLE_GATE = 3

gate_types: dict[ops.Qid, dict[int, _CircuitRepr._GateType]]
need_to_stop: dict[ops.Qid, dict[int, bool]]
circuit: FrozenCircuit

def __init__(self, circuit: cirq.FrozenCircuit, single_qubit_gate_moments_only: bool):
object.__setattr__(self, 'circuit', circuit)

gate_types: dict[ops.Qid, dict[int, _CircuitRepr._GateType]] = {
q: {mid: _CircuitRepr._GateType.UNKOWN for mid in range(len(circuit))}
for q in circuit.all_qubits()
}
mergable: dict[ops.Qid, dict[int, bool]] = {
q: {mid: False for mid in range(len(circuit))} for q in circuit.all_qubits()
}
busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(circuit)

# Set gate types for each (q, mid)
for mid, moment in enumerate(circuit):
is_insertable_moment = (
not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment)
)
for q in circuit.all_qubits():
if mid < busy_moment_range_by_qubit[q][0] or mid > busy_moment_range_by_qubit[q][1]:
gate_types[q][mid] = _CircuitRepr._GateType.WALL_GATE
continue
op_at_q = moment.operation_at(q)
if op_at_q is None:
if is_insertable_moment:
gate_types[q][mid] = _CircuitRepr._GateType.INSERTABLE_GATE
mergable[q][mid] = True
else:
gate_types[q][mid] = _CircuitRepr._GateType.DOOR_GATE
else:
if _is_clifford_op(op_at_q):
gate_types[q][mid] = _CircuitRepr._GateType.DOOR_GATE
mergable[q][mid] = _is_single_qubit_operation(op_at_q)
else:
gate_types[q][mid] = _CircuitRepr._GateType.WALL_GATE
object.__setattr__(self, 'gate_types', gate_types)

need_to_stop: dict[ops.Qid, dict[int, bool]] = {
q: {mid: False for mid in range(len(circuit))} for q in circuit.all_qubits()
}
# Reversely find the last mergeable gate of each qubit, set them as need_to_stop.
for q in circuit.all_qubits():
self._backward_set_stopping_slots(q, len(circuit) - 1, mergable, need_to_stop)
# Reversely check for each wall gate, mark the closest mergeable gate as need_to_stop.
for mid in range(len(circuit)):
for q in circuit.all_qubits():
if self.gate_types[q][mid] == _CircuitRepr._GateType.WALL_GATE:
self._backward_set_stopping_slots(q, mid - 1, mergable, need_to_stop)
object.__setattr__(self, 'need_to_stop', need_to_stop)

def _backward_set_stopping_slots(
self,
q: ops.Qid,
from_mid: int,
mergable: dict[ops.Qid, dict[int, bool]],
need_to_stop: dict[ops.Qid, dict[int, bool]],
):
affected_qubits: set[ops.Qid] = {q}
for back_mid in range(from_mid, -1, -1):
for back_q in set(affected_qubits):
if self.gate_types[back_q][back_mid] == _CircuitRepr._GateType.WALL_GATE:
affected_qubits.remove(back_q)
continue
if mergable[back_q][back_mid]:
need_to_stop[back_q][back_mid] = True
affected_qubits.remove(back_q)
continue
op_at_q = self.circuit[back_mid].operation_at(back_q) or ops.I(q)
affected_qubits.update(op_at_q.qubits)
if not affected_qubits:
break

def __repr__(self) -> str:
if not self.gate_types:
return "CircuitRepr(empty)"

qubits = sorted(list(self.gate_types.keys()))
if not qubits:
return "CircuitRepr(no qubits)"
num_moments = len(self.gate_types[qubits[0]])

def _get_stop_qubits(moment: Moment) -> set[ops.Qid]:
stop_pulling_through_qubits: set[ops.Qid] = set()
for op in moment:
if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not has_unitary(
op
): # multi-qubit clifford op or non-mergable op.
stop_pulling_through_qubits.update(op.qubits)
return stop_pulling_through_qubits
type_map = {
_CircuitRepr._GateType.WALL_GATE: 'w',
_CircuitRepr._GateType.DOOR_GATE: 'd',
_CircuitRepr._GateType.INSERTABLE_GATE: 'i',
_CircuitRepr._GateType.UNKOWN: 'u',
}

max_qubit_len = max(len(str(q)) for q in qubits) if qubits else 0

def _need_merge_pulled_through(op_at_q: ops.Operation, is_at_last_busy_moment: bool) -> bool:
"""With a pulling through pauli gate before op_at_q, need to merge with the
pauli in the conditions below."""
# The op must be mergable and single-qubit
if not (_is_single_qubit_operation(op_at_q) and has_unitary(op_at_q)):
return False
# Either non-Clifford or at the last busy moment
return is_at_last_busy_moment or not _is_clifford_op(op_at_q)
header = f"{'':>{max_qubit_len}} |"
for i in range(num_moments):
header += f" {i:^3} |"

separator = f"{'-' * max_qubit_len}-+"
separator += '-----+' * num_moments

lines = ["CircuitRepr:", header, separator]

for q in qubits:
row_str = f"{str(q):>{max_qubit_len}} |"
for mid in range(num_moments):
gate_type = self.gate_types[q][mid]
char = type_map.get(gate_type, '?')
stop = self.need_to_stop[q][mid]
cell = f"{char},s" if stop else f" {char} "
row_str += f" {cell} |"
lines.append(row_str)

return "\n".join(lines)


@transformer_api.transformer
Expand All @@ -188,7 +280,7 @@ def add_dynamical_decoupling(
single_qubit_gate_moments_only: bool = True,
) -> cirq.Circuit:
"""Adds dynamical decoupling gate operations to a given circuit.
This transformer might add new moments and thus change the structure of the original circuit.
This transformer preserves the structure of the original circuit.

Args:
circuit: Input circuit to transform.
Expand All @@ -202,11 +294,15 @@ def add_dynamical_decoupling(
Returns:
A copy of the input circuit with dynamical decoupling operations.
"""
base_dd_sequence, pauli_map = _parse_dd_sequence(schema)

if context is not None and context.deep:
raise ValueError("Deep transformation is not supported.")

orig_circuit = circuit.freeze()

busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(orig_circuit)
repr = _CircuitRepr(orig_circuit, single_qubit_gate_moments_only)

base_dd_sequence, pauli_map = _parse_dd_sequence(schema)
# Stores all the moments of the output circuit chronologically.
transformed_moments: list[Moment] = []
# A PauliString stores the result of 'pulling' Pauli gates past each operations
Expand All @@ -215,90 +311,31 @@ def add_dynamical_decoupling(
# Iterator of gate to be used in dd sequence for each qubit.
dd_iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()}

def _update_pulled_through(q: ops.Qid, insert_gate: ops.Gate) -> ops.Operation:
nonlocal pulled_through, pauli_map
pulled_through *= pauli_map[insert_gate].on(q)
return insert_gate.on(q)

# Insert and pull remaining Pauli ops through the whole circuit.
# General ideas are
# * Pull through Clifford gates.
# * Stop at multi-qubit non-Clifford ops (and other non-mergable ops).
# * Merge to single-qubit non-Clifford ops.
# * Insert a new moment if necessary.
# After pulling through pulled_through at `moment`, we expect a transformation of
# (pulled_through, moment) -> (updated_moment, updated_pulled_through) or
# (pulled_through, moment) -> (new_moment, updated_moment, updated_pulled_through)
# Moments structure changes are split into 3 steps:
# 1, (..., last_moment, pulled_through1, moment, ...)
# -> (..., last_moment, new_moment or None, pulled_through2, moment, ...)
# 2, (..., pulled_through2, moment, ...) -> (..., pulled_through3, updated_moment, ...)
# 3, (..., pulled_through3, updated_moment, ...)
# -> (..., updated_moment, pulled_through4, ...)
for moment_id, moment in enumerate(orig_circuit.moments):
# Step 1, insert new_moment if necessary.
# In detail: stop pulling through for multi-qubit non-Clifford ops or gates without
# unitary representation (e.g., measure gates). If there are remaining pulled through ops,
# insert into a new moment before current moment.
stop_pulling_through_qubits: set[ops.Qid] = _get_stop_qubits(moment)
new_moment_ops: list[ops.Operation] = []
for q in stop_pulling_through_qubits:
# Insert the remaining pulled_through
remaining_pulled_through_gate = pulled_through.get(q)
if remaining_pulled_through_gate is not None:
new_moment_ops.append(_update_pulled_through(q, remaining_pulled_through_gate))
# Reset dd sequence
dd_iter_by_qubits[q] = cycle(base_dd_sequence)
# Need to insert a new moment before current moment
if new_moment_ops:
# Fill insertable idle moments in the new moment using dd sequence
for q in orig_circuit.all_qubits() - stop_pulling_through_qubits:
if busy_moment_range_by_qubit[q][0] < moment_id <= busy_moment_range_by_qubit[q][1]:
new_moment_ops.append(_update_pulled_through(q, next(dd_iter_by_qubits[q])))
transformed_moments.append(Moment(new_moment_ops))

# Step 2, calc updated_moment with insertions / merges.
updated_moment_ops: set[cirq.Operation] = set()
for q in orig_circuit.all_qubits():
op_at_q = moment.operation_at(q)
remaining_pulled_through_gate = pulled_through.get(q)
updated_op = op_at_q
if op_at_q is None: # insert into idle op
if not _is_insertable_moment(moment, single_qubit_gate_moments_only):
continue
if (
busy_moment_range_by_qubit[q][0] < moment_id < busy_moment_range_by_qubit[q][1]
): # insert next pauli gate in the dd sequence
updated_op = _update_pulled_through(q, next(dd_iter_by_qubits[q]))
elif ( # insert the remaining pulled through if beyond the ending busy moment
moment_id > busy_moment_range_by_qubit[q][1]
and remaining_pulled_through_gate is not None
):
updated_op = _update_pulled_through(q, remaining_pulled_through_gate)
elif (
remaining_pulled_through_gate is not None
): # merge pulled-through of q to op_at_q if needed
if _need_merge_pulled_through(
op_at_q, moment_id == busy_moment_range_by_qubit[q][1]
):
remaining_op = _update_pulled_through(q, remaining_pulled_through_gate)
updated_op = _merge_single_qubit_ops_to_phxz(q, (remaining_op, op_at_q))
if updated_op is not None:
updated_moment_ops.add(updated_op)

if updated_moment_ops:
updated_moment = Moment(updated_moment_ops)
transformed_moments.append(updated_moment)

# Step 3, update pulled through.
# In detail: pulling current `pulled_through` through updated_moment.
pulled_through = _calc_pulled_through(updated_moment, pulled_through)

# Insert a new moment if there are remaining pulled-through operations.
ending_moment_ops = []
for affected_q, combined_op_in_pauli in pulled_through.items():
ending_moment_ops.append(combined_op_in_pauli.on(affected_q))
if ending_moment_ops:
transformed_moments.append(Moment(ending_moment_ops))
new_op_at_q = moment.operation_at(q)
if repr.gate_types[q][moment_id] == _CircuitRepr._GateType.INSERTABLE_GATE:
new_gate = next(dd_iter_by_qubits[q])
new_op_at_q = new_gate.on(q)
pulled_through *= pauli_map[new_gate].on(q)
if repr.need_to_stop[q][moment_id]:
to_be_merged = pulled_through.get(q)
if to_be_merged is not None:
new_op_at_q = _merge_single_qubit_ops_to_phxz(
q, [to_be_merged, new_op_at_q or ops.I(q)]
)
pulled_through *= to_be_merged.on(q)
if new_op_at_q is not None:
updated_moment_ops.add(new_op_at_q)

updated_moment = Moment(updated_moment_ops)
clifford_ops = [op for op in updated_moment if _is_clifford_op(op)]
pulled_through = pulled_through.after(clifford_ops)
transformed_moments.append(updated_moment)

# DO NOT SUBMIT
# if pulled_through.qubits() is not None:
# raise RuntimeError("Expect empty pulled through after propogating all moments.")

return Circuit.from_moments(*transformed_moments)
Loading
Loading