Skip to content
23 changes: 22 additions & 1 deletion graphix/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import dataclasses
from copy import copy
from dataclasses import dataclass
from types import MappingProxyType
Expand All @@ -16,7 +17,7 @@
from graphix import command
from graphix.clifford import Clifford
from graphix.command import CommandKind, Node
from graphix.fundamentals import Axis
from graphix.fundamentals import Axis, Plane
from graphix.measurements import Domains, Outcome, PauliMeasurement

if TYPE_CHECKING:
Expand Down Expand Up @@ -460,3 +461,23 @@ def incorporate_pauli_results(pattern: Pattern) -> Pattern:
result.add(cmd)
result.reorder_output_nodes(pattern.output_nodes)
return result


def remove_useless_domains(pattern: Pattern) -> Pattern:
"""Return an equivalent pattern where results from Pauli presimulation are integrated in corrections."""
new_pattern = graphix.pattern.Pattern(input_nodes=pattern.input_nodes)
new_pattern.results = pattern.results
for cmd in pattern:
if cmd.kind == CommandKind.M:
if cmd.angle == 0:
if cmd.plane == Plane.XY:
new_cmd = dataclasses.replace(cmd, s_domain=set())
else:
new_cmd = dataclasses.replace(cmd, t_domain=set())
else:
new_cmd = cmd
new_pattern.add(new_cmd)
else:
new_pattern.add(cmd)
new_pattern.reorder_output_nodes(pattern.output_nodes)
return new_pattern
19 changes: 19 additions & 0 deletions graphix/random_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,22 @@ def rand_circuit(
ind = rng.integers(len(gate_choice))
gate_choice[ind](j)
return circuit


def rand_state_vector(nqubits: int, rng: Generator | None = None) -> npt.NDArray[np.complex128]:
"""
Generate a random normalized complex state vector of size 2^n.

Parameters
----------
nqubits : int
The power of 2 for the vector size

Returns
-------
numpy.ndarray
Normalized complex vector of size 2^nqubits
"""
rng = ensure_rng(rng)
vec = rng.random(2**nqubits) + 1j * rng.random(2**nqubits)
return vec / np.linalg.norm(vec)
84 changes: 36 additions & 48 deletions graphix/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _cnot_command(
E(nodes=(control_node, ancilla[0])),
E(nodes=(ancilla[0], ancilla[1])),
M(node=target_node),
M(node=ancilla[0]),
M(node=ancilla[0], s_domain={target_node}),
X(node=ancilla[1], domain={ancilla[0]}),
Z(node=ancilla[1], domain={target_node}),
Z(node=control_node, domain={target_node}),
Expand Down Expand Up @@ -552,7 +552,7 @@ def _s_command(cls, input_node: int, ancilla: Sequence[int]) -> tuple[int, list[
E(nodes=(input_node, ancilla[0])),
E(nodes=(ancilla[0], ancilla[1])),
M(node=input_node, angle=-0.5),
M(node=ancilla[0]),
M(node=ancilla[0], s_domain={input_node}),
X(node=ancilla[1], domain={ancilla[0]}),
Z(node=ancilla[1], domain={input_node}),
)
Expand Down Expand Up @@ -584,7 +584,7 @@ def _x_command(cls, input_node: int, ancilla: Sequence[int]) -> tuple[int, list[
E(nodes=(input_node, ancilla[0])),
E(nodes=(ancilla[0], ancilla[1])),
M(node=input_node),
M(node=ancilla[0], angle=-1),
M(node=ancilla[0], angle=-1, s_domain={input_node}),
X(node=ancilla[1], domain={ancilla[0]}),
Z(node=ancilla[1], domain={input_node}),
)
Expand Down Expand Up @@ -620,10 +620,10 @@ def _y_command(cls, input_node: int, ancilla: Sequence[int]) -> tuple[int, list[
E(nodes=(ancilla[2], ancilla[3])),
M(node=input_node, angle=0.5),
M(node=ancilla[0], angle=1.0, s_domain={input_node}),
M(node=ancilla[1], angle=-0.5, s_domain={input_node}),
M(node=ancilla[2]),
X(node=ancilla[3], domain={ancilla[0], ancilla[2]}),
Z(node=ancilla[3], domain={ancilla[0], ancilla[1]}),
M(node=ancilla[1], angle=-0.5, s_domain={ancilla[0]}, t_domain={input_node}),
M(node=ancilla[2], s_domain={ancilla[1]}, t_domain={ancilla[0]}),
X(node=ancilla[3], domain={ancilla[2]}),
Z(node=ancilla[3], domain={ancilla[1]}),
)
)
return ancilla[3], seq
Expand Down Expand Up @@ -653,7 +653,7 @@ def _z_command(cls, input_node: int, ancilla: Sequence[int]) -> tuple[int, list[
E(nodes=(input_node, ancilla[0])),
E(nodes=(ancilla[0], ancilla[1])),
M(node=input_node, angle=-1),
M(node=ancilla[0]),
M(node=ancilla[0], s_domain={input_node}),
X(node=ancilla[1], domain={ancilla[0]}),
Z(node=ancilla[1], domain={input_node}),
)
Expand Down Expand Up @@ -725,10 +725,10 @@ def _ry_command(cls, input_node: int, ancilla: Sequence[int], angle: Angle) -> t
E(nodes=(ancilla[2], ancilla[3])),
M(node=input_node, angle=0.5),
M(node=ancilla[0], angle=-angle / np.pi, s_domain={input_node}),
M(node=ancilla[1], angle=-0.5, s_domain={input_node}),
M(node=ancilla[2]),
X(node=ancilla[3], domain={ancilla[0], ancilla[2]}),
Z(node=ancilla[3], domain={ancilla[0], ancilla[1]}),
M(node=ancilla[1], angle=-0.5, s_domain={ancilla[0]}, t_domain={input_node}),
M(node=ancilla[2], s_domain={ancilla[1]}, t_domain={ancilla[0]}),
X(node=ancilla[3], domain={ancilla[2]}),
Z(node=ancilla[3], domain={ancilla[1]}),
)
)
return ancilla[3], seq
Expand Down Expand Up @@ -760,7 +760,7 @@ def _rz_command(cls, input_node: int, ancilla: Sequence[int], angle: Angle) -> t
E(nodes=(input_node, ancilla[0])),
E(nodes=(ancilla[0], ancilla[1])),
M(node=input_node, angle=-angle / np.pi),
M(node=ancilla[0]),
M(node=ancilla[0], s_domain={input_node}),
X(node=ancilla[1], domain={ancilla[0]}),
Z(node=ancilla[1], domain={input_node}),
)
Expand Down Expand Up @@ -829,49 +829,37 @@ def _ccx_command(
E(nodes=(ancilla[16], ancilla[17])),
M(node=target_node),
M(node=ancilla[0], s_domain={target_node}),
M(node=ancilla[1], s_domain={ancilla[0]}),
M(node=ancilla[1], s_domain={ancilla[0]}, t_domain={target_node}),
M(node=control_node1),
M(node=ancilla[2], angle=-1.75, s_domain={ancilla[1], target_node}),
M(node=ancilla[2], angle=-1.75, s_domain={ancilla[1]}, t_domain={ancilla[0]}),
M(node=ancilla[14], s_domain={control_node1}),
M(node=ancilla[3], s_domain={ancilla[2], ancilla[0]}),
M(node=ancilla[5], angle=-0.25, s_domain={ancilla[3], ancilla[1], ancilla[14], target_node}),
M(node=control_node2, angle=-0.25),
M(node=ancilla[6], s_domain={ancilla[5], ancilla[2], ancilla[0]}),
M(node=ancilla[9], s_domain={control_node2, ancilla[5], ancilla[2]}),
M(
node=ancilla[7],
angle=-1.75,
s_domain={ancilla[6], ancilla[3], ancilla[1], ancilla[14], target_node},
),
M(node=ancilla[10], angle=-1.75, s_domain={ancilla[9], ancilla[14]}),
M(node=ancilla[4], angle=-0.25, s_domain={ancilla[14]}),
M(node=ancilla[8], s_domain={ancilla[7], ancilla[5], ancilla[2], ancilla[0]}),
M(node=ancilla[11], s_domain={ancilla[10], control_node2, ancilla[5], ancilla[2]}),
M(node=ancilla[3], s_domain={ancilla[2]}, t_domain={ancilla[1], ancilla[14]}),
M(node=ancilla[5], angle=-0.25, s_domain={ancilla[3]}, t_domain={ancilla[2]}),
M(node=control_node2, angle=-0.25, t_domain={ancilla[5], ancilla[0]}),
M(node=ancilla[6], s_domain={ancilla[5]}, t_domain={ancilla[3]}),
M(node=ancilla[9], s_domain={control_node2}, t_domain={ancilla[14]}),
M(node=ancilla[7], angle=-1.75, s_domain={ancilla[6]}, t_domain={ancilla[5]}),
M(node=ancilla[10], angle=-1.75, s_domain={ancilla[9]}, t_domain={control_node2}),
M(
node=ancilla[12],
node=ancilla[4],
angle=-0.25,
s_domain={ancilla[8], ancilla[6], ancilla[3], ancilla[1], target_node},
s_domain={ancilla[14]},
t_domain={control_node1, control_node2, ancilla[2], ancilla[7], ancilla[10]},
),
M(node=ancilla[8], s_domain={ancilla[7]}, t_domain={ancilla[14], ancilla[6]}),
M(node=ancilla[11], s_domain={ancilla[10]}, t_domain={ancilla[9], ancilla[14]}),
M(node=ancilla[12], angle=-0.25, s_domain={ancilla[8]}, t_domain={ancilla[7]}),
M(
node=ancilla[16],
s_domain={
ancilla[4],
control_node1,
ancilla[2],
control_node2,
ancilla[7],
ancilla[10],
ancilla[2],
control_node2,
ancilla[5],
},
s_domain={ancilla[4]},
t_domain={ancilla[14]},
),
X(node=ancilla[17], domain={ancilla[14], ancilla[16]}),
X(node=ancilla[15], domain={ancilla[9], ancilla[11]}),
X(node=ancilla[13], domain={ancilla[0], ancilla[2], ancilla[5], ancilla[7], ancilla[12]}),
Z(node=ancilla[17], domain={ancilla[4], ancilla[5], ancilla[7], ancilla[10], control_node1}),
Z(node=ancilla[15], domain={control_node2, ancilla[2], ancilla[5], ancilla[10]}),
Z(node=ancilla[13], domain={ancilla[1], ancilla[3], ancilla[6], ancilla[8], target_node}),
X(node=ancilla[17], domain={ancilla[16]}),
X(node=ancilla[15], domain={ancilla[11]}),
X(node=ancilla[13], domain={ancilla[12]}),
Z(node=ancilla[17], domain={ancilla[4]}),
Z(node=ancilla[15], domain={ancilla[10]}),
Z(node=ancilla[13], domain={ancilla[8]}),
)
)
return ancilla[17], ancilla[15], ancilla[13], seq
Expand Down
18 changes: 17 additions & 1 deletion tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from graphix.command import C, Command, CommandKind, E, M, N, X, Z
from graphix.fundamentals import Plane
from graphix.gflow import gflow_from_pattern
from graphix.optimization import StandardizedPattern, incorporate_pauli_results
from graphix.optimization import StandardizedPattern, incorporate_pauli_results, remove_useless_domains
from graphix.pattern import Pattern
from graphix.random_objects import rand_circuit
from graphix.states import PlanarState
Expand Down Expand Up @@ -86,6 +86,22 @@ def test_flow_after_pauli_preprocessing(fx_bg: PCG64, jumps: int) -> None:
assert f is not None


@pytest.mark.parametrize("jumps", range(1, 11))
def test_remove_useless_domains(fx_bg: PCG64, jumps: int) -> None:
rng = Generator(fx_bg.jumped(jumps))
nqubits = 3
depth = 3
circuit = rand_circuit(nqubits, depth, rng)
pattern = circuit.transpile().pattern
pattern.standardize()
pattern.shift_signals()
pattern.perform_pauli_measurements()
pattern2 = remove_useless_domains(pattern)
state = pattern.simulate_pattern(rng=rng)
state2 = pattern2.simulate_pattern(rng=rng)
assert np.abs(np.dot(state.flatten().conjugate(), state2.flatten())) == pytest.approx(1)


def test_to_space_optimal_pattern() -> None:
pattern = Pattern(
cmds=[
Expand Down
62 changes: 39 additions & 23 deletions tests/test_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,34 @@

from graphix import instruction
from graphix.fundamentals import Plane
from graphix.random_objects import rand_circuit, rand_gate
from graphix.gflow import flow_from_pattern
from graphix.random_objects import rand_circuit, rand_gate, rand_state_vector
from graphix.transpiler import Circuit

if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypeAlias

from graphix.instruction import Instruction

InstructionTestCase: TypeAlias = Callable[[Generator], Instruction]

INSTRUCTION_TEST_CASES: list[InstructionTestCase] = [
lambda _rng: instruction.CCX(0, (1, 2)),
lambda rng: instruction.RZZ(0, 1, rng.random() * 2 * np.pi),
lambda _rng: instruction.CNOT(0, 1),
lambda _rng: instruction.SWAP((0, 1)),
lambda _rng: instruction.H(0),
lambda _rng: instruction.S(0),
lambda _rng: instruction.X(0),
lambda _rng: instruction.Y(0),
lambda _rng: instruction.Z(0),
lambda _rng: instruction.I(0),
lambda rng: instruction.RX(0, rng.random() * 2 * np.pi),
lambda rng: instruction.RY(0, rng.random() * 2 * np.pi),
lambda rng: instruction.RZ(0, rng.random() * 2 * np.pi),
]


class TestTranspilerUnitGates:
def test_cnot(self, fx_rng: Generator) -> None:
Expand Down Expand Up @@ -156,27 +178,21 @@ def test_add_extend(self) -> None:
circuit2 = Circuit(3, instr=circuit.instruction)
assert circuit.instruction == circuit2.instruction

@pytest.mark.parametrize(
"instruction",
[
instruction.CCX(0, (1, 2)),
instruction.RZZ(0, 1, np.pi / 4),
instruction.CNOT(0, 1),
instruction.SWAP((0, 1)),
instruction.H(0),
instruction.S(0),
instruction.X(0),
instruction.Y(0),
instruction.Z(0),
instruction.I(0),
instruction.RX(0, 0),
instruction.RY(0, 0),
instruction.RZ(0, 0),
],
)
def test_instructions(self, fx_rng: Generator, instruction: Instruction) -> None:
circuit = Circuit(3, instr=[instruction])
@pytest.mark.parametrize("instruction", INSTRUCTION_TEST_CASES)
def test_instruction_flow(self, fx_rng: Generator, instruction: InstructionTestCase) -> None:
circuit = Circuit(3, instr=[instruction(fx_rng)])
pattern = circuit.transpile().pattern
state = circuit.simulate_statevector().statevec
state_mbqc = pattern.simulate_pattern(rng=fx_rng)
pattern.standardize()
f, _l = flow_from_pattern(pattern)
assert f is not None

@pytest.mark.parametrize("jumps", range(1, 11))
@pytest.mark.parametrize("instruction", INSTRUCTION_TEST_CASES)
def test_instructions(self, fx_bg: PCG64, jumps: int, instruction: InstructionTestCase) -> None:
rng = Generator(fx_bg.jumped(jumps))
circuit = Circuit(3, instr=[instruction(rng)])
pattern = circuit.transpile().pattern
input_state = rand_state_vector(3, rng=rng)
state = circuit.simulate_statevector(input_state=input_state).statevec
state_mbqc = pattern.simulate_pattern(input_state=input_state, rng=rng)
assert np.abs(np.dot(state_mbqc.flatten().conjugate(), state.flatten())) == pytest.approx(1)
Loading