diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index fc573cf7cb..6bb1a9a129 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -15,10 +15,19 @@ # Global cudnn handle. need to make it per device in future _cudnn_handle = None +_dummy_scale_tensor = None + + +def _get_dummy_scale_tensor(device: torch.device): + global _dummy_scale_tensor + + _dummy_scale_tensor = torch.tensor([1.0], device=device, dtype=torch.float32) + return _dummy_scale_tensor def _create_cudnn_handle(stream: torch.cuda.Stream): global _cudnn_handle + if _cudnn_handle is None: _cudnn_handle = cudnn.create_handle() cudnn.set_stream(_cudnn_handle, stream.cuda_stream) @@ -49,6 +58,16 @@ class UIDs(Enum): O_UID = 1000 # Output tensor STATS_UID = 1001 # Stats tensor + Q_SCALE_UID = 150 # Query scale tensor + K_SCALE_UID = 151 # Key scale tensor + V_SCALE_UID = 152 # Value scale tensor + S_SCALE_UID = 153 # Scale tensor + S_DESCALE_UID = 154 # Descale tensor + O_SCALE_UID = 155 # Output scale tensor + + S_AMAX_UID = 160 # Scale amax tensor + O_AMAX_UID = 161 # Output amax tensor + def _sdpa_prefill_key_fn( q: torch.Tensor, @@ -136,6 +155,13 @@ def _build_prefill_graph( graph_s_qo = max_token_seq_q graph_s_kv = max_sequence_kv + if not cudnn.datatypes.is_torch_available(): + raise RuntimeError("torch is not available") + + cudnn_q_data_type = cudnn.datatypes._torch_to_cudnn_data_type(q.dtype) + cudnn_k_data_type = cudnn.datatypes._torch_to_cudnn_data_type(k_cache.dtype) + cudnn_v_data_type = cudnn.datatypes._torch_to_cudnn_data_type(v_cache.dtype) + with cudnn.graph(handle) as (g, _): # Create tensors from the input tensors if q.dim() == 3: @@ -149,9 +175,62 @@ def _build_prefill_graph( name="q", dim=(graph_b, h_qo, graph_s_qo, d_qk), stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_q_data_type, ) + if ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ): + cudnn_q_scale = g.tensor( + name="q_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_k_scale = g.tensor( + name="k_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_v_scale = g.tensor( + name="v_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_s_scale = g.tensor( + name="s_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_s_descale = g.tensor( + name="s_descale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_o_scale = g.tensor( + name="o_scale", + dim=(1, 1, 1, 1), + stride=(1, 1, 1, 1), + data_type=cudnn.data_type.FLOAT, + ) + + cudnn_q_scale.set_uid(UIDs.Q_SCALE_UID.value) + cudnn_k_scale.set_uid(UIDs.K_SCALE_UID.value) + cudnn_v_scale.set_uid(UIDs.V_SCALE_UID.value) + cudnn_s_scale.set_uid(UIDs.S_SCALE_UID.value) + cudnn_s_descale.set_uid(UIDs.S_DESCALE_UID.value) + cudnn_o_scale.set_uid(UIDs.O_SCALE_UID.value) + if batch_offsets_q is not None: ragged_q = g.tensor_like(batch_offsets_q) ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) @@ -175,7 +254,7 @@ def _build_prefill_graph( name="k_cache", dim=(graph_b, h_kv, graph_s_kv, d_qk), stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_k_data_type, ) if batch_offsets_k is not None: @@ -187,7 +266,7 @@ def _build_prefill_graph( name="v_cache", dim=(graph_b, h_kv, graph_s_kv, d_vo), stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_v_data_type, ) if batch_offsets_v is not None: @@ -200,14 +279,14 @@ def _build_prefill_graph( name="k_cache", dim=k_cache.shape, stride=k_cache.stride(), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_k_data_type, ) cudnn_v_cache = g.tensor( name="v_cache", dim=v_cache.shape, stride=v_cache.stride(), - data_type=cudnn.data_type.BFLOAT16, + data_type=cudnn_v_data_type, ) cudnn_q.set_uid(UIDs.Q_UID.value) @@ -238,32 +317,83 @@ def _build_prefill_graph( actual_seq_lens_q is not None and actual_seq_lens_kv is not None ) - O, Stats = g.sdpa( - name="sdpa", - q=cudnn_q, - k=cudnn_k_cache, - v=cudnn_v_cache, - seq_len_q=( - cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None - ), - seq_len_kv=( - cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None - ), - use_padding_mask=padding_mask, - attn_scale=scale, - generate_stats=return_lse, - use_causal_mask_bottom_right=bottom_right_causal_mask, - paged_attention_k_table=( - cudnn_k_block_tables if block_tables is not None else None - ), - paged_attention_v_table=( - cudnn_v_block_tables if block_tables is not None else None - ), - paged_attention_max_seq_len_kv=( - graph_s_kv if block_tables is not None else None - ), - compute_data_type=cudnn.data_type.FLOAT, - ) + if cudnn_q_data_type == cudnn.data_type.BFLOAT16: + O, Stats = g.sdpa( + name="sdpa", + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + seq_len_q=( + cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None + ), + use_padding_mask=padding_mask, + attn_scale=scale, + generate_stats=return_lse, + use_causal_mask_bottom_right=bottom_right_causal_mask, + paged_attention_k_table=( + cudnn_k_block_tables if block_tables is not None else None + ), + paged_attention_v_table=( + cudnn_v_block_tables if block_tables is not None else None + ), + paged_attention_max_seq_len_kv=( + graph_s_kv if block_tables is not None else None + ), + compute_data_type=cudnn.data_type.FLOAT, + ) + + elif ( + cudnn_q_data_type == cudnn.data_type.FP8_E4M3 + or cudnn_q_data_type == cudnn.data_type.FP8_E5M2 + ): + O, Stats, amax_s, amax_o = g.sdpa_fp8( + q=cudnn_q, + k=cudnn_k_cache, + v=cudnn_v_cache, + descale_q=cudnn_q_scale, + descale_k=cudnn_k_scale, + descale_v=cudnn_v_scale, + scale_s=cudnn_s_scale, + descale_s=cudnn_s_descale, + scale_o=cudnn_o_scale, + generate_stats=True, + attn_scale=scale, + use_causal_mask_bottom_right=bottom_right_causal_mask, + use_padding_mask=padding_mask, + seq_len_q=( + cudnn_actual_seq_lens_q + if actual_seq_lens_q is not None + else None + ), + seq_len_kv=( + cudnn_actual_seq_lens_kv + if actual_seq_lens_kv is not None + else None + ), + paged_attention_k_table=( + cudnn_k_block_tables if block_tables is not None else None + ), + paged_attention_v_table=( + cudnn_v_block_tables if block_tables is not None else None + ), + paged_attention_max_seq_len_kv=( + graph_s_kv if block_tables is not None else None + ), + ) + + amax_s.set_uid(UIDs.S_AMAX_UID.value).set_output(False).set_dim( + (1, 1, 1, 1) + ).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) + amax_o.set_uid(UIDs.O_AMAX_UID.value).set_output(False).set_dim( + (1, 1, 1, 1) + ).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT) if batch_offsets_o is not None: ragged_o = g.tensor_like(batch_offsets_o) @@ -279,7 +409,7 @@ def _build_prefill_graph( [graph_b, h_qo, graph_s_qo, d_vo] ).set_stride( [graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1] - ).set_data_type(cudnn.data_type.BFLOAT16) + ).set_data_type(cudnn_q_data_type) if return_lse: Stats.set_uid(UIDs.STATS_UID.value).set_output( @@ -314,6 +444,9 @@ def _batch_prefill_with_kv_cache( block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, + q_scale: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, @@ -374,6 +507,17 @@ def _batch_prefill_with_kv_cache( if batch_offsets_stats is not None: var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats + if q_scale is not None: + dummy_scale_tensor = _get_dummy_scale_tensor(q.device) + var_map[UIDs.Q_SCALE_UID.value] = q_scale + var_map[UIDs.S_SCALE_UID.value] = dummy_scale_tensor + var_map[UIDs.S_DESCALE_UID.value] = dummy_scale_tensor + var_map[UIDs.O_SCALE_UID.value] = dummy_scale_tensor + if k_scale is not None: + var_map[UIDs.K_SCALE_UID.value] = k_scale + if v_scale is not None: + var_map[UIDs.V_SCALE_UID.value] = v_scale + handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) graph.execute(var_map, workspace=workspace_buffer, handle=handle) @@ -397,6 +541,9 @@ def cudnn_batch_prefill_with_kv_cache( block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, + q_scale: Optional[torch.Tensor] = None, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, @@ -425,6 +572,9 @@ def cudnn_batch_prefill_with_kv_cache( out: Optional pre-allocated output tensor lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph + q_scale: Optional scale tensor for query tensor of shape (1, 1, 1, 1) on GPU + k_scale: Optional scale tensor for key tensor of shape (1, 1, 1, 1) on GPU + v_scale: Optional scale tensor for value tensor of shape (1, 1, 1, 1) on GPU batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU @@ -488,6 +638,9 @@ def cudnn_batch_prefill_with_kv_cache( block_tables=block_tables, causal=causal, return_lse=return_lse, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, batch_offsets_q=batch_offsets_q, batch_offsets_o=batch_offsets_o, batch_offsets_k=batch_offsets_k, diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 49abe60897..3a3e962fad 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1980,9 +1980,9 @@ def run( q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], *args, - q_scale: Optional[float] = None, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, + q_scale: Optional[Union[float, torch.Tensor]] = None, + k_scale: Optional[Union[float, torch.Tensor]] = None, + v_scale: Optional[Union[float, torch.Tensor]] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, @@ -2012,9 +2012,11 @@ def run( *args Additional arguments for custom kernels. - k_scale : Optional[float] + q_scale : Optional[Union[float, torch.Tensor]] + The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. + k_scale : Optional[Union[float, torch.Tensor]] The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] + v_scale : Optional[Union[float, torch.Tensor]] The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. out : Optional[torch.Tensor] The output tensor, if not provided, will be allocated internally. @@ -2061,10 +2063,11 @@ def run( logits_soft_cap = 0.0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale + if self._backend != "cudnn": + if q_scale is not None: + sm_scale *= q_scale + if k_scale is not None: + sm_scale *= k_scale if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -2143,6 +2146,9 @@ def run( block_tables=self._block_tables, causal=self._causal, return_lse=return_lse, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, batch_offsets_q=self._qo_indptr_buf, batch_offsets_o=self._qo_indptr_buf, out=out, diff --git a/tests/attention/test_cudnn_prefill.py b/tests/attention/test_cudnn_prefill.py index d264db8ae4..68b0ff4b01 100644 --- a/tests/attention/test_cudnn_prefill.py +++ b/tests/attention/test_cudnn_prefill.py @@ -180,3 +180,202 @@ def test_cudnn_prefill( output_ref = wrapper.run(q, kv_cache) torch.testing.assert_close(output, output_ref, atol=3e-3, rtol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("s_qo", [8, 17, 700]) +@pytest.mark.parametrize("s_kv", [8, 32, 1066]) +@pytest.mark.parametrize("page_size", [8, 16, 64]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [4]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("return_lse", [True, False]) +@pytest.mark.parametrize("is_cuda_graph_compatible", [True]) +def test_cudnn_prefill_fp8( + batch_size, + s_qo, + s_kv, + page_size, + num_kv_heads, + num_qo_heads, + causal, + return_lse, + is_cuda_graph_compatible, +): + head_dim = 128 + if s_qo > s_kv: + pytest.skip("s_qo > s_kv, skipping test") + + # test set up basics + seed = 1 + torch.manual_seed(seed) + device = "cuda:0" + + actual_seq_lens_q = torch.randint( + 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + actual_seq_lens_kv = torch.randint( + s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + + cumsum_s_qo = torch.sum(actual_seq_lens_q) + q = torch.randn( + cumsum_s_qo, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16 + ) + + q_scale = q.amax().item() / 256 + + q_scale = torch.tensor(q_scale, device=device, dtype=torch.float32) + q_fp8 = (q / q_scale).to(torch.float8_e4m3fn) + + q_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0) * head_dim * num_qo_heads, + ] + ).int() + + # Initialize KV Cache + num_pages_per_seq = (s_kv + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + + kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim) + kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device) + kv_cache = kv_cache.as_strided( + kv_cache.shape, + ( + 2 * page_size * num_kv_heads * head_dim, + page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + ), + ) + k_cache_view = kv_cache[:, 0, :, :, :] + v_cache_view = kv_cache[:, 1, :, :, :] + + v_cache = v_cache_view.as_strided( + v_cache_view.shape, + (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + ) + k_cache = k_cache_view.as_strided( + k_cache_view.shape, + (2 * page_size * num_kv_heads * head_dim, head_dim, num_kv_heads * head_dim, 1), + ) + + k_scale = k_cache.amax().item() / 256 + v_scale = v_cache.amax().item() / 256 + k_cache_fp8 = (k_cache / k_scale).to(torch.float8_e4m3fn) + v_cache_fp8 = (v_cache / v_scale).to(torch.float8_e4m3fn) + + k_scale_tensor = torch.tensor(k_scale, device=device, dtype=torch.float32) + v_scale_tensor = torch.tensor(v_scale, device=device, dtype=torch.float32) + + kv_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum( + (actual_seq_lens_kv.flatten() + page_size - 1) // page_size, + dim=0, + ), + ] + ).int() + + # kv_indices + kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32) + for i in range(len(kv_indptr) - 1): + start_idx = kv_indptr[i] + end_idx = kv_indptr[i + 1] + kv_indices[start_idx:end_idx] = torch.arange( + i * num_pages_per_seq, + i * num_pages_per_seq + (end_idx - start_idx), + device=device, + ) + + # kv_last_page_len + kv_last_page_len = torch.where( + actual_seq_lens_kv.flatten() % page_size == 0, + torch.full((batch_size,), page_size, device=device), + actual_seq_lens_kv.flatten() % page_size, + ).int() + + # Now initialize the page tables + block_tables = torch.tensor( + [ + [k + i * num_pages_per_seq for k in range(num_pages_per_seq)] + for i in range(batch_size) + ], + dtype=torch.int, + device=device, + ) + + # Initialize scale + scale = float(1.0 / (head_dim**0.5)) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + + wrapper_cudnn = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD", backend="cudnn" + ) + wrapper_cudnn.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + causal=causal, + q_data_type=torch.float8_e4m3fn, + seq_lens=actual_seq_lens_kv, + seq_lens_q=actual_seq_lens_q, + sm_scale=scale, + max_token_per_sequence=s_qo, + max_sequence_kv=s_kv, + block_tables=block_tables, + ) + + output = wrapper_cudnn.run( + q_fp8, + (k_cache_fp8, v_cache_fp8), + q_scale=q_scale, + k_scale=k_scale_tensor, + v_scale=v_scale_tensor, + ) + + qo_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + ] + ).int() + + # Workspace buffer + workspace_buffer_ref = torch.empty( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer_ref, "HND", backend="fa2" + ) + wrapper.plan( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + causal=causal, + q_data_type=torch.bfloat16, + ) + + output_ref = wrapper.run(q, kv_cache) + + output_bf16 = output.to(torch.bfloat16) + + torch.testing.assert_close(output_bf16, output_ref, atol=1e-1, rtol=1e-1)