|
| 1 | +from itertools import product |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +from sgl_kernel.flash_attn import flash_attn_with_kvcache |
| 6 | + |
| 7 | + |
| 8 | +def flash_attn_baseline( |
| 9 | + q, |
| 10 | + k_cache, |
| 11 | + v_cache, |
| 12 | + causal, |
| 13 | + softmax_scale, |
| 14 | + cache_seqlens, |
| 15 | + page_table, |
| 16 | + cu_seqlens_q, |
| 17 | + max_seqlen_q, |
| 18 | +): |
| 19 | + """Baseline Flash Attention implementation""" |
| 20 | + out, lse, *rest = flash_attn_with_kvcache( |
| 21 | + q, |
| 22 | + k_cache, |
| 23 | + v_cache, |
| 24 | + causal=causal, |
| 25 | + softmax_scale=softmax_scale, |
| 26 | + page_table=page_table, |
| 27 | + cache_seqlens=cache_seqlens, |
| 28 | + cu_seqlens_q=cu_seqlens_q, |
| 29 | + max_seqlen_q=max_seqlen_q, |
| 30 | + return_softmax_lse=True, |
| 31 | + ) |
| 32 | + return out, lse |
| 33 | + |
| 34 | + |
| 35 | +# Benchmark configurations |
| 36 | +causal = [True, False] |
| 37 | +batch_size = [1, 16] |
| 38 | +q_seq_length_range = [1, 512, 1024] |
| 39 | +kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384] |
| 40 | +page_size_range = [32, 64, 128] |
| 41 | +configs = list( |
| 42 | + product( |
| 43 | + causal, batch_size, q_seq_length_range, kv_seq_length_range, page_size_range |
| 44 | + ) |
| 45 | +) |
| 46 | + |
| 47 | + |
| 48 | +@triton.testing.perf_report( |
| 49 | + triton.testing.Benchmark( |
| 50 | + x_names=["causal", "batch_size", "q_seq_length", "kv_seq_length", "page_size"], |
| 51 | + x_vals=[list(c) for c in configs], |
| 52 | + line_arg="provider", |
| 53 | + line_vals=["flash_attn"], |
| 54 | + line_names=["Flash Attention"], |
| 55 | + styles=[("blue", "-")], |
| 56 | + ylabel="us", |
| 57 | + plot_name="flash-attention-performance", |
| 58 | + args={}, |
| 59 | + ) |
| 60 | +) |
| 61 | +def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provider): |
| 62 | + dtype = torch.bfloat16 |
| 63 | + device = torch.device("xpu") |
| 64 | + |
| 65 | + # Attention parameters |
| 66 | + num_heads = 16 |
| 67 | + head_dim = 64 |
| 68 | + |
| 69 | + # Create input tensors |
| 70 | + q = torch.randn( |
| 71 | + (batch_size * q_seq_length, num_heads, head_dim), device=device, dtype=dtype |
| 72 | + ) |
| 73 | + num_pages = (batch_size * kv_seq_length + page_size - 1) // page_size |
| 74 | + k_cache = torch.randn( |
| 75 | + (num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype |
| 76 | + ) |
| 77 | + v_cache = torch.randn( |
| 78 | + (num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype |
| 79 | + ) |
| 80 | + cache_seqlens = ( |
| 81 | + torch.ones(batch_size, device=device, dtype=torch.int32) * kv_seq_length |
| 82 | + ) |
| 83 | + page_table = ( |
| 84 | + torch.randperm(num_pages, device=device, dtype=torch.int32) |
| 85 | + .reshape(batch_size, -1) |
| 86 | + .contiguous() |
| 87 | + ) |
| 88 | + cu_seqlens_q = torch.arange( |
| 89 | + 0, |
| 90 | + (batch_size + 1) * q_seq_length, |
| 91 | + step=q_seq_length, |
| 92 | + device=device, |
| 93 | + dtype=torch.int32, |
| 94 | + ) |
| 95 | + max_seqlen_q = q_seq_length |
| 96 | + |
| 97 | + softmax_scale = 1.0 / (head_dim**0.5) |
| 98 | + |
| 99 | + quantiles = [0.5, 0.2, 0.8] |
| 100 | + |
| 101 | + if provider == "flash_attn": |
| 102 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 103 | + lambda: flash_attn_baseline( |
| 104 | + q.clone(), |
| 105 | + k_cache.clone(), |
| 106 | + v_cache.clone(), |
| 107 | + causal=causal, |
| 108 | + softmax_scale=softmax_scale, |
| 109 | + cache_seqlens=cache_seqlens, |
| 110 | + page_table=page_table, |
| 111 | + cu_seqlens_q=cu_seqlens_q, |
| 112 | + max_seqlen_q=max_seqlen_q, |
| 113 | + ), |
| 114 | + quantiles=quantiles, |
| 115 | + ) |
| 116 | + |
| 117 | + return 1000 * ms, 1000 * max_ms, 1000 * min_ms |
| 118 | + |
| 119 | + |
| 120 | +if __name__ == "__main__": |
| 121 | + benchmark.run(print_data=True) |
0 commit comments