|
| 1 | +from itertools import product |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +from sgl_kernel.flash_attn import flash_attn_with_kvcache |
| 6 | + |
| 7 | + |
| 8 | +def flash_attn_baseline( |
| 9 | + q, |
| 10 | + k_cache, |
| 11 | + v_cache, |
| 12 | + causal, |
| 13 | + window_size, |
| 14 | + softmax_scale, |
| 15 | + softmax_sink, |
| 16 | + cache_seqlens, |
| 17 | + page_table, |
| 18 | + cu_seqlens_q, |
| 19 | + max_seqlen_q, |
| 20 | +): |
| 21 | + """Baseline Flash Attention implementation""" |
| 22 | + out, lse, *rest = flash_attn_with_kvcache( |
| 23 | + q, |
| 24 | + k_cache, |
| 25 | + v_cache, |
| 26 | + causal=causal, |
| 27 | + softmax_sink=softmax_sink, |
| 28 | + window_size=window_size, |
| 29 | + softmax_scale=softmax_scale, |
| 30 | + page_table=page_table, |
| 31 | + cache_seqlens=cache_seqlens, |
| 32 | + cu_seqlens_q=cu_seqlens_q, |
| 33 | + max_seqlen_q=max_seqlen_q, |
| 34 | + return_softmax_lse=True, |
| 35 | + ) |
| 36 | + return out, lse |
| 37 | + |
| 38 | + |
| 39 | +# Benchmark configurations |
| 40 | +causal = [True, False] |
| 41 | +local = [True, False] |
| 42 | +use_softmax_sink = [True, False] |
| 43 | +batch_size = [1, 16] |
| 44 | +q_seq_length_range = [1, 512, 1024] |
| 45 | +kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384] |
| 46 | +page_size_range = [32, 64, 128] |
| 47 | +configs = list( |
| 48 | + filter( |
| 49 | + lambda cfg: not (cfg[0] and cfg[1]), |
| 50 | + product( |
| 51 | + causal, |
| 52 | + local, |
| 53 | + use_softmax_sink, |
| 54 | + batch_size, |
| 55 | + q_seq_length_range, |
| 56 | + kv_seq_length_range, |
| 57 | + page_size_range, |
| 58 | + ), |
| 59 | + ) |
| 60 | +) |
| 61 | + |
| 62 | + |
| 63 | +@triton.testing.perf_report( |
| 64 | + triton.testing.Benchmark( |
| 65 | + x_names=[ |
| 66 | + "causal", |
| 67 | + "local", |
| 68 | + "use_softmax_sink", |
| 69 | + "batch_size", |
| 70 | + "q_seq_length", |
| 71 | + "kv_seq_length", |
| 72 | + "page_size", |
| 73 | + ], |
| 74 | + x_vals=[list(c) for c in configs], |
| 75 | + line_arg="provider", |
| 76 | + line_vals=["flash_attn"], |
| 77 | + line_names=["Flash Attention"], |
| 78 | + styles=[("blue", "-")], |
| 79 | + ylabel="us", |
| 80 | + plot_name="flash-attention-performance", |
| 81 | + args={}, |
| 82 | + ) |
| 83 | +) |
| 84 | +def benchmark( |
| 85 | + causal, |
| 86 | + local, |
| 87 | + use_softmax_sink, |
| 88 | + batch_size, |
| 89 | + q_seq_length, |
| 90 | + kv_seq_length, |
| 91 | + page_size, |
| 92 | + provider, |
| 93 | +): |
| 94 | + dtype = torch.bfloat16 |
| 95 | + device = torch.device("xpu") |
| 96 | + |
| 97 | + # Attention parameters |
| 98 | + num_heads = 16 |
| 99 | + head_dim = 64 |
| 100 | + |
| 101 | + # Create input tensors |
| 102 | + q = torch.randn( |
| 103 | + (batch_size * q_seq_length, num_heads, head_dim), device=device, dtype=dtype |
| 104 | + ) |
| 105 | + num_pages = (batch_size * kv_seq_length + page_size - 1) // page_size |
| 106 | + k_cache = torch.randn( |
| 107 | + (num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype |
| 108 | + ) |
| 109 | + v_cache = torch.randn( |
| 110 | + (num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype |
| 111 | + ) |
| 112 | + cache_seqlens = ( |
| 113 | + torch.ones(batch_size, device=device, dtype=torch.int32) * kv_seq_length |
| 114 | + ) |
| 115 | + page_table = ( |
| 116 | + torch.randperm(num_pages, device=device, dtype=torch.int32) |
| 117 | + .reshape(batch_size, -1) |
| 118 | + .contiguous() |
| 119 | + ) |
| 120 | + cu_seqlens_q = torch.arange( |
| 121 | + 0, |
| 122 | + (batch_size + 1) * q_seq_length, |
| 123 | + step=q_seq_length, |
| 124 | + device=device, |
| 125 | + dtype=torch.int32, |
| 126 | + ) |
| 127 | + max_seqlen_q = q_seq_length |
| 128 | + window_size = (-1, -1) if not local else torch.randint(0, kv_seq_length, (2,)) |
| 129 | + |
| 130 | + softmax_sink = ( |
| 131 | + torch.randn(num_heads, device=device, dtype=dtype) if use_softmax_sink else None |
| 132 | + ) |
| 133 | + |
| 134 | + softmax_scale = 1.0 / (head_dim**0.5) |
| 135 | + |
| 136 | + quantiles = [0.5, 0.2, 0.8] |
| 137 | + |
| 138 | + if provider == "flash_attn": |
| 139 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 140 | + lambda: flash_attn_baseline( |
| 141 | + q.clone(), |
| 142 | + k_cache.clone(), |
| 143 | + v_cache.clone(), |
| 144 | + causal=causal, |
| 145 | + window_size=window_size, |
| 146 | + softmax_scale=softmax_scale, |
| 147 | + softmax_sink=softmax_sink, |
| 148 | + cache_seqlens=cache_seqlens, |
| 149 | + page_table=page_table, |
| 150 | + cu_seqlens_q=cu_seqlens_q, |
| 151 | + max_seqlen_q=max_seqlen_q, |
| 152 | + ), |
| 153 | + quantiles=quantiles, |
| 154 | + ) |
| 155 | + |
| 156 | + return 1000 * ms, 1000 * max_ms, 1000 * min_ms |
| 157 | + |
| 158 | + |
| 159 | +if __name__ == "__main__": |
| 160 | + benchmark.run(print_data=True) |
| 161 | + print("Benchmark finished!") |
0 commit comments