diff --git a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp index 9dcc4ea1ad..7604780a2d 100644 --- a/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp @@ -40,6 +40,7 @@ namespace quantum { /// - A runtime Value (for dynamic indices computed at runtime) /// - An IntegerAttr (for compile-time constant indices) /// - Invalid/uninitialized (represented by std::monostate) +/// And a qreg value to represent the qreg that the index belongs to /// /// The struct uses std::variant to ensure only one type is active at a time, /// preventing invalid states. @@ -54,17 +55,21 @@ namespace quantum { /// Value idx = dynamicIdx.getValue(); // Get the Value /// } /// } -struct QubitIndex { +class QubitIndex { + private: // use monostate to represent the invalid index std::variant index; + Value qreg; - QubitIndex() : index(std::monostate()) {} - QubitIndex(Value val) : index(val) {} - QubitIndex(IntegerAttr attr) : index(attr) {} + public: + QubitIndex() : index(std::monostate()), qreg(nullptr) {} + QubitIndex(Value val, Value qreg) : index(val), qreg(qreg) {} + QubitIndex(IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {} bool isValue() const { return std::holds_alternative(index); } bool isAttr() const { return std::holds_alternative(index); } operator bool() const { return isValue() || isAttr(); } + Value getReg() const { return qreg; } Value getValue() const { return isValue() ? std::get(index) : nullptr; } IntegerAttr getAttr() const { return isAttr() ? std::get(index) : nullptr; } }; @@ -75,7 +80,7 @@ struct QubitIndex { class OpSignatureAnalyzer { public: OpSignatureAnalyzer() = delete; - OpSignatureAnalyzer(CustomOp op, bool enableQregMode) + OpSignatureAnalyzer(CustomOp op, bool enableQregMode, PatternRewriter &rewriter) : signature(OpSignature{ .params = op.getParams(), .inQubits = op.getInQubits(), @@ -83,18 +88,12 @@ class OpSignatureAnalyzer { .inCtrlValues = op.getInCtrlValues(), .outQubits = op.getOutQubits(), .outCtrlQubits = op.getOutCtrlQubits(), + .rewriter = rewriter, }) { if (!enableQregMode) return; - signature.sourceQreg = getSourceQreg(signature.inQubits.front()); - if (!signature.sourceQreg) { - op.emitError("Cannot get source qreg"); - isValid = false; - return; - } - // input wire indices for (Value qubit : signature.inQubits) { const QubitIndex index = getExtractIndex(qubit); @@ -117,6 +116,34 @@ class OpSignatureAnalyzer { signature.inCtrlWireIndices.emplace_back(index); } + assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 && + "inWireIndices or inCtrlWireIndices should not be empty"); + + // Get the first qreg as reference + Value refQreg = !signature.inWireIndices.empty() ? signature.inWireIndices[0].getReg() + : signature.inCtrlWireIndices[0].getReg(); + + // Check if any qreg is different + signature.needAllocQreg = + std::any_of(signature.inWireIndices.begin(), signature.inWireIndices.end(), + [refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }) || + std::any_of(signature.inCtrlWireIndices.begin(), signature.inCtrlWireIndices.end(), + [refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }); + + // If needAllocQreg, the indices should be updated to from 0 to nqubits - 1 + // Since we will use the new qreg for the indices + if (signature.needAllocQreg) { + for (auto [i, index] : llvm::enumerate(signature.inWireIndices)) { + auto attr = IntegerAttr::get(rewriter.getI64Type(), i); + signature.inWireIndices[i] = QubitIndex(attr, index.getReg()); + } + for (auto [i, index] : llvm::enumerate(signature.inCtrlWireIndices)) { + auto attr = + IntegerAttr::get(rewriter.getI64Type(), i + signature.inWireIndices.size()); + signature.inCtrlWireIndices[i] = QubitIndex(attr, index.getReg()); + } + } + // Output qubit indices are the same as input qubit indices signature.outQubitIndices = signature.inWireIndices; signature.outCtrlQubitIndices = signature.inCtrlWireIndices; @@ -124,6 +151,19 @@ class OpSignatureAnalyzer { operator bool() const { return isValid; } + Value getUpdatedQreg(PatternRewriter &rewriter, Location loc) + { + if (signature.needAllocQreg) { + // allocate a new qreg with the number of qubits + auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size(); + IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits); + auto allocOp = rewriter.create( + loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr); + return allocOp.getQreg(); + } + return signature.inWireIndices[0].getReg(); + } + // Prepare the operands for calling the decomposition function // There are two cases: // 1. The first input is a qreg, which means the decomposition function is a qreg mode function @@ -144,7 +184,9 @@ class OpSignatureAnalyzer { int operandIdx = 0; if (isa(funcInputs[0])) { - Value updatedQreg = signature.sourceQreg; + // Allocate a new qreg if needed + Value updatedQreg = getUpdatedQreg(rewriter, loc); + for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) { const QubitIndex &index = signature.inWireIndices[i]; updatedQreg = @@ -152,6 +194,13 @@ class OpSignatureAnalyzer { index.getValue(), index.getAttr(), qubit); } + for (auto [i, qubit] : llvm::enumerate(signature.inCtrlQubits)) { + const QubitIndex &index = signature.inCtrlWireIndices[i]; + updatedQreg = + rewriter.create(loc, updatedQreg.getType(), updatedQreg, + index.getValue(), index.getAttr(), qubit); + } + operands[operandIdx++] = updatedQreg; if (!signature.params.empty()) { auto [startIdx, endIdx] = @@ -218,6 +267,7 @@ class OpSignatureAnalyzer { SmallVector newResults; rewriter.setInsertionPointAfter(callOp); + for (const QubitIndex &index : signature.outQubitIndices) { auto extractOp = rewriter.create( callOp.getLoc(), rewriter.getType(), qreg, index.getValue(), @@ -230,6 +280,12 @@ class OpSignatureAnalyzer { index.getAttr()); newResults.emplace_back(extractOp.getResult()); } + + // FIXME: Dealloc should be fine, but it will cause the error in lightning now + // if (signature.needAllocQreg) { + // rewriter.create(callOp.getLoc(), qreg); + // } + return newResults; } @@ -245,11 +301,17 @@ class OpSignatureAnalyzer { ValueRange outCtrlQubits; // Qreg mode specific information - Value sourceQreg = nullptr; SmallVector inWireIndices; SmallVector inCtrlWireIndices; SmallVector outQubitIndices; SmallVector outCtrlQubitIndices; + + // Qreg mode specific information, if true, a new qreg should be allocated before function + // call and deallocated after function call + bool needAllocQreg = false; + + // Rewriter + PatternRewriter &rewriter; } signature; Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc) @@ -356,10 +418,10 @@ class OpSignatureAnalyzer { while (qubit) { if (auto extractOp = qubit.getDefiningOp()) { if (Value idx = extractOp.getIdx()) { - return QubitIndex(idx); + return QubitIndex(idx, extractOp.getQreg()); } if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) { - return QubitIndex(idxAttr); + return QubitIndex(idxAttr, extractOp.getQreg()); } } @@ -422,10 +484,11 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern { "Decomposition function must have at least one result"); auto enableQreg = isa(decompFunc.getFunctionType().getInput(0)); - auto analyzer = OpSignatureAnalyzer(op, enableQreg); - assert(analyzer && "Analyzer should be valid"); rewriter.setInsertionPointAfter(op); + auto analyzer = OpSignatureAnalyzer(op, enableQreg, rewriter); + assert(analyzer && "Analyzer should be valid"); + auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc()); auto callOp = rewriter.create(op.getLoc(), decompFunc.getFunctionType().getResults(), diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index 91bfbe7778..b616ec2532 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -84,6 +84,91 @@ module @single_hadamard { } } +// ----- + +module @cz_hadamard { + func.func public @test_cz_hadamard() -> tensor<2xf64> attributes {decompose_gatesets = [["CZ", "Hadamard"]]} { + %cst = arith.constant dense<[0, 1]> : tensor<2xi64> + %0 = quantum.alloc( 1) : !quantum.reg + %1 = quantum.alloc( 1) : !quantum.reg + + // Extract qubits from different qregs (this will trigger needAllocQreg) + %2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit + %3 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + + // CHECK: [[CST:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64> + // CHECK: [[REG0:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[REG1:%.+]] = quantum.alloc( 1) : !quantum.reg + // CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG1]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[QUBIT2:%.+]] = quantum.extract [[REG0]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[NEW_REG:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[INSERT1:%.+]] = quantum.insert [[NEW_REG]][ 0], [[QUBIT1]] : !quantum.reg, !quantum.bit + // CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][ 1], [[QUBIT2]] : !quantum.reg, !quantum.bit + // CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor + // CHECK: [[EXTRACTED:%.+]] = tensor.extract [[RESHAPE1]][] : tensor + // CHECK: [[EXTRACT1:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[H1:%.+]] = quantum.custom "Hadamard"() [[EXTRACT1]] : !quantum.bit + // CHECK: [[SLICE2:%.+]] = stablehlo.slice [[CST]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor + // CHECK: [[INSERT_H:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[H1]] : !quantum.reg, !quantum.bit + // CHECK: [[EXTRACTED_0:%.+]] = tensor.extract [[RESHAPE2]][] : tensor + // CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED_0]]] : !quantum.reg -> !quantum.bit + // CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit + // CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[EXTRACT2]], [[EXTRACT3]] : !quantum.bit, !quantum.bit + // CHECK: [[INSERT_CZ1:%.+]] = quantum.insert [[INSERT_H]][[[EXTRACTED_0]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[H2:%.+]] = quantum.custom "Hadamard"() [[CZ_RESULT]]#1 : !quantum.bit + // CHECK: [[INSERT_CZ2:%.+]] = quantum.insert [[INSERT_CZ1]][[[EXTRACTED]]], [[H2]] : !quantum.reg, !quantum.bit + // CHECK: [[FINAL_EXTRACT1:%.+]] = quantum.extract [[INSERT_CZ2]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[FINAL_EXTRACT2:%.+]] = quantum.extract [[INSERT_CZ2]][ 1] : !quantum.reg -> !quantum.bit + // CHECK-NOT: quantum.custom "CNOT" + %out_qubits:2 = quantum.custom "CNOT"() %2, %3 : !quantum.bit, !quantum.bit + + %4 = quantum.insert %1[ 0], %out_qubits#0 : !quantum.reg, !quantum.bit + quantum.dealloc %4 : !quantum.reg + %5 = quantum.compbasis qubits %out_qubits#1 : !quantum.obs + %6 = quantum.probs %5 : tensor<2xf64> + %7 = quantum.insert %0[ 0], %out_qubits#1 : !quantum.reg, !quantum.bit + quantum.dealloc %7 : !quantum.reg + return %6 : tensor<2xf64> + } + + // Decomposition function for CNOT gate into CZ and Hadamard + // CHECK-NOT: func.func private @cz_hadamard + func.func private @cz_hadamard(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage} { + %0 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor + %extracted = tensor.extract %1[] : tensor + %2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %2 : !quantum.bit + %3 = stablehlo.slice %arg1 [0:1] : (tensor<2xi64>) -> tensor<1xi64> + %4 = stablehlo.reshape %3 : (tensor<1xi64>) -> tensor + %5 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor + %extracted_0 = tensor.extract %1[] : tensor + %7 = quantum.insert %arg0[%extracted_0], %out_qubits : !quantum.reg, !quantum.bit + %extracted_1 = tensor.extract %4[] : tensor + %8 = quantum.extract %7[%extracted_1] : !quantum.reg -> !quantum.bit + %extracted_2 = tensor.extract %6[] : tensor + %9 = quantum.extract %7[%extracted_2] : !quantum.reg -> !quantum.bit + %out_qubits_3:2 = quantum.custom "CZ"() %8, %9 : !quantum.bit, !quantum.bit + %10 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64> + %11 = stablehlo.reshape %10 : (tensor<1xi64>) -> tensor + %extracted_4 = tensor.extract %4[] : tensor + %12 = quantum.insert %7[%extracted_4], %out_qubits_3#0 : !quantum.reg, !quantum.bit + %extracted_5 = tensor.extract %6[] : tensor + %13 = quantum.insert %12[%extracted_5], %out_qubits_3#1 : !quantum.reg, !quantum.bit + %extracted_6 = tensor.extract %11[] : tensor + %14 = quantum.extract %13[%extracted_6] : !quantum.reg -> !quantum.bit + %out_qubits_7 = quantum.custom "Hadamard"() %14 : !quantum.bit + %extracted_8 = tensor.extract %11[] : tensor + %15 = quantum.insert %13[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit + return %15 : !quantum.reg + } +} + + + // ----- module @recursive { func.func public @test_recursive() -> tensor<4xf64> {