Skip to content

Commit 87d319c

Browse files
authored
[AMD][CI] Support Triton attention with ExampleConnector (vllm-project#34931)
Signed-off-by: Ryan Rock <ryan.rock@amd.com>
1 parent a9ec392 commit 87d319c

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

tests/v1/kv_connector/unit/test_example_connector.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from vllm import LLM, EngineArgs, SamplingParams
1010
from vllm.assets.image import ImageAsset
11-
from vllm.config import KVTransferConfig
11+
from vllm.config import AttentionConfig, KVTransferConfig
1212
from vllm.multimodal.utils import encode_image_url
1313
from vllm.platforms import current_platform
1414

@@ -110,14 +110,17 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
110110
print("-" * 50)
111111

112112

113-
@pytest.mark.skipif(
114-
current_platform.is_rocm(),
115-
reason=(
116-
"hipErrorLaunchFailure when running this test, see issue:"
117-
"https://github.com/ROCm/pytorch/issues/2822"
113+
@pytest.mark.parametrize(
114+
"attn_backend",
115+
(
116+
["FLASH_ATTN", "TRITON_ATTN"]
117+
if current_platform.is_cuda()
118+
else ["TRITON_ATTN"]
119+
if current_platform.is_rocm()
120+
else []
118121
),
119122
)
120-
def test_shared_storage_connector_hashes(tmp_path):
123+
def test_shared_storage_connector_hashes(tmp_path, attn_backend):
121124
"""
122125
Tests that ExampleConnector saves KV to the storage locations
123126
with proper hashes; that are unique for inputs with identical text but
@@ -138,6 +141,7 @@ def test_shared_storage_connector_hashes(tmp_path):
138141
max_model_len=8192,
139142
max_num_seqs=1,
140143
gpu_memory_utilization=0.4,
144+
attention_config=AttentionConfig(backend=attn_backend),
141145
enforce_eager=True,
142146
kv_transfer_config=kv_transfer_config,
143147
limit_mm_per_prompt={"image": 2},

tests/v1/kv_connector/unit/test_multi_connector.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
2121
NixlKVConnectorStats,
2222
)
23-
from vllm.platforms import current_platform
2423

2524
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
2625

@@ -97,13 +96,6 @@ def _compare_directories(dir1: Path, dir2: Path) -> bool:
9796
return True
9897

9998

100-
@pytest.mark.skipif(
101-
current_platform.is_rocm(),
102-
reason=(
103-
"hipErrorLaunchFailure when running this test, see issue:"
104-
"https://github.com/ROCm/pytorch/issues/2822"
105-
),
106-
)
10799
def test_multi_example_connector_consistency():
108100
"""
109101
Tests that MultiConnector with two ExampleConnectors saves

vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
1818
from vllm.utils.hashing import safe_hash
1919
from vllm.v1.attention.backend import AttentionMetadata
20+
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
2021
from vllm.v1.core.sched.output import SchedulerOutput
2122

2223
if TYPE_CHECKING:
@@ -118,12 +119,12 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> Non
118119
The number of elements in kv_caches and layer_names should be
119120
the same.
120121
"""
121-
attn_metadata = forward_context.attn_metadata
122122

123123
def inject_kv_into_layer(
124124
dst_kv_cache_layer: torch.Tensor,
125125
src_kv_cache: torch.Tensor,
126126
slot_mapping: torch.Tensor,
127+
attn_metadata: AttentionMetadata,
127128
) -> None:
128129
"""Inject the KV cache into the layer.
129130
@@ -145,6 +146,10 @@ def inject_kv_into_layer(
145146
num_pages * page_size, -1
146147
)
147148
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
149+
elif isinstance(attn_metadata, TritonAttentionMetadata):
150+
block_idxs = slot_mapping // self._block_size
151+
offsets = slot_mapping % self._block_size
152+
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
148153
else:
149154
num_pages = dst_kv_cache_layer_shape[1]
150155
page_size = dst_kv_cache_layer_shape[2]
@@ -186,7 +191,13 @@ def inject_kv_into_layer(
186191
layer_name, request.token_ids, request.mm_hashes
187192
)
188193
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
189-
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
194+
if isinstance(attn_metadata, dict):
195+
inject_kv_into_layer(
196+
kv_cache_layer,
197+
kv_cache,
198+
request.slot_mapping,
199+
attn_metadata[layer_name],
200+
)
190201

191202
def wait_for_layer_load(self, layer_name: str) -> None:
192203
"""Blocking until the KV for a specific layer is loaded into vLLM's
@@ -229,6 +240,10 @@ def extract_kv_from_layer(
229240
if isinstance(attn_metadata, MLACommonMetadata):
230241
num_pages, page_size = layer.shape[0], layer.shape[1]
231242
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
243+
elif isinstance(attn_metadata, TritonAttentionMetadata):
244+
block_idxs = slot_mapping // self._block_size
245+
offsets = slot_mapping % self._block_size
246+
return layer[block_idxs, :, offsets]
232247
num_pages, page_size = layer.shape[1], layer.shape[2]
233248
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
234249

0 commit comments

Comments
 (0)