From 81042ff25d02e1e0ebee569f41c36e1181d7176a Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 4 Sep 2025 12:29:57 -0400 Subject: [PATCH 01/36] Graph-based Catalyst decomposition at MLIR (cherry-picking commits for e2e testing from v3) --- frontend/catalyst/device/qjit_device.py | 7 + frontend/catalyst/from_plxpr/decompose.py | 677 ++++++++++++++++++ frontend/catalyst/from_plxpr/from_plxpr.py | 93 ++- frontend/catalyst/jax_extras/lowering.py | 17 +- frontend/catalyst/jax_tracer.py | 4 +- frontend/catalyst/jit.py | 22 +- .../from_plxpr/test_capture_integration.py | 4 +- .../from_plxpr/test_from_plxpr_decompose.py | 310 ++++++++ 8 files changed, 1120 insertions(+), 14 deletions(-) create mode 100644 frontend/catalyst/from_plxpr/decompose.py create mode 100644 frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index a129568639..b9c9426201 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -108,6 +108,13 @@ RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] +# A list of operations that the can be represented +# in the Catalyst compiler. This is a superset of +# the operations supported by the runtime. +# FIXME: ops with OpName(params, wires) signatures +# can be represented in the Catalyst compiler. +COMPILER_OPERATIONS = RUNTIME_OPERATIONS + ["RotXZX"] + # The runtime interface does not care about specific gate properties, so set them all to True. RUNTIME_OPERATIONS = { op: OperatorProperties(invertible=True, controllable=True, differentiable=True) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py new file mode 100644 index 0000000000..c3d576ee26 --- /dev/null +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -0,0 +1,677 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A transform for the new MLIT-based Catalyst decomposition system. + +Note: this transform will be merged with the PennyLane decomposition transform as part of +the PennyLane <> Catalyst unification project. +""" + + +from __future__ import annotations + +import warnings +from collections import ChainMap +from collections.abc import Callable, Generator, Iterable, Sequence +from functools import partial + +import jax +import pennylane as qml + +# Support ctrl ops in decomposition (adapted from PL's DecomposeInterpreter) +from pennylane.capture.primitives import ctrl_transform_prim + +# GraphSolutionInterpreter: +from pennylane.decomposition import DecompositionGraph +from pennylane.decomposition.collect_resource_ops import CollectResourceOps +from pennylane.decomposition.decomposition_graph import DecompGraphSolution +from pennylane.decomposition.utils import translate_op_alias +from pennylane.operation import Operator + + +# pylint: disable=too-many-instance-attributes +class PreMlirDecomposeInterpreter(qml.capture.PlxprInterpreter): + """Plxpr Interpreter for applying the Catalyst compiler-specific decomposition transform + to callables or jaxpr when program capture is enabled. + + TODO: + - Enable graph-based for pre-mlir decomposition + (not priority for this stage -- needs further maintenance in PennyLane/decomposition) + - Add a more optimized support for PL's templates + + Note: + - This interpreter shares common code with PL's DecomposeInterpreter. + We will merge the two in the future near the completion of the unification project. + """ + + def __init__( + self, + *, + gate_set=None, + max_expansion=None, + ): # pylint: disable=too-many-arguments + + self.max_expansion = max_expansion + self._current_depth = 0 + self._target_gate_names = None + + # We use a ChainMap to store the environment frames, which allows us to push and pop + # environments without copying the interpreter instance when we evaluate a jaxpr of + # a dynamic decomposition. The name is different from the _env in the parent class + # (a dictionary) to avoid confusion. + self._env_map = ChainMap() + + gate_set, stopping_condition = _resolve_gate_set(gate_set) + self._gate_set = gate_set + self._stopping_condition = stopping_condition + + def setup(self) -> None: + """Setup the environment for the interpreter by pushing a new environment frame.""" + + # This is the local environment for the jaxpr evaluation, on the top of the stack, + # from which the interpreter reads and writes variables. + # ChainMap writes to the first dictionary in the chain by default. + self._env_map = self._env_map.new_child() + + def cleanup(self) -> None: + """Cleanup the environment by popping the top-most environment frame.""" + + # We delete the top-most environment frame after the evaluation is done. + self._env_map = self._env_map.parents + + def read(self, var): + """Extract the value corresponding to a variable.""" + return var.val if isinstance(var, jax.extend.core.Literal) else self._env_map[var] + + def stopping_condition(self, op: qml.operation.Operator) -> bool: + """Function to determine whether an operator needs to be decomposed or not. + + Args: + op (qml.operation.Operator): Operator to check. + + Returns: + bool: Whether ``op`` is valid or needs to be decomposed. ``True`` means + that the operator does not need to be decomposed. + """ + + if not op.has_decomposition: + if not self._stopping_condition(op): + warnings.warn( + f"Operator {op.name} does not define a decomposition and was not " + f"found in the target gate set. To remove this warning, add the operator " + f"name ({op.name}) or type ({type(op)}) to the gate set.", + UserWarning, + ) + return True + + return self._stopping_condition(op) + + def decompose_operation(self, op: qml.operation.Operator): + """Decompose a PennyLane operation instance if it does not satisfy the + provided gate set. + + Args: + op (Operator): a pennylane operator instance + + This method is only called when the operator's output is a dropped variable, + so the output will not affect later equations in the circuit. + + See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`. + """ + + if self._stopping_condition(op): + return self.interpret_operation(op) + + max_expansion = ( + self.max_expansion - self._current_depth if self.max_expansion is not None else None + ) + + with qml.capture.pause(): + decomposition = list( + _operator_decomposition_gen( + op, + self.stopping_condition, + max_expansion=max_expansion, + ) + ) + + return [self.interpret_operation(decomp_op) for decomp_op in decomposition] + + def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator): + """Creates and evaluates a Jaxpr of the plxpr decomposition of an operator.""" + + if self._stopping_condition(op): + return self.interpret_operation(op) + + if self.max_expansion is not None and self._current_depth >= self.max_expansion: + return self.interpret_operation(op) + + compute_qfunc_decomposition = op.compute_qfunc_decomposition + + args = (*op.parameters, *op.wires) + + jaxpr_decomp = qml.capture.make_plxpr( + partial(compute_qfunc_decomposition, **op.hyperparameters) + )(*args) + + self._current_depth += 1 + # We don't need to copy the interpreter here, as the jaxpr of the decomposition + # is evaluated with a new environment frame placed on top of the stack. + out = self.eval(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args) + self._current_depth -= 1 + + return out + + # pylint: disable=too-many-branches + def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list: + """ + Evaluates a jaxpr, which can also be generated by a dynamic decomposition. + + Args: + jaxpr_decomp (jax.extend.core.Jaxpr): the Jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args: the arguments to use in the evaluation + """ + + self.setup() + + for arg, invar in zip(args, jaxpr.invars, strict=True): + self._env_map[invar] = arg + for const, constvar in zip(consts, jaxpr.constvars, strict=True): + self._env_map[constvar] = const + + for eq in jaxpr.eqns: + + prim_type = getattr(eq.primitive, "prim_type", "") + custom_handler = self._primitive_registrations.get(eq.primitive, None) + + if custom_handler: + + invals = [self.read(invar) for invar in eq.invars] + outvals = custom_handler(self, *invals, **eq.params) + + elif prim_type == "operator": + outvals = self.interpret_operation_eqn(eq) + elif prim_type == "measurement": + outvals = self.interpret_measurement_eqn(eq) + else: + invals = [self.read(invar) for invar in eq.invars] + subfuns, params = eq.primitive.get_bind_params(eq.params) + outvals = eq.primitive.bind(*subfuns, *invals, **params) + + if not eq.primitive.multiple_results: + outvals = [outvals] + + for outvar, outval in zip(eq.outvars, outvals, strict=True): + self._env_map[outvar] = outval + + outvals = [] + for var in jaxpr.outvars: + outval = self.read(var) + if isinstance(outval, qml.operation.Operator): + outvals.append(self.interpret_operation(outval)) + else: + outvals.append(outval) + + self.cleanup() + + return outvals + + def interpret_operation_eqn(self, eqn: jax.extend.core.JaxprEqn): + """Interpret an equation corresponding to an operator. + + If the operator has a dynamic decomposition defined, this method will + create and evaluate the jaxpr of the decomposition using the :meth:`~.eval` method. + + Args: + eqn (jax.extend.core.JaxprEqn): a jax equation for an operator. + + See also: :meth:`~.interpret_operation`. + + """ + + invals = (self.read(invar) for invar in eqn.invars) + + with qml.QueuingManager.stop_recording(): + op = eqn.primitive.impl(*invals, **eqn.params) + + if not eqn.outvars[0].__class__.__name__ == "DropVar": + return op + + return self.decompose_operation(op) + + +# pylint: disable=too-many-arguments +@PreMlirDecomposeInterpreter.register_primitive(ctrl_transform_prim) +def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): + consts = invals[:n_consts] + args = invals[n_consts:-n_control] + control_wires = invals[-n_control:] + + unroller = ControlTransformInterpreter( + control_wires, control_values=control_values, work_wires=work_wires + ) + + def wrapper(*inner_args): + return unroller.eval(jaxpr, consts, *inner_args) + + jaxpr = jax.make_jaxpr(wrapper)(*args) + return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + + +class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): + """Interpreter for getting the decomposition graph solution + from a jaxpr when program capture is enabled. + + This interpreter should be used after the PreMlirDecomposeInterpreter. + """ + + def __init__( + self, + *, + gate_set=None, + stopping_condition=None, + max_expansion=None, + fixed_decomps=None, + alt_decomps=None, + ): # pylint: disable=too-many-arguments + + self.max_expansion = max_expansion + self._current_depth = 0 + + if not qml.decomposition.enabled_graph() and (fixed_decomps or alt_decomps): + raise TypeError( + "The keyword arguments fixed_decomps and alt_decomps are only available with " + "the new experimental graph-based decomposition system. Use qml.decomposition.enable_graph() " + "to enable the new system." + ) + + self._decomp_graph_solution = None + self._target_gate_names = None + self._fixed_decomps, self._alt_decomps = fixed_decomps, alt_decomps + + # We use a ChainMap to store the environment frames, which allows us to push and pop + # environments without copying the interpreter instance when we evaluate a jaxpr of + # a dynamic decomposition. The name is different from the _env in the parent class + # (a dictionary) to avoid confusion. + self._env_map = ChainMap() + + gate_set, stopping_condition = _resolve_gate_set(gate_set, stopping_condition) + self._gate_set = gate_set + self._stopping_condition = stopping_condition + + def setup(self) -> None: + """Setup the environment for the interpreter by pushing a new environment frame.""" + + # This is the local environment for the jaxpr evaluation, on the top of the stack, + # from which the interpreter reads and writes variables. + # ChainMap writes to the first dictionary in the chain by default. + self._env_map = self._env_map.new_child() + + def cleanup(self) -> None: + """Cleanup the environment by popping the top-most environment frame.""" + + # We delete the top-most environment frame after the evaluation is done. + self._env_map = self._env_map.parents + + def read(self, var): + """Extract the value corresponding to a variable.""" + return var.val if isinstance(var, jax.extend.core.Literal) else self._env_map[var] + + def stopping_condition(self, op: qml.operation.Operator) -> bool: + """Function to determine whether an operator needs to be decomposed or not. + + Args: + op (qml.operation.Operator): Operator to check. + + Returns: + bool: Whether ``op`` is valid or needs to be decomposed. ``True`` means + that the operator does not need to be decomposed. + """ + + # If the new graph-based decomposition is enabled, + # we don't rely on the has_decomposition attribute. + if qml.decomposition.enabled_graph(): + return self._stopping_condition(op) + + if not op.has_decomposition: + if not self._stopping_condition(op): + warnings.warn( + f"Operator {op.name} does not define a decomposition and was not " + f"found in the target gate set. To remove this warning, add the operator " + f"name ({op.name}) or type ({type(op)}) to the gate set.", + UserWarning, + ) + return True + + return self._stopping_condition(op) + + def decompose_operation(self, op: qml.operation.Operator): + """Decompose a PennyLane operation instance if it does not satisfy the + provided gate set. + + Args: + op (Operator): a pennylane operator instance + + This method is only called when the operator's output is a dropped variable, + so the output will not affect later equations in the circuit. + + See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`. + """ + + if self._stopping_condition(op): + return self.interpret_operation(op) + + max_expansion = ( + self.max_expansion - self._current_depth if self.max_expansion is not None else None + ) + + with qml.capture.pause(): + decomposition = list( + _operator_decomposition_gen( + op, + self.stopping_condition, + max_expansion=max_expansion, + decomp_graph_solution=self._decomp_graph_solution, + ) + ) + + return [self.interpret_operation(decomp_op) for decomp_op in decomposition] + + def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator): + """Creates and evaluates a Jaxpr of the plxpr decomposition of an operator.""" + + if self._stopping_condition(op): + return self.interpret_operation(op) + + if self.max_expansion is not None and self._current_depth >= self.max_expansion: + return self.interpret_operation(op) + + if qml.decomposition.enabled_graph() and self._decomp_graph_solution.is_solved_for(op): + + rule = self._decomp_graph_solution.decomposition(op) + num_wires = len(op.wires) + + def compute_qfunc_decomposition(*_args, **_kwargs): + wires = qml.math.array(_args[-num_wires:], like="jax") + rule(*_args[:-num_wires], wires=wires, **_kwargs) + + else: + compute_qfunc_decomposition = op.compute_qfunc_decomposition + + args = (*op.parameters, *op.wires) + + jaxpr_decomp = qml.capture.make_plxpr( + partial(compute_qfunc_decomposition, **op.hyperparameters) + )(*args) + + self._current_depth += 1 + # We don't need to copy the interpreter here, as the jaxpr of the decomposition + # is evaluated with a new environment frame placed on top of the stack. + out = self.eval(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args) + self._current_depth -= 1 + + return out + + # pylint: disable=too-many-branches + def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list: + """ + Evaluates a jaxpr, which can also be generated by a dynamic decomposition. + + Args: + jaxpr_decomp (jax.extend.core.Jaxpr): the Jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args: the arguments to use in the evaluation + """ + + self.setup() + + for arg, invar in zip(args, jaxpr.invars, strict=True): + self._env_map[invar] = arg + for const, constvar in zip(consts, jaxpr.constvars, strict=True): + self._env_map[constvar] = const + + if qml.decomposition.enabled_graph() and not self._decomp_graph_solution: + + with qml.capture.pause(): + + collector = CollectResourceOps() + collector.eval(jaxpr, consts, *args) + operations = collector.state["ops"] + + if operations: + self._decomp_graph_solution = _construct_and_solve_decomp_graph( + operations, + self._gate_set, + self._fixed_decomps, + self._alt_decomps, + ) + + # for op, decomp in self._decomp_graph_solution.decompositions(): + # print(f"Decomposition for {op}: {decomp}") + + for eq in jaxpr.eqns: + + prim_type = getattr(eq.primitive, "prim_type", "") + custom_handler = self._primitive_registrations.get(eq.primitive, None) + + if custom_handler: + + invals = [self.read(invar) for invar in eq.invars] + outvals = custom_handler(self, *invals, **eq.params) + + elif prim_type == "operator": + outvals = self.interpret_operation_eqn(eq) + elif prim_type == "measurement": + outvals = self.interpret_measurement_eqn(eq) + else: + invals = [self.read(invar) for invar in eq.invars] + subfuns, params = eq.primitive.get_bind_params(eq.params) + outvals = eq.primitive.bind(*subfuns, *invals, **params) + + if not eq.primitive.multiple_results: + outvals = [outvals] + + for outvar, outval in zip(eq.outvars, outvals, strict=True): + self._env_map[outvar] = outval + + outvals = [] + for var in jaxpr.outvars: + outval = self.read(var) + if isinstance(outval, qml.operation.Operator): + outvals.append(self.interpret_operation(outval)) + else: + outvals.append(outval) + + self.cleanup() + + return outvals + + def interpret_operation_eqn(self, eqn: jax.extend.core.JaxprEqn): + """Interpret an equation corresponding to an operator. + + If the operator has a dynamic decomposition defined, this method will + create and evaluate the jaxpr of the decomposition using the :meth:`~.eval` method. + + Args: + eqn (jax.extend.core.JaxprEqn): a jax equation for an operator. + + See also: :meth:`~.interpret_operation`. + + """ + + invals = (self.read(invar) for invar in eqn.invars) + + with qml.QueuingManager.stop_recording(): + op = eqn.primitive.impl(*invals, **eqn.params) + + if not eqn.outvars[0].__class__.__name__ == "DropVar": + return op + + # _evaluate_jaxpr_decomposition should be used when the operator defines a + # compute_qfunc_decomposition, or if graph-based decomposition is enabled and + # a solution is found for this operator in the graph. + if ( + op.has_qfunc_decomposition + or qml.decomposition.enabled_graph() + and self._decomp_graph_solution.is_solved_for(op) + ): + return self._evaluate_jaxpr_decomposition(op) + + return self.decompose_operation(op) + + +# pylint: disable=too-many-arguments +@GraphSolutionInterpreter.register_primitive(ctrl_transform_prim) +def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): + consts = invals[:n_consts] + args = invals[n_consts:-n_control] + control_wires = invals[-n_control:] + + unroller = ControlTransformInterpreter( + control_wires, control_values=control_values, work_wires=work_wires + ) + + def wrapper(*inner_args): + return unroller.eval(jaxpr, consts, *inner_args) + + jaxpr = jax.make_jaxpr(wrapper)(*args) + return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + + +class ControlTransformInterpreter(qml.capture.PlxprInterpreter): + """Interpreter for replacing control transforms with individually controlled ops.""" + + def __init__(self, control_wires, control_values=None, work_wires=None): + super().__init__() + self.control_wires = control_wires + self.control_values = control_values + self.work_wires = work_wires + + def interpret_operation(self, op): + """Interpret operation.""" + with qml.capture.pause(): + ctrl_op = qml.ctrl( + op, + self.control_wires, + control_values=self.control_values, + work_wires=self.work_wires, + ) + super().interpret_operation(ctrl_op) + + +def _operator_decomposition_gen( + op: qml.operation.Operator, + acceptance_function: Callable[[qml.operation.Operator], bool], + max_expansion: int | None = None, + current_depth=0, +) -> Generator[qml.operation.Operator]: + """A generator that yields the next operation that is accepted.""" + + max_depth_reached = False + decomp = [] + + if max_expansion is not None and max_expansion <= current_depth: + max_depth_reached = True + + if acceptance_function(op) or max_depth_reached: + yield op + else: + decomp = op.decomposition() + current_depth += 1 + + for sub_op in decomp: + yield from _operator_decomposition_gen( + sub_op, + acceptance_function, + max_expansion=max_expansion, + current_depth=current_depth, + ) + + +def _resolve_gate_set( + gate_set: set[type | str] | dict[type | str, float] = None, + stopping_condition: Callable[[qml.operation.Operator], bool] = None, +) -> tuple[set[type | str] | dict[type | str, float], Callable[[qml.operation.Operator], bool]]: + """Resolve the gate set and the stopping condition from arguments. + + The ``gate_set`` can be provided in various forms, and the ``stopping_condition`` may or + may not be provided. This function will resolve the gate set and the stopping condition + to the following standardized form: + + - The ``gate_set`` is set of operator **types** and/or names, or a dictionary mapping operator + types and/or names to their respective costs. This is only used by the DecompositionGraph + - The ``stopping_condition`` is a function that takes an operator **instances** and returns + ``True`` if the operator does not need to be decomposed. This is used during decomposition. + + """ + + if gate_set is None: + gate_set = set(qml.ops.__all__) + + if isinstance(gate_set, (str, type)): + gate_set = {gate_set} + + if isinstance(gate_set, dict): + + if any(v < 0 for v in gate_set.values()): + raise ValueError("Negative gate weights provided to gate_set are not supported.") + + if isinstance(gate_set, Iterable): + + gate_types = tuple(gate for gate in gate_set if isinstance(gate, type)) + gate_names = {translate_op_alias(gate) for gate in gate_set if isinstance(gate, str)} + + def gate_set_contains(op: Operator) -> bool: + return (op.name in gate_names) or isinstance(op, gate_types) + + elif isinstance(gate_set, Callable): # pylint:disable=isinstance-second-argument-not-valid-type + + gate_set_contains = gate_set + + else: + raise TypeError("Invalid gate_set type. Must be an iterable, dictionary, or function.") + + if stopping_condition: + + # Even when the user provides a stopping condition, we still need to check + # whether an operator belongs to the target gate set. This is to prevent + # the case of an operator missing the stopping condition but doesn't have + # a decomposition assigned due to being in the target gate set. + def _stopping_condition(op): + return gate_set_contains(op) or stopping_condition(op) + + else: + # If the stopping condition is not explicitly provided, the default is to simply check + # whether an operator belongs to the target gate set. + _stopping_condition = gate_set_contains + + return gate_set, _stopping_condition + + +def _construct_and_solve_decomp_graph( + operations, target_gates, fixed_decomps, alt_decomps +) -> DecompGraphSolution: + """Create and solve a DecompositionGraph instance to optimize the decomposition.""" + + # Create the decomposition graph + decomp_graph = DecompositionGraph( + operations, + target_gates, + fixed_decomps=fixed_decomps, + alt_decomps=alt_decomps, + ) + + # Find the efficient pathways to the target gate set + return decomp_graph.solve() diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 5bbf3b9e8c..7f1c19813c 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -44,6 +44,8 @@ from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot from catalyst.device import extract_backend_info +from catalyst.device.qjit_device import COMPILER_OPERATIONS +from catalyst.from_plxpr.decompose import GraphSolutionInterpreter, PreMlirDecomposeInterpreter from catalyst.from_plxpr.qubit_handler import QubitHandler from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( @@ -176,14 +178,52 @@ def f(x): class WorkflowInterpreter(PlxprInterpreter): - """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant.""" + """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst variant.""" def __init__(self): self._pass_pipeline = [] self.qubit_handler = None + self.compiler_decompose = False super().__init__() +def _decompose_to_compiler_gateset(qfunc_jaxpr, consts, non_const_args): + """First stage decomposition to compiler gate set. + + Currently, the compiler can only handle a limited set of gates and + may not support all generic gates and templates of the original circuit. + We perform a first stage decomposition to the compiler gate set, which includes + only a subset of the original gates that can be represented in MLIR using the + `quantum.custom` primitive. + """ + + # TODO: The compiler should be able to handle all gate + # adhering to quantum.custom primitive.This includes + # all the gates with parameters of type `TensorLike` + # and wires of type `WiresLike` with no hyperparams. + # Update `gate_set` to use this as the stopping condition + # of the decomposition transform. + gate_set = COMPILER_OPERATIONS + + decomp_args = () + decomp_kwargs = {"gate_set": gate_set} + + # disable the graph decomposition optimization + graph_decomp_status = False + if qml.decomposition.enabled_graph(): + graph_decomp_status = True + qml.decomposition.disable_graph() + + new_jaxpr = qml.transforms.decompose.plxpr_transform( + qfunc_jaxpr, consts, decomp_args, decomp_kwargs, *non_const_args + ) + + if graph_decomp_status: + qml.decomposition.enable_graph() + + return new_jaxpr + + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( @@ -263,14 +303,50 @@ def handle_transform( non_const_args = args[args_slice] targs = args[targs_slice] - if catalyst_pass_name is None: - # Use PL's ExpandTransformsInterpreter to expand this and any embedded - # transform according to PL rules. It works by overriding the primitive - # registration, making all embedded transforms follow the PL rules - # from now on, hence ignoring the Catalyst pass conversion - def wrapper(*args): - return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) + # Check if the transform is a decomposition transform + # If so, we'll set the compiler_decompose flag to trigger + # 1. Construct the graph with the list of ops and the target gateset + # 2. Capture and lower the decomposition qfuncs down to MLIR + # 3. Bypass the custom PLxPR DecomposeInterpreter class + # + # Notes: + # - The list of target gateset is always taken from the transform's attributes + # and passed down to the MLIR lowering as a quantum function attribute. + if ( + hasattr(pl_plxpr_transform, "__name__") + and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" + and qml.decomposition.enabled_graph() + ): + self.compiler_decompose = True + # Use PL's ExpandTransformsInterpreter to expand this and any embedded + # transform according to PL rules. It works by overriding the primitive + # registration, making all embedded transforms follow the PL rules + # from now on, hence ignoring the Catalyst pass conversion + def wrapper(*args): + return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) + + if self.compiler_decompose: + gate_set = COMPILER_OPERATIONS + decomp_kwargs = {"gate_set": gate_set} + + pmd_interpreter = PreMlirDecomposeInterpreter(*targs, **decomp_kwargs) + + def pmd_wrapper(*args): + return pmd_interpreter.eval(inner_jaxpr, consts, *args) + + pmd_jaxpr = jax.make_jaxpr(pmd_wrapper)(*args) + + gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs) + + def gds_wrapper(*args): + return gds_interpreter.eval(pmd_jaxpr.jaxpr, consts, *args) + + gds_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + + return self.eval(gds_jaxpr.jaxpr, gds_jaxpr.consts, *non_const_args) + + if catalyst_pass_name is None: unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) final_jaxpr = pl_plxpr_transform( unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args @@ -311,6 +387,7 @@ def __init__(self, device, shots, qubit_handler, cache, *, control_wires=(), con # TODO: we assume the qreg value passed into a scope is the unique qreg in the scope # In other words, we assume no new qreg will be allocated in the scope self.qubit_handler = qubit_handler + self.compiler_decompose = False self.subroutine_cache = cache self.control_wires = control_wires """Any control wires used for a subroutine.""" diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index 7dc1382593..27afc8ba1a 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -53,7 +53,7 @@ @debug_logger -def jaxpr_to_mlir(func_name, jaxpr): +def jaxpr_to_mlir(func_name, jaxpr, py_attrs=None): """Lower a Jaxpr into an MLIR module. Args: @@ -81,6 +81,7 @@ def jaxpr_to_mlir(func_name, jaxpr): platform="cpu", axis_context=axis_context, name_stack=name_stack, + py_attrs=py_attrs, ) return module, context @@ -99,6 +100,7 @@ def custom_lower_jaxpr_to_module( replicated_args=None, arg_shardings=None, result_shardings=None, + py_attrs=None, ): """Lowers a top-level jaxpr to an MHLO module. @@ -109,6 +111,10 @@ def custom_lower_jaxpr_to_module( https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version released under the Apache License, Version 2.0, with the following copyright notice: + Note: We further modified this function to accept `py_attrs`, which allows for the passing + of custom attributes from Python to MLIR. This is currently used for passing + the target gate set information. + Copyright 2021 The JAX Authors. """ @@ -163,6 +169,15 @@ def custom_lower_jaxpr_to_module( continue if isinstance(op, FuncOp): op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage") + if py_attrs: # pass custom attributes from Python to MLIR + for attr_name, attr_value in py_attrs.items(): + try: + mlir_attr = get_mlir_attribute_from_pyval(list(attr_value)) + op.attributes[attr_name] = mlir_attr + except CompileError as e: + raise CompileError( + f"While converting Python attribute '{attr_name}': '{attr_value}' to MLIR: {e}" + ) from e if isinstance(op, ModuleOp): worklist += [*op.body.operations] diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 5d83b8f3a8..6f42538594 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -617,7 +617,7 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_in @debug_logger -def lower_jaxpr_to_mlir(jaxpr, func_name): +def lower_jaxpr_to_mlir(jaxpr, func_name, py_attrs=None): """Lower a JAXPR to MLIR. Args: @@ -632,7 +632,7 @@ def lower_jaxpr_to_mlir(jaxpr, func_name): MemrefCallable.clearcache() with transient_jax_config({"jax_dynamic_shapes": True}): - mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr) + mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr, py_attrs=py_attrs) return mlir_module, ctx diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index d7bdc00375..1621e4534b 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -546,6 +546,26 @@ def __init__(self, fn, compile_options): self.user_sig = get_type_annotations(fn) self._validate_configuration() + # Extract transform python kwargs from the function + # with both capture enabled and disabled + + # Note: as we are currently interested in decompose + # target gateset, we avoid passing any non-decompose kwargs + transform_lists = fn._transform_program if hasattr(fn, "_transform_program") else [] + decompose_transform_kwargs = [ + t.kwargs + for t in transform_lists + if hasattr(t, "plxpr_transform") + and hasattr(t.plxpr_transform, "__name__") + and "decompose" in t.plxpr_transform.__name__ + ] + + # TODO: Remove this in the future after enabling multiple decomposition support + # in the MLIR rewrite pass. + if len(decompose_transform_kwargs) > 1: + raise ValueError("Multiple decompose transform is not yet supported.") + self.py_attrs = decompose_transform_kwargs[0] if decompose_transform_kwargs else None + # If static_argnames are present, convert them to static_argnums if compile_options.static_argnames is not None: compile_options.static_argnums = merge_static_argname_into_argnum( @@ -773,7 +793,7 @@ def generate_ir(self): Tuple[ir.Module, str]: the in-memory MLIR module and its string representation """ - mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__) + mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__, py_attrs=self.py_attrs) # Inject Runtime Library-specific functions (e.g. setup/teardown). inject_functions(mlir_module, ctx, self.compile_options.seed) diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index bbbd07935b..45bf6846e9 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -1291,7 +1291,7 @@ def test_transform_decompose_workflow(self, backend): qml.capture.enable() @qjit(target="mlir") - @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) + @partial(qml.transforms.decompose, gate_set=["RX", "RY", "RZ"]) @qml.qnode(qml.device(backend, wires=2)) def captured_circuit(x: float, y: float, z: float): qml.Rot(x, y, z, 0) @@ -1305,7 +1305,7 @@ def captured_circuit(x: float, y: float, z: float): # Capture disabled @qjit - @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) + @partial(qml.transforms.decompose, gate_set=["RX", "RY", "RZ"]) @qml.qnode(qml.device(backend, wires=2)) def circuit(x: float, y: float, z: float): qml.Rot(x, y, z, 0) diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py new file mode 100644 index 0000000000..aeca1a1398 --- /dev/null +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py @@ -0,0 +1,310 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests the ``decompose`` transform with the new Catalyst graph-based decomposition system.""" +from functools import partial + +import numpy as np +import pennylane as qml +import pytest + +pytestmark = pytest.mark.usefixtures("disable_capture") + + +class TestDecomposeGraphEnabled: + """Tests the decompose transform with graph enabled.""" + + @pytest.mark.integration + def test_mixed_gate_set_specification(self): + """Tests that the gate_set can be specified as both a type and a string.""" + + qml.decomposition.enable_graph() + + tape = qml.tape.QuantumScript([qml.RX(0.5, wires=[0]), qml.CNOT(wires=[0, 1])]) + [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RX", qml.CNOT}) + assert new_tape.operations == tape.operations + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_gate_set_targeted_decompositions(self): + """Tests that a simple circuit is correctly decomposed into different gate sets.""" + + qml.decomposition.enable_graph() + + tape = qml.tape.QuantumScript( + [ + qml.H(0), # non-parametric op + qml.Rot(0.1, 0.2, 0.3, wires=[0]), # parametric single-qubit op + qml.MultiRZ(0.5, wires=[0, 1, 2]), # parametric multi-qubit op + ] + ) + + [new_tape], _ = qml.transforms.decompose(tape, gate_set={"Hadamard", "CNOT", "RZ", "RY"}) + assert new_tape.operations == [ + # H is in the target gate set + qml.H(0), + # Rot decomposes to ZYZ + qml.RZ(0.1, wires=[0]), + qml.RY(0.2, wires=[0]), + qml.RZ(0.3, wires=[0]), + # Decomposition of MultiRZ + qml.CNOT(wires=[2, 1]), + qml.CNOT(wires=[1, 0]), + qml.RZ(0.5, wires=[0]), + qml.CNOT(wires=[1, 0]), + qml.CNOT(wires=[2, 1]), + ] + + [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RY", "RZ", "CZ", "GlobalPhase"}) + assert new_tape.operations == [ + # The H decomposes to RZ and RY + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + # Rot decomposes to ZYZ + qml.RZ(0.1, wires=[0]), + qml.RY(0.2, wires=[0]), + qml.RZ(0.3, wires=[0]), + # CNOT decomposes to H and CZ, where H decomposes to RZ and RY + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[2, 1]), + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + # second CNOT + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[1, 0]), + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + # The middle RZ + qml.RZ(0.5, wires=[0]), + # The last two CNOTs + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[1, 0]), + qml.RZ(np.pi, wires=[0]), + qml.RY(np.pi / 2, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[2, 1]), + qml.RZ(np.pi, wires=[1]), + qml.RY(np.pi / 2, wires=[1]), + qml.GlobalPhase(-np.pi / 2), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_fixed_decomp(self): + """Tests that a fixed decomposition rule is used instead of the stock ones.""" + + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) + def my_cnot(wires, **__): + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + qml.CZ(wires=wires) + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + + tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) + [new_tape], _ = qml.transforms.decompose( + tape, + gate_set={"RY", "RZ", "CZ", "Hadamard", "GlobalPhase"}, + fixed_decomps={qml.CNOT: my_cnot}, + ) + assert new_tape.operations == [ + qml.RY(np.pi / 2, wires=[0]), + qml.RZ(np.pi, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + qml.CZ(wires=[1, 0]), + qml.RY(np.pi / 2, wires=[0]), + qml.RZ(np.pi, wires=[0]), + qml.GlobalPhase(-np.pi / 2), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_alt_decomp_not_used(self): + """Tests that alt_decomp isn't necessarily used if it's not efficient.""" + + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) + def my_cnot(wires, **__): + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + qml.CZ(wires=wires) + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + + tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) + [new_tape], _ = qml.transforms.decompose( + tape, + gate_set={"RY", "RZ", "CZ", "Hadamard", "GlobalPhase"}, + alt_decomps={qml.CNOT: [my_cnot]}, + ) + assert new_tape.operations == [ + qml.H(0), + qml.CZ(wires=[1, 0]), + qml.H(0), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_alt_decomp(self): + """Tests that alternative decomposition rules are used when applicable.""" + + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) + def my_cnot(wires, **__): + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + qml.CZ(wires=wires) + qml.RY(np.pi / 2, wires[1]) + qml.Z(wires[1]) + + tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) + [new_tape], _ = qml.transforms.decompose( + tape, + gate_set={"RY", "RZ", "CZ", "PauliZ", "GlobalPhase"}, + alt_decomps={qml.CNOT: [my_cnot]}, + ) + assert new_tape.operations == [ + qml.RY(np.pi / 2, wires=[0]), + qml.Z(0), + qml.CZ(wires=[1, 0]), + qml.RY(np.pi / 2, wires=[0]), + qml.Z(0), + ] + + qml.decomposition.disable_graph() + + @pytest.mark.integration + def test_fall_back(self): + """Tests that op.decompose() is used for ops unsolved in the graph.""" + + qml.decomposition.enable_graph() + + class CustomOp(qml.operation.Operation): # pylint: disable=too-few-public-methods + """Dummy custom op.""" + + resource_keys = set() + + @property + def resource_params(self): + return {} + + def decomposition(self): + return [qml.H(self.wires[1]), qml.CNOT(self.wires), qml.H(self.wires[1])] + + @qml.register_resources({qml.CZ: 1}) + def my_decomp(wires, **__): + qml.CZ(wires=wires) + + tape = qml.tape.QuantumScript([CustomOp(wires=[0, 1])]) + [new_tape], _ = qml.transforms.decompose( + tape, gate_set={"CNOT", "Hadamard"}, fixed_decomps={CustomOp: my_decomp} + ) + assert new_tape.operations == [qml.H(1), qml.CNOT(wires=[0, 1]), qml.H(1)] + + qml.decomposition.disable_graph() + + # @pytest.mark.integration + # def test_controlled_decomp(self): + # """Tests decomposing a controlled operation.""" + + # # The C(MultiRZ) is decomposed by applying control on the base decomposition. + # # The decomposition of MultiRZ contains two CNOTs + # # So this also tests applying control on an PauliX based operation + # # The decomposition of MultiRZ also contains an RZ gate + # # So this also tests logic involving custom controlled operators. + # ops = [qml.ctrl(qml.MultiRZ(0.5, wires=[0, 1]), control=[2])] + # tape = qml.tape.QuantumScript(ops) + # [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RZ", "CNOT", "Toffoli"}) + # assert new_tape.operations == [ + # # Decomposition of C(CNOT) + # qml.Toffoli(wires=[2, 1, 0]), + # # Decomposition of C(RZ) -> CRZ + # qml.RZ(0.25, wires=[0]), + # qml.CNOT(wires=[2, 0]), + # qml.RZ(-0.25, wires=[0]), + # qml.CNOT(wires=[2, 0]), + # # Decomposition of C(CNOT) + # qml.Toffoli(wires=[2, 1, 0]), + # ] + + # @pytest.mark.integration + # def test_adjoint_decomp(self): + # """Tests decomposing an adjoint operation.""" + + # class CustomOp(qml.operation.Operator): # pylint: disable=too-few-public-methods + + # resource_keys = set() + + # @property + # def resource_params(self) -> dict: + # return {} + + # @qml.register_resources({qml.RX: 1, qml.RY: 1, qml.RZ: 1}) + # def custom_decomp(theta, phi, omega, wires): + # qml.RX(theta, wires[0]) + # qml.RY(phi, wires[0]) + # qml.RZ(omega, wires[0]) + + # tape = qml.tape.QuantumScript( + # [ + # qml.adjoint(qml.RX(0.5, wires=[0])), + # qml.adjoint(qml.adjoint(qml.MultiRZ(0.5, wires=[0, 1]))), + # qml.adjoint(CustomOp(0.1, 0.2, 0.3, wires=[0])), + # ] + # ) + # [new_tape], _ = qml.transforms.decompose( + # tape, gate_set={"CNOT", "RX", "RY", "RZ"}, fixed_decomps={CustomOp: custom_decomp} + # ) + # assert new_tape.operations == [ + # qml.RX(-0.5, wires=[0]), + # qml.CNOT(wires=[1, 0]), + # qml.RZ(0.5, wires=[0]), + # qml.CNOT(wires=[1, 0]), + # qml.RZ(-0.3, wires=[0]), + # qml.RY(-0.2, wires=[0]), + # qml.RX(-0.1, wires=[0]), + # ] + + +def test_decompose_qnode(): + """Tests that the decompose transform works with a QNode.""" + + @partial(qml.transforms.decompose, gate_set={"CZ", "Hadamard"}) + @qml.qnode(qml.device("default.qubit", wires=2)) + def circuit(): + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(0)) + + res = circuit() + assert qml.math.allclose(res, 1.0) From 02be03450300af844ae2f9e856925ca62b151566 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:01:26 -0400 Subject: [PATCH 02/36] Tidy up --- frontend/catalyst/device/qjit_device.py | 3 +- frontend/catalyst/from_plxpr/decompose.py | 320 +++++------------- frontend/catalyst/from_plxpr/from_plxpr.py | 79 +++-- frontend/catalyst/jax_extras/lowering.py | 3 +- frontend/catalyst/jax_primitives.py | 2 +- frontend/test/lit/test_decomposition.py | 8 +- .../from_plxpr/test_from_plxpr_decompose.py | 4 + runtime/include/RuntimeCAPI.h | 1 + runtime/lib/capi/RuntimeCAPI.cpp | 8 + 9 files changed, 164 insertions(+), 264 deletions(-) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index b9c9426201..b9cac2d619 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -83,6 +83,7 @@ "PSWAP", "QubitUnitary", "Rot", + "RotXZX", "RX", "RY", "RZ", @@ -113,7 +114,7 @@ # the operations supported by the runtime. # FIXME: ops with OpName(params, wires) signatures # can be represented in the Catalyst compiler. -COMPILER_OPERATIONS = RUNTIME_OPERATIONS + ["RotXZX"] +COMPILER_OPERATIONS = RUNTIME_OPERATIONS # The runtime interface does not care about specific gate properties, so set them all to True. RUNTIME_OPERATIONS = { diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index c3d576ee26..32780a5918 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -34,7 +34,6 @@ # GraphSolutionInterpreter: from pennylane.decomposition import DecompositionGraph -from pennylane.decomposition.collect_resource_ops import CollectResourceOps from pennylane.decomposition.decomposition_graph import DecompGraphSolution from pennylane.decomposition.utils import translate_op_alias from pennylane.operation import Operator @@ -173,7 +172,7 @@ def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator): return out - # pylint: disable=too-many-branches + # pylint: disable=too-many-branches, too-many-locals def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list: """ Evaluates a jaxpr, which can also be generated by a dynamic decomposition. @@ -270,6 +269,7 @@ def wrapper(*inner_args): return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) +# pylint: disable=too-few-public-methods class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): """Interpreter for getting the decomposition graph solution from a jaxpr when program capture is enabled. @@ -280,212 +280,98 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): def __init__( self, *, + operations, gate_set=None, - stopping_condition=None, - max_expansion=None, fixed_decomps=None, alt_decomps=None, ): # pylint: disable=too-many-arguments - self.max_expansion = max_expansion - self._current_depth = 0 - - if not qml.decomposition.enabled_graph() and (fixed_decomps or alt_decomps): + if not qml.decomposition.enabled_graph(): raise TypeError( - "The keyword arguments fixed_decomps and alt_decomps are only available with " - "the new experimental graph-based decomposition system. Use qml.decomposition.enable_graph() " - "to enable the new system." + "The GraphSolutionInterpreter can only be used when" + "graph-based decomposition is enabled." ) - self._decomp_graph_solution = None + self._operations = operations + self._decomp_graph_solution = {} self._target_gate_names = None self._fixed_decomps, self._alt_decomps = fixed_decomps, alt_decomps - # We use a ChainMap to store the environment frames, which allows us to push and pop - # environments without copying the interpreter instance when we evaluate a jaxpr of - # a dynamic decomposition. The name is different from the _env in the parent class - # (a dictionary) to avoid confusion. - self._env_map = ChainMap() - - gate_set, stopping_condition = _resolve_gate_set(gate_set, stopping_condition) + gate_set, _ = _resolve_gate_set(gate_set) self._gate_set = gate_set - self._stopping_condition = stopping_condition - - def setup(self) -> None: - """Setup the environment for the interpreter by pushing a new environment frame.""" - - # This is the local environment for the jaxpr evaluation, on the top of the stack, - # from which the interpreter reads and writes variables. - # ChainMap writes to the first dictionary in the chain by default. - self._env_map = self._env_map.new_child() - - def cleanup(self) -> None: - """Cleanup the environment by popping the top-most environment frame.""" + self._env = {} - # We delete the top-most environment frame after the evaluation is done. - self._env_map = self._env_map.parents - - def read(self, var): - """Extract the value corresponding to a variable.""" - return var.val if isinstance(var, jax.extend.core.Literal) else self._env_map[var] - - def stopping_condition(self, op: qml.operation.Operator) -> bool: - """Function to determine whether an operator needs to be decomposed or not. + # pylint: disable=too-many-branches, too-many-locals + def eval(self, jaxpr: "jax.extend.core.Jaxpr", consts: Sequence, *args) -> list: + """Evaluate a jaxpr. Args: - op (qml.operation.Operator): Operator to check. + jaxpr (jax.extend.core.Jaxpr): the jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args (tuple[TensorLike]): The arguments for the jaxpr. Returns: - bool: Whether ``op`` is valid or needs to be decomposed. ``True`` means - that the operator does not need to be decomposed. - """ - - # If the new graph-based decomposition is enabled, - # we don't rely on the has_decomposition attribute. - if qml.decomposition.enabled_graph(): - return self._stopping_condition(op) - - if not op.has_decomposition: - if not self._stopping_condition(op): - warnings.warn( - f"Operator {op.name} does not define a decomposition and was not " - f"found in the target gate set. To remove this warning, add the operator " - f"name ({op.name}) or type ({type(op)}) to the gate set.", - UserWarning, - ) - return True - - return self._stopping_condition(op) - - def decompose_operation(self, op: qml.operation.Operator): - """Decompose a PennyLane operation instance if it does not satisfy the - provided gate set. - - Args: - op (Operator): a pennylane operator instance - - This method is only called when the operator's output is a dropped variable, - so the output will not affect later equations in the circuit. - - See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`. - """ - - if self._stopping_condition(op): - return self.interpret_operation(op) - - max_expansion = ( - self.max_expansion - self._current_depth if self.max_expansion is not None else None - ) - - with qml.capture.pause(): - decomposition = list( - _operator_decomposition_gen( - op, - self.stopping_condition, - max_expansion=max_expansion, - decomp_graph_solution=self._decomp_graph_solution, - ) - ) - - return [self.interpret_operation(decomp_op) for decomp_op in decomposition] - - def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator): - """Creates and evaluates a Jaxpr of the plxpr decomposition of an operator.""" - - if self._stopping_condition(op): - return self.interpret_operation(op) - - if self.max_expansion is not None and self._current_depth >= self.max_expansion: - return self.interpret_operation(op) - - if qml.decomposition.enabled_graph() and self._decomp_graph_solution.is_solved_for(op): - - rule = self._decomp_graph_solution.decomposition(op) - num_wires = len(op.wires) - - def compute_qfunc_decomposition(*_args, **_kwargs): - wires = qml.math.array(_args[-num_wires:], like="jax") - rule(*_args[:-num_wires], wires=wires, **_kwargs) - - else: - compute_qfunc_decomposition = op.compute_qfunc_decomposition - - args = (*op.parameters, *op.wires) - - jaxpr_decomp = qml.capture.make_plxpr( - partial(compute_qfunc_decomposition, **op.hyperparameters) - )(*args) - - self._current_depth += 1 - # We don't need to copy the interpreter here, as the jaxpr of the decomposition - # is evaluated with a new environment frame placed on top of the stack. - out = self.eval(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args) - self._current_depth -= 1 - - return out - - # pylint: disable=too-many-branches - def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list: - """ - Evaluates a jaxpr, which can also be generated by a dynamic decomposition. + list[TensorLike]: the results of the execution. - Args: - jaxpr_decomp (jax.extend.core.Jaxpr): the Jaxpr to evaluate - consts (list[TensorLike]): the constant variables for the jaxpr - *args: the arguments to use in the evaluation """ - + self._env = {} self.setup() for arg, invar in zip(args, jaxpr.invars, strict=True): - self._env_map[invar] = arg + self._env[invar] = arg for const, constvar in zip(consts, jaxpr.constvars, strict=True): - self._env_map[constvar] = const + self._env[constvar] = const - if qml.decomposition.enabled_graph() and not self._decomp_graph_solution: + if self._operations and not self._decomp_graph_solution: - with qml.capture.pause(): + self._decomp_graph_solution = _solve_decomposition_graph( + self._operations, + self._gate_set, + fixed_decomps=self._fixed_decomps, + alt_decomps=self._alt_decomps, + ) - collector = CollectResourceOps() - collector.eval(jaxpr, consts, *args) - operations = collector.state["ops"] + # for op, rule_impl in self._decomp_graph_solution.items(): + # # print(op, rule_impl) + # def compute_qfunc_decomp(): + # rule_impl(int) - if operations: - self._decomp_graph_solution = _construct_and_solve_decomp_graph( - operations, - self._gate_set, - self._fixed_decomps, - self._alt_decomps, - ) + # if op.op.name in ("RX", "RY", "RZ", "PhaseShift", "Rot", "U1"): + # def compute_qfunc_decomp(): + # rule_impl(float, int) + # else: + # continue - # for op, decomp in self._decomp_graph_solution.decompositions(): - # print(f"Decomposition for {op}: {decomp}") + # jaxpr_decomp = qml.capture.make_plxpr( + # compute_qfunc_decomp + # )() - for eq in jaxpr.eqns: + # print(jaxpr_decomp) + # # out = self.eval(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, tuple()) + # # print(out) - prim_type = getattr(eq.primitive, "prim_type", "") - custom_handler = self._primitive_registrations.get(eq.primitive, None) + for eqn in jaxpr.eqns: + primitive = eqn.primitive + custom_handler = self._primitive_registrations.get(primitive, None) if custom_handler: - - invals = [self.read(invar) for invar in eq.invars] - outvals = custom_handler(self, *invals, **eq.params) - - elif prim_type == "operator": - outvals = self.interpret_operation_eqn(eq) - elif prim_type == "measurement": - outvals = self.interpret_measurement_eqn(eq) + invals = [self.read(invar) for invar in eqn.invars] + outvals = custom_handler(self, *invals, **eqn.params) + elif getattr(primitive, "prim_type", "") == "operator": + outvals = self.interpret_operation_eqn(eqn) + elif getattr(primitive, "prim_type", "") == "measurement": + outvals = self.interpret_measurement_eqn(eqn) else: - invals = [self.read(invar) for invar in eq.invars] - subfuns, params = eq.primitive.get_bind_params(eq.params) - outvals = eq.primitive.bind(*subfuns, *invals, **params) + invals = [self.read(invar) for invar in eqn.invars] + subfuns, params = primitive.get_bind_params(eqn.params) + outvals = primitive.bind(*subfuns, *invals, **params) - if not eq.primitive.multiple_results: + if not primitive.multiple_results: outvals = [outvals] + for outvar, outval in zip(eqn.outvars, outvals, strict=True): + self._env[outvar] = outval - for outvar, outval in zip(eq.outvars, outvals, strict=True): - self._env_map[outvar] = outval - + # Read the final result of the Jaxpr from the environment outvals = [] for var in jaxpr.outvars: outval = self.read(var) @@ -493,44 +379,10 @@ def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list: outvals.append(self.interpret_operation(outval)) else: outvals.append(outval) - self.cleanup() - + self._env = {} return outvals - def interpret_operation_eqn(self, eqn: jax.extend.core.JaxprEqn): - """Interpret an equation corresponding to an operator. - - If the operator has a dynamic decomposition defined, this method will - create and evaluate the jaxpr of the decomposition using the :meth:`~.eval` method. - - Args: - eqn (jax.extend.core.JaxprEqn): a jax equation for an operator. - - See also: :meth:`~.interpret_operation`. - - """ - - invals = (self.read(invar) for invar in eqn.invars) - - with qml.QueuingManager.stop_recording(): - op = eqn.primitive.impl(*invals, **eqn.params) - - if not eqn.outvars[0].__class__.__name__ == "DropVar": - return op - - # _evaluate_jaxpr_decomposition should be used when the operator defines a - # compute_qfunc_decomposition, or if graph-based decomposition is enabled and - # a solution is found for this operator in the graph. - if ( - op.has_qfunc_decomposition - or qml.decomposition.enabled_graph() - and self._decomp_graph_solution.is_solved_for(op) - ): - return self._evaluate_jaxpr_decomposition(op) - - return self.decompose_operation(op) - # pylint: disable=too-many-arguments @GraphSolutionInterpreter.register_primitive(ctrl_transform_prim) @@ -576,6 +428,7 @@ def _operator_decomposition_gen( acceptance_function: Callable[[qml.operation.Operator], bool], max_expansion: int | None = None, current_depth=0, + decomp_graph_solution: DecompGraphSolution | None = None, ) -> Generator[qml.operation.Operator]: """A generator that yields the next operation that is accepted.""" @@ -587,6 +440,12 @@ def _operator_decomposition_gen( if acceptance_function(op) or max_depth_reached: yield op + elif decomp_graph_solution is not None and decomp_graph_solution.is_solved_for(op): + op_rule = decomp_graph_solution.decomposition(op) + with qml.queuing.AnnotatedQueue() as decomposed_ops: + op_rule(*op.parameters, wires=op.wires, **op.hyperparameters) + decomp = decomposed_ops.queue + current_depth += 1 else: decomp = op.decomposition() current_depth += 1 @@ -597,6 +456,7 @@ def _operator_decomposition_gen( acceptance_function, max_expansion=max_expansion, current_depth=current_depth, + decomp_graph_solution=decomp_graph_solution, ) @@ -604,18 +464,7 @@ def _resolve_gate_set( gate_set: set[type | str] | dict[type | str, float] = None, stopping_condition: Callable[[qml.operation.Operator], bool] = None, ) -> tuple[set[type | str] | dict[type | str, float], Callable[[qml.operation.Operator], bool]]: - """Resolve the gate set and the stopping condition from arguments. - - The ``gate_set`` can be provided in various forms, and the ``stopping_condition`` may or - may not be provided. This function will resolve the gate set and the stopping condition - to the following standardized form: - - - The ``gate_set`` is set of operator **types** and/or names, or a dictionary mapping operator - types and/or names to their respective costs. This is only used by the DecompositionGraph - - The ``stopping_condition`` is a function that takes an operator **instances** and returns - ``True`` if the operator does not need to be decomposed. This is used during decomposition. - - """ + """Resolve the gate set and the stopping condition from arguments.""" if gate_set is None: gate_set = set(qml.ops.__all__) @@ -660,18 +509,37 @@ def _stopping_condition(op): return gate_set, _stopping_condition -def _construct_and_solve_decomp_graph( - operations, target_gates, fixed_decomps, alt_decomps -) -> DecompGraphSolution: - """Create and solve a DecompositionGraph instance to optimize the decomposition.""" +# pylint: disable=protected-access +def _solve_decomposition_graph(operations, gate_set, fixed_decomps, alt_decomps): + """Get the decomposition graph solution for the given operations and gate set.""" + + # decomp_graph_solution + decomp_graph_solution = {} - # Create the decomposition graph decomp_graph = DecompositionGraph( operations, - target_gates, + gate_set, fixed_decomps=fixed_decomps, alt_decomps=alt_decomps, ) # Find the efficient pathways to the target gate set - return decomp_graph.solve() + solutions = decomp_graph.solve() + + def is_solved_for(op): + return ( + op in solutions._all_op_indices + and solutions._all_op_indices[op] in solutions._visitor.distances + ) + + for ( + op_node, + op_node_idx, + ) in solutions._all_op_indices.items(): + + if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: + d_node_idx = solutions._visitor.predecessors[op_node_idx] + decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl + + print("[DEBUG PRINT] Decomposition graph solution:", decomp_graph_solution) + return decomp_graph_solution diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 7f1c19813c..dc9798b125 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -33,6 +33,7 @@ from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim from pennylane.capture.primitives import measure_prim as plxpr_measure_prim +from pennylane.decomposition.collect_resource_ops import CollectResourceOps from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires from pennylane.transforms import cancel_inverses as pl_cancel_inverses @@ -45,7 +46,10 @@ from catalyst.device import extract_backend_info from catalyst.device.qjit_device import COMPILER_OPERATIONS -from catalyst.from_plxpr.decompose import GraphSolutionInterpreter, PreMlirDecomposeInterpreter +from catalyst.from_plxpr.decompose import ( + GraphSolutionInterpreter, + PreMlirDecomposeInterpreter, +) from catalyst.from_plxpr.qubit_handler import QubitHandler from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( @@ -279,6 +283,32 @@ def calling_convention(*args): } +# pylint: disable=too-many-arguments +def handle_graph_decomposition(*args, inner_jaxpr, consts, non_const_args, targs, tkwargs): + """Handle the graph decomposition for a given JAXPR.""" + + gate_set = COMPILER_OPERATIONS + decomp_kwargs = {"gate_set": gate_set} + + pmd_interpreter = PreMlirDecomposeInterpreter(*targs, **decomp_kwargs) + + def pmd_wrapper(*args): + return pmd_interpreter.eval(inner_jaxpr, consts, *args) + + pmd_jaxpr = jax.make_jaxpr(pmd_wrapper)(*args) + + ops_collector = CollectResourceOps() + ops_collector.eval(pmd_jaxpr.jaxpr, consts, *args) + pl_ops = ops_collector.state["ops"] + + gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs, operations=pl_ops) + + def gds_wrapper(*args): + return gds_interpreter.eval(pmd_jaxpr.jaxpr, consts, *args) + + return jax.make_jaxpr(gds_wrapper)(*args) + + # pylint: disable-next=redefined-outer-name def register_transform(pl_transform, pass_name, decomposition): """Register pennylane transforms and their conversion to Catalyst transforms""" @@ -304,12 +334,8 @@ def handle_transform( targs = args[targs_slice] # Check if the transform is a decomposition transform - # If so, we'll set the compiler_decompose flag to trigger - # 1. Construct the graph with the list of ops and the target gateset - # 2. Capture and lower the decomposition qfuncs down to MLIR - # 3. Bypass the custom PLxPR DecomposeInterpreter class # - # Notes: + # Note: # - The list of target gateset is always taken from the transform's attributes # and passed down to the MLIR lowering as a quantum function attribute. if ( @@ -319,34 +345,25 @@ def handle_transform( ): self.compiler_decompose = True - # Use PL's ExpandTransformsInterpreter to expand this and any embedded - # transform according to PL rules. It works by overriding the primitive - # registration, making all embedded transforms follow the PL rules - # from now on, hence ignoring the Catalyst pass conversion - def wrapper(*args): - return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) - if self.compiler_decompose: - gate_set = COMPILER_OPERATIONS - decomp_kwargs = {"gate_set": gate_set} - - pmd_interpreter = PreMlirDecomposeInterpreter(*targs, **decomp_kwargs) - - def pmd_wrapper(*args): - return pmd_interpreter.eval(inner_jaxpr, consts, *args) - - pmd_jaxpr = jax.make_jaxpr(pmd_wrapper)(*args) - - gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs) - - def gds_wrapper(*args): - return gds_interpreter.eval(pmd_jaxpr.jaxpr, consts, *args) - - gds_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) - - return self.eval(gds_jaxpr.jaxpr, gds_jaxpr.consts, *non_const_args) + final_jaxpr = handle_graph_decomposition( + *args, + inner_jaxpr=inner_jaxpr, + consts=consts, + non_const_args=non_const_args, + targs=targs, + tkwargs=tkwargs, + ) + return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) if catalyst_pass_name is None: + # Use PL's ExpandTransformsInterpreter to expand this and any embedded + # transform according to PL rules. It works by overriding the primitive + # registration, making all embedded transforms follow the PL rules + # from now on, hence ignoring the Catalyst pass conversion + def wrapper(*args): + return ExpandTransformsInterpreter().eval(inner_jaxpr, consts, *args) + unravelled_jaxpr = jax.make_jaxpr(wrapper)(*non_const_args) final_jaxpr = pl_plxpr_transform( unravelled_jaxpr.jaxpr, unravelled_jaxpr.consts, targs, tkwargs, *non_const_args diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index 27afc8ba1a..30c449d5fb 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -176,7 +176,8 @@ def custom_lower_jaxpr_to_module( op.attributes[attr_name] = mlir_attr except CompileError as e: raise CompileError( - f"While converting Python attribute '{attr_name}': '{attr_value}' to MLIR: {e}" + "While converting Python attribute" + f"'{attr_name}': '{attr_value}' to MLIR: {e}" ) from e if isinstance(op, ModuleOp): worklist += [*op.body.operations] diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 8f754913e7..4f220e703a 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -395,7 +395,7 @@ def wrapper(*args, **kwargs): return wrapper -def decomposition_rule(func=None, *, is_qreg=False, num_params=0): +def decomposition_rule(func=None, *, is_qreg=True, num_params=0): """ Denotes the creation of a quantum definition in the intermediate representation. """ diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 0c61109e9a..b73dd0bd8e 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -273,7 +273,7 @@ def decompose_to_matrix(): def test_decomposition_rule_wire_param(): """Test decomposition rule with passing a parameter that is a wire/integer""" - @decomposition_rule + @decomposition_rule(is_qreg=False) def Hadamard0(wire: WiresLike): qml.Hadamard(wire) @@ -303,7 +303,7 @@ def circuit(_: float): def test_decomposition_rule_gate_param_param(): """Test decomposition rule with passing a regular parameter""" - @decomposition_rule(num_params=1) + @decomposition_rule(is_qreg=False, num_params=1) def RX_on_wire_0(param: TensorLike, w0: WiresLike): qml.RX(param, wires=w0) @@ -336,7 +336,7 @@ def test_multiple_decomposition_rules(): @decomposition_rule def identity(): ... - @decomposition_rule(num_params=1) + @decomposition_rule(is_qreg=True) def all_wires_rx(param: TensorLike, w0: WiresLike, w1: WiresLike, w2: WiresLike): qml.RX(param, wires=w0) qml.RX(param, wires=w1) @@ -409,7 +409,7 @@ def shaped_wires_rule(param: TensorLike, wires: WiresLike): qml.RX(param, wires=wires[1]) qml.RX(param, wires=wires[2]) - @decomposition_rule(num_params=1, is_qreg=False) + @decomposition_rule(is_qreg=False, num_params=1) def expanded_wires_rule(param: TensorLike, w1, w2, w3): shaped_wires_rule(param, [w1, w2, w3]) diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py index aeca1a1398..4e21eadff4 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py @@ -217,9 +217,13 @@ class CustomOp(qml.operation.Operation): # pylint: disable=too-few-public-metho @property def resource_params(self): + """Dummy resource params.""" + return {} def decomposition(self): + """Decomposition of CustomOp into H-CNOT-H.""" + return [qml.H(self.wires[1]), qml.CNOT(self.wires), qml.H(self.wires[1])] @qml.register_resources({qml.CZ: 1}) diff --git a/runtime/include/RuntimeCAPI.h b/runtime/include/RuntimeCAPI.h index 1333527c9e..fd98162de6 100644 --- a/runtime/include/RuntimeCAPI.h +++ b/runtime/include/RuntimeCAPI.h @@ -62,6 +62,7 @@ void __catalyst__qis__RX(double, QUBIT *, const Modifiers *); void __catalyst__qis__RY(double, QUBIT *, const Modifiers *); void __catalyst__qis__RZ(double, QUBIT *, const Modifiers *); void __catalyst__qis__Rot(double, double, double, QUBIT *, const Modifiers *); +void __catalyst__qis__RotXZX(double, double, double, QUBIT *, const Modifiers *); void __catalyst__qis__CNOT(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__CY(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__CZ(QUBIT *, QUBIT *, const Modifiers *); diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 460cef97f1..3f324ccf4a 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -626,6 +626,14 @@ void __catalyst__qis__Rot(double phi, double theta, double omega, QUBIT *qubit, MODIFIERS_ARGS(modifiers)); } +void __catalyst__qis__RotXZX(double phi, double theta, double omega, QUBIT *qubit, + const Modifiers *modifiers) +{ + getQuantumDevicePtr()->NamedOperation("RotXZX", {phi, theta, omega}, + {reinterpret_cast(qubit)}, + MODIFIERS_ARGS(modifiers)); +} + void __catalyst__qis__CNOT(QUBIT *control, QUBIT *target, const Modifiers *modifiers) { RT_FAIL_IF(control == target, From c20baaff43896eb16f82df815d64c209aac0cd98 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:07:09 -0400 Subject: [PATCH 03/36] Add py example --- test_new_decomp.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 test_new_decomp.py diff --git a/test_new_decomp.py b/test_new_decomp.py new file mode 100644 index 0000000000..df21b1b9c7 --- /dev/null +++ b/test_new_decomp.py @@ -0,0 +1,66 @@ +from functools import partial + +import jax +import numpy as np +import pennylane as qml +from pennylane.ftqc import RotXZX +from pennylane.typing import TensorLike +from pennylane.wires import WiresLike + +from catalyst import qjit +from catalyst.from_plxpr import from_plxpr +from catalyst.jax_primitives import decomposition_rule + +qml.capture.enable() + +qml.decomposition.enable_graph() + + +@decomposition_rule +def _ry_to_rz_rx(phi, wires: WiresLike, **__): + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + +@decomposition_rule +def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + +@decomposition_rule +def _u2_phaseshift_rot(phi, delta, wires, **__): + pi_half = qml.math.ones_like(delta) * (np.pi / 2) + qml.Rot(delta, pi_half, -delta, wires=wires) + qml.PhaseShift(delta, wires=wires) + qml.PhaseShift(phi, wires=wires) + + +@decomposition_rule +def _xzx_decompose(phi, theta, omega, wires, **__): + qml.RX(phi, wires=wires) + qml.RZ(theta, wires=wires) + qml.RX(omega, wires=wires) + + +@qml.qjit() +@partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) +@qml.qnode(qml.device("lightning.qubit", wires=3)) +def circuit(): + + qml.RY(0.5, wires=0) + qml.Rot(0.1, 0.2, 0.3, wires=1) + qml.U2(0.4, 0.5, wires=2) + RotXZX(0.6, 0.7, 0.8, wires=0) + + _ry_to_rz_rx(0, 0) + _rot_to_rz_ry_rz(0, 0, 0, 1) + _u2_phaseshift_rot(0, 0, 2) + _xzx_decompose(0, 0, 0, 0) + + return qml.expval(qml.Z(0)) + + +print(circuit.mlir) From b6f41559f0b08b5e7a4c71f520df9986d1cde310 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 08:57:13 -0400 Subject: [PATCH 04/36] cherry-picking decomp_gateset commits --- frontend/catalyst/from_plxpr/decompose.py | 1 - frontend/catalyst/from_plxpr/from_plxpr.py | 45 ++++++++++----- frontend/catalyst/jax_extras/lowering.py | 18 +----- frontend/catalyst/jax_primitives_utils.py | 5 ++ frontend/catalyst/jax_tracer.py | 20 +++---- frontend/catalyst/jit.py | 22 +------ frontend/test/lit/test_decomposition.py | 67 ++++++++++++++++++++++ frontend/test/lit/test_from_plxpr.py | 3 +- test_new_decomp.py | 45 ++++++++++++++- 9 files changed, 161 insertions(+), 65 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 32780a5918..8400db453d 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -541,5 +541,4 @@ def is_solved_for(op): d_node_idx = solutions._visitor.predecessors[op_node_idx] decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl - print("[DEBUG PRINT] Decomposition graph solution:", decomp_graph_solution) return decomp_graph_solution diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 80f633230a..3e64260b16 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -35,6 +35,7 @@ from pennylane.capture.primitives import measure_prim as plxpr_measure_prim from pennylane.decomposition.collect_resource_ops import CollectResourceOps from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim +from pennylane.ops import Adjoint, Controlled, ControlledOp from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires from pennylane.transforms import cancel_inverses as pl_cancel_inverses from pennylane.transforms import commute_controlled as pl_commute_controlled @@ -187,7 +188,7 @@ class WorkflowInterpreter(PlxprInterpreter): def __init__(self): self._pass_pipeline = [] self.qubit_handler = None - self.compiler_decompose = False + self.decomp_gateset = [] super().__init__() @@ -259,6 +260,9 @@ def calling_convention(*args): device_release_p.bind() return retvals + # Add gate_set attribute to the quantum kernel primitive + setattr(qnode, "decomp_gateset", self.decomp_gateset) + return quantum_kernel_p.bind( wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), *non_const_args, @@ -333,19 +337,34 @@ def handle_transform( non_const_args = args[args_slice] targs = args[targs_slice] - # Check if the transform is a decomposition transform - # - # Note: - # - The list of target gateset is always taken from the transform's attributes - # and passed down to the MLIR lowering as a quantum function attribute. + # If the transform is a decomposition transform + # and the graph-based decomposition is enabled if ( hasattr(pl_plxpr_transform, "__name__") and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): - self.compiler_decompose = True + # Update the decomp_gateset to be used by the quantum kernel primitive + self.decomp_gateset = tkwargs.get("gate_set", []) + + # A helper function to get the name of a pennylane operator + def get_operator_name(op): + """Get the name of a pennylane operator, handling wrapped operators. + + Note: Controlled and Adjoint ops aren't supported in `gate_set` + by PennyLane's DecompositionGraph; unit tests were added in PennyLane. + """ + if isinstance(op, str): + return op + + return getattr(op._primitive, "name", "UnsupportedGate") - if self.compiler_decompose: + self.decomp_gateset = [get_operator_name(op) for op in self.decomp_gateset] + + # First decompose to the compiler gateset. + # Then, construct and solve the graph-based decomposition + # to get the optimized rules and lower them to PLxPR + # to Catalyst JAXPR to MLIR. final_jaxpr = handle_graph_decomposition( *args, inner_jaxpr=inner_jaxpr, @@ -375,10 +394,10 @@ def wrapper(*args): ) return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) - else: - # Apply the corresponding Catalyst pass counterpart - self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) - return self.eval(inner_jaxpr, consts, *non_const_args) + + # Apply the corresponding Catalyst pass counterpart + self._pass_pipeline.append(Pass(catalyst_pass_name)) + return self.eval(inner_jaxpr, consts, *non_const_args) # This is our registration factory for PL transforms. The loop below iterates @@ -404,7 +423,7 @@ def __init__(self, device, shots, qubit_handler, cache, *, control_wires=(), con # TODO: we assume the qreg value passed into a scope is the unique qreg in the scope # In other words, we assume no new qreg will be allocated in the scope self.qubit_handler = qubit_handler - self.compiler_decompose = False + self.decomp_gateset = [] self.subroutine_cache = cache self.control_wires = control_wires """Any control wires used for a subroutine.""" diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index 30c449d5fb..7dc1382593 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -53,7 +53,7 @@ @debug_logger -def jaxpr_to_mlir(func_name, jaxpr, py_attrs=None): +def jaxpr_to_mlir(func_name, jaxpr): """Lower a Jaxpr into an MLIR module. Args: @@ -81,7 +81,6 @@ def jaxpr_to_mlir(func_name, jaxpr, py_attrs=None): platform="cpu", axis_context=axis_context, name_stack=name_stack, - py_attrs=py_attrs, ) return module, context @@ -100,7 +99,6 @@ def custom_lower_jaxpr_to_module( replicated_args=None, arg_shardings=None, result_shardings=None, - py_attrs=None, ): """Lowers a top-level jaxpr to an MHLO module. @@ -111,10 +109,6 @@ def custom_lower_jaxpr_to_module( https://github.com/google/jax/blob/c4d590b1b640cc9fcfdbe91bf3fe34c47bcde917/jax/interpreters/mlir.py#L625version released under the Apache License, Version 2.0, with the following copyright notice: - Note: We further modified this function to accept `py_attrs`, which allows for the passing - of custom attributes from Python to MLIR. This is currently used for passing - the target gate set information. - Copyright 2021 The JAX Authors. """ @@ -169,16 +163,6 @@ def custom_lower_jaxpr_to_module( continue if isinstance(op, FuncOp): op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage") - if py_attrs: # pass custom attributes from Python to MLIR - for attr_name, attr_value in py_attrs.items(): - try: - mlir_attr = get_mlir_attribute_from_pyval(list(attr_value)) - op.attributes[attr_name] = mlir_attr - except CompileError as e: - raise CompileError( - "While converting Python attribute" - f"'{attr_name}': '{attr_value}' to MLIR: {e}" - ) from e if isinstance(op, ModuleOp): worklist += [*op.body.operations] diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 73d00cb9ef..0f3c459ec6 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -26,6 +26,8 @@ from mlir_quantum.dialects._transform_ops_gen import ApplyRegisteredPassOp, NamedSequenceOp, YieldOp from mlir_quantum.dialects.catalyst import LaunchKernelOp +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval + def get_call_jaxpr(jaxpr): """Extracts the `call_jaxpr` from a JAXPR if it exists.""" "" @@ -135,6 +137,9 @@ def only_single_expval(): func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) + if gateset := getattr(callable_, "decomp_gateset", []): + func_op.attributes["decomp_gateset"] = get_mlir_attribute_from_pyval(gateset) + return func_op diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index ca0f610923..5d83b8f3a8 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -617,7 +617,7 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_in @debug_logger -def lower_jaxpr_to_mlir(jaxpr, func_name, py_attrs=None): +def lower_jaxpr_to_mlir(jaxpr, func_name): """Lower a JAXPR to MLIR. Args: @@ -632,7 +632,7 @@ def lower_jaxpr_to_mlir(jaxpr, func_name, py_attrs=None): MemrefCallable.clearcache() with transient_jax_config({"jax_dynamic_shapes": True}): - mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr, py_attrs=py_attrs) + mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr) return mlir_module, ctx @@ -1205,12 +1205,7 @@ def apply_transforms( # TODO: Ideally we should allow qnode transforms that don't modify the measurements to # operate in the permissive tracing mode, but that currently leads to a small number of # test failures due to the different result format produced in trace_quantum_function. - only_with_dynamic_one_shot = all( - "dynamic_one_shot_partial" in str(getattr(qnode, "transform", "")) - for qnode in qnode_program - ) - - if has_classical_outputs(flat_results) and not only_with_dynamic_one_shot: + if has_classical_outputs(flat_results): msg = ( "Transforming MeasurementProcesses is unsupported with non-MeasurementProcess " "QNode outputs. The selected device, options, or applied QNode transforms, may be " @@ -1485,9 +1480,12 @@ def check_full_raise(arr, func): meas_results = tree_unflatten(meas_trees, meas_tracers) # TODO: Allow the user to return whatever types they specify. - if tracing_mode == TracingMode.TRANSFORM and isinstance(meas_results, list): - result = meas_results[0] if len(meas_results) == 1 else tuple(meas_results) - transformed_results.append(result) + if tracing_mode == TracingMode.TRANSFORM: + assert isinstance(meas_results, list) + if len(meas_results) == 1: + transformed_results.append(meas_results[0]) + else: + transformed_results.append(tuple(meas_results)) else: transformed_results.append(meas_results) diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 1621e4534b..d7bdc00375 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -546,26 +546,6 @@ def __init__(self, fn, compile_options): self.user_sig = get_type_annotations(fn) self._validate_configuration() - # Extract transform python kwargs from the function - # with both capture enabled and disabled - - # Note: as we are currently interested in decompose - # target gateset, we avoid passing any non-decompose kwargs - transform_lists = fn._transform_program if hasattr(fn, "_transform_program") else [] - decompose_transform_kwargs = [ - t.kwargs - for t in transform_lists - if hasattr(t, "plxpr_transform") - and hasattr(t.plxpr_transform, "__name__") - and "decompose" in t.plxpr_transform.__name__ - ] - - # TODO: Remove this in the future after enabling multiple decomposition support - # in the MLIR rewrite pass. - if len(decompose_transform_kwargs) > 1: - raise ValueError("Multiple decompose transform is not yet supported.") - self.py_attrs = decompose_transform_kwargs[0] if decompose_transform_kwargs else None - # If static_argnames are present, convert them to static_argnums if compile_options.static_argnames is not None: compile_options.static_argnums = merge_static_argname_into_argnum( @@ -793,7 +773,7 @@ def generate_ir(self): Tuple[ir.Module, str]: the in-memory MLIR module and its string representation """ - mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__, py_attrs=self.py_attrs) + mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__) # Inject Runtime Library-specific functions (e.g. setup/teardown). inject_functions(mlir_module, ctx, self.compile_options.seed) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index b73dd0bd8e..1b26fdb15b 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -3,6 +3,7 @@ import pathlib import platform from copy import deepcopy +from functools import partial import jax import pennylane as qml @@ -508,3 +509,69 @@ def circuit_7(): test_decomposition_rule_caller() + + +def test_decompose_gateset_without_graph(): + """Test the decompose transform to a target gate set without the graph decomposition.""" + + qml.capture.enable() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_8() -> tensor attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} + def circuit_8(): + return qml.expval(qml.Z(0)) + + print(circuit_8.mlir) + + qml.capture.disable() + + +test_decompose_gateset_without_graph() + + +def test_decompose_gateset_with_graph(): + """Test the decompose transform to a target gate set with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_9() -> tensor attributes {decomp_gateset + def circuit_9(): + return qml.expval(qml.Z(0)) + + print(circuit_9.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_with_graph() + + +def test_decompose_gateset_operator_with_graph(): + """Test the decompose transform to a target gate set with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, gate_set={qml.RX, qml.RZ, "PauliZ", qml.PauliX, qml.Hadamard} + ) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_10() -> tensor attributes {decomp_gateset + def circuit_10(): + return qml.expval(qml.Z(0)) + + print(circuit_10.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_operator_with_graph() diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index 7fd8618cc6..3564e86146 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -347,9 +347,10 @@ def test_pass_application(): qml.capture.enable() + # TODO: is there an ordering issue here? @qml.qjit(target="mlir") - @qml.transforms.cancel_inverses @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses @qml.qnode(dev) def circuit(): return qml.probs() diff --git a/test_new_decomp.py b/test_new_decomp.py index df21b1b9c7..5505e54f24 100644 --- a/test_new_decomp.py +++ b/test_new_decomp.py @@ -12,10 +12,14 @@ from catalyst.jax_primitives import decomposition_rule qml.capture.enable() - qml.decomposition.enable_graph() +###################################### +# Custom decomposition rules +###################################### + + @decomposition_rule def _ry_to_rz_rx(phi, wires: WiresLike, **__): qml.RZ(-np.pi / 2, wires=wires) @@ -64,3 +68,42 @@ def circuit(): print(circuit.mlir) + + +################################################### +# MBQC Example with custom decomposition to RotXZX +################################################### + +qml.decomposition.enable_graph() +qml.capture.enable() + + +@qml.register_resources({qml.ftqc.RotXZX: 1}) +@decomposition_rule +def _rot_to_xzx(phi, theta, omega, wires, **__): + mat = qml.Rot.compute_matrix(phi, theta, omega) + lam, theta, phi = qml.math.decomposition.xzx_rotation_angles(mat) + qml.ftqc.RotXZX(lam, theta, phi, wires) + + +@qml.qjit() +@partial( + qml.transforms.decompose, + gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, + fixed_decomps={qml.Rot: _rot_to_xzx}, +) +@qml.qnode(qml.device("null.qubit", wires=3)) +def mbqc_circ(x: float, y: float): + qml.RX(x, 0) + qml.RY(y, 1) + + _rot_to_xzx( + float, float, float, int + ) # this needs to be here to include the custom decomposition in the graph + _ry_to_rz_rx(float, int) + _xzx_decompose(float, float, float, int) + + return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)) + + +print(mbqc_circ.mlir) From 891dc5a0f49c09a5ab49672b5d7f95421d7beb1d Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 10:37:33 -0400 Subject: [PATCH 05/36] Make the visibility of decomp rules to public --- frontend/catalyst/from_plxpr/from_plxpr.py | 1 + frontend/catalyst/jax_primitives.py | 5 +- frontend/catalyst/jax_primitives_utils.py | 60 ++++++++++++++++++---- frontend/test/lit/test_decomposition.py | 26 +++++----- test_new_decomp.py | 26 +++++----- 5 files changed, 82 insertions(+), 36 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 3e64260b16..9f7deb650c 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -365,6 +365,7 @@ def get_operator_name(op): # Then, construct and solve the graph-based decomposition # to get the optimized rules and lower them to PLxPR # to Catalyst JAXPR to MLIR. + final_jaxpr = handle_graph_decomposition( *args, inner_jaxpr=inner_jaxpr, diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 4f220e703a..fec8e41684 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -590,7 +590,10 @@ def _decomposition_rule_lowering(ctx, *, pyfun, func_jaxpr, **_): """Lower a quantum decomposition rule into MLIR in a single step process. The step is the compilation of the definition of the function fn. """ - lower_callable(ctx, pyfun, func_jaxpr) + + # Set the visibility of the decomposition rule to public + # to avoid the elimination by the compiler + lower_callable(ctx, pyfun, func_jaxpr, public=True) return () diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 0f3c459ec6..26c27af4e6 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -46,7 +46,16 @@ def get_call_equation(jaxpr): def lower_jaxpr(ctx, jaxpr, context=None): - """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p""" + """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p + + Args: + ctx: LoweringRuleContext + jaxpr: JAXPR to be lowered + context: additional context to distinguish different FuncOps + + Returns: + FuncOp + """ equation = get_call_equation(jaxpr) call_jaxpr = equation.params["call_jaxpr"] callable_ = equation.params.get("fn") @@ -56,7 +65,8 @@ def lower_jaxpr(ctx, jaxpr, context=None): return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context) -def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): +# pylint: disable=too-many-arguments +def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False): """Lowers _callable to MLIR. If callable_ is a qnode, then we will first create a module, then @@ -68,6 +78,8 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): ctx: LoweringRuleContext callable_: python function call_jaxpr: jaxpr representing callable_ + public: whether the visibility should be marked public + Returns: FuncOp """ @@ -75,25 +87,49 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): pipeline = tuple() if not isinstance(callable_, qml.QNode): - return get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) + return get_or_create_funcop( + ctx, callable_, call_jaxpr, pipeline, context=context, public=public + ) return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) -def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None): - """Get funcOp from cache, or create it from scratch""" +# pylint: disable=too-many-arguments +def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False): + """Get funcOp from cache, or create it from scratch + + Args: + ctx: LoweringRuleContext + callable_: python function + call_jaxpr: jaxpr representing callable_ + context: additional context to distinguish different FuncOps + public: whether the visibility should be marked public + + Returns: + FuncOp + """ if context is None: context = tuple() key = (callable_, *context, *pipeline) if func_op := get_cached(ctx, key): return func_op - func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr) + func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public) cache(ctx, key, func_op) return func_op -def lower_callable_to_funcop(ctx, callable_, call_jaxpr): - """Lower callable to either a FuncOp""" +def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): + """Lower callable to either a FuncOp + + Args: + ctx: LoweringRuleContext + callable_: python function + call_jaxpr: jaxpr representing callable_ + public: whether the visibility should be marked public + + Returns: + FuncOp + """ if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) @@ -107,6 +143,11 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr): kwargs["jaxpr"] = call_jaxpr kwargs["effects"] = [] kwargs["name_stack"] = ctx.name_stack + + # Make the visibility of the function public=True + # to avoid elimination by the compiler + kwargs["public"] = public + func_op = mlir.lower_jaxpr_to_fun(**kwargs) if isinstance(callable_, qml.QNode): @@ -137,7 +178,8 @@ def only_single_expval(): func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) - if gateset := getattr(callable_, "decomp_gateset", []): + gateset = getattr(callable_, "decomp_gateset", []) + if gateset: func_op.attributes["decomp_gateset"] = get_mlir_attribute_from_pyval(gateset) return func_op diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 1b26fdb15b..0d26012b7a 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -289,7 +289,7 @@ def circuit(_: float): Hadamard0(int) return qml.probs() - # CHECK: func.func private @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit + # CHECK: func.func public @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit # CHECK-NEXT: [[QUBIT_OUT:%.+]] = quantum.custom "Hadamard"() [[QBIT]] : !quantum.bit # CHECK-NEXT: return [[QUBIT_OUT]] : !quantum.bit @@ -317,7 +317,7 @@ def circuit_2(_: float): RX_on_wire_0(float, int) return qml.probs() - # CHECK: func.func private @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit + # CHECK: func.func public @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit # CHECK-NEXT: [[PARAM:%.+]] = tensor.extract [[PARAM_TENSOR]][] : tensor # CHECK-NEXT: [[QUBIT_1:%.+]] = quantum.custom "RX"([[PARAM]]) [[QUBIT]] : !quantum.bit # CHECK-NEXT: return [[QUBIT_1]] : !quantum.bit @@ -356,8 +356,8 @@ def circuit_3(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @identity - # CHECK: func.func private @all_wires_rx + # CHECK: func.func public @identity + # CHECK: func.func public @all_wires_rx print(circuit_3.mlir) qml.capture.disable() @@ -385,7 +385,7 @@ def circuit_4(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg + # CHECK: func.func public @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg # CHECK-NEXT: [[IDX_0:%.+]] = stablehlo.slice [[QUBITS]] [0:1] : (tensor<3xi64>) -> tensor<1xi64> # CHECK-NEXT: [[RIDX_0:%.+]] = stablehlo.reshape [[IDX_0]] : (tensor<1xi64>) -> tensor # CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract [[RIDX_0]][] : tensor @@ -422,7 +422,7 @@ def circuit_5(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit) + # CHECK: func.func public @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit) print(circuit_5.mlir) qml.capture.disable() @@ -453,7 +453,7 @@ def circuit_6(): cond_RX(float, jax.core.ShapedArray((1,), int)) return qml.probs() - # CHECK: func.func private @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg + # CHECK: func.func public @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg # CHECK-NEXT: [[ZERO:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor # CHECK-NEXT: [[COND_TENSOR:%.+]] = stablehlo.compare NE, [[PARAM_TENSOR]], [[ZERO]], FLOAT : (tensor, tensor) -> tensor # CHECK-NEXT: [[COND:%.+]] = tensor.extract [[COND_TENSOR]][] : tensor @@ -480,17 +480,17 @@ def test_decomposition_rule_caller(): qml.capture.enable() @decomposition_rule(is_qreg=True) - def Op1_decomp(_: TensorLike, wires: WiresLike): + def Rule_Op1_decomp(_: TensorLike, wires: WiresLike): qml.Hadamard(wires=wires[0]) qml.Hadamard(wires=[1]) @decomposition_rule(is_qreg=True) - def Op2_decomp(param: TensorLike, wires: WiresLike): + def Rule_Op2_decomp(param: TensorLike, wires: WiresLike): qml.RX(param, wires=wires[0]) def decomps_caller(param: TensorLike, wires: WiresLike): - Op1_decomp(param, wires) - Op2_decomp(param, wires) + Rule_Op1_decomp(param, wires) + Rule_Op2_decomp(param, wires) @qml.qjit(autograph=False) @qml.qnode(qml.device("lightning.qubit", wires=1)) @@ -501,8 +501,8 @@ def circuit_7(): decomps_caller(float, jax.core.ShapedArray((2,), int)) return qml.probs() - # CHECK: func.func private @Op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK: func.func private @Op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK: func.func public @Rule_Op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK: func.func public @Rule_Op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg print(circuit_7.mlir) qml.capture.disable() diff --git a/test_new_decomp.py b/test_new_decomp.py index 5505e54f24..1d1ea58140 100644 --- a/test_new_decomp.py +++ b/test_new_decomp.py @@ -21,21 +21,21 @@ @decomposition_rule -def _ry_to_rz_rx(phi, wires: WiresLike, **__): +def Rule_ry_to_rz_rx(phi, wires: WiresLike, **__): qml.RZ(-np.pi / 2, wires=wires) qml.RX(phi, wires=wires) qml.RZ(np.pi / 2, wires=wires) @decomposition_rule -def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): +def Rule_rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): qml.RZ(phi, wires=wires) qml.RY(theta, wires=wires) qml.RZ(omega, wires=wires) @decomposition_rule -def _u2_phaseshift_rot(phi, delta, wires, **__): +def Rule_u2_phaseshift_rot(phi, delta, wires, **__): pi_half = qml.math.ones_like(delta) * (np.pi / 2) qml.Rot(delta, pi_half, -delta, wires=wires) qml.PhaseShift(delta, wires=wires) @@ -43,7 +43,7 @@ def _u2_phaseshift_rot(phi, delta, wires, **__): @decomposition_rule -def _xzx_decompose(phi, theta, omega, wires, **__): +def Rule_xzx_decompose(phi, theta, omega, wires, **__): qml.RX(phi, wires=wires) qml.RZ(theta, wires=wires) qml.RX(omega, wires=wires) @@ -59,10 +59,10 @@ def circuit(): qml.U2(0.4, 0.5, wires=2) RotXZX(0.6, 0.7, 0.8, wires=0) - _ry_to_rz_rx(0, 0) - _rot_to_rz_ry_rz(0, 0, 0, 1) - _u2_phaseshift_rot(0, 0, 2) - _xzx_decompose(0, 0, 0, 0) + Rule_ry_to_rz_rx(0, 0) + Rule_rot_to_rz_ry_rz(0, 0, 0, 1) + Rule_u2_phaseshift_rot(0, 0, 2) + Rule_xzx_decompose(0, 0, 0, 0) return qml.expval(qml.Z(0)) @@ -80,7 +80,7 @@ def circuit(): @qml.register_resources({qml.ftqc.RotXZX: 1}) @decomposition_rule -def _rot_to_xzx(phi, theta, omega, wires, **__): +def Rule_rot_to_xzx(phi, theta, omega, wires, **__): mat = qml.Rot.compute_matrix(phi, theta, omega) lam, theta, phi = qml.math.decomposition.xzx_rotation_angles(mat) qml.ftqc.RotXZX(lam, theta, phi, wires) @@ -90,18 +90,18 @@ def _rot_to_xzx(phi, theta, omega, wires, **__): @partial( qml.transforms.decompose, gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, - fixed_decomps={qml.Rot: _rot_to_xzx}, + fixed_decomps={qml.Rot: Rule_rot_to_xzx}, ) @qml.qnode(qml.device("null.qubit", wires=3)) def mbqc_circ(x: float, y: float): qml.RX(x, 0) qml.RY(y, 1) - _rot_to_xzx( + Rule_rot_to_xzx( float, float, float, int ) # this needs to be here to include the custom decomposition in the graph - _ry_to_rz_rx(float, int) - _xzx_decompose(float, float, float, int) + Rule_ry_to_rz_rx(float, int) + Rule_xzx_decompose(float, float, float, int) return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)) From 599baf0746af2dc9f8c6628762f138978f9adca4 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:37:55 -0400 Subject: [PATCH 06/36] Fix a couple of more issues --- frontend/catalyst/device/qjit_device.py | 12 ++-- frontend/catalyst/from_plxpr/from_plxpr.py | 54 +++-------------- frontend/catalyst/jax_primitives_utils.py | 4 +- frontend/catalyst/jax_tracer.py | 16 ++--- frontend/test/lit/test_decomposition.py | 60 +++++++++++++++++++ frontend/test/lit/test_from_plxpr.py | 2 +- .../from_plxpr/test_capture_integration.py | 4 +- test_new_decomp.py | 17 +++--- 8 files changed, 98 insertions(+), 71 deletions(-) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index b9cac2d619..d3787f46eb 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -83,7 +83,6 @@ "PSWAP", "QubitUnitary", "Rot", - "RotXZX", "RX", "RY", "RZ", @@ -109,11 +108,14 @@ RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] -# A list of operations that the can be represented -# in the Catalyst compiler. This is a superset of +# A list of operations that can be represented +# in the Catalyst compiler. This will be a superset of # the operations supported by the runtime. -# FIXME: ops with OpName(params, wires) signatures -# can be represented in the Catalyst compiler. +# FIXME: ops with OpName(params, wires) signatures can be +# represented in the Catalyst compiler. Unfortunately, +# the signature info is not sufficient as there are +# templates with the same signature that should be +# disambiguated. COMPILER_OPERATIONS = RUNTIME_OPERATIONS # The runtime interface does not care about specific gate properties, so set them all to True. diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 9f7deb650c..31be46d38f 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -35,7 +35,6 @@ from pennylane.capture.primitives import measure_prim as plxpr_measure_prim from pennylane.decomposition.collect_resource_ops import CollectResourceOps from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim -from pennylane.ops import Adjoint, Controlled, ControlledOp from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires from pennylane.transforms import cancel_inverses as pl_cancel_inverses from pennylane.transforms import commute_controlled as pl_commute_controlled @@ -183,7 +182,7 @@ def f(x): class WorkflowInterpreter(PlxprInterpreter): - """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst variant.""" + """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant.""" def __init__(self): self._pass_pipeline = [] @@ -192,43 +191,6 @@ def __init__(self): super().__init__() -def _decompose_to_compiler_gateset(qfunc_jaxpr, consts, non_const_args): - """First stage decomposition to compiler gate set. - - Currently, the compiler can only handle a limited set of gates and - may not support all generic gates and templates of the original circuit. - We perform a first stage decomposition to the compiler gate set, which includes - only a subset of the original gates that can be represented in MLIR using the - `quantum.custom` primitive. - """ - - # TODO: The compiler should be able to handle all gate - # adhering to quantum.custom primitive.This includes - # all the gates with parameters of type `TensorLike` - # and wires of type `WiresLike` with no hyperparams. - # Update `gate_set` to use this as the stopping condition - # of the decomposition transform. - gate_set = COMPILER_OPERATIONS - - decomp_args = () - decomp_kwargs = {"gate_set": gate_set} - - # disable the graph decomposition optimization - graph_decomp_status = False - if qml.decomposition.enabled_graph(): - graph_decomp_status = True - qml.decomposition.disable_graph() - - new_jaxpr = qml.transforms.decompose.plxpr_transform( - qfunc_jaxpr, consts, decomp_args, decomp_kwargs, *non_const_args - ) - - if graph_decomp_status: - qml.decomposition.enable_graph() - - return new_jaxpr - - # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( @@ -288,12 +250,10 @@ def calling_convention(*args): # pylint: disable=too-many-arguments -def handle_graph_decomposition(*args, inner_jaxpr, consts, non_const_args, targs, tkwargs): +def handle_graph_decomposition(*args, inner_jaxpr, consts, targs, tkwargs, compiler_gateset): """Handle the graph decomposition for a given JAXPR.""" - gate_set = COMPILER_OPERATIONS - decomp_kwargs = {"gate_set": gate_set} - + decomp_kwargs = {"gate_set": compiler_gateset} pmd_interpreter = PreMlirDecomposeInterpreter(*targs, **decomp_kwargs) def pmd_wrapper(*args): @@ -357,7 +317,10 @@ def get_operator_name(op): if isinstance(op, str): return op - return getattr(op._primitive, "name", "UnsupportedGate") + # Return NoNameOp if the operator has no _primitive.name attribute. + # This is to avoid errors when we capture the program + # as we deal with such ops later in the decomposition graph. + return getattr(op._primitive, "name", "NoNameOp") self.decomp_gateset = [get_operator_name(op) for op in self.decomp_gateset] @@ -365,14 +328,13 @@ def get_operator_name(op): # Then, construct and solve the graph-based decomposition # to get the optimized rules and lower them to PLxPR # to Catalyst JAXPR to MLIR. - final_jaxpr = handle_graph_decomposition( *args, inner_jaxpr=inner_jaxpr, consts=consts, - non_const_args=non_const_args, targs=targs, tkwargs=tkwargs, + compiler_gateset=COMPILER_OPERATIONS + self.decomp_gateset, ) return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 26c27af4e6..22c9f5c911 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -65,7 +65,7 @@ def lower_jaxpr(ctx, jaxpr, context=None): return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, too-many-positional-arguments def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False): """Lowers _callable to MLIR. @@ -94,7 +94,7 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, publ return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, too-many-positional-arguments def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False): """Get funcOp from cache, or create it from scratch diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 5d83b8f3a8..172e294361 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1205,7 +1205,12 @@ def apply_transforms( # TODO: Ideally we should allow qnode transforms that don't modify the measurements to # operate in the permissive tracing mode, but that currently leads to a small number of # test failures due to the different result format produced in trace_quantum_function. - if has_classical_outputs(flat_results): + only_with_dynamic_one_shot = all( + "dynamic_one_shot_partial" in str(getattr(qnode, "transform", "")) + for qnode in qnode_program + ) + + if has_classical_outputs(flat_results) and not only_with_dynamic_one_shot: msg = ( "Transforming MeasurementProcesses is unsupported with non-MeasurementProcess " "QNode outputs. The selected device, options, or applied QNode transforms, may be " @@ -1480,12 +1485,9 @@ def check_full_raise(arr, func): meas_results = tree_unflatten(meas_trees, meas_tracers) # TODO: Allow the user to return whatever types they specify. - if tracing_mode == TracingMode.TRANSFORM: - assert isinstance(meas_results, list) - if len(meas_results) == 1: - transformed_results.append(meas_results[0]) - else: - transformed_results.append(tuple(meas_results)) + if tracing_mode == TracingMode.TRANSFORM and isinstance(meas_results, list): + result = meas_results[0] if len(meas_results) == 1 else tuple(meas_results) + transformed_results.append(result) else: transformed_results.append(meas_results) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 0d26012b7a..4e676f219d 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -537,6 +537,15 @@ def test_decompose_gateset_with_graph(): qml.capture.enable() qml.decomposition.enable_graph() + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_9() -> tensor attributes {decomp_gateset = ["RX"] + def circuit_9(): + return qml.expval(qml.Z(0)) + + print(circuit_9.mlir) + @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @qml.qnode(qml.device("lightning.qubit", wires=1)) @@ -559,6 +568,15 @@ def test_decompose_gateset_operator_with_graph(): qml.capture.enable() qml.decomposition.enable_graph() + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={qml.RX}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_10() -> tensor attributes {decomp_gateset = ["RX"] + def circuit_10(): + return qml.expval(qml.Z(0)) + + print(circuit_10.mlir) + @qml.qjit(target="mlir") @partial( qml.transforms.decompose, gate_set={qml.RX, qml.RZ, "PauliZ", qml.PauliX, qml.Hadamard} @@ -570,8 +588,50 @@ def circuit_10(): print(circuit_10.mlir) + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, gate_set={qml.RX, qml.RZ, qml.PauliZ, qml.PauliX, qml.Hadamard} + ) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_11() -> tensor attributes {decomp_gateset + def circuit_11(): + return qml.expval(qml.Z(0)) + + print(circuit_11.mlir) + qml.decomposition.disable_graph() qml.capture.disable() test_decompose_gateset_operator_with_graph() + + +def test_decompose_gateset_with_rotxzx(): + """Test the decompose transform with a custom operator with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RotXZX"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_12() -> tensor attributes {decomp_gateset = ["RotXZX"] + def circuit_12(): + return qml.expval(qml.Z(0)) + + print(circuit_12.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={qml.ftqc.RotXZX}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_12() -> tensor attributes {decomp_gateset = ["RotXZX"] + def circuit_12(): + return qml.expval(qml.Z(0)) + + print(circuit_12.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_with_rotxzx() diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index 3564e86146..ab31c42fa4 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -347,7 +347,7 @@ def test_pass_application(): qml.capture.enable() - # TODO: is there an ordering issue here? + # TODO: was there an ordering issue here? @qml.qjit(target="mlir") @qml.transforms.merge_rotations @qml.transforms.cancel_inverses diff --git a/frontend/test/pytest/from_plxpr/test_capture_integration.py b/frontend/test/pytest/from_plxpr/test_capture_integration.py index 45bf6846e9..bbbd07935b 100644 --- a/frontend/test/pytest/from_plxpr/test_capture_integration.py +++ b/frontend/test/pytest/from_plxpr/test_capture_integration.py @@ -1291,7 +1291,7 @@ def test_transform_decompose_workflow(self, backend): qml.capture.enable() @qjit(target="mlir") - @partial(qml.transforms.decompose, gate_set=["RX", "RY", "RZ"]) + @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) @qml.qnode(qml.device(backend, wires=2)) def captured_circuit(x: float, y: float, z: float): qml.Rot(x, y, z, 0) @@ -1305,7 +1305,7 @@ def captured_circuit(x: float, y: float, z: float): # Capture disabled @qjit - @partial(qml.transforms.decompose, gate_set=["RX", "RY", "RZ"]) + @partial(qml.transforms.decompose, gate_set=[qml.RX, qml.RY, qml.RZ]) @qml.qnode(qml.device(backend, wires=2)) def circuit(x: float, y: float, z: float): qml.Rot(x, y, z, 0) diff --git a/test_new_decomp.py b/test_new_decomp.py index 1d1ea58140..aa69e3a44b 100644 --- a/test_new_decomp.py +++ b/test_new_decomp.py @@ -1,14 +1,10 @@ from functools import partial -import jax import numpy as np import pennylane as qml from pennylane.ftqc import RotXZX -from pennylane.typing import TensorLike from pennylane.wires import WiresLike -from catalyst import qjit -from catalyst.from_plxpr import from_plxpr from catalyst.jax_primitives import decomposition_rule qml.capture.enable() @@ -22,6 +18,7 @@ @decomposition_rule def Rule_ry_to_rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" qml.RZ(-np.pi / 2, wires=wires) qml.RX(phi, wires=wires) qml.RZ(np.pi / 2, wires=wires) @@ -29,6 +26,7 @@ def Rule_ry_to_rz_rx(phi, wires: WiresLike, **__): @decomposition_rule def Rule_rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" qml.RZ(phi, wires=wires) qml.RY(theta, wires=wires) qml.RZ(omega, wires=wires) @@ -36,6 +34,7 @@ def Rule_rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): @decomposition_rule def Rule_u2_phaseshift_rot(phi, delta, wires, **__): + """Decomposition of U2 gate using Rot and PhaseShift gates.""" pi_half = qml.math.ones_like(delta) * (np.pi / 2) qml.Rot(delta, pi_half, -delta, wires=wires) qml.PhaseShift(delta, wires=wires) @@ -44,6 +43,7 @@ def Rule_u2_phaseshift_rot(phi, delta, wires, **__): @decomposition_rule def Rule_xzx_decompose(phi, theta, omega, wires, **__): + """Decomposition of Rot gate using RX and RZ gates in XZX format.""" qml.RX(phi, wires=wires) qml.RZ(theta, wires=wires) qml.RX(omega, wires=wires) @@ -53,7 +53,7 @@ def Rule_xzx_decompose(phi, theta, omega, wires, **__): @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) @qml.qnode(qml.device("lightning.qubit", wires=3)) def circuit(): - + """Circuit to test custom decomposition rules.""" qml.RY(0.5, wires=0) qml.Rot(0.1, 0.2, 0.3, wires=1) qml.U2(0.4, 0.5, wires=2) @@ -81,12 +81,14 @@ def circuit(): @qml.register_resources({qml.ftqc.RotXZX: 1}) @decomposition_rule def Rule_rot_to_xzx(phi, theta, omega, wires, **__): + """Decomposition of Rot gate using RotXZX gate.""" mat = qml.Rot.compute_matrix(phi, theta, omega) lam, theta, phi = qml.math.decomposition.xzx_rotation_angles(mat) qml.ftqc.RotXZX(lam, theta, phi, wires) @qml.qjit() +@qml.transforms.merge_rotations @partial( qml.transforms.decompose, gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, @@ -94,12 +96,11 @@ def Rule_rot_to_xzx(phi, theta, omega, wires, **__): ) @qml.qnode(qml.device("null.qubit", wires=3)) def mbqc_circ(x: float, y: float): + """MBQC example to test custom decomposition to RotXZX.""" qml.RX(x, 0) qml.RY(y, 1) - Rule_rot_to_xzx( - float, float, float, int - ) # this needs to be here to include the custom decomposition in the graph + Rule_rot_to_xzx(float, float, float, int) Rule_ry_to_rz_rx(float, int) Rule_xzx_decompose(float, float, float, int) From 5465b931c841ba2fdb931e3bbce182b4b28b632d Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 11:41:42 -0400 Subject: [PATCH 07/36] Fix func-redefined --- frontend/test/lit/test_decomposition.py | 18 +++++++++--------- test_new_decomp.py | 6 ++++++ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 4e676f219d..d8c9fe33b9 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -540,11 +540,11 @@ def test_decompose_gateset_with_graph(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX"}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @circuit_9() -> tensor attributes {decomp_gateset = ["RX"] - def circuit_9(): + # CHECK: public @simple_circuit_9() -> tensor attributes {decomp_gateset = ["RX"] + def simple_circuit_9(): return qml.expval(qml.Z(0)) - print(circuit_9.mlir) + print(simple_circuit_9.mlir) @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @@ -571,11 +571,11 @@ def test_decompose_gateset_operator_with_graph(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={qml.RX}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @circuit_10() -> tensor attributes {decomp_gateset = ["RX"] - def circuit_10(): + # CHECK: public @simple_circuit_10() -> tensor attributes {decomp_gateset = ["RX"] + def simple_circuit_10(): return qml.expval(qml.Z(0)) - print(circuit_10.mlir) + print(simple_circuit_10.mlir) @qml.qjit(target="mlir") @partial( @@ -615,11 +615,11 @@ def test_decompose_gateset_with_rotxzx(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RotXZX"}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @circuit_12() -> tensor attributes {decomp_gateset = ["RotXZX"] - def circuit_12(): + # CHECK: public @simple_circuit_12() -> tensor attributes {decomp_gateset = ["RotXZX"] + def simple_circuit_12(): return qml.expval(qml.Z(0)) - print(circuit_12.mlir) + print(simple_circuit_12.mlir) @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={qml.ftqc.RotXZX}) diff --git a/test_new_decomp.py b/test_new_decomp.py index aa69e3a44b..75462d4e3a 100644 --- a/test_new_decomp.py +++ b/test_new_decomp.py @@ -1,3 +1,9 @@ +""" +This file contains a few tests for the end-to-end custom decomposition rules + +TODO: remove the file after testing +""" + from functools import partial import numpy as np From d4dcb1998b90f0d0ca9bff9ed9418f8b6144a813 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 14:10:23 -0400 Subject: [PATCH 08/36] Update comiled name of rules --- frontend/catalyst/jax_primitives_utils.py | 9 +- frontend/test/lit/test_decomposition.py | 132 ++++++++++++++++++++-- test_new_decomp.py | 26 ++--- 3 files changed, 146 insertions(+), 21 deletions(-) diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 22c9f5c911..1f380040b3 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -139,7 +139,14 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): name = callable_.__name__ else: name = callable_.func.__name__ + ".partial" - kwargs["name"] = name + + # Make the function name more descriptive if it is a decomposition rule. + # This is expected by the MLIR decomposition pass. + kwargs["name"] = ( + "rule" + name + if public and name[0] == "_" and ("_to_" in name or "decompos" in name) + else name + ) kwargs["jaxpr"] = call_jaxpr kwargs["effects"] = [] kwargs["name_stack"] = ctx.name_stack diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index d8c9fe33b9..495762edd1 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -6,6 +6,7 @@ from functools import partial import jax +import numpy as np import pennylane as qml from pennylane.devices.capabilities import OperatorProperties from pennylane.typing import TensorLike @@ -480,17 +481,17 @@ def test_decomposition_rule_caller(): qml.capture.enable() @decomposition_rule(is_qreg=True) - def Rule_Op1_decomp(_: TensorLike, wires: WiresLike): + def rule_op1_decomp(_: TensorLike, wires: WiresLike): qml.Hadamard(wires=wires[0]) qml.Hadamard(wires=[1]) @decomposition_rule(is_qreg=True) - def Rule_Op2_decomp(param: TensorLike, wires: WiresLike): + def rule_op2_decomp(param: TensorLike, wires: WiresLike): qml.RX(param, wires=wires[0]) def decomps_caller(param: TensorLike, wires: WiresLike): - Rule_Op1_decomp(param, wires) - Rule_Op2_decomp(param, wires) + rule_op1_decomp(param, wires) + rule_op2_decomp(param, wires) @qml.qjit(autograph=False) @qml.qnode(qml.device("lightning.qubit", wires=1)) @@ -501,9 +502,8 @@ def circuit_7(): decomps_caller(float, jax.core.ShapedArray((2,), int)) return qml.probs() - # CHECK: func.func public @Rule_Op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK: func.func public @Rule_Op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - + # CHECK: func.func public @rule_op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK: func.func public @rule_op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg print(circuit_7.mlir) qml.capture.disable() @@ -635,3 +635,121 @@ def circuit_12(): test_decompose_gateset_with_rotxzx() + + +def test_decomposition_rule_name_update(): + """Test the name of the decomposition rule is updated in the MLIR output.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @decomposition_rule + def _ry_to_rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + @decomposition_rule + def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + @decomposition_rule + def _u2_phaseshift_rot_decomposition(phi, delta, wires, **__): + """Decomposition of U2 gate using Rot and PhaseShift gates.""" + pi_half = qml.math.ones_like(delta) * (np.pi / 2) + qml.Rot(delta, pi_half, -delta, wires=wires) + qml.PhaseShift(delta, wires=wires) + qml.PhaseShift(phi, wires=wires) + + @decomposition_rule + def _xzx_decompose(phi, theta, omega, wires, **__): + """Decomposition of Rot gate using RX and RZ gates in XZX format.""" + qml.RX(phi, wires=wires) + qml.RZ(theta, wires=wires) + qml.RX(omega, wires=wires) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK: public @circuit_13() -> tensor attributes {decomp_gateset + def circuit_13(): + _ry_to_rz_rx(float, int) + _rot_to_rz_ry_rz(float, float, float, int) + _u2_phaseshift_rot_decomposition(float, float, int) + _xzx_decompose(float, float, float, int) + return qml.expval(qml.Z(0)) + + # CHECK: func.func public @rule_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg + # CHECK: func.func public @rule_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + # CHECK: func.func public @rule_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg + # CHECK: func.func public @rule_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + print(circuit_13.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_update() + + +def test_decomposition_rule_name_update(): + """Test the name of the decomposition rule is updated in the MLIR output.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @decomposition_rule + def _ry_to_rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + @decomposition_rule + def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + @decomposition_rule + def _u2_phaseshift_rot_decomposition(phi, delta, wires, **__): + """Decomposition of U2 gate using Rot and PhaseShift gates.""" + pi_half = qml.math.ones_like(delta) * (np.pi / 2) + qml.Rot(delta, pi_half, -delta, wires=wires) + qml.PhaseShift(delta, wires=wires) + qml.PhaseShift(phi, wires=wires) + + @decomposition_rule + def _xzx_decompose(phi, theta, omega, wires, **__): + """Decomposition of Rot gate using RX and RZ gates in XZX format.""" + qml.RX(phi, wires=wires) + qml.RZ(theta, wires=wires) + qml.RX(omega, wires=wires) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK: public @circuit_13() -> tensor attributes {decomp_gateset + def circuit_13(): + _ry_to_rz_rx(float, int) + _rot_to_rz_ry_rz(float, float, float, int) + _u2_phaseshift_rot_decomposition(float, float, int) + _xzx_decompose(float, float, float, int) + return qml.expval(qml.Z(0)) + + # CHECK: func.func public @rule_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg + # CHECK: func.func public @rule_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + # CHECK: func.func public @rule_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg + # CHECK: func.func public @rule_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + print(circuit_13.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_update() diff --git a/test_new_decomp.py b/test_new_decomp.py index 75462d4e3a..393be0e7fb 100644 --- a/test_new_decomp.py +++ b/test_new_decomp.py @@ -23,7 +23,7 @@ @decomposition_rule -def Rule_ry_to_rz_rx(phi, wires: WiresLike, **__): +def _ry_to_rz_rx(phi, wires: WiresLike, **__): """Decomposition of RY gate using RZ and RX gates.""" qml.RZ(-np.pi / 2, wires=wires) qml.RX(phi, wires=wires) @@ -31,7 +31,7 @@ def Rule_ry_to_rz_rx(phi, wires: WiresLike, **__): @decomposition_rule -def Rule_rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): +def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): """Decomposition of Rot gate using RZ and RY gates.""" qml.RZ(phi, wires=wires) qml.RY(theta, wires=wires) @@ -39,7 +39,7 @@ def Rule_rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): @decomposition_rule -def Rule_u2_phaseshift_rot(phi, delta, wires, **__): +def _u2_phaseshift_rot(phi, delta, wires, **__): """Decomposition of U2 gate using Rot and PhaseShift gates.""" pi_half = qml.math.ones_like(delta) * (np.pi / 2) qml.Rot(delta, pi_half, -delta, wires=wires) @@ -48,7 +48,7 @@ def Rule_u2_phaseshift_rot(phi, delta, wires, **__): @decomposition_rule -def Rule_xzx_decompose(phi, theta, omega, wires, **__): +def _xzx_decompose(phi, theta, omega, wires, **__): """Decomposition of Rot gate using RX and RZ gates in XZX format.""" qml.RX(phi, wires=wires) qml.RZ(theta, wires=wires) @@ -65,10 +65,10 @@ def circuit(): qml.U2(0.4, 0.5, wires=2) RotXZX(0.6, 0.7, 0.8, wires=0) - Rule_ry_to_rz_rx(0, 0) - Rule_rot_to_rz_ry_rz(0, 0, 0, 1) - Rule_u2_phaseshift_rot(0, 0, 2) - Rule_xzx_decompose(0, 0, 0, 0) + _ry_to_rz_rx(0, 0) + _rot_to_rz_ry_rz(0, 0, 0, 1) + _u2_phaseshift_rot(0, 0, 2) + _xzx_decompose(0, 0, 0, 0) return qml.expval(qml.Z(0)) @@ -86,7 +86,7 @@ def circuit(): @qml.register_resources({qml.ftqc.RotXZX: 1}) @decomposition_rule -def Rule_rot_to_xzx(phi, theta, omega, wires, **__): +def _rot_to_xzx(phi, theta, omega, wires, **__): """Decomposition of Rot gate using RotXZX gate.""" mat = qml.Rot.compute_matrix(phi, theta, omega) lam, theta, phi = qml.math.decomposition.xzx_rotation_angles(mat) @@ -98,7 +98,7 @@ def Rule_rot_to_xzx(phi, theta, omega, wires, **__): @partial( qml.transforms.decompose, gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, - fixed_decomps={qml.Rot: Rule_rot_to_xzx}, + fixed_decomps={qml.Rot: _rot_to_xzx}, ) @qml.qnode(qml.device("null.qubit", wires=3)) def mbqc_circ(x: float, y: float): @@ -106,9 +106,9 @@ def mbqc_circ(x: float, y: float): qml.RX(x, 0) qml.RY(y, 1) - Rule_rot_to_xzx(float, float, float, int) - Rule_ry_to_rz_rx(float, int) - Rule_xzx_decompose(float, float, float, int) + _rot_to_xzx(float, float, float, int) + _ry_to_rz_rx(float, int) + _xzx_decompose(float, float, float, int) return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)) From e6bf756ce962464fc682dc70a42d54543f98d7fe Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 14:14:07 -0400 Subject: [PATCH 09/36] Update --- frontend/test/lit/test_decomposition.py | 59 ------------------------- 1 file changed, 59 deletions(-) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 495762edd1..3cf407aaaa 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -637,65 +637,6 @@ def circuit_12(): test_decompose_gateset_with_rotxzx() -def test_decomposition_rule_name_update(): - """Test the name of the decomposition rule is updated in the MLIR output.""" - - qml.capture.enable() - qml.decomposition.enable_graph() - - @decomposition_rule - def _ry_to_rz_rx(phi, wires: WiresLike, **__): - """Decomposition of RY gate using RZ and RX gates.""" - qml.RZ(-np.pi / 2, wires=wires) - qml.RX(phi, wires=wires) - qml.RZ(np.pi / 2, wires=wires) - - @decomposition_rule - def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): - """Decomposition of Rot gate using RZ and RY gates.""" - qml.RZ(phi, wires=wires) - qml.RY(theta, wires=wires) - qml.RZ(omega, wires=wires) - - @decomposition_rule - def _u2_phaseshift_rot_decomposition(phi, delta, wires, **__): - """Decomposition of U2 gate using Rot and PhaseShift gates.""" - pi_half = qml.math.ones_like(delta) * (np.pi / 2) - qml.Rot(delta, pi_half, -delta, wires=wires) - qml.PhaseShift(delta, wires=wires) - qml.PhaseShift(phi, wires=wires) - - @decomposition_rule - def _xzx_decompose(phi, theta, omega, wires, **__): - """Decomposition of Rot gate using RX and RZ gates in XZX format.""" - qml.RX(phi, wires=wires) - qml.RZ(theta, wires=wires) - qml.RX(omega, wires=wires) - - @qml.qjit(target="mlir") - @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) - @qml.qnode(qml.device("lightning.qubit", wires=3)) - # CHECK: public @circuit_13() -> tensor attributes {decomp_gateset - def circuit_13(): - _ry_to_rz_rx(float, int) - _rot_to_rz_ry_rz(float, float, float, int) - _u2_phaseshift_rot_decomposition(float, float, int) - _xzx_decompose(float, float, float, int) - return qml.expval(qml.Z(0)) - - # CHECK: func.func public @rule_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg - # CHECK: func.func public @rule_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg - # CHECK: func.func public @rule_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg - # CHECK: func.func public @rule_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg - print(circuit_13.mlir) - - qml.decomposition.disable_graph() - qml.capture.disable() - - -test_decomposition_rule_name_update() - - def test_decomposition_rule_name_update(): """Test the name of the decomposition rule is updated in the MLIR output.""" From fbf5911ec33ae8c58fd628a5596a5068bb770a56 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 16:00:50 -0400 Subject: [PATCH 10/36] apply multiple decomp pass --- frontend/catalyst/from_plxpr/from_plxpr.py | 13 +++++++++++-- frontend/test/lit/test_from_plxpr.py | 3 +-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 31be46d38f..2b183f4728 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -187,7 +187,11 @@ class WorkflowInterpreter(PlxprInterpreter): def __init__(self): self._pass_pipeline = [] self.qubit_handler = None + + # Compiler options for the new decomposition system + self.compiler_decompose = False self.decomp_gateset = [] + super().__init__() @@ -304,6 +308,8 @@ def handle_transform( and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): + self.compiler_decompose = True + # Update the decomp_gateset to be used by the quantum kernel primitive self.decomp_gateset = tkwargs.get("gate_set", []) @@ -359,7 +365,7 @@ def wrapper(*args): return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) # Apply the corresponding Catalyst pass counterpart - self._pass_pipeline.append(Pass(catalyst_pass_name)) + self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) return self.eval(inner_jaxpr, consts, *non_const_args) @@ -386,13 +392,16 @@ def __init__(self, device, shots, qubit_handler, cache, *, control_wires=(), con # TODO: we assume the qreg value passed into a scope is the unique qreg in the scope # In other words, we assume no new qreg will be allocated in the scope self.qubit_handler = qubit_handler - self.decomp_gateset = [] self.subroutine_cache = cache self.control_wires = control_wires """Any control wires used for a subroutine.""" self.control_values = control_values """Any control values for executing a subroutine.""" + # Compiler options for the new decomposition system + self.compiler_decompose = False + self.decomp_gateset = [] + super().__init__() def interpret_operation(self, op, is_adjoint=False, control_values=(), control_wires=()): diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index ab31c42fa4..7fd8618cc6 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -347,10 +347,9 @@ def test_pass_application(): qml.capture.enable() - # TODO: was there an ordering issue here? @qml.qjit(target="mlir") - @qml.transforms.merge_rotations @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations @qml.qnode(dev) def circuit(): return qml.probs() From 6253c901d26baa9a7af62c8b58b961573c7c7166 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 16:38:11 -0400 Subject: [PATCH 11/36] pylint: disable=too-many-instance-attributes --- frontend/catalyst/from_plxpr/from_plxpr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 2b183f4728..9a310a4982 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -377,6 +377,7 @@ def wrapper(*args): register_transform(pl_transform, pass_name, decomposition) +# pylint: disable=too-many-instance-attributes class PLxPRToQuantumJaxprInterpreter(PlxprInterpreter): """ Unlike the previous interpreters which modified the getattr and setattr From bd5fc7c5d12ab857bd8dece1e72e8dfafe45eeb3 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 9 Sep 2025 18:23:37 -0400 Subject: [PATCH 12/36] Add decompose-lowering to the pass pipeline --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 + frontend/catalyst/passes/builtin_passes.py | 19 ++++++++ frontend/catalyst/passes/pass_api.py | 1 + mlir/include/Quantum/Transforms/Passes.h | 1 + mlir/include/Quantum/Transforms/Passes.td | 6 +++ .../Catalyst/Transforms/RegisterAllPasses.cpp | 1 + mlir/lib/Quantum/Transforms/CMakeLists.txt | 1 + .../Quantum/Transforms/decompose_lowering.cpp | 48 +++++++++++++++++++ 8 files changed, 79 insertions(+) create mode 100644 mlir/lib/Quantum/Transforms/decompose_lowering.cpp diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 9a310a4982..f9c73e47cd 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -342,6 +342,8 @@ def get_operator_name(op): tkwargs=tkwargs, compiler_gateset=COMPILER_OPERATIONS + self.decomp_gateset, ) + + self._pass_pipeline.insert(0, Pass("decompose-lowering")) return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) if catalyst_pass_name is None: diff --git a/frontend/catalyst/passes/builtin_passes.py b/frontend/catalyst/passes/builtin_passes.py index 2850e1a444..e4888ada6c 100644 --- a/frontend/catalyst/passes/builtin_passes.py +++ b/frontend/catalyst/passes/builtin_passes.py @@ -394,6 +394,25 @@ def circuit(x: float): return PassPipelineWrapper(qnode, "merge-rotations") +def decompose_lowering(qnode): + """ + Specify that the ``-decompose-lowering`` MLIR compiler pass + for applying the compiled decomposition rules to the QNode + recursively. + + Args: + fn (QNode): the QNode to apply the cancel inverses compiler pass to + + Returns: + ~.QNode: + + **Example** + // TODO: add example here + + """ + return PassPipelineWrapper(qnode, "decompose-lowering") + + def ions_decomposition(qnode): # pragma: nocover """ Specify that the ``--ions-decomposition`` MLIR compiler pass should be diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py index 37872a1f0c..2c59d53e68 100644 --- a/frontend/catalyst/passes/pass_api.py +++ b/frontend/catalyst/passes/pass_api.py @@ -377,6 +377,7 @@ def _API_name_to_pass_name(): "disentangle_cnot": "disentangle-CNOT", "disentangle_swap": "disentangle-SWAP", "merge_rotations": "merge-rotations", + "decompose_lowering": "decompose-lowering", "ions_decomposition": "ions-decomposition", "to_ppr": "to-ppr", "commute_ppr": "commute-ppr", diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 00f33d8fa4..33b25c0179 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -30,6 +30,7 @@ std::unique_ptr createRemoveChainedSelfInversePass(); std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createMergeRotationsPass(); +std::unique_ptr createDecomposeLoweringPass(); std::unique_ptr createDisentangleCNOTPass(); std::unique_ptr createDisentangleSWAPPass(); std::unique_ptr createIonsDecompositionPass(); diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index f0a344190e..918b2032dc 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -110,6 +110,12 @@ def MergeRotationsPass : Pass<"merge-rotations"> { let constructor = "catalyst::createMergeRotationsPass()"; } +def DecomposeLoweringPass : Pass<"decompose-lowering"> { + let summary = "Replace quantum operations with compiled decomposition rules."; + + let constructor = "catalyst::createDecomposeLoweringPass()"; +} + def DisentangleCNOTPass : Pass<"disentangle-CNOT"> { let summary = "Replace a CNOT gate with two single qubit gates whenever possible."; diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 0e7be8337b..73f8327216 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -68,6 +68,7 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createRegisterInactiveCallbackPass); mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); mlir::registerPass(catalyst::createMergeRotationsPass); + mlir::registerPass(catalyst::createDecomposeLoweringPass); mlir::registerPass(catalyst::createScatterLoweringPass); mlir::registerPass(catalyst::createStablehloLegalizeControlFlowPass); mlir::registerPass(catalyst::createStablehloLegalizeSortPass); diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index 3a244ac4d6..ddc54e3148 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ file(GLOB SRC SplitMultipleTapes.cpp merge_rotation.cpp MergeRotationsPatterns.cpp + decompose_lowering.cpp DisentangleSWAP.cpp DisentangleCNOT.cpp ions_decompositions.cpp diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp new file mode 100644 index 0000000000..8f0d6b638e --- /dev/null +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -0,0 +1,48 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "decompose-lowering" + +#include "Catalyst/IR/CatalystDialect.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" + +#include "Catalyst/IR/CatalystDialect.h" +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" + +using namespace llvm; +using namespace mlir; +using namespace catalyst::quantum; + +namespace catalyst { +namespace quantum { +#define GEN_PASS_DEF_DECOMPOSELOWERINGPASS +#define GEN_PASS_DECL_DECOMPOSELOWERINGPASS +#include "Quantum/Transforms/Passes.h.inc" + +struct DecomposeLoweringPass : public impl::DecomposeLoweringPassBase { + using impl::DecomposeLoweringPassBase::DecomposeLoweringPassBase; + + void runOnOperation() override { llvm::errs() << "Decompose Lowering Pass!\n"; } +}; + +} // namespace quantum + +std::unique_ptr createDecomposeLoweringPass() +{ + return std::make_unique(); +} + +} // namespace catalyst From a22a4c526c9ac510789473d1186d987ca8dc5315 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Wed, 10 Sep 2025 14:48:41 -0400 Subject: [PATCH 13/36] provide support for decomp to apply after/before other passes --- frontend/catalyst/from_plxpr/decompose.py | 297 +-------------------- frontend/catalyst/from_plxpr/from_plxpr.py | 103 +++---- frontend/catalyst/jax_primitives_utils.py | 4 +- frontend/test/lit/test_decomposition.py | 16 +- frontend/test/lit/test_from_plxpr.py | 77 ++++++ test_new_decomp.py | 13 + 6 files changed, 159 insertions(+), 351 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 8400db453d..82289124fe 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -13,18 +13,12 @@ # limitations under the License. """ A transform for the new MLIT-based Catalyst decomposition system. - -Note: this transform will be merged with the PennyLane decomposition transform as part of -the PennyLane <> Catalyst unification project. """ from __future__ import annotations -import warnings -from collections import ChainMap -from collections.abc import Callable, Generator, Iterable, Sequence -from functools import partial +from collections.abc import Callable, Iterable, Sequence import jax import pennylane as qml @@ -34,241 +28,10 @@ # GraphSolutionInterpreter: from pennylane.decomposition import DecompositionGraph -from pennylane.decomposition.decomposition_graph import DecompGraphSolution from pennylane.decomposition.utils import translate_op_alias from pennylane.operation import Operator -# pylint: disable=too-many-instance-attributes -class PreMlirDecomposeInterpreter(qml.capture.PlxprInterpreter): - """Plxpr Interpreter for applying the Catalyst compiler-specific decomposition transform - to callables or jaxpr when program capture is enabled. - - TODO: - - Enable graph-based for pre-mlir decomposition - (not priority for this stage -- needs further maintenance in PennyLane/decomposition) - - Add a more optimized support for PL's templates - - Note: - - This interpreter shares common code with PL's DecomposeInterpreter. - We will merge the two in the future near the completion of the unification project. - """ - - def __init__( - self, - *, - gate_set=None, - max_expansion=None, - ): # pylint: disable=too-many-arguments - - self.max_expansion = max_expansion - self._current_depth = 0 - self._target_gate_names = None - - # We use a ChainMap to store the environment frames, which allows us to push and pop - # environments without copying the interpreter instance when we evaluate a jaxpr of - # a dynamic decomposition. The name is different from the _env in the parent class - # (a dictionary) to avoid confusion. - self._env_map = ChainMap() - - gate_set, stopping_condition = _resolve_gate_set(gate_set) - self._gate_set = gate_set - self._stopping_condition = stopping_condition - - def setup(self) -> None: - """Setup the environment for the interpreter by pushing a new environment frame.""" - - # This is the local environment for the jaxpr evaluation, on the top of the stack, - # from which the interpreter reads and writes variables. - # ChainMap writes to the first dictionary in the chain by default. - self._env_map = self._env_map.new_child() - - def cleanup(self) -> None: - """Cleanup the environment by popping the top-most environment frame.""" - - # We delete the top-most environment frame after the evaluation is done. - self._env_map = self._env_map.parents - - def read(self, var): - """Extract the value corresponding to a variable.""" - return var.val if isinstance(var, jax.extend.core.Literal) else self._env_map[var] - - def stopping_condition(self, op: qml.operation.Operator) -> bool: - """Function to determine whether an operator needs to be decomposed or not. - - Args: - op (qml.operation.Operator): Operator to check. - - Returns: - bool: Whether ``op`` is valid or needs to be decomposed. ``True`` means - that the operator does not need to be decomposed. - """ - - if not op.has_decomposition: - if not self._stopping_condition(op): - warnings.warn( - f"Operator {op.name} does not define a decomposition and was not " - f"found in the target gate set. To remove this warning, add the operator " - f"name ({op.name}) or type ({type(op)}) to the gate set.", - UserWarning, - ) - return True - - return self._stopping_condition(op) - - def decompose_operation(self, op: qml.operation.Operator): - """Decompose a PennyLane operation instance if it does not satisfy the - provided gate set. - - Args: - op (Operator): a pennylane operator instance - - This method is only called when the operator's output is a dropped variable, - so the output will not affect later equations in the circuit. - - See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`. - """ - - if self._stopping_condition(op): - return self.interpret_operation(op) - - max_expansion = ( - self.max_expansion - self._current_depth if self.max_expansion is not None else None - ) - - with qml.capture.pause(): - decomposition = list( - _operator_decomposition_gen( - op, - self.stopping_condition, - max_expansion=max_expansion, - ) - ) - - return [self.interpret_operation(decomp_op) for decomp_op in decomposition] - - def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator): - """Creates and evaluates a Jaxpr of the plxpr decomposition of an operator.""" - - if self._stopping_condition(op): - return self.interpret_operation(op) - - if self.max_expansion is not None and self._current_depth >= self.max_expansion: - return self.interpret_operation(op) - - compute_qfunc_decomposition = op.compute_qfunc_decomposition - - args = (*op.parameters, *op.wires) - - jaxpr_decomp = qml.capture.make_plxpr( - partial(compute_qfunc_decomposition, **op.hyperparameters) - )(*args) - - self._current_depth += 1 - # We don't need to copy the interpreter here, as the jaxpr of the decomposition - # is evaluated with a new environment frame placed on top of the stack. - out = self.eval(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args) - self._current_depth -= 1 - - return out - - # pylint: disable=too-many-branches, too-many-locals - def eval(self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args) -> list: - """ - Evaluates a jaxpr, which can also be generated by a dynamic decomposition. - - Args: - jaxpr_decomp (jax.extend.core.Jaxpr): the Jaxpr to evaluate - consts (list[TensorLike]): the constant variables for the jaxpr - *args: the arguments to use in the evaluation - """ - - self.setup() - - for arg, invar in zip(args, jaxpr.invars, strict=True): - self._env_map[invar] = arg - for const, constvar in zip(consts, jaxpr.constvars, strict=True): - self._env_map[constvar] = const - - for eq in jaxpr.eqns: - - prim_type = getattr(eq.primitive, "prim_type", "") - custom_handler = self._primitive_registrations.get(eq.primitive, None) - - if custom_handler: - - invals = [self.read(invar) for invar in eq.invars] - outvals = custom_handler(self, *invals, **eq.params) - - elif prim_type == "operator": - outvals = self.interpret_operation_eqn(eq) - elif prim_type == "measurement": - outvals = self.interpret_measurement_eqn(eq) - else: - invals = [self.read(invar) for invar in eq.invars] - subfuns, params = eq.primitive.get_bind_params(eq.params) - outvals = eq.primitive.bind(*subfuns, *invals, **params) - - if not eq.primitive.multiple_results: - outvals = [outvals] - - for outvar, outval in zip(eq.outvars, outvals, strict=True): - self._env_map[outvar] = outval - - outvals = [] - for var in jaxpr.outvars: - outval = self.read(var) - if isinstance(outval, qml.operation.Operator): - outvals.append(self.interpret_operation(outval)) - else: - outvals.append(outval) - - self.cleanup() - - return outvals - - def interpret_operation_eqn(self, eqn: jax.extend.core.JaxprEqn): - """Interpret an equation corresponding to an operator. - - If the operator has a dynamic decomposition defined, this method will - create and evaluate the jaxpr of the decomposition using the :meth:`~.eval` method. - - Args: - eqn (jax.extend.core.JaxprEqn): a jax equation for an operator. - - See also: :meth:`~.interpret_operation`. - - """ - - invals = (self.read(invar) for invar in eqn.invars) - - with qml.QueuingManager.stop_recording(): - op = eqn.primitive.impl(*invals, **eqn.params) - - if not eqn.outvars[0].__class__.__name__ == "DropVar": - return op - - return self.decompose_operation(op) - - -# pylint: disable=too-many-arguments -@PreMlirDecomposeInterpreter.register_primitive(ctrl_transform_prim) -def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): - consts = invals[:n_consts] - args = invals[n_consts:-n_control] - control_wires = invals[-n_control:] - - unroller = ControlTransformInterpreter( - control_wires, control_values=control_values, work_wires=work_wires - ) - - def wrapper(*inner_args): - return unroller.eval(jaxpr, consts, *inner_args) - - jaxpr = jax.make_jaxpr(wrapper)(*args) - return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) - - # pylint: disable=too-few-public-methods class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): """Interpreter for getting the decomposition graph solution @@ -280,7 +43,7 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): def __init__( self, *, - operations, + operations=[], gate_set=None, fixed_decomps=None, alt_decomps=None, @@ -331,25 +94,6 @@ def eval(self, jaxpr: "jax.extend.core.Jaxpr", consts: Sequence, *args) -> list: alt_decomps=self._alt_decomps, ) - # for op, rule_impl in self._decomp_graph_solution.items(): - # # print(op, rule_impl) - # def compute_qfunc_decomp(): - # rule_impl(int) - - # if op.op.name in ("RX", "RY", "RZ", "PhaseShift", "Rot", "U1"): - # def compute_qfunc_decomp(): - # rule_impl(float, int) - # else: - # continue - - # jaxpr_decomp = qml.capture.make_plxpr( - # compute_qfunc_decomp - # )() - - # print(jaxpr_decomp) - # # out = self.eval(jaxpr_decomp.jaxpr, jaxpr_decomp.consts, tuple()) - # # print(out) - for eqn in jaxpr.eqns: primitive = eqn.primitive custom_handler = self._primitive_registrations.get(primitive, None) @@ -423,43 +167,6 @@ def interpret_operation(self, op): super().interpret_operation(ctrl_op) -def _operator_decomposition_gen( - op: qml.operation.Operator, - acceptance_function: Callable[[qml.operation.Operator], bool], - max_expansion: int | None = None, - current_depth=0, - decomp_graph_solution: DecompGraphSolution | None = None, -) -> Generator[qml.operation.Operator]: - """A generator that yields the next operation that is accepted.""" - - max_depth_reached = False - decomp = [] - - if max_expansion is not None and max_expansion <= current_depth: - max_depth_reached = True - - if acceptance_function(op) or max_depth_reached: - yield op - elif decomp_graph_solution is not None and decomp_graph_solution.is_solved_for(op): - op_rule = decomp_graph_solution.decomposition(op) - with qml.queuing.AnnotatedQueue() as decomposed_ops: - op_rule(*op.parameters, wires=op.wires, **op.hyperparameters) - decomp = decomposed_ops.queue - current_depth += 1 - else: - decomp = op.decomposition() - current_depth += 1 - - for sub_op in decomp: - yield from _operator_decomposition_gen( - sub_op, - acceptance_function, - max_expansion=max_expansion, - current_depth=current_depth, - decomp_graph_solution=decomp_graph_solution, - ) - - def _resolve_gate_set( gate_set: set[type | str] | dict[type | str, float] = None, stopping_condition: Callable[[qml.operation.Operator], bool] = None, diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index f9c73e47cd..904c0382dd 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -33,7 +33,6 @@ from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim from pennylane.capture.primitives import measure_prim as plxpr_measure_prim -from pennylane.decomposition.collect_resource_ops import CollectResourceOps from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires from pennylane.transforms import cancel_inverses as pl_cancel_inverses @@ -48,7 +47,6 @@ from catalyst.device.qjit_device import COMPILER_OPERATIONS from catalyst.from_plxpr.decompose import ( GraphSolutionInterpreter, - PreMlirDecomposeInterpreter, ) from catalyst.from_plxpr.qubit_handler import QubitHandler from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config @@ -189,8 +187,8 @@ def __init__(self): self.qubit_handler = None # Compiler options for the new decomposition system - self.compiler_decompose = False - self.decomp_gateset = [] + self.requires_compiler_decompose = False + self.decompose_gatesets = [] # queue of gatesets super().__init__() @@ -209,7 +207,16 @@ def handle_qnode( consts = args[shots_len : n_consts + shots_len] non_const_args = args[shots_len + n_consts :] - closed_jaxpr = ClosedJaxpr(qfunc_jaxpr, consts) + closed_jaxpr = ( + ClosedJaxpr(qfunc_jaxpr, consts) + if not self.requires_compiler_decompose + else handle_compiler_decompose( + inner_jaxpr=qfunc_jaxpr, + consts=consts, + ncargs=non_const_args, + tgatesets=self.decompose_gatesets, + ) + ) def calling_convention(*args): device_init_p.bind( @@ -227,7 +234,7 @@ def calling_convention(*args): return retvals # Add gate_set attribute to the quantum kernel primitive - setattr(qnode, "decomp_gateset", self.decomp_gateset) + setattr(qnode, "decompose_gatesets", self.decompose_gatesets) return quantum_kernel_p.bind( wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), @@ -253,28 +260,26 @@ def calling_convention(*args): } -# pylint: disable=too-many-arguments -def handle_graph_decomposition(*args, inner_jaxpr, consts, targs, tkwargs, compiler_gateset): - """Handle the graph decomposition for a given JAXPR.""" - - decomp_kwargs = {"gate_set": compiler_gateset} - pmd_interpreter = PreMlirDecomposeInterpreter(*targs, **decomp_kwargs) +def handle_compiler_decompose(inner_jaxpr, consts, tgatesets, ncargs): + """Handle the compiler-specific decomposition for a given JAXPR.""" - def pmd_wrapper(*args): - return pmd_interpreter.eval(inner_jaxpr, consts, *args) + # disable the graph decomposition optimization + is_graph = qml.decomposition.enabled_graph() + if is_graph: + qml.decomposition.disable_graph() - pmd_jaxpr = jax.make_jaxpr(pmd_wrapper)(*args) + # First perform the pre-mlir decomposition to simplify the jaxpr + # by decomposing high-level gates and templates + gate_set = COMPILER_OPERATIONS + list(set().union(*tgatesets)) - ops_collector = CollectResourceOps() - ops_collector.eval(pmd_jaxpr.jaxpr, consts, *args) - pl_ops = ops_collector.state["ops"] - - gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs, operations=pl_ops) + final_jaxpr = qml.transforms.decompose.plxpr_transform( + inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs + ) - def gds_wrapper(*args): - return gds_interpreter.eval(pmd_jaxpr.jaxpr, consts, *args) + if is_graph: + qml.decomposition.enable_graph() - return jax.make_jaxpr(gds_wrapper)(*args) + return final_jaxpr # pylint: disable-next=redefined-outer-name @@ -308,10 +313,8 @@ def handle_transform( and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): - self.compiler_decompose = True - - # Update the decomp_gateset to be used by the quantum kernel primitive - self.decomp_gateset = tkwargs.get("gate_set", []) + if not self.requires_compiler_decompose: + self.requires_compiler_decompose = True # A helper function to get the name of a pennylane operator def get_operator_name(op): @@ -328,23 +331,35 @@ def get_operator_name(op): # as we deal with such ops later in the decomposition graph. return getattr(op._primitive, "name", "NoNameOp") - self.decomp_gateset = [get_operator_name(op) for op in self.decomp_gateset] - - # First decompose to the compiler gateset. - # Then, construct and solve the graph-based decomposition - # to get the optimized rules and lower them to PLxPR - # to Catalyst JAXPR to MLIR. - final_jaxpr = handle_graph_decomposition( - *args, - inner_jaxpr=inner_jaxpr, - consts=consts, - targs=targs, - tkwargs=tkwargs, - compiler_gateset=COMPILER_OPERATIONS + self.decomp_gateset, - ) + # Update the decompose_gatesets to be used by the quantum kernel primitive + tgateset = tkwargs.get("gate_set", []) + + # We treat decompose_gatesets as a queue of gatesets to be used + # by the decompose-lowering pass at MLIR + self.decompose_gatesets.insert(0, [get_operator_name(op) for op in tgateset]) + + # Note. We don't perform the compiler-specific decomposition here + # to be able to support multiple decomposition transforms + # and collect all the required gatesets + # as well as being able to support other transforms in between. + # The compiler specific transformation will be performed + # in the qnode handler. + + # Add the decompose-lowering pass to the start of the pipeline self._pass_pipeline.insert(0, Pass("decompose-lowering")) - return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) + + # We still need to construct and solve the graph based on + # the current jaxpr based on the current gateset + # but we don't rewrite the jaxpr at this stage. + + gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs) + + def gds_wrapper(*args): + return gds_interpreter.eval(inner_jaxpr, consts, *args) + + final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) if catalyst_pass_name is None: # Use PL's ExpandTransformsInterpreter to expand this and any embedded @@ -401,10 +416,6 @@ def __init__(self, device, shots, qubit_handler, cache, *, control_wires=(), con self.control_values = control_values """Any control values for executing a subroutine.""" - # Compiler options for the new decomposition system - self.compiler_decompose = False - self.decomp_gateset = [] - super().__init__() def interpret_operation(self, op, is_adjoint=False, control_values=(), control_wires=()): diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 1f380040b3..a93b9b8078 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -185,9 +185,9 @@ def only_single_expval(): func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) - gateset = getattr(callable_, "decomp_gateset", []) + gateset = getattr(callable_, "decompose_gatesets", []) if gateset: - func_op.attributes["decomp_gateset"] = get_mlir_attribute_from_pyval(gateset) + func_op.attributes["decompose_gatesets"] = get_mlir_attribute_from_pyval(gateset) return func_op diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 3cf407aaaa..bfbb94dc38 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -540,7 +540,7 @@ def test_decompose_gateset_with_graph(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX"}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @simple_circuit_9() -> tensor attributes {decomp_gateset = ["RX"] + # CHECK: public @simple_circuit_9() -> tensor attributes {decompose_gatesets def simple_circuit_9(): return qml.expval(qml.Z(0)) @@ -549,7 +549,7 @@ def simple_circuit_9(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @circuit_9() -> tensor attributes {decomp_gateset + # CHECK: public @circuit_9() -> tensor attributes {decompose_gatesets def circuit_9(): return qml.expval(qml.Z(0)) @@ -571,7 +571,7 @@ def test_decompose_gateset_operator_with_graph(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={qml.RX}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @simple_circuit_10() -> tensor attributes {decomp_gateset = ["RX"] + # CHECK: public @simple_circuit_10() -> tensor attributes {decompose_gatesets def simple_circuit_10(): return qml.expval(qml.Z(0)) @@ -582,7 +582,7 @@ def simple_circuit_10(): qml.transforms.decompose, gate_set={qml.RX, qml.RZ, "PauliZ", qml.PauliX, qml.Hadamard} ) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @circuit_10() -> tensor attributes {decomp_gateset + # CHECK: public @circuit_10() -> tensor attributes {decompose_gatesets def circuit_10(): return qml.expval(qml.Z(0)) @@ -593,7 +593,7 @@ def circuit_10(): qml.transforms.decompose, gate_set={qml.RX, qml.RZ, qml.PauliZ, qml.PauliX, qml.Hadamard} ) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @circuit_11() -> tensor attributes {decomp_gateset + # CHECK: public @circuit_11() -> tensor attributes {decompose_gatesets def circuit_11(): return qml.expval(qml.Z(0)) @@ -615,7 +615,7 @@ def test_decompose_gateset_with_rotxzx(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RotXZX"}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @simple_circuit_12() -> tensor attributes {decomp_gateset = ["RotXZX"] + # CHECK: public @simple_circuit_12() -> tensor attributes {decompose_gatesets def simple_circuit_12(): return qml.expval(qml.Z(0)) @@ -624,7 +624,7 @@ def simple_circuit_12(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={qml.ftqc.RotXZX}) @qml.qnode(qml.device("lightning.qubit", wires=1)) - # CHECK: public @circuit_12() -> tensor attributes {decomp_gateset = ["RotXZX"] + # CHECK: public @circuit_12() -> tensor attributes {decompose_gatesets def circuit_12(): return qml.expval(qml.Z(0)) @@ -675,7 +675,7 @@ def _xzx_decompose(phi, theta, omega, wires, **__): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) @qml.qnode(qml.device("lightning.qubit", wires=3)) - # CHECK: public @circuit_13() -> tensor attributes {decomp_gateset + # CHECK: public @circuit_13() -> tensor attributes {decompose_gatesets def circuit_13(): _ry_to_rz_rx(float, int) _rot_to_rz_ry_rz(float, float, float, int) diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index 7fd8618cc6..3376e3b504 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -18,6 +18,8 @@ """Lit tests for the PLxPR to JAXPR with quantum primitives pipeline""" +from functools import partial + import pennylane as qml @@ -362,3 +364,78 @@ def circuit(): test_pass_application() + + +def test_pass_decomposition(): + """Application of pass decorator with decomposition.""" + + dev = qml.device("null.qubit", wires=1) + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(dev) + def circuit(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] + + print(circuit.mlir) + + @qml.qjit(target="mlir") + @qml.transforms.cancel_inverses + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] + + print(circuit.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" + # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[second_pass]] + + print(circuit.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX"}) + @qml.transforms.cancel_inverses + @partial(qml.transforms.decompose, gate_set={"RZ"}) + @qml.transforms.merge_rotations + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(dev) + def circuit(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: [[merge_rot:%.+]] = transform.apply_registered_pass "merge-rotations" to [[first_pass]] + # CHECK-NEXT: [[decomp_to_rz:%.+]] = transform.apply_registered_pass "decompose-lowering" to [[merge_rot]] + # CHECK-NEXT: [[remove_chained:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[decomp_to_rz]] + # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[remove_chained]] + + print(circuit.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_pass_decomposition() diff --git a/test_new_decomp.py b/test_new_decomp.py index 393be0e7fb..3c6ab32551 100644 --- a/test_new_decomp.py +++ b/test_new_decomp.py @@ -94,12 +94,25 @@ def _rot_to_xzx(phi, theta, omega, wires, **__): @qml.qjit() +@partial( + qml.transforms.decompose, + gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, + fixed_decomps={qml.Rot: _rot_to_xzx}, +) +@qml.transforms.cancel_inverses +@qml.transforms.merge_rotations +@partial( + qml.transforms.decompose, + gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, + fixed_decomps={qml.Rot: _rot_to_xzx}, +) @qml.transforms.merge_rotations @partial( qml.transforms.decompose, gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, fixed_decomps={qml.Rot: _rot_to_xzx}, ) +@qml.transforms.merge_rotations @qml.qnode(qml.device("null.qubit", wires=3)) def mbqc_circ(x: float, y: float): """MBQC example to test custom decomposition to RotXZX.""" From 6539064dfcbf5495d191da5af19e972d5094024f Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Wed, 10 Sep 2025 14:53:39 -0400 Subject: [PATCH 14/36] code format --- frontend/catalyst/from_plxpr/decompose.py | 4 +- frontend/test/lit/test_from_plxpr.py | 16 +-- test_new_decomp.py | 129 ---------------------- 3 files changed, 10 insertions(+), 139 deletions(-) delete mode 100644 test_new_decomp.py diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 82289124fe..1707fbb7b7 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -43,7 +43,7 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): def __init__( self, *, - operations=[], + operations=None, gate_set=None, fixed_decomps=None, alt_decomps=None, @@ -55,7 +55,7 @@ def __init__( "graph-based decomposition is enabled." ) - self._operations = operations + self._operations = [] if operations is None else operations self._decomp_graph_solution = {} self._target_gate_names = None self._fixed_decomps, self._alt_decomps = fixed_decomps, alt_decomps diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index 3376e3b504..a6c2bae85e 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -379,42 +379,42 @@ def test_pass_decomposition(): @qml.transforms.merge_rotations @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @qml.qnode(dev) - def circuit(): + def circuit1(): return qml.probs() # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "merge-rotations" # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] - print(circuit.mlir) + print(circuit1.mlir) @qml.qjit(target="mlir") @qml.transforms.cancel_inverses @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @qml.transforms.merge_rotations @qml.qnode(dev) - def circuit(): + def circuit2(): return qml.probs() # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] - print(circuit.mlir) + print(circuit2.mlir) @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @qml.transforms.cancel_inverses @qml.transforms.merge_rotations @qml.qnode(dev) - def circuit(): + def circuit3(): return qml.probs() # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[second_pass]] - print(circuit.mlir) + print(circuit3.mlir) @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX"}) @@ -423,7 +423,7 @@ def circuit(): @qml.transforms.merge_rotations @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @qml.qnode(dev) - def circuit(): + def circuit4(): return qml.probs() # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" @@ -432,7 +432,7 @@ def circuit(): # CHECK-NEXT: [[remove_chained:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[decomp_to_rz]] # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[remove_chained]] - print(circuit.mlir) + print(circuit4.mlir) qml.decomposition.disable_graph() qml.capture.disable() diff --git a/test_new_decomp.py b/test_new_decomp.py deleted file mode 100644 index 3c6ab32551..0000000000 --- a/test_new_decomp.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -This file contains a few tests for the end-to-end custom decomposition rules - -TODO: remove the file after testing -""" - -from functools import partial - -import numpy as np -import pennylane as qml -from pennylane.ftqc import RotXZX -from pennylane.wires import WiresLike - -from catalyst.jax_primitives import decomposition_rule - -qml.capture.enable() -qml.decomposition.enable_graph() - - -###################################### -# Custom decomposition rules -###################################### - - -@decomposition_rule -def _ry_to_rz_rx(phi, wires: WiresLike, **__): - """Decomposition of RY gate using RZ and RX gates.""" - qml.RZ(-np.pi / 2, wires=wires) - qml.RX(phi, wires=wires) - qml.RZ(np.pi / 2, wires=wires) - - -@decomposition_rule -def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): - """Decomposition of Rot gate using RZ and RY gates.""" - qml.RZ(phi, wires=wires) - qml.RY(theta, wires=wires) - qml.RZ(omega, wires=wires) - - -@decomposition_rule -def _u2_phaseshift_rot(phi, delta, wires, **__): - """Decomposition of U2 gate using Rot and PhaseShift gates.""" - pi_half = qml.math.ones_like(delta) * (np.pi / 2) - qml.Rot(delta, pi_half, -delta, wires=wires) - qml.PhaseShift(delta, wires=wires) - qml.PhaseShift(phi, wires=wires) - - -@decomposition_rule -def _xzx_decompose(phi, theta, omega, wires, **__): - """Decomposition of Rot gate using RX and RZ gates in XZX format.""" - qml.RX(phi, wires=wires) - qml.RZ(theta, wires=wires) - qml.RX(omega, wires=wires) - - -@qml.qjit() -@partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) -@qml.qnode(qml.device("lightning.qubit", wires=3)) -def circuit(): - """Circuit to test custom decomposition rules.""" - qml.RY(0.5, wires=0) - qml.Rot(0.1, 0.2, 0.3, wires=1) - qml.U2(0.4, 0.5, wires=2) - RotXZX(0.6, 0.7, 0.8, wires=0) - - _ry_to_rz_rx(0, 0) - _rot_to_rz_ry_rz(0, 0, 0, 1) - _u2_phaseshift_rot(0, 0, 2) - _xzx_decompose(0, 0, 0, 0) - - return qml.expval(qml.Z(0)) - - -print(circuit.mlir) - - -################################################### -# MBQC Example with custom decomposition to RotXZX -################################################### - -qml.decomposition.enable_graph() -qml.capture.enable() - - -@qml.register_resources({qml.ftqc.RotXZX: 1}) -@decomposition_rule -def _rot_to_xzx(phi, theta, omega, wires, **__): - """Decomposition of Rot gate using RotXZX gate.""" - mat = qml.Rot.compute_matrix(phi, theta, omega) - lam, theta, phi = qml.math.decomposition.xzx_rotation_angles(mat) - qml.ftqc.RotXZX(lam, theta, phi, wires) - - -@qml.qjit() -@partial( - qml.transforms.decompose, - gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, - fixed_decomps={qml.Rot: _rot_to_xzx}, -) -@qml.transforms.cancel_inverses -@qml.transforms.merge_rotations -@partial( - qml.transforms.decompose, - gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, - fixed_decomps={qml.Rot: _rot_to_xzx}, -) -@qml.transforms.merge_rotations -@partial( - qml.transforms.decompose, - gate_set={"X", "Y", "Z", "S", "H", "CNOT", "RZ", "RotXZX", "GlobalPhase"}, - fixed_decomps={qml.Rot: _rot_to_xzx}, -) -@qml.transforms.merge_rotations -@qml.qnode(qml.device("null.qubit", wires=3)) -def mbqc_circ(x: float, y: float): - """MBQC example to test custom decomposition to RotXZX.""" - qml.RX(x, 0) - qml.RY(y, 1) - - _rot_to_xzx(float, float, float, int) - _ry_to_rz_rx(float, int) - _xzx_decompose(float, float, float, int) - - return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)) - - -print(mbqc_circ.mlir) From c2f5ea2d72f355ec001bbeca25ad11c72fd4137f Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Mon, 15 Sep 2025 14:24:17 -0400 Subject: [PATCH 15/36] Tidy up --- frontend/catalyst/from_plxpr/decompose.py | 233 ++++++++------------- frontend/catalyst/from_plxpr/from_plxpr.py | 2 - 2 files changed, 84 insertions(+), 151 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 1707fbb7b7..94bbd38e2b 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -18,32 +18,59 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Sequence +import inspect +from collections.abc import Callable +from copy import copy +from typing import get_type_hints import jax import pennylane as qml -# Support ctrl ops in decomposition (adapted from PL's DecomposeInterpreter) -from pennylane.capture.primitives import ctrl_transform_prim - # GraphSolutionInterpreter: from pennylane.decomposition import DecompositionGraph -from pennylane.decomposition.utils import translate_op_alias -from pennylane.operation import Operator +from pennylane.measurements import MidMeasureMP +from pennylane.wires import WiresLike + +from catalyst.jax_primitives import decomposition_rule + + +def create_decomposition_rule(func: Callable, num_wires: int = 1): + """Create a decomposition rule from a function.""" + + sig_func = inspect.signature(func) + type_hints = get_type_hints(func) + + args = {} + for name in sig_func.parameters.keys(): + typ = type_hints.get(name, None) + + # Skip tailing kwargs in the rules + if name == "__": + continue + + if typ is float or name in ("phi", "theta", "omega", "delta"): + args[name] = float + elif typ is int: + args[name] = int + elif typ is WiresLike or name == "wires": + args[name] = qml.math.array([0] * num_wires, like="jax") + else: + raise ValueError( + f"Unsupported type annotation {typ} for parameter {name} in func {func}." + ) + + return decomposition_rule(func)(**args) # pylint: disable=too-few-public-methods class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): """Interpreter for getting the decomposition graph solution from a jaxpr when program capture is enabled. - - This interpreter should be used after the PreMlirDecomposeInterpreter. """ def __init__( self, *, - operations=None, gate_set=None, fixed_decomps=None, alt_decomps=None, @@ -55,165 +82,73 @@ def __init__( "graph-based decomposition is enabled." ) - self._operations = [] if operations is None else operations - self._decomp_graph_solution = {} - self._target_gate_names = None - self._fixed_decomps, self._alt_decomps = fixed_decomps, alt_decomps - - gate_set, _ = _resolve_gate_set(gate_set) self._gate_set = gate_set - self._env = {} + self._fixed_decomps = fixed_decomps + self._alt_decomps = alt_decomps + + self._captured = False + self._operations = set() + self._decomp_graph_solution = {} - # pylint: disable=too-many-branches, too-many-locals - def eval(self, jaxpr: "jax.extend.core.Jaxpr", consts: Sequence, *args) -> list: - """Evaluate a jaxpr. + def interpret_operation(self, op: "qml.operation.Operator"): + """Interpret a PennyLane operation instance. Args: - jaxpr (jax.extend.core.Jaxpr): the jaxpr to evaluate - consts (list[TensorLike]): the constant variables for the jaxpr - *args (tuple[TensorLike]): The arguments for the jaxpr. + op (Operator): a pennylane operator instance Returns: - list[TensorLike]: the results of the execution. + Any - """ - self._env = {} - self.setup() + This method is only called when the operator's output is a dropped variable, + so the output will not affect later equations in the circuit. - for arg, invar in zip(args, jaxpr.invars, strict=True): - self._env[invar] = arg - for const, constvar in zip(consts, jaxpr.constvars, strict=True): - self._env[constvar] = const + We cache the list of operations seen during the interpretation + to build the decomposition graph in the later stages. - if self._operations and not self._decomp_graph_solution: + See also: :meth:`~.interpret_operation_eqn`. - self._decomp_graph_solution = _solve_decomposition_graph( - self._operations, - self._gate_set, - fixed_decomps=self._fixed_decomps, - alt_decomps=self._alt_decomps, - ) + """ - for eqn in jaxpr.eqns: - primitive = eqn.primitive - custom_handler = self._primitive_registrations.get(primitive, None) - - if custom_handler: - invals = [self.read(invar) for invar in eqn.invars] - outvals = custom_handler(self, *invals, **eqn.params) - elif getattr(primitive, "prim_type", "") == "operator": - outvals = self.interpret_operation_eqn(eqn) - elif getattr(primitive, "prim_type", "") == "measurement": - outvals = self.interpret_measurement_eqn(eqn) - else: - invals = [self.read(invar) for invar in eqn.invars] - subfuns, params = primitive.get_bind_params(eqn.params) - outvals = primitive.bind(*subfuns, *invals, **params) - - if not primitive.multiple_results: - outvals = [outvals] - for outvar, outval in zip(eqn.outvars, outvals, strict=True): - self._env[outvar] = outval - - # Read the final result of the Jaxpr from the environment - outvals = [] - for var in jaxpr.outvars: - outval = self.read(var) - if isinstance(outval, qml.operation.Operator): - outvals.append(self.interpret_operation(outval)) - else: - outvals.append(outval) - self.cleanup() - self._env = {} - return outvals - - -# pylint: disable=too-many-arguments -@GraphSolutionInterpreter.register_primitive(ctrl_transform_prim) -def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts): - consts = invals[:n_consts] - args = invals[n_consts:-n_control] - control_wires = invals[-n_control:] - - unroller = ControlTransformInterpreter( - control_wires, control_values=control_values, work_wires=work_wires - ) + self._operations.add(op) + data, struct = jax.tree_util.tree_flatten(op) + return jax.tree_util.tree_unflatten(struct, data) - def wrapper(*inner_args): - return unroller.eval(jaxpr, consts, *inner_args) + def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess"): + """Interpret a measurement process instance. - jaxpr = jax.make_jaxpr(wrapper)(*args) - return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) + Args: + measurement (MeasurementProcess): a measurement instance. + See also :meth:`~.interpret_measurement_eqn`. -class ControlTransformInterpreter(qml.capture.PlxprInterpreter): - """Interpreter for replacing control transforms with individually controlled ops.""" + """ - def __init__(self, control_wires, control_values=None, work_wires=None): - super().__init__() - self.control_wires = control_wires - self.control_values = control_values - self.work_wires = work_wires + if not self._captured and not isinstance(measurement, MidMeasureMP): + self._captured = True + if self._fixed_decomps: + for rule in self._fixed_decomps.values(): + create_decomposition_rule(rule._impl) - def interpret_operation(self, op): - """Interpret operation.""" - with qml.capture.pause(): - ctrl_op = qml.ctrl( - op, - self.control_wires, - control_values=self.control_values, - work_wires=self.work_wires, + self._decomp_graph_solution = _solve_decomposition_graph( + self._operations, + self._gate_set, + fixed_decomps=self._fixed_decomps, + alt_decomps=self._alt_decomps, ) - super().interpret_operation(ctrl_op) - - -def _resolve_gate_set( - gate_set: set[type | str] | dict[type | str, float] = None, - stopping_condition: Callable[[qml.operation.Operator], bool] = None, -) -> tuple[set[type | str] | dict[type | str, float], Callable[[qml.operation.Operator], bool]]: - """Resolve the gate set and the stopping condition from arguments.""" - - if gate_set is None: - gate_set = set(qml.ops.__all__) - - if isinstance(gate_set, (str, type)): - gate_set = {gate_set} - - if isinstance(gate_set, dict): - - if any(v < 0 for v in gate_set.values()): - raise ValueError("Negative gate weights provided to gate_set are not supported.") - - if isinstance(gate_set, Iterable): - - gate_types = tuple(gate for gate in gate_set if isinstance(gate, type)) - gate_names = {translate_op_alias(gate) for gate in gate_set if isinstance(gate, str)} - - def gate_set_contains(op: Operator) -> bool: - return (op.name in gate_names) or isinstance(op, gate_types) - - elif isinstance(gate_set, Callable): # pylint:disable=isinstance-second-argument-not-valid-type - - gate_set_contains = gate_set - - else: - raise TypeError("Invalid gate_set type. Must be an iterable, dictionary, or function.") - - if stopping_condition: - - # Even when the user provides a stopping condition, we still need to check - # whether an operator belongs to the target gate set. This is to prevent - # the case of an operator missing the stopping condition but doesn't have - # a decomposition assigned due to being in the target gate set. - def _stopping_condition(op): - return gate_set_contains(op) or stopping_condition(op) - - else: - # If the stopping condition is not explicitly provided, the default is to simply check - # whether an operator belongs to the target gate set. - _stopping_condition = gate_set_contains - return gate_set, _stopping_condition + captured_ops = copy(self._operations) + for op, rule in self._decomp_graph_solution.items(): + for o in captured_ops: + if o.name == op.op.name: + create_decomposition_rule(rule, num_wires=len(o.wires)) + captured_ops.remove(o) + break + else: + # else query the number of wires by name + create_decomposition_rule(rule, num_wires=1) + + data, struct = jax.tree_util.tree_flatten(measurement) + return jax.tree_util.tree_unflatten(struct, data) # pylint: disable=protected-access diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 904c0382dd..aab8c97b2a 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -20,7 +20,6 @@ from typing import Callable import jax -import jax.core import jax.numpy as jnp import pennylane as qml from jax._src.sharding_impls import UNSPECIFIED @@ -607,7 +606,6 @@ def handle_decomposition_rule(self, *, pyfun, func_jaxpr, is_qreg, num_params): """ Transform a quantum decomposition rule from PLxPR into JAXPR with quantum primitives. """ - if is_qreg: self.qubit_handler.insert_all_dangling_qubits() From b2237c093f6c659be4e9f05c48fa5dcd24c7b307 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 16 Sep 2025 02:10:24 -0400 Subject: [PATCH 16/36] Update tests --- frontend/catalyst/from_plxpr/decompose.py | 63 +++++++-- frontend/catalyst/jax_primitives_utils.py | 8 +- frontend/test/lit/test_decomposition.py | 148 +++++++++++++++++++++- 3 files changed, 193 insertions(+), 26 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 94bbd38e2b..1cb8afb46c 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -33,8 +33,45 @@ from catalyst.jax_primitives import decomposition_rule - -def create_decomposition_rule(func: Callable, num_wires: int = 1): +COMPILER_OPERATIONS_NUM_WIRES = { + "CNOT": 2, + "ControlledPhaseShift": 2, + "CRot": 2, + "CRX": 2, + "CRY": 2, + "CRZ": 2, + "CSWAP": 3, + "CY": 2, + "CZ": 2, + "Hadamard": 1, + "Identity": 1, + "IsingXX": 2, + "IsingXY": 2, + "IsingYY": 2, + "IsingZZ": 2, + "SingleExcitation": 2, + "DoubleExcitation": 4, + "ISWAP": 2, + "PauliX": 1, + "PauliY": 1, + "PauliZ": 1, + "PhaseShift": 1, + "PSWAP": 2, + "Rot": 1, + "RX": 1, + "RY": 1, + "RZ": 1, + "S": 1, + "SWAP": 2, + "T": 1, + "Toffoli": 3, + "U1": 1, + "U2": 1, + "U3": 1, +} + + +def create_decomposition_rule(func: Callable, op_name: str, num_wires: int): """Create a decomposition rule from a function.""" sig_func = inspect.signature(func) @@ -59,6 +96,10 @@ def create_decomposition_rule(func: Callable, num_wires: int = 1): f"Unsupported type annotation {typ} for parameter {name} in func {func}." ) + # Update the name of decomposition rule + rule_name = "_rule" if func.__name__[0] == "_" else "_rule_" + func.__name__ = op_name + rule_name + func.__name__ + "_wires_" + str(num_wires) + return decomposition_rule(func)(**args) @@ -125,10 +166,6 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess if not self._captured and not isinstance(measurement, MidMeasureMP): self._captured = True - if self._fixed_decomps: - for rule in self._fixed_decomps.values(): - create_decomposition_rule(rule._impl) - self._decomp_graph_solution = _solve_decomposition_graph( self._operations, self._gate_set, @@ -138,14 +175,14 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess captured_ops = copy(self._operations) for op, rule in self._decomp_graph_solution.items(): - for o in captured_ops: - if o.name == op.op.name: - create_decomposition_rule(rule, num_wires=len(o.wires)) - captured_ops.remove(o) - break + + if (o := next((o for o in captured_ops if o.name == op.op.name), None)) is not None: + create_decomposition_rule(rule, op_name=op.op.name, num_wires=len(o.wires)) + elif op.op.name in COMPILER_OPERATIONS_NUM_WIRES: + num_wires = COMPILER_OPERATIONS_NUM_WIRES[op.op.name] + create_decomposition_rule(rule, op_name=op.op.name, num_wires=num_wires) else: - # else query the number of wires by name - create_decomposition_rule(rule, num_wires=1) + raise ValueError(f"Could not capture {op} without the number of wires.") data, struct = jax.tree_util.tree_flatten(measurement) return jax.tree_util.tree_unflatten(struct, data) diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index a93b9b8078..7faa679c7e 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -140,13 +140,7 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): else: name = callable_.func.__name__ + ".partial" - # Make the function name more descriptive if it is a decomposition rule. - # This is expected by the MLIR decomposition pass. - kwargs["name"] = ( - "rule" + name - if public and name[0] == "_" and ("_to_" in name or "decompos" in name) - else name - ) + kwargs["name"] = name kwargs["jaxpr"] = call_jaxpr kwargs["effects"] = [] kwargs["name_stack"] = ctx.name_stack diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index bfbb94dc38..1dc361b253 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -549,6 +549,7 @@ def simple_circuit_9(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" # CHECK: public @circuit_9() -> tensor attributes {decompose_gatesets def circuit_9(): return qml.expval(qml.Z(0)) @@ -593,6 +594,7 @@ def circuit_10(): qml.transforms.decompose, gate_set={qml.RX, qml.RZ, qml.PauliZ, qml.PauliX, qml.Hadamard} ) @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" # CHECK: public @circuit_11() -> tensor attributes {decompose_gatesets def circuit_11(): return qml.expval(qml.Z(0)) @@ -624,6 +626,7 @@ def simple_circuit_12(): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={qml.ftqc.RotXZX}) @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" # CHECK: public @circuit_12() -> tensor attributes {decompose_gatesets def circuit_12(): return qml.expval(qml.Z(0)) @@ -637,8 +640,8 @@ def circuit_12(): test_decompose_gateset_with_rotxzx() -def test_decomposition_rule_name_update(): - """Test the name of the decomposition rule is updated in the MLIR output.""" +def test_decomposition_rule_name(): + """Test the name of the decomposition rule is not updated with circuit instantiation.""" qml.capture.enable() qml.decomposition.enable_graph() @@ -675,6 +678,7 @@ def _xzx_decompose(phi, theta, omega, wires, **__): @qml.qjit(target="mlir") @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" # CHECK: public @circuit_13() -> tensor attributes {decompose_gatesets def circuit_13(): _ry_to_rz_rx(float, int) @@ -683,14 +687,146 @@ def circuit_13(): _xzx_decompose(float, float, float, int) return qml.expval(qml.Z(0)) - # CHECK: func.func public @rule_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg - # CHECK: func.func public @rule_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg - # CHECK: func.func public @rule_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg - # CHECK: func.func public @rule_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + # CHECK: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg + # CHECK: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + # CHECK: func.func public @_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg + # CHECK: func.func public @_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg print(circuit_13.mlir) qml.decomposition.disable_graph() qml.capture.disable() +test_decomposition_rule_name() + + +def test_decomposition_rule_name_update(): + """Test the name of the decomposition rule is updated in the MLIR output.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RZ: 2, qml.RX: 1}) + def rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + @qml.register_resources({qml.RZ: 2, qml.RY: 1}) + def rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + @qml.register_resources({qml.RY: 1, qml.PhaseShift: 1}) + def ry_gp(wires: WiresLike, **__): + """Decomposition of PauliY gate using RY and GlobalPhase gates.""" + qml.RY(np.pi, wires=wires) + qml.GlobalPhase(-np.pi / 2, wires=wires) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RZ", "PhaseShift"}, + fixed_decomps={ + qml.RY: rz_rx, + qml.Rot: rz_ry_rz, + qml.PauliY: ry_gp, + }, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_14() -> tensor attributes {decompose_gatesets + def circuit_14(): + qml.RY(0.5, wires=0) + qml.Rot(0.1, 0.2, 0.3, wires=1) + qml.PauliY(wires=2) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @Rot_rule_rz_ry_rz_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @RY_rule_rz_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @PauliY_rule_ry_gp_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + print(circuit_14.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + test_decomposition_rule_name_update() + + +def test_decomposition_rule_name_update_multi_qubits(): + """Test the name of the decomposition rule with multi-qubit gates.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", "CNOT", "Hadamard", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_15() -> tensor attributes {decompose_gatesets + def circuit_15(): + qml.SingleExcitation(0.5, wires=[0, 1]) + qml.SingleExcitationPlus(0.5, wires=[0, 1]) + qml.SingleExcitationMinus(0.5, wires=[0, 1]) + qml.DoubleExcitation(0.5, wires=[0, 1, 2, 3]) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @SingleExcitationPlus_rule_single_excitation_plus_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @CY_rule_cy_wires_2(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @CRY_rule_cry_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @S_rule_s_phaseshift_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @PhaseShift_rule_phaseshift_to_rz_gp_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @RZ_rule_rz_to_ry_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @Rot_rule_rot_to_rz_ry_rz_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @DoubleExcitation_rule_doublexcit_wires_4(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @SingleExcitationMinus_rule_single_excitation_minus_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @SingleExcitation_rule_single_excitation_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + print(circuit_15.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_update_multi_qubits() + + +def test_decomposition_rule_name_adjoint(): + """Test decomposition rule with qml.adjoint.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", "CZ", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK: public @circuit_16() -> tensor attributes {decompose_gatesets + def circuit_16(): + # CHECK-DAG: %1 = quantum.adjoint(%0) : !quantum.reg + # CHECK-DAG: %2 = quantum.adjoint(%1) : !quantum.reg + # CHECK-DAG: %3 = quantum.adjoint(%2) : !quantum.reg + # CHECK-DAG: %4 = quantum.adjoint(%3) : !quantum.reg + qml.adjoint(qml.CNOT)(wires=[0, 1]) + qml.adjoint(qml.Hadamard)(wires=2) + qml.adjoint(qml.RZ)(0.5, wires=3) + qml.adjoint(qml.SingleExcitation)(0.1, wires=[0, 1]) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @CNOT_rule_cnot_to_cz_h_wires_2(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @Hadamard_rule_hadamard_to_rz_ry_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @SingleExcitation_rule_SingleExcitation_rule_single_excitation_decomp_wires_2_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + print(circuit_16.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + +test_decomposition_rule_name_adjoint() From 51c717958487606c52a2ea5bdb4ea20e7a534c8d Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 16 Sep 2025 07:08:27 -0400 Subject: [PATCH 17/36] Apply code review suggestions --- frontend/catalyst/from_plxpr/decompose.py | 9 ++----- frontend/catalyst/from_plxpr/from_plxpr.py | 11 ++++---- frontend/test/lit/test_decomposition.py | 30 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 1cb8afb46c..28c2933c56 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -A transform for the new MLIT-based Catalyst decomposition system. +A transform for the new MLIR-based Catalyst decomposition system. """ @@ -175,7 +175,6 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess captured_ops = copy(self._operations) for op, rule in self._decomp_graph_solution.items(): - if (o := next((o for o in captured_ops if o.name == op.op.name), None)) is not None: create_decomposition_rule(rule, op_name=op.op.name, num_wires=len(o.wires)) elif op.op.name in COMPILER_OPERATIONS_NUM_WIRES: @@ -211,11 +210,7 @@ def is_solved_for(op): and solutions._all_op_indices[op] in solutions._visitor.distances ) - for ( - op_node, - op_node_idx, - ) in solutions._all_op_indices.items(): - + for op_node, op_node_idx in solutions._all_op_indices.items(): if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: d_node_idx = solutions._visitor.predecessors[op_node_idx] decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index aab8c97b2a..f04470b812 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -209,7 +209,7 @@ def handle_qnode( closed_jaxpr = ( ClosedJaxpr(qfunc_jaxpr, consts) if not self.requires_compiler_decompose - else handle_compiler_decompose( + else apply_compiler_decompose_to_plxpr( inner_jaxpr=qfunc_jaxpr, consts=consts, ncargs=non_const_args, @@ -232,8 +232,9 @@ def calling_convention(*args): device_release_p.bind() return retvals - # Add gate_set attribute to the quantum kernel primitive - setattr(qnode, "decompose_gatesets", self.decompose_gatesets) + if self.requires_compiler_decompose: + # Add gate_set attribute to the quantum kernel primitive + setattr(qnode, "decompose_gatesets", self.decompose_gatesets) return quantum_kernel_p.bind( wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), @@ -259,8 +260,8 @@ def calling_convention(*args): } -def handle_compiler_decompose(inner_jaxpr, consts, tgatesets, ncargs): - """Handle the compiler-specific decomposition for a given JAXPR.""" +def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgatesets, ncargs): + """Apply the compiler-specific decomposition for a given JAXPR.""" # disable the graph decomposition optimization is_graph = qml.decomposition.enabled_graph() diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 1dc361b253..06a4f526ad 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -829,4 +829,34 @@ def circuit_16(): qml.decomposition.disable_graph() qml.capture.disable() + test_decomposition_rule_name_adjoint() + + +def test_decomposition_rule_name_ctrl(): + """Test decomposition rule with qml.ctrl.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RZ"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=5)) + # CHECK: public @circuit_17() -> tensor attributes {decompose_gatesets + def circuit_17(): + # CHECK: %out_qubits:2 = quantum.custom "CRY"(%cst) %1, %2 : !quantum.bit, !quantum.bit + qml.ctrl(qml.RY, control=0)(0.5, 1) + qml.ctrl(qml.PauliX, control=0)(1) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @RY_rule_ry_to_rz_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + print(circuit_17.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_ctrl() From 2a4d1b48cfc004def9adcfe1e746d6499e3407a9 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 18 Sep 2025 00:25:26 -0400 Subject: [PATCH 18/36] Apply code review suggestions --- frontend/catalyst/from_plxpr/decompose.py | 284 ++++++++++------ frontend/catalyst/from_plxpr/from_plxpr.py | 18 +- frontend/catalyst/jax_primitives_utils.py | 13 +- frontend/catalyst/passes/builtin_passes.py | 1 + frontend/test/lit/test_decomposition.py | 34 +- .../from_plxpr/test_from_plxpr_decompose.py | 314 ------------------ 6 files changed, 220 insertions(+), 444 deletions(-) delete mode 100644 frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 28c2933c56..a1ac8b202f 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -33,82 +33,76 @@ from catalyst.jax_primitives import decomposition_rule -COMPILER_OPERATIONS_NUM_WIRES = { - "CNOT": 2, - "ControlledPhaseShift": 2, - "CRot": 2, - "CRX": 2, - "CRY": 2, - "CRZ": 2, - "CSWAP": 3, - "CY": 2, - "CZ": 2, - "Hadamard": 1, - "Identity": 1, - "IsingXX": 2, - "IsingXY": 2, - "IsingYY": 2, - "IsingZZ": 2, - "SingleExcitation": 2, - "DoubleExcitation": 4, - "ISWAP": 2, - "PauliX": 1, - "PauliY": 1, - "PauliZ": 1, - "PhaseShift": 1, - "PSWAP": 2, - "Rot": 1, - "RX": 1, - "RY": 1, - "RZ": 1, - "S": 1, - "SWAP": 2, - "T": 1, - "Toffoli": 3, - "U1": 1, - "U2": 1, - "U3": 1, -} - - -def create_decomposition_rule(func: Callable, op_name: str, num_wires: int): - """Create a decomposition rule from a function.""" - - sig_func = inspect.signature(func) - type_hints = get_type_hints(func) - - args = {} - for name in sig_func.parameters.keys(): - typ = type_hints.get(name, None) - - # Skip tailing kwargs in the rules - if name == "__": - continue - - if typ is float or name in ("phi", "theta", "omega", "delta"): - args[name] = float - elif typ is int: - args[name] = int - elif typ is WiresLike or name == "wires": - args[name] = qml.math.array([0] * num_wires, like="jax") - else: - raise ValueError( - f"Unsupported type annotation {typ} for parameter {name} in func {func}." - ) - - # Update the name of decomposition rule - rule_name = "_rule" if func.__name__[0] == "_" else "_rule_" - func.__name__ = op_name + rule_name + func.__name__ + "_wires_" + str(num_wires) - - return decomposition_rule(func)(**args) - # pylint: disable=too-few-public-methods class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): """Interpreter for getting the decomposition graph solution from a jaxpr when program capture is enabled. + + This interpreter captures all operations seen during the interpretation + and builds a decomposition graph to find efficient decomposition pathways + to a target gate set. + + This interpreter should be used with `qml.decomposition.enable_graph()` + to enable graph-based decomposition. + + Note that this doesn't actually decompose the operations during interpretation. + It only captures the operations and builds the decomposition graph. + The actual decomposition is done later in the MLIR decomposition pass. + + See also: :class:`~.DecompositionGraph`. + + Args: + gate_set (set[Operator] or None): The target gate set to decompose to + fixed_decomps (dict or None): A dictionary of fixed decomposition rules + to use in the decomposition graph. + alt_decomps (dict or None): A dictionary of alternative decomposition rules + to use in the decomposition graph. + + Raises: + TypeError: if graph-based decomposition is not enabled. """ + # A mapping from operation names to the number of wires they act on. + # This is used when the operation is not in the captured operations + # but we still need to create a decomposition rule for it. + COMPILER_OPERATIONS_NUM_WIRES: dict[str, int] = { + "CNOT": 2, + "ControlledPhaseShift": 2, + "CRot": 2, + "CRX": 2, + "CRY": 2, + "CRZ": 2, + "CSWAP": 3, + "CY": 2, + "CZ": 2, + "Hadamard": 1, + "Identity": 1, + "IsingXX": 2, + "IsingXY": 2, + "IsingYY": 2, + "IsingZZ": 2, + "SingleExcitation": 2, + "DoubleExcitation": 4, + "ISWAP": 2, + "PauliX": 1, + "PauliY": 1, + "PauliZ": 1, + "PhaseShift": 1, + "PSWAP": 2, + "Rot": 1, + "RX": 1, + "RY": 1, + "RZ": 1, + "S": 1, + "SWAP": 2, + "T": 1, + "Toffoli": 3, + "U1": 1, + "U2": 1, + "U3": 1, + } + def __init__( self, *, @@ -117,7 +111,7 @@ def __init__( alt_decomps=None, ): # pylint: disable=too-many-arguments - if not qml.decomposition.enabled_graph(): + if not qml.decomposition.enabled_graph(): # pragma: no cover raise TypeError( "The GraphSolutionInterpreter can only be used when" "graph-based decomposition is enabled." @@ -164,55 +158,139 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess """ - if not self._captured and not isinstance(measurement, MidMeasureMP): + # If we haven't captured and compiled the decomposition rules yet, + if not self._captured: + # Capture the current operations and mark as captured self._captured = True - self._decomp_graph_solution = _solve_decomposition_graph( + + # Solve the decomposition graph to get the decomposition rules + # for all the captured operations + # I know it's a bit hacky to do this here, but it's the only + # place where we can be sure that we have seen all operations + # in the circuit before the measurement. + # TODO: Find a better way to do this. + self._decomp_graph_solution = self._solve_decomposition_graph( self._operations, self._gate_set, fixed_decomps=self._fixed_decomps, alt_decomps=self._alt_decomps, ) - captured_ops = copy(self._operations) + # Create decomposition rules for each operation in the solution + # and compile them to Catalyst JAXPR decomposition rules for op, rule in self._decomp_graph_solution.items(): - if (o := next((o for o in captured_ops if o.name == op.op.name), None)) is not None: - create_decomposition_rule(rule, op_name=op.op.name, num_wires=len(o.wires)) - elif op.op.name in COMPILER_OPERATIONS_NUM_WIRES: - num_wires = COMPILER_OPERATIONS_NUM_WIRES[op.op.name] - create_decomposition_rule(rule, op_name=op.op.name, num_wires=num_wires) - else: + if ( + o := next((o for o in self._operations if o.name == op.op.name), None) + ) is not None: + # TODO: This assumes that the operation names are unique in the circuit. + # If there are multiple operations with the same name but different number of wires, + # this will only capture the first one. + self._create_decomposition_rule( + rule, op_name=op.op.name, num_wires=len(o.wires) + ) + elif op.op.name in self.COMPILER_OPERATIONS_NUM_WIRES: + # In this part, we need to handle the case where an operation in the decomposition graph solution + # is not in the captured operations. This can happen if the operation is not directly called + # in the circuit, but is used inside a decomposition rule. In this case, we + # fall back to using the COMPILER_OPERATIONS_NUM_WIRES dictionary to get the number of wires. + num_wires = self.COMPILER_OPERATIONS_NUM_WIRES[op.op.name] + self._create_decomposition_rule(rule, op_name=op.op.name, num_wires=num_wires) + else: # pragma: no cover raise ValueError(f"Could not capture {op} without the number of wires.") data, struct = jax.tree_util.tree_flatten(measurement) return jax.tree_util.tree_unflatten(struct, data) + def _create_decomposition_rule(self, func: Callable, op_name: str, num_wires: int): + """Create a decomposition rule from a callable.""" + + sig_func = inspect.signature(func) + type_hints = get_type_hints(func) + + args = {} + for name in sig_func.parameters.keys(): + typ = type_hints.get(name, None) + + # Skip tailing args or kwargs in the rules + if name in ("__", "_"): + continue + + # TODO: This is a temporary solution until all rules have proper type annotations. + # Why? Because we need to pass the correct types to the decomposition_rule + # function to capture the rule correctly with JAX. + possible_names_for_params = { + "params", + "param", + "parameters", + "angles", + "angle", + "phi", + "omega", + "theta", + "weights", + "weight", + } + possible_names_for_wires = {"wires", "wire"} + + if typ is float or name in possible_names_for_params: + # TensorLike is a Union of float, int, array-like, so we use float here + # to cover the most common case as the JAX tracer doesn't like Union types + # and we don't have the actual values at this point. + args[name] = float + elif typ is WiresLike or name in possible_names_for_wires: + # Pass a dummy array of zeros with the correct number of wires + # This is required for the decomposition_rule to work correctly + # as it expects an array-like input for wires + args[name] = qml.math.array([0] * num_wires, like="jax") + elif typ is int: # pragma: no cover + # This is only for cases where the rule has an int parameter + # e.g., dimension in some gates. Not that common though! + # We cover this when adding end-to-end tests for rules + # in the MLIR PR. + args[name] = int + else: # pragma: no cover + raise ValueError( + f"Unsupported type annotation {typ} for parameter {name} in func {func}." + ) + + # Set custom attributes for the decomposition rule + # These attributes are used in the MLIR decomposition pass + # to identify the target gate and the number of wires + setattr(func, "target_gate", op_name) + setattr(func, "num_wires", num_wires) + + return decomposition_rule(func)(**args) + + # pylint: disable=protected-access + def _solve_decomposition_graph(self, operations, gate_set, fixed_decomps, alt_decomps): + """Get the decomposition graph solution for the given operations and gate set. + + TODO: Extend `DecompGraphSolution` API and avoid accessing protected members + directly in this function. + """ -# pylint: disable=protected-access -def _solve_decomposition_graph(operations, gate_set, fixed_decomps, alt_decomps): - """Get the decomposition graph solution for the given operations and gate set.""" - - # decomp_graph_solution - decomp_graph_solution = {} + # decomp_graph_solution + decomp_graph_solution = {} - decomp_graph = DecompositionGraph( - operations, - gate_set, - fixed_decomps=fixed_decomps, - alt_decomps=alt_decomps, - ) + decomp_graph = DecompositionGraph( + operations, + gate_set, + fixed_decomps=fixed_decomps, + alt_decomps=alt_decomps, + ) - # Find the efficient pathways to the target gate set - solutions = decomp_graph.solve() + # Find the efficient pathways to the target gate set + solutions = decomp_graph.solve() - def is_solved_for(op): - return ( - op in solutions._all_op_indices - and solutions._all_op_indices[op] in solutions._visitor.distances - ) + def is_solved_for(op): + return ( + op in solutions._all_op_indices + and solutions._all_op_indices[op] in solutions._visitor.distances + ) - for op_node, op_node_idx in solutions._all_op_indices.items(): - if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: - d_node_idx = solutions._visitor.predecessors[op_node_idx] - decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl + for op_node, op_node_idx in solutions._all_op_indices.items(): + if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: + d_node_idx = solutions._visitor.predecessors[op_node_idx] + decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl - return decomp_graph_solution + return decomp_graph_solution diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index d5771a7fd6..b5fa165096 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -45,9 +45,7 @@ from catalyst.device import extract_backend_info from catalyst.device.qjit_device import COMPILER_OPERATIONS -from catalyst.from_plxpr.decompose import ( - GraphSolutionInterpreter, -) +from catalyst.from_plxpr.decompose import GraphSolutionInterpreter from catalyst.from_plxpr.qubit_handler import QubitHandler from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( @@ -265,9 +263,14 @@ def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgatesets, ncargs): """Apply the compiler-specific decomposition for a given JAXPR.""" # disable the graph decomposition optimization - is_graph = qml.decomposition.enabled_graph() - if is_graph: - qml.decomposition.disable_graph() + # Why? Because for the compiler-specific decomposition we want to + # only decompose higher-level gates and templates that only have + # a single decomposition, and not do any further optimization + # based on the graph solution. + # Besides, the graph-based decomposition is not supported + # yet in from_plxpr for most gates and templates. + # TODO: Enable the graph-based decomposition + qml.decomposition.disable_graph() # First perform the pre-mlir decomposition to simplify the jaxpr # by decomposing high-level gates and templates @@ -277,8 +280,7 @@ def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgatesets, ncargs): inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs ) - if is_graph: - qml.decomposition.enable_graph() + qml.decomposition.enable_graph() return final_jaxpr diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 7faa679c7e..4e103cce6b 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -179,10 +179,19 @@ def only_single_expval(): func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) - gateset = getattr(callable_, "decompose_gatesets", []) - if gateset: + # Register the decomposition gatesets to the QNode FuncOp + # This will set a queue of gatesets that enables support for multiple + # levels of decomposition in the MLIR decomposition pass + if gateset := getattr(callable_, "decompose_gatesets", []): func_op.attributes["decompose_gatesets"] = get_mlir_attribute_from_pyval(gateset) + # Extract the target gate and number of wires from decomposition rules + # and set them as attributes on the FuncOp for use in the MLIR decomposition pass + if target_gate := getattr(callable_, "target_gate", None): + func_op.attributes["target_gate"] = get_mlir_attribute_from_pyval(target_gate) + if num_wires := getattr(callable_, "num_wires", None): + func_op.attributes["num_wires"] = get_mlir_attribute_from_pyval(num_wires) + return func_op diff --git a/frontend/catalyst/passes/builtin_passes.py b/frontend/catalyst/passes/builtin_passes.py index d539d758de..e39cdb8cf8 100644 --- a/frontend/catalyst/passes/builtin_passes.py +++ b/frontend/catalyst/passes/builtin_passes.py @@ -394,6 +394,7 @@ def circuit(x: float): return PassPipelineWrapper(qnode, "merge-rotations") +# pragma: no cover def decompose_lowering(qnode): """ Specify that the ``-decompose-lowering`` MLIR compiler pass diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 06a4f526ad..fd14d674ae 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -745,9 +745,9 @@ def circuit_14(): qml.PauliY(wires=2) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @Rot_rule_rz_ry_rz_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @RY_rule_rz_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @PauliY_rule_ry_gp_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @ry_gp(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg print(circuit_14.mlir) qml.decomposition.disable_graph() @@ -778,16 +778,16 @@ def circuit_15(): qml.DoubleExcitation(0.5, wires=[0, 1, 2, 3]) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @SingleExcitationPlus_rule_single_excitation_plus_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @CY_rule_cy_wires_2(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @CRY_rule_cry_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @S_rule_s_phaseshift_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @PhaseShift_rule_phaseshift_to_rz_gp_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @RZ_rule_rz_to_ry_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @Rot_rule_rot_to_rz_ry_rz_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @DoubleExcitation_rule_doublexcit_wires_4(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @SingleExcitationMinus_rule_single_excitation_minus_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @SingleExcitation_rule_single_excitation_decomp_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_single_excitation_plus_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_cy(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_s_phaseshift(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_doublexcit(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_single_excitation_minus_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} print(circuit_15.mlir) qml.decomposition.disable_graph() @@ -821,9 +821,9 @@ def circuit_16(): qml.adjoint(qml.SingleExcitation)(0.1, wires=[0, 1]) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @CNOT_rule_cnot_to_cz_h_wires_2(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @Hadamard_rule_hadamard_to_rz_ry_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @SingleExcitation_rule_SingleExcitation_rule_single_excitation_decomp_wires_2_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"} + # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} + # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} print(circuit_16.mlir) qml.decomposition.disable_graph() @@ -852,7 +852,7 @@ def circuit_17(): qml.ctrl(qml.PauliX, control=0)(1) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @RY_rule_ry_to_rz_rx_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg print(circuit_17.mlir) qml.decomposition.disable_graph() diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py deleted file mode 100644 index 4e21eadff4..0000000000 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr_decompose.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2025 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests the ``decompose`` transform with the new Catalyst graph-based decomposition system.""" -from functools import partial - -import numpy as np -import pennylane as qml -import pytest - -pytestmark = pytest.mark.usefixtures("disable_capture") - - -class TestDecomposeGraphEnabled: - """Tests the decompose transform with graph enabled.""" - - @pytest.mark.integration - def test_mixed_gate_set_specification(self): - """Tests that the gate_set can be specified as both a type and a string.""" - - qml.decomposition.enable_graph() - - tape = qml.tape.QuantumScript([qml.RX(0.5, wires=[0]), qml.CNOT(wires=[0, 1])]) - [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RX", qml.CNOT}) - assert new_tape.operations == tape.operations - - qml.decomposition.disable_graph() - - @pytest.mark.integration - def test_gate_set_targeted_decompositions(self): - """Tests that a simple circuit is correctly decomposed into different gate sets.""" - - qml.decomposition.enable_graph() - - tape = qml.tape.QuantumScript( - [ - qml.H(0), # non-parametric op - qml.Rot(0.1, 0.2, 0.3, wires=[0]), # parametric single-qubit op - qml.MultiRZ(0.5, wires=[0, 1, 2]), # parametric multi-qubit op - ] - ) - - [new_tape], _ = qml.transforms.decompose(tape, gate_set={"Hadamard", "CNOT", "RZ", "RY"}) - assert new_tape.operations == [ - # H is in the target gate set - qml.H(0), - # Rot decomposes to ZYZ - qml.RZ(0.1, wires=[0]), - qml.RY(0.2, wires=[0]), - qml.RZ(0.3, wires=[0]), - # Decomposition of MultiRZ - qml.CNOT(wires=[2, 1]), - qml.CNOT(wires=[1, 0]), - qml.RZ(0.5, wires=[0]), - qml.CNOT(wires=[1, 0]), - qml.CNOT(wires=[2, 1]), - ] - - [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RY", "RZ", "CZ", "GlobalPhase"}) - assert new_tape.operations == [ - # The H decomposes to RZ and RY - qml.RZ(np.pi, wires=[0]), - qml.RY(np.pi / 2, wires=[0]), - qml.GlobalPhase(-np.pi / 2), - # Rot decomposes to ZYZ - qml.RZ(0.1, wires=[0]), - qml.RY(0.2, wires=[0]), - qml.RZ(0.3, wires=[0]), - # CNOT decomposes to H and CZ, where H decomposes to RZ and RY - qml.RZ(np.pi, wires=[1]), - qml.RY(np.pi / 2, wires=[1]), - qml.GlobalPhase(-np.pi / 2), - qml.CZ(wires=[2, 1]), - qml.RZ(np.pi, wires=[1]), - qml.RY(np.pi / 2, wires=[1]), - qml.GlobalPhase(-np.pi / 2), - # second CNOT - qml.RZ(np.pi, wires=[0]), - qml.RY(np.pi / 2, wires=[0]), - qml.GlobalPhase(-np.pi / 2), - qml.CZ(wires=[1, 0]), - qml.RZ(np.pi, wires=[0]), - qml.RY(np.pi / 2, wires=[0]), - qml.GlobalPhase(-np.pi / 2), - # The middle RZ - qml.RZ(0.5, wires=[0]), - # The last two CNOTs - qml.RZ(np.pi, wires=[0]), - qml.RY(np.pi / 2, wires=[0]), - qml.GlobalPhase(-np.pi / 2), - qml.CZ(wires=[1, 0]), - qml.RZ(np.pi, wires=[0]), - qml.RY(np.pi / 2, wires=[0]), - qml.GlobalPhase(-np.pi / 2), - qml.RZ(np.pi, wires=[1]), - qml.RY(np.pi / 2, wires=[1]), - qml.GlobalPhase(-np.pi / 2), - qml.CZ(wires=[2, 1]), - qml.RZ(np.pi, wires=[1]), - qml.RY(np.pi / 2, wires=[1]), - qml.GlobalPhase(-np.pi / 2), - ] - - qml.decomposition.disable_graph() - - @pytest.mark.integration - def test_fixed_decomp(self): - """Tests that a fixed decomposition rule is used instead of the stock ones.""" - - qml.decomposition.enable_graph() - - @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) - def my_cnot(wires, **__): - qml.RY(np.pi / 2, wires[1]) - qml.Z(wires[1]) - qml.CZ(wires=wires) - qml.RY(np.pi / 2, wires[1]) - qml.Z(wires[1]) - - tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) - [new_tape], _ = qml.transforms.decompose( - tape, - gate_set={"RY", "RZ", "CZ", "Hadamard", "GlobalPhase"}, - fixed_decomps={qml.CNOT: my_cnot}, - ) - assert new_tape.operations == [ - qml.RY(np.pi / 2, wires=[0]), - qml.RZ(np.pi, wires=[0]), - qml.GlobalPhase(-np.pi / 2), - qml.CZ(wires=[1, 0]), - qml.RY(np.pi / 2, wires=[0]), - qml.RZ(np.pi, wires=[0]), - qml.GlobalPhase(-np.pi / 2), - ] - - qml.decomposition.disable_graph() - - @pytest.mark.integration - def test_alt_decomp_not_used(self): - """Tests that alt_decomp isn't necessarily used if it's not efficient.""" - - qml.decomposition.enable_graph() - - @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) - def my_cnot(wires, **__): - qml.RY(np.pi / 2, wires[1]) - qml.Z(wires[1]) - qml.CZ(wires=wires) - qml.RY(np.pi / 2, wires[1]) - qml.Z(wires[1]) - - tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) - [new_tape], _ = qml.transforms.decompose( - tape, - gate_set={"RY", "RZ", "CZ", "Hadamard", "GlobalPhase"}, - alt_decomps={qml.CNOT: [my_cnot]}, - ) - assert new_tape.operations == [ - qml.H(0), - qml.CZ(wires=[1, 0]), - qml.H(0), - ] - - qml.decomposition.disable_graph() - - @pytest.mark.integration - def test_alt_decomp(self): - """Tests that alternative decomposition rules are used when applicable.""" - - qml.decomposition.enable_graph() - - @qml.register_resources({qml.RY: 2, qml.CZ: 1, qml.Z: 2}) - def my_cnot(wires, **__): - qml.RY(np.pi / 2, wires[1]) - qml.Z(wires[1]) - qml.CZ(wires=wires) - qml.RY(np.pi / 2, wires[1]) - qml.Z(wires[1]) - - tape = qml.tape.QuantumScript([qml.CNOT(wires=[1, 0])]) - [new_tape], _ = qml.transforms.decompose( - tape, - gate_set={"RY", "RZ", "CZ", "PauliZ", "GlobalPhase"}, - alt_decomps={qml.CNOT: [my_cnot]}, - ) - assert new_tape.operations == [ - qml.RY(np.pi / 2, wires=[0]), - qml.Z(0), - qml.CZ(wires=[1, 0]), - qml.RY(np.pi / 2, wires=[0]), - qml.Z(0), - ] - - qml.decomposition.disable_graph() - - @pytest.mark.integration - def test_fall_back(self): - """Tests that op.decompose() is used for ops unsolved in the graph.""" - - qml.decomposition.enable_graph() - - class CustomOp(qml.operation.Operation): # pylint: disable=too-few-public-methods - """Dummy custom op.""" - - resource_keys = set() - - @property - def resource_params(self): - """Dummy resource params.""" - - return {} - - def decomposition(self): - """Decomposition of CustomOp into H-CNOT-H.""" - - return [qml.H(self.wires[1]), qml.CNOT(self.wires), qml.H(self.wires[1])] - - @qml.register_resources({qml.CZ: 1}) - def my_decomp(wires, **__): - qml.CZ(wires=wires) - - tape = qml.tape.QuantumScript([CustomOp(wires=[0, 1])]) - [new_tape], _ = qml.transforms.decompose( - tape, gate_set={"CNOT", "Hadamard"}, fixed_decomps={CustomOp: my_decomp} - ) - assert new_tape.operations == [qml.H(1), qml.CNOT(wires=[0, 1]), qml.H(1)] - - qml.decomposition.disable_graph() - - # @pytest.mark.integration - # def test_controlled_decomp(self): - # """Tests decomposing a controlled operation.""" - - # # The C(MultiRZ) is decomposed by applying control on the base decomposition. - # # The decomposition of MultiRZ contains two CNOTs - # # So this also tests applying control on an PauliX based operation - # # The decomposition of MultiRZ also contains an RZ gate - # # So this also tests logic involving custom controlled operators. - # ops = [qml.ctrl(qml.MultiRZ(0.5, wires=[0, 1]), control=[2])] - # tape = qml.tape.QuantumScript(ops) - # [new_tape], _ = qml.transforms.decompose(tape, gate_set={"RZ", "CNOT", "Toffoli"}) - # assert new_tape.operations == [ - # # Decomposition of C(CNOT) - # qml.Toffoli(wires=[2, 1, 0]), - # # Decomposition of C(RZ) -> CRZ - # qml.RZ(0.25, wires=[0]), - # qml.CNOT(wires=[2, 0]), - # qml.RZ(-0.25, wires=[0]), - # qml.CNOT(wires=[2, 0]), - # # Decomposition of C(CNOT) - # qml.Toffoli(wires=[2, 1, 0]), - # ] - - # @pytest.mark.integration - # def test_adjoint_decomp(self): - # """Tests decomposing an adjoint operation.""" - - # class CustomOp(qml.operation.Operator): # pylint: disable=too-few-public-methods - - # resource_keys = set() - - # @property - # def resource_params(self) -> dict: - # return {} - - # @qml.register_resources({qml.RX: 1, qml.RY: 1, qml.RZ: 1}) - # def custom_decomp(theta, phi, omega, wires): - # qml.RX(theta, wires[0]) - # qml.RY(phi, wires[0]) - # qml.RZ(omega, wires[0]) - - # tape = qml.tape.QuantumScript( - # [ - # qml.adjoint(qml.RX(0.5, wires=[0])), - # qml.adjoint(qml.adjoint(qml.MultiRZ(0.5, wires=[0, 1]))), - # qml.adjoint(CustomOp(0.1, 0.2, 0.3, wires=[0])), - # ] - # ) - # [new_tape], _ = qml.transforms.decompose( - # tape, gate_set={"CNOT", "RX", "RY", "RZ"}, fixed_decomps={CustomOp: custom_decomp} - # ) - # assert new_tape.operations == [ - # qml.RX(-0.5, wires=[0]), - # qml.CNOT(wires=[1, 0]), - # qml.RZ(0.5, wires=[0]), - # qml.CNOT(wires=[1, 0]), - # qml.RZ(-0.3, wires=[0]), - # qml.RY(-0.2, wires=[0]), - # qml.RX(-0.1, wires=[0]), - # ] - - -def test_decompose_qnode(): - """Tests that the decompose transform works with a QNode.""" - - @partial(qml.transforms.decompose, gate_set={"CZ", "Hadamard"}) - @qml.qnode(qml.device("default.qubit", wires=2)) - def circuit(): - qml.CNOT(wires=[0, 1]) - return qml.expval(qml.PauliZ(0)) - - res = circuit() - assert qml.math.allclose(res, 1.0) From fa14d9353b18f8711ef2350902d6fc5913b94645 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 18 Sep 2025 00:29:39 -0400 Subject: [PATCH 19/36] Update --- frontend/catalyst/from_plxpr/decompose.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index a1ac8b202f..dba5ecd44b 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -20,7 +20,6 @@ import inspect from collections.abc import Callable -from copy import copy from typing import get_type_hints import jax @@ -28,7 +27,6 @@ # GraphSolutionInterpreter: from pennylane.decomposition import DecompositionGraph -from pennylane.measurements import MidMeasureMP from pennylane.wires import WiresLike from catalyst.jax_primitives import decomposition_rule @@ -183,16 +181,18 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess o := next((o for o in self._operations if o.name == op.op.name), None) ) is not None: # TODO: This assumes that the operation names are unique in the circuit. - # If there are multiple operations with the same name but different number of wires, - # this will only capture the first one. + # If there are multiple operations with the same name but different number + # of wires, this will only capture the first one. self._create_decomposition_rule( rule, op_name=op.op.name, num_wires=len(o.wires) ) elif op.op.name in self.COMPILER_OPERATIONS_NUM_WIRES: - # In this part, we need to handle the case where an operation in the decomposition graph solution - # is not in the captured operations. This can happen if the operation is not directly called - # in the circuit, but is used inside a decomposition rule. In this case, we - # fall back to using the COMPILER_OPERATIONS_NUM_WIRES dictionary to get the number of wires. + # In this part, we need to handle the case where an operation in + # the decomposition graph solution is not in the captured operations. + # This can happen if the operation is not directly called + # in the circuit, but is used inside a decomposition rule. + # In this case, we fall back to using the COMPILER_OPERATIONS_NUM_WIRES + # dictionary to get the number of wires. num_wires = self.COMPILER_OPERATIONS_NUM_WIRES[op.op.name] self._create_decomposition_rule(rule, op_name=op.op.name, num_wires=num_wires) else: # pragma: no cover From 9827a377e90a6518db1994c5fe5ba377034828ad Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 18 Sep 2025 13:23:19 -0400 Subject: [PATCH 20/36] Update support for templates --- frontend/catalyst/from_plxpr/decompose.py | 39 ++++++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index dba5ecd44b..41de82428f 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -64,7 +64,7 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): # A mapping from operation names to the number of wires they act on. # This is used when the operation is not in the captured operations # but we still need to create a decomposition rule for it. - COMPILER_OPERATIONS_NUM_WIRES: dict[str, int] = { + compiler_ops_num_wires: dict[str, int] = { "CNOT": 2, "ControlledPhaseShift": 2, "CRot": 2, @@ -81,7 +81,11 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): "IsingYY": 2, "IsingZZ": 2, "SingleExcitation": 2, + "SingleExcitationPlus": 2, + "SingleExcitationMinus": 2, "DoubleExcitation": 4, + "DoubleExcitationPlus": 4, + "DoubleExcitationMinus": 4, "ISWAP": 2, "PauliX": 1, "PauliY": 1, @@ -99,6 +103,8 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): "U1": 1, "U2": 1, "U3": 1, + "MultiRZ": -1, # variable number of wires + "GlobalPhase": -1, # variable number of wires } def __init__( @@ -123,6 +129,29 @@ def __init__( self._operations = set() self._decomp_graph_solution = {} + def update_operations(self, operations): + """Update the set of captured operations. + + Args: + operations (set): a set of pennylane operator instances + """ + for op in operations: + # TODO: Although we deal with those ops not in compiler_ops_num_wires in the + # compiler-specific decomposition step, we should ideally have a way to specify + # the list of ops in the structured rule and their corresponding number of wires + # to solve the graph for them. + if op.name in self.compiler_ops_num_wires.keys(): + self._operations.add(op) + else: + try: + with qml.capture.pause(): + ops = op.decomposition() + self.update_operations(ops) + except: # pylint: disable=bare-except + # the compiler-specific decomposition step will handle those ops + # that we can't decompose here; also related to the TODO above. + pass # do nothing if we can't decompose it to the list of ops. + def interpret_operation(self, op: "qml.operation.Operator"): """Interpret a PennyLane operation instance. @@ -142,7 +171,7 @@ def interpret_operation(self, op: "qml.operation.Operator"): """ - self._operations.add(op) + self.update_operations({op}) data, struct = jax.tree_util.tree_flatten(op) return jax.tree_util.tree_unflatten(struct, data) @@ -186,14 +215,14 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess self._create_decomposition_rule( rule, op_name=op.op.name, num_wires=len(o.wires) ) - elif op.op.name in self.COMPILER_OPERATIONS_NUM_WIRES: + elif op.op.name in self.compiler_ops_num_wires: # In this part, we need to handle the case where an operation in # the decomposition graph solution is not in the captured operations. # This can happen if the operation is not directly called # in the circuit, but is used inside a decomposition rule. - # In this case, we fall back to using the COMPILER_OPERATIONS_NUM_WIRES + # In this case, we fall back to using the compiler_ops_num_wires # dictionary to get the number of wires. - num_wires = self.COMPILER_OPERATIONS_NUM_WIRES[op.op.name] + num_wires = self.compiler_ops_num_wires[op.op.name] self._create_decomposition_rule(rule, op_name=op.op.name, num_wires=num_wires) else: # pragma: no cover raise ValueError(f"Could not capture {op} without the number of wires.") From 54e539d425dd7631d9f1eda4f5d306a9324ac0fd Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:33:26 -0400 Subject: [PATCH 21/36] Update orders of catalyst decomps --- frontend/catalyst/from_plxpr/from_plxpr.py | 137 +++++++++++++++------ frontend/test/lit/test_decomposition.py | 33 ++--- frontend/test/lit/test_from_plxpr.py | 21 ---- 3 files changed, 118 insertions(+), 73 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index b5fa165096..90c9c712d8 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -185,8 +185,8 @@ def __init__(self): self.qubit_handler = None # Compiler options for the new decomposition system - self.requires_compiler_decompose = False - self.decompose_gatesets = [] # queue of gatesets + self.requires_decompose_lowering = False + self.decompose_tkwargs = {} # target gateset super().__init__() @@ -207,15 +207,23 @@ def handle_qnode( closed_jaxpr = ( ClosedJaxpr(qfunc_jaxpr, consts) - if not self.requires_compiler_decompose + if not self.requires_decompose_lowering else apply_compiler_decompose_to_plxpr( inner_jaxpr=qfunc_jaxpr, consts=consts, ncargs=non_const_args, - tgatesets=self.decompose_gatesets, + tgateset=list(self.decompose_tkwargs.get("gate_set", [])), ) ) + if self.requires_decompose_lowering: + closed_jaxpr = collect_and_compile_graph_solutions( + inner_jaxpr=closed_jaxpr.jaxpr, + consts=closed_jaxpr.consts, + tkwargs=self.decompose_tkwargs, + ncargs=non_const_args, + ) + def calling_convention(*args): device_init_p.bind( shots, @@ -231,9 +239,15 @@ def calling_convention(*args): device_release_p.bind() return retvals - if self.requires_compiler_decompose: + if self.requires_decompose_lowering: # Add gate_set attribute to the quantum kernel primitive - setattr(qnode, "decompose_gatesets", self.decompose_gatesets) + # decompose_gatesets is treated as a queue of gatesets to be used + # but we only support a single gateset for now in from_plxpr + # as supporting multiple gatesets requires an MLIR/C++ graph-decomposition + # implementation. The current Python implementation cannot be mixed + # with other transforms in between. + gateset = [_get_operator_name(op) for op in self.decompose_tkwargs.get("gate_set", [])] + setattr(qnode, "decompose_gatesets", [gateset]) return quantum_kernel_p.bind( wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), @@ -259,22 +273,35 @@ def calling_convention(*args): } -def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgatesets, ncargs): - """Apply the compiler-specific decomposition for a given JAXPR.""" +def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgateset, ncargs): + """Apply the compiler-specific decomposition for a given JAXPR. + + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + tgateset (list): A list of target gateset for decomposition. + ncargs (list): Non-constant arguments for the JAXPR. + qargs (list): All arguments including constants and non-constants. + + Returns: + ClosedJaxpr: The decomposed JAXPR. + """ + + # Disable the graph decomposition optimization - # disable the graph decomposition optimization # Why? Because for the compiler-specific decomposition we want to # only decompose higher-level gates and templates that only have # a single decomposition, and not do any further optimization # based on the graph solution. # Besides, the graph-based decomposition is not supported # yet in from_plxpr for most gates and templates. + # TODO: Enable the graph-based decomposition qml.decomposition.disable_graph() # First perform the pre-mlir decomposition to simplify the jaxpr # by decomposing high-level gates and templates - gate_set = COMPILER_OPERATIONS + list(set().union(*tgatesets)) + gate_set = COMPILER_OPERATIONS + tgateset final_jaxpr = qml.transforms.decompose.plxpr_transform( inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs @@ -285,6 +312,35 @@ def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgatesets, ncargs): return final_jaxpr +def collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): + """Collect and compile graph solutions for a given JAXPR. + + This function uses the GraphSolutionInterpreter to evaluate + the input JAXPR and obtain a new JAXPR that incorporates + the graph-based decomposition solutions. + + This function doesn't modify the underlying quantum function + but rather constructs a new JAXPR with decomposition rules. + + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + tkwargs (list): The keyword arguments of the decompose transform. + ncargs (list): Non-constant arguments for the JAXPR. + + Returns: + ClosedJaxpr: The decomposed JAXPR. + """ + gds_interpreter = GraphSolutionInterpreter(**tkwargs) + + def gds_wrapper(*args): + return gds_interpreter.eval(inner_jaxpr, consts, *args) + + final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs) + + return final_jaxpr + + # pylint: disable-next=redefined-outer-name def register_transform(pl_transform, pass_name, decomposition): """Register pennylane transforms and their conversion to Catalyst transforms""" @@ -316,30 +372,19 @@ def handle_transform( and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" and qml.decomposition.enabled_graph() ): - if not self.requires_compiler_decompose: - self.requires_compiler_decompose = True - - # A helper function to get the name of a pennylane operator - def get_operator_name(op): - """Get the name of a pennylane operator, handling wrapped operators. - - Note: Controlled and Adjoint ops aren't supported in `gate_set` - by PennyLane's DecompositionGraph; unit tests were added in PennyLane. - """ - if isinstance(op, str): - return op - - # Return NoNameOp if the operator has no _primitive.name attribute. - # This is to avoid errors when we capture the program - # as we deal with such ops later in the decomposition graph. - return getattr(op._primitive, "name", "NoNameOp") - - # Update the decompose_gatesets to be used by the quantum kernel primitive - tgateset = tkwargs.get("gate_set", []) + if not self.requires_decompose_lowering: + self.requires_decompose_lowering = True + else: + raise NotImplementedError( + "Multiple decomposition transforms are not yet supported." + ) - # We treat decompose_gatesets as a queue of gatesets to be used - # by the decompose-lowering pass at MLIR - self.decompose_gatesets.insert(0, [get_operator_name(op) for op in tgateset]) + # Update the decompose_gateset to be used by the quantum kernel primitive + # TODO: we originally wanted to treat decompose_gateset as a queue of + # gatesets to be used by the decompose-lowering pass at MLIR + # but this requires a C++ implementation of the graph-based decomposition + # which doesn't exist yet. + self.decompose_tkwargs = tkwargs # Note. We don't perform the compiler-specific decomposition here # to be able to support multiple decomposition transforms @@ -356,13 +401,14 @@ def get_operator_name(op): # the current jaxpr based on the current gateset # but we don't rewrite the jaxpr at this stage. - gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs) + # gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs) - def gds_wrapper(*args): - return gds_interpreter.eval(inner_jaxpr, consts, *args) + # def gds_wrapper(*args): + # return gds_interpreter.eval(inner_jaxpr, consts, *args) - final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) - return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) + # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) + return self.eval(inner_jaxpr, consts, *non_const_args) if catalyst_pass_name is None: # Use PL's ExpandTransformsInterpreter to expand this and any embedded @@ -863,3 +909,18 @@ def trace_from_pennylane( jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs) return jaxpr, out_type, out_treedef, sig + + +def _get_operator_name(op): + """Get the name of a pennylane operator, handling wrapped operators. + + Note: Controlled and Adjoint ops aren't supported in `gate_set` + by PennyLane's DecompositionGraph; unit tests were added in PennyLane. + """ + if isinstance(op, str): + return op + + # Return NoNameOp if the operator has no _primitive.name attribute. + # This is to avoid errors when we capture the program + # as we deal with such ops later in the decomposition graph. + return getattr(op._primitive, "name", "NoNameOp") diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index fd14d674ae..ec9ad97fa6 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -778,15 +778,12 @@ def circuit_15(): qml.DoubleExcitation(0.5, wires=[0, 1, 2, 3]) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @_single_excitation_plus_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_cy(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_s_phaseshift(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_doublexcit(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg - # CHECK-DAG: func.func public @_single_excitation_minus_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} + # CHECK-DAG: func.func public @_s_phaseshift(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "S"} + # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"} + # func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + # CHECK-DAG: func.func public @_doublexcit(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 4 : i64, target_gate = "DoubleExcitation"} # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} print(circuit_15.mlir) @@ -809,6 +806,7 @@ def test_decomposition_rule_name_adjoint(): gate_set={"RY", "RX", "CZ", "GlobalPhase"}, ) @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" # CHECK: public @circuit_16() -> tensor attributes {decompose_gatesets def circuit_16(): # CHECK-DAG: %1 = quantum.adjoint(%0) : !quantum.reg @@ -821,9 +819,11 @@ def circuit_16(): qml.adjoint(qml.SingleExcitation)(0.1, wires=[0, 1]) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"} - # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} + # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"} print(circuit_16.mlir) qml.decomposition.disable_graph() @@ -842,17 +842,22 @@ def test_decomposition_rule_name_ctrl(): @qml.qjit(target="mlir") @partial( qml.transforms.decompose, - gate_set={"RX", "RZ"}, + gate_set={"RX", "RZ", "H", "CZ"}, ) @qml.qnode(qml.device("lightning.qubit", wires=5)) - # CHECK: public @circuit_17() -> tensor attributes {decompose_gatesets + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK{LITERAL}: func.func public @circuit_17() -> tensor attributes {decompose_gatesets def circuit_17(): # CHECK: %out_qubits:2 = quantum.custom "CRY"(%cst) %1, %2 : !quantum.bit, !quantum.bit + # CHECK-NEXT: %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits#0, %out_qubits#1 : !quantum.bit, !quantum.bit qml.ctrl(qml.RY, control=0)(0.5, 1) qml.ctrl(qml.PauliX, control=0)(1) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"} + # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} + # CHECK-DAG: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RY"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} print(circuit_17.mlir) qml.decomposition.disable_graph() diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index a6c2bae85e..f87678c3d7 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -416,26 +416,5 @@ def circuit3(): print(circuit3.mlir) - @qml.qjit(target="mlir") - @partial(qml.transforms.decompose, gate_set={"RX"}) - @qml.transforms.cancel_inverses - @partial(qml.transforms.decompose, gate_set={"RZ"}) - @qml.transforms.merge_rotations - @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) - @qml.qnode(dev) - def circuit4(): - return qml.probs() - - # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" - # CHECK-NEXT: [[merge_rot:%.+]] = transform.apply_registered_pass "merge-rotations" to [[first_pass]] - # CHECK-NEXT: [[decomp_to_rz:%.+]] = transform.apply_registered_pass "decompose-lowering" to [[merge_rot]] - # CHECK-NEXT: [[remove_chained:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[decomp_to_rz]] - # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[remove_chained]] - - print(circuit4.mlir) - - qml.decomposition.disable_graph() - qml.capture.disable() - test_pass_decomposition() From 4f8c6ee48c15dfda32a7c0b582045de2204a872f Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Fri, 19 Sep 2025 09:00:58 -0400 Subject: [PATCH 22/36] Update tests --- frontend/catalyst/from_plxpr/decompose.py | 33 ++-------- frontend/test/lit/test_decomposition.py | 74 +++++++++++++++++++++++ 2 files changed, 78 insertions(+), 29 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 41de82428f..8ddb256bbe 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -81,11 +81,7 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): "IsingYY": 2, "IsingZZ": 2, "SingleExcitation": 2, - "SingleExcitationPlus": 2, - "SingleExcitationMinus": 2, "DoubleExcitation": 4, - "DoubleExcitationPlus": 4, - "DoubleExcitationMinus": 4, "ISWAP": 2, "PauliX": 1, "PauliY": 1, @@ -129,29 +125,6 @@ def __init__( self._operations = set() self._decomp_graph_solution = {} - def update_operations(self, operations): - """Update the set of captured operations. - - Args: - operations (set): a set of pennylane operator instances - """ - for op in operations: - # TODO: Although we deal with those ops not in compiler_ops_num_wires in the - # compiler-specific decomposition step, we should ideally have a way to specify - # the list of ops in the structured rule and their corresponding number of wires - # to solve the graph for them. - if op.name in self.compiler_ops_num_wires.keys(): - self._operations.add(op) - else: - try: - with qml.capture.pause(): - ops = op.decomposition() - self.update_operations(ops) - except: # pylint: disable=bare-except - # the compiler-specific decomposition step will handle those ops - # that we can't decompose here; also related to the TODO above. - pass # do nothing if we can't decompose it to the list of ops. - def interpret_operation(self, op: "qml.operation.Operator"): """Interpret a PennyLane operation instance. @@ -171,7 +144,7 @@ def interpret_operation(self, op: "qml.operation.Operator"): """ - self.update_operations({op}) + self._operations.add(op) data, struct = jax.tree_util.tree_flatten(op) return jax.tree_util.tree_unflatten(struct, data) @@ -259,7 +232,9 @@ def _create_decomposition_rule(self, func: Callable, op_name: str, num_wires: in "weights", "weight", } - possible_names_for_wires = {"wires", "wire"} + + # TODO: Support work-wires when it's supported in Catalyst. + possible_names_for_wires = {"wires", "wire", "control_wires", "target_wires"} if typ is float or name in possible_names_for_params: # TensorLike is a Union of float, int, array-like, so we use float here diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index ec9ad97fa6..8d75c774b3 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -865,3 +865,77 @@ def circuit_17(): test_decomposition_rule_name_ctrl() + + +def test_qft_decomposition(): + """Test the decomposition of the QFT""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RY", "CNOT", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: func.func public @circuit_16(%arg0: tensor<3xf64>) -> tensor attributes {decompose_gatesets + def circuit_16(): + # %6 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %0) -> (!quantum.reg) { + # %23 = scf.for %arg3 = %c0 to %22 step %c1 iter_args(%arg4 = %21) -> (!quantum.reg) { + # %7 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %6) -> (!quantum.reg) { + qml.QFT(wires=[0, 1, 2, 3]) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @_cphase_to_rz_cnot(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "ControlledPhaseShift"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + # CHECK-DAG: func.func public @_swap_to_cnot(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SWAP"} + # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} + print(circuit_16.mlir) + + +test_qft_decomposition() + + +def test_decompose_lowering_with_other_passes(): + """Test the decompose lowering pass with other passes in a pass pipeline.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @partial( + qml.transforms.decompose, + gate_set={"RZ", "RY", "CNOT", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK: module attributes {transform.with_named_sequence} { + # CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + # CHECK-NEXT: [[ONE:%.+]] = transform.apply_registered_pass "decompose-lowering" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: [[TWO:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[ONE]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: [[THREE:%.+]] = transform.apply_registered_pass "merge-rotations" to [[TWO]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: transform.yield + # CHECK-NEXT: } + def circuit_17(): + + # CHECK: [[OUT_0:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit + # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "PauliX"() [[OUT_0]] : !quantum.bit + # CHECK-NEXT: [[OUT_2:%.+]] = quantum.custom "RX"(%cst_0) [[OUT_1]] : !quantum.bit + # CHECK-NEXT: [[OUT_3:%.+]] = quantum.custom "RX"(%cst) [[OUT_2]] : !quantum.bit + qml.PauliX(0) + qml.PauliX(0) + qml.RX(0.1, wires=0) + qml.RX(-0.1, wires=0) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"} + # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circuit_17.mlir) + + +test_decompose_lowering_with_other_passes() From fdaef6b90cdb7a1707d32e97dec5d3b71d172c70 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Fri, 19 Sep 2025 16:03:26 -0400 Subject: [PATCH 23/36] Support multi-qubit decomp rules --- frontend/catalyst/from_plxpr/decompose.py | 279 ++++++++++++++------- frontend/catalyst/from_plxpr/from_plxpr.py | 8 +- frontend/test/lit/test_decomposition.py | 123 ++++++++- 3 files changed, 306 insertions(+), 104 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 8ddb256bbe..f3d0555be4 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -18,14 +18,16 @@ from __future__ import annotations +import functools import inspect +import types from collections.abc import Callable from typing import get_type_hints import jax import pennylane as qml -# GraphSolutionInterpreter: +# DecompRuleInterpreter: from pennylane.decomposition import DecompositionGraph from pennylane.wires import WiresLike @@ -33,7 +35,7 @@ # pylint: disable=too-few-public-methods -class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): +class DecompRuleInterpreter(qml.capture.PlxprInterpreter): """Interpreter for getting the decomposition graph solution from a jaxpr when program capture is enabled. @@ -64,6 +66,14 @@ class GraphSolutionInterpreter(qml.capture.PlxprInterpreter): # A mapping from operation names to the number of wires they act on. # This is used when the operation is not in the captured operations # but we still need to create a decomposition rule for it. + # + # Note that some operations have a variable number of wires, + # e.g., MultiRZ, GlobalPhase. For these, we set the number + # of wires to -1 to indicate a variable number. + # + # This will require a copy of the function to be made + # when creating the decomposition rule to avoid mutating + # the original function with attributes like num_wires. compiler_ops_num_wires: dict[str, int] = { "CNOT": 2, "ControlledPhaseShift": 2, @@ -113,7 +123,7 @@ def __init__( if not qml.decomposition.enabled_graph(): # pragma: no cover raise TypeError( - "The GraphSolutionInterpreter can only be used when" + "The DecompRuleInterpreter can only be used when" "graph-based decomposition is enabled." ) @@ -169,7 +179,7 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess # place where we can be sure that we have seen all operations # in the circuit before the measurement. # TODO: Find a better way to do this. - self._decomp_graph_solution = self._solve_decomposition_graph( + self._decomp_graph_solution = _solve_decomposition_graph( self._operations, self._gate_set, fixed_decomps=self._fixed_decomps, @@ -179,14 +189,26 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess # Create decomposition rules for each operation in the solution # and compile them to Catalyst JAXPR decomposition rules for op, rule in self._decomp_graph_solution.items(): + # Get number of wires if exists + op_num_wires = ( + op.op.params.get("num_wires", None) if hasattr(op.op, "params") else None + ) + if ( - o := next((o for o in self._operations if o.name == op.op.name), None) + o := next( + ( + o + for o in self._operations + if o.name == op.op.name and len(o.wires) == op_num_wires + ), + None, + ) ) is not None: - # TODO: This assumes that the operation names are unique in the circuit. - # If there are multiple operations with the same name but different number - # of wires, this will only capture the first one. - self._create_decomposition_rule( - rule, op_name=op.op.name, num_wires=len(o.wires) + _create_decomposition_rule( + rule, + op_name=op.op.name, + num_wires=len(o.wires), + requires_copy=self.compiler_ops_num_wires[op.op.name] == -1, ) elif op.op.name in self.compiler_ops_num_wires: # In this part, we need to handle the case where an operation in @@ -196,105 +218,172 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess # In this case, we fall back to using the compiler_ops_num_wires # dictionary to get the number of wires. num_wires = self.compiler_ops_num_wires[op.op.name] - self._create_decomposition_rule(rule, op_name=op.op.name, num_wires=num_wires) + _create_decomposition_rule( + rule, + op_name=op.op.name, + num_wires=num_wires, + requires_copy=self.compiler_ops_num_wires[op.op.name] == -1, + ) else: # pragma: no cover raise ValueError(f"Could not capture {op} without the number of wires.") data, struct = jax.tree_util.tree_flatten(measurement) return jax.tree_util.tree_unflatten(struct, data) - def _create_decomposition_rule(self, func: Callable, op_name: str, num_wires: int): - """Create a decomposition rule from a callable.""" - - sig_func = inspect.signature(func) - type_hints = get_type_hints(func) - - args = {} - for name in sig_func.parameters.keys(): - typ = type_hints.get(name, None) - - # Skip tailing args or kwargs in the rules - if name in ("__", "_"): - continue - - # TODO: This is a temporary solution until all rules have proper type annotations. - # Why? Because we need to pass the correct types to the decomposition_rule - # function to capture the rule correctly with JAX. - possible_names_for_params = { - "params", - "param", - "parameters", - "angles", - "angle", - "phi", - "omega", - "theta", - "weights", - "weight", - } - - # TODO: Support work-wires when it's supported in Catalyst. - possible_names_for_wires = {"wires", "wire", "control_wires", "target_wires"} - - if typ is float or name in possible_names_for_params: - # TensorLike is a Union of float, int, array-like, so we use float here - # to cover the most common case as the JAX tracer doesn't like Union types - # and we don't have the actual values at this point. - args[name] = float - elif typ is WiresLike or name in possible_names_for_wires: - # Pass a dummy array of zeros with the correct number of wires - # This is required for the decomposition_rule to work correctly - # as it expects an array-like input for wires - args[name] = qml.math.array([0] * num_wires, like="jax") - elif typ is int: # pragma: no cover - # This is only for cases where the rule has an int parameter - # e.g., dimension in some gates. Not that common though! - # We cover this when adding end-to-end tests for rules - # in the MLIR PR. - args[name] = int - else: # pragma: no cover - raise ValueError( - f"Unsupported type annotation {typ} for parameter {name} in func {func}." - ) - # Set custom attributes for the decomposition rule - # These attributes are used in the MLIR decomposition pass - # to identify the target gate and the number of wires - setattr(func, "target_gate", op_name) - setattr(func, "num_wires", num_wires) +def _create_decomposition_rule( + func: Callable, op_name: str, num_wires: int, requires_copy: bool = False +): + """Create a decomposition rule from a callable. - return decomposition_rule(func)(**args) + See also: :func:`~.decomposition_rule`. - # pylint: disable=protected-access - def _solve_decomposition_graph(self, operations, gate_set, fixed_decomps, alt_decomps): - """Get the decomposition graph solution for the given operations and gate set. + Args: + func (Callable): The decomposition function. + op_name (str): The name of the operation to decompose. + num_wires (int): The number of wires the operation acts on. - TODO: Extend `DecompGraphSolution` API and avoid accessing protected members - directly in this function. - """ + Returns: + None: The function is decorated in place. + """ - # decomp_graph_solution - decomp_graph_solution = {} + sig_func = inspect.signature(func) + type_hints = get_type_hints(func) + + args = {} + for name in sig_func.parameters.keys(): + typ = type_hints.get(name, None) + + # Skip tailing args or kwargs in the rules + if name in ("__", "_"): + continue + + # TODO: This is a temporary solution until all rules have proper type annotations. + # Why? Because we need to pass the correct types to the decomposition_rule + # function to capture the rule correctly with JAX. + possible_names_for_params = { + "params", + "param", + "parameters", + "angles", + "angle", + "phi", + "omega", + "theta", + "weights", + "weight", + } + + # TODO: Support work-wires when it's supported in Catalyst. + possible_names_for_wires = {"wires", "wire", "control_wires", "target_wires"} + + if typ is float or name in possible_names_for_params: + # TensorLike is a Union of float, int, array-like, so we use float here + # to cover the most common case as the JAX tracer doesn't like Union types + # and we don't have the actual values at this point. + args[name] = float + elif typ is WiresLike or name in possible_names_for_wires: + # Pass a dummy array of zeros with the correct number of wires + # This is required for the decomposition_rule to work correctly + # as it expects an array-like input for wires + args[name] = qml.math.array([0] * num_wires, like="jax") + elif typ is int: # pragma: no cover + # This is only for cases where the rule has an int parameter + # e.g., dimension in some gates. Not that common though! + # We cover this when adding end-to-end tests for rules + # in the MLIR PR. + args[name] = int + else: # pragma: no cover + raise ValueError( + f"Unsupported type annotation {typ} for parameter {name} in func {func}." + ) - decomp_graph = DecompositionGraph( - operations, - gate_set, - fixed_decomps=fixed_decomps, - alt_decomps=alt_decomps, + func_cp = make_def_copy(func) if requires_copy else func + + # Set custom attributes for the decomposition rule + # These attributes are used in the MLIR decomposition pass + # to identify the target gate and the number of wires + setattr(func_cp, "target_gate", op_name) + setattr(func_cp, "num_wires", num_wires) + + if requires_copy: + # Include number of wires in the function name to avoid name clashes + # when the same rule is compiled multiple times with different number of wires + # (e.g., MultiRZ, GlobalPhase) + func_cp.__name__ += f"_wires_{num_wires}" # pylint: disable=protected-access + + return decomposition_rule(func_cp)(**args) + + +# pylint: disable=protected-access +def _solve_decomposition_graph(operations, gate_set, fixed_decomps, alt_decomps): + """Get the decomposition graph solution for the given operations and gate set. + + TODO: Extend `DecompGraphSolution` API and avoid accessing protected members + directly in this function. + + Args: + operations (set[Operator]): The set of operations to decompose. + gate_set (set[Operator]): The target gate set to decompose to. + fixed_decomps (dict or None): A dictionary of fixed decomposition rules + to use in the decomposition graph. + alt_decomps (dict or None): A dictionary of alternative decomposition rules + to use in the decomposition graph. + + Returns: + dict: A dictionary mapping operations to their decomposition rules. + """ + + # decomp_graph_solution + decomp_graph_solution = {} + + decomp_graph = DecompositionGraph( + operations, + gate_set, + fixed_decomps=fixed_decomps, + alt_decomps=alt_decomps, + ) + + # Find the efficient pathways to the target gate set + solutions = decomp_graph.solve() + + def is_solved_for(op): + return ( + op in solutions._all_op_indices + and solutions._all_op_indices[op] in solutions._visitor.distances ) - # Find the efficient pathways to the target gate set - solutions = decomp_graph.solve() + for op_node, op_node_idx in solutions._all_op_indices.items(): + if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: + d_node_idx = solutions._visitor.predecessors[op_node_idx] + decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl + + return decomp_graph_solution - def is_solved_for(op): - return ( - op in solutions._all_op_indices - and solutions._all_op_indices[op] in solutions._visitor.distances - ) - for op_node, op_node_idx in solutions._all_op_indices.items(): - if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: - d_node_idx = solutions._visitor.predecessors[op_node_idx] - decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl +# pylint: disable=protected-access +def make_def_copy(func): + """Create a copy of a Python definition to avoid mutating the original. - return decomp_graph_solution + This is especially useful when compiling decomposition rules with + parametric number of wires (e.g., MultiRZ, GlobalPhase) multiple times, + as the compilation process may add attributes to the function that + can interfere with subsequent compilations. + + Args: + func (Callable): The function to copy. + + Returns: + Callable: A copy of the original function with the same attributes. + """ + # Create a new function object with the same code, globals, name, defaults, and closure + func_copy = types.FunctionType( + func.__code__, + func.__globals__, + name=func.__name__, + argdefs=func.__defaults__, + closure=func.__closure__, + ) + + # Now, we create and update the wrapper to copy over attributes like docstring, module, etc. + return functools.update_wrapper(func_copy, func) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 90c9c712d8..cf5b5e5594 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -45,7 +45,7 @@ from catalyst.device import extract_backend_info from catalyst.device.qjit_device import COMPILER_OPERATIONS -from catalyst.from_plxpr.decompose import GraphSolutionInterpreter +from catalyst.from_plxpr.decompose import DecompRuleInterpreter from catalyst.from_plxpr.qubit_handler import QubitHandler from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( @@ -315,7 +315,7 @@ def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgateset, ncargs): def collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): """Collect and compile graph solutions for a given JAXPR. - This function uses the GraphSolutionInterpreter to evaluate + This function uses the DecompRuleInterpreter to evaluate the input JAXPR and obtain a new JAXPR that incorporates the graph-based decomposition solutions. @@ -331,7 +331,7 @@ def collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): Returns: ClosedJaxpr: The decomposed JAXPR. """ - gds_interpreter = GraphSolutionInterpreter(**tkwargs) + gds_interpreter = DecompRuleInterpreter(**tkwargs) def gds_wrapper(*args): return gds_interpreter.eval(inner_jaxpr, consts, *args) @@ -401,7 +401,7 @@ def handle_transform( # the current jaxpr based on the current gateset # but we don't rewrite the jaxpr at this stage. - # gds_interpreter = GraphSolutionInterpreter(*targs, **tkwargs) + # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) # def gds_wrapper(*args): # return gds_interpreter.eval(inner_jaxpr, consts, *args) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 8d75c774b3..53042a6a71 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -844,7 +844,7 @@ def test_decomposition_rule_name_ctrl(): qml.transforms.decompose, gate_set={"RX", "RZ", "H", "CZ"}, ) - @qml.qnode(qml.device("lightning.qubit", wires=5)) + @qml.qnode(qml.device("lightning.qubit", wires=2)) # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" # CHECK{LITERAL}: func.func public @circuit_17() -> tensor attributes {decompose_gatesets def circuit_17(): @@ -912,12 +912,12 @@ def test_decompose_lowering_with_other_passes(): qml.transforms.decompose, gate_set={"RZ", "RY", "CNOT", "GlobalPhase"}, ) - @qml.qnode(qml.device("lightning.qubit", wires=4)) + @qml.qnode(qml.device("lightning.qubit", wires=1)) # CHECK: module attributes {transform.with_named_sequence} { # CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { # CHECK-NEXT: [[ONE:%.+]] = transform.apply_registered_pass "decompose-lowering" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> # CHECK-NEXT: [[TWO:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[ONE]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> - # CHECK-NEXT: [[THREE:%.+]] = transform.apply_registered_pass "merge-rotations" to [[TWO]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to [[TWO]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> # CHECK-NEXT: transform.yield # CHECK-NEXT: } def circuit_17(): @@ -925,7 +925,7 @@ def circuit_17(): # CHECK: [[OUT_0:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "PauliX"() [[OUT_0]] : !quantum.bit # CHECK-NEXT: [[OUT_2:%.+]] = quantum.custom "RX"(%cst_0) [[OUT_1]] : !quantum.bit - # CHECK-NEXT: [[OUT_3:%.+]] = quantum.custom "RX"(%cst) [[OUT_2]] : !quantum.bit + # CHECK-NEXT: {{%.+}} = quantum.custom "RX"(%cst) [[OUT_2]] : !quantum.bit qml.PauliX(0) qml.PauliX(0) qml.RX(0.1, wires=0) @@ -934,8 +934,121 @@ def circuit_17(): # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"} # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} - # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} print(circuit_17.mlir) test_decompose_lowering_with_other_passes() + + +def test_decompose_lowering_with_ordered_passes(): + """Test the decompose lowering pass with other passes in a specific order in a pass pipeline.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RZ", "RY", "CNOT", "GlobalPhase"}, + ) + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: module attributes {transform.with_named_sequence} { + # CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + # CHECK-NEXT: [[FIRST:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: [[SECOND:%.+]] = transform.apply_registered_pass "merge-rotations" to [[FIRST]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "decompose-lowering" to [[SECOND]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: transform.yield + # CHECK-NEXT: } + def circuit_18(x: float): + # CHECK: [[OUT:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit + # CHECK-NEXT: [[OUT_0:%.+]] = quantum.custom "PauliX"() [[OUT]] : !quantum.bit + # CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "RX"([[EXTRACTED]]) [[OUT_0]] : !quantum.bit + # CHECK-NEXT: [[NEGATED:%.+]] = stablehlo.negate %arg0 : tensor + # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract [[NEGATED]][] : tensor + # CHECK-NEXT: {{%.+}} = quantum.custom "RX"([[EXTRACTED_2]]) [[OUT_1]] : !quantum.bit + qml.PauliX(0) + qml.PauliX(0) + qml.RX(x, wires=0) + qml.RX(-x, wires=0) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"} + # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circuit_18.mlir) + + +test_decompose_lowering_with_ordered_passes() + + +def test_decompose_lowering_multirz(): + """Test the decompose lowering pass with MultiRZ in the gate set.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"CNOT", "RZ"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" + def circuit_19(x: float): + # CHECK: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.multirz([[EXTRACTED]]) %1 : !quantum.bit + # CHECK-NEXT: [[BIT_1:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[EXTRACTED_0:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_QUBITS_1:%.+]] = quantum.multirz([[EXTRACTED_0]]) [[OUT_QUBITS]], [[BIT_1]] : !quantum.bit, !quantum.bit + # CHECK-NEXT: [[BIT_2:%.+]] = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: {{%.+}} = quantum.multirz([[EXTRACTED_2]]) {{%.+}}, {{%.+}}, [[BIT_2]] : !quantum.bit, !quantum.bit, !quantum.bit + qml.MultiRZ(x, wires=[0]) + qml.MultiRZ(x, wires=[0, 1]) + qml.MultiRZ(x, wires=[1, 0, 2]) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_3(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: %0 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %arg0) -> (!quantum.reg) + # CHECK-DAG: %5 = scf.for %arg3 = %c1 to %c3 step %c1 iter_args(%arg4 = %4) -> (!quantum.reg) + print(circuit_19.mlir) + + +test_decompose_lowering_multirz() + + +def test_decompose_lowering_with_gphase(): + """Test the decompose lowering pass with GlobalPhase.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RY", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" + + def circuit_20(): + # CHECK: quantum.gphase(%cst_0) : + # CHECK-NEXT: [[EXTRACTED:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.custom "PhaseShift"(%cst) [[EXTRACTED]] : !quantum.bit + # CHECK-NEXT: {{%.+}} = quantum.custom "PhaseShift"(%cst) [[OUT_QUBITS]] : !quantum.bit + qml.GlobalPhase(0.5) + qml.ctrl(qml.GlobalPhase, control=0)(0.3) + qml.ctrl(qml.GlobalPhase, control=0)(phi=0.3, wires=[1, 2]) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + print(circuit_20.mlir) + + +test_decompose_lowering_with_gphase() From 9c2681c69c3e75c0a70ec029c85e15ddc5b1c733 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Fri, 19 Sep 2025 16:17:35 -0400 Subject: [PATCH 24/36] Update changelog --- doc/releases/changelog-dev.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e6a2d2d8e5..552380941f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,6 +2,12 @@

New features since last release

+* A new experimental decomposition system is introduced in Catalyst enabling the + PennyLane's graph-based decomposition and MLIR-based lowering of decomposition rules. + This feature is integrated with PennyLane program capture and graph-based decomposition + including support for custom decomposition rules and operators. + [(#2029)](https://github.com/PennyLaneAI/catalyst/pull/2029) + * A new pass `--t-layer-reduction` has been added to reduce the depth and number of non-Clifford PPR operations by commuting adjacent PPRs and finding possible PPRs that can be merged. For more details, see the Figure 6 in [A Game of Surface Code](https://arXiv:1808.02892v3) paper. @@ -26,6 +32,7 @@ return qml.probs() ``` +

Improvements 🛠

* Significantly improved resource tracking with `null.qubit`. From 4d155dfa773bb83185c29a691f064beddfefaf28 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 23 Sep 2025 14:28:09 -0400 Subject: [PATCH 25/36] Decomposition pass is added to allow user defined decomposition rules present in MLIR (#2001) **Context:** Introduces a new MLIR pass that enables user-defined decomposition of quantum gate operations. The pass discovers decomposition functions in the module and uses MLIR's inlining infrastructure to replace target quantum operations with their decomposed implementations. **Description of the Change:** This PR will **recognize** the `func.func` as decomposition if they have name start with the pattern `_rule[.]*` where the `` is a `quantum.custom` op it expect to replace (It's not what generated from frontend). Also you can just mark the attributes `catalyst.decomposition` and `catalyst.decomposition.target_op` at decomposition function, it still works. And this pass just **discover** those decomposition functions and **replace** the target operation with `call` to the function and rely on the Inliner interface to inline these decomposition function if needed. The stages of the `--decompose-lowering`: 1. `decompose_lowering.cpp`: Main pass that orchestrates the decomposition process 2. `DecomposeLoweringPatterns`: Pattern rewriting logic that replaces `quantum.custom` operations with function call 3. `QuantumInlinerInterface`: Dialect interface that enables MLIR's inliner to process quantum operations correctly 4. Remove unused decomposition function (if the inliner decide not to inline the certain function, then it will leave a `func.call`, so the func is used, shouldn't be removed) The type signature between decomposition function provided from the frontend is different to the `quantum.custom` op, `OpSignatureAnalyzer` does the trick to get the enough information to generate a function call to decomposition function. Current right now, frontend choose to generate the decomposition function with type: ``` (qreg, param*, inWires*, inCtrlWires*?, inCtrlValues*?) -> qreg ``` We need to figure out the information that need to pass to create the function call: qreg, `in wires indices`, and `in ctrl wires`. `OpSignatureAnalyzer` will use the target `quantum.custom` op (that one we need to replace) to traverse the qubit, and find out the qreg, and wire indices. The traversal logic based on the following assumption for the quantum dialect: 1. The ordering of qubits is preserved in quantum instructions 2. For other instructions, they will use a register, so will hit an extract before them It promise that every gates, the result (qubits) will match to the operands (qubits) at the same index. That's the important assumptions to support the traversal logic here. Traverse the qubit, match the index, keep going up until reach the `qauntum.extract` operation. After then, we use those information to generate a function call and rely on the Inliner in MLIR infra to replace it. **Benefits:** The decision to use MLIR's inlining infrastructure (`InlinerPass`) is to provide future-proofing and flexibility. **Possible Drawbacks:** **Related GitHub Issues:** [sc-98593] --------- Co-authored-by: Ali Asadi <10773383+maliasadi@users.noreply.github.com> --- frontend/catalyst/from_plxpr/qubit_handler.py | 1 - mlir/include/Quantum/IR/QuantumOps.td | 1 + mlir/include/Quantum/Transforms/Patterns.h | 7 + mlir/lib/Quantum/IR/QuantumDialect.cpp | 66 ++- mlir/lib/Quantum/IR/QuantumOps.cpp | 25 +- mlir/lib/Quantum/Transforms/CMakeLists.txt | 1 + .../Transforms/DecomposeLoweringPatterns.cpp | 456 ++++++++++++++++ .../Quantum/Transforms/decompose_lowering.cpp | 176 +++++- mlir/test/Quantum/CanonicalizationTest.mlir | 11 +- mlir/test/Quantum/DecomposeLoweringTest.mlir | 510 ++++++++++++++++++ 10 files changed, 1237 insertions(+), 17 deletions(-) create mode 100644 mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp create mode 100644 mlir/test/Quantum/DecomposeLoweringTest.mlir diff --git a/frontend/catalyst/from_plxpr/qubit_handler.py b/frontend/catalyst/from_plxpr/qubit_handler.py index 17115dd2f9..50ce1a1975 100644 --- a/frontend/catalyst/from_plxpr/qubit_handler.py +++ b/frontend/catalyst/from_plxpr/qubit_handler.py @@ -89,7 +89,6 @@ class QubitHandler: wire_map: dict[int, AbstractQbit] # Note: No dynamic wire indices for now in from_plxpr. def __init__(self, qubit_or_qreg_ref: AbstractQreg | list[AbstractQbit] | tuple[AbstractQbit]): - if isinstance(qubit_or_qreg_ref, (list, tuple)): self.abstract_qreg_val = None self.qubit_indices = qubit_or_qreg_ref diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 89498b185a..0b3a76296d 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -241,6 +241,7 @@ def ExtractOp : Memory_Op<"extract", [NoMemoryEffect]> { $qreg `[` ($idx^):($idx_attr)? `]` attr-dict `:` type($qreg) `->` type(results) }]; + let hasCanonicalizeMethod = 1; let hasVerifier = 1; let hasFolder = 1; } diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index a16569c01b..8b8ade74c1 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -15,8 +15,12 @@ #pragma once #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/AllocatorBase.h" namespace catalyst { namespace quantum { @@ -26,6 +30,9 @@ void populateAdjointPatterns(mlir::RewritePatternSet &); void populateSelfInversePatterns(mlir::RewritePatternSet &); void populateMergeRotationsPatterns(mlir::RewritePatternSet &); void populateIonsDecompositionPatterns(mlir::RewritePatternSet &); +void populateDecomposeLoweringPatterns(mlir::RewritePatternSet &, + const llvm::StringMap &, + const llvm::StringSet &); void populateLoopBoundaryPatterns(mlir::RewritePatternSet &, unsigned int mode); } // namespace quantum diff --git a/mlir/lib/Quantum/IR/QuantumDialect.cpp b/mlir/lib/Quantum/IR/QuantumDialect.cpp index 7049f58e63..14f5e6e811 100644 --- a/mlir/lib/Quantum/IR/QuantumDialect.cpp +++ b/mlir/lib/Quantum/IR/QuantumDialect.cpp @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser + #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser -#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser +#include "mlir/Transforms/InliningUtils.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/IR/QuantumOps.h" @@ -22,6 +25,65 @@ using namespace mlir; using namespace catalyst::quantum; +//===----------------------------------------------------------------------===// +// Quantum Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { + +struct QuantumInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + static constexpr StringRef decompAttr = "target_gate"; + + /// Returns true if the given operation 'callable' can be inlined into the + /// position given by the 'call'. Currently, we always inline quantum + /// decomposition functions. + bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final + { + if (auto funcOp = dyn_cast(callable)) { + return funcOp->hasAttr(decompAttr); + } + return false; + } + + /// Returns true if the given region 'src' can be inlined into the region + /// 'dest'. Only allow for decomposition functions. + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final + { + if (auto funcOp = src->getParentOfType()) { + return funcOp->hasAttr(decompAttr); + } + return false; + } + + // Allow to inline operations from decomposition functions. + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final + { + if (auto funcOp = op->getParentOfType()) { + return funcOp->hasAttr(decompAttr); + } + return false; + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. Required when the region has only one block. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final + { + auto yieldOp = dyn_cast(op); + if (!yieldOp) { + return; + } + + for (auto retValue : llvm::zip(valuesToRepl, yieldOp.getOperands())) { + std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue)); + } + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Quantum dialect definitions. //===----------------------------------------------------------------------===// @@ -45,6 +107,8 @@ void QuantumDialect::initialize() #include "Quantum/IR/QuantumOps.cpp.inc" >(); + addInterfaces(); + declarePromisedInterfaces(); diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 101ed77a57..b0eedf09d3 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -146,6 +146,28 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) return nullptr; } +LogicalResult ExtractOp::canonicalize(ExtractOp extract, mlir::PatternRewriter &rewriter) +{ + // Handle the pattern: %reg2 = insert %reg1[idx], %qubit -> %q = extract %reg2[idx] + // Convert to: %q = %qubit, and replace other uses of %reg2 with %reg1 + if (auto insert = dyn_cast_if_present(extract.getQreg().getDefiningOp())) { + bool bothStatic = extract.getIdxAttr().has_value() && insert.getIdxAttr().has_value(); + bool bothDynamic = !extract.getIdxAttr().has_value() && !insert.getIdxAttr().has_value(); + bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr(); + bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx(); + // if other users of insert are also `insert`, we are good to go + bool valid = llvm::all_of(insert.getResult().getUsers(), [&](Operation *op) { + return isa(op) || op == extract.getOperation(); + }); + if ((staticallyEqual || dynamicallyEqual) && valid) { + rewriter.replaceOp(extract, insert.getQubit()); + rewriter.replaceOp(insert, insert.getInQreg()); + return success(); + } + } + return failure(); +} + LogicalResult InsertOp::canonicalize(InsertOp insert, mlir::PatternRewriter &rewriter) { if (auto extract = dyn_cast_if_present(insert.getQubit().getDefiningOp())) { @@ -153,9 +175,10 @@ LogicalResult InsertOp::canonicalize(InsertOp insert, mlir::PatternRewriter &rew bool bothDynamic = !extract.getIdxAttr().has_value() && !insert.getIdxAttr().has_value(); bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr(); bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx(); + bool sameQreg = extract.getQreg() == insert.getInQreg(); bool oneUse = extract.getResult().hasOneUse(); - if ((staticallyEqual || dynamicallyEqual) && oneUse) { + if ((staticallyEqual || dynamicallyEqual) && oneUse && sameQreg) { rewriter.replaceOp(insert, insert.getInQreg()); rewriter.eraseOp(extract); return success(); diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index ddc54e3148..26b3ac8410 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ file(GLOB SRC merge_rotation.cpp MergeRotationsPatterns.cpp decompose_lowering.cpp + DecomposeLoweringPatterns.cpp DisentangleSWAP.cpp DisentangleCNOT.cpp ions_decompositions.cpp diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp new file mode 100644 index 0000000000..9dcc4ea1ad --- /dev/null +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -0,0 +1,456 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "decompose-lowering" + +#include + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" + +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" + +using namespace mlir; +using namespace catalyst::quantum; + +namespace catalyst { +namespace quantum { + +/// A struct to represent qubit indices in quantum operations. +/// +/// This struct provides a way to handle qubit indices that can be either: +/// - A runtime Value (for dynamic indices computed at runtime) +/// - An IntegerAttr (for compile-time constant indices) +/// - Invalid/uninitialized (represented by std::monostate) +/// +/// The struct uses std::variant to ensure only one type is active at a time, +/// preventing invalid states. +/// +/// Example usage: +/// QubitIndex dynamicIdx(operandValue); // Runtime qubit index +/// QubitIndex staticIdx(IntegerAttr::get(...)); // Compile-time constant +/// QubitIndex invalidIdx; // Uninitialized state +/// +/// if (dynamicIdx) { // Check if valid +/// if (dynamicIdx.isValue()) { // Check if runtime value +/// Value idx = dynamicIdx.getValue(); // Get the Value +/// } +/// } +struct QubitIndex { + // use monostate to represent the invalid index + std::variant index; + + QubitIndex() : index(std::monostate()) {} + QubitIndex(Value val) : index(val) {} + QubitIndex(IntegerAttr attr) : index(attr) {} + + bool isValue() const { return std::holds_alternative(index); } + bool isAttr() const { return std::holds_alternative(index); } + operator bool() const { return isValue() || isAttr(); } + Value getValue() const { return isValue() ? std::get(index) : nullptr; } + IntegerAttr getAttr() const { return isAttr() ? std::get(index) : nullptr; } +}; + +// The goal of this class is to analyze the signature of a custom operation to get the enough +// information to prepare the call operands and results for replacing the op to calling the +// decomposition function. +class OpSignatureAnalyzer { + public: + OpSignatureAnalyzer() = delete; + OpSignatureAnalyzer(CustomOp op, bool enableQregMode) + : signature(OpSignature{ + .params = op.getParams(), + .inQubits = op.getInQubits(), + .inCtrlQubits = op.getInCtrlQubits(), + .inCtrlValues = op.getInCtrlValues(), + .outQubits = op.getOutQubits(), + .outCtrlQubits = op.getOutCtrlQubits(), + }) + { + if (!enableQregMode) + return; + + signature.sourceQreg = getSourceQreg(signature.inQubits.front()); + if (!signature.sourceQreg) { + op.emitError("Cannot get source qreg"); + isValid = false; + return; + } + + // input wire indices + for (Value qubit : signature.inQubits) { + const QubitIndex index = getExtractIndex(qubit); + if (!index) { + op.emitError("Cannot get index for input qubit"); + isValid = false; + return; + } + signature.inWireIndices.emplace_back(index); + } + + // input ctrl wire indices + for (Value ctrlQubit : signature.inCtrlQubits) { + const QubitIndex index = getExtractIndex(ctrlQubit); + if (!index) { + op.emitError("Cannot get index for ctrl qubit"); + isValid = false; + return; + } + signature.inCtrlWireIndices.emplace_back(index); + } + + // Output qubit indices are the same as input qubit indices + signature.outQubitIndices = signature.inWireIndices; + signature.outCtrlQubitIndices = signature.inCtrlWireIndices; + } + + operator bool() const { return isValid; } + + // Prepare the operands for calling the decomposition function + // There are two cases: + // 1. The first input is a qreg, which means the decomposition function is a qreg mode function + // 2. Otherwise, the decomposition function is a qubit mode function + // + // Type signatures: + // 1. qreg mode: + // - func(qreg, param*, inWires*, inCtrlWires*?, inCtrlValues*?) -> qreg + // 2. qubit mode: + // - func(param*, inQubits*, inCtrlQubits*?, inCtrlValues*?) -> outQubits* + llvm::SmallVector prepareCallOperands(func::FuncOp decompFunc, PatternRewriter &rewriter, + Location loc) + { + auto funcType = decompFunc.getFunctionType(); + auto funcInputs = funcType.getInputs(); + + SmallVector operands(funcInputs.size()); + + int operandIdx = 0; + if (isa(funcInputs[0])) { + Value updatedQreg = signature.sourceQreg; + for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { + const QubitIndex &index = signature.inWireIndices[i]; + updatedQreg = + rewriter.create(loc, updatedQreg.getType(), updatedQreg, + index.getValue(), index.getAttr(), qubit); + } + + operands[operandIdx++] = updatedQreg; + if (!signature.params.empty()) { + auto [startIdx, endIdx] = + findParamTypeRange(funcInputs, signature.params.size(), operandIdx); + ArrayRef paramsTypes = funcInputs.slice(startIdx, endIdx - startIdx); + auto updatedParams = generateParams(signature.params, paramsTypes, rewriter, loc); + for (Value param : updatedParams) { + operands[operandIdx++] = param; + } + } + + if (!signature.inWireIndices.empty()) { + operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices, + funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + + if (!signature.inCtrlWireIndices.empty()) { + operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices, + funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + } + else { + if (!signature.params.empty()) { + auto [startIdx, endIdx] = + findParamTypeRange(funcInputs, signature.params.size(), operandIdx); + ArrayRef paramsTypes = funcInputs.slice(startIdx, endIdx - startIdx); + auto updatedParams = generateParams(signature.params, paramsTypes, rewriter, loc); + for (Value param : updatedParams) { + operands[operandIdx++] = param; + } + } + + for (auto inQubit : signature.inQubits) { + operands[operandIdx] = + fromTensorOrAsIs(inQubit, funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + + for (auto inCtrlQubit : signature.inCtrlQubits) { + operands[operandIdx] = + fromTensorOrAsIs(inCtrlQubit, funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + } + + if (!signature.inCtrlValues.empty()) { + operands[operandIdx] = + fromTensorOrAsIs(signature.inCtrlValues, funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + + return operands; + } + + // Prepare the results for the call operation + SmallVector prepareCallResultForQreg(func::CallOp callOp, PatternRewriter &rewriter) + { + assert(callOp.getNumResults() == 1 && "only one qreg result for qreg mode is allowed"); + + auto qreg = callOp.getResult(0); + assert(isa(qreg.getType()) && "only allow to have qreg result"); + + SmallVector newResults; + rewriter.setInsertionPointAfter(callOp); + for (const QubitIndex &index : signature.outQubitIndices) { + auto extractOp = rewriter.create( + callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), + index.getAttr()); + newResults.emplace_back(extractOp.getResult()); + } + for (const QubitIndex &index : signature.outCtrlQubitIndices) { + auto extractOp = rewriter.create( + callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), + index.getAttr()); + newResults.emplace_back(extractOp.getResult()); + } + return newResults; + } + + private: + bool isValid = true; + + struct OpSignature { + ValueRange params; + ValueRange inQubits; + ValueRange inCtrlQubits; + ValueRange inCtrlValues; + ValueRange outQubits; + ValueRange outCtrlQubits; + + // Qreg mode specific information + Value sourceQreg = nullptr; + SmallVector inWireIndices; + SmallVector inCtrlWireIndices; + SmallVector outQubitIndices; + SmallVector outCtrlQubitIndices; + } signature; + + Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc) + { + if (isa(type)) { + return rewriter.create(loc, type, values); + } + return values.front(); + } + + static size_t getElementsCount(Type type) + { + if (isa(type)) { + auto tensorType = cast(type); + return tensorType.getNumElements() > 0 ? tensorType.getNumElements() : 1; + } + return 1; + } + + // Helper function to find the range of function input types that correspond to params + static std::pair findParamTypeRange(ArrayRef funcInputs, + size_t sigParamCount, size_t startIdx = 0) + { + size_t paramTypeCount = 0; + size_t paramTypeEnd = startIdx; + + while (paramTypeCount < sigParamCount) { + assert(paramTypeEnd < funcInputs.size() && + "param type end should be less than function input size"); + paramTypeCount += getElementsCount(funcInputs[paramTypeEnd]); + paramTypeEnd++; + } + + assert(paramTypeCount == sigParamCount && + "param type count should be equal to signature param count"); + + return {startIdx, paramTypeEnd}; + } + + // generate params for calling the decomposition function based on function type requirements + SmallVector generateParams(ValueRange signatureParams, ArrayRef funcParamTypes, + PatternRewriter &rewriter, Location loc) + { + SmallVector operands; + size_t sigParamIdx = 0; + + for (Type funcParamType : funcParamTypes) { + const size_t numElements = getElementsCount(funcParamType); + + // collect numElements of signature params + SmallVector tensorElements; + for (size_t i = 0; i < numElements && sigParamIdx < signatureParams.size(); i++) { + tensorElements.push_back(signatureParams[sigParamIdx++]); + } + operands.push_back(fromTensorOrAsIs(tensorElements, funcParamType, rewriter, loc)); + } + + return operands; + } + + Value fromTensorOrAsIs(ArrayRef indices, Type type, PatternRewriter &rewriter, + Location loc) + { + SmallVector values; + for (const QubitIndex &index : indices) { + if (index.isValue()) { + values.emplace_back(index.getValue()); + } + else if (index.isAttr()) { + auto attr = index.getAttr(); + auto constantValue = rewriter.create(loc, attr.getType(), attr); + values.emplace_back(constantValue); + } + } + + if (isa(type)) { + return rewriter.create(loc, type, values); + } + + assert(values.size() == 1 && "number of values should be 1 for non-tensor type"); + return values.front(); + } + + Value getSourceQreg(Value qubit) + { + while (qubit) { + if (auto extractOp = qubit.getDefiningOp()) { + return extractOp.getQreg(); + } + + if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) { + if (customOp.getQubitOperands().empty()) { + break; + } + qubit = customOp.getQubitOperands()[0]; + } + } + + return nullptr; + } + + QubitIndex getExtractIndex(Value qubit) + { + while (qubit) { + if (auto extractOp = qubit.getDefiningOp()) { + if (Value idx = extractOp.getIdx()) { + return QubitIndex(idx); + } + if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { + return QubitIndex(idxAttr); + } + } + + if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) { + auto qubitOperands = customOp.getQubitOperands(); + auto qubitResults = customOp.getQubitResults(); + auto it = + llvm::find_if(qubitResults, [&](Value result) { return result == qubit; }); + + if (it != qubitResults.end()) { + size_t resultIndex = std::distance(qubitResults.begin(), it); + if (resultIndex < qubitOperands.size()) { + qubit = qubitOperands[resultIndex]; + continue; + } + } + } + + break; + } + + return QubitIndex(); + } +}; + +struct DecomposeLoweringRewritePattern : public OpRewritePattern { + private: + const llvm::StringMap &decompositionRegistry; + const llvm::StringSet &targetGateSet; + + public: + DecomposeLoweringRewritePattern(MLIRContext *context, + const llvm::StringMap ®istry, + const llvm::StringSet &gateSet) + : OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet) + { + } + + LogicalResult matchAndRewrite(CustomOp op, PatternRewriter &rewriter) const override + { + StringRef gateName = op.getGateName(); + + // Only decompose the op if it is not in the target gate set + if (targetGateSet.contains(gateName)) { + return failure(); + } + + // Find the corresponding decomposition function for the op + auto it = decompositionRegistry.find(gateName); + if (it == decompositionRegistry.end()) { + return failure(); + } + func::FuncOp decompFunc = it->second; + + // Here is the assumption that the decomposition function must have at least one input and + // one result + assert(decompFunc.getFunctionType().getNumInputs() > 0 && + "Decomposition function must have at least one input"); + assert(decompFunc.getFunctionType().getNumResults() >= 1 && + "Decomposition function must have at least one result"); + + auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); + auto analyzer = OpSignatureAnalyzer(op, enableQreg); + assert(analyzer && "Analyzer should be valid"); + + rewriter.setInsertionPointAfter(op); + auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); + auto callOp = + rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), + decompFunc.getSymName(), callOperands); + + // Replace the op with the call op and adjust the insert ops for the qreg mode + if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { + auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + rewriter.replaceOp(op, results); + } + else { + rewriter.replaceOp(op, callOp->getResults()); + } + + return success(); + } +}; + +void populateDecomposeLoweringPatterns(RewritePatternSet &patterns, + const llvm::StringMap &decompositionRegistry, + const llvm::StringSet &targetGateSet) +{ + patterns.add(patterns.getContext(), decompositionRegistry, + targetGateSet); +} + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp index 8f0d6b638e..bbddf92023 100644 --- a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -14,15 +14,25 @@ #define DEBUG_TYPE "decompose-lowering" -#include "Catalyst/IR/CatalystDialect.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/AllocatorBase.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "llvm/Support/Debug.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" -#include "Catalyst/IR/CatalystDialect.h" #include "Quantum/IR/QuantumOps.h" #include "Quantum/Transforms/Patterns.h" -using namespace llvm; using namespace mlir; using namespace catalyst::quantum; @@ -32,17 +42,167 @@ namespace quantum { #define GEN_PASS_DECL_DECOMPOSELOWERINGPASS #include "Quantum/Transforms/Passes.h.inc" -struct DecomposeLoweringPass : public impl::DecomposeLoweringPassBase { - using impl::DecomposeLoweringPassBase::DecomposeLoweringPassBase; +namespace DecompUtils { + +static constexpr StringRef target_gate_attr_name = "target_gate"; +static constexpr StringRef decomp_gateset_attr_name = "decomp_gateset"; + +// Check if a function is a decomposition function +// It's expected that the decomposition function would have this attribute: +// `catalyst.decomposition.target_op` And this attribute is set by the `markDecompositionAttributes` +// functionq The decomposition attribute are used to determine if a function is a decomposition +// function, and target_op is that the decomposition function want to replace +bool isDecompositionFunction(func::FuncOp func) { return func->hasAttr(target_gate_attr_name); } + +StringRef getTargetGateName(func::FuncOp func) +{ + if (auto target_op_attr = func->getAttrOfType(target_gate_attr_name)) { + return target_op_attr.getValue(); + } + return StringRef{}; +} + +} // namespace DecompUtils + +/// A module pass that work through a module, register all decomposition functions, and apply the +/// decomposition patterns +struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase { + using DecomposeLoweringPassBase::DecomposeLoweringPassBase; + + void getDependentDialects(DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + private: + llvm::StringMap decompositionRegistry; + llvm::StringSet targetGateSet; + + // Function to discover and register decomposition functions from a module + // It's bookkeeping the targetOp and the decomposition function that can decompose the targetOp + void discoverAndRegisterDecompositions(ModuleOp module, + llvm::StringMap &decompositionRegistry) + { + module.walk([&](func::FuncOp func) { + if (StringRef targetOp = DecompUtils::getTargetGateName(func); !targetOp.empty()) { + decompositionRegistry[targetOp] = func; + } + // No need to walk into the function body + return WalkResult::skip(); + }); + } + + // Find the target gate set from the module.It's expected that the decomposition function would + // have this attribute: `decomp_gateset` And this attribute is set by the frontend, it contains + // the target gate set that the circuit function want to finally decompose into. Since each + // module only contains one circuit function, we can just find the target gate set from the + // function with the `decomp_gateset` attribute + void findTargetGateSet(ModuleOp module, llvm::StringSet &targetGateSet) + { + module.walk([&](func::FuncOp func) { + if (auto gate_set_attr = + func->getAttrOfType(DecompUtils::decomp_gateset_attr_name)) { + for (auto gate : gate_set_attr.getValue()) { + StringRef gate_name = cast(gate).getValue(); + targetGateSet.insert(gate_name); + } + return WalkResult::interrupt(); + } + // No need to walk into the function body + return WalkResult::skip(); + }); + } + + // Remove unused decomposition functions: + // Since the decomposition functions are marked as public from the frontend, + // there is no way to remove them with any DCE pass automatically. + // So we need to manually remove them from the module + void removeDecompositionFunctions(ModuleOp module, + llvm::StringMap &decompositionRegistry) + { + llvm::DenseSet usedDecompositionFunctions; + + module.walk([&](func::CallOp callOp) { + if (auto targetFunc = module.lookupSymbol(callOp.getCallee())) { + if (DecompUtils::isDecompositionFunction(targetFunc)) { + usedDecompositionFunctions.insert(targetFunc); + } + } + }); + + // remove unused decomposition functions + module.walk([&](func::FuncOp func) { + if (DecompUtils::isDecompositionFunction(func) && + !usedDecompositionFunctions.contains(func)) { + func.erase(); + } + return WalkResult::skip(); + }); + } + + public: + void runOnOperation() final + { + ModuleOp module = cast(getOperation()); + + // Step 1: Discover and register all decomposition functions in the module + discoverAndRegisterDecompositions(module, decompositionRegistry); + if (decompositionRegistry.empty()) { + return; + } + + // Step 1.1: Find the target gate set + findTargetGateSet(module, targetGateSet); + + // Step 2: Canonicalize the module + RewritePatternSet patternsCanonicalization(&getContext()); + catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization, + &getContext()); + if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) { + return signalPassFailure(); + } + + // Step 3: Apply the decomposition patterns + RewritePatternSet decompositionPatterns(&getContext()); + populateDecomposeLoweringPatterns(decompositionPatterns, decompositionRegistry, + targetGateSet); + if (failed(applyPatternsGreedily(module, std::move(decompositionPatterns)))) { + return signalPassFailure(); + } + + // Step 4: Inline and canonicalize/CSE the module again + PassManager pm(&getContext()); + pm.addPass(createInlinerPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + if (failed(pm.run(module))) { + return signalPassFailure(); + } + + // Step 5. Remove redundant decomposition functions + removeDecompositionFunctions(module, decompositionRegistry); - void runOnOperation() override { llvm::errs() << "Decompose Lowering Pass!\n"; } + // Step 6. Canonicalize the extract/insert pair + RewritePatternSet patternsInsertExtract(&getContext()); + catalyst::quantum::InsertOp::getCanonicalizationPatterns(patternsInsertExtract, + &getContext()); + catalyst::quantum::ExtractOp::getCanonicalizationPatterns(patternsInsertExtract, + &getContext()); + if (failed(applyPatternsGreedily(module, std::move(patternsInsertExtract)))) { + return signalPassFailure(); + } + } }; } // namespace quantum std::unique_ptr createDecomposeLoweringPass() { - return std::make_unique(); + return std::make_unique(); } } // namespace catalyst diff --git a/mlir/test/Quantum/CanonicalizationTest.mlir b/mlir/test/Quantum/CanonicalizationTest.mlir index 4c698ed4e7..4b9620575d 100644 --- a/mlir/test/Quantum/CanonicalizationTest.mlir +++ b/mlir/test/Quantum/CanonicalizationTest.mlir @@ -83,8 +83,7 @@ func.func @test_extract_insert_no_fold_static(%r1: !quantum.reg, %i1: i64, %i2: %q2 = quantum.extract %r2[0] : !quantum.reg -> !quantum.bit %r3 = quantum.insert %r2[%i1], %q2 : !quantum.reg, !quantum.bit - // CHECK: quantum.extract - // CHECK: quantum.insert + %q3 = quantum.extract %r3[%i1] : !quantum.reg -> !quantum.bit %r4 = quantum.insert %r3[%i2], %q3 : !quantum.reg, !quantum.bit @@ -167,14 +166,14 @@ func.func @test_interleaved_extract_insert() -> tensor<4xf64> { // CHECK: [[QBIT:%.+]] = quantum.extract [[QREG:%.+]][ // CHECK: [[QBIT_1:%.+]] = quantum.custom "Hadamard"() [[QBIT]] // CHECK: [[QREG_1:%.+]] = quantum.insert [[QREG]] - // CHECK-NOT: quantum.insert - // COM: check that insert op canonicalization correctly removes unnecessary extract/inserts + // CHECK-NOT: quantum.insert + // COM: check that insert op canonicalization correctly removes unnecessary extract/inserts // CHECK: quantum.compbasis qreg [[QREG_1]] %1 = quantum.extract %0[%c0_i64] : !quantum.reg -> !quantum.bit %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit %2 = quantum.extract %0[%c1_i64] : !quantum.reg -> !quantum.bit - %3 = quantum.insert %0[%c0_i64], %out_qubits : !quantum.reg, !quantum.bit - %4 = quantum.insert %3[%c1_i64], %2 : !quantum.reg, !quantum.bit + %3 = quantum.insert %0[%c1_i64], %2 : !quantum.reg, !quantum.bit + %4 = quantum.insert %3[%c0_i64], %out_qubits : !quantum.reg, !quantum.bit %5 = quantum.compbasis qreg %4 : !quantum.obs %6 = quantum.probs %5 : tensor<4xf64> quantum.dealloc %4 : !quantum.reg diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir new file mode 100644 index 0000000000..91bfbe7778 --- /dev/null +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -0,0 +1,510 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt --decompose-lowering --split-input-file -verify-diagnostics %s | FileCheck %s + +module @two_hadamards { + func.func public @test_two_hadamards() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<4xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst_0) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} + +// ----- + +// Test single Hadamard decomposition +module @single_hadamard { + func.func @test_single_hadamard() -> !quantum.bit { + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %2 = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: return [[QUBIT2]] + return %2 : !quantum.bit + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst_0) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} + +// ----- +module @recursive { + func.func public @test_recursive() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<4xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit + return %out_qubits_0 : !quantum.bit + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @RZRY_decomp + func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits_1 = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_2 = quantum.custom "RY"(%cst_0) %out_qubits_1 : !quantum.bit + return %out_qubits_2 : !quantum.bit + } +} + +// ----- +module @recursive { + func.func public @test_recursive() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<4xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit + return %out_qubits_0 : !quantum.bit + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @RZRY_decomp + func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits_1 = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_2 = quantum.custom "RY"(%cst_0) %out_qubits_1 : !quantum.bit + return %out_qubits_2 : !quantum.bit + } +} + +// ----- + +// Test parametric gates and wires +module @param_rxry { + func.func public @test_param_rxry(%arg0: tensor, %arg1: tensor) -> tensor<2xf64> { + %c0_i64 = arith.constant 0 : i64 + + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + %0 = quantum.alloc( 1) : !quantum.reg + + // CHECK: [[WIRE:%.+]] = tensor.extract %arg1[] : tensor + %extracted = tensor.extract %arg1[] : tensor + + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit + + // CHECK: [[PARAM:%.+]] = tensor.extract %arg0[] : tensor + %param_0 = tensor.extract %arg0[] : tensor + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "ParametrizedRXRY" + %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0) %1 : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<2xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<2xf64> + } + + // Decomposition function expects tensor while operation provides f64 + // CHECK-NOT: func.func private @ParametrizedRX_decomp + func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: !quantum.bit) -> !quantum.bit + attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} { + %extracted = tensor.extract %arg0[] : tensor + %out_qubits = quantum.custom "RX"(%extracted) %arg1 : !quantum.bit + %extracted_0 = tensor.extract %arg0[] : tensor + %out_qubits_1 = quantum.custom "RY"(%extracted_0) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} +// ----- + +// Test parametric gates and wires +module @param_rxry_2 { + func.func public @test_param_rxry_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xf64> { + %c0_i64 = arith.constant 0 : i64 + + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + %0 = quantum.alloc( 1) : !quantum.reg + + // CHECK: [[WIRE:%.+]] = tensor.extract %arg2[] : tensor + %extracted = tensor.extract %arg2[] : tensor + + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit + + // CHECK: [[PARAM_0:%.+]] = tensor.extract %arg0[] : tensor + %param_0 = tensor.extract %arg0[] : tensor + + // CHECK: [[PARAM_1:%.+]] = tensor.extract %arg1[] : tensor + %param_1 = tensor.extract %arg1[] : tensor + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM_0]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM_1]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "ParametrizedRXRY" + %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0, %param_1) %1 : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<2xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<2xf64> + } + + // Decomposition function expects tensor while operation provides f64 + // CHECK-NOT: func.func private @ParametrizedRX_decomp + func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: tensor, %arg2: !quantum.bit) -> !quantum.bit + attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} { + %extracted_param_0 = tensor.extract %arg0[] : tensor + %out_qubits = quantum.custom "RX"(%extracted_param_0) %arg2 : !quantum.bit + %extracted_param_1 = tensor.extract %arg1[] : tensor + %out_qubits_1 = quantum.custom "RY"(%extracted_param_1) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} +// ----- + +// Test recursive and qreg-based gate decomposition +module @qreg_base_circuit { + func.func public @test_qreg_base_circuit() -> tensor<2xf64> { + // CHECK: [[CST:%.+]] = arith.constant 1.000000e+00 : f64 + %cst = arith.constant 1.000000e+00 : f64 + + // CHECK: [[CST_0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: [[CST_1:%.+]] = arith.constant dense<0> : tensor<1xi64> + // CHECK: [[CST_2:%.+]] = arith.constant dense<1.000000e+00> : tensor + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + %0 = quantum.alloc( 1) : !quantum.reg + + // CHECK: [[EXTRACT_QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[MRES:%.+]], [[OUT_QUBIT:%.+]] = quantum.measure [[EXTRACT_QUBIT]] : i1, !quantum.bit + // CHECK: [[REG1:%.+]] = quantum.insert [[REG]][ 0], [[OUT_QUBIT]] : !quantum.reg, !quantum.bit + // CHECK: [[COMPARE:%.+]] = stablehlo.compare NE, [[CST_2]], [[CST_0]], FLOAT : (tensor, tensor) -> tensor + // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[COMPARE]][] : tensor + // CHECK: [[CONDITIONAL:%.+]] = scf.if [[EXTRACTED]] -> (!quantum.reg) { + // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST_1]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED_3:%.+]] = tensor.extract [[RESHAPE1]][] : tensor + // CHECK: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[EXTRACTED_3]] : tensor<1xi64> + // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[FROM_ELEMENTS]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED_4:%.+]] = tensor.extract [[RESHAPE2]][] : tensor + // CHECK: [[EXTRACT1:%.+]] = quantum.extract [[REG1]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT1]] : !quantum.bit + // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_4]]], [[RZ1]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_3]]], [[EXTRACT2]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT3]] : !quantum.bit + // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_4]]], [[RZ2]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACT4:%.+]] = quantum.extract [[INSERT3]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit + // CHECK: [[INSERT4:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_3]]], [[EXTRACT4]] : !quantum.reg, !quantum.bit + // CHECK: scf.yield [[INSERT4]] : !quantum.reg + // CHECK: } else { + // CHECK: scf.yield [[REG1]] : !quantum.reg + // CHECK: } + // CHECK-NOT: quantum.custom "Test" + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Test"(%cst) %1 : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<2xf64> + + quantum.dealloc %2 : !quantum.reg + quantum.device_release + return %4 : tensor<2xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Test_rule_1 + func.func private @Test_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + attributes {target_gate = "Test", llvm.linkage = #llvm.linkage} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %10 = quantum.extract %arg0[ 0] : !quantum.reg -> !quantum.bit + %mres, %out_qubit = quantum.measure %10 : i1, !quantum.bit + %11 = quantum.insert %arg0[ 0], %out_qubit : !quantum.reg, !quantum.bit + %0 = stablehlo.compare NE, %arg1, %cst, FLOAT : (tensor, tensor) -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = scf.if %extracted -> (!quantum.reg) { + %2 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %3 = stablehlo.reshape %2 : (tensor<1xi64>) -> tensor + %extracted_0 = tensor.extract %3[] : tensor + %4 = quantum.extract %11[%extracted_0] : !quantum.reg -> !quantum.bit + %extracted_1 = tensor.extract %arg1[] : tensor + %out_qubits = quantum.custom "RzDecomp"(%extracted_1) %4 : !quantum.bit + %5 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor + %extracted_2 = tensor.extract %3[] : tensor + %7 = quantum.insert %11[%extracted_2], %out_qubits : !quantum.reg, !quantum.bit + %extracted_3 = tensor.extract %6[] : tensor + %8 = quantum.extract %7[%extracted_3] : !quantum.reg -> !quantum.bit + %extracted_4 = tensor.extract %arg1[] : tensor + %out_qubits_5 = quantum.custom "RzDecomp"(%extracted_4) %8 : !quantum.bit + %extracted_6 = tensor.extract %6[] : tensor + %9 = quantum.insert %7[%extracted_6], %out_qubits_5 : !quantum.reg, !quantum.bit + scf.yield %9 : !quantum.reg + } else { + scf.yield %11 : !quantum.reg + } + return %1 : !quantum.reg + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @RzDecomp_rule_1 + func.func private @RzDecomp_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + attributes {target_gate = "RzDecomp", llvm.linkage = #llvm.linkage} { + %0 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor + %extracted = tensor.extract %1[] : tensor + %2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit + %extracted_0 = tensor.extract %arg1[] : tensor + %out_qubits = quantum.custom "RZ"(%extracted_0) %2 : !quantum.bit + %extracted_1 = tensor.extract %1[] : tensor + %3 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit + return %3 : !quantum.reg + } +} + +// ----- + +module @multi_wire_cnot_decomposition { + func.func public @test_cnot_decomposition() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[WIRE_TENSOR:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64> + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[WIRE_TENSOR]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor + // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[WIRE_TENSOR]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[RESHAPE2]][] : tensor + // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit + // CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit + // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG]][[[EXTRACTED]]], [[RY1]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACTED2:%.+]] = tensor.extract [[RESHAPE1]][] : tensor + // CHECK: [[QUBIT0:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit + // CHECK: [[QUBIT1_2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[QUBIT1_2]] : !quantum.bit, !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit + // CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit + // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[RY2]] : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_QUBIT0:%.+]] = quantum.extract [[INSERT3]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[FINAL_QUBIT1:%.+]] = quantum.extract [[INSERT3]][ 1] : !quantum.reg -> !quantum.bit + // CHECK-NOT: quantum.custom "CNOT" + %3, %4 = quantum.custom "CNOT"() %1, %2 : !quantum.bit, !quantum.bit + + // CHECK: [[FINAL_INSERT1:%.+]] = quantum.insert [[REG]][ 0], [[FINAL_QUBIT0]] : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_INSERT2:%.+]] = quantum.insert [[FINAL_INSERT1]][ 1], [[FINAL_QUBIT1]] : !quantum.reg, !quantum.bit + %5 = quantum.insert %0[ 0], %3 : !quantum.reg, !quantum.bit + %6 = quantum.insert %5[ 1], %4 : !quantum.reg, !quantum.bit + %7 = quantum.compbasis qreg %6 : !quantum.obs + %8 = quantum.probs %7 : tensor<4xf64> + quantum.dealloc %6 : !quantum.reg + return %8 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @CNOT_rule_cz_rz_ry + func.func private @CNOT_rule_cz_rz_ry(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { + // CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H) + %cst = arith.constant 1.5707963267948966 : f64 + %cst_0 = arith.constant 3.1415926535897931 : f64 + + // Extract wire indices from tensor + %0 = stablehlo.slice %arg1 [0:1] : (tensor<2xi64>) -> tensor<1xi64> + %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor + %2 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %3 = stablehlo.reshape %2 : (tensor<1xi64>) -> tensor + + // Step 1: Apply H to target qubit (H = RZ(π) * RY(π/2)) + %extracted = tensor.extract %3[] : tensor + %4 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "RZ"(%cst_0) %4 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst) %out_qubits : !quantum.bit + %extracted_2 = tensor.extract %3[] : tensor + %5 = quantum.insert %arg0[%extracted_2], %out_qubits_1 : !quantum.reg, !quantum.bit + + // Step 2: Apply CZ gate + %extracted_3 = tensor.extract %1[] : tensor + %6 = quantum.extract %5[%extracted_3] : !quantum.reg -> !quantum.bit + %extracted_4 = tensor.extract %3[] : tensor + %7 = quantum.extract %5[%extracted_4] : !quantum.reg -> !quantum.bit + %out_qubits_5:2 = quantum.custom "CZ"() %6, %7 : !quantum.bit, !quantum.bit + %extracted_6 = tensor.extract %1[] : tensor + %8 = quantum.insert %5[%extracted_6], %out_qubits_5#0 : !quantum.reg, !quantum.bit + %extracted_7 = tensor.extract %3[] : tensor + %9 = quantum.insert %8[%extracted_7], %out_qubits_5#1 : !quantum.reg, !quantum.bit + + // Step 3: Apply H to target qubit again + %extracted_8 = tensor.extract %3[] : tensor + %10 = quantum.extract %9[%extracted_8] : !quantum.reg -> !quantum.bit + %out_qubits_9 = quantum.custom "RZ"(%cst_0) %10 : !quantum.bit + %out_qubits_10 = quantum.custom "RY"(%cst) %out_qubits_9 : !quantum.bit + %extracted_11 = tensor.extract %3[] : tensor + %11 = quantum.insert %9[%extracted_11], %out_qubits_10 : !quantum.reg, !quantum.bit + + return %11 : !quantum.reg + } +} + +// ----- + +module @cnot_alternative_decomposition { + func.func public @test_cnot_alternative_decomposition() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT0:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit + // CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[RY1]] : !quantum.bit, !quantum.bit + // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit + // CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit + // CHECK-NOT: quantum.custom "CNOT" + %3, %4 = quantum.custom "CNOT"() %1, %2 : !quantum.bit, !quantum.bit + + // CHECK: [[FINAL_INSERT1:%.+]] = quantum.insert [[REG]][ 0], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_INSERT2:%.+]] = quantum.insert [[FINAL_INSERT1]][ 1], [[RY2]] : !quantum.reg, !quantum.bit + %5 = quantum.insert %0[ 0], %3 : !quantum.reg, !quantum.bit + %6 = quantum.insert %5[ 1], %4 : !quantum.reg, !quantum.bit + %7 = quantum.compbasis qreg %6 : !quantum.obs + %8 = quantum.probs %7 : tensor<4xf64> + quantum.dealloc %6 : !quantum.reg + return %8 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @CNOT_rule_h_cnot_h + func.func private @CNOT_rule_h_cnot_h(%arg0: !quantum.bit, %arg1: !quantum.bit) -> (!quantum.bit, !quantum.bit) attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { + // CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H) + %cst = arith.constant 1.5707963267948966 : f64 + %cst_0 = arith.constant 3.1415926535897931 : f64 + + // Step 1: Apply H to target qubit (H = RZ(π) * RY(π/2)) + %out_qubits = quantum.custom "RZ"(%cst_0) %arg1 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst) %out_qubits : !quantum.bit + + // Step 2: Apply CZ gate + %out_qubits_2:2 = quantum.custom "CZ"() %arg0, %out_qubits_1 : !quantum.bit, !quantum.bit + + // Step 3: Apply H to target qubit again + %out_qubits_3 = quantum.custom "RZ"(%cst_0) %out_qubits_2#1 : !quantum.bit + %out_qubits_4 = quantum.custom "RY"(%cst) %out_qubits_3 : !quantum.bit + + return %out_qubits_2#0, %out_qubits_4 : !quantum.bit, !quantum.bit + } +} From 2651f9e001fdb67815d5624cc311dd6f60cb8866 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:12:53 -0400 Subject: [PATCH 26/36] Apply code review suggestions --- frontend/catalyst/from_plxpr/from_plxpr.py | 140 +++++++++--------- frontend/catalyst/passes/builtin_passes.py | 3 +- frontend/test/lit/test_decomposition.py | 95 ++++++------ .../test/pytest/from_plxpr/test_from_plxpr.py | 31 ++++ 4 files changed, 149 insertions(+), 120 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index cf5b5e5594..41c5895281 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -208,7 +208,7 @@ def handle_qnode( closed_jaxpr = ( ClosedJaxpr(qfunc_jaxpr, consts) if not self.requires_decompose_lowering - else apply_compiler_decompose_to_plxpr( + else _apply_compiler_decompose_to_plxpr( inner_jaxpr=qfunc_jaxpr, consts=consts, ncargs=non_const_args, @@ -217,7 +217,7 @@ def handle_qnode( ) if self.requires_decompose_lowering: - closed_jaxpr = collect_and_compile_graph_solutions( + closed_jaxpr = _collect_and_compile_graph_solutions( inner_jaxpr=closed_jaxpr.jaxpr, consts=closed_jaxpr.consts, tkwargs=self.decompose_tkwargs, @@ -273,74 +273,6 @@ def calling_convention(*args): } -def apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgateset, ncargs): - """Apply the compiler-specific decomposition for a given JAXPR. - - Args: - inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. - consts (list): The constants used in the JAXPR. - tgateset (list): A list of target gateset for decomposition. - ncargs (list): Non-constant arguments for the JAXPR. - qargs (list): All arguments including constants and non-constants. - - Returns: - ClosedJaxpr: The decomposed JAXPR. - """ - - # Disable the graph decomposition optimization - - # Why? Because for the compiler-specific decomposition we want to - # only decompose higher-level gates and templates that only have - # a single decomposition, and not do any further optimization - # based on the graph solution. - # Besides, the graph-based decomposition is not supported - # yet in from_plxpr for most gates and templates. - - # TODO: Enable the graph-based decomposition - qml.decomposition.disable_graph() - - # First perform the pre-mlir decomposition to simplify the jaxpr - # by decomposing high-level gates and templates - gate_set = COMPILER_OPERATIONS + tgateset - - final_jaxpr = qml.transforms.decompose.plxpr_transform( - inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs - ) - - qml.decomposition.enable_graph() - - return final_jaxpr - - -def collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): - """Collect and compile graph solutions for a given JAXPR. - - This function uses the DecompRuleInterpreter to evaluate - the input JAXPR and obtain a new JAXPR that incorporates - the graph-based decomposition solutions. - - This function doesn't modify the underlying quantum function - but rather constructs a new JAXPR with decomposition rules. - - Args: - inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. - consts (list): The constants used in the JAXPR. - tkwargs (list): The keyword arguments of the decompose transform. - ncargs (list): Non-constant arguments for the JAXPR. - - Returns: - ClosedJaxpr: The decomposed JAXPR. - """ - gds_interpreter = DecompRuleInterpreter(**tkwargs) - - def gds_wrapper(*args): - return gds_interpreter.eval(inner_jaxpr, consts, *args) - - final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs) - - return final_jaxpr - - # pylint: disable-next=redefined-outer-name def register_transform(pl_transform, pass_name, decomposition): """Register pennylane transforms and their conversion to Catalyst transforms""" @@ -911,6 +843,74 @@ def trace_from_pennylane( return jaxpr, out_type, out_treedef, sig +def _apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgateset, ncargs): + """Apply the compiler-specific decomposition for a given JAXPR. + + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + tgateset (list): A list of target gateset for decomposition. + ncargs (list): Non-constant arguments for the JAXPR. + qargs (list): All arguments including constants and non-constants. + + Returns: + ClosedJaxpr: The decomposed JAXPR. + """ + + # Disable the graph decomposition optimization + + # Why? Because for the compiler-specific decomposition we want to + # only decompose higher-level gates and templates that only have + # a single decomposition, and not do any further optimization + # based on the graph solution. + # Besides, the graph-based decomposition is not supported + # yet in from_plxpr for most gates and templates. + + # TODO: Enable the graph-based decomposition + qml.decomposition.disable_graph() + + # First perform the pre-mlir decomposition to simplify the jaxpr + # by decomposing high-level gates and templates + gate_set = set(COMPILER_OPERATIONS + tgateset) + + final_jaxpr = qml.transforms.decompose.plxpr_transform( + inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs + ) + + qml.decomposition.enable_graph() + + return final_jaxpr + + +def _collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): + """Collect and compile graph solutions for a given JAXPR. + + This function uses the DecompRuleInterpreter to evaluate + the input JAXPR and obtain a new JAXPR that incorporates + the graph-based decomposition solutions. + + This function doesn't modify the underlying quantum function + but rather constructs a new JAXPR with decomposition rules. + + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + tkwargs (list): The keyword arguments of the decompose transform. + ncargs (list): Non-constant arguments for the JAXPR. + + Returns: + ClosedJaxpr: The decomposed JAXPR. + """ + gds_interpreter = DecompRuleInterpreter(**tkwargs) + + def gds_wrapper(*args): + return gds_interpreter.eval(inner_jaxpr, consts, *args) + + final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs) + + return final_jaxpr + + def _get_operator_name(op): """Get the name of a pennylane operator, handling wrapped operators. diff --git a/frontend/catalyst/passes/builtin_passes.py b/frontend/catalyst/passes/builtin_passes.py index e39cdb8cf8..1c1b924d5c 100644 --- a/frontend/catalyst/passes/builtin_passes.py +++ b/frontend/catalyst/passes/builtin_passes.py @@ -394,7 +394,6 @@ def circuit(x: float): return PassPipelineWrapper(qnode, "merge-rotations") -# pragma: no cover def decompose_lowering(qnode): """ Specify that the ``-decompose-lowering`` MLIR compiler pass @@ -411,7 +410,7 @@ def decompose_lowering(qnode): // TODO: add example here """ - return PassPipelineWrapper(qnode, "decompose-lowering") + return PassPipelineWrapper(qnode, "decompose-lowering") # pragma: no cover def ions_decomposition(qnode): # pragma: nocover diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 53042a6a71..c965ac50b1 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -880,8 +880,8 @@ def test_qft_decomposition(): ) @qml.qnode(qml.device("lightning.qubit", wires=4)) # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" - # CHECK: func.func public @circuit_16(%arg0: tensor<3xf64>) -> tensor attributes {decompose_gatesets - def circuit_16(): + # CHECK: func.func public @circuit_18(%arg0: tensor<3xf64>) -> tensor attributes {decompose_gatesets + def circuit_18(): # %6 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %0) -> (!quantum.reg) { # %23 = scf.for %arg3 = %c0 to %22 step %c1 iter_args(%arg4 = %21) -> (!quantum.reg) { # %7 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %6) -> (!quantum.reg) { @@ -893,7 +893,7 @@ def circuit_16(): # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} # CHECK-DAG: func.func public @_swap_to_cnot(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SWAP"} # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} - print(circuit_16.mlir) + print(circuit_18.mlir) test_qft_decomposition() @@ -920,7 +920,7 @@ def test_decompose_lowering_with_other_passes(): # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to [[TWO]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> # CHECK-NEXT: transform.yield # CHECK-NEXT: } - def circuit_17(): + def circuit_19(): # CHECK: [[OUT_0:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "PauliX"() [[OUT_0]] : !quantum.bit @@ -934,12 +934,50 @@ def circuit_17(): # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"} # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} - print(circuit_17.mlir) + print(circuit_19.mlir) test_decompose_lowering_with_other_passes() +def test_decompose_lowering_multirz(): + """Test the decompose lowering pass with MultiRZ in the gate set.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"CNOT", "RZ"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" + def circuit_20(x: float): + # CHECK: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.multirz([[EXTRACTED]]) %1 : !quantum.bit + # CHECK-NEXT: [[BIT_1:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[EXTRACTED_0:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_QUBITS_1:%.+]] = quantum.multirz([[EXTRACTED_0]]) [[OUT_QUBITS]], [[BIT_1]] : !quantum.bit, !quantum.bit + # CHECK-NEXT: [[BIT_2:%.+]] = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: {{%.+}} = quantum.multirz([[EXTRACTED_2]]) {{%.+}}, {{%.+}}, [[BIT_2]] : !quantum.bit, !quantum.bit, !quantum.bit + qml.MultiRZ(x, wires=[0]) + qml.MultiRZ(x, wires=[0, 1]) + qml.MultiRZ(x, wires=[1, 0, 2]) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_3(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: %0 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %arg0) -> (!quantum.reg) + # CHECK-DAG: %5 = scf.for %arg3 = %c1 to %c3 step %c1 iter_args(%arg4 = %4) -> (!quantum.reg) + print(circuit_20.mlir) + + +test_decompose_lowering_multirz() + + def test_decompose_lowering_with_ordered_passes(): """Test the decompose lowering pass with other passes in a specific order in a pass pipeline.""" @@ -961,7 +999,7 @@ def test_decompose_lowering_with_ordered_passes(): # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "decompose-lowering" to [[SECOND]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> # CHECK-NEXT: transform.yield # CHECK-NEXT: } - def circuit_18(x: float): + def circuit_21(x: float): # CHECK: [[OUT:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit # CHECK-NEXT: [[OUT_0:%.+]] = quantum.custom "PauliX"() [[OUT]] : !quantum.bit # CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor @@ -978,50 +1016,12 @@ def circuit_18(x: float): # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"} # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} - print(circuit_18.mlir) + print(circuit_21.mlir) test_decompose_lowering_with_ordered_passes() -def test_decompose_lowering_multirz(): - """Test the decompose lowering pass with MultiRZ in the gate set.""" - - qml.capture.enable() - qml.decomposition.enable_graph() - - @qml.qjit(target="mlir") - @partial( - qml.transforms.decompose, - gate_set={"CNOT", "RZ"}, - ) - @qml.qnode(qml.device("lightning.qubit", wires=3)) - # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" - def circuit_19(x: float): - # CHECK: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor - # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.multirz([[EXTRACTED]]) %1 : !quantum.bit - # CHECK-NEXT: [[BIT_1:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit - # CHECK-NEXT: [[EXTRACTED_0:%.+]] = tensor.extract %arg0[] : tensor - # CHECK-NEXT: [[OUT_QUBITS_1:%.+]] = quantum.multirz([[EXTRACTED_0]]) [[OUT_QUBITS]], [[BIT_1]] : !quantum.bit, !quantum.bit - # CHECK-NEXT: [[BIT_2:%.+]] = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit - # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract %arg0[] : tensor - # CHECK-NEXT: {{%.+}} = quantum.multirz([[EXTRACTED_2]]) {{%.+}}, {{%.+}}, [[BIT_2]] : !quantum.bit, !quantum.bit, !quantum.bit - qml.MultiRZ(x, wires=[0]) - qml.MultiRZ(x, wires=[0, 1]) - qml.MultiRZ(x, wires=[1, 0, 2]) - return qml.expval(qml.PauliX(0)) - - # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} - # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "MultiRZ"} - # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_3(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} - # CHECK-DAG: %0 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %arg0) -> (!quantum.reg) - # CHECK-DAG: %5 = scf.for %arg3 = %c1 to %c3 step %c1 iter_args(%arg4 = %4) -> (!quantum.reg) - print(circuit_19.mlir) - - -test_decompose_lowering_multirz() - - def test_decompose_lowering_with_gphase(): """Test the decompose lowering pass with GlobalPhase.""" @@ -1035,8 +1035,7 @@ def test_decompose_lowering_with_gphase(): ) @qml.qnode(qml.device("lightning.qubit", wires=3)) # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" - - def circuit_20(): + def circuit_22(): # CHECK: quantum.gphase(%cst_0) : # CHECK-NEXT: [[EXTRACTED:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.custom "PhaseShift"(%cst) [[EXTRACTED]] : !quantum.bit @@ -1048,7 +1047,7 @@ def circuit_20(): # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"} # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} - print(circuit_20.mlir) + print(circuit_22.mlir) test_decompose_lowering_with_gphase() diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py index 47b0b65be8..29f83fdc87 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py @@ -15,6 +15,8 @@ This module tests the from_plxpr conversion function. """ +from functools import partial + import jax import numpy as np import pennylane as qml @@ -965,5 +967,34 @@ def workflow(x, y): assert qml.math.allclose(results, expected) +class TestGraphDecomposition: + """Test the new graph-based decomposition integration with from_plxpr.""" + + def test_with_multiple_decomps_transforms(self): + """Test that a circuit with multiple decompositions and transforms can be converted.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RY"}, + ) + @partial( + qml.transforms.decompose, + gate_set={"NOT", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=0)) + def circuit(x): + qml.GlobalPhase(x) + return qml.expval(qml.PauliX(0)) + + with pytest.raises( + NotImplementedError, match="Multiple decomposition transforms are not yet supported." + ): + circuit(0.2) + + if __name__ == "__main__": pytest.main(["-x", __file__]) From e879db8b274b3ecba6466c958cf0fbcb2f8d5f82 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:44:43 -0400 Subject: [PATCH 27/36] Update tests --- doc/releases/changelog-dev.md | 1 - frontend/catalyst/from_plxpr/decompose.py | 4 +--- frontend/catalyst/from_plxpr/from_plxpr.py | 2 ++ frontend/catalyst/passes/pass_api.py | 2 +- frontend/test/lit/test_decomposition.py | 16 ++++++++++++++++ frontend/test/lit/test_from_plxpr.py | 5 ++++- .../test/pytest/from_plxpr/test_from_plxpr.py | 3 +++ .../Catalyst/Transforms/RegisterAllPasses.cpp | 2 +- runtime/include/RuntimeCAPI.h | 1 - runtime/lib/capi/RuntimeCAPI.cpp | 8 -------- 10 files changed, 28 insertions(+), 16 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8970f8f8f8..290b8ff1a8 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -165,7 +165,6 @@ return qml.probs() ``` -

Improvements 🛠

* Significantly improved resource tracking with `null.qubit`. diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index f3d0555be4..f169212b0d 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -190,9 +190,7 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess # and compile them to Catalyst JAXPR decomposition rules for op, rule in self._decomp_graph_solution.items(): # Get number of wires if exists - op_num_wires = ( - op.op.params.get("num_wires", None) if hasattr(op.op, "params") else None - ) + op_num_wires = op.op.params.get("num_wires", None) if ( o := next( diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index ef5fa9329d..a9fcbf0784 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -15,6 +15,8 @@ This submodule defines a utility for converting plxpr into Catalyst jaxpr. """ # pylint: disable=protected-access +# pylint: disable=too-many-lines + from copy import copy from functools import partial from typing import Callable diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py index 2c59d53e68..33cefda4a7 100644 --- a/frontend/catalyst/passes/pass_api.py +++ b/frontend/catalyst/passes/pass_api.py @@ -374,10 +374,10 @@ def dictionary_to_list_of_passes(pass_pipeline: PipelineDict | str, *flags, **va def _API_name_to_pass_name(): return { "cancel_inverses": "remove-chained-self-inverse", + "decompose_lowering": "decompose-lowering", "disentangle_cnot": "disentangle-CNOT", "disentangle_swap": "disentangle-SWAP", "merge_rotations": "merge-rotations", - "decompose_lowering": "decompose-lowering", "ions_decomposition": "ions-decomposition", "to_ppr": "to-ppr", "commute_ppr": "commute-ppr", diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index c965ac50b1..c31d0d86b1 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -31,6 +31,7 @@ # RUN: %PYTHON %s | FileCheck %s # pylint: disable=line-too-long +# pylint: disable=too-many-lines TEST_PATH = os.path.dirname(__file__) @@ -895,6 +896,9 @@ def circuit_18(): # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} print(circuit_18.mlir) + qml.decomposition.disable_graph() + qml.capture.disable() + test_qft_decomposition() @@ -936,6 +940,9 @@ def circuit_19(): # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} print(circuit_19.mlir) + qml.decomposition.disable_graph() + qml.capture.disable() + test_decompose_lowering_with_other_passes() @@ -974,6 +981,9 @@ def circuit_20(x: float): # CHECK-DAG: %5 = scf.for %arg3 = %c1 to %c3 step %c1 iter_args(%arg4 = %4) -> (!quantum.reg) print(circuit_20.mlir) + qml.decomposition.disable_graph() + qml.capture.disable() + test_decompose_lowering_multirz() @@ -1018,6 +1028,9 @@ def circuit_21(x: float): # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} print(circuit_21.mlir) + qml.decomposition.disable_graph() + qml.capture.disable() + test_decompose_lowering_with_ordered_passes() @@ -1049,5 +1062,8 @@ def circuit_22(): # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} print(circuit_22.mlir) + qml.decomposition.disable_graph() + qml.capture.disable() + test_decompose_lowering_with_gphase() diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index f87678c3d7..14cea4a8e6 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -47,7 +47,7 @@ def main(): print(main.mlir) - qml.capture.enable() + qml.capture.disable() test_conditional_capture() @@ -416,5 +416,8 @@ def circuit3(): print(circuit3.mlir) + qml.decomposition.disable_graph() + qml.capture.disable() + test_pass_decomposition() diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py index 29f83fdc87..1f45619b7a 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py @@ -995,6 +995,9 @@ def circuit(x): ): circuit(0.2) + qml.decomposition.disable_graph() + qml.capture.disable() + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 73f8327216..88d97a8674 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -38,6 +38,7 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createCliffordTToPPRPass); mlir::registerPass(catalyst::createMergePPRIntoPPMPass); mlir::registerPass(catalyst::createPPMCompilationPass); + mlir::registerPass(catalyst::createDecomposeLoweringPass); mlir::registerPass(catalyst::createDecomposeNonCliffordPPRPass); mlir::registerPass(catalyst::createDecomposeCliffordPPRPass); mlir::registerPass(catalyst::createCountPPMSpecsPass); @@ -68,7 +69,6 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createRegisterInactiveCallbackPass); mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); mlir::registerPass(catalyst::createMergeRotationsPass); - mlir::registerPass(catalyst::createDecomposeLoweringPass); mlir::registerPass(catalyst::createScatterLoweringPass); mlir::registerPass(catalyst::createStablehloLegalizeControlFlowPass); mlir::registerPass(catalyst::createStablehloLegalizeSortPass); diff --git a/runtime/include/RuntimeCAPI.h b/runtime/include/RuntimeCAPI.h index b248979a11..87414df96f 100644 --- a/runtime/include/RuntimeCAPI.h +++ b/runtime/include/RuntimeCAPI.h @@ -63,7 +63,6 @@ void __catalyst__qis__RX(double, QUBIT *, const Modifiers *); void __catalyst__qis__RY(double, QUBIT *, const Modifiers *); void __catalyst__qis__RZ(double, QUBIT *, const Modifiers *); void __catalyst__qis__Rot(double, double, double, QUBIT *, const Modifiers *); -void __catalyst__qis__RotXZX(double, double, double, QUBIT *, const Modifiers *); void __catalyst__qis__CNOT(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__CY(QUBIT *, QUBIT *, const Modifiers *); void __catalyst__qis__CZ(QUBIT *, QUBIT *, const Modifiers *); diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 893af5a7fd..94ed41e13a 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -627,14 +627,6 @@ void __catalyst__qis__Rot(double phi, double theta, double omega, QUBIT *qubit, MODIFIERS_ARGS(modifiers)); } -void __catalyst__qis__RotXZX(double phi, double theta, double omega, QUBIT *qubit, - const Modifiers *modifiers) -{ - getQuantumDevicePtr()->NamedOperation("RotXZX", {phi, theta, omega}, - {reinterpret_cast(qubit)}, - MODIFIERS_ARGS(modifiers)); -} - void __catalyst__qis__CNOT(QUBIT *control, QUBIT *target, const Modifiers *modifiers) { RT_FAIL_IF(control == target, From 12279d066a1aaa879db238f0b4c2ce2732e5dbc9 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Tue, 23 Sep 2025 23:53:24 -0400 Subject: [PATCH 28/36] Update --- frontend/catalyst/from_plxpr/decompose.py | 112 ++++++++++-------- frontend/test/lit/test_decomposition.py | 14 +-- .../test/pytest/from_plxpr/test_from_plxpr.py | 6 +- 3 files changed, 71 insertions(+), 61 deletions(-) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index f169212b0d..232d2ab886 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -29,6 +29,7 @@ # DecompRuleInterpreter: from pennylane.decomposition import DecompositionGraph +from pennylane.typing import TensorLike from pennylane.wires import WiresLike from catalyst.jax_primitives import decomposition_rule @@ -63,7 +64,8 @@ class DecompRuleInterpreter(qml.capture.PlxprInterpreter): TypeError: if graph-based decomposition is not enabled. """ - # A mapping from operation names to the number of wires they act on. + # A mapping from operation names to the number of wires they act on + # and the number of parameters they have. # This is used when the operation is not in the captured operations # but we still need to create a decomposition rule for it. # @@ -74,43 +76,43 @@ class DecompRuleInterpreter(qml.capture.PlxprInterpreter): # This will require a copy of the function to be made # when creating the decomposition rule to avoid mutating # the original function with attributes like num_wires. - compiler_ops_num_wires: dict[str, int] = { - "CNOT": 2, - "ControlledPhaseShift": 2, - "CRot": 2, - "CRX": 2, - "CRY": 2, - "CRZ": 2, - "CSWAP": 3, - "CY": 2, - "CZ": 2, - "Hadamard": 1, - "Identity": 1, - "IsingXX": 2, - "IsingXY": 2, - "IsingYY": 2, - "IsingZZ": 2, - "SingleExcitation": 2, - "DoubleExcitation": 4, - "ISWAP": 2, - "PauliX": 1, - "PauliY": 1, - "PauliZ": 1, - "PhaseShift": 1, - "PSWAP": 2, - "Rot": 1, - "RX": 1, - "RY": 1, - "RZ": 1, - "S": 1, - "SWAP": 2, - "T": 1, - "Toffoli": 3, - "U1": 1, - "U2": 1, - "U3": 1, - "MultiRZ": -1, # variable number of wires - "GlobalPhase": -1, # variable number of wires + compiler_ops_num_wires: dict[str, tuple[int, int]] = { + "CNOT": (2, 0), + "ControlledPhaseShift": (2, 1), + "CRot": (2, 3), + "CRX": (2, 1), + "CRY": (2, 1), + "CRZ": (2, 1), + "CSWAP": (3, 0), + "CY": (2, 0), + "CZ": (2, 0), + "Hadamard": (1, 0), + "Identity": (1, 0), + "IsingXX": (2, 1), + "IsingXY": (2, 1), + "IsingYY": (2, 1), + "IsingZZ": (2, 1), + "SingleExcitation": (2, 1), + "DoubleExcitation": (4, 1), + "ISWAP": (2, 0), + "PauliX": (1, 0), + "PauliY": (1, 0), + "PauliZ": (1, 0), + "PhaseShift": (1, 1), + "PSWAP": (2, 1), + "Rot": (1, 3), + "RX": (1, 1), + "RY": (1, 1), + "RZ": (1, 1), + "S": (1, 0), + "SWAP": (2, 0), + "T": (1, 0), + "Toffoli": (3, 0), + "U1": (1, 1), + "U2": (1, 2), + "U3": (1, 3), + "MultiRZ": (-1, 1), + "GlobalPhase": (-1, 1), } def __init__( @@ -191,7 +193,6 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess for op, rule in self._decomp_graph_solution.items(): # Get number of wires if exists op_num_wires = op.op.params.get("num_wires", None) - if ( o := next( ( @@ -202,11 +203,13 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess None, ) ) is not None: + num_wires, num_params = self.compiler_ops_num_wires[op.op.name] _create_decomposition_rule( rule, op_name=op.op.name, num_wires=len(o.wires), - requires_copy=self.compiler_ops_num_wires[op.op.name] == -1, + num_params=num_params, + requires_copy=num_wires == -1, ) elif op.op.name in self.compiler_ops_num_wires: # In this part, we need to handle the case where an operation in @@ -215,12 +218,13 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess # in the circuit, but is used inside a decomposition rule. # In this case, we fall back to using the compiler_ops_num_wires # dictionary to get the number of wires. - num_wires = self.compiler_ops_num_wires[op.op.name] + num_wires, num_params = self.compiler_ops_num_wires[op.op.name] _create_decomposition_rule( rule, op_name=op.op.name, num_wires=num_wires, - requires_copy=self.compiler_ops_num_wires[op.op.name] == -1, + num_params=num_params, + requires_copy=num_wires == -1, ) else: # pragma: no cover raise ValueError(f"Could not capture {op} without the number of wires.") @@ -230,7 +234,7 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess def _create_decomposition_rule( - func: Callable, op_name: str, num_wires: int, requires_copy: bool = False + func: Callable, op_name: str, num_wires: int, num_params: int, requires_copy: bool = False ): """Create a decomposition rule from a callable. @@ -240,9 +244,10 @@ def _create_decomposition_rule( func (Callable): The decomposition function. op_name (str): The name of the operation to decompose. num_wires (int): The number of wires the operation acts on. - - Returns: - None: The function is decorated in place. + num_params (int): The number of parameters the operation takes. + requires_copy (bool): Whether to create a copy of the function + to avoid mutating the original. This is required for operations + with a variable number of wires (e.g., MultiRZ, GlobalPhase). """ sig_func = inspect.signature(func) @@ -259,23 +264,26 @@ def _create_decomposition_rule( # TODO: This is a temporary solution until all rules have proper type annotations. # Why? Because we need to pass the correct types to the decomposition_rule # function to capture the rule correctly with JAX. - possible_names_for_params = { - "params", + possible_names_for_single_param = { "param", - "parameters", - "angles", "angle", "phi", "omega", "theta", - "weights", "weight", } + possible_names_for_multi_params = { + "params", + "angles", + "weights", + } # TODO: Support work-wires when it's supported in Catalyst. possible_names_for_wires = {"wires", "wire", "control_wires", "target_wires"} - if typ is float or name in possible_names_for_params: + if typ is TensorLike or name in possible_names_for_multi_params: + args[name] = qml.math.array([0.0] * num_params, like="jax", dtype=float) + elif typ is float or name in possible_names_for_single_param: # TensorLike is a Union of float, int, array-like, so we use float here # to cover the most common case as the JAX tracer doesn't like Union types # and we don't have the actual values at this point. diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index c31d0d86b1..4c4c340d09 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -779,10 +779,10 @@ def circuit_15(): qml.DoubleExcitation(0.5, wires=[0, 1, 2, 3]) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} + # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} # CHECK-DAG: func.func public @_s_phaseshift(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "S"} # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"} - # func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} # CHECK-DAG: func.func public @_doublexcit(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 4 : i64, target_gate = "DoubleExcitation"} # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} @@ -856,7 +856,7 @@ def circuit_17(): return qml.expval(qml.Z(0)) # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"} - # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} + # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} # CHECK-DAG: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RY"} # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} print(circuit_17.mlir) @@ -889,7 +889,7 @@ def circuit_18(): qml.QFT(wires=[0, 1, 2, 3]) return qml.expval(qml.Z(0)) - # CHECK-DAG: func.func public @_cphase_to_rz_cnot(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "ControlledPhaseShift"} + # CHECK-DAG: func.func public @_cphase_to_rz_cnot(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "ControlledPhaseShift"} # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} # CHECK-DAG: func.func public @_swap_to_cnot(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SWAP"} @@ -974,9 +974,9 @@ def circuit_20(x: float): qml.MultiRZ(x, wires=[1, 0, 2]) return qml.expval(qml.PauliX(0)) - # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} - # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_2(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "MultiRZ"} - # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_3(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_2(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_3(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} # CHECK-DAG: %0 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %arg0) -> (!quantum.reg) # CHECK-DAG: %5 = scf.for %arg3 = %c1 to %c3 step %c1 iter_args(%arg4 = %4) -> (!quantum.reg) print(circuit_20.mlir) diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py index 1f45619b7a..9d2ce69a2c 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py @@ -995,8 +995,10 @@ def circuit(x): ): circuit(0.2) - qml.decomposition.disable_graph() - qml.capture.disable() + qml.decomposition.disable_graph() + qml.capture.disable() + + assert qml.decomposition.enabled_graph() is False if __name__ == "__main__": From e7d071f1e4873b8fe73bb4d002f80fa5ab89a3ec Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Wed, 24 Sep 2025 12:49:43 -0400 Subject: [PATCH 29/36] Fix the issue with dynamic qubit allocs --- frontend/test/lit/test_dynamic_qubit_allocation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/test/lit/test_dynamic_qubit_allocation.py b/frontend/test/lit/test_dynamic_qubit_allocation.py index 85b3ede88a..8b4d7f114c 100644 --- a/frontend/test/lit/test_dynamic_qubit_allocation.py +++ b/frontend/test/lit/test_dynamic_qubit_allocation.py @@ -82,7 +82,8 @@ def test_basic_dynalloc(): # CHECK: [[CNOTout:%.+]]:2 = quantum.custom "CNOT"() [[dyn_bit2]], [[dev_bit1]] # CHECK: [[insert0:%.+]] = quantum.insert [[dyn_qreg]][ 1], [[Xout]] # CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 2], [[CNOTout]]#0 - # CHECK: quantum.dealloc [[insert1]] + # CHECK: [[insert2:%.+]] = quantum.insert [[insert1]][ 3] + # CHECK: quantum.dealloc [[insert2]] with qml.allocate(4) as qs1: qml.X(qs1[1]) From b28f1d64e6fe1a855d39735ce321016ee4bef4a2 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Wed, 24 Sep 2025 13:22:02 -0400 Subject: [PATCH 30/36] Update lit tests --- frontend/test/lit/test_decomposition.py | 76 +++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 4c4c340d09..99b44d5dd6 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -1067,3 +1067,79 @@ def circuit_22(): test_decompose_lowering_with_gphase() + + +def test_decompose_lowering_alt_decomps(): + """Test the decompose lowering pass with alternative decompositions.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 1}) + def custom_rot_cheap(params, wires: WiresLike): + qml.RY(params[1], wires=wires) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RZ"}, + alt_decomps={qml.Rot: [custom_rot_cheap]}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000) + def circ(x: float, y: float): + qml.Rot(x, y, x + y, wires=1) + return qml.expval(qml.PauliZ(0)) + + # CHECK-DAG: func.func public @custom_rot_cheap(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circ.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_alt_decomps() + + +def test_decompose_lowering_with_tensorlike(): + """Test the decompose lowering pass with fixed decompositions + using TensorLike parameters.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RZ: 2, qml.RY: 1}) + def custom_rot(params: TensorLike, wires: WiresLike): + qml.RZ(params[0], wires=wires) + qml.RY(params[1], wires=wires) + qml.RZ(params[2], wires=wires) + + @qml.register_resources({qml.RZ: 1, qml.CNOT: 4}) + def custom_multirz(params: TensorLike, wires: WiresLike): + qml.CNOT(wires=(wires[2], wires[1])) + qml.CNOT(wires=(wires[1], wires[0])) + qml.RZ(params[0], wires=wires[0]) + qml.CNOT(wires=(wires[1], wires[0])) + qml.CNOT(wires=(wires[2], wires[1])) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", qml.CNOT}, + fixed_decomps={qml.Rot: custom_rot, qml.MultiRZ: custom_multirz}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000) + def circ(x: float, y: float): + qml.Rot(x, y, x + y, wires=1) + qml.MultiRZ(x + y, wires=[0, 1, 2]) + return qml.expval(qml.PauliZ(0)) + + # CHECK-DAG: func.func public @custom_multirz_wires_3(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @custom_rot(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circ.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_with_tensorlike() From 974eea34184a1c47c0dce405de572ad8347c001b Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:26:02 -0400 Subject: [PATCH 31/36] Update tests --- frontend/catalyst/device/qjit_device.py | 10 -- frontend/catalyst/from_plxpr/decompose.py | 122 +++++++++++---------- frontend/catalyst/from_plxpr/from_plxpr.py | 5 +- frontend/test/lit/test_decomposition.py | 14 +-- 4 files changed, 74 insertions(+), 77 deletions(-) diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index dec5a8176b..df1fc2675d 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -108,16 +108,6 @@ RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] -# A list of operations that can be represented -# in the Catalyst compiler. This will be a superset of -# the operations supported by the runtime. -# FIXME: ops with OpName(params, wires) signatures can be -# represented in the Catalyst compiler. Unfortunately, -# the signature info is not sufficient as there are -# templates with the same signature that should be -# disambiguated. -COMPILER_OPERATIONS = RUNTIME_OPERATIONS - # The runtime interface does not care about specific gate properties, so set them all to True. RUNTIME_OPERATIONS = { op: OperatorProperties(invertible=True, controllable=True, differentiable=True) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py index 232d2ab886..0fd15edbcf 100644 --- a/frontend/catalyst/from_plxpr/decompose.py +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -26,14 +26,73 @@ import jax import pennylane as qml - -# DecompRuleInterpreter: from pennylane.decomposition import DecompositionGraph from pennylane.typing import TensorLike from pennylane.wires import WiresLike from catalyst.jax_primitives import decomposition_rule +# A mapping from operation names to the number of wires they act on +# and the number of parameters they have. +# This is used when the operation is not in the captured operations +# but we still need to create a decomposition rule for it. +# +# Note that some operations have a variable number of wires, +# e.g., MultiRZ, GlobalPhase. For these, we set the number +# of wires to -1 to indicate a variable number. +# +# This will require a copy of the function to be made +# when creating the decomposition rule to avoid mutating +# the original function with attributes like num_wires. + +# A list of operations that can be represented +# in the Catalyst compiler. This will be a superset of +# the operations supported by the runtime. + +# FIXME: ops with OpName(params, wires) signatures can be +# represented in the Catalyst compiler. Unfortunately, +# the signature info is not sufficient as there are +# templates with the same signature that should be +# disambiguated. +COMPILER_OPS_FOR_DECOMPOSITION: dict[str, tuple[int, int]] = { + "CNOT": (2, 0), + "ControlledPhaseShift": (2, 1), + "CRot": (2, 3), + "CRX": (2, 1), + "CRY": (2, 1), + "CRZ": (2, 1), + "CSWAP": (3, 0), + "CY": (2, 0), + "CZ": (2, 0), + "Hadamard": (1, 0), + "Identity": (1, 0), + "IsingXX": (2, 1), + "IsingXY": (2, 1), + "IsingYY": (2, 1), + "IsingZZ": (2, 1), + "SingleExcitation": (2, 1), + "DoubleExcitation": (4, 1), + "ISWAP": (2, 0), + "PauliX": (1, 0), + "PauliY": (1, 0), + "PauliZ": (1, 0), + "PhaseShift": (1, 1), + "PSWAP": (2, 1), + "Rot": (1, 3), + "RX": (1, 1), + "RY": (1, 1), + "RZ": (1, 1), + "S": (1, 0), + "SWAP": (2, 0), + "T": (1, 0), + "Toffoli": (3, 0), + "U1": (1, 1), + "U2": (1, 2), + "U3": (1, 3), + "MultiRZ": (-1, 1), + "GlobalPhase": (-1, 1), +} + # pylint: disable=too-few-public-methods class DecompRuleInterpreter(qml.capture.PlxprInterpreter): @@ -64,57 +123,6 @@ class DecompRuleInterpreter(qml.capture.PlxprInterpreter): TypeError: if graph-based decomposition is not enabled. """ - # A mapping from operation names to the number of wires they act on - # and the number of parameters they have. - # This is used when the operation is not in the captured operations - # but we still need to create a decomposition rule for it. - # - # Note that some operations have a variable number of wires, - # e.g., MultiRZ, GlobalPhase. For these, we set the number - # of wires to -1 to indicate a variable number. - # - # This will require a copy of the function to be made - # when creating the decomposition rule to avoid mutating - # the original function with attributes like num_wires. - compiler_ops_num_wires: dict[str, tuple[int, int]] = { - "CNOT": (2, 0), - "ControlledPhaseShift": (2, 1), - "CRot": (2, 3), - "CRX": (2, 1), - "CRY": (2, 1), - "CRZ": (2, 1), - "CSWAP": (3, 0), - "CY": (2, 0), - "CZ": (2, 0), - "Hadamard": (1, 0), - "Identity": (1, 0), - "IsingXX": (2, 1), - "IsingXY": (2, 1), - "IsingYY": (2, 1), - "IsingZZ": (2, 1), - "SingleExcitation": (2, 1), - "DoubleExcitation": (4, 1), - "ISWAP": (2, 0), - "PauliX": (1, 0), - "PauliY": (1, 0), - "PauliZ": (1, 0), - "PhaseShift": (1, 1), - "PSWAP": (2, 1), - "Rot": (1, 3), - "RX": (1, 1), - "RY": (1, 1), - "RZ": (1, 1), - "S": (1, 0), - "SWAP": (2, 0), - "T": (1, 0), - "Toffoli": (3, 0), - "U1": (1, 1), - "U2": (1, 2), - "U3": (1, 3), - "MultiRZ": (-1, 1), - "GlobalPhase": (-1, 1), - } - def __init__( self, *, @@ -203,7 +211,7 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess None, ) ) is not None: - num_wires, num_params = self.compiler_ops_num_wires[op.op.name] + num_wires, num_params = COMPILER_OPS_FOR_DECOMPOSITION[op.op.name] _create_decomposition_rule( rule, op_name=op.op.name, @@ -211,14 +219,14 @@ def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess num_params=num_params, requires_copy=num_wires == -1, ) - elif op.op.name in self.compiler_ops_num_wires: + elif op.op.name in COMPILER_OPS_FOR_DECOMPOSITION: # In this part, we need to handle the case where an operation in # the decomposition graph solution is not in the captured operations. # This can happen if the operation is not directly called # in the circuit, but is used inside a decomposition rule. - # In this case, we fall back to using the compiler_ops_num_wires + # In this case, we fall back to using the COMPILER_OPS_FOR_DECOMPOSITION # dictionary to get the number of wires. - num_wires, num_params = self.compiler_ops_num_wires[op.op.name] + num_wires, num_params = COMPILER_OPS_FOR_DECOMPOSITION[op.op.name] _create_decomposition_rule( rule, op_name=op.op.name, diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index a9fcbf0784..6ae7e5a0d3 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -46,8 +46,7 @@ from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot from catalyst.device import extract_backend_info -from catalyst.device.qjit_device import COMPILER_OPERATIONS -from catalyst.from_plxpr.decompose import DecompRuleInterpreter +from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder, get_in_qubit_values from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( @@ -958,7 +957,7 @@ def _apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgateset, ncargs): # First perform the pre-mlir decomposition to simplify the jaxpr # by decomposing high-level gates and templates - gate_set = set(COMPILER_OPERATIONS + tgateset) + gate_set = set(COMPILER_OPS_FOR_DECOMPOSITION.keys()).union(tgateset) final_jaxpr = qml.transforms.decompose.plxpr_transform( inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 99b44d5dd6..d53d450b25 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -804,12 +804,11 @@ def test_decomposition_rule_name_adjoint(): @qml.qjit(target="mlir") @partial( qml.transforms.decompose, - gate_set={"RY", "RX", "CZ", "GlobalPhase"}, + gate_set={"RY", "RX", "CZ", "GlobalPhase", "Adjoint(SingleExcitation)"}, ) @qml.qnode(qml.device("lightning.qubit", wires=4)) # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" - # CHECK: public @circuit_16() -> tensor attributes {decompose_gatesets - def circuit_16(): + def circuit_16(x: float): # CHECK-DAG: %1 = quantum.adjoint(%0) : !quantum.reg # CHECK-DAG: %2 = quantum.adjoint(%1) : !quantum.reg # CHECK-DAG: %3 = quantum.adjoint(%2) : !quantum.reg @@ -818,6 +817,7 @@ def circuit_16(): qml.adjoint(qml.Hadamard)(wires=2) qml.adjoint(qml.RZ)(0.5, wires=3) qml.adjoint(qml.SingleExcitation)(0.1, wires=[0, 1]) + qml.adjoint(qml.SingleExcitation(x, wires=[0, 1])) return qml.expval(qml.Z(0)) # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} @@ -1086,12 +1086,12 @@ def custom_rot_cheap(params, wires: WiresLike): alt_decomps={qml.Rot: [custom_rot_cheap]}, ) @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000) - def circ(x: float, y: float): + def circuit_23(x: float, y: float): qml.Rot(x, y, x + y, wires=1) return qml.expval(qml.PauliZ(0)) # CHECK-DAG: func.func public @custom_rot_cheap(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} - print(circ.mlir) + print(circuit_23.mlir) qml.decomposition.disable_graph() qml.capture.disable() @@ -1128,7 +1128,7 @@ def custom_multirz(params: TensorLike, wires: WiresLike): fixed_decomps={qml.Rot: custom_rot, qml.MultiRZ: custom_multirz}, ) @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000) - def circ(x: float, y: float): + def circuit_24(x: float, y: float): qml.Rot(x, y, x + y, wires=1) qml.MultiRZ(x + y, wires=[0, 1, 2]) return qml.expval(qml.PauliZ(0)) @@ -1136,7 +1136,7 @@ def circ(x: float, y: float): # CHECK-DAG: func.func public @custom_multirz_wires_3(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} # CHECK-DAG: func.func public @custom_rot(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} - print(circ.mlir) + print(circuit_24.mlir) qml.decomposition.disable_graph() qml.capture.disable() From 7bf749a2f9a0a47398b40c5a5a692bc6fc802d46 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:29:53 -0400 Subject: [PATCH 32/36] Add the draft C++ decomposition graph solver --- delightning/CMakeLists.txt | 8 ++ delightning/Makefile | 14 +++ delightning/src/main.cpp | 221 +++++++++++++++++++++++++++++++++++++ 3 files changed, 243 insertions(+) create mode 100644 delightning/CMakeLists.txt create mode 100644 delightning/Makefile create mode 100644 delightning/src/main.cpp diff --git a/delightning/CMakeLists.txt b/delightning/CMakeLists.txt new file mode 100644 index 0000000000..03949991e9 --- /dev/null +++ b/delightning/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.20) + +project(delightning VERSION 0.1.0 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(delightning src/main.cpp) diff --git a/delightning/Makefile b/delightning/Makefile new file mode 100644 index 0000000000..7b82711c13 --- /dev/null +++ b/delightning/Makefile @@ -0,0 +1,14 @@ +CXX=g++ +CXXFLAGS=-std=c++17 -Wall -Wextra -O2 + +TARGET=delightning +SRCS=src/main.cpp + +$(TARGET): $(SRCS) + $(CXX) $(CXXFLAGS) -o $(TARGET) $(SRCS) + +run: $(TARGET) + ./$(TARGET) + +clean: + rm -f $(TARGET) diff --git a/delightning/src/main.cpp b/delightning/src/main.cpp new file mode 100644 index 0000000000..f18e4d76bc --- /dev/null +++ b/delightning/src/main.cpp @@ -0,0 +1,221 @@ + +#include +#include +#include +#include +#include + + +// Operator +// ______________________________ +// * name: Operator name +// + getName(): string + +class Operator { + // TODO: string_view + std::string name; + +public: + Operator() = default; + explicit Operator(const std::string& name) : name(name) {} + std::string getName() const { return name; } + + bool operator==(const Operator& other) const { + return name == other.name; + } + bool operator!=(const Operator& other) const { + return !(*this == other); + } +}; + +namespace std { + template <> + struct hash { + std::size_t operator()(const Operator& op) const noexcept { + return std::hash()(op.getName()); + } + }; +} + + + + +// ResourceOp +// ______________________________ +// * resources>: Resources +// + getResources(): umap +// + total_cost(): int +// + op_cost(Operator): int +// + has_op(Operator): bool + +class ResourceOp { + std::unordered_map resources; + size_t total = 0; + +public: + ResourceOp() = default; + explicit ResourceOp(const std::unordered_map& resources) : resources(resources) {} + + const std::unordered_map& getResources() const { + return resources; + } + + size_t total_cost() { + if (total == 0 && !resources.empty()) { + for (const auto& pair : resources) { + total += pair.second; + } + } + return total; + } + + size_t op_cost(const Operator& op) const { + auto it = resources.find(op); + return (it != resources.end()) ? it->second : 0; + } + + bool has_op(const Operator& op) const { + return resources.find(op) != resources.end(); + } +}; + + + +// RuleRefOp +// _________________________________________ +// * Op: Operator +// * Resource: Resources +// * RuleRef: Pointer to the rule + +class RuleRefOp { + Operator op; + ResourceOp resources; + std::string rule_ref; + +public: + RuleRefOp(const Operator& op, const ResourceOp& resources, const std::string& rule_ref) + : op(op), resources(resources), rule_ref(rule_ref) {} + + const Operator getOperator() const { return op; } + const ResourceOp getResources() const { return resources; } + const std::string getRuleRef() const { return rule_ref; } +}; + + + +// Solver +// _________________________________________ +// * Ops>: Operators +// * Gateset>: Operators +// * Rules>: Rules +// + graph(): void +// + show(): stdout +// + solve(): map + +class Solver { + std::vector ops; + std::vector gateset; + std::vector rules; + +public: + Solver(const std::vector& ops, + const std::vector& gateset, + const std::vector& rules) + : ops(ops), gateset(gateset), rules(rules) {} + + void graph() { + // Placeholder for graph generation logic + std::cout << "Graph generation not implemented.\n"; + } + + std::unordered_map solve() { + std::unordered_map solutions; + + for (const auto& rule: rules) { + solutions[rule.getOperator()] = rule.getRuleRef(); + } + return solutions; + } + + void show() { + std::cout << "Not implemented.\n"; + } +}; + +// ---------------------------- +// Simple Tests +// ---------------------------- + +void test_operator() { + Operator op1("H"); + Operator op2("X"); + Operator op3("H"); + + assert(op1.getName() == "H"); + assert(op2.getName() == "X"); + assert(!(op1 == op2)); + assert(op1 != op2); + assert(op1 == op3); + + std::cout << "[PASS] Operator tests" << std::endl; +} + +void test_resourceop() { + Operator op1("H"); + Operator op2("X"); + + std::unordered_map res{{op1, 3}, {op2, 5}}; + ResourceOp r(res); + + assert(r.total_cost() == 8); + assert(r.op_cost(op1) == 3); + assert(r.op_cost(op2) == 5); + assert(r.op_cost(Operator("Z")) == 0); + assert(r.has_op(op1)); + assert(!r.has_op(Operator("Z"))); + + std::cout << "[PASS] ResourceOp tests" << std::endl; +} + + +void test_rulerefop() { + Operator op("CX"); + ResourceOp r({{op, 2}}); + RuleRefOp rr(op, r, "rule1"); + + assert(rr.getOperator() == op); + assert(rr.getRuleRef() == "rule1"); + assert(rr.getResources().op_cost(op) == 2); + + std::cout << "[PASS] RuleRefOp tests" << std::endl; +} + +void test_solver() { + Operator op1("H"), op2("X"); + ResourceOp r1({{op1, 1}}); + ResourceOp r2({{op2, 2}}); + + RuleRefOp rr1(op1, r1, "ruleH"); + RuleRefOp rr2(op2, r2, "ruleX"); + + Solver solver({op1, op2}, {"H", "X"}, {rr1, rr2}); + // solver.graph(); + // solver.show(); + + auto solution = solver.solve(); + assert(solution.at(op1) == "ruleH"); + assert(solution.at(op2) == "ruleX"); + + std::cout << "[PASS] Solver tests" << std::endl; +} + +int main() { + test_operator(); + test_resourceop(); + test_rulerefop(); + test_solver(); + + std::cout << "All tests passed!" << std::endl; + return 0; +} + From 4a170dea95292f113c033374d13a66bfb68dc820 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:32:16 -0400 Subject: [PATCH 33/36] Add Dijkstra --- delightning/src/main.cpp | 164 ++++++++++++++++++++++++++++++++------- 1 file changed, 138 insertions(+), 26 deletions(-) diff --git a/delightning/src/main.cpp b/delightning/src/main.cpp index f18e4d76bc..1865a191b4 100644 --- a/delightning/src/main.cpp +++ b/delightning/src/main.cpp @@ -4,7 +4,8 @@ #include #include #include - +#include +#include // Operator // ______________________________ @@ -96,9 +97,12 @@ class RuleRefOp { RuleRefOp(const Operator& op, const ResourceOp& resources, const std::string& rule_ref) : op(op), resources(resources), rule_ref(rule_ref) {} - const Operator getOperator() const { return op; } - const ResourceOp getResources() const { return resources; } - const std::string getRuleRef() const { return rule_ref; } + const Operator& getOperator() const { return op; } + const std::string& getRuleRef() const { return rule_ref; } + + // TODO: make this const ref to avoid copy overhead + // this is currently required for my simple Dijkstra sovler + ResourceOp getResources() const { return resources; } }; @@ -110,35 +114,90 @@ class RuleRefOp { // * Rules>: Rules // + graph(): void // + show(): stdout -// + solve(): map +// + solve(): map class Solver { std::vector ops; std::vector gateset; std::vector rules; + // Cached solutions map + std::unordered_map solutions; + public: Solver(const std::vector& ops, const std::vector& gateset, const std::vector& rules) : ops(ops), gateset(gateset), rules(rules) {} - void graph() { - // Placeholder for graph generation logic - std::cout << "Graph generation not implemented.\n"; - } std::unordered_map solve() { - std::unordered_map solutions; + if (!solutions.empty()) { + return solutions; + } - for (const auto& rule: rules) { - solutions[rule.getOperator()] = rule.getRuleRef(); + // We need to create a distance map for our Dijkstra's algorithm + // For now, I keep everything simple starting with max distance + // TODO: do this part implicitly for performance + std::unordered_map distances; + for (const auto& op : ops) { + distances[op] = std::numeric_limits::max(); } + + // There are different ways to implement Dijkstra's algorithm + // Here, I use a simple priority queue for demonstration + // TODO: optimize with a better priority queue or min-heap + using QElement = std::pair; // (distance, operator) + auto cmp = [](const QElement& left, const QElement& right) { return left.first > right.first; }; + std::priority_queue, decltype(cmp)> queue(cmp); + + // Initialize the queue with all operators and distance 0 + for (const auto& op : ops) { + queue.push({0, op}); + } + + // Dijkstra's algorithm main loop + while (!queue.empty()) { + auto [current_distance, current_op] = queue.top(); + queue.pop(); + // If we found a better path, skip processing + if (current_distance > distances[current_op]) { + continue; + } + + // std::cerr << "[DEBUG] Exploring neighbors of operator: " + // << current_op.getName() << "\n"; + + // Explore neighbors :) + for (const auto& rule: rules) { + if (rule.getOperator() == current_op) { + // std::cerr << "[DEBUG] Found applicable rule: " << rule.getRuleRef() + // << " for operator: " << current_op.getName() + // << " with total cost: " << rule.getResources().total_cost() + // << "\n"; + size_t new_distance = current_distance + rule.getResources().total_cost(); + if (new_distance < distances[current_op]) { + // std::cerr << "[DEBUG] Updating distance for operator: " << current_op.getName() + // << " from " << distances[current_op] + // << " to " << new_distance << "\n"; + distances[current_op] = new_distance; + queue.push({new_distance, current_op}); + + // Update solution + solutions[current_op] = rule.getRuleRef(); + } + } + } + } + return solutions; } void show() { - std::cout << "Not implemented.\n"; + for (const auto& [op, rule] : solutions) { + std::cout << "Operator " << op.getName() + << " decomposed using rule: " << rule << "\n"; + } } }; @@ -190,30 +249,83 @@ void test_rulerefop() { std::cout << "[PASS] RuleRefOp tests" << std::endl; } -void test_solver() { - Operator op1("H"), op2("X"); - ResourceOp r1({{op1, 1}}); - ResourceOp r2({{op2, 2}}); +void test_solver1() { + + Operator cnot("CNOT"); + Operator cz("CZ"); + Operator h("H"); + + ResourceOp cz_to_cnot({{cnot, 1}, {h, 2}}); + ResourceOp h_self({{h, 1}}); - RuleRefOp rr1(op1, r1, "ruleH"); - RuleRefOp rr2(op2, r2, "ruleX"); + RuleRefOp rule1(cz, cz_to_cnot, "cz_decomp_rule"); + RuleRefOp rule2(h, h_self, "h_rule"); - Solver solver({op1, op2}, {"H", "X"}, {rr1, rr2}); - // solver.graph(); + std::vector ops = {cz, h}; + std::vector gateset = {"CNOT", "H"}; + std::vector rules = {rule1, rule2}; + + Solver solver(ops, gateset, rules); + auto solutions = solver.solve(); + assert(solutions.size() == 2); + assert(solutions[cz] == "cz_decomp_rule"); + assert(solutions[h] == "h_rule"); // solver.show(); - auto solution = solver.solve(); - assert(solution.at(op1) == "ruleH"); - assert(solution.at(op2) == "ruleX"); + std::cout << "[PASS] Solver tests (1)" << std::endl; +} + +void test_solver2() { + + Operator cz("CZ"); + Operator cnot("CNOT"); + Operator h("H"); + Operator rz("RZ"); + Operator rx("RX"); + + ResourceOp cz_to_h_cnot({{h, 1}, {cnot, 1}}); + RuleRefOp rule1(cz, cz_to_h_cnot, "cz_h_cnot_rule"); + + ResourceOp cz_to_rx_rz_cnot({{rx, 1}, {rz, 1}, {cnot, 1}}); + RuleRefOp rule2(cz, cz_to_rx_rz_cnot, "cz_rx_rz_cnot_rule"); + + ResourceOp h_to_rz_rx_rz({{rz, 2}, {rx, 1}}); + RuleRefOp rule3(h, h_to_rz_rx_rz, "h_rz_rx_rz_rule"); + + ResourceOp h_to_rz_rz({{rz, 2}}); + RuleRefOp rule4(h, h_to_rz_rz, "h_rz_rz_rule"); - std::cout << "[PASS] Solver tests" << std::endl; + ResourceOp rz_self({{rz, 1}}); + RuleRefOp rule5(rz, rz_self, "rz_rule"); + + ResourceOp rx_self({{rx, 1}}); + RuleRefOp rule6(rx, rx_self, "rx_rule"); + + ResourceOp cnot_self({{cnot, 1}}); + RuleRefOp rule7(cnot, cnot_self, "cnot_rule"); + + std::vector ops = {cz, h}; + std::vector gateset = {"CNOT", "RZ", "RX"}; + std::vector rules = {rule1, rule2, rule3, rule4, rule5, rule6, rule7}; + + Solver solver(ops, gateset, rules); + + auto solutions = solver.solve(); + assert(solutions.size() == 2); + assert(solutions[cz] == "cz_h_cnot_rule"); + assert(solutions[h] == "h_rz_rz_rule"); + // solver.show(); + + std::cout << "[PASS] Solver tests (2)" << std::endl; } + int main() { test_operator(); test_resourceop(); test_rulerefop(); - test_solver(); + test_solver1(); + test_solver2(); std::cout << "All tests passed!" << std::endl; return 0; From 16f2107caf1b4d44ba82ff37cb707e8bab7e9a9f Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 25 Sep 2025 18:13:23 -0400 Subject: [PATCH 34/36] Update solver --- delightning/Makefile | 2 +- delightning/src/main.cpp | 223 ++++++++++++++++++++++++++++++++++----- 2 files changed, 199 insertions(+), 26 deletions(-) diff --git a/delightning/Makefile b/delightning/Makefile index 7b82711c13..1d338ddcd9 100644 --- a/delightning/Makefile +++ b/delightning/Makefile @@ -1,5 +1,5 @@ CXX=g++ -CXXFLAGS=-std=c++17 -Wall -Wextra -O2 +CXXFLAGS=-std=c++20 -Wall -Wextra -O2 TARGET=delightning SRCS=src/main.cpp diff --git a/delightning/src/main.cpp b/delightning/src/main.cpp index 1865a191b4..925a70fd41 100644 --- a/delightning/src/main.cpp +++ b/delightning/src/main.cpp @@ -6,6 +6,7 @@ #include #include #include +#include // Operator // ______________________________ @@ -110,28 +111,116 @@ class RuleRefOp { // Solver // _________________________________________ // * Ops>: Operators -// * Gateset>: Operators +// * Gateset>: Operators // * Rules>: Rules // + graph(): void // + show(): stdout // + solve(): map class Solver { +private: std::vector ops; - std::vector gateset; + std::vector gateset; std::vector rules; - // Cached solutions map + // Cached solutions and distances maps std::unordered_map solutions; + std::unordered_map distances; + + size_t computeCost(const RuleRefOp& rule) const { + size_t total = 0; + // std::cerr << "[DEBUG] Computing cost for rule: " << rule.getRuleRef() << "\n"; + for (const auto& [dep, count] : rule.getResources().getResources()) { + auto it = distances.find(dep); + // std::cerr << "[DEBUG] Dependency: " << dep.getName() + // << " with count: " << count + // << " and cost: " << it->second << "\n"; + if (it == distances.end() || it->second == std::numeric_limits::max()) { + // Dependency not found or unreachable yet :( + return std::numeric_limits::max(); + } + total += count * it->second; + } + return total; + } + + auto initGraph() { + using NodeOp = std::pair; + auto cmp = [](const NodeOp& left, const NodeOp& right) { + return left.first > right.first; + }; + std::priority_queue, decltype(cmp)> queue(cmp); + + for (const auto& op : ops) { + distances[op] = std::numeric_limits::max(); + queue.push({distances[op], op}); + } + + for (const auto& g : gateset) { + distances[g] = 1; + queue.push({1, g}); + solutions[g] = "base_op"; + } + return queue; + } public: Solver(const std::vector& ops, - const std::vector& gateset, + const std::vector& gateset, const std::vector& rules) : ops(ops), gateset(gateset), rules(rules) {} + bool isBasisGate(const Operator& op) const { + return std::find(gateset.begin(), gateset.end(), op) != gateset.end(); + } std::unordered_map solve() { + auto queue = initGraph(); + + while (!queue.empty()) { + auto [current_distance, current_op] = queue.top(); + queue.pop(); + + // If we found a better path, skip processing + if (current_distance > distances[current_op]) { + continue; + } + + // std::cerr << "[DEBUG] Exploring neighbors of operator: " + // << current_op.getName() << "\n"; + + // Explore neighbors :) + for (const auto& rule : rules) { + // std::cerr << "[DEBUG] Considering rule: " << rule.getRuleRef() + // << " for operator: " << rule.getOperator().getName() + // << " with total cost: " << rule.getResources().total_cost() + // << "\n"; + if (rule.getOperator() != current_op) { + continue; + } + + // std::cerr << "[DEBUG] Found applicable rule: " << rule.getRuleRef() + // << " for operator: " << current_op.getName() << "\n"; + size_t new_distance = computeCost(rule); + + // std::cerr << "[DEBUG] New computed distance for operator: " + // << current_op.getName() << " is " << new_distance << "\n"; + + if (new_distance < distances[current_op]) { + distances[current_op] = new_distance; + queue.push({new_distance, current_op}); + solutions[current_op] = rule.getRuleRef(); + // std::cerr << "[DEBUG] Updating distance for operator: " + // << current_op.getName() << " to " << new_distance << "\n"; + } + } + } + + return solutions; + } + + // For testing purposes (my first try) + std::unordered_map simple_solver() { if (!solutions.empty()) { return solutions; } @@ -193,6 +282,7 @@ class Solver { return solutions; } + void show() { for (const auto& [op, rule] : solutions) { std::cout << "Operator " << op.getName() @@ -262,15 +352,15 @@ void test_solver1() { RuleRefOp rule2(h, h_self, "h_rule"); std::vector ops = {cz, h}; - std::vector gateset = {"CNOT", "H"}; + std::vector gateset = {cnot, h}; std::vector rules = {rule1, rule2}; Solver solver(ops, gateset, rules); auto solutions = solver.solve(); - assert(solutions.size() == 2); + // solver.show(); + assert(solutions.size() == 3); assert(solutions[cz] == "cz_decomp_rule"); assert(solutions[h] == "h_rule"); - // solver.show(); std::cout << "[PASS] Solver tests (1)" << std::endl; } @@ -289,43 +379,126 @@ void test_solver2() { ResourceOp cz_to_rx_rz_cnot({{rx, 1}, {rz, 1}, {cnot, 1}}); RuleRefOp rule2(cz, cz_to_rx_rz_cnot, "cz_rx_rz_cnot_rule"); - ResourceOp h_to_rz_rx_rz({{rz, 2}, {rx, 1}}); - RuleRefOp rule3(h, h_to_rz_rx_rz, "h_rz_rx_rz_rule"); - ResourceOp h_to_rz_rz({{rz, 2}}); - RuleRefOp rule4(h, h_to_rz_rz, "h_rz_rz_rule"); - - ResourceOp rz_self({{rz, 1}}); - RuleRefOp rule5(rz, rz_self, "rz_rule"); + RuleRefOp rule3(h, h_to_rz_rz, "h_rz_rz_rule"); - ResourceOp rx_self({{rx, 1}}); - RuleRefOp rule6(rx, rx_self, "rx_rule"); - - ResourceOp cnot_self({{cnot, 1}}); - RuleRefOp rule7(cnot, cnot_self, "cnot_rule"); + ResourceOp h_to_rz_rx_rz({{rz, 2}, {rx, 1}}); + RuleRefOp rule4(h, h_to_rz_rx_rz, "h_rz_rx_rz_rule"); - std::vector ops = {cz, h}; - std::vector gateset = {"CNOT", "RZ", "RX"}; - std::vector rules = {rule1, rule2, rule3, rule4, rule5, rule6, rule7}; + std::vector ops = {h, cz}; + std::vector gateset = {cnot, rz, rx}; + std::vector rules = {rule1, rule2, rule3, rule4}; Solver solver(ops, gateset, rules); - auto solutions = solver.solve(); - assert(solutions.size() == 2); + // solver.show(); + assert(solutions.size() == 5); assert(solutions[cz] == "cz_h_cnot_rule"); assert(solutions[h] == "h_rz_rz_rule"); - // solver.show(); + assert(solutions[rz] == "base_op"); + assert(solutions[rx] == "base_op"); + assert(solutions[cnot] == "base_op"); std::cout << "[PASS] Solver tests (2)" << std::endl; } +void test_solver3() { + // Define Operators + Operator single_exc("SingleExcitation"); + Operator single_exc_plus("SingleExcitationPlus"); + Operator double_exc("DoubleExcitation"); + Operator cry("CRY"); + Operator s("S"); + Operator phase("PhaseShift"); + Operator rz("RZ"); + Operator rx("RX"); + Operator ry("RY"); + Operator rot("Rot"); + Operator hadamard("Hadamard"); + Operator cnot("CNOT"); + Operator cy("CY"); + Operator t("T"); + Operator global_phase("GlobalPhase"); + Operator phaseshift("PhaseShift"); + + // ('SingleExcitation', {H:2, CNOT:2, RY:2}, _single_excitation_decomp) + ResourceOp res_single_exc({{hadamard, 2}, {cnot, 2}, {ry, 2}}); + RuleRefOp rule_single_exc(single_exc, res_single_exc, "_single_excitation_decomp"); + + // ('SingleExcitationPlus', {H:2, CY:1, CNOT:2, RY:2, S:1, RZ:1, GlobalPhase:1}, _single_excitation_plus_decomp) + ResourceOp res_single_exc_plus({ + {hadamard, 2}, {cy, 1}, {cnot, 2}, {ry, 2}, + {s, 1}, {rz, 1}, {global_phase, 1}}); + RuleRefOp rule_single_exc_plus(single_exc_plus, res_single_exc_plus, "_single_excitation_plus_decomp"); + + // ('DoubleExcitation', {CNOT:14, H:6, RY:8}, _doublexcit) + ResourceOp res_double_exc1({{cnot, 14}, {hadamard, 6}, {ry, 8}}); + RuleRefOp rule_double_exc1(double_exc, res_double_exc1, "_doublexcit"); + + // ('CRY', {RY:2, CNOT:2}, _cry) + ResourceOp res_cry({{ry, 2}, {cnot, 2}}); + RuleRefOp rule_cry(cry, res_cry, "_cry"); + + // ('S', {PhaseShift:1}, _s_phaseshift) + ResourceOp res_s1({{phase, 1}}); + RuleRefOp rule_s1(s, res_s1, "_s_phaseshift"); + + // ('S', {T:1}, _s_to_t) + ResourceOp res_s2({{t, 1}}); + RuleRefOp rule_s2(s, res_s2, "_s_to_t"); + + // ('PhaseShift', {RZ:1, GlobalPhase:1}, _phaseshift_to_rz_gp) + ResourceOp res_phase({{rz, 1}, {global_phase, 1}}); + RuleRefOp rule_phase(phase, res_phase, "_phaseshift_to_rz_gp"); + + // ('RZ', {Rot:1}, _rz_to_rot) + ResourceOp res_rz1({{rot, 1}}); + RuleRefOp rule_rz1(rz, res_rz1, "_rz_to_rot"); + + // ('RZ', {RY:2, RX:1}, _rz_to_ry_rx) + ResourceOp res_rz2({{ry, 2}, {rx, 1}}); + RuleRefOp rule_rz2(rz, res_rz2, "_rz_to_ry_rx"); + + // ('Rot', {RZ:2, RY:1}, _rot_to_rz_ry_rz) + ResourceOp res_rot({{rz, 2}, {ry, 1}}); + RuleRefOp rule_rot(rot, res_rot, "_rot_to_rz_ry_rz"); + + + std::vector ops = {single_exc, single_exc_plus, double_exc}; + std::vector gateset = {ry, rx, cnot, hadamard, global_phase}; + std::vector rules = { + rule_single_exc, rule_single_exc_plus, + rule_double_exc1, + rule_cry, rule_s1, rule_s2, + rule_phase, rule_rz1, rule_rz2, + rule_rot + }; + + Solver solver(ops, gateset, rules); + auto solutions = solver.solve(); + // solver.show(); + assert(solutions.size() == 8); + assert(solutions[single_exc] == "_single_excitation_decomp"); + assert(solutions[single_exc_plus] == "_single_excitation_plus_decomp"); + assert(solutions[double_exc] == "_doublexcit"); + assert(solutions[ry] == "base_op"); + assert(solutions[rx] == "base_op"); + assert(solutions[cnot] == "base_op"); + assert(solutions[hadamard] == "base_op"); + assert(solutions[global_phase] == "base_op"); + + std::cout << "[PASS] Solver tests (3)" << std::endl; +} + + int main() { test_operator(); test_resourceop(); test_rulerefop(); test_solver1(); test_solver2(); + test_solver3(); std::cout << "All tests passed!" << std::endl; return 0; From f4f6a4e1953d0344438b25de4e2bbd49941013dd Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Thu, 25 Sep 2025 18:43:29 -0400 Subject: [PATCH 35/36] Add a simple parser --- delightning/src/main.cpp | 169 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) diff --git a/delightning/src/main.cpp b/delightning/src/main.cpp index 925a70fd41..ed7f8157c0 100644 --- a/delightning/src/main.cpp +++ b/delightning/src/main.cpp @@ -7,6 +7,9 @@ #include #include #include +#include +#include + // Operator // ______________________________ @@ -291,6 +294,22 @@ class Solver { } }; + +auto parse_quantum_custom_ops(const std::string& mlir_code) { + std::unordered_set ops; + + std::regex pattern(R"(quantum\.custom\s+\"([A-Za-z0-9_\.]+)\")"); + + std::smatch matches; + std::string::const_iterator search_start(mlir_code.cbegin()); + while (std::regex_search(search_start, mlir_code.cend(), matches, pattern)) { + ops.emplace(matches[1].str()); + search_start = matches.suffix().first; + } + + return ops; +} + // ---------------------------- // Simple Tests // ---------------------------- @@ -491,6 +510,155 @@ void test_solver3() { std::cout << "[PASS] Solver tests (3)" << std::endl; } +void test_solver4() { + std::string mlir_code = R"( + func.func public @circuit_15() -> tensor attributes {decompose_gatesets = [["GlobalPhase", "RY", "Hadamard", "CNOT", "RX"]], diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { +       %cst = arith.constant 1.250000e-01 : f64 +       %cst_0 = arith.constant -1.250000e-01 : f64 +       %cst_1 = arith.constant -2.500000e-01 : f64 +       %cst_2 = arith.constant 2.500000e-01 : f64 +       %cst_3 = arith.constant 5.000000e-01 : f64 +       %c0_i64 = arith.constant 0 : i64 +       quantum.device shots(%c0_i64) ["/home/ali/miniforge3/envs/decomp/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.so", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] +       %0 = quantum.alloc( 4) : !quantum.reg +       %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit +       %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit +       %out_qubits:2 = quantum.custom "SingleExcitation"(%cst_3) %1, %2 : !quantum.bit, !quantum.bit +       %out_qubits_4 = quantum.custom "Hadamard"() %out_qubits#1 : !quantum.bit +       %out_qubits_5:2 = quantum.custom "CNOT"() %out_qubits_4, %out_qubits#0 : !quantum.bit, !quantum.bit +       %out_qubits_6 = quantum.custom "RY"(%cst_2) %out_qubits_5#1 : !quantum.bit +       %out_qubits_7 = quantum.custom "RY"(%cst_2) %out_qubits_5#0 : !quantum.bit +       %out_qubits_8:2 = quantum.custom "CY"() %out_qubits_7, %out_qubits_6 : !quantum.bit, !quantum.bit +       %out_qubits_9 = quantum.custom "S"() %out_qubits_8#0 : !quantum.bit +       %out_qubits_10 = quantum.custom "Hadamard"() %out_qubits_9 : !quantum.bit +       %out_qubits_11 = quantum.custom "RZ"(%cst_1) %out_qubits_10 : !quantum.bit +       %out_qubits_12:2 = quantum.custom "CNOT"() %out_qubits_8#1, %out_qubits_11 : !quantum.bit, !quantum.bit +       quantum.gphase(%cst_0) : +       %out_qubits_13 = quantum.custom "Hadamard"() %out_qubits_12#1 : !quantum.bit +       %out_qubits_14:2 = quantum.custom "CNOT"() %out_qubits_13, %out_qubits_12#0 : !quantum.bit, !quantum.bit +       %out_qubits_15 = quantum.custom "RY"(%cst_2) %out_qubits_14#1 : !quantum.bit +       %out_qubits_16 = quantum.custom "RY"(%cst_2) %out_qubits_14#0 : !quantum.bit +       %out_qubits_17:2 = quantum.custom "CY"() %out_qubits_16, %out_qubits_15 : !quantum.bit, !quantum.bit +       %out_qubits_18 = quantum.custom "S"() %out_qubits_17#0 : !quantum.bit +       %out_qubits_19 = quantum.custom "Hadamard"() %out_qubits_18 : !quantum.bit +       %out_qubits_20 = quantum.custom "RZ"(%cst_2) %out_qubits_19 : !quantum.bit +       %out_qubits_21:2 = quantum.custom "CNOT"() %out_qubits_17#1, %out_qubits_20 : !quantum.bit, !quantum.bit +       quantum.gphase(%cst) : +       %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit +       %4 = quantum.extract %0[ 3] : !quantum.reg -> !quantum.bit +       %out_qubits_22:4 = quantum.custom "DoubleExcitation"(%cst_3) %out_qubits_21#0, %out_qubits_21#1, %3, %4 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit +       %5 = quantum.insert %0[ 0], %out_qubits_22#0 : !quantum.reg, !quantum.bit +       %6 = quantum.insert %5[ 1], %out_qubits_22#1 : !quantum.reg, !quantum.bit +       %7 = quantum.insert %6[ 2], %out_qubits_22#2 : !quantum.reg, !quantum.bit +       %8 = quantum.insert %7[ 3], %out_qubits_22#3 : !quantum.reg, !quantum.bit +       %9 = quantum.extract %8[ 0] : !quantum.reg -> !quantum.bit +       %10 = quantum.namedobs %9[ PauliZ] : !quantum.obs +       %11 = quantum.expval %10 : f64 +       %from_elements = tensor.from_elements %11 : tensor +       %12 = quantum.insert %8[ 0], %9 : !quantum.reg, !quantum.bit +       quantum.dealloc %12 : !quantum.reg +       quantum.device_release +       return %from_elements : tensor +     } + )"; + + auto parsed_ops = parse_quantum_custom_ops(mlir_code); + + // std::cout << "Parsed quantum.custom operations:" << std::endl; + // for (const auto& op : parsed_ops) { + // std::cout << op.getName() << std::endl; + // } + + // Define Operators + Operator single_exc("SingleExcitation"); + Operator single_exc_plus("SingleExcitationPlus"); + Operator double_exc("DoubleExcitation"); + Operator cry("CRY"); + Operator s("S"); + Operator phase("PhaseShift"); + Operator rz("RZ"); + Operator rx("RX"); + Operator ry("RY"); + Operator rot("Rot"); + Operator hadamard("Hadamard"); + Operator cnot("CNOT"); + Operator cy("CY"); + Operator t("T"); + Operator global_phase("GlobalPhase"); + Operator phaseshift("PhaseShift"); + + // ('SingleExcitation', {H:2, CNOT:2, RY:2}, _single_excitation_decomp) + ResourceOp res_single_exc({{hadamard, 2}, {cnot, 2}, {ry, 2}}); + RuleRefOp rule_single_exc(single_exc, res_single_exc, "_single_excitation_decomp"); + + // ('SingleExcitationPlus', {H:2, CY:1, CNOT:2, RY:2, S:1, RZ:1, GlobalPhase:1}, _single_excitation_plus_decomp) + ResourceOp res_single_exc_plus({ + {hadamard, 2}, {cy, 1}, {cnot, 2}, {ry, 2}, + {s, 1}, {rz, 1}, {global_phase, 1}}); + RuleRefOp rule_single_exc_plus(single_exc_plus, res_single_exc_plus, "_single_excitation_plus_decomp"); + + // ('DoubleExcitation', {CNOT:14, H:6, RY:8}, _doublexcit) + ResourceOp res_double_exc1({{cnot, 14}, {hadamard, 6}, {ry, 8}}); + RuleRefOp rule_double_exc1(double_exc, res_double_exc1, "_doublexcit"); + + // ('CRY', {RY:2, CNOT:2}, _cry) + ResourceOp res_cry({{ry, 2}, {cnot, 2}}); + RuleRefOp rule_cry(cry, res_cry, "_cry"); + + // ('S', {PhaseShift:1}, _s_phaseshift) + ResourceOp res_s1({{phase, 1}}); + RuleRefOp rule_s1(s, res_s1, "_s_phaseshift"); + + // ('S', {T:1}, _s_to_t) + ResourceOp res_s2({{t, 1}}); + RuleRefOp rule_s2(s, res_s2, "_s_to_t"); + + // ('PhaseShift', {RZ:1, GlobalPhase:1}, _phaseshift_to_rz_gp) + ResourceOp res_phase({{rz, 1}, {global_phase, 1}}); + RuleRefOp rule_phase(phase, res_phase, "_phaseshift_to_rz_gp"); + + // ('RZ', {Rot:1}, _rz_to_rot) + ResourceOp res_rz1({{rot, 1}}); + RuleRefOp rule_rz1(rz, res_rz1, "_rz_to_rot"); + + // ('RZ', {RY:2, RX:1}, _rz_to_ry_rx) + ResourceOp res_rz2({{ry, 2}, {rx, 1}}); + RuleRefOp rule_rz2(rz, res_rz2, "_rz_to_ry_rx"); + + // ('Rot', {RZ:2, RY:1}, _rot_to_rz_ry_rz) + ResourceOp res_rot({{rz, 2}, {ry, 1}}); + RuleRefOp rule_rot(rot, res_rot, "_rot_to_rz_ry_rz"); + + + std::vector ops(parsed_ops.begin(), parsed_ops.end()); + std::vector gateset = {ry, rx, cnot, hadamard, global_phase}; + std::vector rules = { + rule_single_exc, rule_single_exc_plus, + rule_double_exc1, + rule_cry, rule_s1, rule_s2, + rule_phase, rule_rz1, rule_rz2, + rule_rot + }; + + Solver solver(ops, gateset, rules); + auto solutions = solver.solve(); + // solver.show(); + + assert(solutions.size() == 9); + assert(solutions[single_exc] == "_single_excitation_decomp"); + assert(solutions[double_exc] == "_doublexcit"); + assert(solutions[rz] == "_rz_to_rot"); + assert(solutions[s] == "_s_phaseshift"); + assert(solutions[global_phase] == "base_op"); + assert(solutions[hadamard] == "base_op"); + assert(solutions[ry] == "base_op"); + assert(solutions[rx] == "base_op"); + assert(solutions[cnot] == "base_op"); + + std::cout << "[PASS] Solver tests (4)" << std::endl; + +} + int main() { test_operator(); @@ -499,6 +667,7 @@ int main() { test_solver1(); test_solver2(); test_solver3(); + test_solver4(); std::cout << "All tests passed!" << std::endl; return 0; From b000433aa7f22600798ddb83d7c6d01aa1cb1839 Mon Sep 17 00:00:00 2001 From: Ali Asadi <10773383+maliasadi@users.noreply.github.com> Date: Fri, 26 Sep 2025 08:22:49 -0400 Subject: [PATCH 36/36] Update solver w/ graph --- delightning/src/main.cpp | 166 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 163 insertions(+), 3 deletions(-) diff --git a/delightning/src/main.cpp b/delightning/src/main.cpp index ed7f8157c0..6fb85511fc 100644 --- a/delightning/src/main.cpp +++ b/delightning/src/main.cpp @@ -111,7 +111,8 @@ class RuleRefOp { -// Solver +// BasicSolver <- experimental! +// Check PLSolver following PL's implementation // _________________________________________ // * Ops>: Operators // * Gateset>: Operators @@ -120,7 +121,7 @@ class RuleRefOp { // + show(): stdout // + solve(): map -class Solver { +class BasicSolver { private: std::vector ops; std::vector gateset; @@ -168,7 +169,7 @@ class Solver { } public: - Solver(const std::vector& ops, + BasicSolver(const std::vector& ops, const std::vector& gateset, const std::vector& rules) : ops(ops), gateset(gateset), rules(rules) {} @@ -295,6 +296,165 @@ class Solver { }; +// PLSolver w/ Operator and Rule Nodes + +enum class NodeType { + OPERATOR, + RULE +}; + +struct Node { + NodeType type; + Operator op; + RuleRefOp rule; + size_t index; +}; + +struct Edge { + size_t target; + size_t weight; +}; + +class Graph { +private: + std::vector nodes; + std::vector> adjList; + +public: + Graph() = default; + + size_t addNode(const Node& node) { + const size_t idx = nodes.size(); + nodes.push_back(node); + adjList.emplace_back(); + return idx; + } + + void addEdge(size_t from, size_t to, size_t weight) { + adjList[from].push_back({to, weight}); + } + + const Node& getNode(size_t index) const { + return nodes[index]; + } + + size_t size() const { + return nodes.size(); + } + + const std::vector& getNeighbors(size_t index) const { + return adjList[index]; + } +}; + + +Graph buildGraph( + const std::vector& ops, + const std::vector& gateset, + const std::vector& rules) +{ + Graph graph; + std::unordered_map opNodes; + + // Create Operator nodes + for (const auto& op: ops) { + size_t idx = graph.addNode({NodeType::OPERATOR, op, RuleRefOp(op, {}, ""), 0}); + opNodes[op] = idx; + } + + for (const auto &op: gateset) { + size_t idx = graph.addNode({NodeType::OPERATOR, op, RuleRefOp(op, {}, ""), 0}); + opNodes[op] = idx; + } + + // Create Rule nodes and edges + for (const auto& rule: rules) { + size_t ruleIdx = graph.addNode({NodeType::RULE, {}, rule, 0}); + auto op = rule.getOperator(); + size_t opIdx = opNodes[op]; + + // Op -> Rule edge + graph.addEdge(opIdx, ruleIdx, 0); + + // Rule -> deps edges + for (const auto &[dep, count] : rule.getResources().getResources()) { + if (!opNodes.count(dep)) { + size_t depIdx = graph.addNode({NodeType::OPERATOR, dep, RuleRefOp(dep, {}, ""), 0}); + opNodes[dep] = depIdx; + } + graph.addEdge(ruleIdx, opNodes[dep], count); + } + + } + + return graph; +} + + +std::unordered_map +solveGraph(Graph& graph) { + using ElemPair = std::pair; // (distance, nodeIndex) + auto cmp = [](const ElemPair& a, const ElemPair& b) { return a.first > b.first; }; + std::priority_queue, decltype(cmp)> queue(cmp); + + std::vector dist(graph.size(), std::numeric_limits::max()); + std::unordered_map solutions; + + // Start with gateset operators = cost 0 + for (size_t i = 0; i < graph.size(); i++) { + auto& node = graph.getNode(i); + if (node.type == NodeType::OPERATOR && dist[i] == std::numeric_limits::max()) { + // Basis gate → distance 0 + if (solutions.count(node.op) == 0) { + dist[i] = 0; + queue.push({0, i}); + } + } + } + + while (!queue.empty()) { + auto [curDist, u] = queue.top(); + queue.pop(); + + if (curDist > dist[u]) continue; + + auto& uNode = graph.getNode(u); + + // Explore neighbors + for (auto& edge : graph.getNeighbors(u)) { + auto& vNode = graph.getNode(edge.target); + + size_t newDist = 0; + if (uNode.type == NodeType::OPERATOR && vNode.type == NodeType::RULE) { + // Operator → Rule: defer cost to expansion + newDist = curDist; + } else if (uNode.type == NodeType::RULE && vNode.type == NodeType::OPERATOR) { + // Rule → Operator: accumulate resource counts + size_t count = uNode.rule.getResources().op_cost(vNode.op); + newDist = curDist + count * dist[edge.target]; + } else { + continue; + } + + if (newDist < dist[edge.target]) { + dist[edge.target] = newDist; + queue.push({newDist, edge.target}); + + // If we reached an operator from a rule, record the chosen rule + if (vNode.type == NodeType::OPERATOR && uNode.type == NodeType::RULE) { + solutions[vNode.op] = uNode.rule.getRuleRef(); + } + } + } + } + + return solutions; +} + +// ---------------------------- +// MLIR Parser for quantum.custom ops +// ---------------------------- + auto parse_quantum_custom_ops(const std::string& mlir_code) { std::unordered_set ops;