Skip to content

Commit 3031b1f

Browse files
authored
Expand product sweep parameters iteratively rather than recursively (#7523)
The new unit test assertion involving a product of 1025 sweeps crashes the current implementation with a "max recursion depth exceeded" error, but passes with the new implementation. This is particularly relevant when dict_to_product_sweep is called with a large input dictionary. This re-introduces the change reverted in #7522 and addresses the issue which caused the failure: `itertools.chain.from_iterable` produces an `Iterator` as its output, meaning the output object can only be iterated through once before it is exhausted. However, `Params` is a type alias for `Iterable[tuple['cirq.TParamKey', 'cirq.TParamVal']]`, meaning it is expected for a given `Params` to be iterated through repeatedly. This is also necessary in order for the `Params` produced by a `Product` sweep to function correctly, as the `Params` yielded by the individual factors' `param_tuples()` can appear multiple times in the compound `Params` of the product's `param_tuples()`. To properly allow the yielded `Params` to be iterated through repeatedly, we now use `lambda values: tuple(itertools.chain.from_iterable(values)` on line 234 of sweeps.py whereas the previous implementation effectively had `lambda values: itertools.chain.from_iterables(values)`. Collecting into a `tuple` guarantees that we yield a repeatedly-iterable `Params` object. In addition to the new test `test_nested_product_zip()` which reproduced the error in the previous implementation, this change has also been tested against the internal library tests which were broken by the previous implementation.
1 parent 8e8fd80 commit 3031b1f

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

cirq-core/cirq/study/sweeps.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,16 +230,10 @@ def __len__(self) -> int:
230230
return length
231231

232232
def param_tuples(self) -> Iterator[Params]:
233-
def _gen(factors):
234-
if not factors:
235-
yield ()
236-
else:
237-
first, rest = factors[0], factors[1:]
238-
for first_values in first.param_tuples():
239-
for rest_values in _gen(rest):
240-
yield first_values + rest_values
241-
242-
return _gen(self.factors)
233+
yield from map(
234+
lambda values: tuple(itertools.chain.from_iterable(values)),
235+
itertools.product(*(factor.param_tuples() for factor in self.factors)),
236+
)
243237

244238
def __repr__(self) -> str:
245239
factors_repr = ', '.join(repr(f) for f in self.factors)

cirq-core/cirq/study/sweeps_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,30 @@ def test_product():
158158
assert _values(sweep, 'b') == [3, 3, 4, 4, 3, 3, 4, 4]
159159
assert _values(sweep, 'c') == [5, 6, 5, 6, 5, 6, 5, 6]
160160

161+
sweep = cirq.Points('a', [1, 2]) * (cirq.Points('b', [3, 4, 5]))
162+
assert list(map(list, sweep.param_tuples())) == [
163+
[('a', 1), ('b', 3)],
164+
[('a', 1), ('b', 4)],
165+
[('a', 1), ('b', 5)],
166+
[('a', 2), ('b', 3)],
167+
[('a', 2), ('b', 4)],
168+
[('a', 2), ('b', 5)],
169+
]
170+
171+
sweep = cirq.Product(*[cirq.Points(str(i), [0]) for i in range(1025)])
172+
assert list(map(list, sweep.param_tuples())) == [[(str(i), 0) for i in range(1025)]]
173+
174+
175+
def test_nested_product_zip():
176+
sweep = cirq.Product(
177+
cirq.Product(cirq.Points('a', [0]), cirq.Points('b', [0])),
178+
cirq.Zip(cirq.Points('c', [0, 1]), cirq.Points('d', [0, 1])),
179+
)
180+
assert list(map(list, sweep.param_tuples())) == [
181+
[('a', 0), ('b', 0), ('c', 0), ('d', 0)],
182+
[('a', 0), ('b', 0), ('c', 1), ('d', 1)],
183+
]
184+
161185

162186
def test_zip_addition():
163187
zip_sweep = cirq.Zip(cirq.Points('a', [1, 2]), cirq.Points('b', [3, 4]))
@@ -172,6 +196,7 @@ def test_empty_product():
172196
sweep = cirq.Product()
173197
assert len(sweep) == len(list(sweep)) == 1
174198
assert str(sweep) == 'Product()'
199+
assert list(map(list, sweep.param_tuples())) == [[]]
175200

176201

177202
def test_slice_access_error():

0 commit comments

Comments
 (0)