Skip to content

Commit ba915e0

Browse files
authored
[BugFix]Fix attention mask bug in D-Node of PD-split mode (#5245)
1 parent 61fc368 commit ba915e0

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

custom_ops/gpu_ops/update_attn_mask_offsets.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ __global__ void update_attn_mask_offsets_kernel(
2424
int* attn_mask_offsets_decoder,
2525
const bool* is_block_step,
2626
int* decode_states,
27-
const int* mask_rollback,
27+
int* mask_rollback,
2828
const int real_bsz,
2929
const int max_model_len,
3030
const int decode_states_len) {
@@ -58,7 +58,7 @@ __global__ void update_attn_mask_offsets_kernel(
5858
// Status: decoder -- normal or chunk_prefill
5959
// TODO: support speculative decoding.
6060
attn_mask_offsets_decoder[bid] -= mask_rollback[bid];
61-
61+
mask_rollback[bid] = 0;
6262
for (int i = 0; i < seq_len_this_time; i++) {
6363
attn_mask_offsets[(query_start_id + i) * 2 + 1] =
6464
attn_mask_offsets_decoder[bid] + 1 + i;
@@ -117,7 +117,7 @@ std::vector<paddle::Tensor> UpdateAttnMaskOffsets(
117117
const_cast<int*>(attn_mask_offsets_decoder.data<int>()),
118118
is_block_step.data<bool>(),
119119
const_cast<int*>(decode_states.data<int>()),
120-
mask_rollback.data<int>(),
120+
const_cast<int*>(mask_rollback.data<int>()),
121121
real_bsz,
122122
max_model_len,
123123
decode_states_len);
@@ -136,6 +136,7 @@ PD_BUILD_STATIC_OP(update_attn_mask_offsets)
136136
"is_block_step",
137137
"decode_states",
138138
"mask_rollback"})
139-
.Outputs({"attn_mask_offsets", "decode_states_out"})
140-
.SetInplaceMap({{"decode_states", "decode_states_out"}})
139+
.Outputs({"attn_mask_offsets", "decode_states_out", "mask_rollback_out"})
140+
.SetInplaceMap({{"decode_states", "decode_states_out"},
141+
{"mask_rollback", "mask_rollback_out"}})
141142
.SetKernelFn(PD_KERNEL(UpdateAttnMaskOffsets));

fastdeploy/engine/common_engine.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,6 @@ def start_worker_queue_service(self, start_queue):
319319
)
320320
self.cfg.cache_config.cache_queue_port = self.cache_task_queue.get_server_port()
321321

322-
self.llm_logger.info(
323-
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
324-
)
325322
self.engine_worker_queue = EngineWorkerQueue(
326323
address=address,
327324
is_server=False,

fastdeploy/spec_decode/mtp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,12 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
515515
self.model_inputs["attn_mask_offsets_decoder"][idx : idx + 1] = (
516516
inputs["attention_mask_offset"][prefill_end_index - 1] + 1
517517
)
518+
if (
519+
self.fd_config.scheduler_config.splitwise_role == "decode"
520+
): # In PD, we continue to decode after P generates first token
521+
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
522+
# P-D split need rollback one step
523+
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
518524

519525
# has_prefill_task = True
520526
elif request.task_type.value == RequestType.DECODE.value: # decode task

0 commit comments

Comments
 (0)