Skip to content

Commit 408506d

Browse files
add tests
1 parent f45f1a0 commit 408506d

File tree

3 files changed

+106
-28
lines changed

3 files changed

+106
-28
lines changed

cirq-core/cirq/transformers/gauge_compiling/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,8 @@
4242
from cirq.transformers.gauge_compiling.cphase_gauge import (
4343
CPhaseGaugeTransformer as CPhaseGaugeTransformer,
4444
)
45+
46+
47+
from cirq.transformers.gauge_compiling.idle_moments_gauge import (
48+
IdleMomentsGauge as IdleMomentsGauge,
49+
)

cirq-core/cirq/transformers/gauge_compiling/idle_moments_gague.py renamed to cirq-core/cirq/transformers/gauge_compiling/idle_moments_gauge.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import attrs
1919
import numpy as np
2020

21+
import cirq.circuits as circuits
2122
import cirq.ops as ops
2223
import cirq.protocols as protocols
2324
import cirq.transformers.transformer_api as transformer_api
@@ -79,16 +80,16 @@ def _get_structure(
7980
yield (stop + 1, n - 1)
8081

8182

82-
def _merge(g1: cirq.Gate, g2: cirq.Gate) -> cirq.Gate:
83+
def _merge(g1: cirq.Gate, g2: cirq.Gate, q: cirq.Qid, tags: Sequence) -> cirq.Operation:
8384
u1 = protocols.unitary(g1)
8485
u2 = protocols.unitary(g2)
85-
return ops.PhasedXZGate.from_matrix(u2 @ u1)
86+
return ops.PhasedXZGate.from_matrix(u2 @ u1)(q).with_tags(*tags)
8687

8788

8889
@transformer_api.transformer
8990
@attrs.frozen
9091
class IdleMomentsGauge:
91-
"""A transformer that inserts identity-preserving "gauge" gates around idle qubit moments.
92+
r"""A transformer that inserts identity-preserving "gauge" gates around idle qubit moments.
9293
9394
This transformer identifies sequences of consecutive idle moments on a single qubit
9495
that meet a `min_length` threshold. For each such sequence, it inserts a randomly
@@ -104,8 +105,8 @@ class IdleMomentsGauge:
104105
105106
gauges: A sequence of `cirq.Gate` objects to randomly select from.
106107
Can be a custom tuple or a string alias:
107-
- `"pauli"`: Uses single-qubit Pauli gates (I, X, Y, Z).
108-
- `"clifford"`: Uses all 24 single-qubit Clifford gates.
108+
- `"pauli"`: Uses single-qubit Pauli gates (I, X, Y, Z).
109+
- `"clifford"`: Uses all 24 single-qubit Clifford gates.
109110
110111
gauges_inverse: An optional sequence of `cirq.Gate` objects representing
111112
the inverses of gates in `gauges`. The `k`-th gate in `gauges_inverse`
@@ -123,6 +124,7 @@ class IdleMomentsGauge:
123124
gauge_ending: If `True`, applies a gauge to idle moments at the circuit's end,
124125
after the last qubit operation. Defaults to `False`.
125126
"""
127+
126128
min_length: int = attrs.field(
127129
validator=(attrs.validators.instance_of(int), attrs.validators.ge(1))
128130
)
@@ -211,7 +213,7 @@ def __call__(
211213
for q in op.qubits:
212214
active_moments[q].append((m_id, is_mergable))
213215

214-
single_qubit_moments = [{q: op.gate for op in m if len(op.qubits) == 1} for m in circuit]
216+
single_qubit_moments = [{q: op for op in m if len(op.qubits) == 1} for m in circuit]
215217
non_single_qubit_moments = [[op for op in m if len(op.qubits) != 1] for m in circuit]
216218

217219
for q, active in active_moments.items():
@@ -220,34 +222,23 @@ def __call__(
220222
):
221223
gate_index = rng.choice(len(self.gauges))
222224
gate = self.gauges[gate_index]
223-
gate_inv = self.gauges[gate_index]
225+
gate_inv = self.gauges_inverse[gate_index]
224226

225-
if existing_gate := single_qubit_moments[s].get(q, None):
226-
single_qubit_moments[s][q] = _merge(existing_gate, gate)
227+
if existing_op := single_qubit_moments[s].get(q, None):
228+
single_qubit_moments[s][q] = _merge(existing_op.gate, gate, q, existing_op.tags)
227229
else:
228-
single_qubit_moments[s][q] = gate
230+
single_qubit_moments[s][q] = gate(q)
229231

230-
if existing_gate := single_qubit_moments[e].get(q, None):
231-
single_qubit_moments[e][q] = _merge(gate_inv, existing_gate)
232+
if existing_op := single_qubit_moments[e].get(q, None):
233+
single_qubit_moments[e][q] = _merge(
234+
gate_inv, existing_op.gate, q, existing_op.tags
235+
)
232236
else:
233-
single_qubit_moments[e][q] = gate_inv
237+
single_qubit_moments[e][q] = gate_inv(q)
234238

235-
return cirq.Circuit.from_moments(
239+
return circuits.Circuit.from_moments(
236240
*(
237-
[g(q) for q, g in sq.items()] + nsq
241+
[op for op in sq.values()] + nsq
238242
for sq, nsq in zip(single_qubit_moments, non_single_qubit_moments, strict=True)
239243
)
240244
)
241-
242-
243-
if __name__ == '__main__':
244-
tr = IdleMomentsGauge(2, gauges='pauli', gauge_beginning=True)
245-
print(tr)
246-
247-
import cirq
248-
249-
# c = cirq.Circuit.from_moments(cirq.X(cirq.q(0)), [], [], cirq.X(cirq.q(0)))
250-
# c = cirq.Circuit.from_moments([], [], cirq.X(cirq.q(0)))
251-
c = cirq.Circuit.from_moments([], [], cirq.X(cirq.q(0)).with_tags('ignore'))
252-
print(c)
253-
print(tr(c, context=cirq.TransformerContext(tags_to_ignore=("ignore",))))
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import cirq
2+
from cirq.transformers import gauge_compiling as gc
3+
4+
5+
def test_add_gauge_merges_gates():
6+
tr = gc.IdleMomentsGauge(2, gauges='pauli')
7+
8+
circuit = cirq.Circuit.from_moments([], [], [], cirq.X(cirq.q(0)), [], [], cirq.X(cirq.q(0)))
9+
transformed_circuit = tr(circuit, rng_or_seed=0)
10+
11+
assert transformed_circuit == cirq.Circuit.from_moments(
12+
[],
13+
[],
14+
[],
15+
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=1, z_exponent=0)(cirq.q(0)),
16+
[],
17+
[],
18+
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=1, z_exponent=0)(cirq.q(0)),
19+
)
20+
21+
22+
def test_add_gauge_respects_ignore_tag():
23+
tr = gc.IdleMomentsGauge(2, gauges='pauli')
24+
25+
circuit = cirq.Circuit.from_moments(
26+
cirq.X(cirq.q(0)), [], [], cirq.X(cirq.q(0)).with_tags('ignore')
27+
)
28+
transformed_circuit = tr(
29+
circuit, context=cirq.TransformerContext(tags_to_ignore=("ignore",)), rng_or_seed=0
30+
)
31+
assert transformed_circuit == cirq.Circuit.from_moments(
32+
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=1, z_exponent=0)(cirq.q(0)),
33+
[],
34+
cirq.Z(cirq.q(0)),
35+
cirq.X(cirq.q(0)).with_tags('ignore'),
36+
)
37+
38+
39+
def test_add_gauge_on_prefix():
40+
tr = gc.IdleMomentsGauge(3, gauges='clifford', gauge_beginning=True)
41+
42+
circuit = cirq.Circuit.from_moments([], [], [], cirq.CNOT(cirq.q(0), cirq.q(1)))
43+
transformed_circuit = tr(circuit, rng_or_seed=0)
44+
assert transformed_circuit == cirq.Circuit.from_moments(
45+
[
46+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[20](cirq.q(0)),
47+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[15](cirq.q(1)),
48+
],
49+
[],
50+
[
51+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[20](cirq.q(0)) ** -1,
52+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[15](cirq.q(1)) ** -1,
53+
],
54+
cirq.CNOT(cirq.q(0), cirq.q(1)),
55+
)
56+
57+
58+
def test_add_gauge_on_suffix():
59+
tr = gc.IdleMomentsGauge(3, gauges='inv_clifford', gauge_ending=True)
60+
61+
circuit = cirq.Circuit.from_moments(cirq.CNOT(cirq.q(0), cirq.q(1)), [], [], [])
62+
transformed_circuit = tr(circuit, rng_or_seed=0)
63+
assert transformed_circuit == cirq.Circuit.from_moments(
64+
cirq.CNOT(cirq.q(0), cirq.q(1)),
65+
[
66+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[20](cirq.q(0)) ** -1,
67+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[15](cirq.q(1)) ** -1,
68+
],
69+
[],
70+
[
71+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[20](cirq.q(0)),
72+
cirq.SingleQubitCliffordGate.all_single_qubit_cliffords[15](cirq.q(1)),
73+
],
74+
)
75+
76+
77+
def test_add_gauge_respects_min_length():
78+
tr = gc.IdleMomentsGauge(2, gauges=[cirq.X])
79+
80+
circuit = cirq.Circuit.from_moments(cirq.X(cirq.q(0)), [], cirq.X(cirq.q(0)))
81+
transformed_circuit = tr(circuit, rng_or_seed=0)
82+
assert transformed_circuit == circuit

0 commit comments

Comments
 (0)