Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 185 additions & 32 deletions flashinfer/cudnn/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@

# Global cudnn handle. need to make it per device in future
_cudnn_handle = None
_dummy_scale_tensor = None


def _get_dummy_scale_tensor(device: torch.device):
global _dummy_scale_tensor

_dummy_scale_tensor = torch.tensor([1.0], device=device, dtype=torch.float32)
return _dummy_scale_tensor


def _create_cudnn_handle(stream: torch.cuda.Stream):
global _cudnn_handle

if _cudnn_handle is None:
_cudnn_handle = cudnn.create_handle()
cudnn.set_stream(_cudnn_handle, stream.cuda_stream)
Expand Down Expand Up @@ -49,6 +58,16 @@ class UIDs(Enum):
O_UID = 1000 # Output tensor
STATS_UID = 1001 # Stats tensor

Q_SCALE_UID = 150 # Query scale tensor
K_SCALE_UID = 151 # Key scale tensor
V_SCALE_UID = 152 # Value scale tensor
S_SCALE_UID = 153 # Scale tensor
S_DESCALE_UID = 154 # Descale tensor
O_SCALE_UID = 155 # Output scale tensor

S_AMAX_UID = 160 # Scale amax tensor
O_AMAX_UID = 161 # Output amax tensor


def _sdpa_prefill_key_fn(
q: torch.Tensor,
Expand Down Expand Up @@ -136,6 +155,13 @@ def _build_prefill_graph(
graph_s_qo = max_token_seq_q
graph_s_kv = max_sequence_kv

if not cudnn.datatypes.is_torch_available():
raise RuntimeError("torch is not available")

cudnn_q_data_type = cudnn.datatypes._torch_to_cudnn_data_type(q.dtype)
cudnn_k_data_type = cudnn.datatypes._torch_to_cudnn_data_type(k_cache.dtype)
cudnn_v_data_type = cudnn.datatypes._torch_to_cudnn_data_type(v_cache.dtype)

with cudnn.graph(handle) as (g, _):
# Create tensors from the input tensors
if q.dim() == 3:
Expand All @@ -149,9 +175,62 @@ def _build_prefill_graph(
name="q",
dim=(graph_b, h_qo, graph_s_qo, d_qk),
stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1),
data_type=cudnn.data_type.BFLOAT16,
data_type=cudnn_q_data_type,
)

if (
cudnn_q_data_type == cudnn.data_type.FP8_E4M3
or cudnn_q_data_type == cudnn.data_type.FP8_E5M2
):
cudnn_q_scale = g.tensor(
name="q_scale",
dim=(1, 1, 1, 1),
stride=(1, 1, 1, 1),
data_type=cudnn.data_type.FLOAT,
)

cudnn_k_scale = g.tensor(
name="k_scale",
dim=(1, 1, 1, 1),
stride=(1, 1, 1, 1),
data_type=cudnn.data_type.FLOAT,
)

cudnn_v_scale = g.tensor(
name="v_scale",
dim=(1, 1, 1, 1),
stride=(1, 1, 1, 1),
data_type=cudnn.data_type.FLOAT,
)

cudnn_s_scale = g.tensor(
name="s_scale",
dim=(1, 1, 1, 1),
stride=(1, 1, 1, 1),
data_type=cudnn.data_type.FLOAT,
)

cudnn_s_descale = g.tensor(
name="s_descale",
dim=(1, 1, 1, 1),
stride=(1, 1, 1, 1),
data_type=cudnn.data_type.FLOAT,
)

cudnn_o_scale = g.tensor(
name="o_scale",
dim=(1, 1, 1, 1),
stride=(1, 1, 1, 1),
data_type=cudnn.data_type.FLOAT,
)

cudnn_q_scale.set_uid(UIDs.Q_SCALE_UID.value)
cudnn_k_scale.set_uid(UIDs.K_SCALE_UID.value)
cudnn_v_scale.set_uid(UIDs.V_SCALE_UID.value)
cudnn_s_scale.set_uid(UIDs.S_SCALE_UID.value)
cudnn_s_descale.set_uid(UIDs.S_DESCALE_UID.value)
cudnn_o_scale.set_uid(UIDs.O_SCALE_UID.value)

if batch_offsets_q is not None:
ragged_q = g.tensor_like(batch_offsets_q)
ragged_q.set_uid(UIDs.RAGGED_Q_UID.value)
Expand All @@ -175,7 +254,7 @@ def _build_prefill_graph(
name="k_cache",
dim=(graph_b, h_kv, graph_s_kv, d_qk),
stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1),
data_type=cudnn.data_type.BFLOAT16,
data_type=cudnn_k_data_type,
)

if batch_offsets_k is not None:
Expand All @@ -187,7 +266,7 @@ def _build_prefill_graph(
name="v_cache",
dim=(graph_b, h_kv, graph_s_kv, d_vo),
stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1),
data_type=cudnn.data_type.BFLOAT16,
data_type=cudnn_v_data_type,
)

if batch_offsets_v is not None:
Expand All @@ -200,14 +279,14 @@ def _build_prefill_graph(
name="k_cache",
dim=k_cache.shape,
stride=k_cache.stride(),
data_type=cudnn.data_type.BFLOAT16,
data_type=cudnn_k_data_type,
)

cudnn_v_cache = g.tensor(
name="v_cache",
dim=v_cache.shape,
stride=v_cache.stride(),
data_type=cudnn.data_type.BFLOAT16,
data_type=cudnn_v_data_type,
)

cudnn_q.set_uid(UIDs.Q_UID.value)
Expand Down Expand Up @@ -238,32 +317,83 @@ def _build_prefill_graph(
actual_seq_lens_q is not None and actual_seq_lens_kv is not None
)

O, Stats = g.sdpa(
name="sdpa",
q=cudnn_q,
k=cudnn_k_cache,
v=cudnn_v_cache,
seq_len_q=(
cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
),
seq_len_kv=(
cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
),
use_padding_mask=padding_mask,
attn_scale=scale,
generate_stats=return_lse,
use_causal_mask_bottom_right=bottom_right_causal_mask,
paged_attention_k_table=(
cudnn_k_block_tables if block_tables is not None else None
),
paged_attention_v_table=(
cudnn_v_block_tables if block_tables is not None else None
),
paged_attention_max_seq_len_kv=(
graph_s_kv if block_tables is not None else None
),
compute_data_type=cudnn.data_type.FLOAT,
)
if cudnn_q_data_type == cudnn.data_type.BFLOAT16:
O, Stats = g.sdpa(
name="sdpa",
q=cudnn_q,
k=cudnn_k_cache,
v=cudnn_v_cache,
seq_len_q=(
cudnn_actual_seq_lens_q
if actual_seq_lens_q is not None
else None
),
seq_len_kv=(
cudnn_actual_seq_lens_kv
if actual_seq_lens_kv is not None
else None
),
use_padding_mask=padding_mask,
attn_scale=scale,
generate_stats=return_lse,
use_causal_mask_bottom_right=bottom_right_causal_mask,
paged_attention_k_table=(
cudnn_k_block_tables if block_tables is not None else None
),
paged_attention_v_table=(
cudnn_v_block_tables if block_tables is not None else None
),
paged_attention_max_seq_len_kv=(
graph_s_kv if block_tables is not None else None
),
compute_data_type=cudnn.data_type.FLOAT,
)

elif (
cudnn_q_data_type == cudnn.data_type.FP8_E4M3
or cudnn_q_data_type == cudnn.data_type.FP8_E5M2
):
O, Stats, amax_s, amax_o = g.sdpa_fp8(
q=cudnn_q,
k=cudnn_k_cache,
v=cudnn_v_cache,
descale_q=cudnn_q_scale,
descale_k=cudnn_k_scale,
descale_v=cudnn_v_scale,
scale_s=cudnn_s_scale,
descale_s=cudnn_s_descale,
scale_o=cudnn_o_scale,
generate_stats=True,
attn_scale=scale,
use_causal_mask_bottom_right=bottom_right_causal_mask,
use_padding_mask=padding_mask,
seq_len_q=(
cudnn_actual_seq_lens_q
if actual_seq_lens_q is not None
else None
),
seq_len_kv=(
cudnn_actual_seq_lens_kv
if actual_seq_lens_kv is not None
else None
),
paged_attention_k_table=(
cudnn_k_block_tables if block_tables is not None else None
),
paged_attention_v_table=(
cudnn_v_block_tables if block_tables is not None else None
),
paged_attention_max_seq_len_kv=(
graph_s_kv if block_tables is not None else None
),
)

amax_s.set_uid(UIDs.S_AMAX_UID.value).set_output(False).set_dim(
(1, 1, 1, 1)
).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT)
amax_o.set_uid(UIDs.O_AMAX_UID.value).set_output(False).set_dim(
(1, 1, 1, 1)
).set_stride((1, 1, 1, 1)).set_data_type(cudnn.data_type.FLOAT)

if batch_offsets_o is not None:
ragged_o = g.tensor_like(batch_offsets_o)
Expand All @@ -279,7 +409,7 @@ def _build_prefill_graph(
[graph_b, h_qo, graph_s_qo, d_vo]
).set_stride(
[graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1]
).set_data_type(cudnn.data_type.BFLOAT16)
).set_data_type(cudnn_q_data_type)

if return_lse:
Stats.set_uid(UIDs.STATS_UID.value).set_output(
Expand Down Expand Up @@ -314,6 +444,9 @@ def _batch_prefill_with_kv_cache(
block_tables: Optional[torch.Tensor] = None,
causal: bool,
return_lse: bool,
q_scale: Optional[torch.Tensor] = None,
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
batch_offsets_q: Optional[torch.Tensor] = None,
batch_offsets_o: Optional[torch.Tensor] = None,
batch_offsets_k: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -374,6 +507,17 @@ def _batch_prefill_with_kv_cache(
if batch_offsets_stats is not None:
var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats

if q_scale is not None:
dummy_scale_tensor = _get_dummy_scale_tensor(q.device)
var_map[UIDs.Q_SCALE_UID.value] = q_scale
var_map[UIDs.S_SCALE_UID.value] = dummy_scale_tensor
var_map[UIDs.S_DESCALE_UID.value] = dummy_scale_tensor
var_map[UIDs.O_SCALE_UID.value] = dummy_scale_tensor
if k_scale is not None:
var_map[UIDs.K_SCALE_UID.value] = k_scale
if v_scale is not None:
var_map[UIDs.V_SCALE_UID.value] = v_scale

handle = _create_cudnn_handle(torch.cuda.current_stream(q.device))
graph.execute(var_map, workspace=workspace_buffer, handle=handle)

Expand All @@ -397,6 +541,9 @@ def cudnn_batch_prefill_with_kv_cache(
block_tables: Optional[torch.Tensor] = None,
causal: bool,
return_lse: bool,
q_scale: Optional[torch.Tensor] = None,
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
batch_offsets_q: Optional[torch.Tensor] = None,
batch_offsets_o: Optional[torch.Tensor] = None,
batch_offsets_k: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -425,6 +572,9 @@ def cudnn_batch_prefill_with_kv_cache(
out: Optional pre-allocated output tensor
lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph
q_scale: Optional scale tensor for query tensor of shape (1, 1, 1, 1) on GPU
k_scale: Optional scale tensor for key tensor of shape (1, 1, 1, 1) on GPU
v_scale: Optional scale tensor for value tensor of shape (1, 1, 1, 1) on GPU
batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
Expand Down Expand Up @@ -488,6 +638,9 @@ def cudnn_batch_prefill_with_kv_cache(
block_tables=block_tables,
causal=causal,
return_lse=return_lse,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
batch_offsets_q=batch_offsets_q,
batch_offsets_o=batch_offsets_o,
batch_offsets_k=batch_offsets_k,
Expand Down
24 changes: 15 additions & 9 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,9 +1980,9 @@ def run(
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
*args,
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
q_scale: Optional[Union[float, torch.Tensor]] = None,
k_scale: Optional[Union[float, torch.Tensor]] = None,
v_scale: Optional[Union[float, torch.Tensor]] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
Expand Down Expand Up @@ -2012,9 +2012,11 @@ def run(

*args
Additional arguments for custom kernels.
k_scale : Optional[float]
q_scale : Optional[Union[float, torch.Tensor]]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[Union[float, torch.Tensor]]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
v_scale : Optional[Union[float, torch.Tensor]]
The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``.
out : Optional[torch.Tensor]
The output tensor, if not provided, will be allocated internally.
Expand Down Expand Up @@ -2061,10 +2063,11 @@ def run(
logits_soft_cap = 0.0
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(q.size(-1))
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if self._backend != "cudnn":
if q_scale is not None:
sm_scale *= q_scale
if k_scale is not None:
sm_scale *= k_scale
if rope_scale is None:
rope_scale = 1.0
if rope_theta is None:
Expand Down Expand Up @@ -2143,6 +2146,9 @@ def run(
block_tables=self._block_tables,
causal=self._causal,
return_lse=return_lse,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
batch_offsets_q=self._qo_indptr_buf,
batch_offsets_o=self._qo_indptr_buf,
out=out,
Expand Down
Loading