Skip to content

cpu: aarch64: prevent large post-op kernel generation#4857

Open
Sqvid wants to merge 1 commit intomainfrom
sqvid/fix-large-po-kernels
Open

cpu: aarch64: prevent large post-op kernel generation#4857
Sqvid wants to merge 1 commit intomainfrom
sqvid/fix-large-po-kernels

Conversation

@Sqvid
Copy link
Contributor

@Sqvid Sqvid commented Mar 19, 2026

Description

Prevents very large post-op kernels from being created. These were causing exceptions to be thrown.

Fixes: Issue #4089

Benchmarks

Generated the test set with ./build/tests/benchdnn/benchdnn --conv --dt=bf16,f32 --stag=axb --dtag=axb --attr-post-ops=sum,sum+tanh,exp,sum+exp,gelu_erf --impl=brgconv:sve --batch=shapes_gemm and filtered for changes of more than +/- 10%.

SVE-256:

No changes observed

SVE-128:

Speedups are mostly due to us being able to run previously skipped post-ops as they don't crash after this patch.

old impl new impl prb old avg new avg speedup ratio (old avg / new avg)
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=gelu_erf g1ic16id5oc16od3kd3pd4dd4n"big_padding_and_dilation_w.r.t._kernel_size-1" 0.0131649 0.00278983 4.7189x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=gelu_erf g1ic16ih5oc16oh3kh3ph4dh4n"big_padding_and_dilation_w.r.t._kernel_size-1" 0.00653319 0.00258148 2.5308x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=gelu_erf g1ic16iw5oc16ow3kw3pw4dw4n"big_padding_and_dilation_w.r.t._kernel_size-1" 0.0053604 0.00244591 2.1916x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=gelu_erf g1mb1ic64ih1000oc64oh1000kh3ph128dh127 8831.42 40.7008 216.9839x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=gelu_erf g3mb4ic3id5oc48od3kd3pd0n"ic%simd_width_!=0_with_im2col" 0.0816859 0.00491753 16.6112x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=gelu_erf g3mb4ic3ih5oc48oh3kh3ph0n"ic%simd_width_!=0_with_im2col" 0.0174365 0.00297653 5.8580x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=gelu_erf g3mb4ic3iw5oc48ow3kw3pw0n"ic%simd_width_!=0_with_im2col" 0.0067145 0.00258478 2.5977x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=sum+tanh g1ic16id5oc16od3kd3pd4dd4n"big_padding_and_dilation_w.r.t._kernel_size-1" 0.013488 0.00274914 4.9063x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=sum+tanh g1ic16ih5oc16oh3kh3ph4dh4n"big_padding_and_dilation_w.r.t._kernel_size-1" 0.00657586 0.00256887 2.5598x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=sum+tanh g1ic16iw5oc16ow3kw3pw4dw4n"big_padding_and_dilation_w.r.t._kernel_size-1" 0.00502907 0.0024492 2.0534x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=sum+tanh g1mb1ic64ih1000oc64oh1000kh3ph128dh127 8844.93 36.3443 243.3650x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=sum+tanh g3mb4ic3id5oc48od3kd3pd0n"ic%simd_width_!=0_with_im2col" 0.0835807 0.00474915 17.5991x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=sum+tanh g3mb4ic3ih5oc48oh3kh3ph0n"ic%simd_width_!=0_with_im2col" 0.0184007 0.00294531 6.2475x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --stag=axb --dtag=axb --attr-post-ops=sum+tanh g3mb4ic3iw5oc48ow3kw3pw0n"ic%simd_width_!=0_with_im2col" 0.00696479 0.00255392 2.7271x
brgconv:sve_128 brgconv:sve_128 --mode=P --conv --stag=axb --dtag=axb --attr-post-ops=sum+tanh g1ic16ih5oc16oh3kh3ph4dh4n"big_padding_and_dilation_w.r.t._kernel_size-1" 0.00251713 0.00280959 0.8959x

Checklist

General

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  • Have you formatted the code using clang-format?

Bug fixes

  • Have you included information on how to reproduce the issue (either in a github issue or in this PR)?

@github-actions github-actions bot added the platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64 label Mar 19, 2026
@Sqvid Sqvid marked this pull request as ready for review March 20, 2026 13:53
@Sqvid Sqvid requested a review from a team as a code owner March 20, 2026 13:53
@Sqvid
Copy link
Contributor Author

Sqvid commented Mar 20, 2026

@Sqvid Sqvid linked an issue Mar 20, 2026 that may be closed by this pull request
@jondea
Copy link
Contributor

jondea commented Mar 20, 2026

The nightly conv test suite is currently a bit rubbish (I need to fix it), would you be able to check a couple of shapes from the models files? E.g. yolo or resnet with some random post ops? The ref->brgconv will always be a win, but I just note that the the only brgconv->brgconv change is a regression (although it is a tiny shape, this could be noise)

Label mb_loop_end;
mov(x16, 0);
L(mb_loop_begin);
cmp(x16, mb);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail if mb>4096 (see CMP docs). It would be safer to use mov_imm(x16) and sub down to zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now. Thanks for spotting this!

@Sqvid
Copy link
Contributor Author

Sqvid commented Mar 20, 2026

I have rerun the benchmarks with batch composed of resnet_50 and yolov2 shapes.
The batch was generated with:

$ ./build/tests/benchdnn/benchdnn --conv --mode=L --dt=bf16,f32 --attr-post-ops=relu,sum+relu,sum+gelu_tanh --impl=brgconv:sve --batch=shapes_resnet_50 --batch=shapes_yolov2

Results

Filtered for changes of more than +/- 10%.

old impl new impl prb old avg new avg speedup ratio (old avg / new avg)
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic3ih610oc32oh608kh3ph0n"yolov2:conv1*6" 2475.11 41.3693 59.8296x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic32ih306oc64oh304kh3ph0n"yolov2:conv2*9" 8229.8 36.5077 225.4264x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic64ih154oc128oh152kh3ph0n"yolov2:conv3*18" 7785.78 28.7 271.2815x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic128ih78oc256oh76kh3ph0n"yolov2:conv5*18" 7565.69 24.9778 302.8966x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic256ih40oc512oh38kh3ph0n"yolov2:conv7*27" 7446.77 23.9786 310.5590x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic512ih21oc1024oh19kh3ph0n"yolov2:conv9*27" 7472.93 25.592 292.0026x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic1024ih21oc1024oh19kh3ph0n"yolov2:conv11*18" 14813.3 51.246 289.0626x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb16ic1280ih21oc1024oh19kh3ph0n"yolov2:conv13*9" 18827.3 64.0633 293.8859x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh g1mb50ic3ih224oc64oh112kh7sh2ph3n"resnet_50:conv1" 2095.52 13.8916 150.8480x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh mb50ic64ih56oc64oh56kh3ph1n"resnet_50:res2a_branch2b*3" 1614.22 6.10266 264.5109x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh mb50ic128ih28oc128oh28kh3ph1n"resnet_50:res3a_branch2b*4" 1529.69 5.11747 298.9153x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh mb50ic256ih14oc256oh14kh3ph1n"resnet_50:res4a_branch2b*6" 1450.94 5.22144 277.8812x
ref:any brgconv:sve_128 --mode=P --conv --dt=bf16:bf16:bf16 --attr-post-ops=sum+gelu_tanh mb50ic512ih7oc512oh7kh3ph1n"resnet_50:res5a_branch2b*3" 1294.69 5.64159 229.4903x

The generated post-op kernels were being unrolled excessively, leading
to extremely large code sizes. This lead to exceptions being thrown
during JIT assembly. This patch addresses this defect by using branching
and looping instead of unrolling.

Resolves: Issue #4089
Signed-off-by: Siddhartha Menon <siddhartha.menon@arm.com>
@Sqvid Sqvid force-pushed the sqvid/fix-large-po-kernels branch from 0bb7484 to 852d208 Compare March 20, 2026 15:20
@Sqvid Sqvid requested a review from jondea March 20, 2026 15:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Large kernels can generate xbyak_aarch64 exceptions

2 participants