-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
https://github.com/termi-official/thunderbolt-reactant-experiments/blob/main/scripts/linsolve-mwe.jl
#tbaa_root = #llvm.tbaa_root<id = "custom_tbaa">
#tbaa_type_desc = #llvm.tbaa_type_desc<id = "custom_tbaa_addrspace(1)", members = {<#tbaa_root, 0>}>
#tbaa_tag = #llvm.tbaa_tag<base_type = #tbaa_type_desc, access_type = #tbaa_type_desc, offset = 0>
module @"reactant_spmv!" attributes {gpu.container_module, mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
llvm.module_flags [#llvm.mlir.module_flag<warning, "Dwarf Version", 2 : i32>, #llvm.mlir.module_flag<warning, "Debug Info Version", 3 : i32>]
llvm.mlir.global private unnamed_addr constant @mlir.llvm.nameless_global_0("ERROR: Out of dynamic GPU memory (trying to allocate %d bytes)\0A\00") {addr_space = 0 : i32, alignment = 1 : i64, dso_local, sym_visibility = "private"}
llvm.mlir.global private unnamed_addr constant @exception16("exception\00") {addr_space = 0 : i32, alignment = 1 : i64, dso_local, sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_bool_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @malloc(i64) -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @vprintf(!llvm.ptr, !llvm.ptr) -> i32 attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_int32_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_uint8_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_uint32_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_int8_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_float64_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_int64_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_float32_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_uint64_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_uint16_type() -> !llvm.ptr attributes {sym_visibility = "private"}
llvm.func local_unnamed_addr @jl_int16_type() -> !llvm.ptr attributes {sym_visibility = "private"}
gpu.module @gpumod___call__Z16gpu_spmv_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float64Li1ELi1E4_5__E22GenericSparseMatrixCSRISD_S5_SC_IS5_Li1ELi1E4_6__ESC_IS5_Li1ELi1E5_16__ESC_ISD_Li1ELi1E5_16__EESE__304 {
gpu.func @__call__Z16gpu_spmv_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float64Li1ELi1E4_5__E22GenericSparseMatrixCSRISD_S5_SC_IS5_Li1ELi1E4_6__ESC_IS5_Li1ELi1E5_16__ESC_ISD_Li1ELi1E5_16__EESE__304(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: !llvm.ptr<1>) kernel {
%c1_i32 = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%c5_i64 = arith.constant 5 : i64
%true = arith.constant true
%c6_i64 = arith.constant 6 : i64
%cst = arith.constant 0.000000e+00 : f64
%c16_i64 = arith.constant 16 : i64
%0 = nvvm.read.ptx.sreg.ctaid.x range <i32, 0, 1> : i32
%1 = arith.addi %0, %c1_i32 : i32
%2 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 5> : i32
%3 = arith.addi %2, %c1_i32 : i32
%4 = arith.extui %1 : i32 to i64
%5 = arith.extui %3 : i32 to i64
%6 = arith.subi %4, %c1_i64 : i64
%7 = arith.muli %6, %c5_i64 : i64
%8 = arith.addi %7, %5 : i64
%9 = arith.cmpi sge, %8, %c1_i64 : i64
%10 = arith.cmpi sle, %8, %c5_i64 : i64
%11 = arith.andi %9, %10 : i1
%12 = arith.cmpi sgt, %8, %c5_i64 : i64
%13 = arith.select %11, %12, %true : i1
llvm.cond_br %13, ^bb11, ^bb1
^bb1: // pred: ^bb0
%14 = arith.subi %8, %c1_i64 : i64
%15 = arith.cmpi uge, %14, %c6_i64 : i64
llvm.cond_br %15, ^bb2, ^bb3
^bb2: // 5 preds: ^bb1, ^bb3, ^bb6, ^bb7, ^bb9
gpu.return
^bb3: // pred: ^bb1
%16 = llvm.getelementptr inbounds %arg1[%14] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64
%17 = llvm.load %16 {alignment = 8 : i64, tbaa = [#tbaa_tag]} : !llvm.ptr<1> -> i64
%18 = arith.cmpi uge, %8, %c6_i64 : i64
llvm.cond_br %18, ^bb2, ^bb4
^bb4: // pred: ^bb3
%19 = llvm.getelementptr inbounds %arg1[%8] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64
%20 = llvm.load %19 {alignment = 8 : i64, tbaa = [#tbaa_tag]} : !llvm.ptr<1> -> i64
%21 = arith.subi %20, %c1_i64 : i64
%22 = arith.cmpi sgt, %17, %21 : i64
%23 = arith.subi %17, %c1_i64 : i64
%24 = arith.select %22, %23, %21 {fastmathFlags = #llvm.fastmath<none>} : i64
%25 = arith.cmpi slt, %24, %17 : i64
llvm.cond_br %25, ^bb9(%cst : f64), ^bb5
^bb5: // pred: ^bb4
llvm.br ^bb6(%17, %cst : i64, f64)
^bb6(%26: i64, %27: f64): // 2 preds: ^bb5, ^bb8
%28 = arith.subi %26, %c1_i64 : i64
%29 = arith.cmpi uge, %28, %c16_i64 : i64
llvm.cond_br %29, ^bb2, ^bb7
^bb7: // pred: ^bb6
%30 = llvm.getelementptr inbounds %arg2[%28] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, i64
%31 = llvm.load %30 {alignment = 8 : i64, tbaa = [#tbaa_tag]} : !llvm.ptr<1> -> i64
%32 = llvm.getelementptr inbounds %arg3[%28] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f64
%33 = llvm.load %32 {alignment = 8 : i64, tbaa = [#tbaa_tag]} : !llvm.ptr<1> -> f64
%34 = arith.subi %31, %c1_i64 : i64
%35 = arith.cmpi uge, %34, %c5_i64 : i64
llvm.cond_br %35, ^bb2, ^bb8
^bb8: // pred: ^bb7
%36 = llvm.getelementptr inbounds %arg4[%34] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f64
%37 = llvm.load %36 {alignment = 8 : i64, tbaa = [#tbaa_tag]} : !llvm.ptr<1> -> f64
%38 = arith.mulf %33, %37 {fastmathFlags = #llvm.fastmath<none>} : f64
%39 = arith.addf %27, %38 {fastmathFlags = #llvm.fastmath<none>} : f64
%40 = arith.addi %26, %c1_i64 : i64
%41 = arith.cmpi eq, %26, %24 : i64
llvm.cond_br %41, ^bb9(%39 : f64), ^bb6(%40, %39 : i64, f64)
^bb9(%42: f64): // 2 preds: ^bb4, ^bb8
%43 = arith.cmpi uge, %14, %c5_i64 : i64
llvm.cond_br %43, ^bb2, ^bb10
^bb10: // pred: ^bb9
%44 = llvm.getelementptr inbounds %arg0[%14] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f64
llvm.store %42, %44 {alignment = 8 : i64, tbaa = [#tbaa_tag]} : f64, !llvm.ptr<1>
llvm.br ^bb11
^bb11: // 2 preds: ^bb0, ^bb10
llvm.br ^bb12
^bb12: // pred: ^bb11
gpu.return
}
}
func.func private @"##call__Z16gpu_spmv_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float64Li1ELi1E4_5__E22GenericSparseMatrixCSRISD_S5_SC_IS5_Li1ELi1E4_6__ESC_IS5_Li1ELi1E5_16__ESC_ISD_Li1ELi1E5_16__EESE_#304$call$3"(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: !llvm.ptr<1>) {
%c1_i64 = arith.constant 1 : i64
%c5_i64 = arith.constant 5 : i64
%c0_i32 = arith.constant 0 : i32
%0 = "enzymexla.get_stream"() : () -> !gpu.async.token
%1 = gpu.launch_func async [%0] @gpumod___call__Z16gpu_spmv_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float64Li1ELi1E4_5__E22GenericSparseMatrixCSRISD_S5_SC_IS5_Li1ELi1E4_6__ESC_IS5_Li1ELi1E5_16__ESC_ISD_Li1ELi1E5_16__EESE__304::@__call__Z16gpu_spmv_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float64Li1ELi1E4_5__E22GenericSparseMatrixCSRISD_S5_SC_IS5_Li1ELi1E4_6__ESC_IS5_Li1ELi1E5_16__ESC_ISD_Li1ELi1E5_16__EESE__304 blocks in (%c1_i64, %c1_i64, %c1_i64) threads in (%c5_i64, %c1_i64, %c1_i64) : i64 dynamic_shared_memory_size %c0_i32 args(%arg0 : !llvm.ptr<1>, %arg1 : !llvm.ptr<1>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>, %arg4 : !llvm.ptr<1>)
return
}
func.func @main(%arg0: tensor<5xf64> {tf.aliasing_output = 0 : i32}, %arg1: tensor<6xi64>, %arg2: tensor<16xi64>, %arg3: tensor<16xf64>, %arg4: tensor<5xf64>) -> tensor<5xf64> {
%0 = enzymexla.jit_call @"##call__Z16gpu_spmv_kernel_16CompilerMetadataI11DynamicSize12DynamicCheckv16CartesianIndicesILi1E5TupleI5OneToI5Int64EEE7NDRangeILi1ES0_S0_S8_S8_EE13CuTracedArrayI7Float64Li1ELi1E4_5__E22GenericSparseMatrixCSRISD_S5_SC_IS5_Li1ELi1E4_6__ESC_IS5_Li1ELi1E5_16__ESC_ISD_Li1ELi1E5_16__EESE_#304$call$3" (%arg0, %arg1, %arg2, %arg3, %arg4) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>], xla_side_effect_free} : (tensor<5xf64>, tensor<6xi64>, tensor<16xi64>, tensor<16xf64>, tensor<5xf64>) -> tensor<5xf64>
return %0 : tensor<5xf64>
}
}
julia> @code_hlo raise=true spmv!(br, Ar, xr)
loc(callsite(fused<#llvm.di_subprogram<id = distinct[0]<>, compileUnit = <id = distinct[1]<>, sourceLanguage = DW_LANG_Julia, file = <"julia" in ".">, producer = "julia", isOptimized = true, emissionKind = None, nameTableKind = None>, scope = #llvm.di_file<"/mnt/software/lux/Reactant.jl/envs/sparse/linsolve.jl" in ".">, name = "macro expansion;", linkageName = "macro expansion", file = <"/mnt/software/lux/Reactant.jl/envs/sparse/linsolve.jl" in ".">, subprogramFlags = "Definition|Optimized", type = <>>>["/mnt/software/lux/Reactant.jl/envs/sparse/linsolve.jl":86:0] at callsite(fused<#llvm.di_subprogram<id = distinct[2]<>, compileUnit = <id = distinct[1]<>, sourceLanguage = DW_LANG_Julia, file = <"julia" in ".">, producer = "julia", isOptimized = true, emissionKind = None, nameTableKind = None>, scope = #llvm.di_file<"/mnt/.julia/packages/KernelAbstractions/sWSE0/src/macros.jl" in ".">, name = "gpu_spmv_kernel!;", linkageName = "gpu_spmv_kernel!", file = <"/mnt/.julia/packages/KernelAbstractions/sWSE0/src/macros.jl" in ".">, subprogramFlags = "Definition|Optimized", type = <>>>["/mnt/.julia/packages/KernelAbstractions/sWSE0/src/macros.jl":322:0] at fused<#llvm.di_subprogram<id = distinct[3]<>, compileUnit = <id = distinct[1]<>, sourceLanguage = DW_LANG_Julia, file = <"julia" in ".">, producer = "julia", isOptimized = true, emissionKind = None, nameTableKind = None>, name = "gpu_spmv_kernel!", linkageName = "julia_gpu_spmv_kernel!_63126", file = <"none" in ".">, subprogramFlags = "Definition|Optimized", type = <>>>["none":0:0]))): error: 'scf.yield' op must be the last operation in the parent block
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_aUkj6J/module_001_NhsJ_post_all_pm.mlir
└ @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:119
ERROR: "failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:934
[3] run_pass_pipeline!
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:929 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(spmv!), args::Tuple{…}, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, optimize::Bool, cudnn_hlo_optimize::Bool, shardy_passes::Symbol, no_nan::Bool, transpose_propagate::Symbol, reshape_propagate::Symbol, optimize_communications::Bool, assert_nonallocating::Bool, backend::String, raise::Bool, raise_first::Bool, donated_args::Symbol, optimize_then_pad::Bool, runtime::Val{…}, kwargs::@Kwargs{})
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:1303
[5] (::Reactant.Compiler.var"#8#9"{Nothing, @Kwargs{…}, typeof(spmv!), Tuple{…}})(ctx::Reactant.MLIR.IR.Context)
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:982
[6] with_context(f::Reactant.Compiler.var"#8#9"{Nothing, @Kwargs{…}, typeof(spmv!), Tuple{…}}; allow_use_existing::Bool)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:99
[7] with_context(f::Function)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:82
[8] compile_mlir(f::Function, args::Tuple{…}; client::Nothing, kwargs::@Kwargs{…})
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:979
[9] top-level scope
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:2145
Some type information was truncated. Use `show(err)` to see complete types.
Metadata
Metadata
Assignees
Labels
No labels