@@ -24,7 +24,7 @@ __global__ void update_attn_mask_offsets_kernel(
2424 int * attn_mask_offsets_decoder,
2525 const bool * is_block_step,
2626 int * decode_states,
27- const int * mask_rollback,
27+ int * mask_rollback,
2828 const int real_bsz,
2929 const int max_model_len,
3030 const int decode_states_len) {
@@ -58,7 +58,7 @@ __global__ void update_attn_mask_offsets_kernel(
5858 // Status: decoder -- normal or chunk_prefill
5959 // TODO: support speculative decoding.
6060 attn_mask_offsets_decoder[bid] -= mask_rollback[bid];
61-
61+ mask_rollback[bid] = 0 ;
6262 for (int i = 0 ; i < seq_len_this_time; i++) {
6363 attn_mask_offsets[(query_start_id + i) * 2 + 1 ] =
6464 attn_mask_offsets_decoder[bid] + 1 + i;
@@ -117,7 +117,7 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
117117 const_cast <int *>(attn_mask_offsets_decoder.data <int >()),
118118 is_block_step.data <bool >(),
119119 const_cast <int *>(decode_states.data <int >()),
120- mask_rollback.data <int >(),
120+ const_cast < int *>( mask_rollback.data <int >() ),
121121 real_bsz,
122122 max_model_len,
123123 decode_states_len);
@@ -136,6 +136,7 @@ PD_BUILD_STATIC_OP(update_attn_mask_offsets)
136136 " is_block_step" ,
137137 " decode_states" ,
138138 " mask_rollback" })
139- .Outputs({" attn_mask_offsets" , " decode_states_out" })
140- .SetInplaceMap({{" decode_states" , " decode_states_out" }})
139+ .Outputs({" attn_mask_offsets" , " decode_states_out" , " mask_rollback_out" })
140+ .SetInplaceMap({{" decode_states" , " decode_states_out" },
141+ {" mask_rollback" , " mask_rollback_out" }})
141142 .SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets));
0 commit comments