diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index e3d1c9a0d35..a6e37eb0882 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -150,6 +150,7 @@ SpinInversionGaugeTransformer as SpinInversionGaugeTransformer, SqrtCZGaugeTransformer as SqrtCZGaugeTransformer, SqrtISWAPGaugeTransformer as SqrtISWAPGaugeTransformer, + CPhaseGaugeTransformerMM as CPhaseGaugeTransformerMM, ) from cirq.transformers.randomized_measurements import ( diff --git a/cirq-core/cirq/transformers/gauge_compiling/__init__.py b/cirq-core/cirq/transformers/gauge_compiling/__init__.py index f67eb6ee409..fe783f70961 100644 --- a/cirq-core/cirq/transformers/gauge_compiling/__init__.py +++ b/cirq-core/cirq/transformers/gauge_compiling/__init__.py @@ -42,3 +42,11 @@ from cirq.transformers.gauge_compiling.cphase_gauge import ( CPhaseGaugeTransformer as CPhaseGaugeTransformer, ) + +from cirq.transformers.gauge_compiling.multi_moment_gauge_compiling import ( + MultiMomentGaugeTransformer as MultiMomentGaugeTransformer, +) + +from cirq.transformers.gauge_compiling.multi_moment_cphase_gauge import ( + CPhaseGaugeTransformerMM as CPhaseGaugeTransformerMM, +) diff --git a/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge.py b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge.py new file mode 100644 index 00000000000..de08ff06284 --- /dev/null +++ b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge.py @@ -0,0 +1,245 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Multi-Moment Gauge Transformer for the cphase gate.""" + +from __future__ import annotations + +from typing import cast + +import numpy as np + +from cirq import circuits, ops +from cirq.transformers.gauge_compiling.multi_moment_gauge_compiling import ( + MultiMomentGaugeTransformer, +) + + +class _PauliAndZPow: + """In pulling through, one qubit gate can be represented by a Pauli and an Rz gate. + The order is --Pauli--ZPowGate--. + """ + + pauli: ops.Pauli | ops.IdentityGate = ops.I + zpow: ops.ZPowGate = ops.ZPowGate(exponent=0) + + commuting_gates = {ops.I, ops.Z} # I,Z Commute with ZPowGate and CZPowGate; X,Y anti-commute. + + def __init__( + self, + pauli: ops.Pauli | ops.IdentityGate = ops.I, + zpow: ops.ZPowGate = ops.ZPowGate(exponent=0), + ) -> None: + self.pauli = pauli + self.zpow = zpow + + def _merge_left_zpow(self, left: ops.ZPowGate): + """Merges ZPowGate from left.""" + if self.pauli in self.commuting_gates: + self.zpow = ops.ZPowGate(exponent=left.exponent + self.zpow.exponent) + else: + self.zpow = ops.ZPowGate(exponent=-left.exponent + self.zpow.exponent) + + def _merge_right_zpow(self, right: ops.ZPowGate): + """Merges ZPowGate from right.""" + self.zpow = ops.ZPowGate(exponent=right.exponent + self.zpow.exponent) + + def _merge_left_pauli(self, left: ops.Pauli): + """Merges --left_pauli--self--.""" + if self.pauli == ops.I: + self.pauli = left + else: + self.pauli = left.phased_pauli_product(self.pauli)[1] + + def _merge_right_pauli(self, right: ops.Pauli): + """Merges --self--right_pauli--.""" + if self.pauli == ops.I: + self.pauli = right + else: + self.pauli = right.phased_pauli_product(self.pauli)[1] + if right not in self.commuting_gates: + self.zpow = ops.ZPowGate(exponent=-self.zpow.exponent) + + def merge_left(self, left: _PauliAndZPow) -> None: + """Inplace merge other from left.""" + self._merge_left_zpow(left.zpow) + if left.pauli != ops.I: + self._merge_left_pauli(cast(ops.Pauli, left.pauli)) + + def merge_right(self, right: _PauliAndZPow) -> None: + """Inplace merge other from right.""" + if right.pauli != ops.I: + self._merge_right_pauli(cast(ops.Pauli, right.pauli)) + self._merge_right_zpow(right.zpow) + + def after_cphase( + self, cphase: ops.CZPowGate + ) -> tuple[ops.CZPowGate, _PauliAndZPow, _PauliAndZPow]: + """Pull self through cphase. + + Returns: + A tuple of + (updated cphase gate, pull_through of this qubit, pull_through of the other qubit). + """ + if self.pauli in self.commuting_gates: + return cphase, self, _PauliAndZPow() + else: + # Taking self.pauli==X gate as an example: + # 0: ─X─Z^t──@────── 0: ─X──@─────Z^t─ 0: ─@──────X──Z^t── + # │ ==> │ ==> │ + # 1: ────────@^exp── 1: ────@^exp───── 1: ─@^-exp─Z^exp─── + # Similarly for X|Y on qubit 0/1, the result is always flipping cphase and + # add an extra Rz rotation on the other qubit. + return ( + cast(ops.CZPowGate, cphase**-1), + self, + _PauliAndZPow(zpow=ops.ZPowGate(exponent=cphase.exponent)), + ) + + def after_pauli(self, pauli: ops.Pauli | ops.IdentityGate) -> _PauliAndZPow: + """Calculates ─self─pauli─ ==> ─pauli─output─.""" + if pauli in self.commuting_gates: + return _PauliAndZPow(self.pauli, self.zpow) + else: + return _PauliAndZPow(self.pauli, ops.ZPowGate(exponent=-self.zpow.exponent)) + + def after_zpow(self, zpow: ops.ZPowGate) -> tuple[ops.ZPowGate, _PauliAndZPow]: + """Calculates ─self─zpow─ ==> ─zpow'─output─.""" + if self.pauli in self.commuting_gates: + return zpow, self + else: + return ops.ZPowGate(exponent=-zpow.exponent), self + + def __str__(self) -> str: + return f"─{self.pauli}──{self.zpow}─" + + def to_single_qubit_gate(self) -> ops.PhasedXZGate | ops.ZPowGate | ops.IdentityGate: + """Converts the _PhasedXYAndRz to a single-qubit gate.""" + exp = self.zpow.exponent + match self.pauli: + case ops.I: + if exp % 2 == 0: + return ops.I + return self.zpow + case ops.X: + return ops.PhasedXZGate(x_exponent=1, z_exponent=exp, axis_phase_exponent=0) + case ops.Y: + return ops.PhasedXZGate(x_exponent=1, z_exponent=exp - 1, axis_phase_exponent=0) + case _: # ops.Z + if (exp + 1) % 2 == 0: + return ops.I + return ops.ZPowGate(exponent=1 + exp) + + +def _pull_through_single_cphase( + cphase: ops.CZPowGate, input0: _PauliAndZPow, input1: _PauliAndZPow +) -> tuple[ops.CZPowGate, _PauliAndZPow, _PauliAndZPow]: + """Pulls input0 and input1 through a CZPowGate. + Input: + 0: ─(input0)─@───── + │ + 1: ─(input1)─@^exp─ + Output: + 0: ─@────────(output0)─ + │ + 1: ─@^+/-exp─(output1)─ + """ + + # Step 1; pull input0 through CZPowGate. + # 0: ─input0─@───── 0: ────────@─────────output0─ + # │ ==> │ + # 1: ─input1─@^exp─ 1: ─input1─@^+/-exp──output1─ + output_cphase, output0, output1 = input0.after_cphase(cphase) + + # Step 2; similar to step 1, pull input1 through CZPowGate. + # 0: ─@──────────pulled0────output0─ 0: ─@────────output0─ + # ==> │ ==> │ + # 1: ─@^+/-exp───pulled1────output1─ 1: ─@^+/-exp─output1─ + output_cphase, pulled1, pulled0 = input1.after_cphase(output_cphase) + output0.merge_left(pulled0) + output1.merge_left(pulled1) + + return output_cphase, output0, output1 + + +_TARGET_GATESET: ops.Gateset = ops.Gateset(ops.CZPowGate) +_SUPPORTED_GATESET: ops.Gateset = ops.Gateset(ops.Pauli, ops.IdentityGate, ops.Rz, ops.ZPowGate) + + +class CPhaseGaugeTransformerMM(MultiMomentGaugeTransformer): + + def __init__(self, supported_gates=_SUPPORTED_GATESET): + super().__init__(target=_TARGET_GATESET, supported_gates=supported_gates) + + def sample_left_moment( + self, active_qubits: frozenset[ops.Qid], rng: np.random.Generator = np.random.default_rng() + ) -> circuits.Moment: + return circuits.Moment( + [ + rng.choice( + np.array([ops.I, ops.X, ops.Y, ops.Z], dtype=ops.Gate), + p=[0.25, 0.25, 0.25, 0.25], + ).on(q) + for q in active_qubits + ] + ) + + def gauge_on_moments(self, moments_to_gauge) -> list[circuits.Moment]: + active_qubits = circuits.Circuit.from_moments(*moments_to_gauge).all_qubits() + left_moment = self.sample_left_moment(active_qubits) + pulled: dict[ops.Qid, _PauliAndZPow] = { + op.qubits[0]: _PauliAndZPow(pauli=cast(ops.Pauli | ops.IdentityGate, op.gate)) + for op in left_moment + if op.gate + } + ret: list[circuits.Moment] = [left_moment] + # The loop iterates through each moment of the target block, propagating + # the `pulled` gauge from left to right. In each iteration, `prev` holds + # the gauge to the left of the current `moment`, and the loop computes + # the transformed `moment` and the new `pulled` gauge to its right. + for moment in moments_to_gauge: + # Calculate --prev--moment-- ==> --updated_momment--pulled-- + prev = pulled + pulled = {} + ops_at_updated_moment: list[ops.Operation] = [] + for op in moment: + # Pull prev through ops at the moment. + if op.gate: + match op.gate: + case ops.CZPowGate(): + q0, q1 = op.qubits + new_gate, pulled[q0], pulled[q1] = _pull_through_single_cphase( + op.gate, prev[q0], prev[q1] + ) + ops_at_updated_moment.append(new_gate.on(q0, q1)) + case ops.Pauli() | ops.IdentityGate(): + q = op.qubits[0] + ops_at_updated_moment.append(op) + pulled[q] = prev[q].after_pauli(op.gate) + case ops.ZPowGate(): + q = op.qubits[0] + new_zpow, pulled[q] = prev[q].after_zpow(op.gate) + ops_at_updated_moment.append(new_zpow.on(q)) + case _: + raise ValueError(f"Gate type {type(op.gate)} is not supported.") + # Keep the other ops of prev + for q, gate in prev.items(): + if q not in pulled: + pulled[q] = gate + ret.append(circuits.Moment(ops_at_updated_moment)) + last_moment = circuits.Moment( + [gate.to_single_qubit_gate().on(q) for q, gate in pulled.items()] + ) + ret.append(last_moment) + return ret diff --git a/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge_test.py b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge_test.py new file mode 100644 index 00000000000..18d20076de6 --- /dev/null +++ b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_cphase_gauge_test.py @@ -0,0 +1,242 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from copy import deepcopy +from unittest.mock import patch + +import numpy as np +import pytest + +import cirq +from cirq import I, X, Y, Z, ZPowGate +from cirq.transformers.gauge_compiling.multi_moment_cphase_gauge import ( + _PauliAndZPow, + CPhaseGaugeTransformerMM, +) + + +def test_gauge_on_single_cphase(): + """Test case. + Input: + 0: ───@─────── + │ + 1: ───@^0.2─── + Example output: + 0: ───X───@────────PhXZ(a=0,x=1,z=0)─── + │ + 1: ───I───@^-0.2───Z^0.2─────────────── + """ + q0, q1 = cirq.LineQubit.range(2) + + input_circuit = cirq.Circuit(cirq.Moment(cirq.CZ(q0, q1) ** 0.2)) + cphase_transformer = CPhaseGaugeTransformerMM() + + for g1 in [X, Y, Z, I]: + for g2 in [X, Y, Z, I]: # Test with all possible samples of the left moment. + with patch.object( + cphase_transformer, "sample_left_moment", return_value=[g1(q0), g2(q1)] + ): + output_circuit = cphase_transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_gauge_on_cz_moments(): + """Test case. + Input: + ┌──┐ + 0: ───@────@─────H───────@───@─── + │ │ │ │ + 1: ───@────┼@────────────@───@─── + ││ + 2: ───@────@┼────────@───@───@─── + │ │ │ │ │ + 3: ───@─────@────────@───@───@─── + └──┘ + Example output: + ┌──┐ + 0: ───X───@────@─────PhXZ(a=0,x=1,z=1)──────H───X───────@───@───PhXZ(a=0,x=1,z=2)──── + │ │ │ │ + 1: ───I───@────┼@────Z──────────────────────────X───────@───@───PhXZ(a=2,x=1,z=-2)─── + ││ + 2: ───Y───@────@┼────PhXZ(a=1.5,x=1,z=-1)───────Z───@───@───@───Z──────────────────── + │ │ │ │ │ + 3: ───Z───@─────@────Z^0────────────────────────I───@───@───@───Z^0────────────────── + └──┘ + """ + q0, q1, q2, q3 = cirq.LineQubit.range(4) + input_circuit = cirq.Circuit( + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q2), cirq.CZ(q1, q3)), + cirq.Moment(cirq.H(q0)), + cirq.Moment(cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3)), + ) + transformer = CPhaseGaugeTransformerMM() + + output_circuit = transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_is_target_moment(): + q0, q1, q2 = cirq.LineQubit.range(3) + + target_moments = [ + cirq.Moment(cirq.CZ(q0, q1) ** 0.2), + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, cirq.X(q2)), + ] + non_target_moments = [ + cirq.Moment(cirq.X(q0), cirq.Y(q1)), + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, cirq.Rz(rads=-0.8).on(q2)), + cirq.Moment(cirq.CZ(q0, q1).with_tags("ignore")), + ] + cphase_transformer = CPhaseGaugeTransformerMM(supported_gates=cirq.Gateset(cirq.Pauli)) + for m in target_moments: + assert cphase_transformer.is_target_moment(m) + for m in non_target_moments: + assert not cphase_transformer.is_target_moment( + m, cirq.TransformerContext(tags_to_ignore={'ignore'}) + ) + + +def test_gauge_on_cphase_moments(): + """Test case. + Input: + ┌──┐ + 0: ───@────────@─────H───Rz(-0.255π)───────────@───────@─────── + │ │ │ │ + 1: ───@^0.2────┼@──────────────────────────────@^0.1───@─────── + ││ + 2: ───@────────@┼────────@─────────────@───────@───────@─────── + │ │ │ │ │ │ + 3: ───@─────────@────────@^0.2─────────@^0.2───@───────@^0.2─── + └──┘ + Example output: + ┌──┐ + 0: ───Y───@─────────@─────PhXZ(a=0,x=1,z=0)───H───X───Rz(0.255π)────────────@───────@────────PhXZ(a=0,x=1,z=1.1)─── + │ │ │ │ + 1: ───I───@^-0.2────┼@────Z^0.2───────────────────Y─────────────────────────@^0.1───@────────PhXZ(a=0,x=1,z=0.1)─── + ││ + 2: ───X───@─────────@┼────PhXZ(a=0,x=1,z=1)───────X───@────────────@────────@───────@────────PhXZ(a=0,x=1,z=0)───── + │ │ │ │ │ │ + 3: ───Z───@──────────@────I───────────────────────I───@^-0.2───────@^-0.2───@───────@^-0.2───Z^-0.4──────────────── + └──┘ + """ # noqa: E501 + q0, q1, q2, q3 = cirq.LineQubit.range(4) + cphase_transformer = CPhaseGaugeTransformerMM() + for _ in range(5): + input_circuit = cirq.Circuit( + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q2), cirq.CZ(q1, q3)), + cirq.Moment(cirq.H(q0)), + cirq.Moment(cirq.CZ(q2, q3) ** 0.2, cirq.Rz(rads=-0.8).on(q0)), + cirq.Moment(cirq.CZ(q2, q3) ** 0.2), + cirq.Moment(cirq.CZ(q0, q1) ** 0.1, cirq.CZ(q2, q3)), + cirq.Moment(cirq.CZ(q0, q1), cirq.CZ(q2, q3) ** 0.2), + ) + + output_circuit = cphase_transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_gauge_on_czpow_only_moments(): + q0, q1, q2 = cirq.LineQubit.range(3) + + input_circuit = cirq.Circuit(cirq.Moment(cirq.CZ(q0, q1) ** 0.2, X(q2))) + cphase_transformer = CPhaseGaugeTransformerMM(supported_gates=cirq.Gateset()) + output_circuit = cphase_transformer(input_circuit) + + # Since X isn't in supported_gates, the moment won't be gauged. + assert input_circuit == output_circuit + + +def test_gauge_on_supported_gates(): + q0, q1, q2, q3 = cirq.LineQubit.range(4) + cphase_transformer = CPhaseGaugeTransformerMM() + for g1 in [X, Z**0.6, I, Z]: + for g2 in [Y, cirq.Rz(rads=0.2), Z**0.7]: + input_circuit = cirq.Circuit( + cirq.Moment(cirq.CZ(q0, q1) ** 0.2, g1(q2), g2(q3)), + cirq.Moment(cirq.CZ(q0, q2), g2(q1), g1(q3)), + ) + output_circuit = cphase_transformer(input_circuit) + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + input_circuit, output_circuit, {q: q for q in input_circuit.all_qubits()} + ) + + +def test_gauge_on_unsupported_gates(): + q0, q1, q2, q3 = cirq.LineQubit.range(4) + + cphase_transformer = CPhaseGaugeTransformerMM(supported_gates=cirq.Gateset(cirq.CNOT)) + with pytest.raises(ValueError, match="Gate type .* is not supported."): + cphase_transformer(cirq.Circuit(cirq.CNOT(q0, q1), cirq.CZ(q2, q3))) + + +def test_pauli_and_phxz_util_str(): + assert str(_PauliAndZPow(pauli=X)) == '─X──Z**0─' + assert str(_PauliAndZPow(pauli=X, zpow=Z**0.1)) == '─X──Z**0.1─' + + +def test_pauli_and_phxz_util_gate_merges(): + """Tests _PauliAndZPow's merge_left() and merge_right().""" + for left_pauli in [X, Y, Z, I]: + for right_pauli in [X, Y, Z, I]: + left = _PauliAndZPow(pauli=left_pauli, zpow=ZPowGate(exponent=0.2)) + right = _PauliAndZPow(pauli=right_pauli, zpow=ZPowGate(exponent=0.6)) + merge1 = deepcopy(right) + merge1.merge_left(left) + merge2 = deepcopy(left) + merge2.merge_right(right) + + assert np.allclose( + cirq.unitary(merge1.to_single_qubit_gate()), + cirq.unitary(merge2.to_single_qubit_gate()), + ) + q = cirq.LineQubit(0) + cirq.testing.assert_allclose_up_to_global_phase( + cirq.unitary( + cirq.Circuit( + left.to_single_qubit_gate().on(q), right.to_single_qubit_gate().on(q) + ) + ), + cirq.unitary(merge1.to_single_qubit_gate()), + atol=1e-6, + ) + + +def test_pauli_and_phxz_util_to_1q_gate(): + """Tests _PauliAndZPow.to_single_qubit_gate().""" + q = cirq.LineQubit(0) + for pauli in [cirq.X, cirq.Y, cirq.Z, cirq.I]: + for zpow in [cirq.ZPowGate(exponent=exp) for exp in [0, 0.1, 0.5, 1, 10.2]]: + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + cirq.Circuit(pauli(q), zpow(q)), + cirq.Circuit(_PauliAndZPow(pauli=pauli, zpow=zpow).to_single_qubit_gate().on(q)), + {q: q}, + ) + + +def test_deep_not_supported(): + with pytest.raises(ValueError, match="GaugeTransformer cannot be used with deep=True"): + t = CPhaseGaugeTransformerMM() + t(cirq.Circuit(), context=cirq.TransformerContext(deep=True)) diff --git a/cirq-core/cirq/transformers/gauge_compiling/multi_moment_gauge_compiling.py b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_gauge_compiling.py new file mode 100644 index 00000000000..2461958e04e --- /dev/null +++ b/cirq-core/cirq/transformers/gauge_compiling/multi_moment_gauge_compiling.py @@ -0,0 +1,135 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the abstraction for multi-moment gauge compiling as a cirq transformer.""" + +import abc + +import numpy as np + +from cirq import circuits, ops +from cirq.transformers import transformer_api + + +@transformer_api.transformer +class MultiMomentGaugeTransformer(abc.ABC): + """A gauge transformer that wraps target blocks of moments with single-qubit gates. + + In detail, a "gauging moment" of single-qubit gates is inserted before a target block of + moments. These gates are then commuted through the block, resulting in a corresponding + moment of gates after it. + + q₀: ... ───LG0───╭───────────╮────RG0───... + │ │ + q₁: ... ───LG1───┤ moments ├────RG1───... + │ to be │ + q₂: ... ───LG2───┤ gauged on ├────RG2───... + │ │ + q₃: ... ───LG3───╰───────────╯────RG3───... + """ + + def __init__( + self, + target: ops.Gate | ops.Gateset | ops.GateFamily, + supported_gates: ops.Gateset = ops.Gateset(), + ) -> None: + """Constructs a MultiMomentGaugeTransformer. + + Args: + target: Specifies the two-qubit gates, gate families, or gate sets that will + be targeted during gauge compiling. The gauge moment must contain at least + one of the target gates. + supported_gates: Determines what other gates, in addition to the target gates, + are permitted within the gauge moments. If a moment contains a gate not found + in either target or supported_gates, it won't be gauged. + """ + self.target = ops.GateFamily(target) if isinstance(target, ops.Gate) else target + self.supported_gates = ( + ops.GateFamily(supported_gates) + if isinstance(supported_gates, ops.Gate) + else supported_gates + ) + + @abc.abstractmethod + def gauge_on_moments(self, moments_to_gauge: list[circuits.Moment]) -> list[circuits.Moment]: + """Gauges a block of moments. + + Args: + moments_to_gauge: A list of moments to be gauged. + + Returns: + A list of moments after gauging. + """ + + @abc.abstractmethod + def sample_left_moment( + self, active_qubits: frozenset[ops.Qid], rng: np.random.Generator + ) -> circuits.Moment: + """Samples a random single-qubit moment to be inserted before the target block. + + Args: + active_qubits: The qubits on which the sampled gates should be applied. + rng: A pseudorandom number generator. + + Returns: + The sampled moment. + """ + + def is_target_moment( + self, moment: circuits.Moment, context: transformer_api.TransformerContext | None = None + ) -> bool: + """Checks if a moment is a target for gauging. + + A moment is a target moment if it contains at least one target op and + all its operations are supported by this transformer. + """ + has_target_gates: bool = False + for op in moment: + if ( + context + and isinstance(op, ops.TaggedOperation) + and set(op.tags).intersection(context.tags_to_ignore) + ): # skip the moment if the op is tagged with a tag in tags_to_ignore + return False + if op.gate: + if op in self.target: + has_target_gates = True + elif op not in self.supported_gates: + return False + return has_target_gates + + def __call__( + self, + circuit: circuits.AbstractCircuit, + *, + context: transformer_api.TransformerContext | None = None, + ) -> circuits.AbstractCircuit: + if context is None: + context = transformer_api.TransformerContext(deep=False) + if context.deep: + raise ValueError('GaugeTransformer cannot be used with deep=True') + output_moments: list[circuits.Moment] = [] + moments_to_gauge: list[circuits.Moment] = [] + for moment in circuit: + if self.is_target_moment(moment, context): + moments_to_gauge.append(moment) + else: + if moments_to_gauge: + output_moments.extend(self.gauge_on_moments(moments_to_gauge)) + moments_to_gauge.clear() + output_moments.append(moment) + if moments_to_gauge: + output_moments.extend(self.gauge_on_moments(moments_to_gauge)) + + return circuits.Circuit.from_moments(*output_moments)