Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.nn as nn
import types
Expand Down Expand Up @@ -1296,6 +1297,10 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format)
self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format)
self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format)
self.qkv_slice_thld = int(os.getenv("VLLM_FUSEDSDPA_QKV_SLICE_SEQ_LEN_THLD", 8192))
if self.qkv_slice_thld > 0:
self.q_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_Q_SLICE_CHUNK_SIZE", self.qkv_slice_thld))
self.kv_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_KV_SLICE_CHUNK_SIZE", self.qkv_slice_thld))

def forward_qdq(
self,
Expand Down Expand Up @@ -1330,6 +1335,41 @@ def forward_qdq(
seq_padding_type,
)
return results

def fp8_fsdpa_fwd(self,
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
softmax_mode,
):
results = torch.ops.hpu.fp8_sdpa_recomp_fwd(
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
True, # requires_backward
softmax_mode, # softmax_mode
self.scale_q, # d_scale_q
self.scale_k, # d_scale_k
self.scale_v, # d_scale_v
self.scale_amax, # q_scale_s
self.scale_output, # q_scale_o
self.descale_amax, # d_scale_s
False, # is_amax_s
False, # is_amax_o
None, # valid_seq_len
"right", # seq_padding_type
(-1, -1), # window_size
None, # sink
)
return results

def forward_quant(
self,
Expand All @@ -1345,32 +1385,142 @@ def forward_quant(
valid_seq_len=None,
seq_padding_type="None",
):
sm_mode = softmax_mode if softmax_mode == "fp32" else "None"
sm_mode = softmax_mode if softmax_mode == "fp32" else "none"
qinput = self.quant_q(q).detach()
kinput = self.quant_k(k).detach()
vinput = self.quant_v(v).detach()
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out
q_len = q.shape[-2]
kv_len = kinput.size(-2)

# for prefill with prefix caching
if self.qkv_slice_thld > 0 and q_len != 1 and q_len != kv_len and kv_len > self.qkv_slice_thld:
assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching."
ctx_len = kv_len - q_len
from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
gqa = is_gqa(qinput, kinput)
if gqa:
qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask)

num_q_chunks = (q_len + self.q_chunk_size - 1) // self.q_chunk_size
num_context_kv_chunks = (ctx_len + self.kv_chunk_size - 1) // self.kv_chunk_size
num_causal_kv_chunks = num_q_chunks
chunk_outputs = []
for q_chunk_idx in range(num_q_chunks):
q_start = q_chunk_idx * self.q_chunk_size
q_end = min((q_chunk_idx + 1) * self.q_chunk_size, q_len)
q_chunk = qinput[..., q_start:q_end, :]

last_out = None
last_m = None
last_linv = None
for kv_chunk_idx in range(num_context_kv_chunks):
kv_start = kv_chunk_idx * self.kv_chunk_size
kv_end = min((kv_chunk_idx + 1) * self.kv_chunk_size, ctx_len)
k_chunk = kinput[..., kv_start:kv_end, :]
v_chunk = vinput[..., kv_start:kv_end, :]

chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, sm_mode)
chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3]

chunk_m = chunk_m.to(torch.float32)
chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32)
chunk_out = self.dequant_output(chunk_out).to(torch.float32)

if kv_chunk_idx == 0:
last_out = chunk_out
last_m = chunk_m
last_linv = chunk_linv
else:
new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (
chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m

kv_causal_start = ctx_len + q_start
kv_causal_end = ctx_len + q_end
k_causal_chunk = kinput[..., kv_causal_start:kv_causal_end, :]
v_causal_chunk = vinput[..., kv_causal_start:kv_causal_end, :]

bs = q_chunk.size(0)
q_chunk_len = q_chunk.size(-2)
if q_chunk.size(-2) < self.q_chunk_size:
mask = (1 - torch.tril(
torch.ones(bs,
1,
1,
q_chunk_len,
q_chunk_len,
dtype=q.dtype,
device=q.device))) * torch.finfo(
q.dtype).min
causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, mask, dropout_p, scale, False, sm_mode)
else:
causal_chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, None, dropout_p, scale, True, sm_mode)

causal_chunk_out, causal_chunk_m, causal_chunk_linv = (gqa_output_reshape(x) for x in (causal_chunk_res[:3])) if gqa else causal_chunk_res[:3]
causal_chunk_m = causal_chunk_m.to(torch.float32)
causal_chunk_linv = causal_chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else causal_chunk_linv.to(torch.float32)
causal_chunk_out = self.dequant_output(causal_chunk_out).to(torch.float32)

if num_causal_kv_chunks == 1:
new_m = torch.maximum(last_m, causal_chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / causal_chunk_linv) * torch.exp(causal_chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (
chunk_linv_rescaled * last_linv) * causal_chunk_out
last_m = new_m
else:
for kv_chunk_idx in range(0, q_chunk_idx):
kv_causal_start = ctx_len + kv_chunk_idx * self.q_chunk_size
kv_causal_end = ctx_len + (kv_chunk_idx + 1) * self.q_chunk_size
k_causal_chunk = kinput[..., kv_causal_start:kv_causal_end, :]
v_causal_chunk = vinput[..., kv_causal_start:kv_causal_end, :]

chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_causal_chunk, v_causal_chunk, None, dropout_p, scale, False, sm_mode)

chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3]
chunk_m = chunk_m.to(torch.float32)
chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32)
chunk_out = self.dequant_output(chunk_out).to(torch.float32)

new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (
chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m

chunk_outputs.append(last_out)
output = torch.cat(chunk_outputs, dim=-2)
return output.to(q.dtype)
else:
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out

def forward_measure(
self,
Expand Down