@@ -216,7 +216,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
216216 state_.q_seq_lens .insert (state_.q_seq_lens .end (),
217217 state.q_seq_lens .begin (),
218218 state.q_seq_lens .end ());
219- #elif defined(USE_MLU)
219+ #elif defined(USE_MLU) || defined(USE_CUDA)
220220 int32_t seq_len_offset = state_.seq_lens .back ();
221221 // skip the first element which is 0
222222 for (size_t i = 1 ; i < state.seq_lens .size (); ++i) {
@@ -248,6 +248,16 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
248248 state.kv_cache_start_offsets .begin (),
249249 state.kv_cache_start_offsets .end ());
250250 }
251+ // for flashinfer
252+ state_.paged_kv_indptr .insert (state_.paged_kv_indptr .end (),
253+ state.paged_kv_indptr .begin (),
254+ state.paged_kv_indptr .end ());
255+ state_.paged_kv_indices .insert (state_.paged_kv_indices .end (),
256+ state.paged_kv_indices .begin (),
257+ state.paged_kv_indices .end ());
258+ state_.paged_kv_last_page_len .insert (state_.paged_kv_last_page_len .end (),
259+ state.paged_kv_last_page_len .begin (),
260+ state.paged_kv_last_page_len .end ());
251261 }
252262 for (const auto & write_block_ids : thread_write_block_ids) {
253263 write_block_ids_.insert (write_block_ids.begin (), write_block_ids.end ());
@@ -288,7 +298,7 @@ void BatchInputBuilder::process_single_sequence(
288298#if defined(USE_NPU)
289299 state.seq_lens .push_back (seq_len);
290300 state.q_seq_lens .push_back (q_seq_len);
291- #elif defined(USE_MLU)
301+ #elif defined(USE_MLU) || defined(USE_CUDA)
292302 state.seq_lens .push_back (state.seq_lens .back () + seq_len);
293303 state.q_seq_lens .push_back (state.q_seq_lens .back () + q_seq_len);
294304#endif
@@ -448,7 +458,12 @@ void BatchInputBuilder::setup_kv_cache_info(
448458 block_size = block.size ();
449459 block_ids.push_back (block.id ());
450460 u_block_ids.emplace_back (block.id ());
461+ state.paged_kv_indices .push_back (block.id ());
451462 }
463+ state.paged_kv_indptr .push_back (state.paged_kv_indptr .back () + blocks.size ());
464+ int32_t last_page_len =
465+ (seq_len % block_size == 0 ) ? block_size : seq_len % block_size;
466+ state.paged_kv_last_page_len .push_back (last_page_len);
452467
453468 int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size;
454469 for (auto iter = block_ids.begin () + kv_cache_block_idx;
@@ -517,12 +532,15 @@ void BatchInputBuilder::padding_decode_batch_size(
517532#if defined(USE_NPU)
518533 state_.seq_lens .push_back (num_decoding_tokens);
519534 state_.q_seq_lens .push_back (num_decoding_tokens);
520- #elif defined(USE_MLU)
535+ #elif defined(USE_MLU) || defined(USE_CUDA)
521536 state_.seq_lens .push_back (state_.seq_lens .back () + num_decoding_tokens);
522537 state_.q_seq_lens .push_back (state_.q_seq_lens .back () +
523538 num_decoding_tokens);
524539#endif
525540 state_.block_tables_vec .emplace_back ();
541+ state_.paged_kv_indices .push_back (0 );
542+ state_.paged_kv_indptr .push_back (state_.paged_kv_indptr .back () + 1 );
543+ state_.paged_kv_last_page_len .push_back (1 );
526544 }
527545 }
528546 }
@@ -560,6 +578,14 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
560578 input_params.decode_seq_range =
561579 util::find_ones_indices (input_params.q_seq_lens_vec );
562580
581+ // for flashinfer
582+ input_params.paged_kv_indptr =
583+ torch::tensor (state_.paged_kv_indptr , torch::kInt );
584+ input_params.paged_kv_indices =
585+ torch::tensor (state_.paged_kv_indices , torch::kInt );
586+ input_params.paged_kv_last_page_len =
587+ torch::tensor (state_.paged_kv_last_page_len , torch::kInt );
588+
563589 // Setup multimodal data
564590 input_params.mm_data = MMData::batch (mm_data_vec_);
565591
@@ -634,6 +660,12 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
634660 raw_forward_input.transfer_kv_infos = std::move (state_.transfer_kv_infos );
635661 raw_forward_input.prefill_seq_len = state_.prefill_seq_len ;
636662
663+ // for flashinfer
664+ raw_forward_input.paged_kv_indptr = std::move (state_.paged_kv_indptr );
665+ raw_forward_input.paged_kv_indices = std::move (state_.paged_kv_indices );
666+ raw_forward_input.paged_kv_last_page_len =
667+ std::move (state_.paged_kv_last_page_len );
668+
637669 raw_forward_input.embedding_ids = std::move (state_.embedding_ids );
638670 raw_forward_input.extra_token_ids = std::move (state_.extra_token_ids );
639671 // beam search kernel input
0 commit comments