diff --git a/graphqomb/feedforward.py b/graphqomb/feedforward.py index fdb0102e..f176d859 100644 --- a/graphqomb/feedforward.py +++ b/graphqomb/feedforward.py @@ -19,7 +19,7 @@ import typing_extensions -from graphqomb.common import Plane +from graphqomb.common import Plane, determine_pauli_axis, Axis from graphqomb.graphstate import BaseGraphState, odd_neighbors if sys.version_info >= (3, 10): @@ -277,3 +277,57 @@ def propagate_correction_map( # noqa: C901, PLR0912 new_zflow[parent] ^= {child_z} return new_xflow, new_zflow + + +def pauli_simplification( # noqa: C901 + graph: BaseGraphState, + xflow: Mapping[int, AbstractSet[int]], + zflow: Mapping[int, AbstractSet[int]] | None = None, +) -> tuple[dict[int, set[int]], dict[int, set[int]]]: + r"""Simplify the correction maps by removing redundant Pauli corrections. + + Parameters + ---------- + graph : `BaseGraphState` + Underlying graph state. + xflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] + Correction map for X. + zflow : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] | `None` + Correction map for Z. If `None`, it is generated from xflow by odd neighbors. + + Returns + ------- + `tuple`\[`dict`\[`int`, `set`\[`int`\]\], `dict`\[`int`, `set`\[`int`\]\]] + Updated correction maps for X and Z after simplification. + """ + if zflow is None: + zflow = {node: odd_neighbors(xflow[node], graph) - {node} for node in xflow} + + new_xflow = {k: set(vs) for k, vs in xflow.items()} + new_zflow = {k: set(vs) for k, vs in zflow.items()} + + inv_xflow: dict[int, set[int]] = {} + inv_zflow: dict[int, set[int]] = {} + for k, vs in xflow.items(): + for v in vs: + inv_xflow.setdefault(v, set()).add(k) + for k, vs in zflow.items(): + for v in vs: + inv_zflow.setdefault(v, set()).add(k) + + for node in graph.physical_nodes - graph.output_node_indices.keys(): + meas_basis = graph.meas_bases.get(node) + meas_axis = determine_pauli_axis(meas_basis) + + if meas_axis == Axis.X: + for parent in inv_xflow.get(node, set()): + new_xflow[parent] -= {node} + elif meas_axis == Axis.Z: + for parent in inv_zflow.get(node, set()): + new_zflow[parent] -= {node} + elif meas_axis == Axis.Y: + for parent in inv_xflow.get(node, set()) & inv_zflow.get(node, set()): + new_xflow[parent] -= {node} + new_zflow[parent] -= {node} + + return new_xflow, new_zflow diff --git a/graphqomb/greedy_scheduler.py b/graphqomb/greedy_scheduler.py new file mode 100644 index 00000000..84737311 --- /dev/null +++ b/graphqomb/greedy_scheduler.py @@ -0,0 +1,320 @@ +"""Greedy heuristic scheduler for fast MBQC pattern scheduling. + +This module provides fast greedy scheduling algorithms as an alternative to +CP-SAT based optimization. The greedy algorithms provide approximate solutions +with speedup compared to CP-SAT, making them suitable for large-scale +graphs or when optimality is not critical. + +This module provides: + +- `greedy_minimize_time`: Fast greedy scheduler optimizing for minimal execution time +- `greedy_minimize_space`: Fast greedy scheduler optimizing for minimal qubit usage +""" + +from __future__ import annotations + +from graphlib import TopologicalSorter +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Mapping + from collections.abc import Set as AbstractSet + + from graphqomb.graphstate import BaseGraphState + + +def greedy_minimize_time( # noqa: C901, PLR0912 + graph: BaseGraphState, + dag: Mapping[int, AbstractSet[int]], + max_qubit_count: int | None = None, +) -> tuple[dict[int, int], dict[int, int]]: + r"""Fast greedy scheduler optimizing for minimal execution time (makespan). + + This algorithm uses a straightforward greedy approach: + 1. At each time step, measure all nodes that can be measured + 2. Prepare all neighbors of measured nodes just before measurement + + Parameters + ---------- + graph : `BaseGraphState` + The graph state to schedule + dag : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] + The directed acyclic graph representing measurement dependencies + + Returns + ------- + `tuple`\[`dict`\[`int`, `int`\], `dict`\[`int`, `int`\]\] + A tuple of (prepare_time, measure_time) dictionaries + + Raises + ------ + RuntimeError + If no nodes can be measured at a given time step, indicating a possible + """ + prepare_time: dict[int, int] = {} + measure_time: dict[int, int] = {} + + unmeasured = graph.physical_nodes - graph.output_node_indices.keys() + + # Build inverse DAG: for each node, track which nodes must be measured before it + inv_dag: dict[int, set[int]] = {node: set() for node in graph.physical_nodes} + for parent, children in dag.items(): + for child in children: + inv_dag[child].add(parent) + + prepared: set[int] = set(graph.input_node_indices.keys()) + alive: set[int] = set(graph.input_node_indices.keys()) + + if max_qubit_count is not None and len(alive) > max_qubit_count: + msg = "Initial number of active qubits exceeds max_qubit_count." + raise RuntimeError(msg) + + current_time = 0 + + # Nodes whose dependencies are all resolved and are not yet measured + measure_candidates: set[int] = {node for node in unmeasured if not inv_dag[node]} + + # Cache neighbors to avoid repeated set constructions in tight loops + neighbors_map = {node: graph.neighbors(node) for node in graph.physical_nodes} + + while unmeasured: # noqa: PLR1702 + if not measure_candidates: + msg = "No nodes can be measured; possible cyclic dependency or incomplete preparation." + raise RuntimeError(msg) + + if max_qubit_count is not None: + # Choose measurement nodes from measure_candidates while respecting max_qubit_count + to_measure, to_prepare = _determine_measure_nodes( + neighbors_map, + measure_candidates, + prepared, + alive, + max_qubit_count, + ) + needs_prep = False + for neighbor in to_prepare: + if neighbor not in prepared: + prepare_time[neighbor] = current_time + prepared.add(neighbor) + alive.add(neighbor) + needs_prep = True # toggle prep flag + + # If this neighbor already had no dependencies, it becomes measure candidate + if not inv_dag[neighbor] and neighbor in unmeasured: + measure_candidates.add(neighbor) + else: + # Without a qubit limit, measure all currently measure candidates + to_measure = set(measure_candidates) + needs_prep = False + for node in to_measure: + for neighbor in neighbors_map[node]: + if neighbor not in prepared: + prepare_time[neighbor] = current_time + prepared.add(neighbor) + alive.add(neighbor) + needs_prep = True + + if not inv_dag[neighbor] and neighbor in unmeasured: + measure_candidates.add(neighbor) + + # Measure at current_time if no prep needed, otherwise at current_time + 1 + meas_time = current_time + 1 if needs_prep else current_time + + for node in to_measure: + measure_time[node] = meas_time + alive.remove(node) + unmeasured.remove(node) + measure_candidates.remove(node) + + # Remove measured node from dependencies of all its children in the DAG + for child in dag.get(node, ()): + inv_dag[child].remove(node) + if not inv_dag[child] and child in unmeasured: + measure_candidates.add(child) + + current_time = meas_time + 1 + + return prepare_time, measure_time + + +def _determine_measure_nodes( + neighbors_map: Mapping[int, AbstractSet[int]], + measure_candidates: AbstractSet[int], + prepared: AbstractSet[int], + alive: AbstractSet[int], + max_qubit_count: int, +) -> tuple[set[int], set[int]]: + r"""Determine which nodes to measure without exceeding max qubit count. + + Parameters + ---------- + neighbors_map : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] + Mapping from node to its neighbors. + measure_candidates : `collections.abc.Set`\[`int`\] + The candidate nodes available for measurement. + prepared : `collections.abc.Set`\[`int`\] + The set of currently prepared nodes. + alive : `collections.abc.Set`\[`int`\] + The set of currently active (prepared but not yet measured) nodes. + max_qubit_count : `int` + The maximum allowed number of active qubits. + + Returns + ------- + `tuple`\[`set`\[`int`\], `set`\[`int`\]\] + A tuple of (to_measure, to_prepare) sets indicating which nodes to measure and prepare. + + Raises + ------ + RuntimeError + If no nodes can be measured without exceeding the max qubit count. + """ + to_measure: set[int] = set() + to_prepare: set[int] = set() + + for node in measure_candidates: + # Neighbors that still need to be prepared for this node + new_neighbors = neighbors_map[node] - prepared + additional_to_prepare = new_neighbors - to_prepare + + # Projected number of active qubits after preparing these neighbors + projected_active = len(alive) + len(to_prepare) + len(additional_to_prepare) + + if projected_active <= max_qubit_count: + to_measure.add(node) + to_prepare |= new_neighbors + + if not to_measure: + msg = "Cannot schedule more measurements without exceeding max qubit count. Please increase max_qubit_count." + raise RuntimeError(msg) + + return to_measure, to_prepare + + +def greedy_minimize_space( # noqa: C901, PLR0912 + graph: BaseGraphState, + dag: Mapping[int, AbstractSet[int]], +) -> tuple[dict[int, int], dict[int, int]]: + r"""Fast greedy scheduler optimizing for minimal qubit usage (space). + + This algorithm uses a greedy approach to minimize the number of active + qubits at each time step: + 1. At each time step, select the next node to measure that minimizes the + number of new qubits that need to be prepared. + 2. Prepare neighbors of the measured node just before measurement. + + Parameters + ---------- + graph : `BaseGraphState` + The graph state to schedule + dag : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] + The directed acyclic graph representing measurement dependencies + + Returns + ------- + `tuple`\[`dict`\[`int`, `int`\], `dict`\[`int`, `int`\] + A tuple of (prepare_time, measure_time) dictionaries + + Raises + ------ + RuntimeError + If no nodes can be measured at a given time step, indicating a possible + cyclic dependency or incomplete preparation. + """ + prepare_time: dict[int, int] = {} + measure_time: dict[int, int] = {} + + unmeasured = graph.physical_nodes - graph.output_node_indices.keys() + + topo_order = list(TopologicalSorter(dag).static_order()) + topo_order.reverse() # from parents to children + topo_rank = {node: i for i, node in enumerate(topo_order)} + + # Build inverse DAG: for each node, track which nodes must be measured before it + inv_dag: dict[int, set[int]] = {node: set() for node in graph.physical_nodes} + for parent, children in dag.items(): + for child in children: + inv_dag[child].add(parent) + + prepared: set[int] = set(graph.input_node_indices.keys()) + alive: set[int] = set(graph.input_node_indices.keys()) + current_time = 0 + + # Cache neighbors once as the graph is static during scheduling + neighbors_map = {node: graph.neighbors(node) for node in graph.physical_nodes} + + measure_candidates: set[int] = {node for node in unmeasured if not inv_dag[node]} + + while unmeasured: + if not measure_candidates: + msg = "No nodes can be measured; possible cyclic dependency or incomplete preparation." + raise RuntimeError(msg) + + # calculate costs and pick the best node to measure + best_node_candidate: set[int] = set() + best_cost = float("inf") + for node in measure_candidates: + cost = _calc_activate_cost(node, neighbors_map, prepared) + if cost < best_cost: + best_cost = cost + best_node_candidate = {node} + elif cost == best_cost: + best_node_candidate.add(node) + + # tie-breaker: choose the node that appears first in topological order + default_rank = len(topo_rank) + best_node = min(best_node_candidate, key=lambda n: topo_rank.get(n, default_rank)) + + # Prepare neighbors at current_time + needs_prep = False + for neighbor in neighbors_map[best_node]: + if neighbor not in prepared: + prepare_time[neighbor] = current_time + prepared.add(neighbor) + alive.add(neighbor) + needs_prep = True + + # Measure at current_time if no prep needed, otherwise at current_time + 1 + meas_time = current_time + 1 if needs_prep else current_time + measure_time[best_node] = meas_time + unmeasured.remove(best_node) + alive.remove(best_node) + + measure_candidates.remove(best_node) + + # Remove measured node from dependencies of all its children in the DAG + for child in dag.get(best_node, ()): + inv_dag[child].remove(best_node) + if not inv_dag[child] and child in unmeasured: + measure_candidates.add(child) + + current_time = meas_time + 1 + + return prepare_time, measure_time + + +def _calc_activate_cost( + node: int, + neighbors_map: Mapping[int, AbstractSet[int]], + prepared: AbstractSet[int], +) -> int: + r"""Calculate the cost of activating (preparing) a node. + + The cost is defined as the number of new qubits that would become active + (prepared but not yet measured) if this node were to be measured next. + + Parameters + ---------- + node : `int` + The node to evaluate. + neighbors_map : `collections.abc.Mapping`\[`int`, `collections.abc.Set`\[`int`\]\] + Cached neighbor sets for graph nodes. + prepared : `collections.abc.Set`\[`int`\] + The set of currently prepared nodes. + + Returns + ------- + `int` + The activation cost for the node. + """ + return len(neighbors_map[node] - prepared) diff --git a/graphqomb/pattern.py b/graphqomb/pattern.py index b9718da4..baf4170f 100644 --- a/graphqomb/pattern.py +++ b/graphqomb/pattern.py @@ -101,6 +101,68 @@ def depth(self) -> int: """ return sum(1 for cmd in self.commands if isinstance(cmd, TICK)) + @property + def volume(self) -> int: + """Calculate tha volume, summation of space for each timeslice. + + Returns + ------- + `int` + Volume of the pattern + """ + return sum(self.space) + + @property + def max_volume(self) -> int: + """Calculate the maximum volume, defined as max_space * depth. + + Returns + ------- + `int` + Maximum volume of the pattern + """ + return self.max_space * self.depth + + @property + def idle_times(self) -> dict[int, int]: + r"""Calculate the idle times for each qubit in the pattern. + + Returns + ------- + `dict`\[`int`, `int`\] + A dictionary mapping each qubit index to its idle time. + """ + idle_times: dict[int, int] = {} + prepared_time: dict[int, int] = dict.fromkeys(self.input_node_indices, -1) + + current_time = 0 + for cmd in self.commands: + if isinstance(cmd, TICK): + current_time += 1 + elif isinstance(cmd, N): + prepared_time[cmd.node] = current_time + elif isinstance(cmd, M): + idle_times[cmd.node] = current_time - prepared_time[cmd.node] + + for output_node in self.output_node_indices: + if output_node in prepared_time: + idle_times[output_node] = current_time - prepared_time[output_node] + + return idle_times + + @property + def throughput(self) -> float: + """Calculate the number of measurements per TICK in the pattern. + + Returns + ------- + `float` + Number of measurements per TICK + """ + num_measurements = sum(1 for cmd in self.commands if isinstance(cmd, M)) + num_ticks = self.depth + return num_measurements / num_ticks + def is_runnable(pattern: Pattern) -> None: """Check if the pattern is runnable. diff --git a/graphqomb/schedule_solver.py b/graphqomb/schedule_solver.py index 90cbeb05..fa12f498 100644 --- a/graphqomb/schedule_solver.py +++ b/graphqomb/schedule_solver.py @@ -37,6 +37,7 @@ class ScheduleConfig: strategy: Strategy max_qubit_count: int | None = None max_time: int | None = None + use_greedy: bool = False @dataclass diff --git a/graphqomb/scheduler.py b/graphqomb/scheduler.py index 889f045d..423f7469 100644 --- a/graphqomb/scheduler.py +++ b/graphqomb/scheduler.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, NamedTuple from graphqomb.feedforward import dag_from_flow +from graphqomb.greedy_scheduler import greedy_minimize_space, greedy_minimize_time from graphqomb.schedule_solver import ScheduleConfig, Strategy, solve_schedule if TYPE_CHECKING: @@ -484,14 +485,15 @@ def solve_schedule( config: ScheduleConfig | None = None, timeout: int = 60, ) -> bool: - r"""Compute the schedule using the constraint programming solver. + r"""Compute the schedule using constraint programming or greedy heuristics. Parameters ---------- config : `ScheduleConfig` | `None`, optional - The scheduling configuration. If None, defaults to MINIMIZE_SPACE strategy. + The scheduling configuration. If None, defaults to MINIMIZE_TIME strategy. timeout : `int`, optional - Maximum solve time in seconds, by default 60 + Maximum solve time in seconds for CP-SAT solver, by default 60. + Ignored when use_greedy=True. Returns ------- @@ -506,7 +508,16 @@ def solve_schedule( if config is None: config = ScheduleConfig(Strategy.MINIMIZE_TIME) - result = solve_schedule(self.graph, self.dag, config, timeout) + if config.use_greedy: + # Use fast greedy heuristics + if config.strategy == Strategy.MINIMIZE_TIME: + result = greedy_minimize_time(self.graph, self.dag, max_qubit_count=config.max_qubit_count) + else: # Strategy.MINIMIZE_SPACE + result = greedy_minimize_space(self.graph, self.dag) + else: + # Use CP-SAT solver for optimal solution + result = solve_schedule(self.graph, self.dag, config, timeout) + if result is None: return False diff --git a/tests/test_greedy_scheduler.py b/tests/test_greedy_scheduler.py new file mode 100644 index 00000000..b74a995b --- /dev/null +++ b/tests/test_greedy_scheduler.py @@ -0,0 +1,476 @@ +"""Test greedy scheduling algorithms.""" + +import time + +import pytest + +from graphqomb.graphstate import GraphState +from graphqomb.greedy_scheduler import ( + greedy_minimize_space, + greedy_minimize_time, +) +from graphqomb.schedule_solver import ScheduleConfig, Strategy +from graphqomb.scheduler import Scheduler + + +def test_greedy_minimize_time_simple() -> None: + """Test greedy_minimize_time on a simple graph.""" + # Create a simple 3-node chain graph + graph = GraphState() + node0 = graph.add_physical_node() + node1 = graph.add_physical_node() + node2 = graph.add_physical_node() + graph.add_physical_edge(node0, node1) + graph.add_physical_edge(node1, node2) + qindex = 0 + graph.register_input(node0, qindex) + graph.register_output(node2, qindex) + + flow = {node0: {node1}, node1: {node2}} + scheduler = Scheduler(graph, flow) + + # Run greedy scheduler + prepare_time, measure_time = greedy_minimize_time(graph, scheduler.dag) + + # Check that all non-input nodes have preparation times + assert node1 in prepare_time + assert node0 not in prepare_time # Input node should not be prepared + + # Check that all non-output nodes have measurement times + assert node0 in measure_time + assert node1 in measure_time + assert node2 not in measure_time # Output node should not be measured + + # Verify DAG constraints: node0 measured before node1 + assert measure_time[node0] < measure_time[node1] + + +def test_greedy_minimize_space_simple() -> None: + """Test greedy_minimize_space on a simple graph.""" + # Create a simple 3-node chain graph + graph = GraphState() + node0 = graph.add_physical_node() + node1 = graph.add_physical_node() + node2 = graph.add_physical_node() + graph.add_physical_edge(node0, node1) + graph.add_physical_edge(node1, node2) + qindex = 0 + graph.register_input(node0, qindex) + graph.register_output(node2, qindex) + + flow = {node0: {node1}, node1: {node2}} + scheduler = Scheduler(graph, flow) + + # Run greedy scheduler + prepare_time, measure_time = greedy_minimize_space(graph, scheduler.dag) + + # Check that all non-input nodes have preparation times + assert node1 in prepare_time + assert node0 not in prepare_time # Input node should not be prepared + + # Check that all non-output nodes have measurement times + assert node0 in measure_time + assert node1 in measure_time + assert node2 not in measure_time # Output node should not be measured + + # Verify DAG constraints + assert measure_time[node0] < measure_time[node1] + + +def _compute_max_alive_qubits( + graph: GraphState, + prepare_time: dict[int, int], + measure_time: dict[int, int], +) -> int: + """Compute the maximum number of alive qubits over time. + + A node is considered alive at time t if: + - It is an input node and t >= -1 and t < measurement time (if any), or + - It has a preparation time p and t >= p and t < measurement time (if any). + + Returns + ------- + int + The maximum number of alive qubits at any time step. + """ + # Determine time range to check + max_t = max(set(prepare_time.values()) | set(measure_time.values()), default=0) + + max_alive = len(graph.input_node_indices) # At least inputs are alive at t = -1 + for t in range(max_t + 1): + alive_nodes = set() + for node in graph.physical_nodes: + # Determine preparation time + prep_t = -1 if node in graph.input_node_indices else prepare_time.get(node) + + if prep_t is None or t < prep_t: + continue + + # Determine measurement time (None for outputs or unscheduled) + meas_t = measure_time.get(node) + + if meas_t is None or t < meas_t: + alive_nodes.add(node) + + max_alive = max(max_alive, len(alive_nodes)) + + return max_alive + + +def test_greedy_minimize_time_with_max_qubit_count_respects_limit() -> None: + """Verify that greedy_minimize_time respects max_qubit_count.""" + graph = GraphState() + # chain graph: 0-1-2-3 + n0 = graph.add_physical_node() + n1 = graph.add_physical_node() + n2 = graph.add_physical_node() + n3 = graph.add_physical_node() + graph.add_physical_edge(n0, n1) + graph.add_physical_edge(n1, n2) + graph.add_physical_edge(n2, n3) + + qindex = 0 + graph.register_input(n0, qindex) + graph.register_output(n3, qindex) + + flow = {n0: {n1}, n1: {n2}, n2: {n3}} + scheduler = Scheduler(graph, flow) + + # Set max_qubit_count to 2 (a feasible value for this graph) + prepare_time, measure_time = greedy_minimize_time(graph, scheduler.dag, max_qubit_count=2) + + # Check basic properties + assert n1 in prepare_time + assert n0 not in prepare_time + assert n0 in measure_time + assert n2 in measure_time + assert n3 not in measure_time + + # Verify that the number of alive qubits never exceeds the limit + max_alive = _compute_max_alive_qubits(graph, prepare_time, measure_time) + assert max_alive <= 2 + + +def test_greedy_minimize_time_with_too_small_max_qubit_count_raises() -> None: + """Verify that greedy_minimize_time raises RuntimeError when max_qubit_count is too small.""" + graph = GraphState() + # chain graph: 0-1-2 (at least 2 qubits are needed) + n0 = graph.add_physical_node() + n1 = graph.add_physical_node() + n2 = graph.add_physical_node() + graph.add_physical_edge(n0, n1) + graph.add_physical_edge(n1, n2) + + qindex = 0 + graph.register_input(n0, qindex) + graph.register_output(n2, qindex) + + flow = {n0: {n1}, n1: {n2}} + scheduler = Scheduler(graph, flow) + + # max_qubit_count=1 is not feasible, so expect RuntimeError + with pytest.raises(RuntimeError, match="max_qubit_count"): + greedy_minimize_time(graph, scheduler.dag, max_qubit_count=1) + + +def test_greedy_scheduler_via_solve_schedule() -> None: + """Test greedy scheduler through Scheduler.solve_schedule with use_greedy=True.""" + # Create a simple graph + graph = GraphState() + node0 = graph.add_physical_node() + node1 = graph.add_physical_node() + node2 = graph.add_physical_node() + graph.add_physical_edge(node0, node1) + graph.add_physical_edge(node1, node2) + qindex = 0 + graph.register_input(node0, qindex) + graph.register_output(node2, qindex) + + flow = {node0: {node1}, node1: {node2}} + scheduler = Scheduler(graph, flow) + + # Test with greedy MINIMIZE_TIME + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=True) + success = scheduler.solve_schedule(config) + assert success + + # Verify schedule is valid + scheduler.validate_schedule() + + # Test with greedy MINIMIZE_SPACE + scheduler2 = Scheduler(graph, flow) + config = ScheduleConfig(strategy=Strategy.MINIMIZE_SPACE, use_greedy=True) + success = scheduler2.solve_schedule(config) + assert success + + # Verify schedule is valid + scheduler2.validate_schedule() + + +def test_greedy_vs_cpsat_correctness() -> None: + """Test that greedy scheduler produces valid schedules compared to CP-SAT.""" + # Create a slightly larger graph + graph = GraphState() + nodes = [graph.add_physical_node() for _ in range(5)] + + # Create a chain + for i in range(4): + graph.add_physical_edge(nodes[i], nodes[i + 1]) + + qindex = 0 + graph.register_input(nodes[0], qindex) + graph.register_output(nodes[4], qindex) + + flow = {nodes[i]: {nodes[i + 1]} for i in range(4)} + + # Test greedy scheduler + scheduler_greedy = Scheduler(graph, flow) + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=True) + success_greedy = scheduler_greedy.solve_schedule(config) + assert success_greedy + + # Verify greedy schedule is valid + scheduler_greedy.validate_schedule() + + # Test CP-SAT scheduler + scheduler_cpsat = Scheduler(graph, flow) + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=False) + success_cpsat = scheduler_cpsat.solve_schedule(config, timeout=10) + assert success_cpsat + + # Verify CP-SAT schedule is valid + scheduler_cpsat.validate_schedule() + + # Both should produce valid schedules + # Note: Greedy may not be optimal, so we don't compare quality here + + +def test_greedy_scheduler_larger_graph() -> None: + """Test greedy scheduler on a larger graph to ensure scalability.""" + # Create a larger graph with branching structure + graph = GraphState() + num_layers = 4 + nodes_per_layer = 3 + + # Build layered graph + all_nodes = [] + for layer in range(num_layers): + layer_nodes = [graph.add_physical_node() for _ in range(nodes_per_layer)] + all_nodes.append(layer_nodes) + + # Connect to previous layer (if not first layer) + if layer > 0: + for i, node in enumerate(layer_nodes): + # Connect to corresponding node in previous layer + prev_node = all_nodes[layer - 1][i] + graph.add_physical_edge(prev_node, node) + + # Register inputs (first layer) and outputs (last layer) + for i, node in enumerate(all_nodes[0]): + graph.register_input(node, i) + for i, node in enumerate(all_nodes[-1]): + graph.register_output(node, i) + + # Build flow (simple forward flow) + flow = {} + for layer in range(num_layers - 1): + for i, node in enumerate(all_nodes[layer]): + if node not in graph.output_node_indices: + flow[node] = {all_nodes[layer + 1][i]} + + # Test greedy scheduler + scheduler = Scheduler(graph, flow) + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=True) + success = scheduler.solve_schedule(config) + assert success + + # Validate the schedule + scheduler.validate_schedule() + + # Check that we got reasonable results + assert scheduler.num_slices() > 0 + assert scheduler.num_slices() <= num_layers * 2 # Reasonable upper bound + + +@pytest.mark.parametrize("strategy", [Strategy.MINIMIZE_TIME, Strategy.MINIMIZE_SPACE]) +def test_greedy_scheduler_both_strategies(strategy: Strategy) -> None: + """Test greedy scheduler with both optimization strategies.""" + # Create a graph + graph = GraphState() + node0 = graph.add_physical_node() + node1 = graph.add_physical_node() + node2 = graph.add_physical_node() + node3 = graph.add_physical_node() + graph.add_physical_edge(node0, node1) + graph.add_physical_edge(node1, node2) + graph.add_physical_edge(node2, node3) + qindex = 0 + graph.register_input(node0, qindex) + graph.register_output(node3, qindex) + + flow = {node0: {node1}, node1: {node2}, node2: {node3}} + scheduler = Scheduler(graph, flow) + + # Test with specified strategy + config = ScheduleConfig(strategy=strategy, use_greedy=True) + success = scheduler.solve_schedule(config) + assert success + + # Validate schedule + scheduler.validate_schedule() + + +def test_greedy_minimize_space_wrapper() -> None: + """Test the greedy_minimize_space wrapper function.""" + # Create a simple graph + graph = GraphState() + node0 = graph.add_physical_node() + node1 = graph.add_physical_node() + node2 = graph.add_physical_node() + graph.add_physical_edge(node0, node1) + graph.add_physical_edge(node1, node2) + qindex = 0 + graph.register_input(node0, qindex) + graph.register_output(node2, qindex) + + flow = {node0: {node1}, node1: {node2}} + scheduler = Scheduler(graph, flow) + + # Test MINIMIZE_TIME + result = greedy_minimize_time(graph, scheduler.dag) + assert result is not None + prepare_time, measure_time = result + assert len(prepare_time) > 0 + assert len(measure_time) > 0 + + # Test MINIMIZE_SPACE + result = greedy_minimize_space(graph, scheduler.dag) + assert result is not None + prepare_time, measure_time = result + assert len(prepare_time) > 0 + assert len(measure_time) > 0 + + +def test_greedy_scheduler_performance() -> None: + """Test that greedy scheduler is significantly faster than CP-SAT on larger graphs.""" + # Create a larger graph (chain of 20 nodes) + graph = GraphState() + nodes = [graph.add_physical_node() for _ in range(20)] + + for i in range(19): + graph.add_physical_edge(nodes[i], nodes[i + 1]) + + qindex = 0 + graph.register_input(nodes[0], qindex) + graph.register_output(nodes[-1], qindex) + + flow = {nodes[i]: {nodes[i + 1]} for i in range(19)} + + # Time greedy scheduler + scheduler_greedy = Scheduler(graph, flow) + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=True) + + start_greedy = time.perf_counter() + success_greedy = scheduler_greedy.solve_schedule(config) + end_greedy = time.perf_counter() + greedy_time = end_greedy - start_greedy + + assert success_greedy + scheduler_greedy.validate_schedule() + + # Time CP-SAT scheduler + scheduler_cpsat = Scheduler(graph, flow) + + start_cpsat = time.perf_counter() + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=False) + success_cpsat = scheduler_cpsat.solve_schedule(config, timeout=10) + end_cpsat = time.perf_counter() + cpsat_time = end_cpsat - start_cpsat + + assert success_cpsat + scheduler_cpsat.validate_schedule() + + # Print timing information for debugging + print(f"\nGreedy time: {greedy_time:.4f}s") + print(f"CP-SAT time: {cpsat_time:.4f}s") + print(f"Speedup: {cpsat_time / greedy_time:.1f}x") + + # Greedy should be significantly faster (at least 5x for this size) + # Note: We use a conservative factor to avoid flaky tests + assert greedy_time < cpsat_time + + +def test_greedy_scheduler_dag_constraints() -> None: + """Test that greedy scheduler respects DAG constraints.""" + # Create a graph with more complex dependencies + graph = GraphState() + nodes = [graph.add_physical_node() for _ in range(6)] + + # Create edges forming a DAG structure + # 0 -> 1 -> 3 -> 5 + # 2 -> 4 -> + graph.add_physical_edge(nodes[0], nodes[1]) + graph.add_physical_edge(nodes[1], nodes[2]) + graph.add_physical_edge(nodes[1], nodes[3]) + graph.add_physical_edge(nodes[2], nodes[4]) + graph.add_physical_edge(nodes[3], nodes[5]) + graph.add_physical_edge(nodes[4], nodes[5]) + + qindex = 0 + graph.register_input(nodes[0], qindex) + graph.register_output(nodes[5], qindex) + + # Create flow with dependencies + flow = { + nodes[0]: {nodes[1]}, + nodes[1]: {nodes[2], nodes[3]}, + nodes[2]: {nodes[4]}, + nodes[3]: {nodes[5]}, + nodes[4]: {nodes[5]}, + } + + scheduler = Scheduler(graph, flow) + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=True) + + # Note: This flow creates a cyclic DAG (nodes 3 and 4 have circular dependency) + # The greedy scheduler should raise RuntimeError for invalid flows + with pytest.raises(RuntimeError, match="No nodes can be measured"): + scheduler.solve_schedule(config) + + +def test_greedy_scheduler_edge_constraints() -> None: + """Test that greedy scheduler respects edge constraints (neighbor preparation).""" + # Create a simple graph + graph = GraphState() + node0 = graph.add_physical_node() + node1 = graph.add_physical_node() + node2 = graph.add_physical_node() + graph.add_physical_edge(node0, node1) + graph.add_physical_edge(node1, node2) + qindex = 0 + graph.register_input(node0, qindex) + graph.register_output(node2, qindex) + + flow = {node0: {node1}, node1: {node2}} + scheduler = Scheduler(graph, flow) + config = ScheduleConfig(strategy=Strategy.MINIMIZE_TIME, use_greedy=True) + success = scheduler.solve_schedule(config) + assert success + + # Validate edge constraints via validate_schedule + scheduler.validate_schedule() + + # Manually check: neighbors must be prepared before measurement + # node0 (input) is prepared at time -1, node1 prepared at some time + # node0 must be measured after node1 is prepared + # This is ensured by the auto-scheduled entanglement times + + # Check that entanglement times were auto-scheduled correctly + edge01 = (node0, node1) + edge12 = (node1, node2) + assert scheduler.entangle_time[edge01] is not None + assert scheduler.entangle_time[edge12] is not None + + # Entanglement must happen before measurement + assert scheduler.entangle_time[edge01] < scheduler.measure_time[node0] + assert scheduler.entangle_time[edge12] < scheduler.measure_time[node1]