@@ -71,12 +71,9 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
71
71
ext .prompt_flash_attention (single_out , single_q , single_k , single_v , None , mask , [], head , scale , 2147473647 , 0 , "BSH" , numKeyValueHeads )
72
72
return out
73
73
74
- def fused_context_attention (q , k , v , out , b_start_loc , b_seq_len , max_input_len ):
75
- batch , head , dim = b_start_loc .shape [0 ], q .shape [1 ], q .shape [2 ]
76
- numKeyValueHeads = k .shape [1 ]
77
- assert k .shape [1 ] == v .shape [1 ]
74
+ def fused_context_attention (q , k , v , out , b_start_loc , b_seq_len , max_input_len , head , numKeyValueHeads , dim ):
75
+ batch = b_start_loc .shape [0 ]
78
76
scale = 1 / math .sqrt (dim )
79
-
80
77
mask_key_str = str (batch ) + ":" + str (max_input_len )
81
78
if mask_key_str not in mask_cache :
82
79
mask = torch .tril (torch .ones (max_input_len , max_input_len , dtype = torch .bool ), diagonal = 0 ).cuda ()
@@ -86,11 +83,7 @@ def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
86
83
print (f"cache mask in context attention, batch:seqLen={ mask_key_str } " )
87
84
88
85
mask = mask_cache [mask_key_str ]
89
- ext .prompt_flash_attention (
90
- out .view (batch , max_input_len , head * dim ),
91
- q .view (batch , max_input_len , head * dim ),
92
- k .view (batch , max_input_len , numKeyValueHeads * dim ),
93
- v .view (batch , max_input_len , numKeyValueHeads * dim ),
86
+ ext .prompt_flash_attention (out , q , k , v ,
94
87
None , mask , b_seq_len , head , scale , 2147473647 , 0 , "BSH" , numKeyValueHeads )
95
88
return out
96
89
@@ -101,18 +94,10 @@ def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
101
94
102
95
103
96
def patch_paged_token_attention_inference ():
104
- def paged_token_attention (q , k_cache , v_cache , out , kv_head_num , b_seq_len , block_table :torch .Tensor , block_size ):
105
- # numKeyValueHeads = k_cache.shape[1]
106
- # assert k_cache.shape[1] == v_cache.shape[1]
107
- batch , head , dim = q .shape
108
- kv_cache_len = k_cache .shape [0 ]
109
- ext .paged_attention (out .view (batch , 1 , head * dim ),
110
- q .view (batch , 1 , head * dim ),
111
- k_cache .view (kv_cache_len , 1 , kv_head_num * dim ),
112
- v_cache .view (kv_cache_len , 1 , kv_head_num * dim ),
113
- None , None ,
114
- b_seq_len , block_table , head , kv_head_num ,
115
- 1.0 / math .sqrt (dim ), "BSH" , block_size , 0 ,
97
+ def paged_token_attention (q , k_cache , v_cache , out , q_head_num , kv_head_num , head_dim , b_seq_len , block_table :torch .Tensor , block_size ):
98
+ ext .paged_attention (out , q , k_cache , v_cache , None , None ,
99
+ b_seq_len , block_table , q_head_num , kv_head_num ,
100
+ 1.0 / math .sqrt (head_dim ), "BSH" , block_size , 0 ,
116
101
None , None , None , None , None , None , None , None
117
102
)
118
103
return out
0 commit comments