Skip to content

[CPU][FP8] Support FP8 SDPA for CPU backend #2689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/prototype/inductor/test_int8_sdpa_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from torch.testing._internal.inductor_utils import HAS_CPU

import torchao
from torchao.prototype.inductor.fx_passes.int8_sdpa_fusion import (
_int8_sdpa_init,
from torchao.prototype.inductor.fx_passes.qsdpa_fusion import (
_qsdpa_init,
custom_pass,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
Expand Down Expand Up @@ -120,7 +120,7 @@ def _check_common(
)
source_code = "\n".join(source_code)
if has_fuse_pattern:
self.assertGreaterEqual(counters["inductor"]["int8_fuse_attention"], 1)
self.assertGreaterEqual(counters["inductor"]["qsdpa_fuse_attention"], 1)
if contains:
self.assertTrue(
any(
Expand Down Expand Up @@ -157,7 +157,7 @@ def _check_common(
)
@config.patch({"freezing": True})
def _test_sdpa_int8_rewriter(self):
from torch.export import export_for_training
from torch.export import export

import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
Expand Down Expand Up @@ -193,13 +193,13 @@ def _test_sdpa_int8_rewriter(self):
),
config.patch(post_grad_custom_pre_pass=custom_pass),
):
_int8_sdpa_init()
_qsdpa_init()
quantizer = X86InductorQuantizer()
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
quantizer.set_function_type_qconfig(
torch.matmul, quantizer.get_global_quantization_config()
)
export_model = export_for_training(
export_model = export(
mod,
inputs,
strict=True,
Expand Down
204 changes: 141 additions & 63 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,50 +155,101 @@ def _scaled_dot_product_int8_op_ref(
out = torch.clamp(torch.round(out / o_scale) + o_zp, min=0, max=255)
return out.to(torch.uint8)

def _scaled_dot_product_fp8_op_ref(
self,
q,
k,
v,
attn_mask=None,
dropout_p=0,
is_causal=False,
q_scale=1.0,
k_scale=1.0,
v_scale=1.0,
a_scale=1.0,
o_scale=1.0,
):
q = q.to(torch.float) * q_scale
k = k.to(torch.float) * k_scale
v = v.to(torch.float) * v_scale
scale_factor = 1 / math.sqrt(q.size(-1))
attn = q @ k.transpose(-2, -1)

attn = attn * scale_factor
if attn_mask is not None:
attn = attn + attn_mask.to(torch.float)
attn_max = attn.max(dim=-1, keepdim=True).values
attn = attn - attn_max
attn = torch.exp(attn)
attn_sum = torch.sum(attn, dim=-1, keepdim=True)
attn = attn / attn_sum
attn = torch.clamp(attn / a_scale, min=-448, max=448)
attn = attn.to(torch.float8_e4m3fn).to(torch.float)
attn = attn * a_scale
out = attn @ v
out = torch.clamp(out / o_scale, min=-448, max=448)
return out.to(torch.float8_e4m3fn)

@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_7, reason="int8 sdpa requires torch 2.7 or later"
not TORCH_VERSION_AT_LEAST_2_7,
reason="quantized sdpa requires torch 2.7 or later",
)
@pytest.mark.skipif(not IS_LINUX, reason="only support on linux")
@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::qscaled_dot_product"),
reason="cpp kernels not built",
)
@parametrize("input_dtype", [torch.uint8, torch.float8_e4m3fn])
@parametrize("batch_size", [56, 120])
@parametrize("n_head", [2, 16])
@parametrize("q_seq_len", [18, 89])
@parametrize("kv_seq_len", [100, 253])
@parametrize("head_dim", [32, 64])
@parametrize("mask_dtype", [None, torch.float32, torch.bfloat16])
def test_scaled_dot_product_int8_op(
self, batch_size, n_head, q_seq_len, kv_seq_len, head_dim, mask_dtype
def test_quantized_scaled_dot_product_op(
self,
input_dtype,
batch_size,
n_head,
q_seq_len,
kv_seq_len,
head_dim,
mask_dtype,
):
torch.manual_seed(1234)
device = "cpu"
q_scale = float(1.7907238006591797)
q_zp = int(127)
k_scale = float(1.8039721250534058)
k_zp = int(125)
v_scale = float(1.839004635810852)
v_zp = int(127)
a_scale = float(0.003919653594493866)
a_zp = int(120)
o_scale = float(1.8191684484481812)
o_zp = int(128)
if input_dtype == torch.uint8:
q_scale = float(1.7907238006591797)
k_scale = float(1.8039721250534058)
v_scale = float(1.839004635810852)
a_scale = float(0.003919653594493866)
o_scale = float(1.8191684484481812)
q_zp = int(127)
k_zp = int(125)
v_zp = int(127)
a_zp = int(120)
o_zp = int(128)
atol, rtol = 1.0, 5e-6
else:
q_scale = float(5.96875)
k_scale = float(5.78125)
v_scale = float(0.98046875)
a_scale = float(4.84375)
o_scale = float(3.171875)
atol, rtol = 0.125, 5e-6
q_shape = [batch_size, q_seq_len, n_head, head_dim]
kv_shape = [batch_size, kv_seq_len, n_head, head_dim]
mask_shape = [batch_size, 1, 1, kv_seq_len]
q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2) * 100
k = (
torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
* 100
)
v = (
torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
* 100
)
q = q.to(torch.uint8)
k = k.to(torch.uint8)
v = v.to(torch.uint8)
q = torch.randn(q_shape, dtype=torch.float, device=device).transpose(1, 2)
k = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
v = torch.randn(kv_shape, dtype=torch.float, device=device).transpose(1, 2)
if input_dtype == torch.uint8:
q *= 100
k *= 100
v *= 100
q = q.to(input_dtype)
k = k.to(input_dtype)
v = v.to(input_dtype)
attn_mask = (
torch.randn(mask_shape, dtype=mask_dtype, device=device)
if mask_dtype is not None
Expand All @@ -211,44 +262,71 @@ def test_scaled_dot_product_int8_op(
attn_mask.clone() if mask_dtype is not None else None,
)

math_ref = self._scaled_dot_product_int8_op_ref(
q2,
k2,
v2,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
q_zp=q_zp,
k_scale=k_scale,
k_zp=k_zp,
v_scale=v_scale,
v_zp=v_zp,
a_scale=a_scale,
a_zp=a_zp,
o_scale=o_scale,
o_zp=o_zp,
)
actual = torch.ops.torchao.qscaled_dot_product(
q,
k,
v,
attn_mask=attn_mask_2,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
q_zp=q_zp,
k_scale=k_scale,
k_zp=k_zp,
v_scale=v_scale,
v_zp=v_zp,
a_scale=a_scale,
a_zp=a_zp,
o_scale=o_scale,
o_zp=o_zp,
)

self.assertEqual(actual, math_ref, atol=1.0, rtol=5e-6)
if input_dtype == torch.uint8:
math_ref = self._scaled_dot_product_int8_op_ref(
q2,
k2,
v2,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
q_zp=q_zp,
k_scale=k_scale,
k_zp=k_zp,
v_scale=v_scale,
v_zp=v_zp,
a_scale=a_scale,
a_zp=a_zp,
o_scale=o_scale,
o_zp=o_zp,
)
actual = torch.ops.torchao.qscaled_dot_product(
q,
k,
v,
attn_mask=attn_mask_2,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
q_zp=q_zp,
k_scale=k_scale,
k_zp=k_zp,
v_scale=v_scale,
v_zp=v_zp,
a_scale=a_scale,
a_zp=a_zp,
o_scale=o_scale,
o_zp=o_zp,
)
else:
math_ref = self._scaled_dot_product_fp8_op_ref(
q2,
k2,
v2,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
a_scale=a_scale,
o_scale=o_scale,
)
actual = torch.ops.torchao.qscaled_dot_product(
q,
k,
v,
attn_mask=attn_mask_2,
dropout_p=0.0,
is_causal=False,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
a_scale=a_scale,
o_scale=o_scale,
)
self.assertEqual(actual.float(), math_ref.float(), atol=atol, rtol=rtol)


instantiate_parametrized_tests(TestOps)
Expand Down
Loading
Loading