Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 81 additions & 18 deletions mlir/lib/Quantum/Transforms/DecomposeLoweringPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace quantum {
/// - A runtime Value (for dynamic indices computed at runtime)
/// - An IntegerAttr (for compile-time constant indices)
/// - Invalid/uninitialized (represented by std::monostate)
/// And a qreg value to represent the qreg that the index belongs to
///
/// The struct uses std::variant to ensure only one type is active at a time,
/// preventing invalid states.
Expand All @@ -54,17 +55,21 @@ namespace quantum {
/// Value idx = dynamicIdx.getValue(); // Get the Value
/// }
/// }
struct QubitIndex {
class QubitIndex {
private:
// use monostate to represent the invalid index
std::variant<std::monostate, Value, IntegerAttr> index;
Value qreg;

QubitIndex() : index(std::monostate()) {}
QubitIndex(Value val) : index(val) {}
QubitIndex(IntegerAttr attr) : index(attr) {}
public:
QubitIndex() : index(std::monostate()), qreg(nullptr) {}
QubitIndex(Value val, Value qreg) : index(val), qreg(qreg) {}
QubitIndex(IntegerAttr attr, Value qreg) : index(attr), qreg(qreg) {}

bool isValue() const { return std::holds_alternative<Value>(index); }
bool isAttr() const { return std::holds_alternative<IntegerAttr>(index); }
operator bool() const { return isValue() || isAttr(); }
Value getReg() const { return qreg; }
Value getValue() const { return isValue() ? std::get<Value>(index) : nullptr; }
IntegerAttr getAttr() const { return isAttr() ? std::get<IntegerAttr>(index) : nullptr; }
};
Expand All @@ -75,26 +80,20 @@ struct QubitIndex {
class OpSignatureAnalyzer {
public:
OpSignatureAnalyzer() = delete;
OpSignatureAnalyzer(CustomOp op, bool enableQregMode)
OpSignatureAnalyzer(CustomOp op, bool enableQregMode, PatternRewriter &rewriter)
: signature(OpSignature{
.params = op.getParams(),
.inQubits = op.getInQubits(),
.inCtrlQubits = op.getInCtrlQubits(),
.inCtrlValues = op.getInCtrlValues(),
.outQubits = op.getOutQubits(),
.outCtrlQubits = op.getOutCtrlQubits(),
.rewriter = rewriter,
})
{
if (!enableQregMode)
return;

signature.sourceQreg = getSourceQreg(signature.inQubits.front());
if (!signature.sourceQreg) {
op.emitError("Cannot get source qreg");
isValid = false;
return;
}

// input wire indices
for (Value qubit : signature.inQubits) {
const QubitIndex index = getExtractIndex(qubit);
Expand All @@ -117,13 +116,54 @@ class OpSignatureAnalyzer {
signature.inCtrlWireIndices.emplace_back(index);
}

assert((signature.inWireIndices.size() + signature.inCtrlWireIndices.size()) > 0 &&
"inWireIndices or inCtrlWireIndices should not be empty");

// Get the first qreg as reference
Value refQreg = !signature.inWireIndices.empty() ? signature.inWireIndices[0].getReg()
: signature.inCtrlWireIndices[0].getReg();

// Check if any qreg is different
signature.needAllocQreg =
std::any_of(signature.inWireIndices.begin(), signature.inWireIndices.end(),
[refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; }) ||
std::any_of(signature.inCtrlWireIndices.begin(), signature.inCtrlWireIndices.end(),
[refQreg](const QubitIndex &idx) { return idx.getReg() != refQreg; });

// If needAllocQreg, the indices should be updated to from 0 to nqubits - 1
// Since we will use the new qreg for the indices
if (signature.needAllocQreg) {
for (auto [i, index] : llvm::enumerate(signature.inWireIndices)) {
auto attr = IntegerAttr::get(rewriter.getI64Type(), i);
signature.inWireIndices[i] = QubitIndex(attr, index.getReg());
}
for (auto [i, index] : llvm::enumerate(signature.inCtrlWireIndices)) {
auto attr =
IntegerAttr::get(rewriter.getI64Type(), i + signature.inWireIndices.size());
signature.inCtrlWireIndices[i] = QubitIndex(attr, index.getReg());
}
}

// Output qubit indices are the same as input qubit indices
signature.outQubitIndices = signature.inWireIndices;
signature.outCtrlQubitIndices = signature.inCtrlWireIndices;
}

operator bool() const { return isValid; }

Value getUpdatedQreg(PatternRewriter &rewriter, Location loc)
{
if (signature.needAllocQreg) {
// allocate a new qreg with the number of qubits
auto nqubits = signature.inWireIndices.size() + signature.inCtrlWireIndices.size();
IntegerAttr nqubitsAttr = IntegerAttr::get(rewriter.getI64Type(), nqubits);
auto allocOp = rewriter.create<quantum::AllocOp>(
loc, quantum::QuregType::get(rewriter.getContext()), nullptr, nqubitsAttr);
return allocOp.getQreg();
}
return signature.inWireIndices[0].getReg();
}

// Prepare the operands for calling the decomposition function
// There are two cases:
// 1. The first input is a qreg, which means the decomposition function is a qreg mode function
Expand All @@ -144,14 +184,23 @@ class OpSignatureAnalyzer {

int operandIdx = 0;
if (isa<quantum::QuregType>(funcInputs[0])) {
Value updatedQreg = signature.sourceQreg;
// Allocate a new qreg if needed
Value updatedQreg = getUpdatedQreg(rewriter, loc);

for (auto [i, qubit] : llvm::enumerate(signature.inQubits)) {
const QubitIndex &index = signature.inWireIndices[i];
updatedQreg =
rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(), updatedQreg,
index.getValue(), index.getAttr(), qubit);
}

for (auto [i, qubit] : llvm::enumerate(signature.inCtrlQubits)) {
const QubitIndex &index = signature.inCtrlWireIndices[i];
updatedQreg =
rewriter.create<quantum::InsertOp>(loc, updatedQreg.getType(), updatedQreg,
index.getValue(), index.getAttr(), qubit);
}

operands[operandIdx++] = updatedQreg;
if (!signature.params.empty()) {
auto [startIdx, endIdx] =
Expand Down Expand Up @@ -218,6 +267,7 @@ class OpSignatureAnalyzer {

SmallVector<Value> newResults;
rewriter.setInsertionPointAfter(callOp);

for (const QubitIndex &index : signature.outQubitIndices) {
auto extractOp = rewriter.create<quantum::ExtractOp>(
callOp.getLoc(), rewriter.getType<quantum::QubitType>(), qreg, index.getValue(),
Expand All @@ -230,6 +280,12 @@ class OpSignatureAnalyzer {
index.getAttr());
newResults.emplace_back(extractOp.getResult());
}

// FIXME: Dealloc should be fine, but it will cause the error in lightning now
// if (signature.needAllocQreg) {
// rewriter.create<quantum::DeallocOp>(callOp.getLoc(), qreg);
// }

return newResults;
}

Expand All @@ -245,11 +301,17 @@ class OpSignatureAnalyzer {
ValueRange outCtrlQubits;

// Qreg mode specific information
Value sourceQreg = nullptr;
SmallVector<QubitIndex> inWireIndices;
SmallVector<QubitIndex> inCtrlWireIndices;
SmallVector<QubitIndex> outQubitIndices;
SmallVector<QubitIndex> outCtrlQubitIndices;

// Qreg mode specific information, if true, a new qreg should be allocated before function
// call and deallocated after function call
bool needAllocQreg = false;

// Rewriter
PatternRewriter &rewriter;
} signature;

Value fromTensorOrAsIs(ValueRange values, Type type, PatternRewriter &rewriter, Location loc)
Expand Down Expand Up @@ -356,10 +418,10 @@ class OpSignatureAnalyzer {
while (qubit) {
if (auto extractOp = qubit.getDefiningOp<quantum::ExtractOp>()) {
if (Value idx = extractOp.getIdx()) {
return QubitIndex(idx);
return QubitIndex(idx, extractOp.getQreg());
}
if (IntegerAttr idxAttr = extractOp.getIdxAttrAttr()) {
return QubitIndex(idxAttr);
return QubitIndex(idxAttr, extractOp.getQreg());
}
}

Expand Down Expand Up @@ -422,10 +484,11 @@ struct DecomposeLoweringRewritePattern : public OpRewritePattern<CustomOp> {
"Decomposition function must have at least one result");

auto enableQreg = isa<quantum::QuregType>(decompFunc.getFunctionType().getInput(0));
auto analyzer = OpSignatureAnalyzer(op, enableQreg);
assert(analyzer && "Analyzer should be valid");

rewriter.setInsertionPointAfter(op);
auto analyzer = OpSignatureAnalyzer(op, enableQreg, rewriter);
assert(analyzer && "Analyzer should be valid");

auto callOperands = analyzer.prepareCallOperands(decompFunc, rewriter, op.getLoc());
auto callOp =
rewriter.create<func::CallOp>(op.getLoc(), decompFunc.getFunctionType().getResults(),
Expand Down
85 changes: 85 additions & 0 deletions mlir/test/Quantum/DecomposeLoweringTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,91 @@ module @single_hadamard {
}
}

// -----

module @cz_hadamard {
func.func public @test_cz_hadamard() -> tensor<2xf64> attributes {decompose_gatesets = [["CZ", "Hadamard"]]} {
%cst = arith.constant dense<[0, 1]> : tensor<2xi64>
%0 = quantum.alloc( 1) : !quantum.reg
%1 = quantum.alloc( 1) : !quantum.reg

// Extract qubits from different qregs (this will trigger needAllocQreg)
%2 = quantum.extract %1[ 0] : !quantum.reg -> !quantum.bit
%3 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit

// CHECK: [[CST:%.+]] = arith.constant dense<[0, 1]> : tensor<2xi64>
// CHECK: [[REG0:%.+]] = quantum.alloc( 1) : !quantum.reg
// CHECK: [[REG1:%.+]] = quantum.alloc( 1) : !quantum.reg
// CHECK: [[QUBIT1:%.+]] = quantum.extract [[REG1]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[QUBIT2:%.+]] = quantum.extract [[REG0]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[NEW_REG:%.+]] = quantum.alloc( 2) : !quantum.reg
// CHECK: [[INSERT1:%.+]] = quantum.insert [[NEW_REG]][ 0], [[QUBIT1]] : !quantum.reg, !quantum.bit
// CHECK: [[INSERT2:%.+]] = quantum.insert [[INSERT1]][ 1], [[QUBIT2]] : !quantum.reg, !quantum.bit
// CHECK: [[SLICE1:%.+]] = stablehlo.slice [[CST]] [1:2] : (tensor<2xi64>) -> tensor<1xi64>
// CHECK: [[RESHAPE1:%.+]] = stablehlo.reshape [[SLICE1]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: [[EXTRACTED:%.+]] = tensor.extract [[RESHAPE1]][] : tensor<i64>
// CHECK: [[EXTRACT1:%.+]] = quantum.extract [[INSERT2]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit
// CHECK: [[H1:%.+]] = quantum.custom "Hadamard"() [[EXTRACT1]] : !quantum.bit
// CHECK: [[SLICE2:%.+]] = stablehlo.slice [[CST]] [0:1] : (tensor<2xi64>) -> tensor<1xi64>
// CHECK: [[RESHAPE2:%.+]] = stablehlo.reshape [[SLICE2]] : (tensor<1xi64>) -> tensor<i64>
// CHECK: [[INSERT_H:%.+]] = quantum.insert [[INSERT2]][[[EXTRACTED]]], [[H1]] : !quantum.reg, !quantum.bit
// CHECK: [[EXTRACTED_0:%.+]] = tensor.extract [[RESHAPE2]][] : tensor<i64>
// CHECK: [[EXTRACT2:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED_0]]] : !quantum.reg -> !quantum.bit
// CHECK: [[EXTRACT3:%.+]] = quantum.extract [[INSERT_H]][[[EXTRACTED]]] : !quantum.reg -> !quantum.bit
// CHECK: [[CZ_RESULT:%.+]]:2 = quantum.custom "CZ"() [[EXTRACT2]], [[EXTRACT3]] : !quantum.bit, !quantum.bit
// CHECK: [[INSERT_CZ1:%.+]] = quantum.insert [[INSERT_H]][[[EXTRACTED_0]]], [[CZ_RESULT]]#0 : !quantum.reg, !quantum.bit
// CHECK: [[H2:%.+]] = quantum.custom "Hadamard"() [[CZ_RESULT]]#1 : !quantum.bit
// CHECK: [[INSERT_CZ2:%.+]] = quantum.insert [[INSERT_CZ1]][[[EXTRACTED]]], [[H2]] : !quantum.reg, !quantum.bit
// CHECK: [[FINAL_EXTRACT1:%.+]] = quantum.extract [[INSERT_CZ2]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[FINAL_EXTRACT2:%.+]] = quantum.extract [[INSERT_CZ2]][ 1] : !quantum.reg -> !quantum.bit
// CHECK-NOT: quantum.custom "CNOT"
%out_qubits:2 = quantum.custom "CNOT"() %2, %3 : !quantum.bit, !quantum.bit

%4 = quantum.insert %1[ 0], %out_qubits#0 : !quantum.reg, !quantum.bit
quantum.dealloc %4 : !quantum.reg
%5 = quantum.compbasis qubits %out_qubits#1 : !quantum.obs
%6 = quantum.probs %5 : tensor<2xf64>
%7 = quantum.insert %0[ 0], %out_qubits#1 : !quantum.reg, !quantum.bit
quantum.dealloc %7 : !quantum.reg
return %6 : tensor<2xf64>
}

// Decomposition function for CNOT gate into CZ and Hadamard
// CHECK-NOT: func.func private @cz_hadamard
func.func private @cz_hadamard(%arg0: !quantum.reg, %arg1: tensor<2xi64>) -> !quantum.reg attributes {target_gate = "CNOT", llvm.linkage = #llvm.linkage<internal>} {
%0 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64>
%1 = stablehlo.reshape %0 : (tensor<1xi64>) -> tensor<i64>
%extracted = tensor.extract %1[] : tensor<i64>
%2 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit
%out_qubits = quantum.custom "Hadamard"() %2 : !quantum.bit
%3 = stablehlo.slice %arg1 [0:1] : (tensor<2xi64>) -> tensor<1xi64>
%4 = stablehlo.reshape %3 : (tensor<1xi64>) -> tensor<i64>
%5 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64>
%6 = stablehlo.reshape %5 : (tensor<1xi64>) -> tensor<i64>
%extracted_0 = tensor.extract %1[] : tensor<i64>
%7 = quantum.insert %arg0[%extracted_0], %out_qubits : !quantum.reg, !quantum.bit
%extracted_1 = tensor.extract %4[] : tensor<i64>
%8 = quantum.extract %7[%extracted_1] : !quantum.reg -> !quantum.bit
%extracted_2 = tensor.extract %6[] : tensor<i64>
%9 = quantum.extract %7[%extracted_2] : !quantum.reg -> !quantum.bit
%out_qubits_3:2 = quantum.custom "CZ"() %8, %9 : !quantum.bit, !quantum.bit
%10 = stablehlo.slice %arg1 [1:2] : (tensor<2xi64>) -> tensor<1xi64>
%11 = stablehlo.reshape %10 : (tensor<1xi64>) -> tensor<i64>
%extracted_4 = tensor.extract %4[] : tensor<i64>
%12 = quantum.insert %7[%extracted_4], %out_qubits_3#0 : !quantum.reg, !quantum.bit
%extracted_5 = tensor.extract %6[] : tensor<i64>
%13 = quantum.insert %12[%extracted_5], %out_qubits_3#1 : !quantum.reg, !quantum.bit
%extracted_6 = tensor.extract %11[] : tensor<i64>
%14 = quantum.extract %13[%extracted_6] : !quantum.reg -> !quantum.bit
%out_qubits_7 = quantum.custom "Hadamard"() %14 : !quantum.bit
%extracted_8 = tensor.extract %11[] : tensor<i64>
%15 = quantum.insert %13[%extracted_8], %out_qubits_7 : !quantum.reg, !quantum.bit
return %15 : !quantum.reg
}
}



// -----
module @recursive {
func.func public @test_recursive() -> tensor<4xf64> {
Expand Down