@@ -150,20 +150,35 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
150
150
# TODO(lucas): handle this more gracefully
151
151
# Note: model_config may be None during testing
152
152
if model_config is not None and model_config .use_mla :
153
- # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
154
- # we default to FlashMLA backend, so we need to force the blocksize
155
- # here
156
- use_flashmla = (envs .VLLM_ATTENTION_BACKEND is None \
157
- or envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" )
153
+ # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
154
+ # then we default to FlashMLA backend for non-blackwell GPUs,
155
+ # else we default to CutlassMLA. For each case, we force the
156
+ # required block_size.
157
+ use_flashmla = False
158
+ use_cutlass_mla = False
159
+
160
+ if envs .VLLM_ATTENTION_BACKEND is None :
161
+ # Default case
162
+ if cls .is_device_capability (100 ):
163
+ # Blackwell => Force CutlassMLA.
164
+ use_cutlass_mla = True
165
+ envs .VLLM_ATTENTION_BACKEND = "CUTLASS_MLA_VLLM_V1"
166
+ else :
167
+ # Not Blackwell
168
+ use_flashmla = True
169
+ else :
170
+ # Forced case
171
+ use_flashmla = (envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" )
172
+ use_cutlass_mla = (
173
+ envs .VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1" )
174
+
158
175
from vllm .attention .ops .flashmla import is_flashmla_supported
159
176
if use_flashmla and is_flashmla_supported ()[0 ] \
160
177
and cache_config .block_size != 64 :
161
178
cache_config .block_size = 64
162
179
logger .info (
163
180
"Forcing kv cache block size to 64 for FlashMLA backend." )
164
181
165
- use_cutlass_mla = (envs .VLLM_ATTENTION_BACKEND is not None \
166
- and envs .VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1" )
167
182
if use_cutlass_mla and cache_config .block_size != 128 :
168
183
cache_config .block_size = 128
169
184
logger .info ("Forcing kv cache block size to 128 for "
0 commit comments