Skip to content

Commit 8ce1a4b

Browse files
Provide native support for U3 gate (#7717)
Adjust QasmUGate to have the same unitary as qiskit U3Gate. Fixes #7634 and #5959 --------- Co-authored-by: Pavol Juhas <[email protected]>
1 parent 12b3906 commit 8ce1a4b

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

cirq-core/cirq/circuits/qasm_output.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ def __repr__(self) -> str:
7575

7676
def _decompose_(self, qubits):
7777
q = qubits[0]
78+
phase_correction_half_turns = (self.phi + self.lmda) / 2
7879
return [
7980
ops.rz(self.lmda * np.pi).on(q),
8081
ops.ry(self.theta * np.pi).on(q),
8182
ops.rz(self.phi * np.pi).on(q),
83+
ops.global_phase_operation(1j ** (2 * phase_correction_half_turns)),
8284
]
8385

8486
def _value_equality_values_(self):

cirq-core/cirq/circuits/qasm_output_test.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ def _make_qubits(n):
2929
return [cirq.NamedQubit(f'q{i}') for i in range(n)]
3030

3131

32+
def _qiskit_ugate_unitary(gate: QasmUGate) -> np.ndarray:
33+
# Ref: https://quantum.cloud.ibm.com/docs/en/api/qiskit/qiskit.circuit.library.U3Gate#u3gate
34+
th, ph, lm = np.pi * np.array([gate.theta, gate.phi, gate.lmda])
35+
return np.array(
36+
[
37+
[np.cos(th / 2), -np.exp(1j * lm) * np.sin(th / 2)],
38+
[np.exp(1j * ph) * np.sin(th / 2), np.exp(1j * (ph + lm)) * np.cos(th / 2)],
39+
]
40+
)
41+
42+
3243
def test_u_gate_repr() -> None:
3344
gate = QasmUGate(0.1, 0.2, 0.3)
3445
assert repr(gate) == 'cirq.circuits.qasm_output.QasmUGate(theta=0.1, phi=0.2, lmda=0.3)'
@@ -43,6 +54,20 @@ def test_u_gate_eq() -> None:
4354
cirq.approx_eq(gate4, gate3, atol=1e-16)
4455

4556

57+
@pytest.mark.parametrize("_", range(10))
58+
def test_u_gate_from_qiskit_ugate_unitary(_) -> None:
59+
# QasmUGate at (theta, phi, lmda) is the same as QasmUGate at
60+
# (2 - theta, phi + 1, lmda + 1) and a global phase factor of -1.
61+
# QasmUGate.from_matrix resolves theta at [0, 1] and ignores possible global
62+
# phase. To avoid phase discrepancy we limit theta to the [0, 1] interval.
63+
theta = np.random.uniform(0, 1)
64+
phi = np.random.uniform(0, 2)
65+
lmda = np.random.uniform(0, 2)
66+
u = _qiskit_ugate_unitary(QasmUGate(theta, phi, lmda))
67+
g = QasmUGate.from_matrix(u)
68+
np.testing.assert_allclose(cirq.unitary(g), u, atol=1e-7)
69+
70+
4671
def test_qasm_two_qubit_gate_repr() -> None:
4772
cirq.testing.assert_equivalent_repr(
4873
QasmTwoQubitGate.from_matrix(cirq.testing.random_unitary(4))
@@ -53,13 +78,14 @@ def test_qasm_u_qubit_gate_unitary() -> None:
5378
u = cirq.testing.random_unitary(2)
5479
g = QasmUGate.from_matrix(u)
5580
cirq.testing.assert_allclose_up_to_global_phase(cirq.unitary(g), u, atol=1e-7)
56-
5781
cirq.testing.assert_implements_consistent_protocols(g)
82+
np.testing.assert_allclose(cirq.unitary(g), _qiskit_ugate_unitary(g), atol=1e-7)
5883

5984
u = cirq.unitary(cirq.Y)
6085
g = QasmUGate.from_matrix(u)
6186
cirq.testing.assert_allclose_up_to_global_phase(cirq.unitary(g), u, atol=1e-7)
6287
cirq.testing.assert_implements_consistent_protocols(g)
88+
np.testing.assert_allclose(cirq.unitary(g), _qiskit_ugate_unitary(g), atol=1e-7)
6389

6490

6591
def test_qasm_two_qubit_gate_unitary() -> None:
@@ -200,7 +226,7 @@ def test_h_gate_with_parameter() -> None:
200226
)
201227

202228

203-
def test_qasm_global_pahse() -> None:
229+
def test_qasm_global_phase() -> None:
204230
output = cirq.QasmOutput((cirq.global_phase_operation(np.exp(1j * 5))), ())
205231
assert (
206232
str(output)

cirq-core/cirq/contrib/qasm_import/_parser_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,8 +1089,8 @@ def test_standard_gates_wrong_params_error(qasm_gate: str, num_params: int) -> N
10891089
# Mapping of two-qubit gates and `num_params`
10901090
two_qubit_param_gates = {
10911091
# TODO: fix and enable commented gates below
1092-
# ('cu1', cirq.ControlledGate(QasmUGate(0, 0, 0.1 / np.pi))): 1,
1093-
# ('cu3', cirq.ControlledGate(QasmUGate(0.1 / np.pi, 0.2 / np.pi, 0.3 / np.pi))): 3,
1092+
('cu1', cirq.ControlledGate(QasmUGate(0, 0, 0.1 / np.pi))): 1,
1093+
('cu3', cirq.ControlledGate(QasmUGate(0.1 / np.pi, 0.2 / np.pi, 0.3 / np.pi))): 3,
10941094
# ('cu', cirq.ControlledGate(QasmUGate(0.1 / np.pi, 0.2 / np.pi, 0.3 / np.pi))): 3,
10951095
('crx', cirq.ControlledGate(cirq.rx(0.1))): 1,
10961096
('cry', cirq.ControlledGate(cirq.ry(0.1))): 1,

0 commit comments

Comments
 (0)