Skip to content

Fix: Optimize circuit append performance by rebuilding placement cache #7475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 65 additions & 22 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,25 +1777,48 @@ def __init__(
return
flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
if all(isinstance(c, Moment) for c in flattened_contents):
self._placement_cache = None
self._placement_cache = None # Direct moment loading doesn't use/need cache initially.
self._moments[:] = cast(Iterable[Moment], flattened_contents)
return
# EARLIEST strategy during __init__ will use _load_contents_with_earliest_strategy,
with _compat.block_overlapping_deprecation('.*'):
if strategy == InsertStrategy.EARLIEST:
self._load_contents_with_earliest_strategy(flattened_contents)
else:
# For non-EARLIEST strategies, the cache is not useful during this initial loading.
if self._placement_cache is not None:
self._placement_cache = None
self.append(flattened_contents, strategy=strategy)

def _mutated(self, *, preserve_placement_cache=False) -> None:
def _mutated(self, *, preserve_placement_cache: bool = False) -> None:
"""Clear cached properties in response to this circuit being mutated."""
self._all_qubits = None
self._frozen = None
self._is_measurement = None
self._is_parameterized = None
self._parameter_names = None
if not preserve_placement_cache:
if not preserve_placement_cache and self._placement_cache is not None:
self._placement_cache = None

def _rebuild_placement_cache(self) -> None:
"""Rebuilds the placement cache from the current _moments.

This method is called when an EARLIEST append is attempted and the
cache has been invalidated.
"""
self._placement_cache = _PlacementCache()
cache = self._placement_cache # Shorthand

for i, moment_val in enumerate(self._moments):
for op_in_moment in moment_val.operations:
for q_op in op_in_moment.qubits:
cache._qubit_indices[q_op] = i
for mk_op in protocols.measurement_key_objs(op_in_moment):
cache._mkey_indices[mk_op] = i
for ck_op in protocols.control_keys(op_in_moment):
cache._ckey_indices[ck_op] = i
cache._length = len(self._moments)

@classmethod
def _from_moments(cls, moments: Iterable[cirq.Moment]) -> Circuit:
new_circuit = Circuit()
Expand Down Expand Up @@ -2116,8 +2139,14 @@ def insert(
ValueError: Bad insertion strategy.
"""
# limit index to 0..len(self._moments), also deal with indices smaller 0
k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0)
if strategy != InsertStrategy.EARLIEST or k != len(self._moments):
k_idx = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0)

# FIX: Move cache logic here, use original index
is_earliest_append = strategy == InsertStrategy.EARLIEST and index == len(self._moments)
if is_earliest_append:
if self._placement_cache is None:
self._rebuild_placement_cache()
elif self._placement_cache is not None:
self._placement_cache = None
mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree))
if self._placement_cache:
Expand All @@ -2126,35 +2155,41 @@ def insert(
batches = [[mop] for mop in mops] # Each op goes into its own moment.
else:
batches = list(_group_into_moment_compatible(mops))

current_k_idx = k_idx # Use a local variable for the current index within the batch loop
for batch in batches:
# Insert a moment if inline/earliest and _any_ op in the batch requires it.
if (
not self._placement_cache
and not isinstance(batch[0], Moment)
and strategy in (InsertStrategy.INLINE, InsertStrategy.EARLIEST)
and not all(
(strategy is InsertStrategy.EARLIEST and self._can_add_op_at(k, op))
or (k > 0 and self._can_add_op_at(k - 1, op))
(strategy is InsertStrategy.EARLIEST and self._can_add_op_at(current_k_idx, op))
or (current_k_idx > 0 and self._can_add_op_at(current_k_idx - 1, op))
for op in cast(list[cirq.Operation], batch)
)
):
self._moments.insert(k, Moment())
self._moments.insert(current_k_idx, Moment())
if strategy is InsertStrategy.INLINE:
k += 1
current_k_idx += 1
max_p = 0
current_strategy_for_batch = strategy
for moment_or_op in batch:
# Determine Placement
if self._placement_cache:
p = self._placement_cache.append(moment_or_op)
elif isinstance(moment_or_op, Moment):
p = k
elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE):
self._moments.insert(k, Moment())
p = k
elif strategy is InsertStrategy.INLINE:
p = k - 1
p = current_k_idx
elif current_strategy_for_batch in (
InsertStrategy.NEW,
InsertStrategy.NEW_THEN_INLINE,
):
self._moments.insert(current_k_idx, Moment())
p = current_k_idx
elif current_strategy_for_batch is InsertStrategy.INLINE:
p = current_k_idx - 1
else: # InsertStrategy.EARLIEST:
p = self.earliest_available_moment(moment_or_op, end_moment_index=k)
p = self.earliest_available_moment(moment_or_op, end_moment_index=current_k_idx)
# Place
if isinstance(moment_or_op, Moment):
self._moments.insert(p, moment_or_op)
Expand All @@ -2164,12 +2199,20 @@ def insert(
self._moments[p] = self._moments[p].with_operation(moment_or_op)
# Iterate
max_p = max(p, max_p)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
k += 1
k = max(k, max_p + 1)
self._mutated(preserve_placement_cache=True)
return k
if current_strategy_for_batch is InsertStrategy.NEW_THEN_INLINE:
current_strategy_for_batch = InsertStrategy.INLINE
if p == current_k_idx:
current_k_idx += 1
# Update current_k_idx to be after the newly inserted operations for the next batch.
current_k_idx = max(current_k_idx, max_p + 1)
# At the end, update k_idx to reflect the final position
k_idx = current_k_idx
# Preserve cache only if it was an EARLIEST append and the cache was
# successfully used/rebuilt.
self._mutated(
preserve_placement_cache=(is_earliest_append and self._placement_cache is not None)
)
return k_idx

def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int:
"""Writes operations inline into an area of the circuit.
Expand Down
180 changes: 179 additions & 1 deletion cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import sympy

import cirq
from cirq import circuits, ops
from cirq import circuits, Moment, ops
from cirq.testing.devices import ValidatingTestDevice


Expand Down Expand Up @@ -4910,3 +4910,181 @@ def test_append_speed():
duration = time.perf_counter() - t
assert len(c) == moments
assert duration < 5


def test_placement_cache_rebuild_and_invalidation():
q0, q1, q2 = cirq.LineQubit.range(3)
c = cirq.Circuit()

# Initial appends (EARLIEST by default) should use/build cache
c.append(cirq.H(q0))
c.append(cirq.H(q1)) # Moment 0: H(q0), H(q1)
assert c._placement_cache is not None, "Cache should be active after initial appends"
assert len(c) == 1
assert c._placement_cache._length == 1
assert c._placement_cache._qubit_indices.get(q0) == 0
assert c._placement_cache._qubit_indices.get(q1) == 0

c.append(cirq.CNOT(q0, q1)) # Moment 1: CNOT(q0, q1)
assert c._placement_cache is not None, "Cache should remain active for subsequent appends"
assert len(c) == 2
assert c._placement_cache._length == 2
# Verify qubit indices in cache
assert c._placement_cache._qubit_indices[q0] == 1 # CNOT is in moment 1
assert c._placement_cache._qubit_indices[q1] == 1 # CNOT is in moment 1

# Operation that invalidates the cache: insert with NEW strategy
c.insert(1, cirq.X(q2), strategy=cirq.InsertStrategy.NEW)
# Circuit: H(q0),H(q1) | X(q2) | CNOT(q0,q1)
# Moment 0: H(q0), H(q1)
# Moment 1: X(q2)
# Moment 2: CNOT(q0, q1)
assert c._placement_cache is None, "Cache should be invalidated by insert (NEW)"
assert len(c) == 3

# First EARLIEST append after invalidation should trigger rebuild
c.append(cirq.Y(q0)) # Appends to a new moment 3
# Circuit: H(q0),H(q1) | X(q2) | CNOT(q0,q1) | Y(q0)
assert c._placement_cache is not None, "Cache should be rebuilt on first EARLIEST append"
assert len(c) == 4
assert c._placement_cache._length == 4 # Cache should now know about all 4 moments

# Check cache integrity after rebuild
assert c._placement_cache._qubit_indices[q0] == 3 # Y(q0) is latest @ moment 3
assert c._placement_cache._qubit_indices[q1] == 2 # CNOT(q0,q1) is latest for q1 @ moment 2
assert c._placement_cache._qubit_indices[q2] == 1 # X(q2) is latest @ moment 1

# Subsequent EARLIEST appends should use the rebuilt cache
c.append(cirq.Z(q1)) # Appends to moment 3 with Y(q0)
# Circuit: H(q0),H(q1) | X(q2) | CNOT(q0,q1) | Y(q0), Z(q1)
assert c._placement_cache is not None, "Cache should stay active"
assert len(c) == 4
assert c._placement_cache._length == 4
assert c._placement_cache._qubit_indices[q1] == 3 # Z(q1) is now latest for q1 @ moment 3

# Another invalidating operation: __setitem__
c[0] = Moment([ops.S(q0)]) # Moment 0: S(q0)
assert c._placement_cache is None, "Cache should be invalidated by __setitem__"

# EARLIEST append to trigger rebuild again
c.append(ops.T(q2)) # Appends to moment 4
# Circuit: S(q0) | X(q2) | CNOT(q0,q1) | Y(q0), Z(q1), T(q2)
assert c._placement_cache is not None, "Cache should be rebuilt again"
assert len(c) == 4
assert c._placement_cache._length == 4
assert c._placement_cache._qubit_indices[q0] == 3 # Y(q0) still latest for q0 @ moment 3
assert c._placement_cache._qubit_indices[q1] == 3 # Z(q1) still latest for q1 @ moment 3
assert c._placement_cache._qubit_indices[q2] == 2 # T(q2) is new latest for q2 @ moment 2

# Test with a different invalidating operation: clear_operations_touching
c.clear_operations_touching([q0], [0]) # Removes S(q0) from moment 0
assert c._placement_cache is None, "Cache invalidated by clear_operations_touching"

# Rebuild
c.append(cirq.H(q0)) # Appends to moment 4 (with T(q2))
# Circuit: EMPTY | X(q2) | CNOT(q0,q1) | Y(q0), Z(q1) | T(q2), H(q0)
assert c._placement_cache is not None
assert len(c) == 5
assert c._placement_cache._length == 5
assert c._placement_cache._qubit_indices[q0] == 4 # H(q0) is new latest @ moment 4

# Test initialization with non-EARLIEST strategy invalidates cache if it was active
c_new_strat = cirq.Circuit(cirq.H(q0), cirq.H(q1), strategy=cirq.InsertStrategy.NEW)
assert (
c_new_strat._placement_cache is None
), "Cache should be None for NEW strategy init if ops are provided"
assert len(c_new_strat) == 2

# Test initialization with all moments also results in a None cache
c_all_moments = cirq.Circuit(Moment(cirq.H(q0)), Moment(cirq.H(q1)))
assert c_all_moments._placement_cache is None, "Cache should be None for all-moments init"

# Test that if EARLIEST append itself creates new moments, cache length is correct
c_len_test = cirq.Circuit()
c_len_test.append(Moment(cirq.X(q0))) # Moment 0, cache len 1
assert c_len_test._placement_cache is not None
c_len_test.append(Moment(cirq.Y(q0))) # Moment 1, cache len 2
assert c_len_test._placement_cache is not None
assert c_len_test._placement_cache._length == 2
c_len_test.append(cirq.Z(q0)) # This should go into a new moment 2 because Y(q0) is in moment 1
assert c_len_test._placement_cache is not None
assert c_len_test._placement_cache._length == 3
assert c_len_test._placement_cache._qubit_indices[q0] == 2

# Test inserting a Moment object with EARLIEST strategy when cache is active
c_moment_append = cirq.Circuit(cirq.H(q0)) # Cache active, length 1
assert c_moment_append._placement_cache is not None
new_moment = Moment(cirq.X(q1))
c_moment_append.append(new_moment) # Appends the moment at index 1
assert len(c_moment_append) == 2
assert c_moment_append[1] == new_moment
assert c_moment_append._placement_cache is not None
assert c_moment_append._placement_cache._length == 2
assert c_moment_append._placement_cache._qubit_indices[q0] == 0
assert c_moment_append._placement_cache._qubit_indices[q1] == 1

# Test _mutated correctly clears cache when preserve_placement_cache is False
c_mutated_test = cirq.Circuit(cirq.H(q0))
assert c_mutated_test._placement_cache is not None
c_mutated_test._mutated(preserve_placement_cache=False) # Explicit call for testing this path
assert c_mutated_test._placement_cache is None

# Test _mutated preserves cache when preserve_placement_cache is True
c_mutated_preserve = cirq.Circuit(cirq.H(q0))
assert c_mutated_preserve._placement_cache is not None
initial_cache_obj = c_mutated_preserve._placement_cache
c_mutated_preserve._mutated(preserve_placement_cache=True)
assert c_mutated_preserve._placement_cache is initial_cache_obj

# Test scenario: Init with EARLIEST (cache active)
# -> insert (NEW) (cache invalid) -> append (EARLIEST) (cache rebuild)
q_a, q_b = cirq.LineQubit.range(2)
circuit_scenario = cirq.Circuit(cirq.H(q_a), cirq.H(q_b), strategy=cirq.InsertStrategy.EARLIEST)
assert circuit_scenario._placement_cache is not None # Active
assert circuit_scenario._placement_cache._length == 1

circuit_scenario.insert(0, cirq.X(q_a), strategy=cirq.InsertStrategy.NEW) # Invalidates
# Circuit: X(q_a) | H(q_a), H(q_b)
assert circuit_scenario._placement_cache is None

circuit_scenario.append(cirq.CNOT(q_a, q_b)) # Rebuilds and appends
# Circuit: X(q_a) | H(q_a), H(q_b) | CNOT(q_a, q_b)
assert circuit_scenario._placement_cache is not None
assert circuit_scenario._placement_cache._length == 3
assert circuit_scenario._placement_cache._qubit_indices[q_a] == 2
assert circuit_scenario._placement_cache._qubit_indices[q_b] == 2


def test_cache_correctness_with_measurement_and_control_keys():
q0, q1 = cirq.LineQubit.range(2)
c = cirq.Circuit()

# Append measurement
c.append(cirq.measure(q0, key="m0")) # Moment 0
assert c._placement_cache is not None
assert c._placement_cache._mkey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 0
assert c._placement_cache._length == 1

# Append classically controlled op
c.append(cirq.X(q1).with_classical_controls("m0")) # Moment 1
assert c._placement_cache is not None
assert c._placement_cache._ckey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 1
assert c._placement_cache._length == 2
assert c._placement_cache._qubit_indices[q1] == 1

# Invalidate cache
c.insert(
0, cirq.H(q0), strategy=cirq.InsertStrategy.NEW
) # Moment 0: H(q0) | measure(q0) | X(q1).c("m0")
assert c._placement_cache is None

# Rebuild and append another measurement
c.append(cirq.measure(q1, key="m1")) # Moment 3
# Circuit: H(q0) | measure(q0 key="m0") | X(q1).c("m0") | measure(q1 key="m1")
assert c._placement_cache is not None
assert c._placement_cache._length == 4
assert c._placement_cache._mkey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 1
assert c._placement_cache._mkey_indices.get(cirq.MeasurementKey.parse_serialized("m1")) == 3
assert c._placement_cache._ckey_indices.get(cirq.MeasurementKey.parse_serialized("m0")) == 2
assert c._placement_cache._qubit_indices[q0] == 1 # measure(q0) is latest for q0
assert c._placement_cache._qubit_indices[q1] == 3 # measure(q1) is latest for q1