diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index e3d1c9a0d35..ce718222d6c 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -160,3 +160,6 @@ from cirq.transformers.insertion_sort import ( insertion_sort_transformer as insertion_sort_transformer, ) + + +from cirq.transformers.pauli_insertion import PauliInsertionTransformer as PauliInsertionTransformer diff --git a/cirq-core/cirq/transformers/pauli_insertion.py b/cirq-core/cirq/transformers/pauli_insertion.py new file mode 100644 index 00000000000..df8188dc127 --- /dev/null +++ b/cirq-core/cirq/transformers/pauli_insertion.py @@ -0,0 +1,133 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A pauli insertion transformer.""" + +from __future__ import annotations + +import inspect +from collections.abc import Mapping + +import numpy as np + +from cirq import circuits, ops +from cirq.transformers import transformer_api + +_PAULIS: tuple[ops.Gate, ops.Gate, ops.Gate, ops.Gate] = (ops.I, ops.X, ops.Y, ops.Z) # type: ignore[has-type] + + +@transformer_api.transformer +class PauliInsertionTransformer: + r"""Creates a pauli insertion transformer. + + A pauli insertion operation samples paulis from $\{I, X, Y, Z\}^2$ with the given + probabilities and adds it before the target 2Q gate/operation. This procedure is commonly + used in zero noise extrapolation (ZNE), see appendix D of https://arxiv.org/abs/2503.20870. + """ + + def __init__( + self, + target: ops.Gate | ops.GateFamily | ops.Gateset | type[ops.Gate], + probabilities: np.ndarray | Mapping[tuple[ops.Qid, ops.Qid], np.ndarray] | None = None, + ): + """Makes a pauli insertion transformer that samples 2Q paulis with the given probabilities. + + Args: + target: The target gate, gatefamily, gateset, or type (e.g. ZZPowGAte). + probabilities: Optional ndarray or mapping[qubit-pair, nndarray] representing the + probabilities of sampling 2Q paulis. The order of the paulis is IXYZ. + If at operation `op` a pair (i, j) is sampled then _PAULIS[i] is applied + to op.qubits[0] and _PAULIS[j] is applied to op.qubits[1]. + If None, assume uniform distribution. + """ + if probabilities is None: + probabilities = np.ones((4, 4)) / 16 + elif isinstance(probabilities, dict): + probabilities = {k: np.asarray(v) for k, v in probabilities.items()} + for probs in probabilities.values(): + assert np.isclose(probs.sum(), 1) + assert probs.shape == (4, 4) + else: + probabilities = np.asarray(probabilities) + assert np.isclose(probabilities.sum(), 1) + assert probabilities.shape == (4, 4) + self.probabilities = probabilities + + if inspect.isclass(target): + self.target: ops.GateFamily | ops.Gateset = ops.GateFamily(target) + elif isinstance(target, ops.Gate): + self.target = ops.Gateset(target) + else: + assert isinstance(target, (ops.Gateset, ops.GateFamily)) + self.target = target + + def _is_target(self, op: ops.Operation) -> bool: + if isinstance(self.probabilities, dict) and op.qubits not in self.probabilities: + return False + return op in self.target + + def _sample( + self, qubits: tuple[ops.Qid, ...], rng: np.random.Generator + ) -> tuple[ops.Gate, ops.Gate]: + if isinstance(self.probabilities, dict): + assert len(qubits) == 2 + flat_probs = self.probabilities[qubits].reshape(-1) + else: + flat_probs = self.probabilities.reshape(-1) + i, j = np.unravel_index(rng.choice(16, p=flat_probs), (4, 4)) + return _PAULIS[i], _PAULIS[j] + + def __call__( + self, + circuit: circuits.AbstractCircuit, + *, + rng_or_seed: np.random.Generator | int | None = None, + context: transformer_api.TransformerContext | None = None, + ): + context = ( + context + if isinstance(context, transformer_api.TransformerContext) + else transformer_api.TransformerContext() + ) + rng = ( + rng_or_seed + if isinstance(rng_or_seed, np.random.Generator) + else np.random.default_rng(rng_or_seed) + ) + + if context.deep: + raise ValueError(f"this transformer doesn't support deep {context=}") + + tags_to_ignore = frozenset(context.tags_to_ignore) + new_circuit: list[circuits.Moment] = [] + for moment in circuit: + if any(tag in tags_to_ignore for tag in moment.tags): + new_circuit.append(moment) + continue + new_moment = [] + for op in moment: + if any(tag in tags_to_ignore for tag in op.tags): + continue + if not self._is_target(op): + continue + pair = self._sample(op.qubits, rng) + for pauli, q in zip(pair, op.qubits): + if new_circuit and (q not in new_circuit[-1].qubits): + new_circuit[-1] += pauli(q) + else: + new_moment.append(pauli(q)) + if new_moment: + new_circuit.append(circuits.Moment(new_moment)) + new_circuit.append(moment) + return circuits.Circuit.from_moments(*new_circuit) diff --git a/cirq-core/cirq/transformers/pauli_insertion_test.py b/cirq-core/cirq/transformers/pauli_insertion_test.py new file mode 100644 index 00000000000..130c95eed9a --- /dev/null +++ b/cirq-core/cirq/transformers/pauli_insertion_test.py @@ -0,0 +1,104 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +import cirq + +_PAULIS = [cirq.I, cirq.X, cirq.Y, cirq.Z] + + +def _random_probs(n: int, seed: int | None = None): + rng = np.random.default_rng(seed) + for _ in range(n): + probs = rng.random((4, 4)) + probs /= probs.sum() + yield probs + + +@pytest.mark.parametrize('probs', _random_probs(3, 0)) +@pytest.mark.parametrize( + 'target', + [cirq.ZZPowGate, cirq.ZZ**0.324, cirq.Gateset(cirq.ZZ**0.324), cirq.GateFamily(cirq.ZZ**0.324)], +) +def test_pauli_insertion_with_probabilities(probs, target): + c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324) + transformer = cirq.transformers.PauliInsertionTransformer(target, probs) + count = np.zeros((4, 4)) + rng = np.random.default_rng(0) + for _ in range(100): + nc = transformer(c, rng_or_seed=rng) + assert len(nc) == 2 + u, v = nc[0] + i = _PAULIS.index(u.gate) + j = _PAULIS.index(v.gate) + count[i, j] += 1 + count = count / count.sum() + np.testing.assert_allclose(count, probs, atol=0.1) + + +@pytest.mark.parametrize('probs', _random_probs(3, 0)) +def test_pauli_insertion_with_probabilities_doesnot_create_moment(probs): + c = cirq.Circuit.from_moments([], [cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324]) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate, probs) + count = np.zeros((4, 4)) + rng = np.random.default_rng(0) + for _ in range(100): + nc = transformer(c, rng_or_seed=rng) + assert len(nc) == 2 + u, v = nc[0] + i = _PAULIS.index(u.gate) + j = _PAULIS.index(v.gate) + count[i, j] += 1 + count = count / count.sum() + np.testing.assert_allclose(count, probs, atol=0.1) + + +def test_invalid_context_raises(): + c = cirq.Circuit(cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate) + with pytest.raises(ValueError): + _ = transformer(c, context=cirq.TransformerContext(deep=True)) + + +def test_transformer_ignores_tagged_ops(): + op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324 + c = cirq.Circuit(op.with_tags('ignore')) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate) + + assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c + + +def test_transformer_ignores_tagged_moments(): + op = cirq.ZZ(*cirq.LineQubit.range(2)) ** 0.324 + c = cirq.Circuit(cirq.Moment(op).with_tags('ignore')) + transformer = cirq.transformers.PauliInsertionTransformer(cirq.ZZPowGate) + + assert transformer(c, context=cirq.TransformerContext(tags_to_ignore=('ignore',))) == c + + +def test_transformer_ignores_with_probs_map(): + qs = tuple(cirq.LineQubit.range(3)) + op = cirq.ZZ(*qs[:2]) ** 0.324 + c = cirq.Circuit(cirq.Moment(op)) + transformer = cirq.transformers.PauliInsertionTransformer( + cirq.ZZPowGate, {qs[1:]: np.ones((4, 4)) / 16} + ) + + assert transformer(c) == c # qubits are not in target + + c = cirq.Circuit(cirq.Moment(op.with_qubits(*qs[1:]))) + nc = transformer(c) + assert len(nc) == 2