Skip to content

graph: re-enable sdpa training bwd#4814

Open
ElaineBao wants to merge 23 commits intomainfrom
yixin/sdpa_ukernel_train
Open

graph: re-enable sdpa training bwd#4814
ElaineBao wants to merge 23 commits intomainfrom
yixin/sdpa_ukernel_train

Conversation

@ElaineBao
Copy link
Contributor

@ElaineBao ElaineBao commented Mar 12, 2026

Description

Note: please review #4825 first.

Implementation of Proposal 2.C in RFC:

SDPA backwards w/ Dropout, gradients for masks are not enabled in this PR, will create separate PRs.

  • Rebase to latest implementation of fused sdpa training #4498

  • Currently there are still some correctness issue, but it doesn't seem to be computation error.

DNNL_VERBOSE=1 ./tests/benchdnn/benchdnn --graph --engine=gpu  --case=../tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training
-backward-bf16-f32.json
onednn_verbose,v1,info,oneDNN v3.12.0 (commit 6caca25bb9380200bf99de51b2871dcc10c06a98)
.....
onednn_verbose,v1,graph,exec,gpu,100002,sdp,bmm1;scale_mul;mask_add;subtract;exp;bmm_dprobs;mul_o_do;reducesum_correction;sub_dp_corrected;mul_softmax_bwd;scale_mul;typecast;bmm_dq;bmm_dk;typecast;transpose_dk;bmm_dv;reduce_dk;reduce_dv,,in0_bf16:100:strided:undef:2x2x8x128x64:131072s65536s8192s64s1 in1_bf16:101:strided:undef:2x2x1x64x128:16384s8192s8192s128s1 in2_bf16:102:strided:undef:1:1 in3_bf16:103:strided:undef:2x2x8x128x128:262144s131072s16384s128s1 in4_f32:8:strided:undef:2x2x8x128x1:2048s1024s128s1s1 in5_bf16:105:strided:undef:2x2x8x128x64:131072s65536s8192s64s1 in6_bf16:104:strided:undef:2x2x1x128x64:16384s8192s8192s64s1 in7_bf16:10:strided:undef:2x2x8x128x64:131072s65536s8192s64s1 in8_bf16:105:strided:undef:2x2x8x128x64:131072s65536s8192s64s1 in9_bf16:102:strided:undef:1:1 in10_bf16:101:strided:undef:2x2x1x64x128:16384s8192s8192s128s1 in11_bf16:100:strided:undef:2x2x8x128x64:131072s65536s8192s64s1 in12_bf16:105:strided:undef:2x2x8x128x64:131072s65536s8192s64s1 out0_bf16:31:strided:undef:2x2x8x128x64:131072s65536s8192s64s1 out1_bf16:37:strided:undef:2x2x1x64x128:16384s8192s8192s128s1 out2_bf16:15:strided:undef:2x2x1x128x64:16384s8192s8192s64s1,fpm:strict,sdp_bwd_primitive_kernel_t,dnnl_backend,1.00293
onednn_verbose,v1,primitive,exec,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:abcde::f0,,,2x2x1x128x64,0.0209961
onednn_verbose,v1,primitive,exec,cpu,reorder,jit:uni,undef,src:bf16::blocked:abcde::f0 dst:f32::blocked:abcde::f0,,,2x2x1x128x64,0.0209961
onednn_verbose,v1,primitive,exec,cpu,reorder,jit_direct_copy:uni,undef,src:f32::blocked:abcde::f0 dst:f32::blocked:abcde::f0,,,2x2x8x128x64,0.0620117
onednn_verbose,v1,primitive,exec,cpu,reorder,jit:uni,undef,src:bf16::blocked:abcde::f0 dst:f32::blocked:abcde::f0,,,2x2x8x128x64,0.0620117
[ 360][0:0:0:5:40] exp_f32:    -2.74586 exp:       -2.75 got:    -2.71875 diff: 0.03125 rdiff:0.0113636
[ 367][0:0:0:5:47] exp_f32:     2.75372 exp:        2.75 got:     2.71875 diff: 0.03125 rdiff:0.0113636
[ 394][0:0:0:6:10] exp_f32:    -2.98334 exp:    -2.98438 got:    -3.01562 diff: 0.03125 rdiff:0.0104712
[ 395][0:0:0:6:11] exp_f32:     3.11424 exp:     3.10938 got:     3.14062 diff: 0.03125 rdiff:0.0100503
[ 396][0:0:0:6:12] exp_f32:    -3.04886 exp:    -3.04688 got:    -3.07812 diff: 0.03125 rdiff:0.0102564
[ 400][0:0:0:6:16] exp_f32:      2.9805 exp:     2.98438 got:     3.01562 diff: 0.03125 rdiff:0.0104712
[ 441][0:0:0:6:57] exp_f32:     3.04877 exp:     3.04688 got:     3.07812 diff: 0.03125 rdiff:0.0102564
[ 445][0:0:0:6:61] exp_f32:     3.04584 exp:     3.04688 got:     3.07812 diff: 0.03125 rdiff:0.0102564
[ 896][0:0:0:14:0] exp_f32:     -10.651 exp:     -10.625 got:      -10.75 diff:   0.125 rdiff:0.0117647
[ 938][0:0:0:14:42] exp_f32:     10.6508 exp:      10.625 got:       10.75 diff:   0.125 rdiff:0.0117647
[COMPARE_STATS]: trh=0 err_max_diff:  0.1875 err_max_rdiff: 17.4051 all_max_diff:  0.1875 all_max_rdiff: 90.9556
[COMPARE_STATS] Norm check is prohibited; error_to_total_ratio: 8232/262144; allowed_ratio: 256/262144;
Error: Function 'doit' at (/home/gta/yixin/oneDNN/tests/benchdnn/graph/graph.cpp:773) returned '1'
0:FAILED (errors:8232 total:262144) (1308 ms) __REPRO: --graph --engine=gpu --case=../tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json
===========================================================
= Failed cases summary (--summary=no-failures to disable) =
===========================================================
0:FAILED (errors:8232 total:262144) (1308 ms) __REPRO: --graph --engine=gpu --case=../tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json
============================
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.32s; create_pd: 0.62s (47%); create_prim: 0.15s (11%); fill: 0.00s (0%); execute: 0.06s (4%); compute_ref: 0.00s (0%); compare: 0.00s (0%);

@ElaineBao ElaineBao self-assigned this Mar 12, 2026
@ElaineBao ElaineBao added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Mar 12, 2026
@ElaineBao ElaineBao force-pushed the yixin/sdpa_ukernel_train branch from 01a743b to 3a0e59e Compare March 12, 2026 08:00
@github-actions github-actions bot added platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel component:tests Codeowner: @oneapi-src/onednn-arch component:examples component:common labels Mar 12, 2026
@ElaineBao ElaineBao force-pushed the yixin/sdpa_ukernel_train branch 3 times, most recently from c6df869 to 3011cd7 Compare March 13, 2026 04:51
@ElaineBao ElaineBao marked this pull request as ready for review March 13, 2026 04:52
@ElaineBao ElaineBao requested review from a team as code owners March 13, 2026 04:52
@ElaineBao ElaineBao force-pushed the yixin/sdpa_ukernel_train branch from 3011cd7 to 20240af Compare March 13, 2026 05:00
BACKEND_DNNL_ADD_PASS(pipeline, fuse_mul_sigmoid_to_swish);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_to_dnnl_sum);
BACKEND_DNNL_ADD_PASS(pipeline, fuse_to_shuffle);
BACKEND_DNNL_ADD_PASS(pipeline, decompose_softmax);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this affect all patterns containing softmax?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I think so, I'll need to update patterns that contain softmax, currently there's only one of such patterns - single softmax pattern.

@TaoLv TaoLv force-pushed the yixin/sdpa_ukernel_train branch from 20240af to 5638f5f Compare March 19, 2026 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component:common component:examples component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants