Skip to content

Commit aae62b1

Browse files
committed
WIP enabling llama4 models
Signed-off-by: Artur Fierka <[email protected]>
1 parent e3dd6a6 commit aae62b1

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

vllm_gaudi/attention/backends/hpu_attn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2020
AttentionLayer,
21-
AttentionMetadata, AttentionType)
21+
AttentionMetadata, AttentionType, AttentionMetadataBuilder)
2222
from vllm.attention.backends.mla.common import MLACommonImpl
2323
from vllm.attention.backends.utils import CommonAttentionState
2424
from vllm_gaudi.attention.ops.hpu_paged_attn import (HPUPagedAttention,
@@ -47,6 +47,10 @@ def get_metadata_cls() -> type["AttentionMetadata"]:
4747
def get_state_cls() -> type["CommonAttentionState"]:
4848
return CommonAttentionState
4949

50+
@staticmethod
51+
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
52+
return HPUAttentionMetadataBuilder
53+
5054
@staticmethod
5155
def get_kv_cache_shape(
5256
num_blocks: int,

vllm_gaudi/attention/ops/hpu_paged_attn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99

1010
import torch
11+
from vllm.attention.backends.abstract import AttentionMetadataBuilder
1112
from vllm_gaudi.extension import cache_ops, ops
1213

1314
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -24,6 +25,23 @@ class HPUPagedAttentionMetadata:
2425
alibi_blocks: Optional[torch.Tensor]
2526

2627

28+
@dataclass
29+
class HPUPagedAttentionMetadataBuilder(AttentionMetadataBuilder[HPUPagedAttentionMetadata]):
30+
31+
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
32+
"""Create the builder, remember some configuration and parameters."""
33+
self.input_builder = input_builder
34+
35+
def prepare(self) -> None:
36+
"""Prepare for one batch."""
37+
pass
38+
39+
def build(self, seq_lens: list[int], query_lens: list[int],
40+
cuda_graph_pad_size: int, batch_size: int) -> HPUPagedAttentionMetadata:
41+
"""Build attention metadata with on-device tensors."""
42+
return HPUPagedAttentionMetadata
43+
44+
2745
class HPUPagedAttention:
2846

2947
@staticmethod

0 commit comments

Comments
 (0)