@@ -31,31 +31,34 @@ void batch_prefill(const torch::Tensor& query,
3131 const torch::Tensor& mask,
3232 const torch::Tensor& seq_len,
3333 float scale,
34- int num_heads,
35- int num_kv_heads,
3634 torch::Tensor& output) {
35+ auto num_heads = query.size (-2 );
36+ auto num_kv_heads = key.size (-2 );
3737 atb::_npu_flash_attention (
3838 query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output);
3939}
4040
4141void batch_decode (const torch::Tensor& query,
4242 const torch::Tensor& k_cache,
4343 const torch::Tensor& v_cache,
44- int num_kv_heads,
45- int num_heads,
4644 float scale,
4745 const torch::Tensor& block_table,
4846 const torch::Tensor& seq_lens,
4947 torch::Tensor& output) {
50- atb::_npu_paged_attention (query,
48+ auto head_size = query.size (-1 );
49+ auto num_heads = query.size (-2 );
50+ auto num_kv_heads = k_cache.size (-2 );
51+ auto q = query.view ({-1 , num_heads, head_size});
52+ auto o = output.view ({-1 , num_heads, head_size});
53+ atb::_npu_paged_attention (q,
5154 k_cache,
5255 v_cache,
5356 num_kv_heads,
5457 num_heads,
5558 scale,
5659 block_table,
5760 seq_lens,
58- output );
61+ o );
5962}
6063
6164} // namespace xllm::kernel::npu
0 commit comments