diff --git a/delightning/CMakeLists.txt b/delightning/CMakeLists.txt new file mode 100644 index 0000000000..03949991e9 --- /dev/null +++ b/delightning/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.20) + +project(delightning VERSION 0.1.0 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_executable(delightning src/main.cpp) diff --git a/delightning/Makefile b/delightning/Makefile new file mode 100644 index 0000000000..1d338ddcd9 --- /dev/null +++ b/delightning/Makefile @@ -0,0 +1,14 @@ +CXX=g++ +CXXFLAGS=-std=c++20 -Wall -Wextra -O2 + +TARGET=delightning +SRCS=src/main.cpp + +$(TARGET): $(SRCS) + $(CXX) $(CXXFLAGS) -o $(TARGET) $(SRCS) + +run: $(TARGET) + ./$(TARGET) + +clean: + rm -f $(TARGET) diff --git a/delightning/src/main.cpp b/delightning/src/main.cpp new file mode 100644 index 0000000000..6fb85511fc --- /dev/null +++ b/delightning/src/main.cpp @@ -0,0 +1,835 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +// Operator +// ______________________________ +// * name: Operator name +// + getName(): string + +class Operator { + // TODO: string_view + std::string name; + +public: + Operator() = default; + explicit Operator(const std::string& name) : name(name) {} + std::string getName() const { return name; } + + bool operator==(const Operator& other) const { + return name == other.name; + } + bool operator!=(const Operator& other) const { + return !(*this == other); + } +}; + +namespace std { + template <> + struct hash { + std::size_t operator()(const Operator& op) const noexcept { + return std::hash()(op.getName()); + } + }; +} + + + + +// ResourceOp +// ______________________________ +// * resources>: Resources +// + getResources(): umap +// + total_cost(): int +// + op_cost(Operator): int +// + has_op(Operator): bool + +class ResourceOp { + std::unordered_map resources; + size_t total = 0; + +public: + ResourceOp() = default; + explicit ResourceOp(const std::unordered_map& resources) : resources(resources) {} + + const std::unordered_map& getResources() const { + return resources; + } + + size_t total_cost() { + if (total == 0 && !resources.empty()) { + for (const auto& pair : resources) { + total += pair.second; + } + } + return total; + } + + size_t op_cost(const Operator& op) const { + auto it = resources.find(op); + return (it != resources.end()) ? it->second : 0; + } + + bool has_op(const Operator& op) const { + return resources.find(op) != resources.end(); + } +}; + + + +// RuleRefOp +// _________________________________________ +// * Op: Operator +// * Resource: Resources +// * RuleRef: Pointer to the rule + +class RuleRefOp { + Operator op; + ResourceOp resources; + std::string rule_ref; + +public: + RuleRefOp(const Operator& op, const ResourceOp& resources, const std::string& rule_ref) + : op(op), resources(resources), rule_ref(rule_ref) {} + + const Operator& getOperator() const { return op; } + const std::string& getRuleRef() const { return rule_ref; } + + // TODO: make this const ref to avoid copy overhead + // this is currently required for my simple Dijkstra sovler + ResourceOp getResources() const { return resources; } +}; + + + +// BasicSolver <- experimental! +// Check PLSolver following PL's implementation +// _________________________________________ +// * Ops>: Operators +// * Gateset>: Operators +// * Rules>: Rules +// + graph(): void +// + show(): stdout +// + solve(): map + +class BasicSolver { +private: + std::vector ops; + std::vector gateset; + std::vector rules; + + // Cached solutions and distances maps + std::unordered_map solutions; + std::unordered_map distances; + + size_t computeCost(const RuleRefOp& rule) const { + size_t total = 0; + // std::cerr << "[DEBUG] Computing cost for rule: " << rule.getRuleRef() << "\n"; + for (const auto& [dep, count] : rule.getResources().getResources()) { + auto it = distances.find(dep); + // std::cerr << "[DEBUG] Dependency: " << dep.getName() + // << " with count: " << count + // << " and cost: " << it->second << "\n"; + if (it == distances.end() || it->second == std::numeric_limits::max()) { + // Dependency not found or unreachable yet :( + return std::numeric_limits::max(); + } + total += count * it->second; + } + return total; + } + + auto initGraph() { + using NodeOp = std::pair; + auto cmp = [](const NodeOp& left, const NodeOp& right) { + return left.first > right.first; + }; + std::priority_queue, decltype(cmp)> queue(cmp); + + for (const auto& op : ops) { + distances[op] = std::numeric_limits::max(); + queue.push({distances[op], op}); + } + + for (const auto& g : gateset) { + distances[g] = 1; + queue.push({1, g}); + solutions[g] = "base_op"; + } + return queue; + } + +public: + BasicSolver(const std::vector& ops, + const std::vector& gateset, + const std::vector& rules) + : ops(ops), gateset(gateset), rules(rules) {} + + bool isBasisGate(const Operator& op) const { + return std::find(gateset.begin(), gateset.end(), op) != gateset.end(); + } + + std::unordered_map solve() { + auto queue = initGraph(); + + while (!queue.empty()) { + auto [current_distance, current_op] = queue.top(); + queue.pop(); + + // If we found a better path, skip processing + if (current_distance > distances[current_op]) { + continue; + } + + // std::cerr << "[DEBUG] Exploring neighbors of operator: " + // << current_op.getName() << "\n"; + + // Explore neighbors :) + for (const auto& rule : rules) { + // std::cerr << "[DEBUG] Considering rule: " << rule.getRuleRef() + // << " for operator: " << rule.getOperator().getName() + // << " with total cost: " << rule.getResources().total_cost() + // << "\n"; + if (rule.getOperator() != current_op) { + continue; + } + + // std::cerr << "[DEBUG] Found applicable rule: " << rule.getRuleRef() + // << " for operator: " << current_op.getName() << "\n"; + size_t new_distance = computeCost(rule); + + // std::cerr << "[DEBUG] New computed distance for operator: " + // << current_op.getName() << " is " << new_distance << "\n"; + + if (new_distance < distances[current_op]) { + distances[current_op] = new_distance; + queue.push({new_distance, current_op}); + solutions[current_op] = rule.getRuleRef(); + // std::cerr << "[DEBUG] Updating distance for operator: " + // << current_op.getName() << " to " << new_distance << "\n"; + } + } + } + + return solutions; + } + + // For testing purposes (my first try) + std::unordered_map simple_solver() { + if (!solutions.empty()) { + return solutions; + } + + // We need to create a distance map for our Dijkstra's algorithm + // For now, I keep everything simple starting with max distance + // TODO: do this part implicitly for performance + std::unordered_map distances; + for (const auto& op : ops) { + distances[op] = std::numeric_limits::max(); + } + + // There are different ways to implement Dijkstra's algorithm + // Here, I use a simple priority queue for demonstration + // TODO: optimize with a better priority queue or min-heap + using QElement = std::pair; // (distance, operator) + auto cmp = [](const QElement& left, const QElement& right) { return left.first > right.first; }; + std::priority_queue, decltype(cmp)> queue(cmp); + + // Initialize the queue with all operators and distance 0 + for (const auto& op : ops) { + queue.push({0, op}); + } + + // Dijkstra's algorithm main loop + while (!queue.empty()) { + auto [current_distance, current_op] = queue.top(); + queue.pop(); + // If we found a better path, skip processing + if (current_distance > distances[current_op]) { + continue; + } + + // std::cerr << "[DEBUG] Exploring neighbors of operator: " + // << current_op.getName() << "\n"; + + // Explore neighbors :) + for (const auto& rule: rules) { + if (rule.getOperator() == current_op) { + // std::cerr << "[DEBUG] Found applicable rule: " << rule.getRuleRef() + // << " for operator: " << current_op.getName() + // << " with total cost: " << rule.getResources().total_cost() + // << "\n"; + size_t new_distance = current_distance + rule.getResources().total_cost(); + if (new_distance < distances[current_op]) { + // std::cerr << "[DEBUG] Updating distance for operator: " << current_op.getName() + // << " from " << distances[current_op] + // << " to " << new_distance << "\n"; + distances[current_op] = new_distance; + queue.push({new_distance, current_op}); + + // Update solution + solutions[current_op] = rule.getRuleRef(); + } + } + } + } + + return solutions; + } + + + void show() { + for (const auto& [op, rule] : solutions) { + std::cout << "Operator " << op.getName() + << " decomposed using rule: " << rule << "\n"; + } + } +}; + + +// PLSolver w/ Operator and Rule Nodes + +enum class NodeType { + OPERATOR, + RULE +}; + +struct Node { + NodeType type; + Operator op; + RuleRefOp rule; + size_t index; +}; + +struct Edge { + size_t target; + size_t weight; +}; + +class Graph { +private: + std::vector nodes; + std::vector> adjList; + +public: + Graph() = default; + + size_t addNode(const Node& node) { + const size_t idx = nodes.size(); + nodes.push_back(node); + adjList.emplace_back(); + return idx; + } + + void addEdge(size_t from, size_t to, size_t weight) { + adjList[from].push_back({to, weight}); + } + + const Node& getNode(size_t index) const { + return nodes[index]; + } + + size_t size() const { + return nodes.size(); + } + + const std::vector& getNeighbors(size_t index) const { + return adjList[index]; + } +}; + + +Graph buildGraph( + const std::vector& ops, + const std::vector& gateset, + const std::vector& rules) +{ + Graph graph; + std::unordered_map opNodes; + + // Create Operator nodes + for (const auto& op: ops) { + size_t idx = graph.addNode({NodeType::OPERATOR, op, RuleRefOp(op, {}, ""), 0}); + opNodes[op] = idx; + } + + for (const auto &op: gateset) { + size_t idx = graph.addNode({NodeType::OPERATOR, op, RuleRefOp(op, {}, ""), 0}); + opNodes[op] = idx; + } + + // Create Rule nodes and edges + for (const auto& rule: rules) { + size_t ruleIdx = graph.addNode({NodeType::RULE, {}, rule, 0}); + auto op = rule.getOperator(); + size_t opIdx = opNodes[op]; + + // Op -> Rule edge + graph.addEdge(opIdx, ruleIdx, 0); + + // Rule -> deps edges + for (const auto &[dep, count] : rule.getResources().getResources()) { + if (!opNodes.count(dep)) { + size_t depIdx = graph.addNode({NodeType::OPERATOR, dep, RuleRefOp(dep, {}, ""), 0}); + opNodes[dep] = depIdx; + } + graph.addEdge(ruleIdx, opNodes[dep], count); + } + + } + + return graph; +} + + +std::unordered_map +solveGraph(Graph& graph) { + using ElemPair = std::pair; // (distance, nodeIndex) + auto cmp = [](const ElemPair& a, const ElemPair& b) { return a.first > b.first; }; + std::priority_queue, decltype(cmp)> queue(cmp); + + std::vector dist(graph.size(), std::numeric_limits::max()); + std::unordered_map solutions; + + // Start with gateset operators = cost 0 + for (size_t i = 0; i < graph.size(); i++) { + auto& node = graph.getNode(i); + if (node.type == NodeType::OPERATOR && dist[i] == std::numeric_limits::max()) { + // Basis gate → distance 0 + if (solutions.count(node.op) == 0) { + dist[i] = 0; + queue.push({0, i}); + } + } + } + + while (!queue.empty()) { + auto [curDist, u] = queue.top(); + queue.pop(); + + if (curDist > dist[u]) continue; + + auto& uNode = graph.getNode(u); + + // Explore neighbors + for (auto& edge : graph.getNeighbors(u)) { + auto& vNode = graph.getNode(edge.target); + + size_t newDist = 0; + if (uNode.type == NodeType::OPERATOR && vNode.type == NodeType::RULE) { + // Operator → Rule: defer cost to expansion + newDist = curDist; + } else if (uNode.type == NodeType::RULE && vNode.type == NodeType::OPERATOR) { + // Rule → Operator: accumulate resource counts + size_t count = uNode.rule.getResources().op_cost(vNode.op); + newDist = curDist + count * dist[edge.target]; + } else { + continue; + } + + if (newDist < dist[edge.target]) { + dist[edge.target] = newDist; + queue.push({newDist, edge.target}); + + // If we reached an operator from a rule, record the chosen rule + if (vNode.type == NodeType::OPERATOR && uNode.type == NodeType::RULE) { + solutions[vNode.op] = uNode.rule.getRuleRef(); + } + } + } + } + + return solutions; +} + +// ---------------------------- +// MLIR Parser for quantum.custom ops +// ---------------------------- + +auto parse_quantum_custom_ops(const std::string& mlir_code) { + std::unordered_set ops; + + std::regex pattern(R"(quantum\.custom\s+\"([A-Za-z0-9_\.]+)\")"); + + std::smatch matches; + std::string::const_iterator search_start(mlir_code.cbegin()); + while (std::regex_search(search_start, mlir_code.cend(), matches, pattern)) { + ops.emplace(matches[1].str()); + search_start = matches.suffix().first; + } + + return ops; +} + +// ---------------------------- +// Simple Tests +// ---------------------------- + +void test_operator() { + Operator op1("H"); + Operator op2("X"); + Operator op3("H"); + + assert(op1.getName() == "H"); + assert(op2.getName() == "X"); + assert(!(op1 == op2)); + assert(op1 != op2); + assert(op1 == op3); + + std::cout << "[PASS] Operator tests" << std::endl; +} + +void test_resourceop() { + Operator op1("H"); + Operator op2("X"); + + std::unordered_map res{{op1, 3}, {op2, 5}}; + ResourceOp r(res); + + assert(r.total_cost() == 8); + assert(r.op_cost(op1) == 3); + assert(r.op_cost(op2) == 5); + assert(r.op_cost(Operator("Z")) == 0); + assert(r.has_op(op1)); + assert(!r.has_op(Operator("Z"))); + + std::cout << "[PASS] ResourceOp tests" << std::endl; +} + + +void test_rulerefop() { + Operator op("CX"); + ResourceOp r({{op, 2}}); + RuleRefOp rr(op, r, "rule1"); + + assert(rr.getOperator() == op); + assert(rr.getRuleRef() == "rule1"); + assert(rr.getResources().op_cost(op) == 2); + + std::cout << "[PASS] RuleRefOp tests" << std::endl; +} + +void test_solver1() { + + Operator cnot("CNOT"); + Operator cz("CZ"); + Operator h("H"); + + ResourceOp cz_to_cnot({{cnot, 1}, {h, 2}}); + ResourceOp h_self({{h, 1}}); + + RuleRefOp rule1(cz, cz_to_cnot, "cz_decomp_rule"); + RuleRefOp rule2(h, h_self, "h_rule"); + + std::vector ops = {cz, h}; + std::vector gateset = {cnot, h}; + std::vector rules = {rule1, rule2}; + + Solver solver(ops, gateset, rules); + auto solutions = solver.solve(); + // solver.show(); + assert(solutions.size() == 3); + assert(solutions[cz] == "cz_decomp_rule"); + assert(solutions[h] == "h_rule"); + + std::cout << "[PASS] Solver tests (1)" << std::endl; +} + +void test_solver2() { + + Operator cz("CZ"); + Operator cnot("CNOT"); + Operator h("H"); + Operator rz("RZ"); + Operator rx("RX"); + + ResourceOp cz_to_h_cnot({{h, 1}, {cnot, 1}}); + RuleRefOp rule1(cz, cz_to_h_cnot, "cz_h_cnot_rule"); + + ResourceOp cz_to_rx_rz_cnot({{rx, 1}, {rz, 1}, {cnot, 1}}); + RuleRefOp rule2(cz, cz_to_rx_rz_cnot, "cz_rx_rz_cnot_rule"); + + ResourceOp h_to_rz_rz({{rz, 2}}); + RuleRefOp rule3(h, h_to_rz_rz, "h_rz_rz_rule"); + + ResourceOp h_to_rz_rx_rz({{rz, 2}, {rx, 1}}); + RuleRefOp rule4(h, h_to_rz_rx_rz, "h_rz_rx_rz_rule"); + + std::vector ops = {h, cz}; + std::vector gateset = {cnot, rz, rx}; + std::vector rules = {rule1, rule2, rule3, rule4}; + + Solver solver(ops, gateset, rules); + auto solutions = solver.solve(); + // solver.show(); + assert(solutions.size() == 5); + assert(solutions[cz] == "cz_h_cnot_rule"); + assert(solutions[h] == "h_rz_rz_rule"); + assert(solutions[rz] == "base_op"); + assert(solutions[rx] == "base_op"); + assert(solutions[cnot] == "base_op"); + + std::cout << "[PASS] Solver tests (2)" << std::endl; +} + + +void test_solver3() { + // Define Operators + Operator single_exc("SingleExcitation"); + Operator single_exc_plus("SingleExcitationPlus"); + Operator double_exc("DoubleExcitation"); + Operator cry("CRY"); + Operator s("S"); + Operator phase("PhaseShift"); + Operator rz("RZ"); + Operator rx("RX"); + Operator ry("RY"); + Operator rot("Rot"); + Operator hadamard("Hadamard"); + Operator cnot("CNOT"); + Operator cy("CY"); + Operator t("T"); + Operator global_phase("GlobalPhase"); + Operator phaseshift("PhaseShift"); + + // ('SingleExcitation', {H:2, CNOT:2, RY:2}, _single_excitation_decomp) + ResourceOp res_single_exc({{hadamard, 2}, {cnot, 2}, {ry, 2}}); + RuleRefOp rule_single_exc(single_exc, res_single_exc, "_single_excitation_decomp"); + + // ('SingleExcitationPlus', {H:2, CY:1, CNOT:2, RY:2, S:1, RZ:1, GlobalPhase:1}, _single_excitation_plus_decomp) + ResourceOp res_single_exc_plus({ + {hadamard, 2}, {cy, 1}, {cnot, 2}, {ry, 2}, + {s, 1}, {rz, 1}, {global_phase, 1}}); + RuleRefOp rule_single_exc_plus(single_exc_plus, res_single_exc_plus, "_single_excitation_plus_decomp"); + + // ('DoubleExcitation', {CNOT:14, H:6, RY:8}, _doublexcit) + ResourceOp res_double_exc1({{cnot, 14}, {hadamard, 6}, {ry, 8}}); + RuleRefOp rule_double_exc1(double_exc, res_double_exc1, "_doublexcit"); + + // ('CRY', {RY:2, CNOT:2}, _cry) + ResourceOp res_cry({{ry, 2}, {cnot, 2}}); + RuleRefOp rule_cry(cry, res_cry, "_cry"); + + // ('S', {PhaseShift:1}, _s_phaseshift) + ResourceOp res_s1({{phase, 1}}); + RuleRefOp rule_s1(s, res_s1, "_s_phaseshift"); + + // ('S', {T:1}, _s_to_t) + ResourceOp res_s2({{t, 1}}); + RuleRefOp rule_s2(s, res_s2, "_s_to_t"); + + // ('PhaseShift', {RZ:1, GlobalPhase:1}, _phaseshift_to_rz_gp) + ResourceOp res_phase({{rz, 1}, {global_phase, 1}}); + RuleRefOp rule_phase(phase, res_phase, "_phaseshift_to_rz_gp"); + + // ('RZ', {Rot:1}, _rz_to_rot) + ResourceOp res_rz1({{rot, 1}}); + RuleRefOp rule_rz1(rz, res_rz1, "_rz_to_rot"); + + // ('RZ', {RY:2, RX:1}, _rz_to_ry_rx) + ResourceOp res_rz2({{ry, 2}, {rx, 1}}); + RuleRefOp rule_rz2(rz, res_rz2, "_rz_to_ry_rx"); + + // ('Rot', {RZ:2, RY:1}, _rot_to_rz_ry_rz) + ResourceOp res_rot({{rz, 2}, {ry, 1}}); + RuleRefOp rule_rot(rot, res_rot, "_rot_to_rz_ry_rz"); + + + std::vector ops = {single_exc, single_exc_plus, double_exc}; + std::vector gateset = {ry, rx, cnot, hadamard, global_phase}; + std::vector rules = { + rule_single_exc, rule_single_exc_plus, + rule_double_exc1, + rule_cry, rule_s1, rule_s2, + rule_phase, rule_rz1, rule_rz2, + rule_rot + }; + + Solver solver(ops, gateset, rules); + auto solutions = solver.solve(); + // solver.show(); + assert(solutions.size() == 8); + assert(solutions[single_exc] == "_single_excitation_decomp"); + assert(solutions[single_exc_plus] == "_single_excitation_plus_decomp"); + assert(solutions[double_exc] == "_doublexcit"); + assert(solutions[ry] == "base_op"); + assert(solutions[rx] == "base_op"); + assert(solutions[cnot] == "base_op"); + assert(solutions[hadamard] == "base_op"); + assert(solutions[global_phase] == "base_op"); + + std::cout << "[PASS] Solver tests (3)" << std::endl; +} + +void test_solver4() { + std::string mlir_code = R"( + func.func public @circuit_15() -> tensor attributes {decompose_gatesets = [["GlobalPhase", "RY", "Hadamard", "CNOT", "RX"]], diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} { +       %cst = arith.constant 1.250000e-01 : f64 +       %cst_0 = arith.constant -1.250000e-01 : f64 +       %cst_1 = arith.constant -2.500000e-01 : f64 +       %cst_2 = arith.constant 2.500000e-01 : f64 +       %cst_3 = arith.constant 5.000000e-01 : f64 +       %c0_i64 = arith.constant 0 : i64 +       quantum.device shots(%c0_i64) ["/home/ali/miniforge3/envs/decomp/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.so", "LightningSimulator", "{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] +       %0 = quantum.alloc( 4) : !quantum.reg +       %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit +       %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit +       %out_qubits:2 = quantum.custom "SingleExcitation"(%cst_3) %1, %2 : !quantum.bit, !quantum.bit +       %out_qubits_4 = quantum.custom "Hadamard"() %out_qubits#1 : !quantum.bit +       %out_qubits_5:2 = quantum.custom "CNOT"() %out_qubits_4, %out_qubits#0 : !quantum.bit, !quantum.bit +       %out_qubits_6 = quantum.custom "RY"(%cst_2) %out_qubits_5#1 : !quantum.bit +       %out_qubits_7 = quantum.custom "RY"(%cst_2) %out_qubits_5#0 : !quantum.bit +       %out_qubits_8:2 = quantum.custom "CY"() %out_qubits_7, %out_qubits_6 : !quantum.bit, !quantum.bit +       %out_qubits_9 = quantum.custom "S"() %out_qubits_8#0 : !quantum.bit +       %out_qubits_10 = quantum.custom "Hadamard"() %out_qubits_9 : !quantum.bit +       %out_qubits_11 = quantum.custom "RZ"(%cst_1) %out_qubits_10 : !quantum.bit +       %out_qubits_12:2 = quantum.custom "CNOT"() %out_qubits_8#1, %out_qubits_11 : !quantum.bit, !quantum.bit +       quantum.gphase(%cst_0) : +       %out_qubits_13 = quantum.custom "Hadamard"() %out_qubits_12#1 : !quantum.bit +       %out_qubits_14:2 = quantum.custom "CNOT"() %out_qubits_13, %out_qubits_12#0 : !quantum.bit, !quantum.bit +       %out_qubits_15 = quantum.custom "RY"(%cst_2) %out_qubits_14#1 : !quantum.bit +       %out_qubits_16 = quantum.custom "RY"(%cst_2) %out_qubits_14#0 : !quantum.bit +       %out_qubits_17:2 = quantum.custom "CY"() %out_qubits_16, %out_qubits_15 : !quantum.bit, !quantum.bit +       %out_qubits_18 = quantum.custom "S"() %out_qubits_17#0 : !quantum.bit +       %out_qubits_19 = quantum.custom "Hadamard"() %out_qubits_18 : !quantum.bit +       %out_qubits_20 = quantum.custom "RZ"(%cst_2) %out_qubits_19 : !quantum.bit +       %out_qubits_21:2 = quantum.custom "CNOT"() %out_qubits_17#1, %out_qubits_20 : !quantum.bit, !quantum.bit +       quantum.gphase(%cst) : +       %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit +       %4 = quantum.extract %0[ 3] : !quantum.reg -> !quantum.bit +       %out_qubits_22:4 = quantum.custom "DoubleExcitation"(%cst_3) %out_qubits_21#0, %out_qubits_21#1, %3, %4 : !quantum.bit, !quantum.bit, !quantum.bit, !quantum.bit +       %5 = quantum.insert %0[ 0], %out_qubits_22#0 : !quantum.reg, !quantum.bit +       %6 = quantum.insert %5[ 1], %out_qubits_22#1 : !quantum.reg, !quantum.bit +       %7 = quantum.insert %6[ 2], %out_qubits_22#2 : !quantum.reg, !quantum.bit +       %8 = quantum.insert %7[ 3], %out_qubits_22#3 : !quantum.reg, !quantum.bit +       %9 = quantum.extract %8[ 0] : !quantum.reg -> !quantum.bit +       %10 = quantum.namedobs %9[ PauliZ] : !quantum.obs +       %11 = quantum.expval %10 : f64 +       %from_elements = tensor.from_elements %11 : tensor +       %12 = quantum.insert %8[ 0], %9 : !quantum.reg, !quantum.bit +       quantum.dealloc %12 : !quantum.reg +       quantum.device_release +       return %from_elements : tensor +     } + )"; + + auto parsed_ops = parse_quantum_custom_ops(mlir_code); + + // std::cout << "Parsed quantum.custom operations:" << std::endl; + // for (const auto& op : parsed_ops) { + // std::cout << op.getName() << std::endl; + // } + + // Define Operators + Operator single_exc("SingleExcitation"); + Operator single_exc_plus("SingleExcitationPlus"); + Operator double_exc("DoubleExcitation"); + Operator cry("CRY"); + Operator s("S"); + Operator phase("PhaseShift"); + Operator rz("RZ"); + Operator rx("RX"); + Operator ry("RY"); + Operator rot("Rot"); + Operator hadamard("Hadamard"); + Operator cnot("CNOT"); + Operator cy("CY"); + Operator t("T"); + Operator global_phase("GlobalPhase"); + Operator phaseshift("PhaseShift"); + + // ('SingleExcitation', {H:2, CNOT:2, RY:2}, _single_excitation_decomp) + ResourceOp res_single_exc({{hadamard, 2}, {cnot, 2}, {ry, 2}}); + RuleRefOp rule_single_exc(single_exc, res_single_exc, "_single_excitation_decomp"); + + // ('SingleExcitationPlus', {H:2, CY:1, CNOT:2, RY:2, S:1, RZ:1, GlobalPhase:1}, _single_excitation_plus_decomp) + ResourceOp res_single_exc_plus({ + {hadamard, 2}, {cy, 1}, {cnot, 2}, {ry, 2}, + {s, 1}, {rz, 1}, {global_phase, 1}}); + RuleRefOp rule_single_exc_plus(single_exc_plus, res_single_exc_plus, "_single_excitation_plus_decomp"); + + // ('DoubleExcitation', {CNOT:14, H:6, RY:8}, _doublexcit) + ResourceOp res_double_exc1({{cnot, 14}, {hadamard, 6}, {ry, 8}}); + RuleRefOp rule_double_exc1(double_exc, res_double_exc1, "_doublexcit"); + + // ('CRY', {RY:2, CNOT:2}, _cry) + ResourceOp res_cry({{ry, 2}, {cnot, 2}}); + RuleRefOp rule_cry(cry, res_cry, "_cry"); + + // ('S', {PhaseShift:1}, _s_phaseshift) + ResourceOp res_s1({{phase, 1}}); + RuleRefOp rule_s1(s, res_s1, "_s_phaseshift"); + + // ('S', {T:1}, _s_to_t) + ResourceOp res_s2({{t, 1}}); + RuleRefOp rule_s2(s, res_s2, "_s_to_t"); + + // ('PhaseShift', {RZ:1, GlobalPhase:1}, _phaseshift_to_rz_gp) + ResourceOp res_phase({{rz, 1}, {global_phase, 1}}); + RuleRefOp rule_phase(phase, res_phase, "_phaseshift_to_rz_gp"); + + // ('RZ', {Rot:1}, _rz_to_rot) + ResourceOp res_rz1({{rot, 1}}); + RuleRefOp rule_rz1(rz, res_rz1, "_rz_to_rot"); + + // ('RZ', {RY:2, RX:1}, _rz_to_ry_rx) + ResourceOp res_rz2({{ry, 2}, {rx, 1}}); + RuleRefOp rule_rz2(rz, res_rz2, "_rz_to_ry_rx"); + + // ('Rot', {RZ:2, RY:1}, _rot_to_rz_ry_rz) + ResourceOp res_rot({{rz, 2}, {ry, 1}}); + RuleRefOp rule_rot(rot, res_rot, "_rot_to_rz_ry_rz"); + + + std::vector ops(parsed_ops.begin(), parsed_ops.end()); + std::vector gateset = {ry, rx, cnot, hadamard, global_phase}; + std::vector rules = { + rule_single_exc, rule_single_exc_plus, + rule_double_exc1, + rule_cry, rule_s1, rule_s2, + rule_phase, rule_rz1, rule_rz2, + rule_rot + }; + + Solver solver(ops, gateset, rules); + auto solutions = solver.solve(); + // solver.show(); + + assert(solutions.size() == 9); + assert(solutions[single_exc] == "_single_excitation_decomp"); + assert(solutions[double_exc] == "_doublexcit"); + assert(solutions[rz] == "_rz_to_rot"); + assert(solutions[s] == "_s_phaseshift"); + assert(solutions[global_phase] == "base_op"); + assert(solutions[hadamard] == "base_op"); + assert(solutions[ry] == "base_op"); + assert(solutions[rx] == "base_op"); + assert(solutions[cnot] == "base_op"); + + std::cout << "[PASS] Solver tests (4)" << std::endl; + +} + + +int main() { + test_operator(); + test_resourceop(); + test_rulerefop(); + test_solver1(); + test_solver2(); + test_solver3(); + test_solver4(); + + std::cout << "All tests passed!" << std::endl; + return 0; +} + diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 318a1afb07..fa1bdf853d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,6 +2,13 @@

New features since last release

+* A new experimental decomposition system is introduced in Catalyst enabling the + PennyLane's graph-based decomposition and MLIR-based lowering of decomposition rules. + This feature is integrated with PennyLane program capture and graph-based decomposition + including support for custom decomposition rules and operators. + [(#2001)](https://github.com/PennyLaneAI/catalyst/pull/2001) + [(#2029)](https://github.com/PennyLaneAI/catalyst/pull/2029) + * Catalyst now supports dynamic wire allocation with ``qml.allocate()`` and ``qml.deallocate()`` when program capture is enabled. [(#2002)](https://github.com/PennyLaneAI/catalyst/pull/2002) diff --git a/frontend/catalyst/from_plxpr/decompose.py b/frontend/catalyst/from_plxpr/decompose.py new file mode 100644 index 0000000000..0fd15edbcf --- /dev/null +++ b/frontend/catalyst/from_plxpr/decompose.py @@ -0,0 +1,403 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A transform for the new MLIR-based Catalyst decomposition system. +""" + + +from __future__ import annotations + +import functools +import inspect +import types +from collections.abc import Callable +from typing import get_type_hints + +import jax +import pennylane as qml +from pennylane.decomposition import DecompositionGraph +from pennylane.typing import TensorLike +from pennylane.wires import WiresLike + +from catalyst.jax_primitives import decomposition_rule + +# A mapping from operation names to the number of wires they act on +# and the number of parameters they have. +# This is used when the operation is not in the captured operations +# but we still need to create a decomposition rule for it. +# +# Note that some operations have a variable number of wires, +# e.g., MultiRZ, GlobalPhase. For these, we set the number +# of wires to -1 to indicate a variable number. +# +# This will require a copy of the function to be made +# when creating the decomposition rule to avoid mutating +# the original function with attributes like num_wires. + +# A list of operations that can be represented +# in the Catalyst compiler. This will be a superset of +# the operations supported by the runtime. + +# FIXME: ops with OpName(params, wires) signatures can be +# represented in the Catalyst compiler. Unfortunately, +# the signature info is not sufficient as there are +# templates with the same signature that should be +# disambiguated. +COMPILER_OPS_FOR_DECOMPOSITION: dict[str, tuple[int, int]] = { + "CNOT": (2, 0), + "ControlledPhaseShift": (2, 1), + "CRot": (2, 3), + "CRX": (2, 1), + "CRY": (2, 1), + "CRZ": (2, 1), + "CSWAP": (3, 0), + "CY": (2, 0), + "CZ": (2, 0), + "Hadamard": (1, 0), + "Identity": (1, 0), + "IsingXX": (2, 1), + "IsingXY": (2, 1), + "IsingYY": (2, 1), + "IsingZZ": (2, 1), + "SingleExcitation": (2, 1), + "DoubleExcitation": (4, 1), + "ISWAP": (2, 0), + "PauliX": (1, 0), + "PauliY": (1, 0), + "PauliZ": (1, 0), + "PhaseShift": (1, 1), + "PSWAP": (2, 1), + "Rot": (1, 3), + "RX": (1, 1), + "RY": (1, 1), + "RZ": (1, 1), + "S": (1, 0), + "SWAP": (2, 0), + "T": (1, 0), + "Toffoli": (3, 0), + "U1": (1, 1), + "U2": (1, 2), + "U3": (1, 3), + "MultiRZ": (-1, 1), + "GlobalPhase": (-1, 1), +} + + +# pylint: disable=too-few-public-methods +class DecompRuleInterpreter(qml.capture.PlxprInterpreter): + """Interpreter for getting the decomposition graph solution + from a jaxpr when program capture is enabled. + + This interpreter captures all operations seen during the interpretation + and builds a decomposition graph to find efficient decomposition pathways + to a target gate set. + + This interpreter should be used with `qml.decomposition.enable_graph()` + to enable graph-based decomposition. + + Note that this doesn't actually decompose the operations during interpretation. + It only captures the operations and builds the decomposition graph. + The actual decomposition is done later in the MLIR decomposition pass. + + See also: :class:`~.DecompositionGraph`. + + Args: + gate_set (set[Operator] or None): The target gate set to decompose to + fixed_decomps (dict or None): A dictionary of fixed decomposition rules + to use in the decomposition graph. + alt_decomps (dict or None): A dictionary of alternative decomposition rules + to use in the decomposition graph. + + Raises: + TypeError: if graph-based decomposition is not enabled. + """ + + def __init__( + self, + *, + gate_set=None, + fixed_decomps=None, + alt_decomps=None, + ): # pylint: disable=too-many-arguments + + if not qml.decomposition.enabled_graph(): # pragma: no cover + raise TypeError( + "The DecompRuleInterpreter can only be used when" + "graph-based decomposition is enabled." + ) + + self._gate_set = gate_set + self._fixed_decomps = fixed_decomps + self._alt_decomps = alt_decomps + + self._captured = False + self._operations = set() + self._decomp_graph_solution = {} + + def interpret_operation(self, op: "qml.operation.Operator"): + """Interpret a PennyLane operation instance. + + Args: + op (Operator): a pennylane operator instance + + Returns: + Any + + This method is only called when the operator's output is a dropped variable, + so the output will not affect later equations in the circuit. + + We cache the list of operations seen during the interpretation + to build the decomposition graph in the later stages. + + See also: :meth:`~.interpret_operation_eqn`. + + """ + + self._operations.add(op) + data, struct = jax.tree_util.tree_flatten(op) + return jax.tree_util.tree_unflatten(struct, data) + + def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess"): + """Interpret a measurement process instance. + + Args: + measurement (MeasurementProcess): a measurement instance. + + See also :meth:`~.interpret_measurement_eqn`. + + """ + + # If we haven't captured and compiled the decomposition rules yet, + if not self._captured: + # Capture the current operations and mark as captured + self._captured = True + + # Solve the decomposition graph to get the decomposition rules + # for all the captured operations + # I know it's a bit hacky to do this here, but it's the only + # place where we can be sure that we have seen all operations + # in the circuit before the measurement. + # TODO: Find a better way to do this. + self._decomp_graph_solution = _solve_decomposition_graph( + self._operations, + self._gate_set, + fixed_decomps=self._fixed_decomps, + alt_decomps=self._alt_decomps, + ) + + # Create decomposition rules for each operation in the solution + # and compile them to Catalyst JAXPR decomposition rules + for op, rule in self._decomp_graph_solution.items(): + # Get number of wires if exists + op_num_wires = op.op.params.get("num_wires", None) + if ( + o := next( + ( + o + for o in self._operations + if o.name == op.op.name and len(o.wires) == op_num_wires + ), + None, + ) + ) is not None: + num_wires, num_params = COMPILER_OPS_FOR_DECOMPOSITION[op.op.name] + _create_decomposition_rule( + rule, + op_name=op.op.name, + num_wires=len(o.wires), + num_params=num_params, + requires_copy=num_wires == -1, + ) + elif op.op.name in COMPILER_OPS_FOR_DECOMPOSITION: + # In this part, we need to handle the case where an operation in + # the decomposition graph solution is not in the captured operations. + # This can happen if the operation is not directly called + # in the circuit, but is used inside a decomposition rule. + # In this case, we fall back to using the COMPILER_OPS_FOR_DECOMPOSITION + # dictionary to get the number of wires. + num_wires, num_params = COMPILER_OPS_FOR_DECOMPOSITION[op.op.name] + _create_decomposition_rule( + rule, + op_name=op.op.name, + num_wires=num_wires, + num_params=num_params, + requires_copy=num_wires == -1, + ) + else: # pragma: no cover + raise ValueError(f"Could not capture {op} without the number of wires.") + + data, struct = jax.tree_util.tree_flatten(measurement) + return jax.tree_util.tree_unflatten(struct, data) + + +def _create_decomposition_rule( + func: Callable, op_name: str, num_wires: int, num_params: int, requires_copy: bool = False +): + """Create a decomposition rule from a callable. + + See also: :func:`~.decomposition_rule`. + + Args: + func (Callable): The decomposition function. + op_name (str): The name of the operation to decompose. + num_wires (int): The number of wires the operation acts on. + num_params (int): The number of parameters the operation takes. + requires_copy (bool): Whether to create a copy of the function + to avoid mutating the original. This is required for operations + with a variable number of wires (e.g., MultiRZ, GlobalPhase). + """ + + sig_func = inspect.signature(func) + type_hints = get_type_hints(func) + + args = {} + for name in sig_func.parameters.keys(): + typ = type_hints.get(name, None) + + # Skip tailing args or kwargs in the rules + if name in ("__", "_"): + continue + + # TODO: This is a temporary solution until all rules have proper type annotations. + # Why? Because we need to pass the correct types to the decomposition_rule + # function to capture the rule correctly with JAX. + possible_names_for_single_param = { + "param", + "angle", + "phi", + "omega", + "theta", + "weight", + } + possible_names_for_multi_params = { + "params", + "angles", + "weights", + } + + # TODO: Support work-wires when it's supported in Catalyst. + possible_names_for_wires = {"wires", "wire", "control_wires", "target_wires"} + + if typ is TensorLike or name in possible_names_for_multi_params: + args[name] = qml.math.array([0.0] * num_params, like="jax", dtype=float) + elif typ is float or name in possible_names_for_single_param: + # TensorLike is a Union of float, int, array-like, so we use float here + # to cover the most common case as the JAX tracer doesn't like Union types + # and we don't have the actual values at this point. + args[name] = float + elif typ is WiresLike or name in possible_names_for_wires: + # Pass a dummy array of zeros with the correct number of wires + # This is required for the decomposition_rule to work correctly + # as it expects an array-like input for wires + args[name] = qml.math.array([0] * num_wires, like="jax") + elif typ is int: # pragma: no cover + # This is only for cases where the rule has an int parameter + # e.g., dimension in some gates. Not that common though! + # We cover this when adding end-to-end tests for rules + # in the MLIR PR. + args[name] = int + else: # pragma: no cover + raise ValueError( + f"Unsupported type annotation {typ} for parameter {name} in func {func}." + ) + + func_cp = make_def_copy(func) if requires_copy else func + + # Set custom attributes for the decomposition rule + # These attributes are used in the MLIR decomposition pass + # to identify the target gate and the number of wires + setattr(func_cp, "target_gate", op_name) + setattr(func_cp, "num_wires", num_wires) + + if requires_copy: + # Include number of wires in the function name to avoid name clashes + # when the same rule is compiled multiple times with different number of wires + # (e.g., MultiRZ, GlobalPhase) + func_cp.__name__ += f"_wires_{num_wires}" # pylint: disable=protected-access + + return decomposition_rule(func_cp)(**args) + + +# pylint: disable=protected-access +def _solve_decomposition_graph(operations, gate_set, fixed_decomps, alt_decomps): + """Get the decomposition graph solution for the given operations and gate set. + + TODO: Extend `DecompGraphSolution` API and avoid accessing protected members + directly in this function. + + Args: + operations (set[Operator]): The set of operations to decompose. + gate_set (set[Operator]): The target gate set to decompose to. + fixed_decomps (dict or None): A dictionary of fixed decomposition rules + to use in the decomposition graph. + alt_decomps (dict or None): A dictionary of alternative decomposition rules + to use in the decomposition graph. + + Returns: + dict: A dictionary mapping operations to their decomposition rules. + """ + + # decomp_graph_solution + decomp_graph_solution = {} + + decomp_graph = DecompositionGraph( + operations, + gate_set, + fixed_decomps=fixed_decomps, + alt_decomps=alt_decomps, + ) + + # Find the efficient pathways to the target gate set + solutions = decomp_graph.solve() + + def is_solved_for(op): + return ( + op in solutions._all_op_indices + and solutions._all_op_indices[op] in solutions._visitor.distances + ) + + for op_node, op_node_idx in solutions._all_op_indices.items(): + if is_solved_for(op_node) and op_node_idx in solutions._visitor.predecessors: + d_node_idx = solutions._visitor.predecessors[op_node_idx] + decomp_graph_solution[op_node] = solutions._graph[d_node_idx].rule._impl + + return decomp_graph_solution + + +# pylint: disable=protected-access +def make_def_copy(func): + """Create a copy of a Python definition to avoid mutating the original. + + This is especially useful when compiling decomposition rules with + parametric number of wires (e.g., MultiRZ, GlobalPhase) multiple times, + as the compilation process may add attributes to the function that + can interfere with subsequent compilations. + + Args: + func (Callable): The function to copy. + + Returns: + Callable: A copy of the original function with the same attributes. + """ + # Create a new function object with the same code, globals, name, defaults, and closure + func_copy = types.FunctionType( + func.__code__, + func.__globals__, + name=func.__name__, + argdefs=func.__defaults__, + closure=func.__closure__, + ) + + # Now, we create and update the wrapper to copy over attributes like docstring, module, etc. + return functools.update_wrapper(func_copy, func) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index aba015a42e..6ae7e5a0d3 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -15,12 +15,13 @@ This submodule defines a utility for converting plxpr into Catalyst jaxpr. """ # pylint: disable=protected-access +# pylint: disable=too-many-lines + from copy import copy from functools import partial from typing import Callable import jax -import jax.core import jax.numpy as jnp import pennylane as qml from jax._src.sharding_impls import UNSPECIFIED @@ -45,6 +46,7 @@ from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot from catalyst.device import extract_backend_info +from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder, get_in_qubit_values from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( @@ -177,11 +179,16 @@ def f(x): class WorkflowInterpreter(PlxprInterpreter): - """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant.""" + """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant.""" def __init__(self): self._pass_pipeline = [] self.init_qreg = None + + # Compiler options for the new decomposition system + self.requires_decompose_lowering = False + self.decompose_tkwargs = {} # target gateset + super().__init__() @@ -201,7 +208,24 @@ def handle_qnode( consts = args[shots_len : n_consts + shots_len] non_const_args = args[shots_len + n_consts :] - closed_jaxpr = ClosedJaxpr(qfunc_jaxpr, consts) + closed_jaxpr = ( + ClosedJaxpr(qfunc_jaxpr, consts) + if not self.requires_decompose_lowering + else _apply_compiler_decompose_to_plxpr( + inner_jaxpr=qfunc_jaxpr, + consts=consts, + ncargs=non_const_args, + tgateset=list(self.decompose_tkwargs.get("gate_set", [])), + ) + ) + + if self.requires_decompose_lowering: + closed_jaxpr = _collect_and_compile_graph_solutions( + inner_jaxpr=closed_jaxpr.jaxpr, + consts=closed_jaxpr.consts, + tkwargs=self.decompose_tkwargs, + ncargs=non_const_args, + ) def calling_convention(*args): device_init_p.bind( @@ -220,6 +244,16 @@ def calling_convention(*args): device_release_p.bind() return retvals + if self.requires_decompose_lowering: + # Add gate_set attribute to the quantum kernel primitive + # decompose_gatesets is treated as a queue of gatesets to be used + # but we only support a single gateset for now in from_plxpr + # as supporting multiple gatesets requires an MLIR/C++ graph-decomposition + # implementation. The current Python implementation cannot be mixed + # with other transforms in between. + gateset = [_get_operator_name(op) for op in self.decompose_tkwargs.get("gate_set", [])] + setattr(qnode, "decompose_gatesets", [gateset]) + return quantum_kernel_p.bind( wrap_init(calling_convention, debug_info=qfunc_jaxpr.debug_info), *non_const_args, @@ -268,6 +302,51 @@ def handle_transform( non_const_args = args[args_slice] targs = args[targs_slice] + # If the transform is a decomposition transform + # and the graph-based decomposition is enabled + if ( + hasattr(pl_plxpr_transform, "__name__") + and pl_plxpr_transform.__name__ == "decompose_plxpr_to_plxpr" + and qml.decomposition.enabled_graph() + ): + if not self.requires_decompose_lowering: + self.requires_decompose_lowering = True + else: + raise NotImplementedError( + "Multiple decomposition transforms are not yet supported." + ) + + # Update the decompose_gateset to be used by the quantum kernel primitive + # TODO: we originally wanted to treat decompose_gateset as a queue of + # gatesets to be used by the decompose-lowering pass at MLIR + # but this requires a C++ implementation of the graph-based decomposition + # which doesn't exist yet. + self.decompose_tkwargs = tkwargs + + # Note. We don't perform the compiler-specific decomposition here + # to be able to support multiple decomposition transforms + # and collect all the required gatesets + # as well as being able to support other transforms in between. + + # The compiler specific transformation will be performed + # in the qnode handler. + + # Add the decompose-lowering pass to the start of the pipeline + self._pass_pipeline.insert(0, Pass("decompose-lowering")) + + # We still need to construct and solve the graph based on + # the current jaxpr based on the current gateset + # but we don't rewrite the jaxpr at this stage. + + # gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs) + + # def gds_wrapper(*args): + # return gds_interpreter.eval(inner_jaxpr, consts, *args) + + # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args) + # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) + return self.eval(inner_jaxpr, consts, *non_const_args) + if catalyst_pass_name is None: # Use PL's ExpandTransformsInterpreter to expand this and any embedded # transform according to PL rules. It works by overriding the primitive @@ -287,10 +366,10 @@ def wrapper(*args): ) return self.eval(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args) - else: - # Apply the corresponding Catalyst pass counterpart - self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) - return self.eval(inner_jaxpr, consts, *non_const_args) + + # Apply the corresponding Catalyst pass counterpart + self._pass_pipeline.insert(0, Pass(catalyst_pass_name)) + return self.eval(inner_jaxpr, consts, *non_const_args) # This is our registration factory for PL transforms. The loop below iterates @@ -301,6 +380,7 @@ def wrapper(*args): register_transform(pl_transform, pass_name, decomposition) +# pylint: disable=too-many-instance-attributes class PLxPRToQuantumJaxprInterpreter(PlxprInterpreter): """ Unlike the previous interpreters which modified the getattr and setattr @@ -571,7 +651,6 @@ def handle_decomposition_rule(self, *, pyfun, func_jaxpr, is_qreg, num_params): """ Transform a quantum decomposition rule from PLxPR into JAXPR with quantum primitives. """ - if is_qreg: self.init_qreg.insert_all_dangling_qubits() @@ -848,3 +927,86 @@ def trace_from_pennylane( jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs) return jaxpr, out_type, out_treedef, sig + + +def _apply_compiler_decompose_to_plxpr(inner_jaxpr, consts, tgateset, ncargs): + """Apply the compiler-specific decomposition for a given JAXPR. + + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + tgateset (list): A list of target gateset for decomposition. + ncargs (list): Non-constant arguments for the JAXPR. + qargs (list): All arguments including constants and non-constants. + + Returns: + ClosedJaxpr: The decomposed JAXPR. + """ + + # Disable the graph decomposition optimization + + # Why? Because for the compiler-specific decomposition we want to + # only decompose higher-level gates and templates that only have + # a single decomposition, and not do any further optimization + # based on the graph solution. + # Besides, the graph-based decomposition is not supported + # yet in from_plxpr for most gates and templates. + + # TODO: Enable the graph-based decomposition + qml.decomposition.disable_graph() + + # First perform the pre-mlir decomposition to simplify the jaxpr + # by decomposing high-level gates and templates + gate_set = set(COMPILER_OPS_FOR_DECOMPOSITION.keys()).union(tgateset) + + final_jaxpr = qml.transforms.decompose.plxpr_transform( + inner_jaxpr, consts, (), {"gate_set": gate_set}, *ncargs + ) + + qml.decomposition.enable_graph() + + return final_jaxpr + + +def _collect_and_compile_graph_solutions(inner_jaxpr, consts, tkwargs, ncargs): + """Collect and compile graph solutions for a given JAXPR. + + This function uses the DecompRuleInterpreter to evaluate + the input JAXPR and obtain a new JAXPR that incorporates + the graph-based decomposition solutions. + + This function doesn't modify the underlying quantum function + but rather constructs a new JAXPR with decomposition rules. + + Args: + inner_jaxpr (Jaxpr): The input JAXPR to be decomposed. + consts (list): The constants used in the JAXPR. + tkwargs (list): The keyword arguments of the decompose transform. + ncargs (list): Non-constant arguments for the JAXPR. + + Returns: + ClosedJaxpr: The decomposed JAXPR. + """ + gds_interpreter = DecompRuleInterpreter(**tkwargs) + + def gds_wrapper(*args): + return gds_interpreter.eval(inner_jaxpr, consts, *args) + + final_jaxpr = jax.make_jaxpr(gds_wrapper)(*ncargs) + + return final_jaxpr + + +def _get_operator_name(op): + """Get the name of a pennylane operator, handling wrapped operators. + + Note: Controlled and Adjoint ops aren't supported in `gate_set` + by PennyLane's DecompositionGraph; unit tests were added in PennyLane. + """ + if isinstance(op, str): + return op + + # Return NoNameOp if the operator has no _primitive.name attribute. + # This is to avoid errors when we capture the program + # as we deal with such ops later in the decomposition graph. + return getattr(op._primitive, "name", "NoNameOp") diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 8f754913e7..fec8e41684 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -395,7 +395,7 @@ def wrapper(*args, **kwargs): return wrapper -def decomposition_rule(func=None, *, is_qreg=False, num_params=0): +def decomposition_rule(func=None, *, is_qreg=True, num_params=0): """ Denotes the creation of a quantum definition in the intermediate representation. """ @@ -590,7 +590,10 @@ def _decomposition_rule_lowering(ctx, *, pyfun, func_jaxpr, **_): """Lower a quantum decomposition rule into MLIR in a single step process. The step is the compilation of the definition of the function fn. """ - lower_callable(ctx, pyfun, func_jaxpr) + + # Set the visibility of the decomposition rule to public + # to avoid the elimination by the compiler + lower_callable(ctx, pyfun, func_jaxpr, public=True) return () diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 411aba7cc6..1227b7f1b5 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -26,6 +26,8 @@ from mlir_quantum.dialects._transform_ops_gen import ApplyRegisteredPassOp, NamedSequenceOp, YieldOp from mlir_quantum.dialects.catalyst import LaunchKernelOp +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval + def get_call_jaxpr(jaxpr): """Extracts the `call_jaxpr` from a JAXPR if it exists.""" "" @@ -44,7 +46,16 @@ def get_call_equation(jaxpr): def lower_jaxpr(ctx, jaxpr, context=None): - """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p""" + """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p + + Args: + ctx: LoweringRuleContext + jaxpr: JAXPR to be lowered + context: additional context to distinguish different FuncOps + + Returns: + FuncOp + """ equation = get_call_equation(jaxpr) call_jaxpr = equation.params["call_jaxpr"] callable_ = equation.params.get("fn") @@ -54,7 +65,8 @@ def lower_jaxpr(ctx, jaxpr, context=None): return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context) -def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): +# pylint: disable=too-many-arguments, too-many-positional-arguments +def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False): """Lowers _callable to MLIR. If callable_ is a qnode, then we will first create a module, then @@ -66,6 +78,8 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): ctx: LoweringRuleContext callable_: python function call_jaxpr: jaxpr representing callable_ + public: whether the visibility should be marked public + Returns: FuncOp """ @@ -73,25 +87,49 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None): pipeline = tuple() if not isinstance(callable_, qml.QNode): - return get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) + return get_or_create_funcop( + ctx, callable_, call_jaxpr, pipeline, context=context, public=public + ) return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) -def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None): - """Get funcOp from cache, or create it from scratch""" +# pylint: disable=too-many-arguments, too-many-positional-arguments +def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False): + """Get funcOp from cache, or create it from scratch + + Args: + ctx: LoweringRuleContext + callable_: python function + call_jaxpr: jaxpr representing callable_ + context: additional context to distinguish different FuncOps + public: whether the visibility should be marked public + + Returns: + FuncOp + """ if context is None: context = tuple() key = (callable_, *context, *pipeline) if func_op := get_cached(ctx, key): return func_op - func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr) + func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public) cache(ctx, key, func_op) return func_op -def lower_callable_to_funcop(ctx, callable_, call_jaxpr): - """Lower callable to either a FuncOp""" +def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): + """Lower callable to either a FuncOp + + Args: + ctx: LoweringRuleContext + callable_: python function + call_jaxpr: jaxpr representing callable_ + public: whether the visibility should be marked public + + Returns: + FuncOp + """ if isinstance(call_jaxpr, core.Jaxpr): call_jaxpr = core.ClosedJaxpr(call_jaxpr, ()) @@ -101,10 +139,16 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr): name = callable_.__name__ else: name = callable_.func.__name__ + ".partial" + kwargs["name"] = name kwargs["jaxpr"] = call_jaxpr kwargs["effects"] = [] kwargs["name_stack"] = ctx.name_stack + + # Make the visibility of the function public=True + # to avoid elimination by the compiler + kwargs["public"] = public + func_op = mlir.lower_jaxpr_to_fun(**kwargs) if isinstance(callable_, qml.QNode): @@ -135,6 +179,19 @@ def only_single_expval(): func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) + # Register the decomposition gatesets to the QNode FuncOp + # This will set a queue of gatesets that enables support for multiple + # levels of decomposition in the MLIR decomposition pass + if gateset := getattr(callable_, "decompose_gatesets", []): + func_op.attributes["decompose_gatesets"] = get_mlir_attribute_from_pyval(gateset) + + # Extract the target gate and number of wires from decomposition rules + # and set them as attributes on the FuncOp for use in the MLIR decomposition pass + if target_gate := getattr(callable_, "target_gate", None): + func_op.attributes["target_gate"] = get_mlir_attribute_from_pyval(target_gate) + if num_wires := getattr(callable_, "num_wires", None): + func_op.attributes["num_wires"] = get_mlir_attribute_from_pyval(num_wires) + return func_op diff --git a/frontend/catalyst/passes/builtin_passes.py b/frontend/catalyst/passes/builtin_passes.py index 564e75b401..d4cbe1f84d 100644 --- a/frontend/catalyst/passes/builtin_passes.py +++ b/frontend/catalyst/passes/builtin_passes.py @@ -394,6 +394,25 @@ def circuit(x: float): return PassPipelineWrapper(qnode, "merge-rotations") +def decompose_lowering(qnode): + """ + Specify that the ``-decompose-lowering`` MLIR compiler pass + for applying the compiled decomposition rules to the QNode + recursively. + + Args: + fn (QNode): the QNode to apply the cancel inverses compiler pass to + + Returns: + ~.QNode: + + **Example** + // TODO: add example here + + """ + return PassPipelineWrapper(qnode, "decompose-lowering") # pragma: no cover + + def ions_decomposition(qnode): # pragma: nocover """ Specify that the ``--ions-decomposition`` MLIR compiler pass should be diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py index 37872a1f0c..33cefda4a7 100644 --- a/frontend/catalyst/passes/pass_api.py +++ b/frontend/catalyst/passes/pass_api.py @@ -374,6 +374,7 @@ def dictionary_to_list_of_passes(pass_pipeline: PipelineDict | str, *flags, **va def _API_name_to_pass_name(): return { "cancel_inverses": "remove-chained-self-inverse", + "decompose_lowering": "decompose-lowering", "disentangle_cnot": "disentangle-CNOT", "disentangle_swap": "disentangle-SWAP", "merge_rotations": "merge-rotations", diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py index 0c61109e9a..d53d450b25 100644 --- a/frontend/test/lit/test_decomposition.py +++ b/frontend/test/lit/test_decomposition.py @@ -3,8 +3,10 @@ import pathlib import platform from copy import deepcopy +from functools import partial import jax +import numpy as np import pennylane as qml from pennylane.devices.capabilities import OperatorProperties from pennylane.typing import TensorLike @@ -29,6 +31,7 @@ # RUN: %PYTHON %s | FileCheck %s # pylint: disable=line-too-long +# pylint: disable=too-many-lines TEST_PATH = os.path.dirname(__file__) @@ -273,7 +276,7 @@ def decompose_to_matrix(): def test_decomposition_rule_wire_param(): """Test decomposition rule with passing a parameter that is a wire/integer""" - @decomposition_rule + @decomposition_rule(is_qreg=False) def Hadamard0(wire: WiresLike): qml.Hadamard(wire) @@ -288,7 +291,7 @@ def circuit(_: float): Hadamard0(int) return qml.probs() - # CHECK: func.func private @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit + # CHECK: func.func public @Hadamard0([[QBIT:%.+]]: !quantum.bit) -> !quantum.bit # CHECK-NEXT: [[QUBIT_OUT:%.+]] = quantum.custom "Hadamard"() [[QBIT]] : !quantum.bit # CHECK-NEXT: return [[QUBIT_OUT]] : !quantum.bit @@ -303,7 +306,7 @@ def circuit(_: float): def test_decomposition_rule_gate_param_param(): """Test decomposition rule with passing a regular parameter""" - @decomposition_rule(num_params=1) + @decomposition_rule(is_qreg=False, num_params=1) def RX_on_wire_0(param: TensorLike, w0: WiresLike): qml.RX(param, wires=w0) @@ -316,7 +319,7 @@ def circuit_2(_: float): RX_on_wire_0(float, int) return qml.probs() - # CHECK: func.func private @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit + # CHECK: func.func public @RX_on_wire_0([[PARAM_TENSOR:%.+]]: tensor, [[QUBIT:%.+]]: !quantum.bit) -> !quantum.bit # CHECK-NEXT: [[PARAM:%.+]] = tensor.extract [[PARAM_TENSOR]][] : tensor # CHECK-NEXT: [[QUBIT_1:%.+]] = quantum.custom "RX"([[PARAM]]) [[QUBIT]] : !quantum.bit # CHECK-NEXT: return [[QUBIT_1]] : !quantum.bit @@ -336,7 +339,7 @@ def test_multiple_decomposition_rules(): @decomposition_rule def identity(): ... - @decomposition_rule(num_params=1) + @decomposition_rule(is_qreg=True) def all_wires_rx(param: TensorLike, w0: WiresLike, w1: WiresLike, w2: WiresLike): qml.RX(param, wires=w0) qml.RX(param, wires=w1) @@ -355,8 +358,8 @@ def circuit_3(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @identity - # CHECK: func.func private @all_wires_rx + # CHECK: func.func public @identity + # CHECK: func.func public @all_wires_rx print(circuit_3.mlir) qml.capture.disable() @@ -384,7 +387,7 @@ def circuit_4(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg + # CHECK: func.func public @shaped_wires_rule([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<3xi64>) -> !quantum.reg # CHECK-NEXT: [[IDX_0:%.+]] = stablehlo.slice [[QUBITS]] [0:1] : (tensor<3xi64>) -> tensor<1xi64> # CHECK-NEXT: [[RIDX_0:%.+]] = stablehlo.reshape [[IDX_0]] : (tensor<1xi64>) -> tensor # CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract [[RIDX_0]][] : tensor @@ -409,7 +412,7 @@ def shaped_wires_rule(param: TensorLike, wires: WiresLike): qml.RX(param, wires=wires[1]) qml.RX(param, wires=wires[2]) - @decomposition_rule(num_params=1, is_qreg=False) + @decomposition_rule(is_qreg=False, num_params=1) def expanded_wires_rule(param: TensorLike, w1, w2, w3): shaped_wires_rule(param, [w1, w2, w3]) @@ -421,7 +424,7 @@ def circuit_5(_: float): qml.Hadamard(0) return qml.probs() - # CHECK: func.func private @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit) + # CHECK: func.func public @expanded_wires_rule(%arg0: tensor, %arg1: !quantum.bit, %arg2: !quantum.bit, %arg3: !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit) print(circuit_5.mlir) qml.capture.disable() @@ -452,7 +455,7 @@ def circuit_6(): cond_RX(float, jax.core.ShapedArray((1,), int)) return qml.probs() - # CHECK: func.func private @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg + # CHECK: func.func public @cond_RX([[QREG:%.+]]: !quantum.reg, [[PARAM_TENSOR:%.+]]: tensor, [[QUBITS:%.+]]: tensor<1xi64>) -> !quantum.reg # CHECK-NEXT: [[ZERO:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor # CHECK-NEXT: [[COND_TENSOR:%.+]] = stablehlo.compare NE, [[PARAM_TENSOR]], [[ZERO]], FLOAT : (tensor, tensor) -> tensor # CHECK-NEXT: [[COND:%.+]] = tensor.extract [[COND_TENSOR]][] : tensor @@ -479,17 +482,17 @@ def test_decomposition_rule_caller(): qml.capture.enable() @decomposition_rule(is_qreg=True) - def Op1_decomp(_: TensorLike, wires: WiresLike): + def rule_op1_decomp(_: TensorLike, wires: WiresLike): qml.Hadamard(wires=wires[0]) qml.Hadamard(wires=[1]) @decomposition_rule(is_qreg=True) - def Op2_decomp(param: TensorLike, wires: WiresLike): + def rule_op2_decomp(param: TensorLike, wires: WiresLike): qml.RX(param, wires=wires[0]) def decomps_caller(param: TensorLike, wires: WiresLike): - Op1_decomp(param, wires) - Op2_decomp(param, wires) + rule_op1_decomp(param, wires) + rule_op2_decomp(param, wires) @qml.qjit(autograph=False) @qml.qnode(qml.device("lightning.qubit", wires=1)) @@ -500,11 +503,643 @@ def circuit_7(): decomps_caller(float, jax.core.ShapedArray((2,), int)) return qml.probs() - # CHECK: func.func private @Op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - # CHECK: func.func private @Op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg - + # CHECK: func.func public @rule_op1_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg + # CHECK: func.func public @rule_op2_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg print(circuit_7.mlir) qml.capture.disable() test_decomposition_rule_caller() + + +def test_decompose_gateset_without_graph(): + """Test the decompose transform to a target gate set without the graph decomposition.""" + + qml.capture.enable() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_8() -> tensor attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage, qnode} + def circuit_8(): + return qml.expval(qml.Z(0)) + + print(circuit_8.mlir) + + qml.capture.disable() + + +test_decompose_gateset_without_graph() + + +def test_decompose_gateset_with_graph(): + """Test the decompose transform to a target gate set with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @simple_circuit_9() -> tensor attributes {decompose_gatesets + def simple_circuit_9(): + return qml.expval(qml.Z(0)) + + print(simple_circuit_9.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_9() -> tensor attributes {decompose_gatesets + def circuit_9(): + return qml.expval(qml.Z(0)) + + print(circuit_9.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_with_graph() + + +def test_decompose_gateset_operator_with_graph(): + """Test the decompose transform to a target gate set with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={qml.RX}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @simple_circuit_10() -> tensor attributes {decompose_gatesets + def simple_circuit_10(): + return qml.expval(qml.Z(0)) + + print(simple_circuit_10.mlir) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, gate_set={qml.RX, qml.RZ, "PauliZ", qml.PauliX, qml.Hadamard} + ) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @circuit_10() -> tensor attributes {decompose_gatesets + def circuit_10(): + return qml.expval(qml.Z(0)) + + print(circuit_10.mlir) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, gate_set={qml.RX, qml.RZ, qml.PauliZ, qml.PauliX, qml.Hadamard} + ) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_11() -> tensor attributes {decompose_gatesets + def circuit_11(): + return qml.expval(qml.Z(0)) + + print(circuit_11.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_operator_with_graph() + + +def test_decompose_gateset_with_rotxzx(): + """Test the decompose transform with a custom operator with the graph decomposition.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RotXZX"}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: public @simple_circuit_12() -> tensor attributes {decompose_gatesets + def simple_circuit_12(): + return qml.expval(qml.Z(0)) + + print(simple_circuit_12.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={qml.ftqc.RotXZX}) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_12() -> tensor attributes {decompose_gatesets + def circuit_12(): + return qml.expval(qml.Z(0)) + + print(circuit_12.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_gateset_with_rotxzx() + + +def test_decomposition_rule_name(): + """Test the name of the decomposition rule is not updated with circuit instantiation.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @decomposition_rule + def _ry_to_rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + @decomposition_rule + def _rot_to_rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + @decomposition_rule + def _u2_phaseshift_rot_decomposition(phi, delta, wires, **__): + """Decomposition of U2 gate using Rot and PhaseShift gates.""" + pi_half = qml.math.ones_like(delta) * (np.pi / 2) + qml.Rot(delta, pi_half, -delta, wires=wires) + qml.PhaseShift(delta, wires=wires) + qml.PhaseShift(phi, wires=wires) + + @decomposition_rule + def _xzx_decompose(phi, theta, omega, wires, **__): + """Decomposition of Rot gate using RX and RZ gates in XZX format.""" + qml.RX(phi, wires=wires) + qml.RZ(theta, wires=wires) + qml.RX(omega, wires=wires) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ", "PhaseShift"}) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_13() -> tensor attributes {decompose_gatesets + def circuit_13(): + _ry_to_rz_rx(float, int) + _rot_to_rz_ry_rz(float, float, float, int) + _u2_phaseshift_rot_decomposition(float, float, int) + _xzx_decompose(float, float, float, int) + return qml.expval(qml.Z(0)) + + # CHECK: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor) -> !quantum.reg + # CHECK: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + # CHECK: func.func public @_u2_phaseshift_rot_decomposition(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> !quantum.reg + # CHECK: func.func public @_xzx_decompose(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> !quantum.reg + print(circuit_13.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name() + + +def test_decomposition_rule_name_update(): + """Test the name of the decomposition rule is updated in the MLIR output.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RZ: 2, qml.RX: 1}) + def rz_rx(phi, wires: WiresLike, **__): + """Decomposition of RY gate using RZ and RX gates.""" + qml.RZ(-np.pi / 2, wires=wires) + qml.RX(phi, wires=wires) + qml.RZ(np.pi / 2, wires=wires) + + @qml.register_resources({qml.RZ: 2, qml.RY: 1}) + def rz_ry_rz(phi, theta, omega, wires: WiresLike, **__): + """Decomposition of Rot gate using RZ and RY gates.""" + qml.RZ(phi, wires=wires) + qml.RY(theta, wires=wires) + qml.RZ(omega, wires=wires) + + @qml.register_resources({qml.RY: 1, qml.PhaseShift: 1}) + def ry_gp(wires: WiresLike, **__): + """Decomposition of PauliY gate using RY and GlobalPhase gates.""" + qml.RY(np.pi, wires=wires) + qml.GlobalPhase(-np.pi / 2, wires=wires) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RZ", "PhaseShift"}, + fixed_decomps={ + qml.RY: rz_rx, + qml.Rot: rz_ry_rz, + qml.PauliY: ry_gp, + }, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_14() -> tensor attributes {decompose_gatesets + def circuit_14(): + qml.RY(0.5, wires=0) + qml.Rot(0.1, 0.2, 0.3, wires=1) + qml.PauliY(wires=2) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + # CHECK-DAG: func.func public @ry_gp(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg + print(circuit_14.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_update() + + +def test_decomposition_rule_name_update_multi_qubits(): + """Test the name of the decomposition rule with multi-qubit gates.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", "CNOT", "Hadamard", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: public @circuit_15() -> tensor attributes {decompose_gatesets + def circuit_15(): + qml.SingleExcitation(0.5, wires=[0, 1]) + qml.SingleExcitationPlus(0.5, wires=[0, 1]) + qml.SingleExcitationMinus(0.5, wires=[0, 1]) + qml.DoubleExcitation(0.5, wires=[0, 1, 2, 3]) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} + # CHECK-DAG: func.func public @_s_phaseshift(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "S"} + # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + # CHECK-DAG: func.func public @_doublexcit(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<4xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 4 : i64, target_gate = "DoubleExcitation"} + # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} + print(circuit_15.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_update_multi_qubits() + + +def test_decomposition_rule_name_adjoint(): + """Test decomposition rule with qml.adjoint.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", "CZ", "GlobalPhase", "Adjoint(SingleExcitation)"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + def circuit_16(x: float): + # CHECK-DAG: %1 = quantum.adjoint(%0) : !quantum.reg + # CHECK-DAG: %2 = quantum.adjoint(%1) : !quantum.reg + # CHECK-DAG: %3 = quantum.adjoint(%2) : !quantum.reg + # CHECK-DAG: %4 = quantum.adjoint(%3) : !quantum.reg + qml.adjoint(qml.CNOT)(wires=[0, 1]) + qml.adjoint(qml.Hadamard)(wires=2) + qml.adjoint(qml.RZ)(0.5, wires=3) + qml.adjoint(qml.SingleExcitation)(0.1, wires=[0, 1]) + qml.adjoint(qml.SingleExcitation(x, wires=[0, 1])) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @_single_excitation_decomp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SingleExcitation"} + # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"} + print(circuit_16.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_adjoint() + + +def test_decomposition_rule_name_ctrl(): + """Test decomposition rule with qml.ctrl.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RZ", "H", "CZ"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=2)) + # CHECK-DAG: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK{LITERAL}: func.func public @circuit_17() -> tensor attributes {decompose_gatesets + def circuit_17(): + # CHECK: %out_qubits:2 = quantum.custom "CRY"(%cst) %1, %2 : !quantum.bit, !quantum.bit + # CHECK-NEXT: %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits#0, %out_qubits#1 : !quantum.bit, !quantum.bit + qml.ctrl(qml.RY, control=0)(0.5, 1) + qml.ctrl(qml.PauliX, control=0)(1) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @_cnot_to_cz_h(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CNOT"} + # CHECK-DAG: func.func public @_cry(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "CRY"} + # CHECK-DAG: func.func public @_ry_to_rz_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RY"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circuit_17.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decomposition_rule_name_ctrl() + + +def test_qft_decomposition(): + """Test the decomposition of the QFT""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RY", "CNOT", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=4)) + # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" + # CHECK: func.func public @circuit_18(%arg0: tensor<3xf64>) -> tensor attributes {decompose_gatesets + def circuit_18(): + # %6 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %0) -> (!quantum.reg) { + # %23 = scf.for %arg3 = %c0 to %22 step %c1 iter_args(%arg4 = %21) -> (!quantum.reg) { + # %7 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %6) -> (!quantum.reg) { + qml.QFT(wires=[0, 1, 2, 3]) + return qml.expval(qml.Z(0)) + + # CHECK-DAG: func.func public @_cphase_to_rz_cnot(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "ControlledPhaseShift"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + # CHECK-DAG: func.func public @_swap_to_cnot(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "SWAP"} + # CHECK-DAG: func.func public @_hadamard_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Hadamard"} + print(circuit_18.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_qft_decomposition() + + +def test_decompose_lowering_with_other_passes(): + """Test the decompose lowering pass with other passes in a pass pipeline.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @partial( + qml.transforms.decompose, + gate_set={"RZ", "RY", "CNOT", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: module attributes {transform.with_named_sequence} { + # CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + # CHECK-NEXT: [[ONE:%.+]] = transform.apply_registered_pass "decompose-lowering" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: [[TWO:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to [[ONE]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to [[TWO]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: transform.yield + # CHECK-NEXT: } + def circuit_19(): + + # CHECK: [[OUT_0:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit + # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "PauliX"() [[OUT_0]] : !quantum.bit + # CHECK-NEXT: [[OUT_2:%.+]] = quantum.custom "RX"(%cst_0) [[OUT_1]] : !quantum.bit + # CHECK-NEXT: {{%.+}} = quantum.custom "RX"(%cst) [[OUT_2]] : !quantum.bit + qml.PauliX(0) + qml.PauliX(0) + qml.RX(0.1, wires=0) + qml.RX(-0.1, wires=0) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"} + # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} + print(circuit_19.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_with_other_passes() + + +def test_decompose_lowering_multirz(): + """Test the decompose lowering pass with MultiRZ in the gate set.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"CNOT", "RZ"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" + def circuit_20(x: float): + # CHECK: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.multirz([[EXTRACTED]]) %1 : !quantum.bit + # CHECK-NEXT: [[BIT_1:%.+]] = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[EXTRACTED_0:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_QUBITS_1:%.+]] = quantum.multirz([[EXTRACTED_0]]) [[OUT_QUBITS]], [[BIT_1]] : !quantum.bit, !quantum.bit + # CHECK-NEXT: [[BIT_2:%.+]] = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: {{%.+}} = quantum.multirz([[EXTRACTED_2]]) {{%.+}}, {{%.+}}, [[BIT_2]] : !quantum.bit, !quantum.bit, !quantum.bit + qml.MultiRZ(x, wires=[0]) + qml.MultiRZ(x, wires=[0, 1]) + qml.MultiRZ(x, wires=[1, 0, 2]) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_1(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_2(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<2xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 2 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_multi_rz_decomposition_wires_3(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: %0 = scf.for %arg3 = %c0 to %c2 step %c1 iter_args(%arg4 = %arg0) -> (!quantum.reg) + # CHECK-DAG: %5 = scf.for %arg3 = %c1 to %c3 step %c1 iter_args(%arg4 = %4) -> (!quantum.reg) + print(circuit_20.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_multirz() + + +def test_decompose_lowering_with_ordered_passes(): + """Test the decompose lowering pass with other passes in a specific order in a pass pipeline.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RZ", "RY", "CNOT", "GlobalPhase"}, + ) + @qml.transforms.merge_rotations + @qml.transforms.cancel_inverses + @qml.qnode(qml.device("lightning.qubit", wires=1)) + # CHECK: module attributes {transform.with_named_sequence} { + # CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.op<"builtin.module">) { + # CHECK-NEXT: [[FIRST:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" to %arg0 : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: [[SECOND:%.+]] = transform.apply_registered_pass "merge-rotations" to [[FIRST]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "decompose-lowering" to [[SECOND]] : (!transform.op<"builtin.module">) -> !transform.op<"builtin.module"> + # CHECK-NEXT: transform.yield + # CHECK-NEXT: } + def circuit_21(x: float): + # CHECK: [[OUT:%.+]] = quantum.custom "PauliX"() %1 : !quantum.bit + # CHECK-NEXT: [[OUT_0:%.+]] = quantum.custom "PauliX"() [[OUT]] : !quantum.bit + # CHECK-NEXT: [[EXTRACTED:%.+]] = tensor.extract %arg0[] : tensor + # CHECK-NEXT: [[OUT_1:%.+]] = quantum.custom "RX"([[EXTRACTED]]) [[OUT_0]] : !quantum.bit + # CHECK-NEXT: [[NEGATED:%.+]] = stablehlo.negate %arg0 : tensor + # CHECK-NEXT: [[EXTRACTED_2:%.+]] = tensor.extract [[NEGATED]][] : tensor + # CHECK-NEXT: {{%.+}} = quantum.custom "RX"([[EXTRACTED_2]]) [[OUT_1]] : !quantum.bit + qml.PauliX(0) + qml.PauliX(0) + qml.RX(x, wires=0) + qml.RX(-x, wires=0) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_paulix_to_rx(%arg0: !quantum.reg, %arg1: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PauliX"} + # CHECK-DAG: func.func public @_rx_to_rz_ry(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RX"} + # CHECK-DAG: func.func public @_rot_to_rz_ry_rz(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circuit_21.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_with_ordered_passes() + + +def test_decompose_lowering_with_gphase(): + """Test the decompose lowering pass with GlobalPhase.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RY", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3)) + # CHECK: %0 = transform.apply_registered_pass "decompose-lowering" + def circuit_22(): + # CHECK: quantum.gphase(%cst_0) : + # CHECK-NEXT: [[EXTRACTED:%.+]] = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + # CHECK-NEXT: [[OUT_QUBITS:%.+]] = quantum.custom "PhaseShift"(%cst) [[EXTRACTED]] : !quantum.bit + # CHECK-NEXT: {{%.+}} = quantum.custom "PhaseShift"(%cst) [[OUT_QUBITS]] : !quantum.bit + qml.GlobalPhase(0.5) + qml.ctrl(qml.GlobalPhase, control=0)(0.3) + qml.ctrl(qml.GlobalPhase, control=0)(phi=0.3, wires=[1, 2]) + return qml.expval(qml.PauliX(0)) + + # CHECK-DAG: func.func public @_phaseshift_to_rz_gp(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "PhaseShift"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + print(circuit_22.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_with_gphase() + + +def test_decompose_lowering_alt_decomps(): + """Test the decompose lowering pass with alternative decompositions.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RY: 1}) + def custom_rot_cheap(params, wires: WiresLike): + qml.RY(params[1], wires=wires) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RZ"}, + alt_decomps={qml.Rot: [custom_rot_cheap]}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000) + def circuit_23(x: float, y: float): + qml.Rot(x, y, x + y, wires=1) + return qml.expval(qml.PauliZ(0)) + + # CHECK-DAG: func.func public @custom_rot_cheap(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circuit_23.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_alt_decomps() + + +def test_decompose_lowering_with_tensorlike(): + """Test the decompose lowering pass with fixed decompositions + using TensorLike parameters.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.register_resources({qml.RZ: 2, qml.RY: 1}) + def custom_rot(params: TensorLike, wires: WiresLike): + qml.RZ(params[0], wires=wires) + qml.RY(params[1], wires=wires) + qml.RZ(params[2], wires=wires) + + @qml.register_resources({qml.RZ: 1, qml.CNOT: 4}) + def custom_multirz(params: TensorLike, wires: WiresLike): + qml.CNOT(wires=(wires[2], wires[1])) + qml.CNOT(wires=(wires[1], wires[0])) + qml.RZ(params[0], wires=wires[0]) + qml.CNOT(wires=(wires[1], wires[0])) + qml.CNOT(wires=(wires[2], wires[1])) + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RY", "RX", qml.CNOT}, + fixed_decomps={qml.Rot: custom_rot, qml.MultiRZ: custom_multirz}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=3), shots=1000) + def circuit_24(x: float, y: float): + qml.Rot(x, y, x + y, wires=1) + qml.MultiRZ(x + y, wires=[0, 1, 2]) + return qml.expval(qml.PauliZ(0)) + + # CHECK-DAG: func.func public @custom_multirz_wires_3(%arg0: !quantum.reg, %arg1: tensor<1xf64>, %arg2: tensor<3xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 3 : i64, target_gate = "MultiRZ"} + # CHECK-DAG: func.func public @_rz_to_ry_rx(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "RZ"} + # CHECK-DAG: func.func public @custom_rot(%arg0: !quantum.reg, %arg1: tensor<3xf64>, %arg2: tensor<1xi64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage, num_wires = 1 : i64, target_gate = "Rot"} + print(circuit_24.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_decompose_lowering_with_tensorlike() diff --git a/frontend/test/lit/test_dynamic_qubit_allocation.py b/frontend/test/lit/test_dynamic_qubit_allocation.py index 85b3ede88a..8b4d7f114c 100644 --- a/frontend/test/lit/test_dynamic_qubit_allocation.py +++ b/frontend/test/lit/test_dynamic_qubit_allocation.py @@ -82,7 +82,8 @@ def test_basic_dynalloc(): # CHECK: [[CNOTout:%.+]]:2 = quantum.custom "CNOT"() [[dyn_bit2]], [[dev_bit1]] # CHECK: [[insert0:%.+]] = quantum.insert [[dyn_qreg]][ 1], [[Xout]] # CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 2], [[CNOTout]]#0 - # CHECK: quantum.dealloc [[insert1]] + # CHECK: [[insert2:%.+]] = quantum.insert [[insert1]][ 3] + # CHECK: quantum.dealloc [[insert2]] with qml.allocate(4) as qs1: qml.X(qs1[1]) diff --git a/frontend/test/lit/test_from_plxpr.py b/frontend/test/lit/test_from_plxpr.py index 7fd8618cc6..14cea4a8e6 100644 --- a/frontend/test/lit/test_from_plxpr.py +++ b/frontend/test/lit/test_from_plxpr.py @@ -18,6 +18,8 @@ """Lit tests for the PLxPR to JAXPR with quantum primitives pipeline""" +from functools import partial + import pennylane as qml @@ -45,7 +47,7 @@ def main(): print(main.mlir) - qml.capture.enable() + qml.capture.disable() test_conditional_capture() @@ -362,3 +364,60 @@ def circuit(): test_pass_application() + + +def test_pass_decomposition(): + """Application of pass decorator with decomposition.""" + + dev = qml.device("null.qubit", wires=1) + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.qnode(dev) + def circuit1(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] + + print(circuit1.mlir) + + @qml.qjit(target="mlir") + @qml.transforms.cancel_inverses + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit2(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "decompose-lowering" + # CHECK-NEXT: transform.apply_registered_pass "remove-chained-self-inverse" to [[second_pass]] + + print(circuit2.mlir) + + @qml.qjit(target="mlir") + @partial(qml.transforms.decompose, gate_set={"RX", "RZ"}) + @qml.transforms.cancel_inverses + @qml.transforms.merge_rotations + @qml.qnode(dev) + def circuit3(): + return qml.probs() + + # CHECK: [[first_pass:%.+]] = transform.apply_registered_pass "merge-rotations" + # CHECK-NEXT: [[second_pass:%.+]] = transform.apply_registered_pass "remove-chained-self-inverse" + # CHECK-NEXT: transform.apply_registered_pass "decompose-lowering" to [[second_pass]] + + print(circuit3.mlir) + + qml.decomposition.disable_graph() + qml.capture.disable() + + +test_pass_decomposition() diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr.py b/frontend/test/pytest/from_plxpr/test_from_plxpr.py index 47b0b65be8..9d2ce69a2c 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr.py @@ -15,6 +15,8 @@ This module tests the from_plxpr conversion function. """ +from functools import partial + import jax import numpy as np import pennylane as qml @@ -965,5 +967,39 @@ def workflow(x, y): assert qml.math.allclose(results, expected) +class TestGraphDecomposition: + """Test the new graph-based decomposition integration with from_plxpr.""" + + def test_with_multiple_decomps_transforms(self): + """Test that a circuit with multiple decompositions and transforms can be converted.""" + + qml.capture.enable() + qml.decomposition.enable_graph() + + @qml.qjit(target="mlir") + @partial( + qml.transforms.decompose, + gate_set={"RX", "RY"}, + ) + @partial( + qml.transforms.decompose, + gate_set={"NOT", "GlobalPhase"}, + ) + @qml.qnode(qml.device("lightning.qubit", wires=0)) + def circuit(x): + qml.GlobalPhase(x) + return qml.expval(qml.PauliX(0)) + + with pytest.raises( + NotImplementedError, match="Multiple decomposition transforms are not yet supported." + ): + circuit(0.2) + + qml.decomposition.disable_graph() + qml.capture.disable() + + assert qml.decomposition.enabled_graph() is False + + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 89498b185a..0b3a76296d 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -241,6 +241,7 @@ def ExtractOp : Memory_Op<"extract", [NoMemoryEffect]> { $qreg `[` ($idx^):($idx_attr)? `]` attr-dict `:` type($qreg) `->` type(results) }]; + let hasCanonicalizeMethod = 1; let hasVerifier = 1; let hasFolder = 1; } diff --git a/mlir/include/Quantum/Transforms/Passes.h b/mlir/include/Quantum/Transforms/Passes.h index 00f33d8fa4..33b25c0179 100644 --- a/mlir/include/Quantum/Transforms/Passes.h +++ b/mlir/include/Quantum/Transforms/Passes.h @@ -30,6 +30,7 @@ std::unique_ptr createRemoveChainedSelfInversePass(); std::unique_ptr createAnnotateFunctionPass(); std::unique_ptr createSplitMultipleTapesPass(); std::unique_ptr createMergeRotationsPass(); +std::unique_ptr createDecomposeLoweringPass(); std::unique_ptr createDisentangleCNOTPass(); std::unique_ptr createDisentangleSWAPPass(); std::unique_ptr createIonsDecompositionPass(); diff --git a/mlir/include/Quantum/Transforms/Passes.td b/mlir/include/Quantum/Transforms/Passes.td index f0a344190e..918b2032dc 100644 --- a/mlir/include/Quantum/Transforms/Passes.td +++ b/mlir/include/Quantum/Transforms/Passes.td @@ -110,6 +110,12 @@ def MergeRotationsPass : Pass<"merge-rotations"> { let constructor = "catalyst::createMergeRotationsPass()"; } +def DecomposeLoweringPass : Pass<"decompose-lowering"> { + let summary = "Replace quantum operations with compiled decomposition rules."; + + let constructor = "catalyst::createDecomposeLoweringPass()"; +} + def DisentangleCNOTPass : Pass<"disentangle-CNOT"> { let summary = "Replace a CNOT gate with two single qubit gates whenever possible."; diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index a16569c01b..8b8ade74c1 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -15,8 +15,12 @@ #pragma once #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/AllocatorBase.h" namespace catalyst { namespace quantum { @@ -26,6 +30,9 @@ void populateAdjointPatterns(mlir::RewritePatternSet &); void populateSelfInversePatterns(mlir::RewritePatternSet &); void populateMergeRotationsPatterns(mlir::RewritePatternSet &); void populateIonsDecompositionPatterns(mlir::RewritePatternSet &); +void populateDecomposeLoweringPatterns(mlir::RewritePatternSet &, + const llvm::StringMap &, + const llvm::StringSet &); void populateLoopBoundaryPatterns(mlir::RewritePatternSet &, unsigned int mode); } // namespace quantum diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 0e7be8337b..88d97a8674 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -38,6 +38,7 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createCliffordTToPPRPass); mlir::registerPass(catalyst::createMergePPRIntoPPMPass); mlir::registerPass(catalyst::createPPMCompilationPass); + mlir::registerPass(catalyst::createDecomposeLoweringPass); mlir::registerPass(catalyst::createDecomposeNonCliffordPPRPass); mlir::registerPass(catalyst::createDecomposeCliffordPPRPass); mlir::registerPass(catalyst::createCountPPMSpecsPass); diff --git a/mlir/lib/Quantum/IR/QuantumDialect.cpp b/mlir/lib/Quantum/IR/QuantumDialect.cpp index 7049f58e63..14f5e6e811 100644 --- a/mlir/lib/Quantum/IR/QuantumDialect.cpp +++ b/mlir/lib/Quantum/IR/QuantumDialect.cpp @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser + #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser -#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser +#include "mlir/Transforms/InliningUtils.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/IR/QuantumOps.h" @@ -22,6 +25,65 @@ using namespace mlir; using namespace catalyst::quantum; +//===----------------------------------------------------------------------===// +// Quantum Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { + +struct QuantumInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + static constexpr StringRef decompAttr = "target_gate"; + + /// Returns true if the given operation 'callable' can be inlined into the + /// position given by the 'call'. Currently, we always inline quantum + /// decomposition functions. + bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final + { + if (auto funcOp = dyn_cast(callable)) { + return funcOp->hasAttr(decompAttr); + } + return false; + } + + /// Returns true if the given region 'src' can be inlined into the region + /// 'dest'. Only allow for decomposition functions. + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final + { + if (auto funcOp = src->getParentOfType()) { + return funcOp->hasAttr(decompAttr); + } + return false; + } + + // Allow to inline operations from decomposition functions. + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final + { + if (auto funcOp = op->getParentOfType()) { + return funcOp->hasAttr(decompAttr); + } + return false; + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. Required when the region has only one block. + void handleTerminator(Operation *op, ValueRange valuesToRepl) const final + { + auto yieldOp = dyn_cast(op); + if (!yieldOp) { + return; + } + + for (auto retValue : llvm::zip(valuesToRepl, yieldOp.getOperands())) { + std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue)); + } + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Quantum dialect definitions. //===----------------------------------------------------------------------===// @@ -45,6 +107,8 @@ void QuantumDialect::initialize() #include "Quantum/IR/QuantumOps.cpp.inc" >(); + addInterfaces(); + declarePromisedInterfaces(); diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 101ed77a57..b0eedf09d3 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -146,6 +146,28 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) return nullptr; } +LogicalResult ExtractOp::canonicalize(ExtractOp extract, mlir::PatternRewriter &rewriter) +{ + // Handle the pattern: %reg2 = insert %reg1[idx], %qubit -> %q = extract %reg2[idx] + // Convert to: %q = %qubit, and replace other uses of %reg2 with %reg1 + if (auto insert = dyn_cast_if_present(extract.getQreg().getDefiningOp())) { + bool bothStatic = extract.getIdxAttr().has_value() && insert.getIdxAttr().has_value(); + bool bothDynamic = !extract.getIdxAttr().has_value() && !insert.getIdxAttr().has_value(); + bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr(); + bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx(); + // if other users of insert are also `insert`, we are good to go + bool valid = llvm::all_of(insert.getResult().getUsers(), [&](Operation *op) { + return isa(op) || op == extract.getOperation(); + }); + if ((staticallyEqual || dynamicallyEqual) && valid) { + rewriter.replaceOp(extract, insert.getQubit()); + rewriter.replaceOp(insert, insert.getInQreg()); + return success(); + } + } + return failure(); +} + LogicalResult InsertOp::canonicalize(InsertOp insert, mlir::PatternRewriter &rewriter) { if (auto extract = dyn_cast_if_present(insert.getQubit().getDefiningOp())) { @@ -153,9 +175,10 @@ LogicalResult InsertOp::canonicalize(InsertOp insert, mlir::PatternRewriter &rew bool bothDynamic = !extract.getIdxAttr().has_value() && !insert.getIdxAttr().has_value(); bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr(); bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx(); + bool sameQreg = extract.getQreg() == insert.getInQreg(); bool oneUse = extract.getResult().hasOneUse(); - if ((staticallyEqual || dynamicallyEqual) && oneUse) { + if ((staticallyEqual || dynamicallyEqual) && oneUse && sameQreg) { rewriter.replaceOp(insert, insert.getInQreg()); rewriter.eraseOp(extract); return success(); diff --git a/mlir/lib/Quantum/Transforms/CMakeLists.txt b/mlir/lib/Quantum/Transforms/CMakeLists.txt index 3a244ac4d6..26b3ac8410 100644 --- a/mlir/lib/Quantum/Transforms/CMakeLists.txt +++ b/mlir/lib/Quantum/Transforms/CMakeLists.txt @@ -14,6 +14,8 @@ file(GLOB SRC SplitMultipleTapes.cpp merge_rotation.cpp MergeRotationsPatterns.cpp + decompose_lowering.cpp + DecomposeLoweringPatterns.cpp DisentangleSWAP.cpp DisentangleCNOT.cpp ions_decompositions.cpp diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp new file mode 100644 index 0000000000..9dcc4ea1ad --- /dev/null +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -0,0 +1,456 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "decompose-lowering" + +#include + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" + +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" + +using namespace mlir; +using namespace catalyst::quantum; + +namespace catalyst { +namespace quantum { + +/// A struct to represent qubit indices in quantum operations. +/// +/// This struct provides a way to handle qubit indices that can be either: +/// - A runtime Value (for dynamic indices computed at runtime) +/// - An IntegerAttr (for compile-time constant indices) +/// - Invalid/uninitialized (represented by std::monostate) +/// +/// The struct uses std::variant to ensure only one type is active at a time, +/// preventing invalid states. +/// +/// Example usage: +/// QubitIndex dynamicIdx(operandValue); // Runtime qubit index +/// QubitIndex staticIdx(IntegerAttr::get(...)); // Compile-time constant +/// QubitIndex invalidIdx; // Uninitialized state +/// +/// if (dynamicIdx) { // Check if valid +/// if (dynamicIdx.isValue()) { // Check if runtime value +/// Value idx = dynamicIdx.getValue(); // Get the Value +/// } +/// } +struct QubitIndex { + // use monostate to represent the invalid index + std::variant index; + + QubitIndex() : index(std::monostate()) {} + QubitIndex(Value val) : index(val) {} + QubitIndex(IntegerAttr attr) : index(attr) {} + + bool isValue() const { return std::holds_alternative(index); } + bool isAttr() const { return std::holds_alternative(index); } + operator bool() const { return isValue() || isAttr(); } + Value getValue() const { return isValue() ? std::get(index) : nullptr; } + IntegerAttr getAttr() const { return isAttr() ? std::get(index) : nullptr; } +}; + +// The goal of this class is to analyze the signature of a custom operation to get the enough +// information to prepare the call operands and results for replacing the op to calling the +// decomposition function. +class OpSignatureAnalyzer { + public: + OpSignatureAnalyzer() = delete; + OpSignatureAnalyzer(CustomOp op, bool enableQregMode) + : signature(OpSignature{ + .params = op.getParams(), + .inQubits = op.getInQubits(), + .inCtrlQubits = op.getInCtrlQubits(), + .inCtrlValues = op.getInCtrlValues(), + .outQubits = op.getOutQubits(), + .outCtrlQubits = op.getOutCtrlQubits(), + }) + { + if (!enableQregMode) + return; + + signature.sourceQreg = getSourceQreg(signature.inQubits.front()); + if (!signature.sourceQreg) { + op.emitError("Cannot get source qreg"); + isValid = false; + return; + } + + // input wire indices + for (Value qubit : signature.inQubits) { + const QubitIndex index = getExtractIndex(qubit); + if (!index) { + op.emitError("Cannot get index for input qubit"); + isValid = false; + return; + } + signature.inWireIndices.emplace_back(index); + } + + // input ctrl wire indices + for (Value ctrlQubit : signature.inCtrlQubits) { + const QubitIndex index = getExtractIndex(ctrlQubit); + if (!index) { + op.emitError("Cannot get index for ctrl qubit"); + isValid = false; + return; + } + signature.inCtrlWireIndices.emplace_back(index); + } + + // Output qubit indices are the same as input qubit indices + signature.outQubitIndices = signature.inWireIndices; + signature.outCtrlQubitIndices = signature.inCtrlWireIndices; + } + + operator bool() const { return isValid; } + + // Prepare the operands for calling the decomposition function + // There are two cases: + // 1. The first input is a qreg, which means the decomposition function is a qreg mode function + // 2. Otherwise, the decomposition function is a qubit mode function + // + // Type signatures: + // 1. qreg mode: + // - func(qreg, param*, inWires*, inCtrlWires*?, inCtrlValues*?) -> qreg + // 2. qubit mode: + // - func(param*, inQubits*, inCtrlQubits*?, inCtrlValues*?) -> outQubits* + llvm::SmallVector prepareCallOperands(func::FuncOp decompFunc, PatternRewriter &rewriter, + Location loc) + { + auto funcType = decompFunc.getFunctionType(); + auto funcInputs = funcType.getInputs(); + + SmallVector operands(funcInputs.size()); + + int operandIdx = 0; + if (isa(funcInputs[0])) { + Value updatedQreg = signature.sourceQreg; + for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { + const QubitIndex &index = signature.inWireIndices[i]; + updatedQreg = + rewriter.create(loc, updatedQreg.getType(), updatedQreg, + index.getValue(), index.getAttr(), qubit); + } + + operands[operandIdx++] = updatedQreg; + if (!signature.params.empty()) { + auto [startIdx, endIdx] = + findParamTypeRange(funcInputs, signature.params.size(), operandIdx); + ArrayRef paramsTypes = funcInputs.slice(startIdx, endIdx - startIdx); + auto updatedParams = generateParams(signature.params, paramsTypes, rewriter, loc); + for (Value param : updatedParams) { + operands[operandIdx++] = param; + } + } + + if (!signature.inWireIndices.empty()) { + operands[operandIdx] = fromTensorOrAsIs(signature.inWireIndices, + funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + + if (!signature.inCtrlWireIndices.empty()) { + operands[operandIdx] = fromTensorOrAsIs(signature.inCtrlWireIndices, + funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + } + else { + if (!signature.params.empty()) { + auto [startIdx, endIdx] = + findParamTypeRange(funcInputs, signature.params.size(), operandIdx); + ArrayRef paramsTypes = funcInputs.slice(startIdx, endIdx - startIdx); + auto updatedParams = generateParams(signature.params, paramsTypes, rewriter, loc); + for (Value param : updatedParams) { + operands[operandIdx++] = param; + } + } + + for (auto inQubit : signature.inQubits) { + operands[operandIdx] = + fromTensorOrAsIs(inQubit, funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + + for (auto inCtrlQubit : signature.inCtrlQubits) { + operands[operandIdx] = + fromTensorOrAsIs(inCtrlQubit, funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + } + + if (!signature.inCtrlValues.empty()) { + operands[operandIdx] = + fromTensorOrAsIs(signature.inCtrlValues, funcInputs[operandIdx], rewriter, loc); + operandIdx++; + } + + return operands; + } + + // Prepare the results for the call operation + SmallVector prepareCallResultForQreg(func::CallOp callOp, PatternRewriter &rewriter) + { + assert(callOp.getNumResults() == 1 && "only one qreg result for qreg mode is allowed"); + + auto qreg = callOp.getResult(0); + assert(isa(qreg.getType()) && "only allow to have qreg result"); + + SmallVector newResults; + rewriter.setInsertionPointAfter(callOp); + for (const QubitIndex &index : signature.outQubitIndices) { + auto extractOp = rewriter.create( + callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), + index.getAttr()); + newResults.emplace_back(extractOp.getResult()); + } + for (const QubitIndex &index : signature.outCtrlQubitIndices) { + auto extractOp = rewriter.create( + callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), + index.getAttr()); + newResults.emplace_back(extractOp.getResult()); + } + return newResults; + } + + private: + bool isValid = true; + + struct OpSignature { + ValueRange params; + ValueRange inQubits; + ValueRange inCtrlQubits; + ValueRange inCtrlValues; + ValueRange outQubits; + ValueRange outCtrlQubits; + + // Qreg mode specific information + Value sourceQreg = nullptr; + SmallVector inWireIndices; + SmallVector inCtrlWireIndices; + SmallVector outQubitIndices; + SmallVector outCtrlQubitIndices; + } signature; + + Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc) + { + if (isa(type)) { + return rewriter.create(loc, type, values); + } + return values.front(); + } + + static size_t getElementsCount(Type type) + { + if (isa(type)) { + auto tensorType = cast(type); + return tensorType.getNumElements() > 0 ? tensorType.getNumElements() : 1; + } + return 1; + } + + // Helper function to find the range of function input types that correspond to params + static std::pair findParamTypeRange(ArrayRef funcInputs, + size_t sigParamCount, size_t startIdx = 0) + { + size_t paramTypeCount = 0; + size_t paramTypeEnd = startIdx; + + while (paramTypeCount < sigParamCount) { + assert(paramTypeEnd < funcInputs.size() && + "param type end should be less than function input size"); + paramTypeCount += getElementsCount(funcInputs[paramTypeEnd]); + paramTypeEnd++; + } + + assert(paramTypeCount == sigParamCount && + "param type count should be equal to signature param count"); + + return {startIdx, paramTypeEnd}; + } + + // generate params for calling the decomposition function based on function type requirements + SmallVector generateParams(ValueRange signatureParams, ArrayRef funcParamTypes, + PatternRewriter &rewriter, Location loc) + { + SmallVector operands; + size_t sigParamIdx = 0; + + for (Type funcParamType : funcParamTypes) { + const size_t numElements = getElementsCount(funcParamType); + + // collect numElements of signature params + SmallVector tensorElements; + for (size_t i = 0; i < numElements && sigParamIdx < signatureParams.size(); i++) { + tensorElements.push_back(signatureParams[sigParamIdx++]); + } + operands.push_back(fromTensorOrAsIs(tensorElements, funcParamType, rewriter, loc)); + } + + return operands; + } + + Value fromTensorOrAsIs(ArrayRef indices, Type type, PatternRewriter &rewriter, + Location loc) + { + SmallVector values; + for (const QubitIndex &index : indices) { + if (index.isValue()) { + values.emplace_back(index.getValue()); + } + else if (index.isAttr()) { + auto attr = index.getAttr(); + auto constantValue = rewriter.create(loc, attr.getType(), attr); + values.emplace_back(constantValue); + } + } + + if (isa(type)) { + return rewriter.create(loc, type, values); + } + + assert(values.size() == 1 && "number of values should be 1 for non-tensor type"); + return values.front(); + } + + Value getSourceQreg(Value qubit) + { + while (qubit) { + if (auto extractOp = qubit.getDefiningOp()) { + return extractOp.getQreg(); + } + + if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) { + if (customOp.getQubitOperands().empty()) { + break; + } + qubit = customOp.getQubitOperands()[0]; + } + } + + return nullptr; + } + + QubitIndex getExtractIndex(Value qubit) + { + while (qubit) { + if (auto extractOp = qubit.getDefiningOp()) { + if (Value idx = extractOp.getIdx()) { + return QubitIndex(idx); + } + if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { + return QubitIndex(idxAttr); + } + } + + if (auto customOp = dyn_cast_or_null(qubit.getDefiningOp())) { + auto qubitOperands = customOp.getQubitOperands(); + auto qubitResults = customOp.getQubitResults(); + auto it = + llvm::find_if(qubitResults, [&](Value result) { return result == qubit; }); + + if (it != qubitResults.end()) { + size_t resultIndex = std::distance(qubitResults.begin(), it); + if (resultIndex < qubitOperands.size()) { + qubit = qubitOperands[resultIndex]; + continue; + } + } + } + + break; + } + + return QubitIndex(); + } +}; + +struct DecomposeLoweringRewritePattern : public OpRewritePattern { + private: + const llvm::StringMap &decompositionRegistry; + const llvm::StringSet &targetGateSet; + + public: + DecomposeLoweringRewritePattern(MLIRContext *context, + const llvm::StringMap ®istry, + const llvm::StringSet &gateSet) + : OpRewritePattern(context), decompositionRegistry(registry), targetGateSet(gateSet) + { + } + + LogicalResult matchAndRewrite(CustomOp op, PatternRewriter &rewriter) const override + { + StringRef gateName = op.getGateName(); + + // Only decompose the op if it is not in the target gate set + if (targetGateSet.contains(gateName)) { + return failure(); + } + + // Find the corresponding decomposition function for the op + auto it = decompositionRegistry.find(gateName); + if (it == decompositionRegistry.end()) { + return failure(); + } + func::FuncOp decompFunc = it->second; + + // Here is the assumption that the decomposition function must have at least one input and + // one result + assert(decompFunc.getFunctionType().getNumInputs() > 0 && + "Decomposition function must have at least one input"); + assert(decompFunc.getFunctionType().getNumResults() >= 1 && + "Decomposition function must have at least one result"); + + auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); + auto analyzer = OpSignatureAnalyzer(op, enableQreg); + assert(analyzer && "Analyzer should be valid"); + + rewriter.setInsertionPointAfter(op); + auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); + auto callOp = + rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), + decompFunc.getSymName(), callOperands); + + // Replace the op with the call op and adjust the insert ops for the qreg mode + if (callOp.getNumResults() == 1 && isa(callOp.getResult(0).getType())) { + auto results = analyzer.prepareCallResultForQreg(callOp, rewriter); + rewriter.replaceOp(op, results); + } + else { + rewriter.replaceOp(op, callOp->getResults()); + } + + return success(); + } +}; + +void populateDecomposeLoweringPatterns(RewritePatternSet &patterns, + const llvm::StringMap &decompositionRegistry, + const llvm::StringSet &targetGateSet) +{ + patterns.add(patterns.getContext(), decompositionRegistry, + targetGateSet); +} + +} // namespace quantum +} // namespace catalyst diff --git a/mlir/lib/Quantum/Transforms/decompose_lowering.cpp b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp new file mode 100644 index 0000000000..bbddf92023 --- /dev/null +++ b/mlir/lib/Quantum/Transforms/decompose_lowering.cpp @@ -0,0 +1,208 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "decompose-lowering" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/AllocatorBase.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +#include "Quantum/IR/QuantumOps.h" +#include "Quantum/Transforms/Patterns.h" + +using namespace mlir; +using namespace catalyst::quantum; + +namespace catalyst { +namespace quantum { +#define GEN_PASS_DEF_DECOMPOSELOWERINGPASS +#define GEN_PASS_DECL_DECOMPOSELOWERINGPASS +#include "Quantum/Transforms/Passes.h.inc" + +namespace DecompUtils { + +static constexpr StringRef target_gate_attr_name = "target_gate"; +static constexpr StringRef decomp_gateset_attr_name = "decomp_gateset"; + +// Check if a function is a decomposition function +// It's expected that the decomposition function would have this attribute: +// `catalyst.decomposition.target_op` And this attribute is set by the `markDecompositionAttributes` +// functionq The decomposition attribute are used to determine if a function is a decomposition +// function, and target_op is that the decomposition function want to replace +bool isDecompositionFunction(func::FuncOp func) { return func->hasAttr(target_gate_attr_name); } + +StringRef getTargetGateName(func::FuncOp func) +{ + if (auto target_op_attr = func->getAttrOfType(target_gate_attr_name)) { + return target_op_attr.getValue(); + } + return StringRef{}; +} + +} // namespace DecompUtils + +/// A module pass that work through a module, register all decomposition functions, and apply the +/// decomposition patterns +struct DecomposeLoweringPass : impl::DecomposeLoweringPassBase { + using DecomposeLoweringPassBase::DecomposeLoweringPassBase; + + void getDependentDialects(DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + private: + llvm::StringMap decompositionRegistry; + llvm::StringSet targetGateSet; + + // Function to discover and register decomposition functions from a module + // It's bookkeeping the targetOp and the decomposition function that can decompose the targetOp + void discoverAndRegisterDecompositions(ModuleOp module, + llvm::StringMap &decompositionRegistry) + { + module.walk([&](func::FuncOp func) { + if (StringRef targetOp = DecompUtils::getTargetGateName(func); !targetOp.empty()) { + decompositionRegistry[targetOp] = func; + } + // No need to walk into the function body + return WalkResult::skip(); + }); + } + + // Find the target gate set from the module.It's expected that the decomposition function would + // have this attribute: `decomp_gateset` And this attribute is set by the frontend, it contains + // the target gate set that the circuit function want to finally decompose into. Since each + // module only contains one circuit function, we can just find the target gate set from the + // function with the `decomp_gateset` attribute + void findTargetGateSet(ModuleOp module, llvm::StringSet &targetGateSet) + { + module.walk([&](func::FuncOp func) { + if (auto gate_set_attr = + func->getAttrOfType(DecompUtils::decomp_gateset_attr_name)) { + for (auto gate : gate_set_attr.getValue()) { + StringRef gate_name = cast(gate).getValue(); + targetGateSet.insert(gate_name); + } + return WalkResult::interrupt(); + } + // No need to walk into the function body + return WalkResult::skip(); + }); + } + + // Remove unused decomposition functions: + // Since the decomposition functions are marked as public from the frontend, + // there is no way to remove them with any DCE pass automatically. + // So we need to manually remove them from the module + void removeDecompositionFunctions(ModuleOp module, + llvm::StringMap &decompositionRegistry) + { + llvm::DenseSet usedDecompositionFunctions; + + module.walk([&](func::CallOp callOp) { + if (auto targetFunc = module.lookupSymbol(callOp.getCallee())) { + if (DecompUtils::isDecompositionFunction(targetFunc)) { + usedDecompositionFunctions.insert(targetFunc); + } + } + }); + + // remove unused decomposition functions + module.walk([&](func::FuncOp func) { + if (DecompUtils::isDecompositionFunction(func) && + !usedDecompositionFunctions.contains(func)) { + func.erase(); + } + return WalkResult::skip(); + }); + } + + public: + void runOnOperation() final + { + ModuleOp module = cast(getOperation()); + + // Step 1: Discover and register all decomposition functions in the module + discoverAndRegisterDecompositions(module, decompositionRegistry); + if (decompositionRegistry.empty()) { + return; + } + + // Step 1.1: Find the target gate set + findTargetGateSet(module, targetGateSet); + + // Step 2: Canonicalize the module + RewritePatternSet patternsCanonicalization(&getContext()); + catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization, + &getContext()); + if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) { + return signalPassFailure(); + } + + // Step 3: Apply the decomposition patterns + RewritePatternSet decompositionPatterns(&getContext()); + populateDecomposeLoweringPatterns(decompositionPatterns, decompositionRegistry, + targetGateSet); + if (failed(applyPatternsGreedily(module, std::move(decompositionPatterns)))) { + return signalPassFailure(); + } + + // Step 4: Inline and canonicalize/CSE the module again + PassManager pm(&getContext()); + pm.addPass(createInlinerPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + if (failed(pm.run(module))) { + return signalPassFailure(); + } + + // Step 5. Remove redundant decomposition functions + removeDecompositionFunctions(module, decompositionRegistry); + + // Step 6. Canonicalize the extract/insert pair + RewritePatternSet patternsInsertExtract(&getContext()); + catalyst::quantum::InsertOp::getCanonicalizationPatterns(patternsInsertExtract, + &getContext()); + catalyst::quantum::ExtractOp::getCanonicalizationPatterns(patternsInsertExtract, + &getContext()); + if (failed(applyPatternsGreedily(module, std::move(patternsInsertExtract)))) { + return signalPassFailure(); + } + } +}; + +} // namespace quantum + +std::unique_ptr createDecomposeLoweringPass() +{ + return std::make_unique(); +} + +} // namespace catalyst diff --git a/mlir/test/Quantum/CanonicalizationTest.mlir b/mlir/test/Quantum/CanonicalizationTest.mlir index 4c698ed4e7..4b9620575d 100644 --- a/mlir/test/Quantum/CanonicalizationTest.mlir +++ b/mlir/test/Quantum/CanonicalizationTest.mlir @@ -83,8 +83,7 @@ func.func @test_extract_insert_no_fold_static(%r1: !quantum.reg, %i1: i64, %i2: %q2 = quantum.extract %r2[0] : !quantum.reg -> !quantum.bit %r3 = quantum.insert %r2[%i1], %q2 : !quantum.reg, !quantum.bit - // CHECK: quantum.extract - // CHECK: quantum.insert + %q3 = quantum.extract %r3[%i1] : !quantum.reg -> !quantum.bit %r4 = quantum.insert %r3[%i2], %q3 : !quantum.reg, !quantum.bit @@ -167,14 +166,14 @@ func.func @test_interleaved_extract_insert() -> tensor<4xf64> { // CHECK: [[QBIT:%.+]] = quantum.extract [[QREG:%.+]][ // CHECK: [[QBIT_1:%.+]] = quantum.custom "Hadamard"() [[QBIT]] // CHECK: [[QREG_1:%.+]] = quantum.insert [[QREG]] - // CHECK-NOT: quantum.insert - // COM: check that insert op canonicalization correctly removes unnecessary extract/inserts + // CHECK-NOT: quantum.insert + // COM: check that insert op canonicalization correctly removes unnecessary extract/inserts // CHECK: quantum.compbasis qreg [[QREG_1]] %1 = quantum.extract %0[%c0_i64] : !quantum.reg -> !quantum.bit %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit %2 = quantum.extract %0[%c1_i64] : !quantum.reg -> !quantum.bit - %3 = quantum.insert %0[%c0_i64], %out_qubits : !quantum.reg, !quantum.bit - %4 = quantum.insert %3[%c1_i64], %2 : !quantum.reg, !quantum.bit + %3 = quantum.insert %0[%c1_i64], %2 : !quantum.reg, !quantum.bit + %4 = quantum.insert %3[%c0_i64], %out_qubits : !quantum.reg, !quantum.bit %5 = quantum.compbasis qreg %4 : !quantum.obs %6 = quantum.probs %5 : tensor<4xf64> quantum.dealloc %4 : !quantum.reg diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir new file mode 100644 index 0000000000..91bfbe7778 --- /dev/null +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -0,0 +1,510 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt --decompose-lowering --split-input-file -verify-diagnostics %s | FileCheck %s + +module @two_hadamards { + func.func public @test_two_hadamards() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<4xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst_0) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} + +// ----- + +// Test single Hadamard decomposition +module @single_hadamard { + func.func @test_single_hadamard() -> !quantum.bit { + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %2 = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: return [[QUBIT2]] + return %2 : !quantum.bit + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst_0) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} + +// ----- +module @recursive { + func.func public @test_recursive() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<4xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit + return %out_qubits_0 : !quantum.bit + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @RZRY_decomp + func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits_1 = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_2 = quantum.custom "RY"(%cst_0) %out_qubits_1 : !quantum.bit + return %out_qubits_2 : !quantum.bit + } +} + +// ----- +module @recursive { + func.func public @test_recursive() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + + // CHECK: [[QUBIT3:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT2]] : !quantum.bit + // CHECK: [[QUBIT4:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[QUBIT3]] : !quantum.bit + // CHECK-NOT: quantum.custom "Hadamard" + %out_qubits_0 = quantum.custom "Hadamard"() %out_qubits : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT4]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits_0 : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<4xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Hadamard_to_RY_decomp + func.func private @Hadamard_to_RY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "Hadamard", llvm.linkage = #llvm.linkage} { + %out_qubits_0 = quantum.custom "RZRY"() %arg0 : !quantum.bit + return %out_qubits_0 : !quantum.bit + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @RZRY_decomp + func.func private @RZRY_decomp(%arg0: !quantum.bit) -> !quantum.bit attributes {target_gate = "RZRY", llvm.linkage = #llvm.linkage} { + %cst = arith.constant 3.1415926535897931 : f64 + %cst_0 = arith.constant 1.5707963267948966 : f64 + %out_qubits_1 = quantum.custom "RZ"(%cst) %arg0 : !quantum.bit + %out_qubits_2 = quantum.custom "RY"(%cst_0) %out_qubits_1 : !quantum.bit + return %out_qubits_2 : !quantum.bit + } +} + +// ----- + +// Test parametric gates and wires +module @param_rxry { + func.func public @test_param_rxry(%arg0: tensor, %arg1: tensor) -> tensor<2xf64> { + %c0_i64 = arith.constant 0 : i64 + + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + %0 = quantum.alloc( 1) : !quantum.reg + + // CHECK: [[WIRE:%.+]] = tensor.extract %arg1[] : tensor + %extracted = tensor.extract %arg1[] : tensor + + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit + + // CHECK: [[PARAM:%.+]] = tensor.extract %arg0[] : tensor + %param_0 = tensor.extract %arg0[] : tensor + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "ParametrizedRXRY" + %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0) %1 : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<2xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<2xf64> + } + + // Decomposition function expects tensor while operation provides f64 + // CHECK-NOT: func.func private @ParametrizedRX_decomp + func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: !quantum.bit) -> !quantum.bit + attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} { + %extracted = tensor.extract %arg0[] : tensor + %out_qubits = quantum.custom "RX"(%extracted) %arg1 : !quantum.bit + %extracted_0 = tensor.extract %arg0[] : tensor + %out_qubits_1 = quantum.custom "RY"(%extracted_0) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} +// ----- + +// Test parametric gates and wires +module @param_rxry_2 { + func.func public @test_param_rxry_2(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xf64> { + %c0_i64 = arith.constant 0 : i64 + + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + %0 = quantum.alloc( 1) : !quantum.reg + + // CHECK: [[WIRE:%.+]] = tensor.extract %arg2[] : tensor + %extracted = tensor.extract %arg2[] : tensor + + // CHECK: [[QUBIT:%.+]] = quantum.extract [[REG]][[[WIRE]]] : !quantum.reg -> !quantum.bit + %1 = quantum.extract %0[%extracted] : !quantum.reg -> !quantum.bit + + // CHECK: [[PARAM_0:%.+]] = tensor.extract %arg0[] : tensor + %param_0 = tensor.extract %arg0[] : tensor + + // CHECK: [[PARAM_1:%.+]] = tensor.extract %arg1[] : tensor + %param_1 = tensor.extract %arg1[] : tensor + + // CHECK: [[QUBIT1:%.+]] = quantum.custom "RX"([[PARAM_0]]) [[QUBIT]] : !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.custom "RY"([[PARAM_1]]) [[QUBIT1]] : !quantum.bit + // CHECK-NOT: quantum.custom "ParametrizedRXRY" + %out_qubits = quantum.custom "ParametrizedRXRY"(%param_0, %param_1) %1 : !quantum.bit + + // CHECK: [[UPDATED_REG:%.+]] = quantum.insert [[REG]][ 0], [[QUBIT2]] : !quantum.reg, !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<2xf64> + quantum.dealloc %2 : !quantum.reg + return %4 : tensor<2xf64> + } + + // Decomposition function expects tensor while operation provides f64 + // CHECK-NOT: func.func private @ParametrizedRX_decomp + func.func private @ParametrizedRXRY_decomp(%arg0: tensor, %arg1: tensor, %arg2: !quantum.bit) -> !quantum.bit + attributes {target_gate = "ParametrizedRXRY", llvm.linkage = #llvm.linkage} { + %extracted_param_0 = tensor.extract %arg0[] : tensor + %out_qubits = quantum.custom "RX"(%extracted_param_0) %arg2 : !quantum.bit + %extracted_param_1 = tensor.extract %arg1[] : tensor + %out_qubits_1 = quantum.custom "RY"(%extracted_param_1) %out_qubits : !quantum.bit + return %out_qubits_1 : !quantum.bit + } +} +// ----- + +// Test recursive and qreg-based gate decomposition +module @qreg_base_circuit { + func.func public @test_qreg_base_circuit() -> tensor<2xf64> { + // CHECK: [[CST:%.+]] = arith.constant 1.000000e+00 : f64 + %cst = arith.constant 1.000000e+00 : f64 + + // CHECK: [[CST_0:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: [[CST_1:%.+]] = arith.constant dense<0> : tensor<1xi64> + // CHECK: [[CST_2:%.+]] = arith.constant dense<1.000000e+00> : tensor + // CHECK: [[REG:%.+]] = quantum.alloc( 1) : !quantum.reg + %0 = quantum.alloc( 1) : !quantum.reg + + // CHECK: [[EXTRACT_QUBIT:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[MRES:%.+]], [[OUT_QUBIT:%.+]] = quantum.measure [[EXTRACT_QUBIT]] : i1, !quantum.bit + // CHECK: [[REG1:%.+]] = quantum.insert [[REG]][ 0], [[OUT_QUBIT]] : !quantum.reg, !quantum.bit + // CHECK: [[COMPARE:%.+]] = stablehlo.compare NE, [[CST_2]], [[CST_0]], FLOAT : (tensor, tensor) -> tensor + // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[COMPARE]][] : tensor + // CHECK: [[CONDITIONAL:%.+]] = scf.if [[EXTRACTED]] -> (!quantum.reg) { + // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST_1]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED_3:%.+]] = tensor.extract [[RESHAPE1]][] : tensor + // CHECK: [[FROM_ELEMENTS:%.+]] = tensor.from_elements [[EXTRACTED_3]] : tensor<1xi64> + // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[FROM_ELEMENTS]] [0:1] : (tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED_4:%.+]] = tensor.extract [[RESHAPE2]][] : tensor + // CHECK: [[EXTRACT1:%.+]] = quantum.extract [[REG1]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT1]] : !quantum.bit + // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_4]]], [[RZ1]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[REG1]][[[EXTRACTED_3]]], [[EXTRACT2]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED_4]]] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST]]) [[EXTRACT3]] : !quantum.bit + // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_4]]], [[RZ2]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACT4:%.+]] = quantum.extract [[INSERT3]][[[EXTRACTED_3]]] : !quantum.reg -> !quantum.bit + // CHECK: [[INSERT4:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED_3]]], [[EXTRACT4]] : !quantum.reg, !quantum.bit + // CHECK: scf.yield [[INSERT4]] : !quantum.reg + // CHECK: } else { + // CHECK: scf.yield [[REG1]] : !quantum.reg + // CHECK: } + // CHECK-NOT: quantum.custom "Test" + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Test"(%cst) %1 : !quantum.bit + %2 = quantum.insert %0[ 0], %out_qubits : !quantum.reg, !quantum.bit + %3 = quantum.compbasis qreg %2 : !quantum.obs + %4 = quantum.probs %3 : tensor<2xf64> + + quantum.dealloc %2 : !quantum.reg + quantum.device_release + return %4 : tensor<2xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @Test_rule_1 + func.func private @Test_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + attributes {target_gate = "Test", llvm.linkage = #llvm.linkage} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %10 = quantum.extract %arg0[ 0] : !quantum.reg -> !quantum.bit + %mres, %out_qubit = quantum.measure %10 : i1, !quantum.bit + %11 = quantum.insert %arg0[ 0], %out_qubit : !quantum.reg, !quantum.bit + %0 = stablehlo.compare NE, %arg1, %cst, FLOAT : (tensor, tensor) -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = scf.if %extracted -> (!quantum.reg) { + %2 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %3 = stablehlo.reshape %2 : (tensor<1xi64>) -> tensor + %extracted_0 = tensor.extract %3[] : tensor + %4 = quantum.extract %11[%extracted_0] : !quantum.reg -> !quantum.bit + %extracted_1 = tensor.extract %arg1[] : tensor + %out_qubits = quantum.custom "RzDecomp"(%extracted_1) %4 : !quantum.bit + %5 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor + %extracted_2 = tensor.extract %3[] : tensor + %7 = quantum.insert %11[%extracted_2], %out_qubits : !quantum.reg, !quantum.bit + %extracted_3 = tensor.extract %6[] : tensor + %8 = quantum.extract %7[%extracted_3] : !quantum.reg -> !quantum.bit + %extracted_4 = tensor.extract %arg1[] : tensor + %out_qubits_5 = quantum.custom "RzDecomp"(%extracted_4) %8 : !quantum.bit + %extracted_6 = tensor.extract %6[] : tensor + %9 = quantum.insert %7[%extracted_6], %out_qubits_5 : !quantum.reg, !quantum.bit + scf.yield %9 : !quantum.reg + } else { + scf.yield %11 : !quantum.reg + } + return %1 : !quantum.reg + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @RzDecomp_rule_1 + func.func private @RzDecomp_rule_1(%arg0: !quantum.reg, %arg1: tensor, %arg2: tensor<1xi64>) -> !quantum.reg + attributes {target_gate = "RzDecomp", llvm.linkage = #llvm.linkage} { + %0 = stablehlo.slice %arg2 [0:1] : (tensor<1xi64>) -> tensor<1xi64> + %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor + %extracted = tensor.extract %1[] : tensor + %2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit + %extracted_0 = tensor.extract %arg1[] : tensor + %out_qubits = quantum.custom "RZ"(%extracted_0) %2 : !quantum.bit + %extracted_1 = tensor.extract %1[] : tensor + %3 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit + return %3 : !quantum.reg + } +} + +// ----- + +module @multi_wire_cnot_decomposition { + func.func public @test_cnot_decomposition() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[WIRE_TENSOR:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64> + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[WIRE_TENSOR]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor + // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[WIRE_TENSOR]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[RESHAPE2]][] : tensor + // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit + // CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit + // CHECK: [[INSERT1:%.+]] = quantum.insert [[REG]][[[EXTRACTED]]], [[RY1]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACTED2:%.+]] = tensor.extract [[RESHAPE1]][] : tensor + // CHECK: [[QUBIT0:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED2]]] : !quantum.reg -> !quantum.bit + // CHECK: [[QUBIT1_2:%.+]] = quantum.extract [[INSERT1]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[QUBIT1_2]] : !quantum.bit, !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][[[EXTRACTED2]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit + // CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit + // CHECK: [[INSERT3:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[RY2]] : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_QUBIT0:%.+]] = quantum.extract [[INSERT3]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[FINAL_QUBIT1:%.+]] = quantum.extract [[INSERT3]][ 1] : !quantum.reg -> !quantum.bit + // CHECK-NOT: quantum.custom "CNOT" + %3, %4 = quantum.custom "CNOT"() %1, %2 : !quantum.bit, !quantum.bit + + // CHECK: [[FINAL_INSERT1:%.+]] = quantum.insert [[REG]][ 0], [[FINAL_QUBIT0]] : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_INSERT2:%.+]] = quantum.insert [[FINAL_INSERT1]][ 1], [[FINAL_QUBIT1]] : !quantum.reg, !quantum.bit + %5 = quantum.insert %0[ 0], %3 : !quantum.reg, !quantum.bit + %6 = quantum.insert %5[ 1], %4 : !quantum.reg, !quantum.bit + %7 = quantum.compbasis qreg %6 : !quantum.obs + %8 = quantum.probs %7 : tensor<4xf64> + quantum.dealloc %6 : !quantum.reg + return %8 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @CNOT_rule_cz_rz_ry + func.func private @CNOT_rule_cz_rz_ry(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { + // CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H) + %cst = arith.constant 1.5707963267948966 : f64 + %cst_0 = arith.constant 3.1415926535897931 : f64 + + // Extract wire indices from tensor + %0 = stablehlo.slice %arg1 [0:1] : (tensor<2xi64>) -> tensor<1xi64> + %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor + %2 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %3 = stablehlo.reshape %2 : (tensor<1xi64>) -> tensor + + // Step 1: Apply H to target qubit (H = RZ(π) * RY(π/2)) + %extracted = tensor.extract %3[] : tensor + %4 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "RZ"(%cst_0) %4 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst) %out_qubits : !quantum.bit + %extracted_2 = tensor.extract %3[] : tensor + %5 = quantum.insert %arg0[%extracted_2], %out_qubits_1 : !quantum.reg, !quantum.bit + + // Step 2: Apply CZ gate + %extracted_3 = tensor.extract %1[] : tensor + %6 = quantum.extract %5[%extracted_3] : !quantum.reg -> !quantum.bit + %extracted_4 = tensor.extract %3[] : tensor + %7 = quantum.extract %5[%extracted_4] : !quantum.reg -> !quantum.bit + %out_qubits_5:2 = quantum.custom "CZ"() %6, %7 : !quantum.bit, !quantum.bit + %extracted_6 = tensor.extract %1[] : tensor + %8 = quantum.insert %5[%extracted_6], %out_qubits_5#0 : !quantum.reg, !quantum.bit + %extracted_7 = tensor.extract %3[] : tensor + %9 = quantum.insert %8[%extracted_7], %out_qubits_5#1 : !quantum.reg, !quantum.bit + + // Step 3: Apply H to target qubit again + %extracted_8 = tensor.extract %3[] : tensor + %10 = quantum.extract %9[%extracted_8] : !quantum.reg -> !quantum.bit + %out_qubits_9 = quantum.custom "RZ"(%cst_0) %10 : !quantum.bit + %out_qubits_10 = quantum.custom "RY"(%cst) %out_qubits_9 : !quantum.bit + %extracted_11 = tensor.extract %3[] : tensor + %11 = quantum.insert %9[%extracted_11], %out_qubits_10 : !quantum.reg, !quantum.bit + + return %11 : !quantum.reg + } +} + +// ----- + +module @cnot_alternative_decomposition { + func.func public @test_cnot_alternative_decomposition() -> tensor<4xf64> { + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + + // CHECK: [[CST_PI:%.+]] = arith.constant 3.1415926535897931 : f64 + // CHECK: [[CST_PI2:%.+]] = arith.constant 1.5707963267948966 : f64 + // CHECK: [[REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[QUBIT0:%.+]] = quantum.extract [[REG]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[RZ1:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[QUBIT1]] : !quantum.bit + // CHECK: [[RY1:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ1]] : !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[QUBIT0]], [[RY1]] : !quantum.bit, !quantum.bit + // CHECK: [[RZ2:%.+]] = quantum.custom "RZ"([[CST_PI]]) [[CZ_RESULT]]#1 : !quantum.bit + // CHECK: [[RY2:%.+]] = quantum.custom "RY"([[CST_PI2]]) [[RZ2]] : !quantum.bit + // CHECK-NOT: quantum.custom "CNOT" + %3, %4 = quantum.custom "CNOT"() %1, %2 : !quantum.bit, !quantum.bit + + // CHECK: [[FINAL_INSERT1:%.+]] = quantum.insert [[REG]][ 0], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_INSERT2:%.+]] = quantum.insert [[FINAL_INSERT1]][ 1], [[RY2]] : !quantum.reg, !quantum.bit + %5 = quantum.insert %0[ 0], %3 : !quantum.reg, !quantum.bit + %6 = quantum.insert %5[ 1], %4 : !quantum.reg, !quantum.bit + %7 = quantum.compbasis qreg %6 : !quantum.obs + %8 = quantum.probs %7 : tensor<4xf64> + quantum.dealloc %6 : !quantum.reg + return %8 : tensor<4xf64> + } + + // Decomposition function should be applied and removed from the module + // CHECK-NOT: func.func private @CNOT_rule_h_cnot_h + func.func private @CNOT_rule_h_cnot_h(%arg0: !quantum.bit, %arg1: !quantum.bit) -> (!quantum.bit, !quantum.bit) attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { + // CNOT decomposition: CNOT = (I ⊗ H) * CZ * (I ⊗ H) + %cst = arith.constant 1.5707963267948966 : f64 + %cst_0 = arith.constant 3.1415926535897931 : f64 + + // Step 1: Apply H to target qubit (H = RZ(π) * RY(π/2)) + %out_qubits = quantum.custom "RZ"(%cst_0) %arg1 : !quantum.bit + %out_qubits_1 = quantum.custom "RY"(%cst) %out_qubits : !quantum.bit + + // Step 2: Apply CZ gate + %out_qubits_2:2 = quantum.custom "CZ"() %arg0, %out_qubits_1 : !quantum.bit, !quantum.bit + + // Step 3: Apply H to target qubit again + %out_qubits_3 = quantum.custom "RZ"(%cst_0) %out_qubits_2#1 : !quantum.bit + %out_qubits_4 = quantum.custom "RY"(%cst) %out_qubits_3 : !quantum.bit + + return %out_qubits_2#0, %out_qubits_4 : !quantum.bit, !quantum.bit + } +}