from qrisp import QuantumVariable, h, cx, measure, control, x
from qrisp.jasp import make_jaspr
def bell():
qv = QuantumVariable(2)
h(qv[0])
c = measure(qv[1])
with control(c==0):
x(qv[1])
return measure(qv[0])
jaspr = make_jaspr(bell)()
module = jaspr.to_mlir(lower_stablehlo=True)
print("=== Original MLIR ===")
print(module)
=== Original MLIR ===
builtin.module @jasp_module {
func.func public @main(%arg0 : !jasp.QuantumState) -> (tensor, !jasp.QuantumState) {
%0 = arith.constant dense<2> : tensor
%1, %2 = jasp.create_qubits %0, %arg0 : !jasp.QuantumState, tensor -> !jasp.QubitArray, !jasp.QuantumState
%3 = arith.constant dense<0> : tensor
%4 = jasp.get_qubit %1, %3 : !jasp.QubitArray, tensor -> !jasp.Qubit
%5 = jasp.quantum_gate "h" (%4) , %2 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
%6 = arith.constant dense<1> : tensor
%7 = jasp.get_qubit %1, %6 : !jasp.QubitArray, tensor -> !jasp.Qubit
%8, %9 = jasp.measure %7, %5 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
%10 = tensor.empty() : tensor
%11 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%8 : tensor) outs(%10 : tensor) {
^bb0(%arg5 : i1, %arg6 : i64):
%12 = arith.extui %arg5 : i1 to i64
linalg.yield %12 : i64
} -> tensor
%13 = tensor.empty() : tensor
%14 = arith.constant 0 : i64
%15 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%11 : tensor) outs(%13 : tensor) {
^bb1(%arg3 : i64, %arg4 : i1):
%16 = arith.cmpi eq, %arg3, %14 : i64
linalg.yield %16 : i1
} -> tensor
%17 = tensor.empty() : tensor
%18 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor) outs(%17 : tensor) {
^bb2(%arg1 : i1, %arg2 : i32):
%19 = arith.extui %arg1 : i1 to i32
linalg.yield %19 : i32
} -> tensor
%20 = tensor.extract %18[] : tensor
%21 = arith.constant 0 : i32
%22 = arith.cmpi ne, %20, %21 : i32
%23 = scf.if %22 -> (!jasp.QuantumState) {
scf.yield %9 : !jasp.QuantumState
} else {
%24 = arith.constant dense<1> : tensor
%25 = jasp.get_qubit %1, %24 : !jasp.QubitArray, tensor -> !jasp.Qubit
%26 = jasp.quantum_gate "x" (%25) , %9 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
scf.yield %26 : !jasp.QuantumState
}
%27, %28 = jasp.measure %4, %23 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
func.return %27, %28 : tensor, !jasp.QuantumState
}
}
The entire chain %10 -> %22 can be collapsed to a NOT. PR #534 introduces rewrites that resolve this (and similar) situations:
builtin.module @jasp_module {
func.func public @main(%arg0 : !jasp.QuantumState) -> (tensor, !jasp.QuantumState) {
%0 = arith.constant dense<2> : tensor
%1, %2 = jasp.create_qubits %0, %arg0 : !jasp.QuantumState, tensor -> !jasp.QubitArray, !jasp.QuantumState
%3 = arith.constant dense<0> : tensor
%4 = jasp.get_qubit %1, %3 : !jasp.QubitArray, tensor -> !jasp.Qubit
%5 = jasp.quantum_gate "h" (%4) , %2 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
%6 = arith.constant dense<1> : tensor
%7 = jasp.get_qubit %1, %6 : !jasp.QubitArray, tensor -> !jasp.Qubit
%8, %9 = jasp.measure %7, %5 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
%10 = tensor.extract %8[] : tensor
%11 = arith.constant true
%12 = arith.xori %10, %11 : i1
%13 = scf.if %12 -> (!jasp.QuantumState) {
scf.yield %9 : !jasp.QuantumState
} else {
%14 = arith.constant dense<1> : tensor
%15 = jasp.get_qubit %1, %14 : !jasp.QubitArray, tensor -> !jasp.Qubit
%16 = jasp.quantum_gate "x" (%15) , %9 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
scf.yield %16 : !jasp.QuantumState
}
%17, %18 = jasp.measure %4, %13 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
func.return %17, %18 : tensor, !jasp.QuantumState
}
}
=== Original MLIR ===
builtin.module @jasp_module {
func.func public @main(%arg0 : !jasp.QuantumState) -> (tensor, !jasp.QuantumState) {
%0 = arith.constant dense<2> : tensor
%1, %2 = jasp.create_qubits %0, %arg0 : !jasp.QuantumState, tensor -> !jasp.QubitArray, !jasp.QuantumState
%3 = arith.constant dense<0> : tensor
%4 = jasp.get_qubit %1, %3 : !jasp.QubitArray, tensor -> !jasp.Qubit
%5 = jasp.quantum_gate "h" (%4) , %2 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
%6 = arith.constant dense<1> : tensor
%7 = jasp.get_qubit %1, %6 : !jasp.QubitArray, tensor -> !jasp.Qubit
%8, %9 = jasp.measure %7, %5 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
%10 = tensor.empty() : tensor
%11 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%8 : tensor) outs(%10 : tensor) {
^bb0(%arg5 : i1, %arg6 : i64):
%12 = arith.extui %arg5 : i1 to i64
linalg.yield %12 : i64
} -> tensor
%13 = tensor.empty() : tensor
%14 = arith.constant 0 : i64
%15 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%11 : tensor) outs(%13 : tensor) {
^bb1(%arg3 : i64, %arg4 : i1):
%16 = arith.cmpi eq, %arg3, %14 : i64
linalg.yield %16 : i1
} -> tensor
%17 = tensor.empty() : tensor
%18 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%15 : tensor) outs(%17 : tensor) {
^bb2(%arg1 : i1, %arg2 : i32):
%19 = arith.extui %arg1 : i1 to i32
linalg.yield %19 : i32
} -> tensor
%20 = tensor.extract %18[] : tensor
%21 = arith.constant 0 : i32
%22 = arith.cmpi ne, %20, %21 : i32
%23 = scf.if %22 -> (!jasp.QuantumState) {
scf.yield %9 : !jasp.QuantumState
} else {
%24 = arith.constant dense<1> : tensor
%25 = jasp.get_qubit %1, %24 : !jasp.QubitArray, tensor -> !jasp.Qubit
%26 = jasp.quantum_gate "x" (%25) , %9 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
scf.yield %26 : !jasp.QuantumState
}
%27, %28 = jasp.measure %4, %23 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
func.return %27, %28 : tensor, !jasp.QuantumState
}
}
The entire chain %10 -> %22 can be collapsed to a NOT. PR #534 introduces rewrites that resolve this (and similar) situations:
builtin.module @jasp_module {
func.func public @main(%arg0 : !jasp.QuantumState) -> (tensor, !jasp.QuantumState) {
%0 = arith.constant dense<2> : tensor
%1, %2 = jasp.create_qubits %0, %arg0 : !jasp.QuantumState, tensor -> !jasp.QubitArray, !jasp.QuantumState
%3 = arith.constant dense<0> : tensor
%4 = jasp.get_qubit %1, %3 : !jasp.QubitArray, tensor -> !jasp.Qubit
%5 = jasp.quantum_gate "h" (%4) , %2 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
%6 = arith.constant dense<1> : tensor
%7 = jasp.get_qubit %1, %6 : !jasp.QubitArray, tensor -> !jasp.Qubit
%8, %9 = jasp.measure %7, %5 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
%10 = tensor.extract %8[] : tensor
%11 = arith.constant true
%12 = arith.xori %10, %11 : i1
%13 = scf.if %12 -> (!jasp.QuantumState) {
scf.yield %9 : !jasp.QuantumState
} else {
%14 = arith.constant dense<1> : tensor
%15 = jasp.get_qubit %1, %14 : !jasp.QubitArray, tensor -> !jasp.Qubit
%16 = jasp.quantum_gate "x" (%15) , %9 : (!jasp.Qubit) , !jasp.QuantumState -> !jasp.QuantumState
scf.yield %16 : !jasp.QuantumState
}
%17, %18 = jasp.measure %4, %13 : !jasp.Qubit, !jasp.QuantumState -> tensor, !jasp.QuantumState
func.return %17, %18 : tensor, !jasp.QuantumState
}
}