Skip to content

Inefficient mlir after linalg lowering #535

@renezander90

Description

@renezander90
from qrisp import QuantumVariable, h, cx, measure, control, x
from qrisp.jasp import make_jaspr

def bell():
    qv = QuantumVariable(2)
    h(qv[0])
    c = measure(qv[1])
    with control(c==0):
        x(qv[1])
    return measure(qv[0])

jaspr = make_jaspr(bell)()
module = jaspr.to_mlir(lower_stablehlo=True)
print("=== Original MLIR ===")
print(module)

=== Original MLIR ===
builtin.module @jasp_module {
func.func public @main(%arg0 : !jasp.QuantumState) -> (tensor, !jasp.QuantumState) {
%0 = arith.constant dense<2> : tensor
%1, %2 = jasp.create_qubits %0, %arg0 : !jasp.QuantumState, tensor -> !jasp.QubitArray, !jasp.QuantumState
%3 = arith.constant dense<0> : tensor
%4 = jasp.get_qubit %1, %3 : !jasp.QubitArray, tensor -> !jasp.Qubit
%5 = jasp.quantum_gate "h" (%4) , %2 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
%6 = arith.constant dense<1> : tensor
%7 = jasp.get_qubit %1, %6 : !jasp.QubitArray, tensor -> !jasp.Qubit
%8, %9 = jasp.measure %7, %5 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
%10 = tensor.empty() : tensor
%11 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%8 : tensor) outs(%10 : tensor) {
^bb0(%arg5 : i1, %arg6 : i64):
%12 = arith.extui %arg5 : i1 to i64
linalg.yield %12 : i64
} -> tensor
%13 = tensor.empty() : tensor
%14 = arith.constant 0 : i64
%15 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%11 : tensor) outs(%13 : tensor) {
^bb1(%arg3 : i64, %arg4 : i1):
%16 = arith.cmpi eq, %arg3, %14 : i64
linalg.yield %16 : i1
} -> tensor
%17 = tensor.empty() : tensor
%18 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor) outs(%17 : tensor) {
^bb2(%arg1 : i1, %arg2 : i32):
%19 = arith.extui %arg1 : i1 to i32
linalg.yield %19 : i32
} -> tensor
%20 = tensor.extract %18[] : tensor
%21 = arith.constant 0 : i32
%22 = arith.cmpi ne, %20, %21 : i32

%23 = scf.if %22 -> (!jasp.QuantumState) {
scf.yield %9 : !jasp.QuantumState
} else {
%24 = arith.constant dense<1> : tensor
%25 = jasp.get_qubit %1, %24 : !jasp.QubitArray, tensor -> !jasp.Qubit
%26 = jasp.quantum_gate "x" (%25) , %9 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
scf.yield %26 : !jasp.QuantumState
}
%27, %28 = jasp.measure %4, %23 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
func.return %27, %28 : tensor, !jasp.QuantumState
}
}

The entire chain %10 -> %22 can be collapsed to a NOT. PR #534 introduces rewrites that resolve this (and similar) situations:

builtin.module @jasp_module {
func.func public @main(%arg0 : !jasp.QuantumState) -> (tensor, !jasp.QuantumState) {
%0 = arith.constant dense<2> : tensor
%1, %2 = jasp.create_qubits %0, %arg0 : !jasp.QuantumState, tensor -> !jasp.QubitArray, !jasp.QuantumState
%3 = arith.constant dense<0> : tensor
%4 = jasp.get_qubit %1, %3 : !jasp.QubitArray, tensor -> !jasp.Qubit
%5 = jasp.quantum_gate "h" (%4) , %2 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
%6 = arith.constant dense<1> : tensor
%7 = jasp.get_qubit %1, %6 : !jasp.QubitArray, tensor -> !jasp.Qubit
%8, %9 = jasp.measure %7, %5 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
%10 = tensor.extract %8[] : tensor
%11 = arith.constant true
%12 = arith.xori %10, %11 : i1
%13 = scf.if %12 -> (!jasp.QuantumState) {
scf.yield %9 : !jasp.QuantumState
} else {
%14 = arith.constant dense<1> : tensor
%15 = jasp.get_qubit %1, %14 : !jasp.QubitArray, tensor -> !jasp.Qubit
%16 = jasp.quantum_gate "x" (%15) , %9 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
scf.yield %16 : !jasp.QuantumState
}
%17, %18 = jasp.measure %4, %13 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
func.return %17, %18 : tensor, !jasp.QuantumState
}
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions