[MHA] add mha dispatch logic #776
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Purpose
Two different MHA kernel implementations are available on ROCm GPUs, triton kernel and aiter kernels (including ck and asm, dispatched inside AITER). We benchmark the performance of them and observe the gap in various cases, based on which a simple dispatch logic is added in this pr.
Benchmark result
We benchmark the shapes from deepseek-v3 in TP8 scenario on MI355 for now, i.e., num_heads=16, qk_head_dim=192, v_head_dim=128. The seq_len and batch size range from 1k to 64k and 1 to 64, respectively.
Here aiter kernel actually corresponds to the FA3 asm kernel. Basically, the asm kernel demonstrates superior performance for seq_len of 4k and above, while triton kernel performs better with relatively short seq_len like 1k.
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.