diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 016102b814ed..6d623fb2c097 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -21,7 +21,7 @@ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. +* Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. @@ -79,7 +79,7 @@ torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v -) +) return spda_o @ W_O NOTE: in the actual code, @@ -117,20 +117,20 @@ ## Chunked Prefill -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: -MCC Max chunk of context to process per iter, computed dynamically, +MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ @@ -152,7 +152,7 @@ new_v, casual=True, return_softmax_lse=True -) +) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): @@ -193,15 +193,21 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, + MLAAttentionImpl, +) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + RowParallelLinear, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down @@ -209,10 +215,12 @@ try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func + is_vllm_fa = False if TYPE_CHECKING: @@ -224,6 +232,7 @@ from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat_and_cache_mla if envs.VLLM_AITER_TRITON_FP8_BMM: + def dynamic_per_batched_tensor_quant( x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn ): @@ -234,18 +243,28 @@ def dynamic_per_batched_tensor_quant( x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, + ) + # @torch.compiler.disable - def aiter_triton_fp8_bmm_wrapper(x, w, w_s, group_size = 128, y = None, transpose_bm = False): + def aiter_triton_fp8_bmm_wrapper( + x, w, w_s, group_size=128, y=None, transpose_bm=False + ): if y is not None: - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(x, w, w_s, group_size = group_size, YQ=y, transpose_bm=transpose_bm) + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + x, w, w_s, group_size=group_size, YQ=y, transpose_bm=transpose_bm + ) else: - y = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(x, w, w_s, group_size = group_size, transpose_bm = transpose_bm) + y = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + x, w, w_s, group_size=group_size, transpose_bm=transpose_bm + ) return y - + + if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: from aiter.ops.triton.fused_concat_zeros import fused_concat_zeros - + logger = init_logger(__name__) @@ -285,7 +304,7 @@ def use_cascade_attention(*args, **kwargs) -> bool: @dataclass class MLACommonPrefillMetadata: - """ Prefill Specific Metadata """ + """Prefill Specific Metadata""" @dataclass class ChunkedContextMetadata: @@ -325,6 +344,7 @@ class MLACommonMetadata(Generic[D]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -354,11 +374,11 @@ class MLACommonMetadata(Generic[D]): def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: + if self.head_dim is not None and self.head_dim not in supported_head_sizes: raise ValueError( f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") + f"received {self.head_dim}.", + ) M = TypeVar("M", bound=MLACommonMetadata) @@ -370,18 +390,18 @@ class MLACommonMetadataBuilder(Generic[M]): understand this class """ - def __init__(self, - runner: "GPUModelRunner", - metadata_cls: Optional[type[M]] = None): - self.metadata_cls = metadata_cls \ - if metadata_cls is not None else MLACommonMetadata + def __init__( + self, runner: "GPUModelRunner", metadata_cls: Optional[type[M]] = None + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else MLACommonMetadata + ) self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config cache_config = runner.cache_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - self.num_heads = model_config.get_num_attention_heads( - runner.parallel_config) + self.num_heads = model_config.get_num_attention_heads(runner.parallel_config) self.mla_dims = get_mla_dims(model_config) self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) @@ -394,8 +414,9 @@ def __init__(self, # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request max( - 8 * model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), + 8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size, + ), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, # which would result in the workspace being: @@ -404,18 +425,21 @@ def __init__(self, # which would result in up-projected context being # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size + 128 * 1024, + ) + assert ( + self.chunked_prefill_workspace_size + >= scheduler_config.max_num_seqs * cache_config.block_size + ) self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - model_config.get_head_size()), + (self.chunked_prefill_workspace_size, model_config.get_head_size()), dtype=model_config.dtype, device=runner.device, ) - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -474,14 +498,17 @@ def reorder_batch(self, input_batch: "InputBatch", return modified_batch - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, - work_indptr: torch.Tensor, - work_info_set: torch.Tensor, - reduce_indptr: torch.Tensor, - reduce_final_map: torch.Tensor, - reduce_partial_map: torch.Tensor, - seq_lens: torch.Tensor): + def _build_decode( + self, + input_positions: torch.Tensor, + block_table: torch.Tensor, + work_indptr: torch.Tensor, + work_info_set: torch.Tensor, + reduce_indptr: torch.Tensor, + reduce_final_map: torch.Tensor, + reduce_partial_map: torch.Tensor, + seq_lens: torch.Tensor, + ): return MLACommonDecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -493,22 +520,33 @@ def _build_decode(self, input_positions: torch.Tensor, seq_lens=seq_lens, ) - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> M: + def build( + self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, + common_prefix_len: int, + ) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() + block_table = self.runner.input_batch.block_table.get_device_tensor()[:num_reqs] + query_start_loc = self.runner.query_start_loc_cpu[: num_reqs + 1].to( + device, non_blocking=True + ) + slot_mapping = ( + self.runner.slot_mapping_cpu[:num_actual_tokens] + .to(device, non_blocking=True) + .long() + ) + input_positions = ( + self.runner.positions_cpu[:num_actual_tokens] + .to(device, non_blocking=True) + .long() + ) seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(device, non_blocking=True) @@ -518,16 +556,21 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, reqs_start = self._num_decodes # prefill_start tokens_start = self._num_decode_tokens - context_lens_cpu = self.runner.input_batch.\ - num_computed_tokens_cpu_tensor[reqs_start:num_reqs] + context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[ + reqs_start:num_reqs + ] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) chunked_context_metadata = None - if self.chunked_prefill_enabled and self._num_prefills > 0 \ - and max_context_len_cpu > 0: + if ( + self.chunked_prefill_enabled + and self._num_prefills > 0 + and max_context_len_cpu > 0 + ): # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code @@ -536,15 +579,15 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + ) if self.aot_schedule: # align max_context_chunk to page_size by rounding down, # currently the `gather_cache` kernel cannot handle # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) @@ -555,34 +598,41 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] # Note(simon): this is done in CPU because of downstream's # of `to_list`. - chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) \ + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, self._num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) + cu_seq_lens_cpu = torch.zeros( + num_chunks, + self._num_prefills + 1, + dtype=torch.int32, + pin_memory=True, + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) - chunked_context_metadata = \ + chunked_context_metadata = ( MLACommonPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - workspace=self.chunked_prefill_workspace, + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + workspace=self.chunked_prefill_workspace, + ) ) - assert max(chunked_context_metadata.max_seq_lens) <= \ - self.chunked_prefill_workspace_size + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], @@ -595,9 +645,14 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, decode_metadata = None if self._num_decodes > 0: decode_metadata = self._build_decode( - input_positions=input_positions[:self._num_decode_tokens], - block_table=block_table[:self._num_decodes, ...], - seq_lens=seq_lens[:self._num_decodes], + input_positions=input_positions[: self._num_decode_tokens], + block_table=block_table[: self._num_decodes, ...], + work_indptr=self.runner.work_indptr, + work_info_set=self.runner.work_info_set, + reduce_indptr=self.runner.reduce_indptr, + reduce_final_map=self.runner.reduce_final_map, + reduce_partial_map=self.runner.reduce_partial_map, + seq_lens=seq_lens[: self._num_decodes], ) return self.metadata_cls( @@ -668,8 +723,11 @@ def __init__( # if current_platform.is_cuda(): # self.rotary_emb = rotary_emb.forward_cuda - self.use_rocm_aiter = current_platform.is_rocm( - ) and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_ROPE + self.use_rocm_aiter = ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_ROPE + ) if self.use_rocm_aiter: self.rotary_emb = rotary_emb.forward_hip @@ -679,7 +737,7 @@ def __init__( self.rotary_emb = rotary_emb.forward_native if current_platform.is_cuda(): self.rotary_emb = rotary_emb.forward_cuda - self.cos_cache, self.sin_cache = rotary_emb.cos_sin_cache.chunk(2, dim = -1) + self.cos_cache, self.sin_cache = rotary_emb.cos_sin_cache.chunk(2, dim=-1) self.rotary_emb_is_neox_style = rotary_emb.is_neox_style self.q_proj = q_proj @@ -693,9 +751,9 @@ def __init__( self.flash_attn_varlen_func = flash_attn_varlen_func self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) + self.flash_attn_varlen_func = functools.partial( + flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version + ) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do @@ -703,19 +761,17 @@ def __init__( # We don't need to pad V if we are on a hopper system with FA3 self._pad_v = self.vllm_flash_attn_version is None or not ( self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) - - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + and current_platform.get_device_capability()[0] == 9 + ) + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) + v, [0, q.shape[-1] - v.shape[-1]], value=0 + ) attn_out = self.flash_attn_varlen_func( q=q, @@ -733,7 +789,7 @@ def _flash_attn_varlen_diff_headdims(self, # unpad if necessary if self._pad_v: - attn_out = attn_out[..., :v.shape[-1]] + attn_out = attn_out[..., : v.shape[-1]] # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. @@ -747,7 +803,9 @@ def _v_up_proj_and_o_proj(self, x): if envs.VLLM_AITER_TRITON_FP8_BMM: # Multiply + Transpose (N, B, L) x (N, L, V) -> (N, B, V) -> (B, N, V) # print(f"{x.dtype=}") - x = aiter_triton_fp8_bmm_wrapper(x, self.W_V, self.W_V_scale, group_size = 128, transpose_bm = True) + x = aiter_triton_fp8_bmm_wrapper( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) else: @@ -759,15 +817,19 @@ def _v_up_proj_and_o_proj(self, x): # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ + q_nope, q_pe = ( + self.q_proj(x)[0] + .view(-1, self.num_heads, self.qk_head_dim) .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) if envs.VLLM_AITER_TRITON_FP8_BMM: # Multiply + Transpose (N, B, P) x (N, P, L) -> (N, B, L) -> (B, N, L) - ql_nope = aiter_triton_fp8_bmm_wrapper(q_nope, self.W_K, self.W_K_scale, group_size = 128, transpose_bm = True) + ql_nope = aiter_triton_fp8_bmm_wrapper( + q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) else: # Multiply (N, B, P) x (N, P, L) -> (N, B, L) ql_nope = torch.bmm(q_nope, self.W_UK_T) @@ -784,17 +846,18 @@ def get_layer_weight(layer): return getattr(layer, attr) raise AttributeError( f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f" {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T @@ -807,12 +870,14 @@ def get_and_maybe_dequant_weights(layer: LinearBase): assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -820,8 +885,9 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT: kv_cache_size = 8192 max_position_embedding = self.cos_cache.shape[0] @@ -830,21 +896,56 @@ def get_and_maybe_dequant_weights(layer: LinearBase): if decode_batch_size > prefill_decode_size: continue - k_scale = torch.ones([1,], dtype=torch.float32, device=W_UK.device)[0] - - q = torch.empty((decode_batch_size, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) - decode_ql_nope = q[..., :self.kv_lora_rank] - decode_q_pe = q[..., self.kv_lora_rank:] - - k = torch.empty((prefill_decode_size, 1, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) - k_c_normed = k[..., :self.kv_lora_rank].squeeze(1) - k_pe = k[..., self.kv_lora_rank:] + k_scale = torch.ones( + [ + 1, + ], + dtype=torch.float32, + device=W_UK.device, + )[0] + + q = torch.empty( + ( + decode_batch_size, + self.num_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=torch.bfloat16, + device=W_UK.device, + ) + decode_ql_nope = q[..., : self.kv_lora_rank] + decode_q_pe = q[..., self.kv_lora_rank :] + + k = torch.empty( + ( + prefill_decode_size, + 1, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + dtype=torch.bfloat16, + device=W_UK.device, + ) + k_c_normed = k[..., : self.kv_lora_rank].squeeze(1) + k_pe = k[..., self.kv_lora_rank :] + + input_positions = torch.randint( + 0, + max_position_embedding, + (decode_batch_size,), + device=W_UK.device, + ) + slot_mapping = torch.randperm(kv_cache_size, device=W_UK.device)[ + :prefill_decode_size + ] + kv_cache = torch.empty( + (kv_cache_size, 1, self.kv_lora_rank + self.qk_rope_head_dim), + dtype=torch.bfloat16, + device=W_UK.device, + ) - input_positions = torch.randint(0, max_position_embedding, (decode_batch_size, ), device=W_UK.device) - slot_mapping = torch.randperm(kv_cache_size, device=W_UK.device)[:prefill_decode_size] - kv_cache = torch.empty((kv_cache_size, 1, self.kv_lora_rank + self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) - - logger.info(f"[Triton] compiling fused_qk_rope_cat_and_cache_mla with (decode tokens, total tokens) = ({decode_batch_size}, {prefill_decode_size})") + logger.info( + f"[Triton] compiling fused_qk_rope_cat_and_cache_mla with (decode tokens, total tokens) = ({decode_batch_size}, {prefill_decode_size})" + ) fused_qk_rope_cat_and_cache_mla( decode_ql_nope, decode_q_pe, @@ -857,32 +958,62 @@ def get_and_maybe_dequant_weights(layer: LinearBase): self.sin_cache, k_scale, self.rotary_emb_is_neox_style, - output_q_nope_zeros=True + output_q_nope_zeros=True, ) if envs.VLLM_AITER_TRITON_FUSED_CONCAT_ZEROS: max_batch_size = 256 - logger.info(f"[Triton] compiling fused_concat_zeros with shape = [1~{max_batch_size}] {self.num_heads} [{self.kv_lora_rank} : {self.qk_rope_head_dim}]") - for m in range(1, max_batch_size+1): - x1 = torch.empty((m, self.num_heads, self.kv_lora_rank), dtype=torch.bfloat16, device=W_UK.device) - x2 = torch.empty((m, self.num_heads, self.qk_rope_head_dim), dtype=torch.bfloat16, device=W_UK.device) + logger.info( + f"[Triton] compiling fused_concat_zeros with shape = [1~{max_batch_size}] {self.num_heads} [{self.kv_lora_rank} : {self.qk_rope_head_dim}]" + ) + for m in range(1, max_batch_size + 1): + x1 = torch.empty( + (m, self.num_heads, self.kv_lora_rank), + dtype=torch.bfloat16, + device=W_UK.device, + ) + x2 = torch.empty( + (m, self.num_heads, self.qk_rope_head_dim), + dtype=torch.bfloat16, + device=W_UK.device, + ) fused_concat_zeros(x1, x2) if envs.VLLM_AITER_TRITON_FP8_BMM: max_batch_size = 256 - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(W_K, dtype=torch.float8_e4m3fnuz) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(W_V, dtype=torch.float8_e4m3fnuz) - logger.info(f"[Triton] compiling fp8 BMM with shape = {self.W_K.shape[0]} [1~{max_batch_size}] {self.W_K.shape[1]} {self.W_K.shape[2]}") - logger.info(f"[Triton] compiling fp8 BMM with shape = {self.W_V.shape[0]} [1~{max_batch_size}] {self.W_V.shape[1]} {self.W_V.shape[2]}") - for m in range(1, max_batch_size+1): - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), dtype=torch.bfloat16, device=self.W_K.device) - aiter_triton_fp8_bmm_wrapper(x, self.W_K, self.W_K_scale, group_size = 128, transpose_bm = True) - - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), dtype=torch.bfloat16, device=self.W_V.device) - aiter_triton_fp8_bmm_wrapper(x, self.W_V, self.W_V_scale, group_size = 128, transpose_bm = True) - + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=torch.float8_e4m3fnuz + ) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=torch.float8_e4m3fnuz + ) + logger.info( + f"[Triton] compiling fp8 BMM with shape = {self.W_K.shape[0]} [1~{max_batch_size}] {self.W_K.shape[1]} {self.W_K.shape[2]}" + ) + logger.info( + f"[Triton] compiling fp8 BMM with shape = {self.W_V.shape[0]} [1~{max_batch_size}] {self.W_V.shape[1]} {self.W_V.shape[2]}" + ) + for m in range(1, max_batch_size + 1): + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm_wrapper( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm_wrapper( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) + else: # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1) @@ -915,21 +1046,17 @@ def _compute_prefill_context( seq_starts=prefill_metadata.chunked_context.starts[i], ) - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - attn_output, attn_softmax_lse = \ - self._flash_attn_varlen_diff_headdims( + attn_output, attn_softmax_lse = self._flash_attn_varlen_diff_headdims( q=q, k=k, v=v, @@ -972,10 +1099,10 @@ def _forward_prefill( assert attn_metadata.prefill is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) @@ -994,8 +1121,9 @@ def _forward_prefill( if has_context: suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + context_output, context_lse = self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata + ) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1047,9 +1175,11 @@ def forward( # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -1064,37 +1194,47 @@ def forward( if has_decode: assert attn_metadata.decode is not None - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_hs_or_q_c) if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT: pass # the rope operator for decode is now fused with concat_and_cache_mla operator using fused_qk_rope_cat_and_cache_mla elif self.use_rocm_aiter: - self.rotary_emb(attn_metadata.decode.input_positions, - decode_q_pe, decode_k_pe) + self.rotary_emb( + attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe + ) else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, - decode_q_pe.contiguous(), decode_k_pe) + decode_q_pe.contiguous(), + decode_k_pe, + ) if has_prefill: assert attn_metadata.prefill is not None - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q = self.q_proj(prefill_hs_or_q_c)[0].view( + -1, self.num_heads, self.qk_head_dim + ) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] if self.use_rocm_aiter: - self.rotary_emb(attn_metadata.prefill.input_positions, - prefill_q_pe, prefill_k_pe) + self.rotary_emb( + attn_metadata.prefill.input_positions, prefill_q_pe, prefill_k_pe + ) else: prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( attn_metadata.prefill.input_positions, - prefill_q_pe.contiguous(), prefill_k_pe) + prefill_q_pe.contiguous(), + prefill_k_pe, + ) # write the latent and rope to kv cache q_nope_pe, q_nope_zeros = None, None - if envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT and has_decode and kv_cache.numel() > 0: + if ( + envs.VLLM_AITER_TRITON_FUSED_ROPE_CACHE_CONCAT + and has_decode + and kv_cache.numel() > 0 + ): q_nope_pe, q_nope_zeros = fused_qk_rope_cat_and_cache_mla( decode_ql_nope, decode_q_pe, @@ -1107,7 +1247,7 @@ def forward( self.sin_cache, layer._k_scale, self.rotary_emb_is_neox_style, - output_q_nope_zeros=True + output_q_nope_zeros=True, ) elif kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -1121,11 +1261,17 @@ def forward( if has_prefill: output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata) + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata + ) if has_decode: output[:num_decode_tokens] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, q_nope_pe=q_nope_pe, q_nope_zeros=q_nope_zeros) + decode_ql_nope, + decode_q_pe, + kv_cache, + attn_metadata, + q_nope_pe=q_nope_pe, + q_nope_zeros=q_nope_zeros, + ) return output_padded