-
Notifications
You must be signed in to change notification settings - Fork 68
Revert "[DEBUG] Revert "Enable SPV_INTEL_fp_fast_math_mode
(#4058)"(#4473)"
#4576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
By relanding this change to |
Right I thought we didn't want to set the fp fast math mode until inductor changes the precision check. |
@anmyachev have you already reported to the PyTorch team the problem ? The benchmark verification code would need to be updated to allow Triton to set the fp fast math flag. |
I am in touch with @chuanqi129 about this topic. For now they are busy with pytorch 2.8 release testing (the status from last Friday). |
Few more updates/findings:
![]() Refs:
All tensors for comparison: |
Hi @etaf, mismatch of lines 19, 20 (from screenshot above) is also present on the main branch. I also checked For the reference: I run the model this way: UPD: intel/torch-xpu-ops#1855 can relate to this problem, need to recheck after this problem is fixed |
Hi, @anmyachev, sorry, I'm not familiar with this issue, so I didn't respond immediately. I need some time to understand the background and details before getting back to you. |
Hi, @anmyachev , the model failed after we switched to conv channel last mode, so you can verify it before that, you may use this pytorch commit 0504480f37714a289b2ba32c9cf32a5e50e86d38 which do not contain the channel last change. |
Got it, thanks @etaf |
Signed-off-by: Anatoly Myachev <[email protected]>
Signed-off-by: Anatoly Myachev <[email protected]>
third_party/intel/triton_xpu.cc
Outdated
!fastMath.has_value()) { | ||
// Default to allow contract when default fp fusion is not disabled. | ||
if ((!enableFpFusion.has_value() || enableFpFusion.value()) && | ||
!fastMath.has_value() && inst.hasNoNaNs()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@whitneywhtsang with hasNoNaNs
change detectron2_fasterrcnn_r_50_fpn
passes accuracy check locally. This change is inspired by
opt.NoNaNsFPMath = true; |
Could you take a look?
Here is
- benchmarks CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/16377530523/job/46281040257 (I forgot to change CI tag, so use
test
) - E2E accuracy test + float16 + torchbench (to test
detectron2_fasterrcnn_r_50_fpn
): https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/16377556558
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the performance regression recovered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Started a BMG performance run: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/16380688900
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the above BMG run result, the geomean for ATTN D_HEAD=128 CAUSAL=1
is 60TFlops, which is the bad performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the above BMG run result, the geomean for
ATTN D_HEAD=128 CAUSAL=1
is 60TFlops, which is the bad performance.
Hm, how many teraflops do you expect to see? I don't remember
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expect to see above 80TFLops, #4514 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into why that is. However, I have doubts that adding contract to all operations regardless of their safety in terms of accuracy is what is expected from the default behavior when the fast-math option is not enabled. I suspect that when AllowFPOpFusion
option is enabled, some analysis is performed that adds it only where it is safe. Maybe you've already tested this and know exactly how it works?
Managed to find the problem kernel. The only difference is in the use of Correct kernel; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"
; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !9 spir_func i64 @_Z12get_local_idj(i32) local_unnamed_addr #0
; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !12 spir_func i64 @_Z12get_group_idj(i32) local_unnamed_addr #0
; Function Attrs: mustprogress nofree nounwind willreturn memory(argmem: readwrite)
define spir_kernel void @triton_poi_fused__native_batch_norm_legit_no_training_relu_2(ptr addrspace(1) captures(none) %0, ptr addrspace(1) readonly captures(none) %1, ptr addrspace(1) readonly captures(none) %2, ptr addrspace(1) readonly captures(none) %3, ptr addrspace(1) readonly captures(none) %4, i32 %5, ptr addrspace(1) readnone captures(none) %6) local_unnamed_addr #1 !dbg !13 !intel_reqd_sub_group_size !14 !max_work_group_size !15 {
%8 = tail call spir_func i64 @_Z12get_group_idj(i32 0) #3, !dbg !16
%9 = trunc i64 %8 to i32, !dbg !16
%10 = shl i32 %9, 10, !dbg !17
%11 = tail call spir_func i64 @_Z12get_local_idj(i32 0) #3, !dbg !18
%12 = trunc i64 %11 to i32, !dbg !18
%13 = shl i32 %12, 3, !dbg !18
%14 = and i32 %13, 1016, !dbg !18
%15 = or disjoint i32 %14, %10, !dbg !19
%16 = srem i32 %15, 64, !dbg !20
%17 = sext i32 %15 to i64, !dbg !21
%18 = getelementptr half, ptr addrspace(1) %0, i64 %17, !dbg !21
%19 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 4, !dbg !22
%20 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 8, !dbg !22
%21 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 12, !dbg !22
%22 = load <2 x half>, ptr addrspace(1) %21, align 4, !dbg !22
%23 = load half, ptr addrspace(1) %18, align 16, !dbg !22
%24 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 2, !dbg !22
%25 = load half, ptr addrspace(1) %24, align 2, !dbg !22
%26 = load half, ptr addrspace(1) %19, align 4, !dbg !22
%27 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 6, !dbg !22
%28 = load half, ptr addrspace(1) %27, align 2, !dbg !22
%29 = load half, ptr addrspace(1) %20, align 8, !dbg !22
%30 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 10, !dbg !22
%31 = load half, ptr addrspace(1) %30, align 2, !dbg !22
%32 = extractelement <2 x half> %22, i64 0, !dbg !22
%33 = extractelement <2 x half> %22, i64 1, !dbg !22
%34 = fpext half %23 to float, !dbg !23
%35 = fpext half %25 to float, !dbg !23
%36 = fpext half %26 to float, !dbg !23
%37 = fpext half %28 to float, !dbg !23
%38 = fpext half %29 to float, !dbg !23
%39 = fpext half %31 to float, !dbg !23
%40 = fpext half %32 to float, !dbg !23
%41 = fpext half %33 to float, !dbg !23
%42 = sext i32 %16 to i64, !dbg !24
%43 = getelementptr half, ptr addrspace(1) %1, i64 %42, !dbg !24
%44 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 4, !dbg !25
%45 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 8, !dbg !25
%46 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 12, !dbg !25
%47 = load <2 x half>, ptr addrspace(1) %46, align 4, !dbg !25
%48 = load half, ptr addrspace(1) %43, align 16, !dbg !25
%49 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 2, !dbg !25
%50 = load half, ptr addrspace(1) %49, align 2, !dbg !25
%51 = load half, ptr addrspace(1) %44, align 4, !dbg !25
%52 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 6, !dbg !25
%53 = load half, ptr addrspace(1) %52, align 2, !dbg !25
%54 = load half, ptr addrspace(1) %45, align 8, !dbg !25
%55 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 10, !dbg !25
%56 = load half, ptr addrspace(1) %55, align 2, !dbg !25
%57 = extractelement <2 x half> %47, i64 0, !dbg !25
%58 = extractelement <2 x half> %47, i64 1, !dbg !25
%59 = fpext half %48 to float, !dbg !26
%60 = fpext half %50 to float, !dbg !26
%61 = fpext half %51 to float, !dbg !26
%62 = fpext half %53 to float, !dbg !26
%63 = fpext half %54 to float, !dbg !26
%64 = fpext half %56 to float, !dbg !26
%65 = fpext half %57 to float, !dbg !26
%66 = fpext half %58 to float, !dbg !26
%67 = getelementptr half, ptr addrspace(1) %2, i64 %42, !dbg !27
%68 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 4, !dbg !28
%69 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 8, !dbg !28
%70 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 12, !dbg !28
%71 = load <2 x half>, ptr addrspace(1) %70, align 4, !dbg !28
%72 = load half, ptr addrspace(1) %67, align 16, !dbg !28
%73 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 2, !dbg !28
%74 = load half, ptr addrspace(1) %73, align 2, !dbg !28
%75 = load half, ptr addrspace(1) %68, align 4, !dbg !28
%76 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 6, !dbg !28
%77 = load half, ptr addrspace(1) %76, align 2, !dbg !28
%78 = load half, ptr addrspace(1) %69, align 8, !dbg !28
%79 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 10, !dbg !28
%80 = load half, ptr addrspace(1) %79, align 2, !dbg !28
%81 = extractelement <2 x half> %71, i64 0, !dbg !28
%82 = extractelement <2 x half> %71, i64 1, !dbg !28
%83 = fpext half %72 to float, !dbg !29
%84 = fpext half %74 to float, !dbg !29
%85 = fpext half %75 to float, !dbg !29
%86 = fpext half %77 to float, !dbg !29
%87 = fpext half %78 to float, !dbg !29
%88 = fpext half %80 to float, !dbg !29
%89 = fpext half %81 to float, !dbg !29
%90 = fpext half %82 to float, !dbg !29
%91 = getelementptr half, ptr addrspace(1) %3, i64 %42, !dbg !30
%92 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 4, !dbg !31
%93 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 8, !dbg !31
%94 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 12, !dbg !31
%95 = load <2 x half>, ptr addrspace(1) %94, align 4, !dbg !31
%96 = load half, ptr addrspace(1) %91, align 16, !dbg !31
%97 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 2, !dbg !31
%98 = load half, ptr addrspace(1) %97, align 2, !dbg !31
%99 = load half, ptr addrspace(1) %92, align 4, !dbg !31
%100 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 6, !dbg !31
%101 = load half, ptr addrspace(1) %100, align 2, !dbg !31
%102 = load half, ptr addrspace(1) %93, align 8, !dbg !31
%103 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 10, !dbg !31
%104 = load half, ptr addrspace(1) %103, align 2, !dbg !31
%105 = extractelement <2 x half> %95, i64 0, !dbg !31
%106 = extractelement <2 x half> %95, i64 1, !dbg !31
%107 = fpext half %96 to float, !dbg !32
%108 = fpext half %98 to float, !dbg !32
%109 = fpext half %99 to float, !dbg !32
%110 = fpext half %101 to float, !dbg !32
%111 = fpext half %102 to float, !dbg !32
%112 = fpext half %104 to float, !dbg !32
%113 = fpext half %105 to float, !dbg !32
%114 = fpext half %106 to float, !dbg !32
%115 = getelementptr half, ptr addrspace(1) %4, i64 %42, !dbg !33
%116 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 4, !dbg !34
%117 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 8, !dbg !34
%118 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 12, !dbg !34
%119 = load <2 x half>, ptr addrspace(1) %118, align 4, !dbg !34
%120 = load half, ptr addrspace(1) %115, align 16, !dbg !34
%121 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 2, !dbg !34
%122 = load half, ptr addrspace(1) %121, align 2, !dbg !34
%123 = load half, ptr addrspace(1) %116, align 4, !dbg !34
%124 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 6, !dbg !34
%125 = load half, ptr addrspace(1) %124, align 2, !dbg !34
%126 = load half, ptr addrspace(1) %117, align 8, !dbg !34
%127 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 10, !dbg !34
%128 = load half, ptr addrspace(1) %127, align 2, !dbg !34
%129 = extractelement <2 x half> %119, i64 0, !dbg !34
%130 = extractelement <2 x half> %119, i64 1, !dbg !34
%131 = fpext half %120 to float, !dbg !35
%132 = fpext half %122 to float, !dbg !35
%133 = fpext half %123 to float, !dbg !35
%134 = fpext half %125 to float, !dbg !35
%135 = fpext half %126 to float, !dbg !35
%136 = fpext half %128 to float, !dbg !35
%137 = fpext half %129 to float, !dbg !35
%138 = fpext half %130 to float, !dbg !35
%139 = fsub float %34, %59, !dbg !36
%140 = fsub float %35, %60, !dbg !36
%141 = fsub float %36, %61, !dbg !36
%142 = fsub float %37, %62, !dbg !36
%143 = fsub float %38, %63, !dbg !36
%144 = fsub float %39, %64, !dbg !36
%145 = fsub float %40, %65, !dbg !36
%146 = fsub float %41, %66, !dbg !36
%147 = fadd float %83, 0x3EE4F8B580000000, !dbg !37
%148 = fadd float %84, 0x3EE4F8B580000000, !dbg !37
%149 = fadd float %85, 0x3EE4F8B580000000, !dbg !37
%150 = fadd float %86, 0x3EE4F8B580000000, !dbg !37
%151 = fadd float %87, 0x3EE4F8B580000000, !dbg !37
%152 = fadd float %88, 0x3EE4F8B580000000, !dbg !37
%153 = fadd float %89, 0x3EE4F8B580000000, !dbg !37
%154 = fadd float %90, 0x3EE4F8B580000000, !dbg !37
%155 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %147) #4, !dbg !38
%156 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %148) #4, !dbg !38
%157 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %149) #4, !dbg !38
%158 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %150) #4, !dbg !38
%159 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %151) #4, !dbg !38
%160 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %152) #4, !dbg !38
%161 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %153) #4, !dbg !38
%162 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %154) #4, !dbg !38
%163 = fdiv float 1.000000e+00, %155, !dbg !39
%164 = fdiv float 1.000000e+00, %156, !dbg !39
%165 = fdiv float 1.000000e+00, %157, !dbg !39
%166 = fdiv float 1.000000e+00, %158, !dbg !39
%167 = fdiv float 1.000000e+00, %159, !dbg !39
%168 = fdiv float 1.000000e+00, %160, !dbg !39
%169 = fdiv float 1.000000e+00, %161, !dbg !39
%170 = fdiv float 1.000000e+00, %162, !dbg !39
%171 = fmul float %139, %163, !dbg !40
%172 = fmul float %140, %164, !dbg !40
%173 = fmul float %141, %165, !dbg !40
%174 = fmul float %142, %166, !dbg !40
%175 = fmul float %143, %167, !dbg !40
%176 = fmul float %144, %168, !dbg !40
%177 = fmul float %145, %169, !dbg !40
%178 = fmul float %146, %170, !dbg !40
%179 = fmul float %171, %107, !dbg !41
%180 = fmul float %172, %108, !dbg !41
%181 = fmul float %173, %109, !dbg !41
%182 = fmul float %174, %110, !dbg !41
%183 = fmul float %175, %111, !dbg !41
%184 = fmul float %176, %112, !dbg !41
%185 = fmul float %177, %113, !dbg !41
%186 = fmul float %178, %114, !dbg !41
%187 = fadd float %179, %131, !dbg !42
%188 = fadd float %180, %132, !dbg !42
%189 = fadd float %181, %133, !dbg !42
%190 = fadd float %182, %134, !dbg !42
%191 = fadd float %183, %135, !dbg !42
%192 = fadd float %184, %136, !dbg !42
%193 = fadd float %185, %137, !dbg !42
%194 = fadd float %186, %138, !dbg !42
%195 = fcmp olt float %187, 0.000000e+00, !dbg !43
%196 = fcmp olt float %188, 0.000000e+00, !dbg !43
%197 = fcmp olt float %189, 0.000000e+00, !dbg !43
%198 = fcmp olt float %190, 0.000000e+00, !dbg !43
%199 = fcmp olt float %191, 0.000000e+00, !dbg !43
%200 = fcmp olt float %192, 0.000000e+00, !dbg !43
%201 = fcmp olt float %193, 0.000000e+00, !dbg !43
%202 = fcmp olt float %194, 0.000000e+00, !dbg !43
%203 = select i1 %195, float 0.000000e+00, float %187, !dbg !47
%204 = select i1 %196, float 0.000000e+00, float %188, !dbg !47
%205 = select i1 %197, float 0.000000e+00, float %189, !dbg !47
%206 = select i1 %198, float 0.000000e+00, float %190, !dbg !47
%207 = select i1 %199, float 0.000000e+00, float %191, !dbg !47
%208 = select i1 %200, float 0.000000e+00, float %192, !dbg !47
%209 = select i1 %201, float 0.000000e+00, float %193, !dbg !47
%210 = select i1 %202, float 0.000000e+00, float %194, !dbg !47
%211 = fptrunc float %203 to half, !dbg !48
%212 = fptrunc float %204 to half, !dbg !48
%213 = fptrunc float %205 to half, !dbg !48
%214 = fptrunc float %206 to half, !dbg !48
%215 = fptrunc float %207 to half, !dbg !48
%216 = fptrunc float %208 to half, !dbg !48
%217 = fptrunc float %209 to half, !dbg !48
%218 = fptrunc float %210 to half, !dbg !48
%219 = insertelement <2 x half> poison, half %211, i64 0, !dbg !48
%220 = insertelement <2 x half> %219, half %212, i64 1, !dbg !48
%221 = bitcast <2 x half> %220 to i32, !dbg !48
%222 = insertelement <2 x half> poison, half %213, i64 0, !dbg !48
%223 = insertelement <2 x half> %222, half %214, i64 1, !dbg !48
%224 = bitcast <2 x half> %223 to i32, !dbg !48
%225 = insertelement <2 x half> poison, half %215, i64 0, !dbg !48
%226 = insertelement <2 x half> %225, half %216, i64 1, !dbg !48
%227 = bitcast <2 x half> %226 to i32, !dbg !48
%228 = insertelement <2 x half> poison, half %217, i64 0, !dbg !48
%229 = insertelement <2 x half> %228, half %218, i64 1, !dbg !48
%230 = bitcast <2 x half> %229 to i32, !dbg !48
%231 = insertelement <4 x i32> poison, i32 %221, i64 0, !dbg !48
%232 = insertelement <4 x i32> %231, i32 %224, i64 1, !dbg !48
%233 = insertelement <4 x i32> %232, i32 %227, i64 2, !dbg !48
%234 = insertelement <4 x i32> %233, i32 %230, i64 3, !dbg !48
store <4 x i32> %234, ptr addrspace(1) %18, align 16, !dbg !48
ret void, !dbg !49
}
; Function Attrs: convergent mustprogress nofree nounwind willreturn memory(none)
declare dso_local spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef) local_unnamed_addr #2
attributes #0 = { mustprogress nofree nosync nounwind willreturn memory(none) }
attributes #1 = { mustprogress nofree nounwind willreturn memory(argmem: readwrite) }
attributes #2 = { convergent mustprogress nofree nounwind willreturn memory(none) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #3 = { nounwind willreturn memory(none) }
attributes #4 = { convergent nounwind willreturn memory(none) }
!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!2, !3, !4, !5}
!opencl.spir.version = !{!6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6}
!spirv.Source = !{!7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7}
!llvm.ident = !{!8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8}
!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!1 = !DIFile(filename: "ci4ch6recibbahxfnjzez75dv5ritrfki3463z37medubgfvzt3z.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/torchinductor_cache3/i4")
!2 = !{i32 2, !"Debug Info Version", i32 3}
!3 = !{i32 1, !"wchar_size", i32 4}
!4 = !{i32 1, !"sycl-device", i32 1}
!5 = !{i32 7, !"frame-pointer", i32 2}
!6 = !{i32 1, i32 2}
!7 = !{i32 3, i32 100000}
!8 = !{!"Intel(R) oneAPI DPC++/C++ Compiler 2025.0.0 (2025.0.0.20241008)"}
!9 = !DISubprogram(name: "_Z12get_local_idj", linkageName: "_Z12get_local_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!10 = !DISubroutineType(cc: DW_CC_normal, types: !11)
!11 = !{}
!12 = !DISubprogram(name: "_Z12get_group_idj", linkageName: "_Z12get_group_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!13 = distinct !DISubprogram(name: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", linkageName: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
!14 = !{i32 32}
!15 = !{i64 128, i64 1, i64 1}
!16 = !DILocation(line: 20, column: 28, scope: !13)
!17 = !DILocation(line: 20, column: 33, scope: !13)
!18 = !DILocation(line: 21, column: 36, scope: !13)
!19 = !DILocation(line: 21, column: 23, scope: !13)
!20 = !DILocation(line: 24, column: 19, scope: !13)
!21 = !DILocation(line: 25, column: 34, scope: !13)
!22 = !DILocation(line: 25, column: 39, scope: !13)
!23 = !DILocation(line: 25, column: 48, scope: !13)
!24 = !DILocation(line: 26, column: 30, scope: !13)
!25 = !DILocation(line: 26, column: 35, scope: !13)
!26 = !DILocation(line: 26, column: 74, scope: !13)
!27 = !DILocation(line: 27, column: 30, scope: !13)
!28 = !DILocation(line: 27, column: 35, scope: !13)
!29 = !DILocation(line: 27, column: 74, scope: !13)
!30 = !DILocation(line: 28, column: 31, scope: !13)
!31 = !DILocation(line: 28, column: 36, scope: !13)
!32 = !DILocation(line: 28, column: 75, scope: !13)
!33 = !DILocation(line: 29, column: 31, scope: !13)
!34 = !DILocation(line: 29, column: 36, scope: !13)
!35 = !DILocation(line: 29, column: 75, scope: !13)
!36 = !DILocation(line: 32, column: 18, scope: !13)
!37 = !DILocation(line: 35, column: 18, scope: !13)
!38 = !DILocation(line: 36, column: 26, scope: !13)
!39 = !DILocation(line: 38, column: 21, scope: !13)
!40 = !DILocation(line: 41, column: 19, scope: !13)
!41 = !DILocation(line: 43, column: 20, scope: !13)
!42 = !DILocation(line: 45, column: 20, scope: !13)
!43 = !DILocation(line: 111, column: 15, scope: !44, inlinedAt: !46)
!44 = distinct !DILexicalBlockFile(scope: !13, file: !45, discriminator: 0)
!45 = !DIFile(filename: "triton_helpers.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/.scripts_cache/pytorch/torch/_inductor/runtime")
!46 = !DILocation(line: 48, column: 42, scope: !13)
!47 = !DILocation(line: 114, column: 29, scope: !44, inlinedAt: !46)
!48 = !DILocation(line: 49, column: 40, scope: !13)
!49 = !DILocation(line: 49, column: 4, scope: !13) Incorrect kernel; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"
; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !9 spir_func i64 @_Z12get_local_idj(i32) local_unnamed_addr #0
; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !12 spir_func i64 @_Z12get_group_idj(i32) local_unnamed_addr #0
; Function Attrs: mustprogress nofree nounwind willreturn memory(argmem: readwrite)
define spir_kernel void @triton_poi_fused__native_batch_norm_legit_no_training_relu_2(ptr addrspace(1) captures(none) %0, ptr addrspace(1) readonly captures(none) %1, ptr addrspace(1) readonly captures(none) %2, ptr addrspace(1) readonly captures(none) %3, ptr addrspace(1) readonly captures(none) %4, i32 %5, ptr addrspace(1) readnone captures(none) %6) local_unnamed_addr #1 !dbg !13 !intel_reqd_sub_group_size !14 !max_work_group_size !15 {
%8 = tail call spir_func i64 @_Z12get_group_idj(i32 0) #3, !dbg !16
%9 = trunc i64 %8 to i32, !dbg !16
%10 = shl i32 %9, 10, !dbg !17
%11 = tail call spir_func i64 @_Z12get_local_idj(i32 0) #3, !dbg !18
%12 = trunc i64 %11 to i32, !dbg !18
%13 = shl i32 %12, 3, !dbg !18
%14 = and i32 %13, 1016, !dbg !18
%15 = or disjoint i32 %14, %10, !dbg !19
%16 = srem i32 %15, 64, !dbg !20
%17 = sext i32 %15 to i64, !dbg !21
%18 = getelementptr half, ptr addrspace(1) %0, i64 %17, !dbg !21
%19 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 4, !dbg !22
%20 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 8, !dbg !22
%21 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 12, !dbg !22
%22 = load <2 x half>, ptr addrspace(1) %21, align 4, !dbg !22
%23 = load half, ptr addrspace(1) %18, align 16, !dbg !22
%24 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 2, !dbg !22
%25 = load half, ptr addrspace(1) %24, align 2, !dbg !22
%26 = load half, ptr addrspace(1) %19, align 4, !dbg !22
%27 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 6, !dbg !22
%28 = load half, ptr addrspace(1) %27, align 2, !dbg !22
%29 = load half, ptr addrspace(1) %20, align 8, !dbg !22
%30 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 10, !dbg !22
%31 = load half, ptr addrspace(1) %30, align 2, !dbg !22
%32 = extractelement <2 x half> %22, i64 0, !dbg !22
%33 = extractelement <2 x half> %22, i64 1, !dbg !22
%34 = fpext half %23 to float, !dbg !23
%35 = fpext half %25 to float, !dbg !23
%36 = fpext half %26 to float, !dbg !23
%37 = fpext half %28 to float, !dbg !23
%38 = fpext half %29 to float, !dbg !23
%39 = fpext half %31 to float, !dbg !23
%40 = fpext half %32 to float, !dbg !23
%41 = fpext half %33 to float, !dbg !23
%42 = sext i32 %16 to i64, !dbg !24
%43 = getelementptr half, ptr addrspace(1) %1, i64 %42, !dbg !24
%44 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 4, !dbg !25
%45 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 8, !dbg !25
%46 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 12, !dbg !25
%47 = load <2 x half>, ptr addrspace(1) %46, align 4, !dbg !25
%48 = load half, ptr addrspace(1) %43, align 16, !dbg !25
%49 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 2, !dbg !25
%50 = load half, ptr addrspace(1) %49, align 2, !dbg !25
%51 = load half, ptr addrspace(1) %44, align 4, !dbg !25
%52 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 6, !dbg !25
%53 = load half, ptr addrspace(1) %52, align 2, !dbg !25
%54 = load half, ptr addrspace(1) %45, align 8, !dbg !25
%55 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 10, !dbg !25
%56 = load half, ptr addrspace(1) %55, align 2, !dbg !25
%57 = extractelement <2 x half> %47, i64 0, !dbg !25
%58 = extractelement <2 x half> %47, i64 1, !dbg !25
%59 = fpext half %48 to float, !dbg !26
%60 = fpext half %50 to float, !dbg !26
%61 = fpext half %51 to float, !dbg !26
%62 = fpext half %53 to float, !dbg !26
%63 = fpext half %54 to float, !dbg !26
%64 = fpext half %56 to float, !dbg !26
%65 = fpext half %57 to float, !dbg !26
%66 = fpext half %58 to float, !dbg !26
%67 = getelementptr half, ptr addrspace(1) %2, i64 %42, !dbg !27
%68 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 4, !dbg !28
%69 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 8, !dbg !28
%70 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 12, !dbg !28
%71 = load <2 x half>, ptr addrspace(1) %70, align 4, !dbg !28
%72 = load half, ptr addrspace(1) %67, align 16, !dbg !28
%73 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 2, !dbg !28
%74 = load half, ptr addrspace(1) %73, align 2, !dbg !28
%75 = load half, ptr addrspace(1) %68, align 4, !dbg !28
%76 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 6, !dbg !28
%77 = load half, ptr addrspace(1) %76, align 2, !dbg !28
%78 = load half, ptr addrspace(1) %69, align 8, !dbg !28
%79 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 10, !dbg !28
%80 = load half, ptr addrspace(1) %79, align 2, !dbg !28
%81 = extractelement <2 x half> %71, i64 0, !dbg !28
%82 = extractelement <2 x half> %71, i64 1, !dbg !28
%83 = fpext half %72 to float, !dbg !29
%84 = fpext half %74 to float, !dbg !29
%85 = fpext half %75 to float, !dbg !29
%86 = fpext half %77 to float, !dbg !29
%87 = fpext half %78 to float, !dbg !29
%88 = fpext half %80 to float, !dbg !29
%89 = fpext half %81 to float, !dbg !29
%90 = fpext half %82 to float, !dbg !29
%91 = getelementptr half, ptr addrspace(1) %3, i64 %42, !dbg !30
%92 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 4, !dbg !31
%93 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 8, !dbg !31
%94 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 12, !dbg !31
%95 = load <2 x half>, ptr addrspace(1) %94, align 4, !dbg !31
%96 = load half, ptr addrspace(1) %91, align 16, !dbg !31
%97 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 2, !dbg !31
%98 = load half, ptr addrspace(1) %97, align 2, !dbg !31
%99 = load half, ptr addrspace(1) %92, align 4, !dbg !31
%100 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 6, !dbg !31
%101 = load half, ptr addrspace(1) %100, align 2, !dbg !31
%102 = load half, ptr addrspace(1) %93, align 8, !dbg !31
%103 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 10, !dbg !31
%104 = load half, ptr addrspace(1) %103, align 2, !dbg !31
%105 = extractelement <2 x half> %95, i64 0, !dbg !31
%106 = extractelement <2 x half> %95, i64 1, !dbg !31
%107 = fpext half %96 to float, !dbg !32
%108 = fpext half %98 to float, !dbg !32
%109 = fpext half %99 to float, !dbg !32
%110 = fpext half %101 to float, !dbg !32
%111 = fpext half %102 to float, !dbg !32
%112 = fpext half %104 to float, !dbg !32
%113 = fpext half %105 to float, !dbg !32
%114 = fpext half %106 to float, !dbg !32
%115 = getelementptr half, ptr addrspace(1) %4, i64 %42, !dbg !33
%116 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 4, !dbg !34
%117 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 8, !dbg !34
%118 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 12, !dbg !34
%119 = load <2 x half>, ptr addrspace(1) %118, align 4, !dbg !34
%120 = load half, ptr addrspace(1) %115, align 16, !dbg !34
%121 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 2, !dbg !34
%122 = load half, ptr addrspace(1) %121, align 2, !dbg !34
%123 = load half, ptr addrspace(1) %116, align 4, !dbg !34
%124 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 6, !dbg !34
%125 = load half, ptr addrspace(1) %124, align 2, !dbg !34
%126 = load half, ptr addrspace(1) %117, align 8, !dbg !34
%127 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 10, !dbg !34
%128 = load half, ptr addrspace(1) %127, align 2, !dbg !34
%129 = extractelement <2 x half> %119, i64 0, !dbg !34
%130 = extractelement <2 x half> %119, i64 1, !dbg !34
%131 = fpext half %120 to float, !dbg !35
%132 = fpext half %122 to float, !dbg !35
%133 = fpext half %123 to float, !dbg !35
%134 = fpext half %125 to float, !dbg !35
%135 = fpext half %126 to float, !dbg !35
%136 = fpext half %128 to float, !dbg !35
%137 = fpext half %129 to float, !dbg !35
%138 = fpext half %130 to float, !dbg !35
%139 = fsub float %34, %59, !dbg !36
%140 = fsub float %35, %60, !dbg !36
%141 = fsub float %36, %61, !dbg !36
%142 = fsub float %37, %62, !dbg !36
%143 = fsub float %38, %63, !dbg !36
%144 = fsub float %39, %64, !dbg !36
%145 = fsub float %40, %65, !dbg !36
%146 = fsub float %41, %66, !dbg !36
%147 = fadd contract float %83, 0x3EE4F8B580000000, !dbg !37
%148 = fadd contract float %84, 0x3EE4F8B580000000, !dbg !37
%149 = fadd contract float %85, 0x3EE4F8B580000000, !dbg !37
%150 = fadd contract float %86, 0x3EE4F8B580000000, !dbg !37
%151 = fadd contract float %87, 0x3EE4F8B580000000, !dbg !37
%152 = fadd contract float %88, 0x3EE4F8B580000000, !dbg !37
%153 = fadd contract float %89, 0x3EE4F8B580000000, !dbg !37
%154 = fadd contract float %90, 0x3EE4F8B580000000, !dbg !37
%155 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %147) #4, !dbg !38
%156 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %148) #4, !dbg !38
%157 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %149) #4, !dbg !38
%158 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %150) #4, !dbg !38
%159 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %151) #4, !dbg !38
%160 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %152) #4, !dbg !38
%161 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %153) #4, !dbg !38
%162 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %154) #4, !dbg !38
%163 = fdiv float 1.000000e+00, %155, !dbg !39
%164 = fdiv float 1.000000e+00, %156, !dbg !39
%165 = fdiv float 1.000000e+00, %157, !dbg !39
%166 = fdiv float 1.000000e+00, %158, !dbg !39
%167 = fdiv float 1.000000e+00, %159, !dbg !39
%168 = fdiv float 1.000000e+00, %160, !dbg !39
%169 = fdiv float 1.000000e+00, %161, !dbg !39
%170 = fdiv float 1.000000e+00, %162, !dbg !39
%171 = fmul contract float %139, %163, !dbg !40
%172 = fmul contract float %140, %164, !dbg !40
%173 = fmul contract float %141, %165, !dbg !40
%174 = fmul contract float %142, %166, !dbg !40
%175 = fmul contract float %143, %167, !dbg !40
%176 = fmul contract float %144, %168, !dbg !40
%177 = fmul contract float %145, %169, !dbg !40
%178 = fmul contract float %146, %170, !dbg !40
%179 = fmul contract float %171, %107, !dbg !41
%180 = fmul contract float %172, %108, !dbg !41
%181 = fmul contract float %173, %109, !dbg !41
%182 = fmul contract float %174, %110, !dbg !41
%183 = fmul contract float %175, %111, !dbg !41
%184 = fmul contract float %176, %112, !dbg !41
%185 = fmul contract float %177, %113, !dbg !41
%186 = fmul contract float %178, %114, !dbg !41
%187 = fadd contract float %179, %131, !dbg !42
%188 = fadd contract float %180, %132, !dbg !42
%189 = fadd contract float %181, %133, !dbg !42
%190 = fadd contract float %182, %134, !dbg !42
%191 = fadd contract float %183, %135, !dbg !42
%192 = fadd contract float %184, %136, !dbg !42
%193 = fadd contract float %185, %137, !dbg !42
%194 = fadd contract float %186, %138, !dbg !42
%195 = fcmp olt float %187, 0.000000e+00, !dbg !43
%196 = fcmp olt float %188, 0.000000e+00, !dbg !43
%197 = fcmp olt float %189, 0.000000e+00, !dbg !43
%198 = fcmp olt float %190, 0.000000e+00, !dbg !43
%199 = fcmp olt float %191, 0.000000e+00, !dbg !43
%200 = fcmp olt float %192, 0.000000e+00, !dbg !43
%201 = fcmp olt float %193, 0.000000e+00, !dbg !43
%202 = fcmp olt float %194, 0.000000e+00, !dbg !43
%203 = select i1 %195, float 0.000000e+00, float %187, !dbg !47
%204 = select i1 %196, float 0.000000e+00, float %188, !dbg !47
%205 = select i1 %197, float 0.000000e+00, float %189, !dbg !47
%206 = select i1 %198, float 0.000000e+00, float %190, !dbg !47
%207 = select i1 %199, float 0.000000e+00, float %191, !dbg !47
%208 = select i1 %200, float 0.000000e+00, float %192, !dbg !47
%209 = select i1 %201, float 0.000000e+00, float %193, !dbg !47
%210 = select i1 %202, float 0.000000e+00, float %194, !dbg !47
%211 = fptrunc float %203 to half, !dbg !48
%212 = fptrunc float %204 to half, !dbg !48
%213 = fptrunc float %205 to half, !dbg !48
%214 = fptrunc float %206 to half, !dbg !48
%215 = fptrunc float %207 to half, !dbg !48
%216 = fptrunc float %208 to half, !dbg !48
%217 = fptrunc float %209 to half, !dbg !48
%218 = fptrunc float %210 to half, !dbg !48
%219 = insertelement <2 x half> poison, half %211, i64 0, !dbg !48
%220 = insertelement <2 x half> %219, half %212, i64 1, !dbg !48
%221 = bitcast <2 x half> %220 to i32, !dbg !48
%222 = insertelement <2 x half> poison, half %213, i64 0, !dbg !48
%223 = insertelement <2 x half> %222, half %214, i64 1, !dbg !48
%224 = bitcast <2 x half> %223 to i32, !dbg !48
%225 = insertelement <2 x half> poison, half %215, i64 0, !dbg !48
%226 = insertelement <2 x half> %225, half %216, i64 1, !dbg !48
%227 = bitcast <2 x half> %226 to i32, !dbg !48
%228 = insertelement <2 x half> poison, half %217, i64 0, !dbg !48
%229 = insertelement <2 x half> %228, half %218, i64 1, !dbg !48
%230 = bitcast <2 x half> %229 to i32, !dbg !48
%231 = insertelement <4 x i32> poison, i32 %221, i64 0, !dbg !48
%232 = insertelement <4 x i32> %231, i32 %224, i64 1, !dbg !48
%233 = insertelement <4 x i32> %232, i32 %227, i64 2, !dbg !48
%234 = insertelement <4 x i32> %233, i32 %230, i64 3, !dbg !48
store <4 x i32> %234, ptr addrspace(1) %18, align 16, !dbg !48
ret void, !dbg !49
}
; Function Attrs: convergent mustprogress nofree nounwind willreturn memory(none)
declare dso_local spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef) local_unnamed_addr #2
attributes #0 = { mustprogress nofree nosync nounwind willreturn memory(none) }
attributes #1 = { mustprogress nofree nounwind willreturn memory(argmem: readwrite) }
attributes #2 = { convergent mustprogress nofree nounwind willreturn memory(none) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #3 = { nounwind willreturn memory(none) }
attributes #4 = { convergent nounwind willreturn memory(none) }
!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!2, !3, !4, !5}
!opencl.spir.version = !{!6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6}
!spirv.Source = !{!7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7}
!llvm.ident = !{!8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8}
!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!1 = !DIFile(filename: "chsgifbghvzneiwpg4l7venb33d6syzbr2mypkbv7rvm6ptn6svr.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/torchinductor_cache4/hs")
!2 = !{i32 2, !"Debug Info Version", i32 3}
!3 = !{i32 1, !"wchar_size", i32 4}
!4 = !{i32 1, !"sycl-device", i32 1}
!5 = !{i32 7, !"frame-pointer", i32 2}
!6 = !{i32 1, i32 2}
!7 = !{i32 3, i32 100000}
!8 = !{!"Intel(R) oneAPI DPC++/C++ Compiler 2025.0.0 (2025.0.0.20241008)"}
!9 = !DISubprogram(name: "_Z12get_local_idj", linkageName: "_Z12get_local_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!10 = !DISubroutineType(cc: DW_CC_normal, types: !11)
!11 = !{}
!12 = !DISubprogram(name: "_Z12get_group_idj", linkageName: "_Z12get_group_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!13 = distinct !DISubprogram(name: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", linkageName: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
!14 = !{i32 32}
!15 = !{i64 128, i64 1, i64 1}
!16 = !DILocation(line: 20, column: 28, scope: !13)
!17 = !DILocation(line: 20, column: 33, scope: !13)
!18 = !DILocation(line: 21, column: 36, scope: !13)
!19 = !DILocation(line: 21, column: 23, scope: !13)
!20 = !DILocation(line: 24, column: 19, scope: !13)
!21 = !DILocation(line: 25, column: 34, scope: !13)
!22 = !DILocation(line: 25, column: 39, scope: !13)
!23 = !DILocation(line: 25, column: 48, scope: !13)
!24 = !DILocation(line: 26, column: 30, scope: !13)
!25 = !DILocation(line: 26, column: 35, scope: !13)
!26 = !DILocation(line: 26, column: 74, scope: !13)
!27 = !DILocation(line: 27, column: 30, scope: !13)
!28 = !DILocation(line: 27, column: 35, scope: !13)
!29 = !DILocation(line: 27, column: 74, scope: !13)
!30 = !DILocation(line: 28, column: 31, scope: !13)
!31 = !DILocation(line: 28, column: 36, scope: !13)
!32 = !DILocation(line: 28, column: 75, scope: !13)
!33 = !DILocation(line: 29, column: 31, scope: !13)
!34 = !DILocation(line: 29, column: 36, scope: !13)
!35 = !DILocation(line: 29, column: 75, scope: !13)
!36 = !DILocation(line: 32, column: 18, scope: !13)
!37 = !DILocation(line: 35, column: 18, scope: !13)
!38 = !DILocation(line: 36, column: 26, scope: !13)
!39 = !DILocation(line: 38, column: 21, scope: !13)
!40 = !DILocation(line: 41, column: 19, scope: !13)
!41 = !DILocation(line: 43, column: 20, scope: !13)
!42 = !DILocation(line: 45, column: 20, scope: !13)
!43 = !DILocation(line: 111, column: 15, scope: !44, inlinedAt: !46)
!44 = distinct !DILexicalBlockFile(scope: !13, file: !45, discriminator: 0)
!45 = !DIFile(filename: "triton_helpers.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/.scripts_cache/pytorch/torch/_inductor/runtime")
!46 = !DILocation(line: 48, column: 42, scope: !13)
!47 = !DILocation(line: 114, column: 29, scope: !44, inlinedAt: !46)
!48 = !DILocation(line: 49, column: 40, scope: !13)
!49 = !DILocation(line: 49, column: 4, scope: !13)
UPD: Mismatched elements: 15 / 92143616 (0.0%)
Greatest absolute difference: 0.00390625 at index (0, 29, 0, 598) (up to 0.0005 allowed)
Greatest relative difference: 0.0008788108825683594 at index (0, 29, 0, 598) (up to 0.0005 allowed) This is how I was able to figure out that this particular kernel was the cause of this behavior: Run command: Modification in compiler: intel.set_spv_target_triple(llvm_mod)
disable_fast_math_kernels = {
'triton_poi_fused_ceil_div_mul_sub_0.ttir',
'triton_poi_fused_add_25.ttir',
'triton_poi_fused_convolution_relu_11.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_6.ttir',
'triton_poi_fused_mul_0.ttir',
'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_15.ttir',
'triton_poi_fused_add_2.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_0.ttir',
'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_19.ttir',
'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_18.ttir',
'triton_poi_fused_clone_16.ttir',
'triton_poi_fused__native_batch_norm_legit_no_training_relu_8.ttir',
'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_6.ttir',
'triton_poi_fused_add_21.ttir',
'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_2.ttir',
'triton_poi_fused_convolution_relu_14.ttir',
'triton_poi_fused_clone_12.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_1.ttir',
'triton_poi_fused_clone_9.ttir',
'triton_poi_fused_clone_15.ttir',
'triton_poi_fused_ceil_div_mul_sub_0.ttir',
'triton_poi_fused_stack_3.ttir',
'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_2.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_3.ttir',
'triton_poi_fused_clone_3.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
'triton_poi_fused__native_batch_norm_legit_no_training_relu_12.ttir',
'triton_poi_fused_mul_scalar_tensor_where_1.ttir',
'triton_poi_fused_add_22.ttir',
#'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_3.ttir',
#'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_10.ttir',
#'triton_poi_fused__to_copy__unsafe_index_add_convolution_24.ttir',
#'triton_poi_fused__to_copy__unsafe_index_add_convolution_25.ttir',
#'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_2.ttir',
#'triton_poi_fused_convolution_28.ttir',
#'triton_poi_fused_clone_19.ttir',
#'triton_poi_fused__to_copy__unsafe_index_add_convolution_23.ttir',
#'triton_poi_fused_clone_17.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_0.ttir',
'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_0.ttir',
'triton_poi_fused_stack_2.ttir',
'triton_poi_fused_add_23.ttir',
'triton_poi_fused_ceil_div_mul_sub_1.ttir',
'triton_poi_fused_clone_18.ttir',
'triton_poi_fused_add_24.ttir',
#'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_6.ttir',
#'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
#'triton_poi_fused_ceil_div_mul_sub_0.ttir',
#'triton_poi_fused__native_batch_norm_legit_no_training_relu_2.ttir', # <- uncommenting this line results in a precision error
#'triton_poi_fused_ceil_div_mul_sub_0.ttir',
#'triton_poi_fused_mul_0.ttir',
#'triton_poi_fused__to_copy__unsafe_index_add_convolution_26.ttir',
}
def get_function_name():
return glob.glob('*.ttir', root_dir=metadata['cache_dir'])[0]
if not get_function_name() in disable_fast_math_kernels:
intel.set_fast_math(llvm_mod)
else:
print("ignore: ", get_function_name()) |
@whitneywhtsang @chengjunlu as per the message above, I don't see a way other than just disabling fp_fusion for this particular benchmark. Thoughts? |
I am ok to disable it. |
Decided to check IGC code generation... @whitneywhtsang Is it expected to have ![]() |
We need PyTorch to agree to either changing reference implementation or disabling fp_fusion for this particular benchmark. |
It is legal to contract to mad instructions, let's discuss about this with IGC team tomorrow. |
NOTE2: The changes after applying (Pdb) in_out_ptr2[(0, 20, 17, 246)] # <- problematic value in float32 with contraction (i.e. like `-ff-contact=1`)
tensor(-5.5410, device='xpu:0')
(Pdb) in_out_ptr3[(0, 20, 17, 246)] # <- problematic value in float32 without contraction (i.e. like `-ff-contact=0`)
tensor(-5.5410, device='xpu:0')
(Pdb) in_out_ptr3[(0, 20, 17, 246)].to(torch.float16) # <- problematic value in float16 with contraction (i.e. like `-ff-contact=1`)
tensor(-5.5430, device='xpu:0', dtype=torch.float16)
(Pdb) in_out_ptr2[(0, 20, 17, 246)].to(torch.float16) # <- problematic value in float16 with contraction (i.e. like `-ff-contact=0`)
tensor(-5.5391, device='xpu:0', dtype=torch.float16)
(Pdb) in_out_ptr2[(0, 20, 17, 246)].to(torch.float16) - in_out_ptr3[(0, 20, 17, 246)].to(torch.float16)
tensor(0.0039, device='xpu:0', dtype=torch.float16) @chuanqi129 @mengfei25 @whitneywhtsang I think that's enough as evidence and considering that we have no choice, we can merge it right now.
Reproducerimport triton
import triton.language as tl
import torch
import copy
import os
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
'''
@triton_heuristics.pointwise(
size_hints={'x': 134217728},
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*fp16', 'in_ptr0': '*fp16', 'in_ptr1': '*fp16', 'in_ptr2': '*fp16', 'in_ptr3': '*fp16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=56, cc={'architecture': 13136561920, 'device_id': 3034, 'driver_version': '1.6.33276+22', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51522830336, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__native_batch_norm_legit_no_training_relu_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': 'BF1FE552CB8E524B1A91FA16118F1D335628999A2F77D64AB02D0408920F518D', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 552862208}},
min_elem_per_thread=0
)
'''
@triton.jit
def triton_poi_fused__native_batch_norm_legit_no_training_relu_2(in_out_ptr0, in_ptr2, in_out_ptr2, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
x2 = xindex
x0 = (xindex % 64)
tmp0 = tl.load(in_out_ptr0 + x2).to(tl.float32)
tmp15 = tl.load(in_ptr2 + x0).to(tl.float32)
tmp14 = tmp0 * 0.021739
tmp17 = tmp14 * tmp15
tmp20 = tmp17 + 0.1
tl.store(in_out_ptr2 + x2, tmp20)
@triton.jit
def triton_poi_fused__native_batch_norm_legit_no_training_relu_2_no_contract(in_out_ptr0, in_ptr2, in_out_ptr2, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
x2 = xindex
x0 = (xindex % 64)
tmp0 = tl.load(in_out_ptr0 + x2).to(tl.float32)
tmp15 = tl.load(in_ptr2 + x0).to(tl.float32)
tmp14 = tmp0 * 0.021739
tmp17 = tmp14 * tmp15
tmp20 = tmp17 + 0.1
tl.store(in_out_ptr2 + x2, tmp20)
def main():
torch.manual_seed(42)
#in_out_ptr0 = torch.rand((4, 64, 592, 608), device='xpu:0', dtype=torch.float16)
in_out_ptr0 = torch.load(
'.scripts_cache/pytorch/triton_poi_fused__native_batch_norm_legit_no_training_relu_2_tensor_reduced.pt',
map_location=torch.device("xpu:0"))
in_out_ptr1 = copy.deepcopy(in_out_ptr0)
# if write to float32 buffers instead of float16 buffers (e.c. to avoid fp trunc from fp32 to fp16)
# the accuracy check starts to run successfully
in_out_ptr2 = copy.deepcopy(in_out_ptr0).to(torch.float32)
in_out_ptr3 = copy.deepcopy(in_out_ptr0).to(torch.float32)
breakpoint()
in_ptr2 = torch.tensor([1.0918, 1.1143, 0.8721, 1.3779, 1.1670, 1.1963, 0.9111, 1.5654, 2.5527,
0.8813, 1.1953, 0.8633, 0.9434, 2.4863, 1.0371, 1.2842, 1.4512, 1.3008,
1.0088, 1.3760, 1.0635, 1.1025, 2.6348, 1.0762, 1.0830, 1.4102, 1.9824,
0.7451, 0.9307, 2.0840, 1.4150, 0.8799, 0.8721, 0.5127, 1.2891, 0.9023,
0.8804, 1.8643, 1.3809, 1.0107, 1.3945, 1.0410, 1.2939, 0.5283, 0.9634,
2.6680, 1.1680, 1.1094, 0.8320, 1.0039, 1.4688, 1.5918, 0.5527, 1.2539,
1.5557, 1.0693, 0.8511, 0.9292, 1.4111, 0.9453, 0.8613, 0.9160, 1.0430,
1.0137], device='xpu:0', dtype=torch.float16)
xnumel = 92143616
XBLOCK = 1024
# 89984 -> 22496
triton_poi_fused__native_batch_norm_legit_no_training_relu_2[(22496, 1, 1)](
in_out_ptr0, in_ptr2, in_out_ptr2, XBLOCK)
os.environ['TRITON_DEFAULT_FP_FUSION'] = "0"
triton_poi_fused__native_batch_norm_legit_no_training_relu_2_no_contract[(22496, 1, 1)](
in_out_ptr1, in_ptr2, in_out_ptr3, XBLOCK)
try:
#torch.testing.assert_close(in_out_ptr0, in_out_ptr1, atol=1e-4, rtol=1e-4)
torch.testing.assert_close(in_out_ptr2, in_out_ptr3, atol=1e-6, rtol=1e-6)
torch.testing.assert_close(in_out_ptr2.to(torch.float16), in_out_ptr3.to(torch.float16), atol=1e-6, rtol=1e-6) # <- accuracy issue again
breakpoint()
except Exception as err:
breakpoint()
raise
if __name__ == "__main__":
main() </details/ |
This reverts commit 38a1984.
Known cases of impact on accuracy of the following models: detectron2 and doctr_reco_predictor from #4412 on PVC and LayoutLMForSequenceClassification from #4509 on ARL
We can't wait for the fp64 patch to be removed on torch-xpu-ops side because #4514 (comment). I think it will take at least a month before the next launches of E2E models for PyTorch, I will try to have time to help prepare their references for this.
CI:
detectron2_fasterrcnn_r_50_fpn
fails accuracy check)