Skip to content

Optimize layout for SubgroupMatrixLoad on Intel #25384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jchen10
Copy link
Contributor

@jchen10 jchen10 commented Jul 14, 2025

This introduces a new LayoutProgram to pre-process the input matrix A, converting it to a layout that is more efficient for the SubgroupMatrixLoad operation on Intel GPUs.

@jchen10
Copy link
Contributor Author

jchen10 commented Jul 14, 2025

On LNL with the latest driver 32.0.101.6913, Prefill can reach 828 tps.

model_benchmark.exe -i ..\models\phi3.5-web-accuracy4-gqa --prompt_file prompt.txt -g 128 -r 10
Batch size: 1, prompt tokens: 1024, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.23556e+06
        avg (tokens/s): 828.772
        p50 (us):       1.24797e+06
        stddev (us):    23209
        n:              10 * 1024 token(s)
Token generation:
        avg (us):       35356.1
        avg (tokens/s): 28.2837
        p50 (us):       35227.7
        stddev (us):    1167.08
        n:              1270 * 1 token(s)
Token sampling:
        avg (us):       10.38
        avg (tokens/s): 96339.1
        p50 (us):       7.5
        stddev (us):    6.6114
        n:              10 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       5725.89
        p50 (ms):       5728.04
        stddev (ms):    25.9712
        n:              10
Peak working set size (bytes): 4416569344

@xhcao @JianhuiD PTAL

This introduces a new LayoutProgram to pre-process the input matrix A,
converting it to a layout that is more efficient for the
SubgroupMatrixLoad operation on Intel GPUs.
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work, Jie!

If I understand correctly, your layout shader is like below (assume s0 is subgroup matrix 0):
Input [32x64] with subgroup matrix[8x16]

s0, s1, s2, s3,
s4, s5, s6, s7
s8, s9, s10, s11
s12, s13, s14, s15

output: [128, 16] with subgroup matrix[8x16]

s0,
s1,
s2,
s3, 
s4,
s5,
s6,
s7,
s8,
s9,
s10,
s11,
s12,
s13,
s14,
s15,

This change ensures that each subgroup's data are contiguous in memory. I am wondering will it further help the performance if we reassign the layout like below?

s0,
s4,
s8,
s12,
s1,
s5,
s9,
s13,
s2,
s6,
s10,
s14,
s3,
s7,
s11,
s15

This can make sure the all subgroups in one workgroup are accessing contiguous data in memory instead of one subgroup. Just curious about the result.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Jul 14, 2025
@jchen10
Copy link
Contributor Author

jchen10 commented Jul 15, 2025

Excellent work, Jie!

If I understand correctly, your layout shader is like below (assume s0 is subgroup matrix 0): Input [32x64] with subgroup matrix[8x16]

s0, s1, s2, s3,
s4, s5, s6, s7
s8, s9, s10, s11
s12, s13, s14, s15

output: [128, 16] with subgroup matrix[8x16]

s0,
s1,
s2,
s3, 
s4,
s5,
s6,
s7,
s8,
s9,
s10,
s11,
s12,
s13,
s14,
s15,

This change ensures that each subgroup's data are contiguous in memory. I am wondering will it further help the performance if we reassign the layout like below?

s0,
s4,
s8,
s12,
s1,
s5,
s9,
s13,
s2,
s6,
s10,
s14,
s3,
s7,
s11,
s15

This can make sure the all subgroups in one workgroup are accessing contiguous data in memory instead of one subgroup. Just curious about the result.

Good point. I tried this approach, unfortunately it improved slightly, less than 20 tps.

@jchen10
Copy link
Contributor Author

jchen10 commented Jul 15, 2025

Without the PR, the perf data is as below. So the improvement is 628->828, 32%.

model_benchmark.exe -i ..\models\phi3.5-web-accuracy4-gqa --prompt_file prompt.txt -g 128 -r 10
Batch size: 1, prompt tokens: 1024, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.62918e+06
        avg (tokens/s): 628.536
        p50 (us):       1.64457e+06
        stddev (us):    26716.1
        n:              10 * 1024 token(s)
Token generation:
        avg (us):       35091.2
        avg (tokens/s): 28.4971
        p50 (us):       35000.6
        stddev (us):    1001.65
        n:              1270 * 1 token(s)
Token sampling:
        avg (us):       7.29
        avg (tokens/s): 137174
        p50 (us):       7.3
        stddev (us):    0.276687
        n:              10 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       6085.87
        p50 (ms):       6096.89
        stddev (ms):    34.2907
        n:              10
Peak working set size (bytes): 4398272512

qjia7
qjia7 previously approved these changes Jul 15, 2025
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@jchen10
Copy link
Contributor Author

jchen10 commented Jul 15, 2025

@sushraja-msft PTAL

guschmue
guschmue previously approved these changes Jul 15, 2025
sushraja-msft
sushraja-msft previously approved these changes Jul 15, 2025
Copy link
Contributor

@sushraja-msft sushraja-msft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise

@guschmue
Copy link
Contributor

CI nagging: run 'lintrunner -a'

@jchen10 jchen10 dismissed stale reviews from sushraja-msft, guschmue, and qjia7 via 291f374 July 16, 2025 05:10
@jchen10
Copy link
Contributor Author

jchen10 commented Jul 16, 2025

CI nagging: run 'lintrunner -a'

Done, thanks!

qjia7
qjia7 previously approved these changes Jul 17, 2025
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more comments. Still LGTM for others.

@jchen10
Copy link
Contributor Author

jchen10 commented Jul 18, 2025

One more comments. Still LGTM for others.

Done, thanks!

@fs-eire @guschmue Could you please help kick off the CI, thanks!

@jchen10 jchen10 requested a review from guschmue July 21, 2025 14:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants