Skip to content

Revert "[DEBUG] Revert "Enable SPV_INTEL_fp_fast_math_mode (#4058)"(#4473)" #4576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 22, 2025

Conversation

anmyachev
Copy link
Contributor

@anmyachev anmyachev commented Jun 26, 2025

This reverts commit 38a1984.

Known cases of impact on accuracy of the following models: detectron2 and doctr_reco_predictor from #4412 on PVC and LayoutLMForSequenceClassification from #4509 on ARL

We can't wait for the fp64 patch to be removed on torch-xpu-ops side because #4514 (comment). I think it will take at least a month before the next launches of E2E models for PyTorch, I will try to have time to help prepare their references for this.

CI:

@anmyachev anmyachev linked an issue Jun 26, 2025 that may be closed by this pull request
@anmyachev anmyachev marked this pull request as ready for review June 26, 2025 08:11
@anmyachev anmyachev requested a review from chengjunlu June 26, 2025 08:49
@whitneywhtsang
Copy link
Contributor

whitneywhtsang commented Jun 26, 2025

By relanding this change to main, is it going to make our inductor or E2E CI fail?
If it does, wonder if we should have patches to increase tolerance until their references are changed.

@etiotto
Copy link
Contributor

etiotto commented Jun 26, 2025

By relanding this change to main, is it going to make our inductor or E2E CI fail?

Right I thought we didn't want to set the fp fast math mode until inductor changes the precision check.

@etiotto
Copy link
Contributor

etiotto commented Jul 2, 2025

@anmyachev have you already reported to the PyTorch team the problem ? The benchmark verification code would need to be updated to allow Triton to set the fp fast math flag.

@anmyachev
Copy link
Contributor Author

@anmyachev have you already reported to the PyTorch team the problem ? The benchmark verification code would need to be updated to allow Triton to set the fp fast math flag.

I am in touch with @chuanqi129 about this topic. For now they are busy with pytorch 2.8 release testing (the status from last Friday).

@anmyachev
Copy link
Contributor Author

anmyachev commented Jul 16, 2025

Few more updates/findings:

image

Refs:

All tensors for comparison:
pytorch_new_results.txt
pytorch_correct_results.txt

@anmyachev
Copy link
Contributor Author

anmyachev commented Jul 16, 2025

Hi @etaf,

mismatch of lines 19, 20 (from screenshot above) is also present on the main branch. I also checked 3.2.x and 3.3.x Triton releases, the order of the lines is the same on them (402.5... and then 508.0...). It turns out that this is not a regression in Triton, so there are two options - a very long-lived bug on Triton side, or a bug in PyTorch code that generates the reference for comparison. Could you check whether the correct order is obtained on PyTorch side (508.0..., and then 402.5...)?

For the reference: I run the model this way: TORCH_COMPILE_DEBUG=1 python benchmarks/dynamo/torchbench.py --accuracy --float16 -d xpu -n10 --inference --only detectron2_fasterrcnn_r_50_fpn --backend=inductor --cold-start-latency > ../../detectron2_fasterrcnn_r_50_fpn_output.txt 2>&1 and then I analyze the logs.

UPD: intel/torch-xpu-ops#1855 can relate to this problem, need to recheck after this problem is fixed

@etaf
Copy link

etaf commented Jul 17, 2025

Hi @etaf,

mismatch of lines 19, 20 (from screenshot above) is also present on the main branch. I also checked 3.2.x and 3.3.x Triton releases, the order of the lines is the same on them (402.5... and then 508.0...). It turns out that this is not a regression in Triton, so there are two options - a very long-lived bug on Triton side, or a bug in PyTorch code that generates the reference for comparison. Could you check whether the correct order is obtained on PyTorch side (508.0..., and then 402.5...)?

For the reference: I run the model this way: TORCH_COMPILE_DEBUG=1 python benchmarks/dynamo/torchbench.py --accuracy --float16 -d xpu -n10 --inference --only detectron2_fasterrcnn_r_50_fpn --backend=inductor --cold-start-latency > ../../detectron2_fasterrcnn_r_50_fpn_output.txt 2>&1 and then I analyze the logs.

UPD: intel/torch-xpu-ops#1855 can relate to this problem, need to recheck after this problem is fixed

Hi, @anmyachev, sorry, I'm not familiar with this issue, so I didn't respond immediately. I need some time to understand the background and details before getting back to you.

@etaf
Copy link

etaf commented Jul 17, 2025

Hi, @anmyachev , the model failed after we switched to conv channel last mode, so you can verify it before that, you may use this pytorch commit 0504480f37714a289b2ba32c9cf32a5e50e86d38 which do not contain the channel last change.

@anmyachev
Copy link
Contributor Author

0504480f37714a289b2ba32c9cf32a5e50e86d38

Got it, thanks @etaf

!fastMath.has_value()) {
// Default to allow contract when default fp fusion is not disabled.
if ((!enableFpFusion.has_value() || enableFpFusion.value()) &&
!fastMath.has_value() && inst.hasNoNaNs()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@whitneywhtsang with hasNoNaNs change detectron2_fasterrcnn_r_50_fpn passes accuracy check locally. This change is inspired by

opt.NoNaNsFPMath = true;

Could you take a look?

Here is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the performance regression recovered?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the above BMG run result, the geomean for ATTN D_HEAD=128 CAUSAL=1 is 60TFlops, which is the bad performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the above BMG run result, the geomean for ATTN D_HEAD=128 CAUSAL=1 is 60TFlops, which is the bad performance.

Hm, how many teraflops do you expect to see? I don't remember

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expect to see above 80TFLops, #4514 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into why that is. However, I have doubts that adding contract to all operations regardless of their safety in terms of accuracy is what is expected from the default behavior when the fast-math option is not enabled. I suspect that when AllowFPOpFusion option is enabled, some analysis is performed that adds it only where it is safe. Maybe you've already tested this and know exactly how it works?

@anmyachev anmyachev marked this pull request as draft July 19, 2025 13:55
@anmyachev
Copy link
Contributor Author

anmyachev commented Jul 20, 2025

Managed to find the problem kernel. The only difference is in the use of contract at .llir level.

Correct kernel
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !9 spir_func i64 @_Z12get_local_idj(i32) local_unnamed_addr #0

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !12 spir_func i64 @_Z12get_group_idj(i32) local_unnamed_addr #0

; Function Attrs: mustprogress nofree nounwind willreturn memory(argmem: readwrite)
define spir_kernel void @triton_poi_fused__native_batch_norm_legit_no_training_relu_2(ptr addrspace(1) captures(none) %0, ptr addrspace(1) readonly captures(none) %1, ptr addrspace(1) readonly captures(none) %2, ptr addrspace(1) readonly captures(none) %3, ptr addrspace(1) readonly captures(none) %4, i32 %5, ptr addrspace(1) readnone captures(none) %6) local_unnamed_addr #1 !dbg !13 !intel_reqd_sub_group_size !14 !max_work_group_size !15 {
  %8 = tail call spir_func i64 @_Z12get_group_idj(i32 0) #3, !dbg !16
  %9 = trunc i64 %8 to i32, !dbg !16
  %10 = shl i32 %9, 10, !dbg !17
  %11 = tail call spir_func i64 @_Z12get_local_idj(i32 0) #3, !dbg !18
  %12 = trunc i64 %11 to i32, !dbg !18
  %13 = shl i32 %12, 3, !dbg !18
  %14 = and i32 %13, 1016, !dbg !18
  %15 = or disjoint i32 %14, %10, !dbg !19
  %16 = srem i32 %15, 64, !dbg !20
  %17 = sext i32 %15 to i64, !dbg !21
  %18 = getelementptr half, ptr addrspace(1) %0, i64 %17, !dbg !21
  %19 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 4, !dbg !22
  %20 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 8, !dbg !22
  %21 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 12, !dbg !22
  %22 = load <2 x half>, ptr addrspace(1) %21, align 4, !dbg !22
  %23 = load half, ptr addrspace(1) %18, align 16, !dbg !22
  %24 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 2, !dbg !22
  %25 = load half, ptr addrspace(1) %24, align 2, !dbg !22
  %26 = load half, ptr addrspace(1) %19, align 4, !dbg !22
  %27 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 6, !dbg !22
  %28 = load half, ptr addrspace(1) %27, align 2, !dbg !22
  %29 = load half, ptr addrspace(1) %20, align 8, !dbg !22
  %30 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 10, !dbg !22
  %31 = load half, ptr addrspace(1) %30, align 2, !dbg !22
  %32 = extractelement <2 x half> %22, i64 0, !dbg !22
  %33 = extractelement <2 x half> %22, i64 1, !dbg !22
  %34 = fpext half %23 to float, !dbg !23
  %35 = fpext half %25 to float, !dbg !23
  %36 = fpext half %26 to float, !dbg !23
  %37 = fpext half %28 to float, !dbg !23
  %38 = fpext half %29 to float, !dbg !23
  %39 = fpext half %31 to float, !dbg !23
  %40 = fpext half %32 to float, !dbg !23
  %41 = fpext half %33 to float, !dbg !23
  %42 = sext i32 %16 to i64, !dbg !24
  %43 = getelementptr half, ptr addrspace(1) %1, i64 %42, !dbg !24
  %44 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 4, !dbg !25
  %45 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 8, !dbg !25
  %46 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 12, !dbg !25
  %47 = load <2 x half>, ptr addrspace(1) %46, align 4, !dbg !25
  %48 = load half, ptr addrspace(1) %43, align 16, !dbg !25
  %49 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 2, !dbg !25
  %50 = load half, ptr addrspace(1) %49, align 2, !dbg !25
  %51 = load half, ptr addrspace(1) %44, align 4, !dbg !25
  %52 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 6, !dbg !25
  %53 = load half, ptr addrspace(1) %52, align 2, !dbg !25
  %54 = load half, ptr addrspace(1) %45, align 8, !dbg !25
  %55 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 10, !dbg !25
  %56 = load half, ptr addrspace(1) %55, align 2, !dbg !25
  %57 = extractelement <2 x half> %47, i64 0, !dbg !25
  %58 = extractelement <2 x half> %47, i64 1, !dbg !25
  %59 = fpext half %48 to float, !dbg !26
  %60 = fpext half %50 to float, !dbg !26
  %61 = fpext half %51 to float, !dbg !26
  %62 = fpext half %53 to float, !dbg !26
  %63 = fpext half %54 to float, !dbg !26
  %64 = fpext half %56 to float, !dbg !26
  %65 = fpext half %57 to float, !dbg !26
  %66 = fpext half %58 to float, !dbg !26
  %67 = getelementptr half, ptr addrspace(1) %2, i64 %42, !dbg !27
  %68 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 4, !dbg !28
  %69 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 8, !dbg !28
  %70 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 12, !dbg !28
  %71 = load <2 x half>, ptr addrspace(1) %70, align 4, !dbg !28
  %72 = load half, ptr addrspace(1) %67, align 16, !dbg !28
  %73 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 2, !dbg !28
  %74 = load half, ptr addrspace(1) %73, align 2, !dbg !28
  %75 = load half, ptr addrspace(1) %68, align 4, !dbg !28
  %76 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 6, !dbg !28
  %77 = load half, ptr addrspace(1) %76, align 2, !dbg !28
  %78 = load half, ptr addrspace(1) %69, align 8, !dbg !28
  %79 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 10, !dbg !28
  %80 = load half, ptr addrspace(1) %79, align 2, !dbg !28
  %81 = extractelement <2 x half> %71, i64 0, !dbg !28
  %82 = extractelement <2 x half> %71, i64 1, !dbg !28
  %83 = fpext half %72 to float, !dbg !29
  %84 = fpext half %74 to float, !dbg !29
  %85 = fpext half %75 to float, !dbg !29
  %86 = fpext half %77 to float, !dbg !29
  %87 = fpext half %78 to float, !dbg !29
  %88 = fpext half %80 to float, !dbg !29
  %89 = fpext half %81 to float, !dbg !29
  %90 = fpext half %82 to float, !dbg !29
  %91 = getelementptr half, ptr addrspace(1) %3, i64 %42, !dbg !30
  %92 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 4, !dbg !31
  %93 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 8, !dbg !31
  %94 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 12, !dbg !31
  %95 = load <2 x half>, ptr addrspace(1) %94, align 4, !dbg !31
  %96 = load half, ptr addrspace(1) %91, align 16, !dbg !31
  %97 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 2, !dbg !31
  %98 = load half, ptr addrspace(1) %97, align 2, !dbg !31
  %99 = load half, ptr addrspace(1) %92, align 4, !dbg !31
  %100 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 6, !dbg !31
  %101 = load half, ptr addrspace(1) %100, align 2, !dbg !31
  %102 = load half, ptr addrspace(1) %93, align 8, !dbg !31
  %103 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 10, !dbg !31
  %104 = load half, ptr addrspace(1) %103, align 2, !dbg !31
  %105 = extractelement <2 x half> %95, i64 0, !dbg !31
  %106 = extractelement <2 x half> %95, i64 1, !dbg !31
  %107 = fpext half %96 to float, !dbg !32
  %108 = fpext half %98 to float, !dbg !32
  %109 = fpext half %99 to float, !dbg !32
  %110 = fpext half %101 to float, !dbg !32
  %111 = fpext half %102 to float, !dbg !32
  %112 = fpext half %104 to float, !dbg !32
  %113 = fpext half %105 to float, !dbg !32
  %114 = fpext half %106 to float, !dbg !32
  %115 = getelementptr half, ptr addrspace(1) %4, i64 %42, !dbg !33
  %116 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 4, !dbg !34
  %117 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 8, !dbg !34
  %118 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 12, !dbg !34
  %119 = load <2 x half>, ptr addrspace(1) %118, align 4, !dbg !34
  %120 = load half, ptr addrspace(1) %115, align 16, !dbg !34
  %121 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 2, !dbg !34
  %122 = load half, ptr addrspace(1) %121, align 2, !dbg !34
  %123 = load half, ptr addrspace(1) %116, align 4, !dbg !34
  %124 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 6, !dbg !34
  %125 = load half, ptr addrspace(1) %124, align 2, !dbg !34
  %126 = load half, ptr addrspace(1) %117, align 8, !dbg !34
  %127 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 10, !dbg !34
  %128 = load half, ptr addrspace(1) %127, align 2, !dbg !34
  %129 = extractelement <2 x half> %119, i64 0, !dbg !34
  %130 = extractelement <2 x half> %119, i64 1, !dbg !34
  %131 = fpext half %120 to float, !dbg !35
  %132 = fpext half %122 to float, !dbg !35
  %133 = fpext half %123 to float, !dbg !35
  %134 = fpext half %125 to float, !dbg !35
  %135 = fpext half %126 to float, !dbg !35
  %136 = fpext half %128 to float, !dbg !35
  %137 = fpext half %129 to float, !dbg !35
  %138 = fpext half %130 to float, !dbg !35
  %139 = fsub float %34, %59, !dbg !36
  %140 = fsub float %35, %60, !dbg !36
  %141 = fsub float %36, %61, !dbg !36
  %142 = fsub float %37, %62, !dbg !36
  %143 = fsub float %38, %63, !dbg !36
  %144 = fsub float %39, %64, !dbg !36
  %145 = fsub float %40, %65, !dbg !36
  %146 = fsub float %41, %66, !dbg !36
  %147 = fadd float %83, 0x3EE4F8B580000000, !dbg !37
  %148 = fadd float %84, 0x3EE4F8B580000000, !dbg !37
  %149 = fadd float %85, 0x3EE4F8B580000000, !dbg !37
  %150 = fadd float %86, 0x3EE4F8B580000000, !dbg !37
  %151 = fadd float %87, 0x3EE4F8B580000000, !dbg !37
  %152 = fadd float %88, 0x3EE4F8B580000000, !dbg !37
  %153 = fadd float %89, 0x3EE4F8B580000000, !dbg !37
  %154 = fadd float %90, 0x3EE4F8B580000000, !dbg !37
  %155 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %147) #4, !dbg !38
  %156 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %148) #4, !dbg !38
  %157 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %149) #4, !dbg !38
  %158 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %150) #4, !dbg !38
  %159 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %151) #4, !dbg !38
  %160 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %152) #4, !dbg !38
  %161 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %153) #4, !dbg !38
  %162 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %154) #4, !dbg !38
  %163 = fdiv float 1.000000e+00, %155, !dbg !39
  %164 = fdiv float 1.000000e+00, %156, !dbg !39
  %165 = fdiv float 1.000000e+00, %157, !dbg !39
  %166 = fdiv float 1.000000e+00, %158, !dbg !39
  %167 = fdiv float 1.000000e+00, %159, !dbg !39
  %168 = fdiv float 1.000000e+00, %160, !dbg !39
  %169 = fdiv float 1.000000e+00, %161, !dbg !39
  %170 = fdiv float 1.000000e+00, %162, !dbg !39
  %171 = fmul float %139, %163, !dbg !40
  %172 = fmul float %140, %164, !dbg !40
  %173 = fmul float %141, %165, !dbg !40
  %174 = fmul float %142, %166, !dbg !40
  %175 = fmul float %143, %167, !dbg !40
  %176 = fmul float %144, %168, !dbg !40
  %177 = fmul float %145, %169, !dbg !40
  %178 = fmul float %146, %170, !dbg !40
  %179 = fmul float %171, %107, !dbg !41
  %180 = fmul float %172, %108, !dbg !41
  %181 = fmul float %173, %109, !dbg !41
  %182 = fmul float %174, %110, !dbg !41
  %183 = fmul float %175, %111, !dbg !41
  %184 = fmul float %176, %112, !dbg !41
  %185 = fmul float %177, %113, !dbg !41
  %186 = fmul float %178, %114, !dbg !41
  %187 = fadd float %179, %131, !dbg !42
  %188 = fadd float %180, %132, !dbg !42
  %189 = fadd float %181, %133, !dbg !42
  %190 = fadd float %182, %134, !dbg !42
  %191 = fadd float %183, %135, !dbg !42
  %192 = fadd float %184, %136, !dbg !42
  %193 = fadd float %185, %137, !dbg !42
  %194 = fadd float %186, %138, !dbg !42
  %195 = fcmp olt float %187, 0.000000e+00, !dbg !43
  %196 = fcmp olt float %188, 0.000000e+00, !dbg !43
  %197 = fcmp olt float %189, 0.000000e+00, !dbg !43
  %198 = fcmp olt float %190, 0.000000e+00, !dbg !43
  %199 = fcmp olt float %191, 0.000000e+00, !dbg !43
  %200 = fcmp olt float %192, 0.000000e+00, !dbg !43
  %201 = fcmp olt float %193, 0.000000e+00, !dbg !43
  %202 = fcmp olt float %194, 0.000000e+00, !dbg !43
  %203 = select i1 %195, float 0.000000e+00, float %187, !dbg !47
  %204 = select i1 %196, float 0.000000e+00, float %188, !dbg !47
  %205 = select i1 %197, float 0.000000e+00, float %189, !dbg !47
  %206 = select i1 %198, float 0.000000e+00, float %190, !dbg !47
  %207 = select i1 %199, float 0.000000e+00, float %191, !dbg !47
  %208 = select i1 %200, float 0.000000e+00, float %192, !dbg !47
  %209 = select i1 %201, float 0.000000e+00, float %193, !dbg !47
  %210 = select i1 %202, float 0.000000e+00, float %194, !dbg !47
  %211 = fptrunc float %203 to half, !dbg !48
  %212 = fptrunc float %204 to half, !dbg !48
  %213 = fptrunc float %205 to half, !dbg !48
  %214 = fptrunc float %206 to half, !dbg !48
  %215 = fptrunc float %207 to half, !dbg !48
  %216 = fptrunc float %208 to half, !dbg !48
  %217 = fptrunc float %209 to half, !dbg !48
  %218 = fptrunc float %210 to half, !dbg !48
  %219 = insertelement <2 x half> poison, half %211, i64 0, !dbg !48
  %220 = insertelement <2 x half> %219, half %212, i64 1, !dbg !48
  %221 = bitcast <2 x half> %220 to i32, !dbg !48
  %222 = insertelement <2 x half> poison, half %213, i64 0, !dbg !48
  %223 = insertelement <2 x half> %222, half %214, i64 1, !dbg !48
  %224 = bitcast <2 x half> %223 to i32, !dbg !48
  %225 = insertelement <2 x half> poison, half %215, i64 0, !dbg !48
  %226 = insertelement <2 x half> %225, half %216, i64 1, !dbg !48
  %227 = bitcast <2 x half> %226 to i32, !dbg !48
  %228 = insertelement <2 x half> poison, half %217, i64 0, !dbg !48
  %229 = insertelement <2 x half> %228, half %218, i64 1, !dbg !48
  %230 = bitcast <2 x half> %229 to i32, !dbg !48
  %231 = insertelement <4 x i32> poison, i32 %221, i64 0, !dbg !48
  %232 = insertelement <4 x i32> %231, i32 %224, i64 1, !dbg !48
  %233 = insertelement <4 x i32> %232, i32 %227, i64 2, !dbg !48
  %234 = insertelement <4 x i32> %233, i32 %230, i64 3, !dbg !48
  store <4 x i32> %234, ptr addrspace(1) %18, align 16, !dbg !48
  ret void, !dbg !49
}

; Function Attrs: convergent mustprogress nofree nounwind willreturn memory(none)
declare dso_local spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef) local_unnamed_addr #2

attributes #0 = { mustprogress nofree nosync nounwind willreturn memory(none) }
attributes #1 = { mustprogress nofree nounwind willreturn memory(argmem: readwrite) }
attributes #2 = { convergent mustprogress nofree nounwind willreturn memory(none) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #3 = { nounwind willreturn memory(none) }
attributes #4 = { convergent nounwind willreturn memory(none) }

!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!2, !3, !4, !5}
!opencl.spir.version = !{!6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6}
!spirv.Source = !{!7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7}
!llvm.ident = !{!8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8}

!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!1 = !DIFile(filename: "ci4ch6recibbahxfnjzez75dv5ritrfki3463z37medubgfvzt3z.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/torchinductor_cache3/i4")
!2 = !{i32 2, !"Debug Info Version", i32 3}
!3 = !{i32 1, !"wchar_size", i32 4}
!4 = !{i32 1, !"sycl-device", i32 1}
!5 = !{i32 7, !"frame-pointer", i32 2}
!6 = !{i32 1, i32 2}
!7 = !{i32 3, i32 100000}
!8 = !{!"Intel(R) oneAPI DPC++/C++ Compiler 2025.0.0 (2025.0.0.20241008)"}
!9 = !DISubprogram(name: "_Z12get_local_idj", linkageName: "_Z12get_local_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!10 = !DISubroutineType(cc: DW_CC_normal, types: !11)
!11 = !{}
!12 = !DISubprogram(name: "_Z12get_group_idj", linkageName: "_Z12get_group_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!13 = distinct !DISubprogram(name: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", linkageName: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
!14 = !{i32 32}
!15 = !{i64 128, i64 1, i64 1}
!16 = !DILocation(line: 20, column: 28, scope: !13)
!17 = !DILocation(line: 20, column: 33, scope: !13)
!18 = !DILocation(line: 21, column: 36, scope: !13)
!19 = !DILocation(line: 21, column: 23, scope: !13)
!20 = !DILocation(line: 24, column: 19, scope: !13)
!21 = !DILocation(line: 25, column: 34, scope: !13)
!22 = !DILocation(line: 25, column: 39, scope: !13)
!23 = !DILocation(line: 25, column: 48, scope: !13)
!24 = !DILocation(line: 26, column: 30, scope: !13)
!25 = !DILocation(line: 26, column: 35, scope: !13)
!26 = !DILocation(line: 26, column: 74, scope: !13)
!27 = !DILocation(line: 27, column: 30, scope: !13)
!28 = !DILocation(line: 27, column: 35, scope: !13)
!29 = !DILocation(line: 27, column: 74, scope: !13)
!30 = !DILocation(line: 28, column: 31, scope: !13)
!31 = !DILocation(line: 28, column: 36, scope: !13)
!32 = !DILocation(line: 28, column: 75, scope: !13)
!33 = !DILocation(line: 29, column: 31, scope: !13)
!34 = !DILocation(line: 29, column: 36, scope: !13)
!35 = !DILocation(line: 29, column: 75, scope: !13)
!36 = !DILocation(line: 32, column: 18, scope: !13)
!37 = !DILocation(line: 35, column: 18, scope: !13)
!38 = !DILocation(line: 36, column: 26, scope: !13)
!39 = !DILocation(line: 38, column: 21, scope: !13)
!40 = !DILocation(line: 41, column: 19, scope: !13)
!41 = !DILocation(line: 43, column: 20, scope: !13)
!42 = !DILocation(line: 45, column: 20, scope: !13)
!43 = !DILocation(line: 111, column: 15, scope: !44, inlinedAt: !46)
!44 = distinct !DILexicalBlockFile(scope: !13, file: !45, discriminator: 0)
!45 = !DIFile(filename: "triton_helpers.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/.scripts_cache/pytorch/torch/_inductor/runtime")
!46 = !DILocation(line: 48, column: 42, scope: !13)
!47 = !DILocation(line: 114, column: 29, scope: !44, inlinedAt: !46)
!48 = !DILocation(line: 49, column: 40, scope: !13)
!49 = !DILocation(line: 49, column: 4, scope: !13)
Incorrect kernel
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !9 spir_func i64 @_Z12get_local_idj(i32) local_unnamed_addr #0

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
declare !dbg !12 spir_func i64 @_Z12get_group_idj(i32) local_unnamed_addr #0

; Function Attrs: mustprogress nofree nounwind willreturn memory(argmem: readwrite)
define spir_kernel void @triton_poi_fused__native_batch_norm_legit_no_training_relu_2(ptr addrspace(1) captures(none) %0, ptr addrspace(1) readonly captures(none) %1, ptr addrspace(1) readonly captures(none) %2, ptr addrspace(1) readonly captures(none) %3, ptr addrspace(1) readonly captures(none) %4, i32 %5, ptr addrspace(1) readnone captures(none) %6) local_unnamed_addr #1 !dbg !13 !intel_reqd_sub_group_size !14 !max_work_group_size !15 {
  %8 = tail call spir_func i64 @_Z12get_group_idj(i32 0) #3, !dbg !16
  %9 = trunc i64 %8 to i32, !dbg !16
  %10 = shl i32 %9, 10, !dbg !17
  %11 = tail call spir_func i64 @_Z12get_local_idj(i32 0) #3, !dbg !18
  %12 = trunc i64 %11 to i32, !dbg !18
  %13 = shl i32 %12, 3, !dbg !18
  %14 = and i32 %13, 1016, !dbg !18
  %15 = or disjoint i32 %14, %10, !dbg !19
  %16 = srem i32 %15, 64, !dbg !20
  %17 = sext i32 %15 to i64, !dbg !21
  %18 = getelementptr half, ptr addrspace(1) %0, i64 %17, !dbg !21
  %19 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 4, !dbg !22
  %20 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 8, !dbg !22
  %21 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 12, !dbg !22
  %22 = load <2 x half>, ptr addrspace(1) %21, align 4, !dbg !22
  %23 = load half, ptr addrspace(1) %18, align 16, !dbg !22
  %24 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 2, !dbg !22
  %25 = load half, ptr addrspace(1) %24, align 2, !dbg !22
  %26 = load half, ptr addrspace(1) %19, align 4, !dbg !22
  %27 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 6, !dbg !22
  %28 = load half, ptr addrspace(1) %27, align 2, !dbg !22
  %29 = load half, ptr addrspace(1) %20, align 8, !dbg !22
  %30 = getelementptr inbounds nuw i8, ptr addrspace(1) %18, i64 10, !dbg !22
  %31 = load half, ptr addrspace(1) %30, align 2, !dbg !22
  %32 = extractelement <2 x half> %22, i64 0, !dbg !22
  %33 = extractelement <2 x half> %22, i64 1, !dbg !22
  %34 = fpext half %23 to float, !dbg !23
  %35 = fpext half %25 to float, !dbg !23
  %36 = fpext half %26 to float, !dbg !23
  %37 = fpext half %28 to float, !dbg !23
  %38 = fpext half %29 to float, !dbg !23
  %39 = fpext half %31 to float, !dbg !23
  %40 = fpext half %32 to float, !dbg !23
  %41 = fpext half %33 to float, !dbg !23
  %42 = sext i32 %16 to i64, !dbg !24
  %43 = getelementptr half, ptr addrspace(1) %1, i64 %42, !dbg !24
  %44 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 4, !dbg !25
  %45 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 8, !dbg !25
  %46 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 12, !dbg !25
  %47 = load <2 x half>, ptr addrspace(1) %46, align 4, !dbg !25
  %48 = load half, ptr addrspace(1) %43, align 16, !dbg !25
  %49 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 2, !dbg !25
  %50 = load half, ptr addrspace(1) %49, align 2, !dbg !25
  %51 = load half, ptr addrspace(1) %44, align 4, !dbg !25
  %52 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 6, !dbg !25
  %53 = load half, ptr addrspace(1) %52, align 2, !dbg !25
  %54 = load half, ptr addrspace(1) %45, align 8, !dbg !25
  %55 = getelementptr inbounds nuw i8, ptr addrspace(1) %43, i64 10, !dbg !25
  %56 = load half, ptr addrspace(1) %55, align 2, !dbg !25
  %57 = extractelement <2 x half> %47, i64 0, !dbg !25
  %58 = extractelement <2 x half> %47, i64 1, !dbg !25
  %59 = fpext half %48 to float, !dbg !26
  %60 = fpext half %50 to float, !dbg !26
  %61 = fpext half %51 to float, !dbg !26
  %62 = fpext half %53 to float, !dbg !26
  %63 = fpext half %54 to float, !dbg !26
  %64 = fpext half %56 to float, !dbg !26
  %65 = fpext half %57 to float, !dbg !26
  %66 = fpext half %58 to float, !dbg !26
  %67 = getelementptr half, ptr addrspace(1) %2, i64 %42, !dbg !27
  %68 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 4, !dbg !28
  %69 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 8, !dbg !28
  %70 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 12, !dbg !28
  %71 = load <2 x half>, ptr addrspace(1) %70, align 4, !dbg !28
  %72 = load half, ptr addrspace(1) %67, align 16, !dbg !28
  %73 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 2, !dbg !28
  %74 = load half, ptr addrspace(1) %73, align 2, !dbg !28
  %75 = load half, ptr addrspace(1) %68, align 4, !dbg !28
  %76 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 6, !dbg !28
  %77 = load half, ptr addrspace(1) %76, align 2, !dbg !28
  %78 = load half, ptr addrspace(1) %69, align 8, !dbg !28
  %79 = getelementptr inbounds nuw i8, ptr addrspace(1) %67, i64 10, !dbg !28
  %80 = load half, ptr addrspace(1) %79, align 2, !dbg !28
  %81 = extractelement <2 x half> %71, i64 0, !dbg !28
  %82 = extractelement <2 x half> %71, i64 1, !dbg !28
  %83 = fpext half %72 to float, !dbg !29
  %84 = fpext half %74 to float, !dbg !29
  %85 = fpext half %75 to float, !dbg !29
  %86 = fpext half %77 to float, !dbg !29
  %87 = fpext half %78 to float, !dbg !29
  %88 = fpext half %80 to float, !dbg !29
  %89 = fpext half %81 to float, !dbg !29
  %90 = fpext half %82 to float, !dbg !29
  %91 = getelementptr half, ptr addrspace(1) %3, i64 %42, !dbg !30
  %92 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 4, !dbg !31
  %93 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 8, !dbg !31
  %94 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 12, !dbg !31
  %95 = load <2 x half>, ptr addrspace(1) %94, align 4, !dbg !31
  %96 = load half, ptr addrspace(1) %91, align 16, !dbg !31
  %97 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 2, !dbg !31
  %98 = load half, ptr addrspace(1) %97, align 2, !dbg !31
  %99 = load half, ptr addrspace(1) %92, align 4, !dbg !31
  %100 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 6, !dbg !31
  %101 = load half, ptr addrspace(1) %100, align 2, !dbg !31
  %102 = load half, ptr addrspace(1) %93, align 8, !dbg !31
  %103 = getelementptr inbounds nuw i8, ptr addrspace(1) %91, i64 10, !dbg !31
  %104 = load half, ptr addrspace(1) %103, align 2, !dbg !31
  %105 = extractelement <2 x half> %95, i64 0, !dbg !31
  %106 = extractelement <2 x half> %95, i64 1, !dbg !31
  %107 = fpext half %96 to float, !dbg !32
  %108 = fpext half %98 to float, !dbg !32
  %109 = fpext half %99 to float, !dbg !32
  %110 = fpext half %101 to float, !dbg !32
  %111 = fpext half %102 to float, !dbg !32
  %112 = fpext half %104 to float, !dbg !32
  %113 = fpext half %105 to float, !dbg !32
  %114 = fpext half %106 to float, !dbg !32
  %115 = getelementptr half, ptr addrspace(1) %4, i64 %42, !dbg !33
  %116 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 4, !dbg !34
  %117 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 8, !dbg !34
  %118 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 12, !dbg !34
  %119 = load <2 x half>, ptr addrspace(1) %118, align 4, !dbg !34
  %120 = load half, ptr addrspace(1) %115, align 16, !dbg !34
  %121 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 2, !dbg !34
  %122 = load half, ptr addrspace(1) %121, align 2, !dbg !34
  %123 = load half, ptr addrspace(1) %116, align 4, !dbg !34
  %124 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 6, !dbg !34
  %125 = load half, ptr addrspace(1) %124, align 2, !dbg !34
  %126 = load half, ptr addrspace(1) %117, align 8, !dbg !34
  %127 = getelementptr inbounds nuw i8, ptr addrspace(1) %115, i64 10, !dbg !34
  %128 = load half, ptr addrspace(1) %127, align 2, !dbg !34
  %129 = extractelement <2 x half> %119, i64 0, !dbg !34
  %130 = extractelement <2 x half> %119, i64 1, !dbg !34
  %131 = fpext half %120 to float, !dbg !35
  %132 = fpext half %122 to float, !dbg !35
  %133 = fpext half %123 to float, !dbg !35
  %134 = fpext half %125 to float, !dbg !35
  %135 = fpext half %126 to float, !dbg !35
  %136 = fpext half %128 to float, !dbg !35
  %137 = fpext half %129 to float, !dbg !35
  %138 = fpext half %130 to float, !dbg !35
  %139 = fsub float %34, %59, !dbg !36
  %140 = fsub float %35, %60, !dbg !36
  %141 = fsub float %36, %61, !dbg !36
  %142 = fsub float %37, %62, !dbg !36
  %143 = fsub float %38, %63, !dbg !36
  %144 = fsub float %39, %64, !dbg !36
  %145 = fsub float %40, %65, !dbg !36
  %146 = fsub float %41, %66, !dbg !36
  %147 = fadd contract float %83, 0x3EE4F8B580000000, !dbg !37
  %148 = fadd contract float %84, 0x3EE4F8B580000000, !dbg !37
  %149 = fadd contract float %85, 0x3EE4F8B580000000, !dbg !37
  %150 = fadd contract float %86, 0x3EE4F8B580000000, !dbg !37
  %151 = fadd contract float %87, 0x3EE4F8B580000000, !dbg !37
  %152 = fadd contract float %88, 0x3EE4F8B580000000, !dbg !37
  %153 = fadd contract float %89, 0x3EE4F8B580000000, !dbg !37
  %154 = fadd contract float %90, 0x3EE4F8B580000000, !dbg !37
  %155 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %147) #4, !dbg !38
  %156 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %148) #4, !dbg !38
  %157 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %149) #4, !dbg !38
  %158 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %150) #4, !dbg !38
  %159 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %151) #4, !dbg !38
  %160 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %152) #4, !dbg !38
  %161 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %153) #4, !dbg !38
  %162 = tail call spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef %154) #4, !dbg !38
  %163 = fdiv float 1.000000e+00, %155, !dbg !39
  %164 = fdiv float 1.000000e+00, %156, !dbg !39
  %165 = fdiv float 1.000000e+00, %157, !dbg !39
  %166 = fdiv float 1.000000e+00, %158, !dbg !39
  %167 = fdiv float 1.000000e+00, %159, !dbg !39
  %168 = fdiv float 1.000000e+00, %160, !dbg !39
  %169 = fdiv float 1.000000e+00, %161, !dbg !39
  %170 = fdiv float 1.000000e+00, %162, !dbg !39
  %171 = fmul contract float %139, %163, !dbg !40
  %172 = fmul contract float %140, %164, !dbg !40
  %173 = fmul contract float %141, %165, !dbg !40
  %174 = fmul contract float %142, %166, !dbg !40
  %175 = fmul contract float %143, %167, !dbg !40
  %176 = fmul contract float %144, %168, !dbg !40
  %177 = fmul contract float %145, %169, !dbg !40
  %178 = fmul contract float %146, %170, !dbg !40
  %179 = fmul contract float %171, %107, !dbg !41
  %180 = fmul contract float %172, %108, !dbg !41
  %181 = fmul contract float %173, %109, !dbg !41
  %182 = fmul contract float %174, %110, !dbg !41
  %183 = fmul contract float %175, %111, !dbg !41
  %184 = fmul contract float %176, %112, !dbg !41
  %185 = fmul contract float %177, %113, !dbg !41
  %186 = fmul contract float %178, %114, !dbg !41
  %187 = fadd contract float %179, %131, !dbg !42
  %188 = fadd contract float %180, %132, !dbg !42
  %189 = fadd contract float %181, %133, !dbg !42
  %190 = fadd contract float %182, %134, !dbg !42
  %191 = fadd contract float %183, %135, !dbg !42
  %192 = fadd contract float %184, %136, !dbg !42
  %193 = fadd contract float %185, %137, !dbg !42
  %194 = fadd contract float %186, %138, !dbg !42
  %195 = fcmp olt float %187, 0.000000e+00, !dbg !43
  %196 = fcmp olt float %188, 0.000000e+00, !dbg !43
  %197 = fcmp olt float %189, 0.000000e+00, !dbg !43
  %198 = fcmp olt float %190, 0.000000e+00, !dbg !43
  %199 = fcmp olt float %191, 0.000000e+00, !dbg !43
  %200 = fcmp olt float %192, 0.000000e+00, !dbg !43
  %201 = fcmp olt float %193, 0.000000e+00, !dbg !43
  %202 = fcmp olt float %194, 0.000000e+00, !dbg !43
  %203 = select i1 %195, float 0.000000e+00, float %187, !dbg !47
  %204 = select i1 %196, float 0.000000e+00, float %188, !dbg !47
  %205 = select i1 %197, float 0.000000e+00, float %189, !dbg !47
  %206 = select i1 %198, float 0.000000e+00, float %190, !dbg !47
  %207 = select i1 %199, float 0.000000e+00, float %191, !dbg !47
  %208 = select i1 %200, float 0.000000e+00, float %192, !dbg !47
  %209 = select i1 %201, float 0.000000e+00, float %193, !dbg !47
  %210 = select i1 %202, float 0.000000e+00, float %194, !dbg !47
  %211 = fptrunc float %203 to half, !dbg !48
  %212 = fptrunc float %204 to half, !dbg !48
  %213 = fptrunc float %205 to half, !dbg !48
  %214 = fptrunc float %206 to half, !dbg !48
  %215 = fptrunc float %207 to half, !dbg !48
  %216 = fptrunc float %208 to half, !dbg !48
  %217 = fptrunc float %209 to half, !dbg !48
  %218 = fptrunc float %210 to half, !dbg !48
  %219 = insertelement <2 x half> poison, half %211, i64 0, !dbg !48
  %220 = insertelement <2 x half> %219, half %212, i64 1, !dbg !48
  %221 = bitcast <2 x half> %220 to i32, !dbg !48
  %222 = insertelement <2 x half> poison, half %213, i64 0, !dbg !48
  %223 = insertelement <2 x half> %222, half %214, i64 1, !dbg !48
  %224 = bitcast <2 x half> %223 to i32, !dbg !48
  %225 = insertelement <2 x half> poison, half %215, i64 0, !dbg !48
  %226 = insertelement <2 x half> %225, half %216, i64 1, !dbg !48
  %227 = bitcast <2 x half> %226 to i32, !dbg !48
  %228 = insertelement <2 x half> poison, half %217, i64 0, !dbg !48
  %229 = insertelement <2 x half> %228, half %218, i64 1, !dbg !48
  %230 = bitcast <2 x half> %229 to i32, !dbg !48
  %231 = insertelement <4 x i32> poison, i32 %221, i64 0, !dbg !48
  %232 = insertelement <4 x i32> %231, i32 %224, i64 1, !dbg !48
  %233 = insertelement <4 x i32> %232, i32 %227, i64 2, !dbg !48
  %234 = insertelement <4 x i32> %233, i32 %230, i64 3, !dbg !48
  store <4 x i32> %234, ptr addrspace(1) %18, align 16, !dbg !48
  ret void, !dbg !49
}

; Function Attrs: convergent mustprogress nofree nounwind willreturn memory(none)
declare dso_local spir_func noundef float @_Z16__spirv_ocl_sqrtf(float noundef) local_unnamed_addr #2

attributes #0 = { mustprogress nofree nosync nounwind willreturn memory(none) }
attributes #1 = { mustprogress nofree nounwind willreturn memory(argmem: readwrite) }
attributes #2 = { convergent mustprogress nofree nounwind willreturn memory(none) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #3 = { nounwind willreturn memory(none) }
attributes #4 = { convergent nounwind willreturn memory(none) }

!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!2, !3, !4, !5}
!opencl.spir.version = !{!6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6}
!spirv.Source = !{!7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7, !7}
!llvm.ident = !{!8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8, !8}

!0 = distinct !DICompileUnit(language: DW_LANG_C, file: !1, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!1 = !DIFile(filename: "chsgifbghvzneiwpg4l7venb33d6syzbr2mypkbv7rvm6ptn6svr.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/torchinductor_cache4/hs")
!2 = !{i32 2, !"Debug Info Version", i32 3}
!3 = !{i32 1, !"wchar_size", i32 4}
!4 = !{i32 1, !"sycl-device", i32 1}
!5 = !{i32 7, !"frame-pointer", i32 2}
!6 = !{i32 1, i32 2}
!7 = !{i32 3, i32 100000}
!8 = !{!"Intel(R) oneAPI DPC++/C++ Compiler 2025.0.0 (2025.0.0.20241008)"}
!9 = !DISubprogram(name: "_Z12get_local_idj", linkageName: "_Z12get_local_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!10 = !DISubroutineType(cc: DW_CC_normal, types: !11)
!11 = !{}
!12 = !DISubprogram(name: "_Z12get_group_idj", linkageName: "_Z12get_group_idj", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagOptimized)
!13 = distinct !DISubprogram(name: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", linkageName: "triton_poi_fused__native_batch_norm_legit_no_training_relu_2", scope: !1, file: !1, line: 18, type: !10, scopeLine: 18, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0)
!14 = !{i32 32}
!15 = !{i64 128, i64 1, i64 1}
!16 = !DILocation(line: 20, column: 28, scope: !13)
!17 = !DILocation(line: 20, column: 33, scope: !13)
!18 = !DILocation(line: 21, column: 36, scope: !13)
!19 = !DILocation(line: 21, column: 23, scope: !13)
!20 = !DILocation(line: 24, column: 19, scope: !13)
!21 = !DILocation(line: 25, column: 34, scope: !13)
!22 = !DILocation(line: 25, column: 39, scope: !13)
!23 = !DILocation(line: 25, column: 48, scope: !13)
!24 = !DILocation(line: 26, column: 30, scope: !13)
!25 = !DILocation(line: 26, column: 35, scope: !13)
!26 = !DILocation(line: 26, column: 74, scope: !13)
!27 = !DILocation(line: 27, column: 30, scope: !13)
!28 = !DILocation(line: 27, column: 35, scope: !13)
!29 = !DILocation(line: 27, column: 74, scope: !13)
!30 = !DILocation(line: 28, column: 31, scope: !13)
!31 = !DILocation(line: 28, column: 36, scope: !13)
!32 = !DILocation(line: 28, column: 75, scope: !13)
!33 = !DILocation(line: 29, column: 31, scope: !13)
!34 = !DILocation(line: 29, column: 36, scope: !13)
!35 = !DILocation(line: 29, column: 75, scope: !13)
!36 = !DILocation(line: 32, column: 18, scope: !13)
!37 = !DILocation(line: 35, column: 18, scope: !13)
!38 = !DILocation(line: 36, column: 26, scope: !13)
!39 = !DILocation(line: 38, column: 21, scope: !13)
!40 = !DILocation(line: 41, column: 19, scope: !13)
!41 = !DILocation(line: 43, column: 20, scope: !13)
!42 = !DILocation(line: 45, column: 20, scope: !13)
!43 = !DILocation(line: 111, column: 15, scope: !44, inlinedAt: !46)
!44 = distinct !DILexicalBlockFile(scope: !13, file: !45, discriminator: 0)
!45 = !DIFile(filename: "triton_helpers.py", directory: "/home/jovyan/intel-xpu-backend-for-triton/.scripts_cache/pytorch/torch/_inductor/runtime")
!46 = !DILocation(line: 48, column: 42, scope: !13)
!47 = !DILocation(line: 114, column: 29, scope: !44, inlinedAt: !46)
!48 = !DILocation(line: 49, column: 40, scope: !13)
!49 = !DILocation(line: 49, column: 4, scope: !13)

UPD:
The numerical difference is not that big, however, this will subsequently lead to the lines getting mixed up:

Mismatched elements: 15 / 92143616 (0.0%)
Greatest absolute difference: 0.00390625 at index (0, 29, 0, 598) (up to 0.0005 allowed)
Greatest relative difference: 0.0008788108825683594 at index (0, 29, 0, 598) (up to 0.0005 allowed)

This is how I was able to figure out that this particular kernel was the cause of this behavior:

Run command: TORCHINDUCTOR_CACHE_DIR=/home/jovyan/intel-xpu-backend-for-triton/torchinductor_cache2 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_COMPILE_DEBUG=1 python benchmarks/dynamo/torchbench.py --accuracy --float16 -d xpu -n10 --inference --only detectron2_fasterrcnn_r_50_fpn --backend=inductor --cold-start-latency

Modification in compiler:

        intel.set_spv_target_triple(llvm_mod)
        disable_fast_math_kernels = {
            'triton_poi_fused_ceil_div_mul_sub_0.ttir',
            'triton_poi_fused_add_25.ttir',
            'triton_poi_fused_convolution_relu_11.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
            'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_6.ttir',
            'triton_poi_fused_mul_0.ttir',
            'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_15.ttir',
            'triton_poi_fused_add_2.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_0.ttir',
            'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_19.ttir',
            'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_18.ttir',
            'triton_poi_fused_clone_16.ttir',
            'triton_poi_fused__native_batch_norm_legit_no_training_relu_8.ttir',
            'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_6.ttir',
            'triton_poi_fused_add_21.ttir',
            'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_2.ttir',
            'triton_poi_fused_convolution_relu_14.ttir',
            'triton_poi_fused_clone_12.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_1.ttir',
            'triton_poi_fused_clone_9.ttir',
            'triton_poi_fused_clone_15.ttir',
            'triton_poi_fused_ceil_div_mul_sub_0.ttir',
            'triton_poi_fused_stack_3.ttir',
            'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_2.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_3.ttir',
            'triton_poi_fused_clone_3.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
            'triton_poi_fused__native_batch_norm_legit_no_training_relu_12.ttir',
            'triton_poi_fused_mul_scalar_tensor_where_1.ttir',
            'triton_poi_fused_add_22.ttir',
            #'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_3.ttir',
            #'triton_poi_fused__native_batch_norm_legit_no_training_add_relu_10.ttir',
            #'triton_poi_fused__to_copy__unsafe_index_add_convolution_24.ttir',
            #'triton_poi_fused__to_copy__unsafe_index_add_convolution_25.ttir',
            #'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_2.ttir',
            #'triton_poi_fused_convolution_28.ttir',
            #'triton_poi_fused_clone_19.ttir',
            #'triton_poi_fused__to_copy__unsafe_index_add_convolution_23.ttir',
            #'triton_poi_fused_clone_17.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_0.ttir',
            'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_where_0.ttir',
            'triton_poi_fused_stack_2.ttir',
            'triton_poi_fused_add_23.ttir',
            'triton_poi_fused_ceil_div_mul_sub_1.ttir',
            'triton_poi_fused_clone_18.ttir',
            'triton_poi_fused_add_24.ttir',
            #'triton_red_fused_add_div_index_mul_scalar_tensor_sum_where_6.ttir',
            #'triton_poi_fused__to_copy_add_clamp_ge_mul_scalar_tensor_sub_where_2.ttir',
            #'triton_poi_fused_ceil_div_mul_sub_0.ttir',

            #'triton_poi_fused__native_batch_norm_legit_no_training_relu_2.ttir',   # <- uncommenting this line results in a precision error 
            #'triton_poi_fused_ceil_div_mul_sub_0.ttir',

            #'triton_poi_fused_mul_0.ttir',
            #'triton_poi_fused__to_copy__unsafe_index_add_convolution_26.ttir',
        }

        def get_function_name():
            return glob.glob('*.ttir', root_dir=metadata['cache_dir'])[0]

        if not get_function_name() in disable_fast_math_kernels:
            intel.set_fast_math(llvm_mod)
        else:
            print("ignore: ", get_function_name())

@anmyachev
Copy link
Contributor Author

Mismatched elements: 15 / 92143616 (0.0%)
Greatest absolute difference: 0.00390625 at index (0, 29, 0, 598) (up to 0.0005 allowed)
Greatest relative difference: 0.0008788108825683594 at index (0, 29, 0, 598) (up to 0.0005 allowed)

@whitneywhtsang @chengjunlu as per the message above, I don't see a way other than just disabling fp_fusion for this particular benchmark. Thoughts?

@chengjunlu
Copy link
Contributor

Mismatched elements: 15 / 92143616 (0.0%)
Greatest absolute difference: 0.00390625 at index (0, 29, 0, 598) (up to 0.0005 allowed)
Greatest relative difference: 0.0008788108825683594 at index (0, 29, 0, 598) (up to 0.0005 allowed)

@whitneywhtsang @chengjunlu as per the message above, I don't see a way other than just disabling fp_fusion for this particular benchmark. Thoughts?

I am ok to disable it.

@anmyachev
Copy link
Contributor Author

Decided to check IGC code generation...

@whitneywhtsang Is it expected to have mad instructions instead of fma?

image

@whitneywhtsang
Copy link
Contributor

@whitneywhtsang @chengjunlu as per the message above, I don't see a way other than just disabling fp_fusion for this particular benchmark. Thoughts?

We need PyTorch to agree to either changing reference implementation or disabling fp_fusion for this particular benchmark.

@whitneywhtsang
Copy link
Contributor

@whitneywhtsang Is it expected to have mad instructions instead of fma?

It is legal to contract to mad instructions, let's discuss about this with IGC team tomorrow.

@anmyachev anmyachev marked this pull request as ready for review July 22, 2025 12:48
@anmyachev
Copy link
Contributor Author

anmyachev commented Jul 22, 2025

NOTE2: The changes after applying mad operation are minor (tensors pass accuracy check with atol && rtol ==1e-6), but after converting to float16 type, the difference becomes significant and this affects the final accuracy check.

(Pdb) in_out_ptr2[(0, 20, 17, 246)]  # <- problematic value in float32 with contraction (i.e. like `-ff-contact=1`)
tensor(-5.5410, device='xpu:0')
(Pdb) in_out_ptr3[(0, 20, 17, 246)]  # <- problematic value in float32 without contraction (i.e. like `-ff-contact=0`)
tensor(-5.5410, device='xpu:0')
(Pdb) in_out_ptr3[(0, 20, 17, 246)].to(torch.float16)  # <- problematic value in float16 with contraction (i.e. like `-ff-contact=1`)
tensor(-5.5430, device='xpu:0', dtype=torch.float16)
(Pdb) in_out_ptr2[(0, 20, 17, 246)].to(torch.float16)  # <- problematic value in float16 with contraction (i.e. like `-ff-contact=0`)
tensor(-5.5391, device='xpu:0', dtype=torch.float16)
(Pdb) in_out_ptr2[(0, 20, 17, 246)].to(torch.float16) - in_out_ptr3[(0, 20, 17, 246)].to(torch.float16)
tensor(0.0039, device='xpu:0', dtype=torch.float16)

@chuanqi129 @mengfei25
According to this we can confirm that there is no bug here (neither in Triton nor in IGC), but just a typical situation of error accumulation in real numbers. And since we can't have -ff-contact disabled by default because we're following Triton's upstream behavior, it turns out that this will need to be done for the model itself in torch-xpu-ops (just set TRITON_DEFAULT_FP_FUSION=0). I also suspect that any minor change in the aten operators or kernels might change the floats slightly and this problem would go away, so that would need to be double-checked.

@whitneywhtsang I think that's enough as evidence and considering that we have no choice, we can merge it right now.

NOTE: I found one more interesting detail that looks like a bug through. If write the results to a buffer of fp32, instead of fp16, then the results begin to undergo correctness checking, regardless of whether contract is enabled or not (I double checked that .visaasm contains mad instructions). It turns out that there can be two options: the first is that we have incorrectly written the conversion from fp32 to fp16, or the code itself for converting these types somehow affects the generation of the code of the entire kernel.

Reproducer
import triton
import triton.language as tl
import torch

import copy
import os

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

'''
@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp16', 'in_ptr0': '*fp16', 'in_ptr1': '*fp16', 'in_ptr2': '*fp16', 'in_ptr3': '*fp16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=56, cc={'architecture': 13136561920, 'device_id': 3034, 'driver_version': '1.6.33276+22', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51522830336, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__native_batch_norm_legit_no_training_relu_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': 'BF1FE552CB8E524B1A91FA16118F1D335628999A2F77D64AB02D0408920F518D', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 552862208}},
    min_elem_per_thread=0
)
'''
@triton.jit
def triton_poi_fused__native_batch_norm_legit_no_training_relu_2(in_out_ptr0, in_ptr2, in_out_ptr2, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    x2 = xindex
    x0 = (xindex % 64)
    tmp0 = tl.load(in_out_ptr0 + x2).to(tl.float32)
    tmp15 = tl.load(in_ptr2 + x0).to(tl.float32)
    tmp14 = tmp0 * 0.021739
    tmp17 = tmp14 * tmp15
    tmp20 = tmp17 + 0.1
    tl.store(in_out_ptr2 + x2, tmp20)


@triton.jit
def triton_poi_fused__native_batch_norm_legit_no_training_relu_2_no_contract(in_out_ptr0, in_ptr2, in_out_ptr2, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    x2 = xindex
    x0 = (xindex % 64)
    tmp0 = tl.load(in_out_ptr0 + x2).to(tl.float32)
    tmp15 = tl.load(in_ptr2 + x0).to(tl.float32)
    tmp14 = tmp0 * 0.021739
    tmp17 = tmp14 * tmp15
    tmp20 = tmp17 + 0.1
    tl.store(in_out_ptr2 + x2, tmp20)


def main():
    torch.manual_seed(42)
    #in_out_ptr0 = torch.rand((4, 64, 592, 608), device='xpu:0', dtype=torch.float16)
    in_out_ptr0 = torch.load(
        '.scripts_cache/pytorch/triton_poi_fused__native_batch_norm_legit_no_training_relu_2_tensor_reduced.pt',
        map_location=torch.device("xpu:0"))
    in_out_ptr1 = copy.deepcopy(in_out_ptr0)
    # if write to float32 buffers instead of float16 buffers (e.c. to avoid fp trunc from fp32 to fp16)
    # the accuracy check starts to run successfully
    in_out_ptr2 = copy.deepcopy(in_out_ptr0).to(torch.float32)
    in_out_ptr3 = copy.deepcopy(in_out_ptr0).to(torch.float32)
    breakpoint()
    in_ptr2 = torch.tensor([1.0918, 1.1143, 0.8721, 1.3779, 1.1670, 1.1963, 0.9111, 1.5654, 2.5527,
        0.8813, 1.1953, 0.8633, 0.9434, 2.4863, 1.0371, 1.2842, 1.4512, 1.3008,
        1.0088, 1.3760, 1.0635, 1.1025, 2.6348, 1.0762, 1.0830, 1.4102, 1.9824,
        0.7451, 0.9307, 2.0840, 1.4150, 0.8799, 0.8721, 0.5127, 1.2891, 0.9023,
        0.8804, 1.8643, 1.3809, 1.0107, 1.3945, 1.0410, 1.2939, 0.5283, 0.9634,
        2.6680, 1.1680, 1.1094, 0.8320, 1.0039, 1.4688, 1.5918, 0.5527, 1.2539,
        1.5557, 1.0693, 0.8511, 0.9292, 1.4111, 0.9453, 0.8613, 0.9160, 1.0430,
        1.0137], device='xpu:0', dtype=torch.float16)

    xnumel = 92143616
    XBLOCK = 1024

    # 89984 -> 22496
    triton_poi_fused__native_batch_norm_legit_no_training_relu_2[(22496, 1, 1)](
        in_out_ptr0, in_ptr2, in_out_ptr2, XBLOCK)

    os.environ['TRITON_DEFAULT_FP_FUSION'] = "0"
    triton_poi_fused__native_batch_norm_legit_no_training_relu_2_no_contract[(22496, 1, 1)](
        in_out_ptr1, in_ptr2, in_out_ptr3, XBLOCK)

    try:
        #torch.testing.assert_close(in_out_ptr0, in_out_ptr1, atol=1e-4, rtol=1e-4)
        torch.testing.assert_close(in_out_ptr2, in_out_ptr3, atol=1e-6, rtol=1e-6)
        torch.testing.assert_close(in_out_ptr2.to(torch.float16), in_out_ptr3.to(torch.float16), atol=1e-6, rtol=1e-6) # <- accuracy issue again
        breakpoint()
    except Exception as err:
        breakpoint()
        raise 



if __name__ == "__main__":
    main()

</details/

@anmyachev anmyachev enabled auto-merge (squash) July 22, 2025 18:08
@anmyachev anmyachev merged commit deb07d3 into main Jul 22, 2025
15 checks passed
@anmyachev anmyachev deleted the amyachev/issue4514 branch July 22, 2025 23:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable SPV_INTEL_fp_fast_math_mode back
5 participants