diff --git a/cirq-google/cirq_google/__init__.py b/cirq-google/cirq_google/__init__.py index 91250d859d3..81531df86e9 100644 --- a/cirq-google/cirq_google/__init__.py +++ b/cirq-google/cirq_google/__init__.py @@ -64,6 +64,7 @@ PhysicalZTag as PhysicalZTag, SYC as SYC, SycamoreGate as SycamoreGate, + WaitGateWithUnit as WaitGateWithUnit, WILLOW as WILLOW, WillowGate as WillowGate, ) diff --git a/cirq-google/cirq_google/experimental/analog_experiments/__init__.py b/cirq-google/cirq_google/experimental/analog_experiments/__init__.py index 13bda79ea9b..63af6311a74 100644 --- a/cirq-google/cirq_google/experimental/analog_experiments/__init__.py +++ b/cirq-google/cirq_google/experimental/analog_experiments/__init__.py @@ -18,3 +18,7 @@ FrequencyMap as FrequencyMap, AnalogTrajectory as AnalogTrajectory, ) + +from cirq_google.experimental.analog_experiments.generic_analog_circuit import ( + GenericAnalogCircuitBuilder as GenericAnalogCircuitBuilder, +) diff --git a/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util.py b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util.py index 4635aa4370e..9f60e034d74 100644 --- a/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util.py +++ b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util.py @@ -36,11 +36,13 @@ class FrequencyMap: duration: duration of step qubit_freqs: dict describing qubit frequencies at end of step (None if idle) couplings: dict describing coupling rates at end of step + is_wait_step: a bool indicating only wait gate should be added. """ duration: su.ValueOrSymbol qubit_freqs: dict[str, su.ValueOrSymbol | None] couplings: dict[tuple[str, str], su.ValueOrSymbol] + is_wait_step: bool def _is_parameterized_(self) -> bool: return ( @@ -68,6 +70,7 @@ def _resolve_parameters_( couplings={ k: su.direct_symbol_replacement(v, resolver_) for k, v in self.couplings.items() }, + is_wait_step=self.is_wait_step, ) @@ -129,9 +132,11 @@ def from_sparse_trajectory( full_trajectory: list[FrequencyMap] = [] init_qubit_freq_dict: dict[str, tu.Value | None] = {q: None for q in qubits} init_g_dict: dict[tuple[str, str], tu.Value] = {p: 0 * tu.MHz for p in pairs} - full_trajectory.append(FrequencyMap(0 * tu.ns, init_qubit_freq_dict, init_g_dict)) + full_trajectory.append(FrequencyMap(0 * tu.ns, init_qubit_freq_dict, init_g_dict, False)) for dt, qubit_freq_dict, g_dict in sparse_trajectory: + # When both qubit_freq_dict and g_dict is empty, it is a wait step. + is_wait_step = not (qubit_freq_dict or g_dict) # If no freq provided, set equal to previous new_qubit_freq_dict = { q: qubit_freq_dict.get(q, full_trajectory[-1].qubit_freqs.get(q)) for q in qubits @@ -141,7 +146,7 @@ def from_sparse_trajectory( p: g_dict.get(p, full_trajectory[-1].couplings.get(p)) for p in pairs # type: ignore[misc] } - full_trajectory.append(FrequencyMap(dt, new_qubit_freq_dict, new_g_dict)) + full_trajectory.append(FrequencyMap(dt, new_qubit_freq_dict, new_g_dict, is_wait_step)) return cls(full_trajectory=full_trajectory, qubits=qubits, pairs=pairs) def get_full_trajectory_with_resolved_idles( diff --git a/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util_test.py b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util_test.py index a39abd1a01f..d55afd19de0 100644 --- a/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util_test.py +++ b/cirq-google/cirq_google/experimental/analog_experiments/analog_trajectory_util_test.py @@ -26,6 +26,7 @@ def freq_map() -> atu.FrequencyMap: 10 * tu.ns, {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": sympy.Symbol("f_q0_2")}, {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): sympy.Symbol("g_q0_1_q0_2")}, + False, ) @@ -42,6 +43,7 @@ def test_freq_map_resolve(freq_map: atu.FrequencyMap) -> None: 10 * tu.ns, {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 6 * tu.GHz}, {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 7 * tu.MHz}, + False, ) @@ -52,36 +54,47 @@ def test_freq_map_resolve(freq_map: atu.FrequencyMap) -> None: def sparse_trajectory() -> list[FreqMapType]: traj1: FreqMapType = (20 * tu.ns, {"q0_1": 5 * tu.GHz}, {}) traj2: FreqMapType = (30 * tu.ns, {"q0_2": 8 * tu.GHz}, {}) - traj3: FreqMapType = ( + traj3: FreqMapType = (35 * tu.ns, {}, {}) + traj4: FreqMapType = ( 40 * tu.ns, {"q0_0": 8 * tu.GHz, "q0_1": None, "q0_2": None}, {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 8 * tu.MHz}, ) - return [traj1, traj2, traj3] + return [traj1, traj2, traj3, traj4] def test_full_traj(sparse_trajectory: list[FreqMapType]) -> None: analog_traj = atu.AnalogTrajectory.from_sparse_trajectory(sparse_trajectory) - assert len(analog_traj.full_trajectory) == 4 + assert len(analog_traj.full_trajectory) == 5 assert analog_traj.full_trajectory[0] == atu.FrequencyMap( 0 * tu.ns, {"q0_0": None, "q0_1": None, "q0_2": None}, {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + False, ) assert analog_traj.full_trajectory[1] == atu.FrequencyMap( 20 * tu.ns, {"q0_0": None, "q0_1": 5 * tu.GHz, "q0_2": None}, {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + False, ) assert analog_traj.full_trajectory[2] == atu.FrequencyMap( 30 * tu.ns, {"q0_0": None, "q0_1": 5 * tu.GHz, "q0_2": 8 * tu.GHz}, {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + False, ) assert analog_traj.full_trajectory[3] == atu.FrequencyMap( + 35 * tu.ns, + {"q0_0": None, "q0_1": 5 * tu.GHz, "q0_2": 8 * tu.GHz}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + True, + ) + assert analog_traj.full_trajectory[4] == atu.FrequencyMap( 40 * tu.ns, {"q0_0": 8 * tu.GHz, "q0_1": None, "q0_2": None}, {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 8 * tu.MHz}, + False, ) @@ -92,26 +105,36 @@ def test_get_full_trajectory_with_resolved_idles(sparse_trajectory: list[FreqMap {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 7 * tu.GHz} ) - assert len(resolved_full_traj) == 4 + assert len(resolved_full_traj) == 5 assert resolved_full_traj[0] == atu.FrequencyMap( 0 * tu.ns, {"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 7 * tu.GHz}, {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + False, ) assert resolved_full_traj[1] == atu.FrequencyMap( 20 * tu.ns, {"q0_0": 5 * tu.GHz, "q0_1": 5 * tu.GHz, "q0_2": 7 * tu.GHz}, {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + False, ) assert resolved_full_traj[2] == atu.FrequencyMap( 30 * tu.ns, {"q0_0": 5 * tu.GHz, "q0_1": 5 * tu.GHz, "q0_2": 8 * tu.GHz}, {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + False, ) assert resolved_full_traj[3] == atu.FrequencyMap( + 35 * tu.ns, + {"q0_0": 5 * tu.GHz, "q0_1": 5 * tu.GHz, "q0_2": 8 * tu.GHz}, + {("q0_0", "q0_1"): 0 * tu.MHz, ("q0_1", "q0_2"): 0 * tu.MHz}, + True, + ) + assert resolved_full_traj[4] == atu.FrequencyMap( 40 * tu.ns, {"q0_0": 8 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": 7 * tu.GHz}, {("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 8 * tu.MHz}, + False, ) diff --git a/cirq-google/cirq_google/experimental/analog_experiments/generic_analog_circuit.py b/cirq-google/cirq_google/experimental/analog_experiments/generic_analog_circuit.py new file mode 100644 index 00000000000..1859d8cb788 --- /dev/null +++ b/cirq-google/cirq_google/experimental/analog_experiments/generic_analog_circuit.py @@ -0,0 +1,144 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import re + +import cirq +from cirq_google.experimental.analog_experiments import analog_trajectory_util as atu +from cirq_google.ops import analog_detune_gates as adg, wait_gate as wg +from cirq_google.study import symbol_util as su + + +def _get_neighbor_freqs( + qubit_pair: tuple[str, str], qubit_freq_dict: dict[str, su.ValueOrSymbol | None] +) -> tuple[su.ValueOrSymbol | None, su.ValueOrSymbol | None]: + """Get neighbor freqs from qubit_freq_dict given the pair.""" + sorted_pair = sorted(qubit_pair, key=_to_grid_qubit) + return (qubit_freq_dict[sorted_pair[0]], qubit_freq_dict[sorted_pair[1]]) + + +@functools.cache +def _to_grid_qubit(qubit_name: str) -> cirq.GridQubit: + match = re.compile(r"^q(\d+)_(\d+)$").match(qubit_name) + if match is None: + raise ValueError(f"Invalid qubit name format: '{qubit_name}'. Expected 'q_'.") + return cirq.GridQubit(int(match[1]), int(match[2])) + + +def _coupler_name_from_qubit_pair(qubit_pair: tuple[str, str]) -> str: + sorted_pair = sorted(qubit_pair, key=_to_grid_qubit) + return f"c_{sorted_pair[0]}_{sorted_pair[1]}" + + +def _get_neighbor_coupler_freqs( + qubit_name: str, coupler_g_dict: dict[tuple[str, str], su.ValueOrSymbol] +) -> dict[str, su.ValueOrSymbol]: + """Get neighbor coupler coupling strength g given qubit name.""" + return { + _coupler_name_from_qubit_pair(pair): g + for pair, g in coupler_g_dict.items() + if qubit_name in pair + } + + +class GenericAnalogCircuitBuilder: + """Class for making arbitrary analog circuits. The circuit is defined by an + AnalogTrajectory object. The class constructs the circuit from AnalogDetune + pulses, which automatically calculate the necessary bias amps to both qubits + and couplers, using tu.Values from analog calibration whenever available. + + Attributes: + trajectory: AnalogTrajectory object defining the circuit + g_ramp_shaping: coupling ramps are shaped according to ramp_shape_exp if True + qubits: list of qubits in the circuit + pairs: list of couplers in the circuit + ramp_shape_exp: exponent of g_ramp (g proportional to t^ramp_shape_exp) + interpolate_coupling_cal: interpolates between calibrated coupling tu.Values if True + linear_qubit_ramp: if True, the qubit ramp is linear. if false, a cosine shaped + ramp is used. + """ + + def __init__( + self, + trajectory: atu.AnalogTrajectory, + g_ramp_shaping: bool = False, + ramp_shape_exp: int = 1, + interpolate_coupling_cal: bool = False, + linear_qubit_ramp: bool = True, + ): + self.trajectory = trajectory + self.g_ramp_shaping = g_ramp_shaping + self.ramp_shape_exp = ramp_shape_exp + self.interpolate_coupling_cal = interpolate_coupling_cal + self.linear_qubit_ramp = linear_qubit_ramp + + def make_circuit(self) -> cirq.Circuit: + """Assemble moments described in trajectory.""" + prev_freq_map = self.trajectory.full_trajectory[0] + moments = [] + for freq_map in self.trajectory.full_trajectory[1:]: + if freq_map.is_wait_step: + targets = [_to_grid_qubit(q) for q in self.trajectory.qubits] + wait_gate = wg.WaitGateWithUnit( + freq_map.duration, qid_shape=cirq.qid_shape(targets) + ) + moment = cirq.Moment(wait_gate.on(*targets)) + else: + moment = self.make_one_moment(freq_map, prev_freq_map) + moments.append(moment) + prev_freq_map = freq_map + + return cirq.Circuit.from_moments(*moments) + + def make_one_moment( + self, freq_map: atu.FrequencyMap, prev_freq_map: atu.FrequencyMap + ) -> cirq.Moment: + """Make one moment of analog detune qubit and coupler gates given freqs.""" + qubit_gates = [] + for q, freq in freq_map.qubit_freqs.items(): + qubit_gates.append( + adg.AnalogDetuneQubit( + length=freq_map.duration, + w=freq_map.duration, + target_freq=freq, + prev_freq=prev_freq_map.qubit_freqs.get(q), + neighbor_coupler_g_dict=_get_neighbor_coupler_freqs(q, freq_map.couplings), + prev_neighbor_coupler_g_dict=_get_neighbor_coupler_freqs( + q, prev_freq_map.couplings + ), + linear_rise=self.linear_qubit_ramp, + ).on(_to_grid_qubit(q)) + ) + coupler_gates = [] + for p, g_max in freq_map.couplings.items(): + # Currently skipping the step if these are the same. + # However, change in neighbor qubit freq could potentially change coupler amp + if g_max == prev_freq_map.couplings[p]: + continue + + coupler_gates.append( + adg.AnalogDetuneCouplerOnly( + length=freq_map.duration, + w=freq_map.duration, + g_0=prev_freq_map.couplings[p], + g_max=g_max, + g_ramp_exponent=self.ramp_shape_exp, + neighbor_qubits_freq=_get_neighbor_freqs(p, freq_map.qubit_freqs), + prev_neighbor_qubits_freq=_get_neighbor_freqs(p, prev_freq_map.qubit_freqs), + interpolate_coupling_cal=self.interpolate_coupling_cal, + ).on(*sorted([_to_grid_qubit(p[0]), _to_grid_qubit(p[1])])) + ) + + return cirq.Moment(qubit_gates + coupler_gates) diff --git a/cirq-google/cirq_google/experimental/analog_experiments/generic_analog_circuit_test.py b/cirq-google/cirq_google/experimental/analog_experiments/generic_analog_circuit_test.py new file mode 100644 index 00000000000..3996c504a61 --- /dev/null +++ b/cirq-google/cirq_google/experimental/analog_experiments/generic_analog_circuit_test.py @@ -0,0 +1,158 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import sympy +import tunits as tu + +import cirq +from cirq_google.experimental.analog_experiments import ( + analog_trajectory_util as atu, + generic_analog_circuit as gac, +) +from cirq_google.ops.analog_detune_gates import AnalogDetuneCouplerOnly, AnalogDetuneQubit + + +def test_get_neighbor_freqs() -> None: + pair = ("q0_0", "q0_1") + qubit_freq_dict = {"q0_0": 5 * tu.GHz, "q0_1": sympy.Symbol("f_q"), "q0_2": 6 * tu.GHz} + neighbor_freqs = gac._get_neighbor_freqs(pair, qubit_freq_dict) + assert neighbor_freqs == (5 * tu.GHz, sympy.Symbol("f_q")) + + +def test_to_grid_qubit() -> None: + grid_qubit = gac._to_grid_qubit("q0_1") + assert grid_qubit == cirq.GridQubit(0, 1) + + with pytest.raises(ValueError, match="Invalid qubit name format"): + gac._to_grid_qubit("q1") + + +def test_coupler_name_from_qubit_pair() -> None: + pair = ("q0_0", "q0_1") + coupler_name = gac._coupler_name_from_qubit_pair(pair) + assert coupler_name == "c_q0_0_q0_1" + + pair = ("q9_0", "q10_0") + coupler_name = gac._coupler_name_from_qubit_pair(pair) + assert coupler_name == "c_q9_0_q10_0" + + pair = ("q7_8", "q7_7") + coupler_name = gac._coupler_name_from_qubit_pair(pair) + assert coupler_name == "c_q7_7_q7_8" + + +def test_make_one_moment_of_generic_analog_circuit() -> None: + freq_map = atu.FrequencyMap( + duration=3 * tu.ns, + qubit_freqs={"q0_0": 5 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": sympy.Symbol("f_q0_2")}, + couplings={("q0_0", "q0_1"): 5 * tu.MHz, ("q0_1", "q0_2"): 6 * tu.MHz}, + is_wait_step=False, + ) + prev_freq_map = atu.FrequencyMap( + duration=9 * tu.ns, + qubit_freqs={"q0_0": 4 * tu.GHz, "q0_1": 6 * tu.GHz, "q0_2": sympy.Symbol("f_q0_2")}, + couplings={("q0_0", "q0_1"): 2 * tu.MHz, ("q0_1", "q0_2"): 3 * tu.MHz}, + is_wait_step=False, + ) + + trajectory = None # we don't need trajector in this test. + builder = gac.GenericAnalogCircuitBuilder(trajectory) # type: ignore + moment = builder.make_one_moment(freq_map, prev_freq_map) + + assert len(moment.operations) == 5 + # Three detune qubit gates + assert moment.operations[0] == AnalogDetuneQubit( + length=3 * tu.ns, + w=3 * tu.ns, + target_freq=5 * tu.GHz, + prev_freq=4 * tu.GHz, + neighbor_coupler_g_dict={"c_q0_0_q0_1": 5 * tu.MHz}, + prev_neighbor_coupler_g_dict={"c_q0_0_q0_1": 2 * tu.MHz}, + linear_rise=True, + ).on(cirq.GridQubit(0, 0)) + assert moment.operations[1] == AnalogDetuneQubit( + length=3 * tu.ns, + w=3 * tu.ns, + target_freq=6 * tu.GHz, + prev_freq=6 * tu.GHz, + neighbor_coupler_g_dict={"c_q0_0_q0_1": 5 * tu.MHz, "c_q0_1_q0_2": 6 * tu.MHz}, + prev_neighbor_coupler_g_dict={"c_q0_0_q0_1": 2 * tu.MHz, "c_q0_1_q0_2": 3 * tu.MHz}, + linear_rise=True, + ).on(cirq.GridQubit(0, 1)) + assert moment.operations[2] == AnalogDetuneQubit( + length=3 * tu.ns, + w=3 * tu.ns, + target_freq=sympy.Symbol("f_q0_2"), + prev_freq=sympy.Symbol("f_q0_2"), + neighbor_coupler_g_dict={"c_q0_1_q0_2": 6 * tu.MHz}, + prev_neighbor_coupler_g_dict={"c_q0_1_q0_2": 3 * tu.MHz}, + linear_rise=True, + ).on(cirq.GridQubit(0, 2)) + + # Two detune coupler only gates + assert moment.operations[3] == AnalogDetuneCouplerOnly( + length=3 * tu.ns, + w=3 * tu.ns, + g_0=2 * tu.MHz, + g_max=5 * tu.MHz, + g_ramp_exponent=1, + neighbor_qubits_freq=(5 * tu.GHz, 6 * tu.GHz), + prev_neighbor_qubits_freq=(4 * tu.GHz, 6 * tu.GHz), + interpolate_coupling_cal=False, + ).on(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)) + assert moment.operations[4] == AnalogDetuneCouplerOnly( + length=3 * tu.ns, + w=3 * tu.ns, + g_0=3 * tu.MHz, + g_max=6 * tu.MHz, + g_ramp_exponent=1, + neighbor_qubits_freq=(6 * tu.GHz, sympy.Symbol("f_q0_2")), + prev_neighbor_qubits_freq=(6 * tu.GHz, sympy.Symbol("f_q0_2")), + interpolate_coupling_cal=False, + ).on(cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)) + + +def test_generic_analog_make_circuit() -> None: + trajectory = atu.AnalogTrajectory.from_sparse_trajectory( + [ + (5 * tu.ns, {"q0_0": 5 * tu.GHz}, {}), + (sympy.Symbol('t'), {}, {}), + ( + 10 * tu.ns, + {"q0_0": 8 * tu.GHz, "q0_1": sympy.Symbol('f')}, + {("q0_0", "q0_1"): -5 * tu.MHz}, + ), + (3 * tu.ns, {}, {}), + (2 * tu.ns, {"q0_1": 4 * tu.GHz}, {}), + ] + ) + builder = gac.GenericAnalogCircuitBuilder(trajectory) + circuit = builder.make_circuit() + + assert len(circuit) == 5 + for op in circuit[0].operations: + assert isinstance(op.gate, AnalogDetuneQubit) + for op in circuit[1].operations: + assert isinstance(op.gate, cirq.WaitGate) + + assert isinstance(circuit[2].operations[0].gate, AnalogDetuneQubit) + assert isinstance(circuit[2].operations[1].gate, AnalogDetuneQubit) + assert isinstance(circuit[2].operations[2].gate, AnalogDetuneCouplerOnly) + + for op in circuit[3].operations: + assert isinstance(op.gate, cirq.WaitGate) + + for op in circuit[4].operations: + assert isinstance(op.gate, AnalogDetuneQubit) diff --git a/cirq-google/cirq_google/json_resolver_cache.py b/cirq-google/cirq_google/json_resolver_cache.py index 0f2339c66bf..8df67ffb5f8 100644 --- a/cirq-google/cirq_google/json_resolver_cache.py +++ b/cirq-google/cirq_google/json_resolver_cache.py @@ -56,6 +56,7 @@ def _old_xmon(*args, **kwargs): cirq_google.experimental.PerQubitDepolarizingWithDampedReadoutNoiseModel ), 'SycamoreGate': cirq_google.SycamoreGate, + 'WaitGateWithUnit': cirq_google.WaitGateWithUnit, 'WillowGate': cirq_google.WillowGate, # cirq_google.GateTabulation has been removed and replaced by cirq.TwoQubitGateTabulation. 'GateTabulation': TwoQubitGateTabulation, diff --git a/cirq-google/cirq_google/json_test_data/WaitGateWithUnit.json b/cirq-google/cirq_google/json_test_data/WaitGateWithUnit.json new file mode 100644 index 00000000000..700e06e86a8 --- /dev/null +++ b/cirq-google/cirq_google/json_test_data/WaitGateWithUnit.json @@ -0,0 +1,7 @@ +{ + "cirq_type": "WaitGateWithUnit", + "duration": { + "cirq_type": "sympy.Symbol", + "name": "d" + } +} \ No newline at end of file diff --git a/cirq-google/cirq_google/json_test_data/WaitGateWithUnit.repr b/cirq-google/cirq_google/json_test_data/WaitGateWithUnit.repr new file mode 100644 index 00000000000..e083ed001fe --- /dev/null +++ b/cirq-google/cirq_google/json_test_data/WaitGateWithUnit.repr @@ -0,0 +1 @@ +cirq_google.WaitGateWithUnit(duration=sympy.Symbol("d")) \ No newline at end of file diff --git a/cirq-google/cirq_google/ops/__init__.py b/cirq-google/cirq_google/ops/__init__.py index 13a70bbe5f8..565fcdcef5c 100644 --- a/cirq-google/cirq_google/ops/__init__.py +++ b/cirq-google/cirq_google/ops/__init__.py @@ -38,4 +38,6 @@ DynamicalDecouplingTag as DynamicalDecouplingTag, ) +from cirq_google.ops.wait_gate import WaitGateWithUnit as WaitGateWithUnit + from cirq_google.ops.willow_gate import WillowGate as WillowGate, WILLOW as WILLOW diff --git a/cirq-google/cirq_google/ops/wait_gate.py b/cirq-google/cirq_google/ops/wait_gate.py new file mode 100644 index 00000000000..72aa1cae750 --- /dev/null +++ b/cirq-google/cirq_google/ops/wait_gate.py @@ -0,0 +1,66 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sympy +import tunits as tu + +import cirq +from cirq_google.study import symbol_util as su + + +class WaitGateWithUnit(cirq.WaitGate): + """A wrapper on top of WaitGate that can work with units.""" + + def __init__( + self, + duration: su.ValueOrSymbol, + num_qubits: int | None = None, + qid_shape: tuple[int, ...] | None = None, + ): + if not isinstance(duration, su.ValueOrSymbol): + raise ValueError("The duration must either be a tu.Value or a sympy.Symbol.") + # Override the original duration + self._duration: su.ValueOrSymbol = duration # type: ignore[assignment] + + # The rest is copy-pasted from WaitGate. We just cannot use + # super().__init__ because of the duration. + if qid_shape is None: + if num_qubits is None: + # Assume one qubit for backwards compatibility + qid_shape = (2,) + else: + qid_shape = (2,) * num_qubits + if num_qubits is None: + num_qubits = len(qid_shape) + if not qid_shape: + raise ValueError('Waiting on an empty set of qubits.') + if num_qubits != len(qid_shape): + raise ValueError('len(qid_shape) != num_qubits') + self._qid_shape = qid_shape + + @property + def duration(self) -> sympy.Symbol | cirq.Duration: + if isinstance(self._duration, sympy.Symbol): + return self._duration + return cirq.Duration(nanos=self._duration[tu.ns]) + + def _resolve_parameters_( + self, resolver: cirq.ParamResolver, recursive: bool + ) -> WaitGateWithUnit: + if isinstance(self._duration, sympy.Symbol): + _duration = su.direct_symbol_replacement(self._duration, resolver) + return WaitGateWithUnit(_duration, qid_shape=self._qid_shape) + return self diff --git a/cirq-google/cirq_google/ops/wait_gate_test.py b/cirq-google/cirq_google/ops/wait_gate_test.py new file mode 100644 index 00000000000..f8c0690b55f --- /dev/null +++ b/cirq-google/cirq_google/ops/wait_gate_test.py @@ -0,0 +1,65 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import sympy +import tunits as tu + +import cirq +from cirq_google.ops import wait_gate as wg + + +def test_wait_gate_with_unit_init() -> None: + g = wg.WaitGateWithUnit(1 * tu.us) + assert g.duration == cirq.Duration(nanos=1000) + + g = wg.WaitGateWithUnit(1 * tu.us, num_qubits=2) + assert g._qid_shape == (2, 2) + + g = wg.WaitGateWithUnit(sympy.Symbol("d")) + assert g.duration == sympy.Symbol("d") + + with pytest.raises(ValueError, match="either be a tu.Value or a sympy.Symbol."): + wg.WaitGateWithUnit(10) + + with pytest.raises(ValueError, match="Waiting on an empty set of qubits."): + wg.WaitGateWithUnit(10 * tu.ns, qid_shape=()) + + with pytest.raises(ValueError, match="num_qubits"): + wg.WaitGateWithUnit(10 * tu.ns, qid_shape=(2, 2), num_qubits=5) + + +def test_wait_gate_with_units_resolving() -> None: + gate = wg.WaitGateWithUnit(sympy.Symbol("d")) + + resolved_gate = cirq.resolve_parameters(gate, {"d": 10 * tu.ns}) + assert resolved_gate.duration == cirq.Duration(nanos=10) + + gate = wg.WaitGateWithUnit(10 * tu.ns) + assert gate._resolve_parameters_(cirq.ParamResolver({}), True) == gate + + +def test_wait_gate_equality() -> None: + gate1 = wg.WaitGateWithUnit(10 * tu.ns) + gate2 = wg.WaitGateWithUnit(10 * tu.ns) + assert gate1 == gate2 + + gate_symbol_1 = wg.WaitGateWithUnit(sympy.Symbol("a")) + gate_symbol_2 = wg.WaitGateWithUnit(sympy.Symbol("a")) + assert gate_symbol_1 == gate_symbol_2 + assert gate_symbol_1 != gate1 + + +def test_wait_gate_jsonify() -> None: + gate = wg.WaitGateWithUnit(sympy.Symbol("d")) + assert gate == cirq.read_json(json_text=cirq.to_json(gate)) diff --git a/cirq-google/cirq_google/study/symbol_util.py b/cirq-google/cirq_google/study/symbol_util.py index b2360543201..75b67ba2e29 100644 --- a/cirq-google/cirq_google/study/symbol_util.py +++ b/cirq-google/cirq_google/study/symbol_util.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import AbstractSet, Any, TypeAlias import sympy