Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion graphqomb/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
- `propagate_correction_map`: Propagate the correction map through a measurement at the target node.
"""

from __future__ import annotations

import sys
from collections.abc import Iterable, Mapping
from collections.abc import Set as AbstractSet
from graphlib import TopologicalSorter
from typing import Any

import typing_extensions

from graphqomb.common import Plane
from graphqomb.common import Plane, determine_pauli_axis, Axis
from graphqomb.graphstate import BaseGraphState, odd_neighbors

Check failure on line 23 in graphqomb/feedforward.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

graphqomb/feedforward.py:12:1: I001 Import block is un-sorted or un-formatted

if sys.version_info >= (3, 10):
from typing import TypeGuard
Expand Down Expand Up @@ -277,3 +277,57 @@
new_zflow[parent] ^= {child_z}

return new_xflow, new_zflow


def pauli_simplification( # noqa: C901
graph: BaseGraphState,
xflow: Mapping[int, AbstractSet[int]],
zflow: Mapping[int, AbstractSet[int]] | None = None,
) -> tuple[dict[int, set[int]], dict[int, set[int]]]:
r"""Simplify the correction maps by removing redundant Pauli corrections.

Parameters
----------
graph : `BaseGraphState`
Underlying graph state.
xflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
Correction map for X.
zflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None`
Correction map for Z. If `None`, it is generated from xflow by odd neighbors.

Returns
-------
`tuple`\[`dict`\[`int`, `set`\[`int`\]\], `dict`\[`int`, `set`\[`int`\]\]]
Updated correction maps for X and Z after simplification.
"""
if zflow is None:
zflow = {node: odd_neighbors(xflow[node], graph) - {node} for node in xflow}

new_xflow = {k: set(vs) for k, vs in xflow.items()}
new_zflow = {k: set(vs) for k, vs in zflow.items()}

inv_xflow: dict[int, set[int]] = {}
inv_zflow: dict[int, set[int]] = {}
for k, vs in xflow.items():
for v in vs:
inv_xflow.setdefault(v, set()).add(k)
for k, vs in zflow.items():
for v in vs:
inv_zflow.setdefault(v, set()).add(k)

for node in graph.physical_nodes - graph.output_node_indices.keys():
meas_basis = graph.meas_bases.get(node)
meas_axis = determine_pauli_axis(meas_basis)

if meas_axis == Axis.X:
for parent in inv_xflow.get(node, set()):
new_xflow[parent] -= {node}
elif meas_axis == Axis.Z:
for parent in inv_zflow.get(node, set()):
new_zflow[parent] -= {node}
elif meas_axis == Axis.Y:
for parent in inv_xflow.get(node, set()) & inv_zflow.get(node, set()):
new_xflow[parent] -= {node}
new_zflow[parent] -= {node}

return new_xflow, new_zflow
320 changes: 320 additions & 0 deletions graphqomb/greedy_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
"""Greedy heuristic scheduler for fast MBQC pattern scheduling.

This module provides fast greedy scheduling algorithms as an alternative to
CP-SAT based optimization. The greedy algorithms provide approximate solutions
with speedup compared to CP-SAT, making them suitable for large-scale
graphs or when optimality is not critical.

This module provides:

- `greedy_minimize_time`: Fast greedy scheduler optimizing for minimal execution time
- `greedy_minimize_space`: Fast greedy scheduler optimizing for minimal qubit usage
"""

from __future__ import annotations

from graphlib import TopologicalSorter
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Set as AbstractSet

from graphqomb.graphstate import BaseGraphState


def greedy_minimize_time( # noqa: C901, PLR0912
graph: BaseGraphState,
dag: Mapping[int, AbstractSet[int]],
max_qubit_count: int | None = None,
) -> tuple[dict[int, int], dict[int, int]]:
r"""Fast greedy scheduler optimizing for minimal execution time (makespan).

This algorithm uses a straightforward greedy approach:
1. At each time step, measure all nodes that can be measured
2. Prepare all neighbors of measured nodes just before measurement

Parameters
----------
graph : `BaseGraphState`
The graph state to schedule
dag : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
The directed acyclic graph representing measurement dependencies

Returns
-------
`tuple`\[`dict`\[`int`, `int`\], `dict`\[`int`, `int`\]\]
A tuple of (prepare_time, measure_time) dictionaries

Raises
------
RuntimeError
If no nodes can be measured at a given time step, indicating a possible
"""
prepare_time: dict[int, int] = {}
measure_time: dict[int, int] = {}

unmeasured = graph.physical_nodes - graph.output_node_indices.keys()

# Build inverse DAG: for each node, track which nodes must be measured before it
inv_dag: dict[int, set[int]] = {node: set() for node in graph.physical_nodes}
for parent, children in dag.items():
for child in children:
inv_dag[child].add(parent)

prepared: set[int] = set(graph.input_node_indices.keys())
alive: set[int] = set(graph.input_node_indices.keys())

if max_qubit_count is not None and len(alive) > max_qubit_count:
msg = "Initial number of active qubits exceeds max_qubit_count."
raise RuntimeError(msg)

current_time = 0

# Nodes whose dependencies are all resolved and are not yet measured
measure_candidates: set[int] = {node for node in unmeasured if not inv_dag[node]}

# Cache neighbors to avoid repeated set constructions in tight loops
neighbors_map = {node: graph.neighbors(node) for node in graph.physical_nodes}

while unmeasured: # noqa: PLR1702
if not measure_candidates:
msg = "No nodes can be measured; possible cyclic dependency or incomplete preparation."
raise RuntimeError(msg)

if max_qubit_count is not None:
# Choose measurement nodes from measure_candidates while respecting max_qubit_count
to_measure, to_prepare = _determine_measure_nodes(
neighbors_map,
measure_candidates,
prepared,
alive,
max_qubit_count,
)
needs_prep = False
for neighbor in to_prepare:
if neighbor not in prepared:
prepare_time[neighbor] = current_time
prepared.add(neighbor)
alive.add(neighbor)
needs_prep = True # toggle prep flag

# If this neighbor already had no dependencies, it becomes measure candidate
if not inv_dag[neighbor] and neighbor in unmeasured:
measure_candidates.add(neighbor)
else:
# Without a qubit limit, measure all currently measure candidates
to_measure = set(measure_candidates)
needs_prep = False
for node in to_measure:
for neighbor in neighbors_map[node]:
if neighbor not in prepared:
prepare_time[neighbor] = current_time
prepared.add(neighbor)
alive.add(neighbor)
needs_prep = True

if not inv_dag[neighbor] and neighbor in unmeasured:
measure_candidates.add(neighbor)

# Measure at current_time if no prep needed, otherwise at current_time + 1
meas_time = current_time + 1 if needs_prep else current_time

for node in to_measure:
measure_time[node] = meas_time
alive.remove(node)
unmeasured.remove(node)
measure_candidates.remove(node)

# Remove measured node from dependencies of all its children in the DAG
for child in dag.get(node, ()):
inv_dag[child].remove(node)
if not inv_dag[child] and child in unmeasured:
measure_candidates.add(child)

current_time = meas_time + 1

return prepare_time, measure_time


def _determine_measure_nodes(
neighbors_map: Mapping[int, AbstractSet[int]],
measure_candidates: AbstractSet[int],
prepared: AbstractSet[int],
alive: AbstractSet[int],
max_qubit_count: int,
) -> tuple[set[int], set[int]]:
r"""Determine which nodes to measure without exceeding max qubit count.

Parameters
----------
neighbors_map : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
Mapping from node to its neighbors.
measure_candidates : `collections.abc.Set`\[`int`\]
The candidate nodes available for measurement.
prepared : `collections.abc.Set`\[`int`\]
The set of currently prepared nodes.
alive : `collections.abc.Set`\[`int`\]
The set of currently active (prepared but not yet measured) nodes.
max_qubit_count : `int`
The maximum allowed number of active qubits.

Returns
-------
`tuple`\[`set`\[`int`\], `set`\[`int`\]\]
A tuple of (to_measure, to_prepare) sets indicating which nodes to measure and prepare.

Raises
------
RuntimeError
If no nodes can be measured without exceeding the max qubit count.
"""
to_measure: set[int] = set()
to_prepare: set[int] = set()

for node in measure_candidates:
# Neighbors that still need to be prepared for this node
new_neighbors = neighbors_map[node] - prepared
additional_to_prepare = new_neighbors - to_prepare

# Projected number of active qubits after preparing these neighbors
projected_active = len(alive) + len(to_prepare) + len(additional_to_prepare)

if projected_active <= max_qubit_count:
to_measure.add(node)
to_prepare |= new_neighbors

if not to_measure:
msg = "Cannot schedule more measurements without exceeding max qubit count. Please increase max_qubit_count."
raise RuntimeError(msg)

return to_measure, to_prepare


def greedy_minimize_space( # noqa: C901, PLR0912

Check failure on line 194 in graphqomb/greedy_scheduler.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF100)

graphqomb/greedy_scheduler.py:194:29: RUF100 Unused `noqa` directive (unused: `PLR0912`)

Check failure on line 194 in graphqomb/greedy_scheduler.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR0914)

graphqomb/greedy_scheduler.py:194:5: PLR0914 Too many local variables (19/15)
graph: BaseGraphState,
dag: Mapping[int, AbstractSet[int]],
) -> tuple[dict[int, int], dict[int, int]]:
r"""Fast greedy scheduler optimizing for minimal qubit usage (space).

This algorithm uses a greedy approach to minimize the number of active
qubits at each time step:
1. At each time step, select the next node to measure that minimizes the
number of new qubits that need to be prepared.
2. Prepare neighbors of the measured node just before measurement.

Parameters
----------
graph : `BaseGraphState`
The graph state to schedule
dag : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
The directed acyclic graph representing measurement dependencies

Returns
-------
`tuple`\[`dict`\[`int`, `int`\], `dict`\[`int`, `int`\]
A tuple of (prepare_time, measure_time) dictionaries

Raises
------
RuntimeError
If no nodes can be measured at a given time step, indicating a possible
cyclic dependency or incomplete preparation.
"""
prepare_time: dict[int, int] = {}
measure_time: dict[int, int] = {}

unmeasured = graph.physical_nodes - graph.output_node_indices.keys()

topo_order = list(TopologicalSorter(dag).static_order())
topo_order.reverse() # from parents to children
topo_rank = {node: i for i, node in enumerate(topo_order)}

# Build inverse DAG: for each node, track which nodes must be measured before it
inv_dag: dict[int, set[int]] = {node: set() for node in graph.physical_nodes}
for parent, children in dag.items():
for child in children:
inv_dag[child].add(parent)

prepared: set[int] = set(graph.input_node_indices.keys())
alive: set[int] = set(graph.input_node_indices.keys())
current_time = 0

# Cache neighbors once as the graph is static during scheduling
neighbors_map = {node: graph.neighbors(node) for node in graph.physical_nodes}

measure_candidates: set[int] = {node for node in unmeasured if not inv_dag[node]}

while unmeasured:
if not measure_candidates:
msg = "No nodes can be measured; possible cyclic dependency or incomplete preparation."
raise RuntimeError(msg)

# calculate costs and pick the best node to measure
best_node_candidate: set[int] = set()
best_cost = float("inf")
for node in measure_candidates:
cost = _calc_activate_cost(node, neighbors_map, prepared)
if cost < best_cost:
best_cost = cost
best_node_candidate = {node}
elif cost == best_cost:
best_node_candidate.add(node)

# tie-breaker: choose the node that appears first in topological order
default_rank = len(topo_rank)
best_node = min(best_node_candidate, key=lambda n: topo_rank.get(n, default_rank))

# Prepare neighbors at current_time
needs_prep = False
for neighbor in neighbors_map[best_node]:
if neighbor not in prepared:
prepare_time[neighbor] = current_time
prepared.add(neighbor)
alive.add(neighbor)
needs_prep = True

# Measure at current_time if no prep needed, otherwise at current_time + 1
meas_time = current_time + 1 if needs_prep else current_time
measure_time[best_node] = meas_time
unmeasured.remove(best_node)
alive.remove(best_node)

measure_candidates.remove(best_node)

# Remove measured node from dependencies of all its children in the DAG
for child in dag.get(best_node, ()):
inv_dag[child].remove(best_node)
if not inv_dag[child] and child in unmeasured:
measure_candidates.add(child)

current_time = meas_time + 1

return prepare_time, measure_time


def _calc_activate_cost(
node: int,
neighbors_map: Mapping[int, AbstractSet[int]],
prepared: AbstractSet[int],
) -> int:
r"""Calculate the cost of activating (preparing) a node.

The cost is defined as the number of new qubits that would become active
(prepared but not yet measured) if this node were to be measured next.

Parameters
----------
node : `int`
The node to evaluate.
neighbors_map : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\]
Cached neighbor sets for graph nodes.
prepared : `collections.abc.Set`\[`int`\]
The set of currently prepared nodes.

Returns
-------
`int`
The activation cost for the node.
"""
return len(neighbors_map[node] - prepared)
Loading
Loading