Skip to content

Commit c0858f1

Browse files
alexm-redhatmgoin
authored andcommitted
[Attention] Make CutlassMLA the default backend for SM100 (blackwell) (vllm-project#21626)
Signed-off-by: Alexander Matveev <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
1 parent 69fc3b2 commit c0858f1

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

vllm/platforms/cuda.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,35 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
150150
# TODO(lucas): handle this more gracefully
151151
# Note: model_config may be None during testing
152152
if model_config is not None and model_config.use_mla:
153-
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
154-
# we default to FlashMLA backend, so we need to force the blocksize
155-
# here
156-
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
157-
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
153+
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
154+
# then we default to FlashMLA backend for non-blackwell GPUs,
155+
# else we default to CutlassMLA. For each case, we force the
156+
# required block_size.
157+
use_flashmla = False
158+
use_cutlass_mla = False
159+
160+
if envs.VLLM_ATTENTION_BACKEND is None:
161+
# Default case
162+
if cls.is_device_capability(100):
163+
# Blackwell => Force CutlassMLA.
164+
use_cutlass_mla = True
165+
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA_VLLM_V1"
166+
else:
167+
# Not Blackwell
168+
use_flashmla = True
169+
else:
170+
# Forced case
171+
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
172+
use_cutlass_mla = (
173+
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1")
174+
158175
from vllm.attention.ops.flashmla import is_flashmla_supported
159176
if use_flashmla and is_flashmla_supported()[0] \
160177
and cache_config.block_size != 64:
161178
cache_config.block_size = 64
162179
logger.info(
163180
"Forcing kv cache block size to 64 for FlashMLA backend.")
164181

165-
use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND is not None \
166-
and envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1")
167182
if use_cutlass_mla and cache_config.block_size != 128:
168183
cache_config.block_size = 128
169184
logger.info("Forcing kv cache block size to 128 for "

0 commit comments

Comments
 (0)