Skip to content

Conversation

@ritu-thombre99
Copy link
Contributor

@ritu-thombre99 ritu-thombre99 commented Jul 25, 2025

Context:

Description of the Change:

When running a quantum circuit on a hardware with certain connectivity constraints, two-qubit gates like CNOTs can only be executed using physical qubits that are connected by an edge on the hardware.

Hence, necessary SWAPs need to be inserted in order to route the logical qubits so that they are mapped to an edge on the device, while respecting other compiling constraint i.e. order the gate dependencies from the input quantum circuit, and ensuring compiled quantum circuit is equivalent to the input quantum circuit.

Following image shows a simple example:

image

Benefits:

  1. Support for qubit mapping and routing algorithm using state-of-the-art
    SABRE

Possible Drawbacks:

  1. Assumption that MLIR is ahead of time with no function arguments
  2. Currently, the pass doesn't work with parametric gates Params are SSA values which are lost when the original function is deleted. Runtime segmentation fault occurs if parametric gates are used
  3. Original return value is replaced with qml.state with statevector of size 2^N, where N is the number of physical qubits on the device, calculated from coupling_map provided.

Related GitHub Issues:
#1928, #1939

Example Usage

Input circuit is a 3 qubit circuit with CNOT on all possible connections:

import pennylane as qml
dev = qml.device("lightning.qubit")
my_pass_pipeline = {
    "route-circuit": {"hardware-graph" : "(0,1);(1,2);(2,3);(3,4);"},
}

@qml.qjit(circuit_transform_pipeline = my_pass_pipeline, keep_intermediate=True) 
@qml.qnode(dev)
def circ():
     qml.H(0)
     qml.CNOT([0,1])
     qml.CNOT([1,2])
     qml.CNOT([0,2])
     return qml.state()

print(circ.mlir_opt)
print(circ())
  • The example above can be directly run from Python
  • Can also be ran from Catalyst cli: quantum-opt --route-circuit="hardware-graph=(0,1);(1,2);(2,3);(3,4);" input.mlir

Output can be verified using intermediate MLIR files or printing mlir_opt

Random Initial Mapping: 
0->4
1->1
2->2
3->3
4->0
module @circ {
  llvm.func @__catalyst__rt__finalize()
  llvm.func @__catalyst__rt__initialize(!llvm.ptr)
  llvm.func @__catalyst__rt__device_release()
  llvm.func @__catalyst__rt__qubit_release_array(!llvm.ptr)
  llvm.func @__catalyst__qis__State(!llvm.ptr, i64, ...)
  llvm.func @__catalyst__rt__num_qubits() -> i64
  llvm.func @__catalyst__qis__CNOT(!llvm.ptr, !llvm.ptr, !llvm.ptr)
  llvm.func @__catalyst__qis__SWAP(!llvm.ptr, !llvm.ptr, !llvm.ptr)
  llvm.func @__catalyst__qis__Hadamard(!llvm.ptr, !llvm.ptr)
  llvm.func @__catalyst__rt__array_get_element_ptr_1d(!llvm.ptr, i64) -> !llvm.ptr
  llvm.func @__catalyst__rt__qubit_allocate_array(i64) -> !llvm.ptr
  llvm.mlir.global internal constant @"{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"("{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\00") {addr_space = 0 : i32}
  llvm.mlir.global internal constant @LightningSimulator("LightningSimulator\00") {addr_space = 0 : i32}
  llvm.mlir.global internal constant @"/Users/ritu.thombre/Desktop/catalyst/.venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib"("/Users/ritu.thombre/Desktop/catalyst/.venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib\00") {addr_space = 0 : i32}
  llvm.func @__catalyst__rt__device_init(!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1)
  llvm.func @_mlir_memref_to_llvm_alloc(i64) -> !llvm.ptr
  llvm.func @jit_circ() -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> attributes {llvm.copy_memref, llvm.emit_c_interface} {
    %0 = llvm.mlir.poison : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>
    %1 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %2 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
    %3 = llvm.mlir.zero : !llvm.ptr
    %4 = llvm.mlir.constant(1 : index) : i64
    %5 = llvm.mlir.constant(0 : index) : i64
    %6 = llvm.mlir.constant(3735928559 : index) : i64
    %7 = llvm.call @circ_0() : () -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>
    %8 = llvm.extractvalue %7[0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %9 = llvm.extractvalue %7[1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %10 = llvm.extractvalue %7[0, 0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %11 = llvm.ptrtoint %10 : !llvm.ptr to i64
    %12 = llvm.icmp "eq" %6, %11 : i64
    llvm.cond_br %12, ^bb1, ^bb2
  ^bb1:  // pred: ^bb0
    %13 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, i64
    %14 = llvm.ptrtoint %13 : !llvm.ptr to i64
    %15 = llvm.call @_mlir_memref_to_llvm_alloc(%14) : (i64) -> !llvm.ptr
    %16 = llvm.insertvalue %15, %2[0] : !llvm.struct<(ptr, ptr, i64)> 
    %17 = llvm.insertvalue %15, %16[1] : !llvm.struct<(ptr, ptr, i64)> 
    %18 = llvm.insertvalue %5, %17[2] : !llvm.struct<(ptr, ptr, i64)> 
    %19 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, i64
    %20 = llvm.ptrtoint %19 : !llvm.ptr to i64
    %21 = llvm.mul %20, %4 : i64
    %22 = llvm.extractvalue %7[0, 1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %23 = llvm.extractvalue %7[0, 2] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %24 = llvm.getelementptr inbounds %22[%23] : (!llvm.ptr, i64) -> !llvm.ptr, i64
    "llvm.intr.memcpy"(%15, %24, %21) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
    llvm.br ^bb3(%18 : !llvm.struct<(ptr, ptr, i64)>)
  ^bb2:  // pred: ^bb0
    llvm.br ^bb3(%8 : !llvm.struct<(ptr, ptr, i64)>)
  ^bb3(%25: !llvm.struct<(ptr, ptr, i64)>):  // 2 preds: ^bb1, ^bb2
    llvm.br ^bb4
  ^bb4:  // pred: ^bb3
    %26 = llvm.extractvalue %7[1, 0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %27 = llvm.ptrtoint %26 : !llvm.ptr to i64
    %28 = llvm.icmp "eq" %6, %27 : i64
    llvm.cond_br %28, ^bb5, ^bb6
  ^bb5:  // pred: ^bb4
    %29 = llvm.extractvalue %7[1, 3] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %30 = llvm.alloca %4 x !llvm.array<1 x i64> : (i64) -> !llvm.ptr
    llvm.store %29, %30 : !llvm.array<1 x i64>, !llvm.ptr
    %31 = llvm.getelementptr inbounds %30[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<1 x i64>
    %32 = llvm.load %31 : !llvm.ptr -> i64
    %33 = llvm.getelementptr %3[%32] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f64, f64)>
    %34 = llvm.ptrtoint %33 : !llvm.ptr to i64
    %35 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f64, f64)>
    %36 = llvm.ptrtoint %35 : !llvm.ptr to i64
    %37 = llvm.add %34, %36 : i64
    %38 = llvm.call @_mlir_memref_to_llvm_alloc(%37) : (i64) -> !llvm.ptr
    %39 = llvm.ptrtoint %38 : !llvm.ptr to i64
    %40 = llvm.sub %36, %4 : i64
    %41 = llvm.add %39, %40 : i64
    %42 = llvm.urem %41, %36 : i64
    %43 = llvm.sub %41, %42 : i64
    %44 = llvm.inttoptr %43 : i64 to !llvm.ptr
    %45 = llvm.insertvalue %38, %1[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %46 = llvm.insertvalue %44, %45[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %47 = llvm.insertvalue %5, %46[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %48 = llvm.insertvalue %32, %47[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %49 = llvm.insertvalue %4, %48[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %50 = llvm.extractvalue %7[1, 3, 0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %51 = llvm.mul %50, %4 : i64
    %52 = llvm.getelementptr %3[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f64, f64)>
    %53 = llvm.ptrtoint %52 : !llvm.ptr to i64
    %54 = llvm.mul %51, %53 : i64
    %55 = llvm.extractvalue %7[1, 1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %56 = llvm.extractvalue %7[1, 2] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %57 = llvm.getelementptr inbounds %55[%56] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f64, f64)>
    "llvm.intr.memcpy"(%44, %57, %54) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
    llvm.br ^bb7(%49 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)
  ^bb6:  // pred: ^bb4
    llvm.br ^bb7(%9 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)
  ^bb7(%58: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>):  // 2 preds: ^bb5, ^bb6
    llvm.br ^bb8
  ^bb8:  // pred: ^bb7
    %59 = llvm.insertvalue %25, %0[0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %60 = llvm.insertvalue %58, %59[1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    llvm.return %60 : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>
  }
  llvm.func @_catalyst_pyface_jit_circ(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
    llvm.call @_catalyst_ciface_jit_circ(%arg0) : (!llvm.ptr) -> ()
    llvm.return
  }
  llvm.func @_catalyst_ciface_jit_circ(%arg0: !llvm.ptr) attributes {llvm.copy_memref, llvm.emit_c_interface} {
    %0 = llvm.call @jit_circ() : () -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>
    llvm.store %0, %arg0 : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>, !llvm.ptr
    llvm.return
  }
  llvm.func @circ_0() -> !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> {
    %0 = llvm.mlir.poison : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>
    %1 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
    %2 = llvm.mlir.constant(0 : index) : i64
    %3 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
    %4 = llvm.mlir.constant(64 : index) : i64
    %5 = llvm.mlir.constant(1 : index) : i64
    %6 = llvm.mlir.zero : !llvm.ptr
    %7 = llvm.mlir.constant(4 : i64) : i64
    %8 = llvm.mlir.constant(3 : i64) : i64
    %9 = llvm.mlir.constant(2 : i64) : i64
    %10 = llvm.mlir.constant(5 : i64) : i64
    %11 = llvm.mlir.constant(false) : i1
    %12 = llvm.mlir.addressof @"{'mcmc': False, 'num_burnin': 0, 'kernel_name': None}" : !llvm.ptr
    %13 = llvm.mlir.addressof @LightningSimulator : !llvm.ptr
    %14 = llvm.mlir.addressof @"/Users/ritu.thombre/Desktop/catalyst/.venv/lib/python3.12/site-packages/pennylane_lightning/liblightning_qubit_catalyst.dylib" : !llvm.ptr
    %15 = llvm.mlir.constant(64 : i64) : i64
    %16 = llvm.mlir.constant(0 : i64) : i64
    %17 = llvm.mlir.constant(1 : i64) : i64
    %18 = llvm.alloca %17 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr
    %19 = llvm.getelementptr inbounds %14[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<126 x i8>
    %20 = llvm.getelementptr inbounds %13[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<19 x i8>
    %21 = llvm.getelementptr inbounds %12[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<54 x i8>
    llvm.call @__catalyst__rt__device_init(%19, %20, %21, %16, %11) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i1) -> ()
    %22 = llvm.call @__catalyst__rt__qubit_allocate_array(%10) : (i64) -> !llvm.ptr
    %23 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %16) : (!llvm.ptr, i64) -> !llvm.ptr
    %24 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %17) : (!llvm.ptr, i64) -> !llvm.ptr
    %25 = llvm.load %24 : !llvm.ptr -> !llvm.ptr
    %26 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %9) : (!llvm.ptr, i64) -> !llvm.ptr
    %27 = llvm.load %26 : !llvm.ptr -> !llvm.ptr
    %28 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %8) : (!llvm.ptr, i64) -> !llvm.ptr
    %29 = llvm.load %28 : !llvm.ptr -> !llvm.ptr
    %30 = llvm.call @__catalyst__rt__array_get_element_ptr_1d(%22, %7) : (!llvm.ptr, i64) -> !llvm.ptr
    %31 = llvm.load %30 : !llvm.ptr -> !llvm.ptr
    llvm.call @__catalyst__qis__Hadamard(%31, %6) : (!llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__SWAP(%31, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__SWAP(%29, %27, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__CNOT(%27, %25, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__CNOT(%27, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__SWAP(%25, %27, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__CNOT(%27, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__SWAP(%25, %27, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    llvm.call @__catalyst__qis__CNOT(%27, %29, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
    %32 = llvm.call @__catalyst__rt__num_qubits() : () -> i64
    %33 = llvm.shl %17, %32 : i64
    %34 = llvm.icmp "ult" %32, %15 : i64
    %35 = llvm.select %34, %33, %16 : i1, i64
    %36 = llvm.getelementptr %6[1] : (!llvm.ptr) -> !llvm.ptr, i64
    %37 = llvm.ptrtoint %36 : !llvm.ptr to i64
    %38 = llvm.add %37, %4 : i64
    %39 = llvm.call @_mlir_memref_to_llvm_alloc(%38) : (i64) -> !llvm.ptr
    %40 = llvm.ptrtoint %39 : !llvm.ptr to i64
    %41 = llvm.sub %4, %5 : i64
    %42 = llvm.add %40, %41 : i64
    %43 = llvm.urem %42, %4 : i64
    %44 = llvm.sub %42, %43 : i64
    %45 = llvm.inttoptr %44 : i64 to !llvm.ptr
    %46 = llvm.insertvalue %39, %3[0] : !llvm.struct<(ptr, ptr, i64)> 
    %47 = llvm.insertvalue %45, %46[1] : !llvm.struct<(ptr, ptr, i64)> 
    %48 = llvm.insertvalue %2, %47[2] : !llvm.struct<(ptr, ptr, i64)> 
    llvm.store %35, %45 : i64, !llvm.ptr
    %49 = llvm.getelementptr %6[%35] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(f64, f64)>
    %50 = llvm.ptrtoint %49 : !llvm.ptr to i64
    %51 = llvm.getelementptr %6[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f64, f64)>
    %52 = llvm.ptrtoint %51 : !llvm.ptr to i64
    %53 = llvm.add %50, %52 : i64
    %54 = llvm.call @_mlir_memref_to_llvm_alloc(%53) : (i64) -> !llvm.ptr
    %55 = llvm.ptrtoint %54 : !llvm.ptr to i64
    %56 = llvm.sub %52, %5 : i64
    %57 = llvm.add %55, %56 : i64
    %58 = llvm.urem %57, %52 : i64
    %59 = llvm.sub %57, %58 : i64
    %60 = llvm.inttoptr %59 : i64 to !llvm.ptr
    %61 = llvm.insertvalue %54, %1[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %62 = llvm.insertvalue %60, %61[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %63 = llvm.insertvalue %2, %62[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %64 = llvm.insertvalue %35, %63[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    %65 = llvm.insertvalue %5, %64[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
    llvm.store %65, %18 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr
    llvm.call @__catalyst__qis__State(%18, %16) vararg(!llvm.func<void (ptr, i64, ...)>) : (!llvm.ptr, i64) -> ()
    llvm.call @__catalyst__rt__qubit_release_array(%22) : (!llvm.ptr) -> ()
    llvm.call @__catalyst__rt__device_release() : () -> ()
    %66 = llvm.insertvalue %48, %0[0] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    %67 = llvm.insertvalue %65, %66[1] : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)> 
    llvm.return %67 : !llvm.struct<(struct<(ptr, ptr, i64)>, struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>)>
  }
  llvm.func @setup() {
    %0 = llvm.mlir.zero : !llvm.ptr
    llvm.call @__catalyst__rt__initialize(%0) : (!llvm.ptr) -> ()
    llvm.return
  }
  llvm.func @teardown() {
    llvm.call @__catalyst__rt__finalize() : () -> ()
    llvm.return
  }
}
[0.70710678+0.j 0.        +0.j 0.        +0.j 0.        +0.j
 0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j
 0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j
 0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j
 0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j
 0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j
 0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j
 0.70710678+0.j 0.        +0.j 0.        +0.j 0.        +0.j]

TODO

  1. Improve the ExtractOp fetching from CustomOp
  2. Optimize code
  3. Support for parametric gates

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

ritu-thombre99 and others added 17 commits August 4, 2025 15:54
In particular, the MLIR is automatically obtained from the user
function, as are the input types and values and result types.

The user function is currently traced with the static_argnum feature
to effect the desired concretization, however this comes with two:
- static_argnums are only allowed for hashable values, which arrays are
  not, so scalar arguments are currently assumed
- while this achieves a removal of the function arguments, and in
  principle makes the program entirely static, computation based on
  those values will not necessarily be folded away
  the latter would require something like jax.ensure_compile_time_eval,
  or additional constant folding / partial eval in the compiler
Instead of using static_argnums which only works with hashable types,
we use a closure to make the target function parameter-less.

One observation is that non-scalar arguments are non-the-less hoisted
to the entry point function by jax, but this is a general problem we
have encountered before.
@ritu-thombre99 ritu-thombre99 requested review from mehrdad2m and removed request for erick-xanadu August 6, 2025 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants