@@ -198,7 +198,8 @@ def __init__(self, device: torch.device):
198
198
199
199
200
200
def run_attention_backend (backend : _Backend , kv_cache_spec : FullAttentionSpec ,
201
- vllm_config , device : torch .device ,
201
+ layer_names : list [str ], vllm_config ,
202
+ device : torch .device ,
202
203
common_attn_metadata : CommonAttentionMetadata ,
203
204
query : torch .Tensor , key : torch .Tensor ,
204
205
value : torch .Tensor ,
@@ -211,31 +212,33 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
211
212
if backend == _Backend .FLASHINFER_VLLM_V1 :
212
213
import unittest .mock
213
214
214
- from vllm .v1 .attention .backends .flashinfer import PerLayerParameters
215
+ from vllm .v1 .attention .backends .utils import PerLayerParameters
215
216
216
- def mock_get_per_layer_parameters (vllm_config , impl_cls ):
217
+ def mock_get_per_layer_parameters (vllm_config , layer_names , impl_cls ):
217
218
# Return mock parameters for a single layer
218
219
head_size = vllm_config .model_config .get_head_size ()
219
220
return {
220
- "mock_layer" :
221
+ layer_name :
221
222
PerLayerParameters (
222
223
window_left = - 1 , # No sliding window
223
224
logits_soft_cap = 0.0 , # No soft cap
224
225
sm_scale = 1.0 / (head_size ** 0.5 ) # Standard scale
225
226
)
227
+ for layer_name in layer_names
226
228
}
227
229
228
230
with unittest .mock .patch (
229
231
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters' ,
230
232
mock_get_per_layer_parameters ):
231
- builder = builder_cls (kv_cache_spec , vllm_config , device )
233
+ builder = builder_cls (kv_cache_spec , layer_names , vllm_config ,
234
+ device )
232
235
attn_metadata = builder .build (
233
236
common_prefix_len = 0 ,
234
237
common_attn_metadata = common_attn_metadata ,
235
238
)
236
239
else :
237
240
# Build metadata
238
- builder = builder_cls (kv_cache_spec , vllm_config , device )
241
+ builder = builder_cls (kv_cache_spec , layer_names , vllm_config , device )
239
242
attn_metadata = builder .build (
240
243
common_prefix_len = 0 ,
241
244
common_attn_metadata = common_attn_metadata ,
@@ -427,8 +430,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
427
430
set_kv_cache_layout ("HND" )
428
431
429
432
backend_output = run_attention_backend (backend_name , kv_cache_spec ,
430
- vllm_config , device ,
431
- common_attn_metadata ,
433
+ [ "placeholder" ], vllm_config ,
434
+ device , common_attn_metadata ,
432
435
query_vllm , key_vllm ,
433
436
value_vllm ,
434
437
kv_cache_for_backend )
0 commit comments