-
Notifications
You must be signed in to change notification settings - Fork 57
SABRE qubit mapping and routing implementation in Catalyst #1940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ritu-thombre99
wants to merge
56
commits into
main
Choose a base branch
from
ritu/sabre
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
Hello. You may have forgotten to update the changelog!
|
In particular, the MLIR is automatically obtained from the user function, as are the input types and values and result types. The user function is currently traced with the static_argnum feature to effect the desired concretization, however this comes with two: - static_argnums are only allowed for hashable values, which arrays are not, so scalar arguments are currently assumed - while this achieves a removal of the function arguments, and in principle makes the program entirely static, computation based on those values will not necessarily be folded away the latter would require something like jax.ensure_compile_time_eval, or additional constant folding / partial eval in the compiler
Instead of using static_argnums which only works with hashable types, we use a closure to make the target function parameter-less. One observation is that non-scalar arguments are non-the-less hoisted to the entry point function by jax, but this is a general problem we have encountered before.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Context:
Description of the Change:
When running a quantum circuit on a hardware with certain connectivity constraints, two-qubit gates like CNOTs can only be executed using physical qubits that are connected by an edge on the hardware.
Hence, necessary SWAPs need to be inserted in order to route the logical qubits so that they are mapped to an edge on the device, while respecting other compiling constraint i.e. order the gate dependencies from the input quantum circuit, and ensuring compiled quantum circuit is equivalent to the input quantum circuit.
Following image shows a simple example:
Benefits:
SABRE
Possible Drawbacks:
Related GitHub Issues:
#1928, #1939
Example Usage
Input circuit is a 3 qubit circuit with CNOT on all possible connections:
quantum-opt --route-circuit="hardware-graph=(0,1);(1,2);(2,3);(3,4);" input.mlirOutput can be verified using intermediate MLIR files or printing mlir_opt
Random Initial Mapping: 0->4 1->1 2->2 3->3 4->0 module @circ { llvm.func @__catalyst__rt__finalize() llvm.func @__catalyst__rt__initialize(!llvm.ptr) llvm.func @__catalyst__rt__device_release() llvm.func @__catalyst__rt__qubit_release_array(!llvm.ptr) llvm.func @__catalyst__qis__State(!llvm.ptr, i64, ...) llvm.func @__catalyst__rt__num_qubits() -> i64 llvm.func @__catalyst__qis__CNOT(!llvm.ptr, !llvm.ptr, !llvm.ptr) llvm.func @__catalyst__qis__SWAP(!llvm.ptr, !llvm.ptr, !llvm.ptr) llvm.func @__catalyst__qis__Hadamard(!llvm.ptr, !llvm.ptr) llvm.func @__catalyst__rt__array_get_element_ptr_1d(!llvm.ptr, i64) -> !llvm.ptr llvm.func @__catalyst__rt__qubit_allocate_array(i64) -> !llvm.ptr llvm.mlir.global internal constant @"{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"("{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\00") {addr_space = 0 : i32} llvm.mlir.global internal constant @LightningSimulator("LightningSimulator\00") {addr_space = 0 : i32} llvm.mlir.global internal constant @"/Users/ritu.thombre/Desktop/catalyst/.venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib"("/Users/ritu.thombre/Desktop/catalyst/.venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib\00") {addr_space = 0 : i32} llvm.func @__catalyst__rt__device_init(!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1) llvm.func @_mlir_memref_to_llvm_alloc(i64) -> !llvm.ptr llvm.func @jit_circ() -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> attributes {llvm.copy_memref, llvm.emit_c_interface} { %0 = llvm.mlir.poison : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %1 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> %3 = llvm.mlir.zero : !llvm.ptr %4 = llvm.mlir.constant(1 : index) : i64 %5 = llvm.mlir.constant(0 : index) : i64 %6 = llvm.mlir.constant(3735928559 : index) : i64 %7 = llvm.call @circ_0() : () -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %8 = llvm.extractvalue %7[0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %9 = llvm.extractvalue %7[1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %10 = llvm.extractvalue %7[0, 0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %11 = llvm.ptrtoint %10 : !llvm.ptr to i64 %12 = llvm.icmp "eq" %6, %11 : i64 llvm.cond_br %12, ^bb1, ^bb2 ^bb1: // pred: ^bb0 %13 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, i64 %14 = llvm.ptrtoint %13 : !llvm.ptr to i64 %15 = llvm.call @_mlir_memref_to_llvm_alloc(%14) : (i64) -> !llvm.ptr %16 = llvm.insertvalue %15, %2[0] : !llvm.struct<(ptr, ptr, i64)> %17 = llvm.insertvalue %15, %16[1] : !llvm.struct<(ptr, ptr, i64)> %18 = llvm.insertvalue %5, %17[2] : !llvm.struct<(ptr, ptr, i64)> %19 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, i64 %20 = llvm.ptrtoint %19 : !llvm.ptr to i64 %21 = llvm.mul %20, %4 : i64 %22 = llvm.extractvalue %7[0, 1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %23 = llvm.extractvalue %7[0, 2] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %24 = llvm.getelementptr inbounds %22[%23] : (!llvm.ptr, i64) -> !llvm.ptr, i64 "llvm.intr.memcpy"(%15, %24, %21) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> () llvm.br ^bb3(%18 : !llvm.struct<(ptr, ptr, i64)>) ^bb2: // pred: ^bb0 llvm.br ^bb3(%8 : !llvm.struct<(ptr, ptr, i64)>) ^bb3(%25: !llvm.struct<(ptr, ptr, i64)>): // 2 preds: ^bb1, ^bb2 llvm.br ^bb4 ^bb4: // pred: ^bb3 %26 = llvm.extractvalue %7[1, 0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %27 = llvm.ptrtoint %26 : !llvm.ptr to i64 %28 = llvm.icmp "eq" %6, %27 : i64 llvm.cond_br %28, ^bb5, ^bb6 ^bb5: // pred: ^bb4 %29 = llvm.extractvalue %7[1, 3] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %30 = llvm.alloca %4 x !llvm.array<1 x i64> : (i64) -> !llvm.ptr llvm.store %29, %30 : !llvm.array<1 x i64>, !llvm.ptr %31 = llvm.getelementptr inbounds %30[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<1 x i64> %32 = llvm.load %31 : !llvm.ptr -> i64 %33 = llvm.getelementptr %3[%32] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f64, f64)> %34 = llvm.ptrtoint %33 : !llvm.ptr to i64 %35 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f64, f64)> %36 = llvm.ptrtoint %35 : !llvm.ptr to i64 %37 = llvm.add %34, %36 : i64 %38 = llvm.call @_mlir_memref_to_llvm_alloc(%37) : (i64) -> !llvm.ptr %39 = llvm.ptrtoint %38 : !llvm.ptr to i64 %40 = llvm.sub %36, %4 : i64 %41 = llvm.add %39, %40 : i64 %42 = llvm.urem %41, %36 : i64 %43 = llvm.sub %41, %42 : i64 %44 = llvm.inttoptr %43 : i64 to !llvm.ptr %45 = llvm.insertvalue %38, %1[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %46 = llvm.insertvalue %44, %45[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %47 = llvm.insertvalue %5, %46[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %48 = llvm.insertvalue %32, %47[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %49 = llvm.insertvalue %4, %48[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %50 = llvm.extractvalue %7[1, 3, 0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %51 = llvm.mul %50, %4 : i64 %52 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f64, f64)> %53 = llvm.ptrtoint %52 : !llvm.ptr to i64 %54 = llvm.mul %51, %53 : i64 %55 = llvm.extractvalue %7[1, 1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %56 = llvm.extractvalue %7[1, 2] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %57 = llvm.getelementptr inbounds %55[%56] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f64, f64)> "llvm.intr.memcpy"(%44, %57, %54) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> () llvm.br ^bb7(%49 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) ^bb6: // pred: ^bb4 llvm.br ^bb7(%9 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) ^bb7(%58: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): // 2 preds: ^bb5, ^bb6 llvm.br ^bb8 ^bb8: // pred: ^bb7 %59 = llvm.insertvalue %25, %0[0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %60 = llvm.insertvalue %58, %59[1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> llvm.return %60 : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> } llvm.func @_catalyst_pyface_jit_circ(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { llvm.call @_catalyst_ciface_jit_circ(%arg0) : (!llvm.ptr) -> () llvm.return } llvm.func @_catalyst_ciface_jit_circ(%arg0: !llvm.ptr) attributes {llvm.copy_memref, llvm.emit_c_interface} { %0 = llvm.call @jit_circ() : () -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> llvm.store %0, %arg0 : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>, !llvm.ptr llvm.return } llvm.func @circ_0() -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> { %0 = llvm.mlir.poison : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %1 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.mlir.constant(0 : index) : i64 %3 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> %4 = llvm.mlir.constant(64 : index) : i64 %5 = llvm.mlir.constant(1 : index) : i64 %6 = llvm.mlir.zero : !llvm.ptr %7 = llvm.mlir.constant(4 : i64) : i64 %8 = llvm.mlir.constant(3 : i64) : i64 %9 = llvm.mlir.constant(2 : i64) : i64 %10 = llvm.mlir.constant(5 : i64) : i64 %11 = llvm.mlir.constant(false) : i1 %12 = llvm.mlir.addressof @"{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}" : !llvm.ptr %13 = llvm.mlir.addressof @LightningSimulator : !llvm.ptr %14 = llvm.mlir.addressof @"/Users/ritu.thombre/Desktop/catalyst/.venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib" : !llvm.ptr %15 = llvm.mlir.constant(64 : i64) : i64 %16 = llvm.mlir.constant(0 : i64) : i64 %17 = llvm.mlir.constant(1 : i64) : i64 %18 = llvm.alloca %17 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr %19 = llvm.getelementptr inbounds %14[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<126 x i8> %20 = llvm.getelementptr inbounds %13[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<19 x i8> %21 = llvm.getelementptr inbounds %12[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<54 x i8> llvm.call @__catalyst__rt__device_init(%19, %20, %21, %16, %11) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1) -> () %22 = llvm.call @__catalyst__rt__qubit_allocate_array(%10) : (i64) -> !llvm.ptr %23 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %16) : (!llvm.ptr, i64) -> !llvm.ptr %24 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %17) : (!llvm.ptr, i64) -> !llvm.ptr %25 = llvm.load %24 : !llvm.ptr -> !llvm.ptr %26 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %9) : (!llvm.ptr, i64) -> !llvm.ptr %27 = llvm.load %26 : !llvm.ptr -> !llvm.ptr %28 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %8) : (!llvm.ptr, i64) -> !llvm.ptr %29 = llvm.load %28 : !llvm.ptr -> !llvm.ptr %30 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %7) : (!llvm.ptr, i64) -> !llvm.ptr %31 = llvm.load %30 : !llvm.ptr -> !llvm.ptr llvm.call @__catalyst__qis__Hadamard(%31, %6) : (!llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__SWAP(%31, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__SWAP(%29, %27, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__CNOT(%27, %25, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__CNOT(%27, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__SWAP(%25, %27, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__CNOT(%27, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__SWAP(%25, %27, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () llvm.call @__catalyst__qis__CNOT(%27, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> () %32 = llvm.call @__catalyst__rt__num_qubits() : () -> i64 %33 = llvm.shl %17, %32 : i64 %34 = llvm.icmp "ult" %32, %15 : i64 %35 = llvm.select %34, %33, %16 : i1, i64 %36 = llvm.getelementptr %6[1] : (!llvm.ptr) -> !llvm.ptr, i64 %37 = llvm.ptrtoint %36 : !llvm.ptr to i64 %38 = llvm.add %37, %4 : i64 %39 = llvm.call @_mlir_memref_to_llvm_alloc(%38) : (i64) -> !llvm.ptr %40 = llvm.ptrtoint %39 : !llvm.ptr to i64 %41 = llvm.sub %4, %5 : i64 %42 = llvm.add %40, %41 : i64 %43 = llvm.urem %42, %4 : i64 %44 = llvm.sub %42, %43 : i64 %45 = llvm.inttoptr %44 : i64 to !llvm.ptr %46 = llvm.insertvalue %39, %3[0] : !llvm.struct<(ptr, ptr, i64)> %47 = llvm.insertvalue %45, %46[1] : !llvm.struct<(ptr, ptr, i64)> %48 = llvm.insertvalue %2, %47[2] : !llvm.struct<(ptr, ptr, i64)> llvm.store %35, %45 : i64, !llvm.ptr %49 = llvm.getelementptr %6[%35] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f64, f64)> %50 = llvm.ptrtoint %49 : !llvm.ptr to i64 %51 = llvm.getelementptr %6[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f64, f64)> %52 = llvm.ptrtoint %51 : !llvm.ptr to i64 %53 = llvm.add %50, %52 : i64 %54 = llvm.call @_mlir_memref_to_llvm_alloc(%53) : (i64) -> !llvm.ptr %55 = llvm.ptrtoint %54 : !llvm.ptr to i64 %56 = llvm.sub %52, %5 : i64 %57 = llvm.add %55, %56 : i64 %58 = llvm.urem %57, %52 : i64 %59 = llvm.sub %57, %58 : i64 %60 = llvm.inttoptr %59 : i64 to !llvm.ptr %61 = llvm.insertvalue %54, %1[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %62 = llvm.insertvalue %60, %61[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %63 = llvm.insertvalue %2, %62[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %64 = llvm.insertvalue %35, %63[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %65 = llvm.insertvalue %5, %64[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> llvm.store %65, %18 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr llvm.call @__catalyst__qis__State(%18, %16) vararg(!llvm.func<void (ptr, i64, ...)>) : (!llvm.ptr, i64) -> () llvm.call @__catalyst__rt__qubit_release_array(%22) : (!llvm.ptr) -> () llvm.call @__catalyst__rt__device_release() : () -> () %66 = llvm.insertvalue %48, %0[0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> %67 = llvm.insertvalue %65, %66[1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> llvm.return %67 : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> } llvm.func @setup() { %0 = llvm.mlir.zero : !llvm.ptr llvm.call @__catalyst__rt__initialize(%0) : (!llvm.ptr) -> () llvm.return } llvm.func @teardown() { llvm.call @__catalyst__rt__finalize() : () -> () llvm.return } } [0.70710678+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0.70710678+0.j 0. +0.j 0. +0.j 0. +0.j]TODO