diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 48265a888aa..53ab4f88b6c 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1777,25 +1777,48 @@ def __init__( return flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents)) if all(isinstance(c, Moment) for c in flattened_contents): - self._placement_cache = None + self._placement_cache = None # Direct moment loading doesn't use/need cache initially. self._moments[:] = cast(Iterable[Moment], flattened_contents) return + # EARLIEST strategy during __init__ will use _load_contents_with_earliest_strategy, with _compat.block_overlapping_deprecation('.*'): if strategy == InsertStrategy.EARLIEST: self._load_contents_with_earliest_strategy(flattened_contents) else: + # For non-EARLIEST strategies, the cache is not useful during this initial loading. + if self._placement_cache is not None: + self._placement_cache = None self.append(flattened_contents, strategy=strategy) - def _mutated(self, *, preserve_placement_cache=False) -> None: + def _mutated(self, *, preserve_placement_cache: bool = False) -> None: """Clear cached properties in response to this circuit being mutated.""" self._all_qubits = None self._frozen = None self._is_measurement = None self._is_parameterized = None self._parameter_names = None - if not preserve_placement_cache: + if not preserve_placement_cache and self._placement_cache is not None: self._placement_cache = None + def _rebuild_placement_cache(self) -> None: + """Rebuilds the placement cache from the current _moments. + + This method is called when an EARLIEST append is attempted and the + cache has been invalidated. + """ + self._placement_cache = _PlacementCache() + cache = self._placement_cache # Shorthand + + for i, moment_val in enumerate(self._moments): + for op_in_moment in moment_val.operations: + for q_op in op_in_moment.qubits: + cache._qubit_indices[q_op] = i + for mk_op in protocols.measurement_key_objs(op_in_moment): + cache._mkey_indices[mk_op] = i + for ck_op in protocols.control_keys(op_in_moment): + cache._ckey_indices[ck_op] = i + cache._length = len(self._moments) + @classmethod def _from_moments(cls, moments: Iterable[cirq.Moment]) -> Circuit: new_circuit = Circuit() @@ -2116,8 +2139,14 @@ def insert( ValueError: Bad insertion strategy. """ # limit index to 0..len(self._moments), also deal with indices smaller 0 - k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0) - if strategy != InsertStrategy.EARLIEST or k != len(self._moments): + k_idx = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0) + + # FIX: Move cache logic here, use original index + is_earliest_append = strategy == InsertStrategy.EARLIEST and index == len(self._moments) + if is_earliest_append: + if self._placement_cache is None: + self._rebuild_placement_cache() + elif self._placement_cache is not None: self._placement_cache = None mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)) if self._placement_cache: @@ -2126,6 +2155,8 @@ def insert( batches = [[mop] for mop in mops] # Each op goes into its own moment. else: batches = list(_group_into_moment_compatible(mops)) + + current_k_idx = k_idx # Use a local variable for the current index within the batch loop for batch in batches: # Insert a moment if inline/earliest and _any_ op in the batch requires it. if ( @@ -2133,28 +2164,32 @@ def insert( and not isinstance(batch[0], Moment) and strategy in (InsertStrategy.INLINE, InsertStrategy.EARLIEST) and not all( - (strategy is InsertStrategy.EARLIEST and self._can_add_op_at(k, op)) - or (k > 0 and self._can_add_op_at(k - 1, op)) + (strategy is InsertStrategy.EARLIEST and self._can_add_op_at(current_k_idx, op)) + or (current_k_idx > 0 and self._can_add_op_at(current_k_idx - 1, op)) for op in cast(list[cirq.Operation], batch) ) ): - self._moments.insert(k, Moment()) + self._moments.insert(current_k_idx, Moment()) if strategy is InsertStrategy.INLINE: - k += 1 + current_k_idx += 1 max_p = 0 + current_strategy_for_batch = strategy for moment_or_op in batch: # Determine Placement if self._placement_cache: p = self._placement_cache.append(moment_or_op) elif isinstance(moment_or_op, Moment): - p = k - elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE): - self._moments.insert(k, Moment()) - p = k - elif strategy is InsertStrategy.INLINE: - p = k - 1 + p = current_k_idx + elif current_strategy_for_batch in ( + InsertStrategy.NEW, + InsertStrategy.NEW_THEN_INLINE, + ): + self._moments.insert(current_k_idx, Moment()) + p = current_k_idx + elif current_strategy_for_batch is InsertStrategy.INLINE: + p = current_k_idx - 1 else: # InsertStrategy.EARLIEST: - p = self.earliest_available_moment(moment_or_op, end_moment_index=k) + p = self.earliest_available_moment(moment_or_op, end_moment_index=current_k_idx) # Place if isinstance(moment_or_op, Moment): self._moments.insert(p, moment_or_op) @@ -2164,12 +2199,20 @@ def insert( self._moments[p] = self._moments[p].with_operation(moment_or_op) # Iterate max_p = max(p, max_p) - if strategy is InsertStrategy.NEW_THEN_INLINE: - strategy = InsertStrategy.INLINE - k += 1 - k = max(k, max_p + 1) - self._mutated(preserve_placement_cache=True) - return k + if current_strategy_for_batch is InsertStrategy.NEW_THEN_INLINE: + current_strategy_for_batch = InsertStrategy.INLINE + if p == current_k_idx: + current_k_idx += 1 + # Update current_k_idx to be after the newly inserted operations for the next batch. + current_k_idx = max(current_k_idx, max_p + 1) + # At the end, update k_idx to reflect the final position + k_idx = current_k_idx + # Preserve cache only if it was an EARLIEST append and the cache was + # successfully used/rebuilt. + self._mutated( + preserve_placement_cache=(is_earliest_append and self._placement_cache is not None) + ) + return k_idx def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int: """Writes operations inline into an area of the circuit. diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 7169297cc51..652a51d140f 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -26,7 +26,7 @@ import sympy import cirq -from cirq import circuits, ops +from cirq import circuits, Moment, ops from cirq.testing.devices import ValidatingTestDevice @@ -4910,3 +4910,181 @@ def test_append_speed(): duration = time.perf_counter() - t assert len(c) == moments assert duration < 5 + + +def test_placement_cache_rebuild_and_invalidation(): + q0, q1, q2 = cirq.LineQubit.range(3) + c = cirq.Circuit() + + # Initial appends (EARLIEST by default) should use/build cache + c.append(cirq.H(q0)) + c.append(cirq.H(q1)) # Moment 0: H(q0), H(q1) + assert c._placement_cache is not None, "Cache should be active after initial appends" + assert len(c) == 1 + assert c._placement_cache._length == 1 + assert c._placement_cache._qubit_indices.get(q0) == 0 + assert c._placement_cache._qubit_indices.get(q1) == 0 + + c.append(cirq.CNOT(q0, q1)) # Moment 1: CNOT(q0, q1) + assert c._placement_cache is not None, "Cache should remain active for subsequent appends" + assert len(c) == 2 + assert c._placement_cache._length == 2 + # Verify qubit indices in cache + assert c._placement_cache._qubit_indices[q0] == 1 # CNOT is in moment 1 + assert c._placement_cache._qubit_indices[q1] == 1 # CNOT is in moment 1 + + # Operation that invalidates the cache: insert with NEW strategy + c.insert(1, cirq.X(q2), strategy=cirq.InsertStrategy.NEW) + # Circuit: H(q0),H(q1) | X(q2) | CNOT(q0,q1) + # Moment 0: H(q0), H(q1) + # Moment 1: X(q2) + # Moment 2: CNOT(q0, q1) + assert c._placement_cache is None, "Cache should be invalidated by insert (NEW)" + assert len(c) == 3 + + # First EARLIEST append after invalidation should trigger rebuild + c.append(cirq.Y(q0)) # Appends to a new moment 3 + # Circuit: H(q0),H(q1) | X(q2) | CNOT(q0,q1) | Y(q0) + assert c._placement_cache is not None, "Cache should be rebuilt on first EARLIEST append" + assert len(c) == 4 + assert c._placement_cache._length == 4 # Cache should now know about all 4 moments + + # Check cache integrity after rebuild + assert c._placement_cache._qubit_indices[q0] == 3 # Y(q0) is latest @ moment 3 + assert c._placement_cache._qubit_indices[q1] == 2 # CNOT(q0,q1) is latest for q1 @ moment 2 + assert c._placement_cache._qubit_indices[q2] == 1 # X(q2) is latest @ moment 1 + + # Subsequent EARLIEST appends should use the rebuilt cache + c.append(cirq.Z(q1)) # Appends to moment 3 with Y(q0) + # Circuit: H(q0),H(q1) | X(q2) | CNOT(q0,q1) | Y(q0), Z(q1) + assert c._placement_cache is not None, "Cache should stay active" + assert len(c) == 4 + assert c._placement_cache._length == 4 + assert c._placement_cache._qubit_indices[q1] == 3 # Z(q1) is now latest for q1 @ moment 3 + + # Another invalidating operation: __setitem__ + c[0] = Moment([ops.S(q0)]) # Moment 0: S(q0) + assert c._placement_cache is None, "Cache should be invalidated by __setitem__" + + # EARLIEST append to trigger rebuild again + c.append(ops.T(q2)) # Appends to moment 4 + # Circuit: S(q0) | X(q2) | CNOT(q0,q1) | Y(q0), Z(q1), T(q2) + assert c._placement_cache is not None, "Cache should be rebuilt again" + assert len(c) == 4 + assert c._placement_cache._length == 4 + assert c._placement_cache._qubit_indices[q0] == 3 # Y(q0) still latest for q0 @ moment 3 + assert c._placement_cache._qubit_indices[q1] == 3 # Z(q1) still latest for q1 @ moment 3 + assert c._placement_cache._qubit_indices[q2] == 2 # T(q2) is new latest for q2 @ moment 2 + + # Test with a different invalidating operation: clear_operations_touching + c.clear_operations_touching([q0], [0]) # Removes S(q0) from moment 0 + assert c._placement_cache is None, "Cache invalidated by clear_operations_touching" + + # Rebuild + c.append(cirq.H(q0)) # Appends to moment 4 (with T(q2)) + # Circuit: EMPTY | X(q2) | CNOT(q0,q1) | Y(q0), Z(q1) | T(q2), H(q0) + assert c._placement_cache is not None + assert len(c) == 5 + assert c._placement_cache._length == 5 + assert c._placement_cache._qubit_indices[q0] == 4 # H(q0) is new latest @ moment 4 + + # Test initialization with non-EARLIEST strategy invalidates cache if it was active + c_new_strat = cirq.Circuit(cirq.H(q0), cirq.H(q1), strategy=cirq.InsertStrategy.NEW) + assert ( + c_new_strat._placement_cache is None + ), "Cache should be None for NEW strategy init if ops are provided" + assert len(c_new_strat) == 2 + + # Test initialization with all moments also results in a None cache + c_all_moments = cirq.Circuit(Moment(cirq.H(q0)), Moment(cirq.H(q1))) + assert c_all_moments._placement_cache is None, "Cache should be None for all-moments init" + + # Test that if EARLIEST append itself creates new moments, cache length is correct + c_len_test = cirq.Circuit() + c_len_test.append(Moment(cirq.X(q0))) # Moment 0, cache len 1 + assert c_len_test._placement_cache is not None + c_len_test.append(Moment(cirq.Y(q0))) # Moment 1, cache len 2 + assert c_len_test._placement_cache is not None + assert c_len_test._placement_cache._length == 2 + c_len_test.append(cirq.Z(q0)) # This should go into a new moment 2 because Y(q0) is in moment 1 + assert c_len_test._placement_cache is not None + assert c_len_test._placement_cache._length == 3 + assert c_len_test._placement_cache._qubit_indices[q0] == 2 + + # Test inserting a Moment object with EARLIEST strategy when cache is active + c_moment_append = cirq.Circuit(cirq.H(q0)) # Cache active, length 1 + assert c_moment_append._placement_cache is not None + new_moment = Moment(cirq.X(q1)) + c_moment_append.append(new_moment) # Appends the moment at index 1 + assert len(c_moment_append) == 2 + assert c_moment_append[1] == new_moment + assert c_moment_append._placement_cache is not None + assert c_moment_append._placement_cache._length == 2 + assert c_moment_append._placement_cache._qubit_indices[q0] == 0 + assert c_moment_append._placement_cache._qubit_indices[q1] == 1 + + # Test _mutated correctly clears cache when preserve_placement_cache is False + c_mutated_test = cirq.Circuit(cirq.H(q0)) + assert c_mutated_test._placement_cache is not None + c_mutated_test._mutated(preserve_placement_cache=False) # Explicit call for testing this path + assert c_mutated_test._placement_cache is None + + # Test _mutated preserves cache when preserve_placement_cache is True + c_mutated_preserve = cirq.Circuit(cirq.H(q0)) + assert c_mutated_preserve._placement_cache is not None + initial_cache_obj = c_mutated_preserve._placement_cache + c_mutated_preserve._mutated(preserve_placement_cache=True) + assert c_mutated_preserve._placement_cache is initial_cache_obj + + # Test scenario: Init with EARLIEST (cache active) + # -> insert (NEW) (cache invalid) -> append (EARLIEST) (cache rebuild) + q_a, q_b = cirq.LineQubit.range(2) + circuit_scenario = cirq.Circuit(cirq.H(q_a), cirq.H(q_b), strategy=cirq.InsertStrategy.EARLIEST) + assert circuit_scenario._placement_cache is not None # Active + assert circuit_scenario._placement_cache._length == 1 + + circuit_scenario.insert(0, cirq.X(q_a), strategy=cirq.InsertStrategy.NEW) # Invalidates + # Circuit: X(q_a) | H(q_a), H(q_b) + assert circuit_scenario._placement_cache is None + + circuit_scenario.append(cirq.CNOT(q_a, q_b)) # Rebuilds and appends + # Circuit: X(q_a) | H(q_a), H(q_b) | CNOT(q_a, q_b) + assert circuit_scenario._placement_cache is not None + assert circuit_scenario._placement_cache._length == 3 + assert circuit_scenario._placement_cache._qubit_indices[q_a] == 2 + assert circuit_scenario._placement_cache._qubit_indices[q_b] == 2 + + +def test_cache_correctness_with_measurement_and_control_keys(): + q0, q1 = cirq.LineQubit.range(2) + c = cirq.Circuit() + + # Append measurement + c.append(cirq.measure(q0, key="m0")) # Moment 0 + assert c._placement_cache is not None + assert c._placement_cache._mkey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 0 + assert c._placement_cache._length == 1 + + # Append classically controlled op + c.append(cirq.X(q1).with_classical_controls("m0")) # Moment 1 + assert c._placement_cache is not None + assert c._placement_cache._ckey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 1 + assert c._placement_cache._length == 2 + assert c._placement_cache._qubit_indices[q1] == 1 + + # Invalidate cache + c.insert( + 0, cirq.H(q0), strategy=cirq.InsertStrategy.NEW + ) # Moment 0: H(q0) | measure(q0) | X(q1).c("m0") + assert c._placement_cache is None + + # Rebuild and append another measurement + c.append(cirq.measure(q1, key="m1")) # Moment 3 + # Circuit: H(q0) | measure(q0 key="m0") | X(q1).c("m0") | measure(q1 key="m1") + assert c._placement_cache is not None + assert c._placement_cache._length == 4 + assert c._placement_cache._mkey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 1 + assert c._placement_cache._mkey_indices.get(cirq.MeasurementKey.parse_serialized("m1")) == 3 + assert c._placement_cache._ckey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 2 + assert c._placement_cache._qubit_indices[q0] == 1 # measure(q0) is latest for q0 + assert c._placement_cache._qubit_indices[q1] == 3 # measure(q1) is latest for q1