diff --git a/cirq-core/cirq/ops/matrix_gates.py b/cirq-core/cirq/ops/matrix_gates.py index edc42f9a05e..68dcf2d9e74 100644 --- a/cirq-core/cirq/ops/matrix_gates.py +++ b/cirq-core/cirq/ops/matrix_gates.py @@ -22,7 +22,7 @@ from cirq import _import, linalg, protocols from cirq._compat import proper_repr -from cirq.ops import global_phase_op, identity, phased_x_z_gate, raw_types +from cirq.ops import global_phase_op, phased_x_z_gate, raw_types if TYPE_CHECKING: import cirq @@ -170,8 +170,8 @@ def _decompose_(self, qubits: tuple[cirq.Qid, ...]) -> cirq.OP_TREE: return NotImplemented # The above algorithms ignore phase, but phase is important to maintain if the gate is # controlled. Here, we add it back in with a global phase op. - ident = identity.IdentityGate(qid_shape=self._qid_shape).on(*qubits) # Preserve qid order - u = protocols.unitary(Circuit(ident, *decomposed)).reshape(self._matrix.shape) + circuit = Circuit(*decomposed) + u = circuit.unitary(qubit_order=qubits, qubits_that_should_be_present=qubits) phase_delta = linalg.phase_delta(u, self._matrix) # Phase delta is on the complex unit circle, so if real(phase_delta) >= 1, that means # no phase delta. (>1 is rounding error). diff --git a/cirq-core/cirq/ops/matrix_gates_test.py b/cirq-core/cirq/ops/matrix_gates_test.py index 9a83211cd25..fd2a674d999 100644 --- a/cirq-core/cirq/ops/matrix_gates_test.py +++ b/cirq-core/cirq/ops/matrix_gates_test.py @@ -413,3 +413,16 @@ def test_matrixgate_name_serialization(): gate_after_serialization3 = cirq.read_json(json_text=cirq.to_json(gate3)) assert gate3._name == '' assert gate_after_serialization3._name == '' + + +def test_decompose_when_qubits_not_in_ascending_order(): + # Previous code for preserving global phase would misorder qubits + q0, q1 = cirq.LineQubit.range(2) + circuit1 = cirq.Circuit() + matrix = cirq.testing.random_unitary(4, random_state=0) + circuit1.append(cirq.MatrixGate(matrix).on(q1, q0)) + u1 = cirq.unitary(circuit1) + decomposed = cirq.decompose(circuit1) + circuit2 = cirq.Circuit(decomposed) + u2 = cirq.unitary(circuit2) + np.testing.assert_allclose(u1, u2, atol=1e-14)