diff --git a/cirq-core/cirq/transformers/connected_component.py b/cirq-core/cirq/transformers/connected_component.py new file mode 100644 index 00000000000..65365ddd3b5 --- /dev/null +++ b/cirq-core/cirq/transformers/connected_component.py @@ -0,0 +1,288 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines a connected component of operations, to be used in merge transformers.""" + +from __future__ import annotations + +from typing import Callable, cast, Sequence, TYPE_CHECKING + +from typing_extensions import override + +from cirq import ops, protocols + +if TYPE_CHECKING: + import cirq + + +class Component: + """Internal representation for a connected component of operations. + + It uses the disjoint-set data structure to implement merge efficiently. + Additional merge conditions can be added by deriving from the Component + class and overriding the merge function (see ComponentWithOps and + ComponentWithCircuitOp) below. + """ + + # Properties for the disjoint set data structure + parent: Component | None = None + rank: int = 0 + + # True if the component can be merged + is_mergeable: bool + + # Circuit moment containing the component + moment_id: int + # Union of all op qubits in the component + qubits: frozenset[cirq.Qid] + # Union of all measurement keys in the component + mkeys: frozenset[cirq.MeasurementKey] + # Union of all control keys in the component + ckeys: frozenset[cirq.MeasurementKey] + # Initial operation in the component + op: cirq.Operation + + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): + """Initializes a singleton component.""" + self.is_mergeable = is_mergeable + self.moment_id = moment_id + self.qubits = frozenset(op.qubits) + self.mkeys = protocols.measurement_key_objs(op) + self.ckeys = protocols.control_keys(op) + self.op = op + + def find(self) -> Component: + """Finds the component representative.""" + + root = self + while root.parent is not None: + root = root.parent + x = self + while x != root: + parent = x.parent + x.parent = root + x = cast(Component, parent) + return root + + def merge(self, c: Component, merge_left=True) -> Component | None: + """Attempts to merge two components. + + If merge_left is True, c is merged into this component, and the representative + will keep this component's moment. If merge_left is False, this component is + merged into c, and the representative will keep c's moment. + + Args: + c: other component to merge + merge_left: True to keep self's moment for the merged component, False to + keep c's moment for the merged component. + + Returns: + None, if the components can't be merged. + Otherwise the new component representative. + """ + x = self.find() + y = c.find() + + if not x.is_mergeable or not y.is_mergeable: + return None + + if x == y: + return x + + if x.rank < y.rank: + if merge_left: + # As y will be the new representative, copy moment id from x + y.moment_id = x.moment_id + x, y = y, x + elif not merge_left: + # As x will be the new representative, copy moment id from y + x.moment_id = y.moment_id + + y.parent = x + if x.rank == y.rank: + x.rank += 1 + + x.qubits = x.qubits.union(y.qubits) + x.mkeys = x.mkeys.union(y.mkeys) + x.ckeys = x.ckeys.union(y.ckeys) + return x + + +class ComponentWithOps(Component): + """Component that keeps track of operations. + + Encapsulates a method can_merge that is used to decide if two components + can be merged. + """ + + # List of all operations in the component + ops: list[cirq.Operation] + + # Method to decide if two components can be merged based on their operations + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] + + def __init__( + self, + op: cirq.Operation, + moment_id: int, + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], + is_mergeable=True, + ): + super().__init__(op, moment_id, is_mergeable) + self.ops = [op] + self.can_merge = can_merge + + @override + def merge(self, c: Component, merge_left=True) -> Component | None: + """Attempts to merge two components. + + Returns: + None if can_merge is False, otherwise the new representative. + The representative will have ops = a.ops + b.ops. + """ + x = cast(ComponentWithOps, self.find()) + y = cast(ComponentWithOps, c.find()) + + if x == y: + return x + + if not x.is_mergeable or not y.is_mergeable or not x.can_merge(x.ops, y.ops): + return None + + root = cast(ComponentWithOps, super(ComponentWithOps, x).merge(y, merge_left)) + root.ops = x.ops + y.ops + # Clear the ops list in the non-representative set to avoid memory consumption + if x != root: + x.ops = [] + else: + y.ops = [] + return root + + +class ComponentWithCircuitOp(Component): + """Component that keeps track of operations as a CircuitOperation. + + Encapsulates a method merge_func that is used to merge two components. + """ + + # CircuitOperation containing all the operations in the component, + # or a single Operation if the component is a singleton + circuit_op: cirq.Operation + + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] + + def __init__( + self, + op: cirq.Operation, + moment_id: int, + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], + is_mergeable=True, + ): + super().__init__(op, moment_id, is_mergeable) + self.circuit_op = op + self.merge_func = merge_func + + @override + def merge(self, c: Component, merge_left=True) -> Component | None: + """Attempts to merge two components. + + If merge_left is True, the merge will use this component representative's + merge_func. If merge_left is False, the merge will use c representative's + merge_func. + + Returns: + None if merge_func returns None, otherwise the new representative. + """ + x = cast(ComponentWithCircuitOp, self.find()) + y = cast(ComponentWithCircuitOp, c.find()) + + if x == y: + return x + + if not x.is_mergeable or not y.is_mergeable: + return None + + if merge_left: + new_op = x.merge_func(x.circuit_op, y.circuit_op) + else: + new_op = y.merge_func(x.circuit_op, y.circuit_op) + if not new_op: + return None + + root = cast(ComponentWithCircuitOp, super(ComponentWithCircuitOp, x).merge(y, merge_left)) + + root.circuit_op = new_op + # The merge_func can be arbitrary, so we need to recompute the component properties + root.qubits = frozenset(new_op.qubits) + root.mkeys = protocols.measurement_key_objs(new_op) + root.ckeys = protocols.control_keys(new_op) + + # Clear the circuit op in the non-representative set to avoid memory consumption + if x != root: + del x.circuit_op + else: + del y.circuit_op + return root + + +class ComponentFactory: + """Factory for components.""" + + is_mergeable: Callable[[cirq.Operation], bool] + + def __init__(self, is_mergeable: Callable[[cirq.Operation], bool]): + self.is_mergeable = is_mergeable + + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: + return Component(op, moment_id, self.is_mergeable(op) and is_mergeable) + + +class ComponentWithOpsFactory(ComponentFactory): + """Factory for components with operations.""" + + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] + + def __init__( + self, + is_mergeable: Callable[[cirq.Operation], bool], + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], + ): + super().__init__(is_mergeable) + self.can_merge = can_merge + + @override + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: + return ComponentWithOps( + op, moment_id, self.can_merge, self.is_mergeable(op) and is_mergeable + ) + + +class ComponentWithCircuitOpFactory(ComponentFactory): + """Factory for components with operations as CircuitOperation.""" + + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] + + def __init__( + self, + is_mergeable: Callable[[cirq.Operation], bool], + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], + ): + super().__init__(is_mergeable) + self.merge_func = merge_func + + @override + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: + return ComponentWithCircuitOp( + op, moment_id, self.merge_func, self.is_mergeable(op) and is_mergeable + ) diff --git a/cirq-core/cirq/transformers/connected_component_test.py b/cirq-core/cirq/transformers/connected_component_test.py new file mode 100644 index 00000000000..e8db273c477 --- /dev/null +++ b/cirq-core/cirq/transformers/connected_component_test.py @@ -0,0 +1,322 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import cast + +import cirq +from cirq.transformers.connected_component import ( + Component, + ComponentFactory, + ComponentWithCircuitOp, + ComponentWithCircuitOpFactory, + ComponentWithOpsFactory, +) + + +def test_find_returns_itself_for_singleton(): + q = cirq.NamedQubit('x') + c = Component(op=cirq.X(q), moment_id=0) + assert c.find() == c + + +def test_merge_components(): + q = cirq.NamedQubit('x') + c = [Component(op=cirq.X(q), moment_id=i) for i in range(5)] + c[1].merge(c[0]) + c[2].merge(c[1]) + c[4].merge(c[3]) + c[3].merge(c[0]) + # Disjoint set structure: + # c[4] + # / \ + # c[1] c[3] + # / \ + # c[0] c[2] + assert c[0].parent == c[1] + assert c[2].parent == c[1] + assert c[1].parent == c[4] + assert c[3].parent == c[4] + + for i in range(5): + assert c[i].find() == c[4] + # Find() compressed all paths + for i in range(4): + assert c[i].parent == c[4] + + +def test_merge_same_component(): + q = cirq.NamedQubit('x') + c = [Component(op=cirq.X(q), moment_id=i) for i in range(3)] + c[1].merge(c[0]) + c[2].merge(c[1]) + # Disjoint set structure: + # c[1] + # / \ + # c[0] c[2] + assert c[0].merge(c[2]) == c[1] + + +def test_merge_returns_None_if_one_component_is_not_mergeable(): + q = cirq.NamedQubit('x') + c0 = Component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = Component(op=cirq.X(q), moment_id=1, is_mergeable=False) + assert c0.merge(c1) is None + + +def test_factory_merge_returns_None_if_is_mergeable_is_false(): + q = cirq.NamedQubit('x') + + def is_mergeable(op: cirq.Operation) -> bool: + del op + return False + + factory = ComponentFactory(is_mergeable=is_mergeable) + c0 = factory.new_component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = factory.new_component(op=cirq.X(q), moment_id=1, is_mergeable=True) + assert c0.merge(c1) is None + + +def test_merge_qubits_with_merge_left_true(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q1), moment_id=0) + c2 = Component(op=cirq.X(q1), moment_id=1) + c1.merge(c2) + c0.merge(c1, merge_left=True) + assert c0.find() == c1 + assert c1.qubits == frozenset([q0, q1]) + + +def test_merge_qubits_with_merge_left_false(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q0), moment_id=0) + c2 = Component(op=cirq.X(q1), moment_id=1) + c0.merge(c1) + c1.merge(c2, merge_left=False) + assert c2.find() == c0 + assert c0.qubits == frozenset([q0, q1]) + + +def test_merge_moment_with_merge_left_true(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q1), moment_id=1) + c2 = Component(op=cirq.X(q1), moment_id=1) + c1.merge(c2) + c0.merge(c1, merge_left=True) + assert c0.find() == c1 + # c1 is the set representative but kept c0's moment + assert c1.moment_id == 0 + + +def test_merge_moment_with_merge_left_false(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q0), moment_id=0) + c2 = Component(op=cirq.X(q1), moment_id=1) + c0.merge(c1) + c1.merge(c2, merge_left=False) + assert c2.find() == c0 + # c0 is the set representative but kept c2's moment + assert c0.moment_id == 1 + + +def test_component_with_ops_merge(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return True + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + assert c[0].find().ops == ops + + +def test_component_with_ops_merge_same_component(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return True + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.NamedQubit('x') + c = [factory.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] + c[1].merge(c[0]) + c[2].merge(c[1]) + assert c[0].merge(c[2]) == c[1] + + +def test_component_with_ops_merge_when_merge_fails(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return False + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] + + +def test_component_with_ops_merge_when_is_mergeable_is_false(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return False + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return True + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] + + +def test_component_with_circuit_op_merge(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + for i in range(3): + assert c[i].find().circuit_op == ops[0] + + +def test_component_with_circuit_op_merge_same_component(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + + q = cirq.NamedQubit('x') + c = [factory.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] + c[1].merge(c[0]) + c[2].merge(c[1]) + assert c[0].merge(c[2]) == c[1] + + +def test_component_with_circuit_op_merge_func_is_none(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> None: + del op1, op2 + return None + + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] + + +def test_component_with_circuit_op_merge_when_is_mergeable_is_false(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return False + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] + + +def test_component_with_circuit_op_merge_when_merge_left_is_false(): + def merge_func_x(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + + def merge_func_y(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op1 + return op2 + + q = cirq.LineQubit.range(2) + x = ComponentWithCircuitOp(cirq.X(q[0]), moment_id=0, merge_func=merge_func_x) + y = ComponentWithCircuitOp(cirq.X(q[1]), moment_id=1, merge_func=merge_func_y) + + root = cast(ComponentWithCircuitOp, x.merge(y, merge_left=False)) + # The merge used merge_func_y because merge_left=False + assert root.circuit_op == cirq.X(q[1]) diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 14396764011..a9c165dfc13 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -17,12 +17,20 @@ from __future__ import annotations import bisect +import copy import dataclasses from collections import defaultdict from typing import Callable, cast, Hashable, Sequence, TYPE_CHECKING from cirq import circuits, ops, protocols from cirq.circuits.circuit import CIRCUIT_TYPE +from cirq.transformers.connected_component import ( + Component, + ComponentFactory, + ComponentWithCircuitOp, + ComponentWithCircuitOpFactory, + ComponentWithOpsFactory, +) if TYPE_CHECKING: import cirq @@ -282,17 +290,17 @@ def map_operations_and_unroll( @dataclasses.dataclass class _MergedCircuit: - """An optimized internal representation of a circuit, tailored for `cirq.merge_operations` + """An optimized internal representation of a circuit, tailored for merge operations Attributes: - qubit_indexes: Mapping from qubits to (sorted) list of moment indexes containing operations - acting on the qubit. - mkey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing - measurement operations with the same key. - ckey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing - classically controlled operations controlled on the same key. - ops_by_index: List of circuit moments containing operations. We use a dictionary instead - of a set to store operations to preserve insertion order. + qubit_indexes: Mapping from qubits to (sorted) list of component moments containing + operations acting on the qubit. + mkey_indexes: Mapping from measurement keys to (sorted) list of component moments + containing measurement operations with the same key. + ckey_indexes: Mapping from measurement keys to (sorted) list of component moments + containing classically controlled operations controlled on the same key. + components_by_index: List of circuit moments containing components. We use a dictionary + instead of a set to store components to preserve insertion order. """ qubit_indexes: dict[cirq.Qid, list[int]] = dataclasses.field( @@ -304,54 +312,224 @@ class _MergedCircuit: ckey_indexes: dict[cirq.MeasurementKey, list[int]] = dataclasses.field( default_factory=lambda: defaultdict(lambda: [-1]) ) - ops_by_index: list[dict[cirq.Operation, int]] = dataclasses.field(default_factory=list) + components_by_index: list[dict[Component, int]] = dataclasses.field(default_factory=list) def append_empty_moment(self) -> None: - self.ops_by_index.append({}) + self.components_by_index.append({}) - def add_op_to_moment(self, moment_index: int, op: cirq.Operation) -> None: - self.ops_by_index[moment_index][op] = 0 - for q in op.qubits: - if moment_index > self.qubit_indexes[q][-1]: - self.qubit_indexes[q].append(moment_index) - else: - bisect.insort(self.qubit_indexes[q], moment_index) - for mkey in protocols.measurement_key_objs(op): - bisect.insort(self.mkey_indexes[mkey], moment_index) - for ckey in protocols.control_keys(op): - bisect.insort(self.ckey_indexes[ckey], moment_index) - - def remove_op_from_moment(self, moment_index: int, op: cirq.Operation) -> None: - self.ops_by_index[moment_index].pop(op) - for q in op.qubits: - if self.qubit_indexes[q][-1] == moment_index: - self.qubit_indexes[q].pop() - else: - self.qubit_indexes[q].remove(moment_index) - for mkey in protocols.measurement_key_objs(op): - self.mkey_indexes[mkey].remove(moment_index) - for ckey in protocols.control_keys(op): - self.ckey_indexes[ckey].remove(moment_index) - - def get_mergeable_ops( - self, op: cirq.Operation, op_qs: set[cirq.Qid] - ) -> tuple[int, list[cirq.Operation]]: - # Find the index of previous moment which can be merged with `op`. - idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1) - idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)]) - idx = max( - [idx] + [self.ckey_indexes[mkey][-1] for mkey in protocols.measurement_key_objs(op)] - ) - # Return the set of overlapping ops in moment with index `idx`. + def add_moment(self, index: list[int], moment_id: int) -> None: + """Adds a moment to a sorted list of moment indexes. + + Optimized for the majority case when the new moment is higher than any moment in the list. + """ + if index[-1] < moment_id: + index.append(moment_id) + else: + bisect.insort(index, moment_id) + + def remove_moment(self, index: list[int], moment_id: int) -> None: + """Removes a moment from a sorted list of moment indexes. + + Optimized for the majority case when the moment is last in the list. + """ + if index[-1] == moment_id: + index.pop() + else: + index.remove(moment_id) + + def add_component(self, c: Component) -> None: + """Adds a new components to merged circuit.""" + self.components_by_index[c.moment_id][c] = 0 + for q in c.qubits: + self.add_moment(self.qubit_indexes[q], c.moment_id) + for mkey in c.mkeys: + self.add_moment(self.mkey_indexes[mkey], c.moment_id) + for ckey in c.ckeys: + self.add_moment(self.ckey_indexes[ckey], c.moment_id) + + def remove_component(self, c: Component, c_data: Component) -> None: + """Removes a component from the merged circuit. + + Args: + c: reference to the component to be removed + c_data: copy of the data in c before any component merges involving c + (this is necessary as component merges alter the component data) + """ + self.components_by_index[c_data.moment_id].pop(c) + for q in c_data.qubits: + self.remove_moment(self.qubit_indexes[q], c_data.moment_id) + for mkey in c_data.mkeys: + self.remove_moment(self.mkey_indexes[mkey], c_data.moment_id) + for ckey in c_data.ckeys: + self.remove_moment(self.ckey_indexes[ckey], c_data.moment_id) + + def get_mergeable_components(self, c: Component, c_qs: set[cirq.Qid]) -> list[Component]: + """Finds all components that can be merged with c. + + Args: + c: component to be merged with existing components + c_qs: subset of c.qubits used to decide which components are mergeable + + Returns: + list of mergeable components + """ + # Find the index of previous moment which can be merged with `c`. + idx = max([self.qubit_indexes[q][-1] for q in c_qs], default=-1) + idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in c.ckeys]) + idx = max([idx] + [self.ckey_indexes[mkey][-1] for mkey in c.mkeys]) + # Return the set of overlapping components in moment with index `idx`. if idx == -1: - return idx, [] + return [] + + return [c for c in self.components_by_index[idx] if not c_qs.isdisjoint(c.qubits)] + + def get_cirq_circuit( + self, components: list[Component], merged_circuit_op_tag: str + ) -> cirq.Circuit: + """Returns the merged circuit. + + Args: + components: all components in creation order + merged_circuit_op_tag: tag to use for CircuitOperations + + Returns: + the circuit with merged components as a CircuitOperation + """ + component_ops: dict[Component, list[cirq.Operation]] = defaultdict(list) + + # Traverse the components in creation order and collect operations + for c in components: + root = c.find() + component_ops[root].append(c.op) + + moments = [] + for m in self.components_by_index: + ops = [] + for c in m.keys(): + if isinstance(c, ComponentWithCircuitOp): + ops.append(c.circuit_op) + continue + if len(component_ops[c]) == 1: + ops.append(component_ops[c][0]) + else: + ops.append( + circuits.CircuitOperation( + circuits.FrozenCircuit(component_ops[c]) + ).with_tags(merged_circuit_op_tag) + ) + moments.append(circuits.Moment(ops)) + return circuits.Circuit(moments) + + +def _merge_operations_impl( + circuit: CIRCUIT_TYPE, + factory: ComponentFactory, + *, + merged_circuit_op_tag: str = "Merged connected component", + tags_to_ignore: Sequence[Hashable] = (), + deep: bool = False, +) -> CIRCUIT_TYPE: + """Merges operations in a circuit. + + Two operations op1 and op2 are merge-able if + - There is no other operations between op1 and op2 in the circuit + - is_subset(op1.qubits, op2.qubits) or is_subset(op2.qubits, op1.qubits) + + The method iterates on the input circuit moment-by-moment from left to right and attempts + to repeatedly merge each operation in the latest moment with all the corresponding merge-able + operations to its left. - return idx, [ - left_op for left_op in self.ops_by_index[idx] if not op_qs.isdisjoint(left_op.qubits) - ] + Operations are wrapped in a component and then component.merge is called to merge two + components. The factory can provide components with different implementations of the merge + function, allowing for optimizations. - def get_cirq_circuit(self) -> cirq.Circuit: - return circuits.Circuit(circuits.Moment(m.keys()) for m in self.ops_by_index) + If op1 and op2 are merged, both op1 and op2 are deleted from the circuit and + the merged component is inserted at the index corresponding to the larger + of op1/op2. If both op1 and op2 act on the same number of qubits, the merged component is + inserted in the smaller moment index to minimize circuit depth. + + At the end every component with more than one operation is replaced by a CircuitOperation. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + factory: Factory that creates components from an operation. + merged_circuit_op_tag: tag used for CircuitOperations created from merged components. + tags_to_ignore: Sequence of tags which should be ignored during the merge: operations with + these tags will not be merged. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. + + + Returns: + Copy of input circuit with merged operations. + """ + components = [] # List of all components in creation order + tags_to_ignore_set = set(tags_to_ignore) + + merged_circuit = _MergedCircuit() + for moment_idx, current_moment in enumerate(cast(list['cirq.Moment'], circuit)): + merged_circuit.append_empty_moment() + for op in sorted(current_moment.operations, key=lambda op: op.qubits): + if ( + deep + and isinstance(op.untagged, circuits.CircuitOperation) + and tags_to_ignore_set.isdisjoint(op.tags) + ): + op_untagged = op.untagged + merged_op = op_untagged.replace( + circuit=_merge_operations_impl( + op_untagged.circuit, + factory, + merged_circuit_op_tag=merged_circuit_op_tag, + tags_to_ignore=tags_to_ignore, + deep=True, + ) + ).with_tags(*op.tags) + c = factory.new_component(merged_op, moment_idx, is_mergeable=False) + components.append(c) + merged_circuit.add_component(c) + continue + + c = factory.new_component( + op, moment_idx, is_mergeable=tags_to_ignore_set.isdisjoint(op.tags) + ) + components.append(c) + if not c.is_mergeable: + merged_circuit.add_component(c) + continue + + c_qs = set(c.qubits) + left_comp = merged_circuit.get_mergeable_components(c, c_qs) + if len(left_comp) == 1 and c_qs.issubset(left_comp[0].qubits): + # Make a shallow copy of the left component data before merge + left_c_data = copy.copy(left_comp[0]) + # Case-1: Try to merge c with the larger component on the left. + new_comp = left_comp[0].merge(c, merge_left=True) + if new_comp is not None: + merged_circuit.remove_component(left_comp[0], left_c_data) + merged_circuit.add_component(new_comp) + else: + merged_circuit.add_component(c) + continue + + while left_comp and c_qs: + # Case-2: left_c will merge right into `c` whenever possible. + for left_c in left_comp: + is_merged = False + if c_qs.issuperset(left_c.qubits): + # Make a shallow copy of the left component data before merge + left_c_data = copy.copy(left_c) + # Try to merge left_c into c + new_comp = left_c.merge(c, merge_left=False) + if new_comp is not None: + merged_circuit.remove_component(left_c, left_c_data) + c, is_merged = new_comp, True + if not is_merged: + c_qs -= left_c.qubits + left_comp = merged_circuit.get_mergeable_components(c, c_qs) + merged_circuit.add_component(c) + ret_circuit = merged_circuit.get_cirq_circuit(components, merged_circuit_op_tag) + return _to_target_circuit_type(ret_circuit, circuit) def merge_operations( @@ -407,12 +585,8 @@ def merge_operations( ValueError if the merged operation acts on new qubits outside the set of qubits corresponding to the original operations to be merged. """ - _circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit" - tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag} def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> ops.Operation | None: - if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]): - return None new_op = merge_func(op1, op2) qubit_set = frozenset(op1.qubits + op2.qubits) if new_op is not None and not qubit_set.issuperset(new_op.qubits): @@ -422,63 +596,16 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> ops.Operation | ) return new_op - merged_circuit = _MergedCircuit() - for moment_idx, current_moment in enumerate(cast(list['cirq.Moment'], circuit)): - merged_circuit.append_empty_moment() - for op in sorted(current_moment.operations, key=lambda op: op.qubits): - if ( - deep - and isinstance(op.untagged, circuits.CircuitOperation) - and tags_to_ignore_set.isdisjoint(op.tags) - ): - op_untagged = op.untagged - merged_circuit.add_op_to_moment( - moment_idx, - op_untagged.replace( - circuit=merge_operations( - op_untagged.circuit, - merge_func, - tags_to_ignore=tags_to_ignore, - deep=True, - ) - ).with_tags(*op.tags, _circuit_op_tag), - ) - continue - - op_qs = set(op.qubits) - left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs) - if len(left_ops) == 1 and op_qs.issubset(left_ops[0].qubits): - # Case-1: Try to merge op with the larger operation on the left. - new_op = apply_merge_func(left_ops[0], op) - if new_op is not None: - merged_circuit.remove_op_from_moment(left_idx, left_ops[0]) - merged_circuit.add_op_to_moment(left_idx, new_op) - else: - merged_circuit.add_op_to_moment(moment_idx, op) - continue + def is_mergeable(op: cirq.Operation): + del op + return True - while left_ops and op_qs: - # Case-2: left_ops will merge right into `op` whenever possible. - for left_op in left_ops: - is_merged = False - if op_qs.issuperset(left_op.qubits): - # Try to merge left_op into op - new_op = apply_merge_func(left_op, op) - if new_op is not None: - merged_circuit.remove_op_from_moment(left_idx, left_op) - op, is_merged = new_op, True - if not is_merged: - op_qs -= frozenset(left_op.qubits) - left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs) - merged_circuit.add_op_to_moment(moment_idx, op) - ret_circuit = merged_circuit.get_cirq_circuit() - if deep: - ret_circuit = map_operations( - ret_circuit, - lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})), - deep=True, - ) - return _to_target_circuit_type(ret_circuit, circuit) + return _merge_operations_impl( + circuit, + ComponentWithCircuitOpFactory(is_mergeable, apply_merge_func), + tags_to_ignore=tags_to_ignore, + deep=deep, + ) def merge_operations_to_circuit_op( @@ -491,10 +618,9 @@ def merge_operations_to_circuit_op( ) -> CIRCUIT_TYPE: """Merges connected components of operations and wraps each component into a circuit operation. - Uses `cirq.merge_operations` to identify connected components of operations. Moment structure - is preserved for operations that do not participate in merging. For merged operations, the - newly created circuit operations are constructed by inserting operations using EARLIEST - strategy. + Moment structure is preserved for operations that do not participate in merging. + For merged operations, the newly created circuit operations are constructed by inserting + operations using EARLIEST strategy. If you need more control on moment structure of newly created circuit operations, consider using `cirq.merge_operations` directly with a custom `merge_func`. @@ -514,24 +640,17 @@ def merge_operations_to_circuit_op( Copy of input circuit with valid connected components wrapped in tagged circuit operations. """ - def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation | None: - def get_ops(op: cirq.Operation): - op_untagged = op.untagged - return ( - [*op_untagged.circuit.all_operations()] - if isinstance(op_untagged, circuits.CircuitOperation) - and merged_circuit_op_tag in op.tags - else [op] - ) - - left_ops, right_ops = get_ops(op1), get_ops(op2) - if not can_merge(left_ops, right_ops): - return None - return circuits.CircuitOperation(circuits.FrozenCircuit(left_ops, right_ops)).with_tags( - merged_circuit_op_tag - ) + def is_mergeable(op: cirq.Operation): + del op + return True - return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep) + return _merge_operations_impl( + circuit, + ComponentWithOpsFactory(is_mergeable, can_merge), + merged_circuit_op_tag=merged_circuit_op_tag, + tags_to_ignore=tags_to_ignore, + deep=deep, + ) def merge_k_qubit_unitaries_to_circuit_op( @@ -544,10 +663,9 @@ def merge_k_qubit_unitaries_to_circuit_op( ) -> CIRCUIT_TYPE: """Merges connected components of operations, acting on <= k qubits, into circuit operations. - Uses `cirq.merge_operations_to_circuit_op` to identify and merge connected components of - unitary operations acting on at-most k-qubits. Moment structure is preserved for operations - that do not participate in merging. For merged operations, the newly created circuit operations - are constructed by inserting operations using EARLIEST strategy. + Moment structure is preserved for operations that do not participate in merging. + For merged operations, the newly created circuit operations are constructed by inserting + operations using EARLIEST strategy. Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. @@ -563,18 +681,14 @@ def merge_k_qubit_unitaries_to_circuit_op( Copy of input circuit with valid connected components wrapped in tagged circuit operations. """ - def can_merge(ops1: Sequence[cirq.Operation], ops2: Sequence[cirq.Operation]) -> bool: - return all( - protocols.num_qubits(op) <= k and protocols.has_unitary(op) - for op_list in [ops1, ops2] - for op in op_list - ) + def is_mergeable(op: cirq.Operation): + return protocols.num_qubits(op) <= k and protocols.has_unitary(op) - return merge_operations_to_circuit_op( + return _merge_operations_impl( circuit, - can_merge, - tags_to_ignore=tags_to_ignore, + ComponentFactory(is_mergeable), merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.", + tags_to_ignore=tags_to_ignore, deep=deep, ) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index e1152b60aff..eddc866266c 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -877,3 +877,187 @@ def merge_func(op1, op2): cirq.testing.assert_same_circuits( cirq.align_left(cirq.merge_operations(circuit, merge_func)), expected_circuit ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_absorbs_overlapping_2q_gates(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.Moment( + cirq.H(q[0]).with_tags("ignore"), + cirq.H(q[1]).with_tags("ignore"), + cirq.H(q[2]).with_tags("ignore"), + ), + cirq.Moment(cirq.CNOT(q[0], q[2]), cirq.X(q[1]).with_tags("ignore")), + cirq.CNOT(q[0], q[1]), + cirq.CNOT(q[1], q[2]), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' + ┌──────────┐ +0: ───H[ignore]────@─────────────@───────@─── + │ │ │ +1: ───H[ignore]────┼X[ignore]────X───@───@─── + │ │ │ +2: ───H[ignore]────X─────────────────X───@─── + └──────────┘ +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@───────@─── ] + [ │ │ │ ] +0: ───H[ignore]───────────────[ 1: ───┼───X───@───@─── ]─────────── + [ │ │ │ ] + [ 2: ───X───────X───@─── ][merged] + │ +1: ───H[ignore]───X[ignore]───#2─────────────────────────────────── + │ +2: ───H[ignore]───────────────#3─────────────────────────────────── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_absorbs_disjoint_gates(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2])), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───@─── + │ │ +1: ───X───@─── + │ +2: ───X───@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@─── ] + [ │ │ ] +0: ───[ 1: ───X───@─── ]─────────── + [ │ ] + [ 2: ───X───@─── ][merged] + │ +1: ───#2─────────────────────────── + │ +2: ───#3─────────────────────────── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_doesnt_absorb_unmergeable_gate(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.CCZ(*q), + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2]).with_tags("ignore")), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───@───────────@─── + │ │ │ +1: ───@───X───────────@─── + │ │ +2: ───@───X[ignore]───@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@─── ] + [ │ │ ] +0: ───[ 1: ───@───X─── ]───────────────────────@─── + [ │ ] │ + [ 2: ───@─────── ][merged] │ + │ │ +1: ───#2───────────────────────────────────────@─── + │ │ +2: ───#3───────────────────────────X[ignore]───@─── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_prefer_to_merge_into_earlier_op(): + q = cirq.LineQubit.range(6) + c_orig = cirq.Circuit( + cirq.Moment( + cirq.CCZ(*q[0:3]), cirq.X(q[3]), cirq.H(q[4]), cirq.H(q[5]).with_tags("ignore") + ), + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2]).with_tags("ignore"), cirq.CCZ(*q[3:6])), + cirq.Moment( + cirq.X(q[0]), + cirq.X(q[1]), + cirq.X(q[2]), + cirq.X(q[3]).with_tags("ignore"), + cirq.CNOT(*q[4:6]), + ), + cirq.Moment(cirq.CCZ(*q[0:3]), cirq.CCZ(*q[3:6])), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───────────@───────────X───────────@─── + │ │ │ +1: ───@───────────X───────────X───────────@─── + │ │ +2: ───@───────────X[ignore]───X───────────@─── + +3: ───X───────────@───────────X[ignore]───@─── + │ │ +4: ───H───────────@───────────@───────────@─── + │ │ │ +5: ───H[ignore]───@───────────X───────────@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@───X─── ] [ 0: ───────@─── ] + [ │ │ ] [ │ ] +0: ───[ 1: ───@───X───X─── ]────────────────────────────────────────────────────────[ 1: ───────@─── ]─────────── + [ │ ] [ │ ] + [ 2: ───@─────────── ][merged] [ 2: ───X───@─── ][merged] + │ │ +1: ───#2────────────────────────────────────────────────────────────────────────────#2─────────────────────────── + │ │ +2: ───#3───────────────────────────────X[ignore]────────────────────────────────────#3─────────────────────────── + + [ 3: ───X───@─────── ] + [ │ ] +3: ────────────────────────────────────[ 4: ───H───@───@─── ]───────────X[ignore]───@──────────────────────────── + [ │ │ ] │ + [ 5: ───────@───X─── ][merged] │ + │ │ +4: ────────────────────────────────────#2───────────────────────────────────────────@──────────────────────────── + │ │ +5: ───H[ignore]────────────────────────#3───────────────────────────────────────────@──────────────────────────── +''', # noqa: E501 + )